Skip to content

Commit

Permalink
[CP-SAT] more work on hints
Browse files Browse the repository at this point in the history
  • Loading branch information
lperron committed Dec 13, 2024
1 parent 8e4ce6c commit 83e80cf
Show file tree
Hide file tree
Showing 13 changed files with 89 additions and 40 deletions.
39 changes: 38 additions & 1 deletion ortools/sat/cp_model_presolve.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2942,6 +2942,34 @@ bool CpModelPresolver::PresolveSmallLinear(ConstraintProto* ct) {
return false;
}

namespace {
// Set the hint in `context` for the variable in `equality` that has no hint, if
// there is exactly one. Otherwise do nothing.
void MaybeComputeMissingHint(PresolveContext* context,
const LinearConstraintProto& equality) {
DCHECK(equality.domain_size() == 2 &&
equality.domain(0) == equality.domain(1));
if (!context->HintIsLoaded()) return;
int term_with_missing_hint = -1;
int64_t missing_term_value = equality.domain(0);
for (int i = 0; i < equality.vars_size(); ++i) {
if (context->VarHasSolutionHint(equality.vars(i))) {
missing_term_value -=
context->SolutionHint(equality.vars(i)) * equality.coeffs(i);
} else if (term_with_missing_hint == -1) {
term_with_missing_hint = i;
} else {
// More than one variable has a missing hint.
return;
}
}
if (term_with_missing_hint == -1) return;
context->SetNewVariableHint(
equality.vars(term_with_missing_hint),
missing_term_value / equality.coeffs(term_with_missing_hint));
}
} // namespace

bool CpModelPresolver::PresolveDiophantine(ConstraintProto* ct) {
if (ct->constraint_case() != ConstraintProto::kLinear) return false;
if (ct->linear().vars().size() <= 1) return false;
Expand Down Expand Up @@ -3064,6 +3092,15 @@ bool CpModelPresolver::PresolveDiophantine(ConstraintProto* ct) {
}
}
context_->InitializeNewDomains();
// Scan the new constraints added above in reverse order so that the hint of
// `new_variables[k]` can be computed from the hint of the existing variables
// and from the hints of `new_variables[k']`, with k' > k.
const int num_constraints = context_->working_model->constraints_size();
for (int i = 0; i < num_replaced_variables; ++i) {
MaybeComputeMissingHint(
context_,
context_->working_model->constraints(num_constraints - 1 - i).linear());
}

if (VLOG_IS_ON(2)) {
std::string log_eq = absl::StrCat(linear_constraint.domain(0), " = ");
Expand Down Expand Up @@ -7457,7 +7494,7 @@ void CpModelPresolver::Probe() {
namespace {

bool FixFromAssignment(const VariablesAssignment& assignment,
const std::vector<int>& var_mapping,
absl::Span<const int> var_mapping,
PresolveContext* context) {
const int num_vars = assignment.NumberOfVariables();
for (int i = 0; i < num_vars; ++i) {
Expand Down
4 changes: 2 additions & 2 deletions ortools/sat/cp_model_solver_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ IntegerVariable GetOrCreateVariableWithTightBound(
}

IntegerVariable GetOrCreateVariableLinkedToSumOf(
const std::vector<std::pair<IntegerVariable, int64_t>>& terms,
absl::Span<const std::pair<IntegerVariable, int64_t>> terms,
bool lb_required, bool ub_required, Model* model) {
if (terms.empty()) return model->Add(ConstantIntegerVariable(0));
if (terms.size() == 1 && terms.front().second == 1) {
Expand Down Expand Up @@ -1862,7 +1862,7 @@ void PostsolveResponseWithFullSolver(int num_variables_in_original_model,
void PostsolveResponseWrapper(const SatParameters& params,
int num_variable_in_original_model,
const CpModelProto& mapping_proto,
const std::vector<int>& postsolve_mapping,
absl::Span<const int> postsolve_mapping,
std::vector<int64_t>* solution) {
if (params.debug_postsolve_with_full_solver()) {
PostsolveResponseWithFullSolver(num_variable_in_original_model,
Expand Down
3 changes: 2 additions & 1 deletion ortools/sat/cp_model_solver_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <vector>

#include "absl/flags/declare.h"
#include "absl/types/span.h"
#include "ortools/base/timer.h"
#include "ortools/sat/cp_model.pb.h"
#include "ortools/sat/integer_base.h"
Expand Down Expand Up @@ -128,7 +129,7 @@ int RegisterClausesLevelZeroImport(int id,
void PostsolveResponseWrapper(const SatParameters& params,
int num_variable_in_original_model,
const CpModelProto& mapping_proto,
const std::vector<int>& postsolve_mapping,
absl::Span<const int> postsolve_mapping,
std::vector<int64_t>* solution);

// Try to find a solution by following the hint and using a low conflict limit.
Expand Down
8 changes: 4 additions & 4 deletions ortools/sat/cuts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2414,7 +2414,7 @@ IntegerValue SumOfAllDiffLowerBounder::GetBestLowerBound(std::string& suffix) {
namespace {

void TryToGenerateAllDiffCut(
const std::vector<std::pair<double, AffineExpression>>& sorted_exprs_lp,
absl::Span<const std::pair<double, AffineExpression>> sorted_exprs_lp,
const IntegerTrail& integer_trail,
const util_intops::StrongVector<IntegerVariable, double>& lp_values,
TopNCuts& top_n_cuts, Model* model) {
Expand Down Expand Up @@ -2527,8 +2527,8 @@ IntegerValue MaxCornerDifference(const IntegerVariable var,
// target expr I(i), max expr k.
// The coefficient of zk is Sum(i=1..n)(MPlusCoefficient_ki) + bk
IntegerValue MPlusCoefficient(
const std::vector<IntegerVariable>& x_vars,
const std::vector<LinearExpression>& exprs,
absl::Span<const IntegerVariable> x_vars,
absl::Span<const LinearExpression> exprs,
const util_intops::StrongVector<IntegerVariable, int>& variable_partition,
const int max_index, const IntegerTrail& integer_trail) {
IntegerValue coeff = exprs[max_index].offset;
Expand Down Expand Up @@ -2659,7 +2659,7 @@ IntegerValue EvaluateMaxAffine(

bool BuildMaxAffineUpConstraint(
const LinearExpression& target, IntegerVariable var,
const std::vector<std::pair<IntegerValue, IntegerValue>>& affines,
absl::Span<const std::pair<IntegerValue, IntegerValue>> affines,
Model* model, LinearConstraintBuilder* builder) {
auto* integer_trail = model->GetOrCreate<IntegerTrail>();
const IntegerValue x_min = integer_trail->LevelZeroLowerBound(var);
Expand Down
2 changes: 1 addition & 1 deletion ortools/sat/cuts.h
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ CutGenerator CreateLinMaxCutGenerator(
// This function will reset the bounds of the builder.
bool BuildMaxAffineUpConstraint(
const LinearExpression& target, IntegerVariable var,
const std::vector<std::pair<IntegerValue, IntegerValue>>& affines,
absl::Span<const std::pair<IntegerValue, IntegerValue>> affines,
Model* model, LinearConstraintBuilder* builder);

// By definition, the Max of affine functions is convex. The linear polytope is
Expand Down
2 changes: 1 addition & 1 deletion ortools/sat/diffn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ int NonOverlappingRectanglesEnergyPropagator::RegisterWith(
}

bool NonOverlappingRectanglesEnergyPropagator::BuildAndReportEnergyTooLarge(
const std::vector<RectangleInRange>& ranges) {
absl::Span<const RectangleInRange> ranges) {
if (ranges.size() == 2) {
num_conflicts_two_boxes_++;
return ClearAndAddTwoBoxesConflictReason(ranges[0].box_index,
Expand Down
3 changes: 1 addition & 2 deletions ortools/sat/diffn.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ class NonOverlappingRectanglesEnergyPropagator : public PropagatorInterface {

std::vector<RectangleInRange> GeneralizeExplanation(const Conflict& conflict);

bool BuildAndReportEnergyTooLarge(
const std::vector<RectangleInRange>& ranges);
bool BuildAndReportEnergyTooLarge(absl::Span<const RectangleInRange> ranges);

SchedulingConstraintHelper& x_;
SchedulingConstraintHelper& y_;
Expand Down
16 changes: 8 additions & 8 deletions ortools/sat/diffn_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,11 @@ std::vector<std::vector<int>> GetOverlappingIntervalComponentsBruteForce(
components[component_indices[i]].push_back(i);
}
// Sort the components by start, like GetOverlappingIntervalComponents().
absl::c_sort(components, [intervals](const std::vector<int>& c1,
const std::vector<int>& c2) {
CHECK(!c1.empty() && !c2.empty());
return intervals[c1[0]].start < intervals[c2[0]].start;
});
absl::c_sort(components,
[intervals](absl::Span<const int> c1, absl::Span<const int> c2) {
CHECK(!c1.empty() && !c2.empty());
return intervals[c1[0]].start < intervals[c2[0]].start;
});
// Inside each component, the intervals should be sorted, too.
// Moreover, we need to convert our indices to IntervalIndex.index.
for (std::vector<int>& component : components) {
Expand Down Expand Up @@ -736,7 +736,7 @@ void ReduceUntilDone(ProbingRectangle& ranges, absl::BitGen& random) {
// detect a conflict even if there is one by looking only at those rectangles,
// see the ProbingRectangleTest.CounterExample unit test for a concrete example.
std::optional<Rectangle> FindRectangleWithEnergyTooLargeExhaustive(
const std::vector<RectangleInRange>& box_ranges) {
absl::Span<const RectangleInRange> box_ranges) {
int num_boxes = box_ranges.size();
std::vector<IntegerValue> x;
x.reserve(num_boxes * 4);
Expand Down Expand Up @@ -957,8 +957,8 @@ TEST(FindPartialIntersections, Simple) {
}

bool GraphsDefineSameConnectedComponents(
const std::vector<std::pair<int, int>>& graph1,
const std::vector<std::pair<int, int>>& graph2) {
absl::Span<const std::pair<int, int>> graph1,
absl::Span<const std::pair<int, int>> graph2) {
int max = -1;
int max2 = -1;
for (const auto& [a, b] : graph1) {
Expand Down
14 changes: 10 additions & 4 deletions ortools/sat/integer_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -603,10 +603,16 @@ inline std::function<void(Model*)> ConditionalWeightedSumLowerOrEqual(
};
}
inline std::function<void(Model*)> ConditionalWeightedSumGreaterOrEqual(
const std::vector<Literal>& enforcement_literals,
const std::vector<IntegerVariable>& vars,
const std::vector<int64_t>& coefficients, int64_t upper_bound) {
return [=](Model* model) {
absl::Span<const Literal> enforcement_literals,
absl::Span<const IntegerVariable> vars,
absl::Span<const int64_t> coefficients, int64_t upper_bound) {
return [=,
coefficients =
std::vector<int64_t>(coefficients.begin(), coefficients.end()),
vars = std::vector<IntegerVariable>(vars.begin(), vars.end()),
enforcement_literals =
std::vector<Literal>(enforcement_literals.begin(),
enforcement_literals.end())](Model* model) {
AddWeightedSumGreaterOrEqual(enforcement_literals, vars, coefficients,
upper_bound, model);
};
Expand Down
9 changes: 6 additions & 3 deletions ortools/sat/integer_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1116,8 +1116,8 @@ std::function<BooleanOrIntegerLiteral()> RandomizeOnRestartHeuristic(
}

std::function<BooleanOrIntegerLiteral()> FollowHint(
const std::vector<BooleanOrIntegerVariable>& vars,
const std::vector<IntegerValue>& values, Model* model) {
absl::Span<const BooleanOrIntegerVariable> vars,
absl::Span<const IntegerValue> values, Model* model) {
auto* trail = model->GetOrCreate<Trail>();
auto* integer_trail = model->GetOrCreate<IntegerTrail>();
auto* rev_int_repo = model->GetOrCreate<RevIntRepository>();
Expand All @@ -1130,7 +1130,10 @@ std::function<BooleanOrIntegerLiteral()> FollowHint(
int* rev_start_index = model->TakeOwnership(new int);
*rev_start_index = 0;

return [=]() {
return [=,
vars =
std::vector<BooleanOrIntegerVariable>(vars.begin(), vars.end()),
values = std::vector<IntegerValue>(values.begin(), values.end())]() {
rev_int_repo->SaveState(rev_start_index);
for (int i = *rev_start_index; i < vars.size(); ++i) {
const IntegerValue value = values[i];
Expand Down
4 changes: 2 additions & 2 deletions ortools/sat/integer_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ struct BooleanOrIntegerVariable {
IntegerVariable int_var = kNoIntegerVariable;
};
std::function<BooleanOrIntegerLiteral()> FollowHint(
const std::vector<BooleanOrIntegerVariable>& vars,
const std::vector<IntegerValue>& values, Model* model);
absl::Span<const BooleanOrIntegerVariable> vars,
absl::Span<const IntegerValue> values, Model* model);

// Combines search heuristics in order: if the i-th one returns kNoLiteralIndex,
// ask the (i+1)-th. If every heuristic returned kNoLiteralIndex,
Expand Down
6 changes: 3 additions & 3 deletions ortools/sat/sat_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ namespace {

// Returns true iff 'b' is subsumed by 'a' (i.e 'a' is included in 'b').
// This is slow and only meant to be used in DCHECKs.
bool ClauseSubsumption(const std::vector<Literal>& a, SatClause* b) {
bool ClauseSubsumption(absl::Span<const Literal> a, SatClause* b) {
std::vector<Literal> superset(b->begin(), b->end());
std::vector<Literal> subset(a.begin(), a.end());
std::sort(superset.begin(), superset.end());
Expand Down Expand Up @@ -1062,7 +1062,7 @@ void SatSolver::Backtrack(int target_level) {
last_decision_or_backtrack_trail_index_ = trail_->Index();
}

bool SatSolver::AddBinaryClauses(const std::vector<BinaryClause>& clauses) {
bool SatSolver::AddBinaryClauses(absl::Span<const BinaryClause> clauses) {
SCOPED_TIME_STAT(&stats_);
CHECK_EQ(CurrentDecisionLevel(), 0);
for (const BinaryClause c : clauses) {
Expand Down Expand Up @@ -1684,7 +1684,7 @@ void SatSolver::UpdateClauseActivityIncrement() {
clause_activity_increment_ *= 1.0 / parameters_->clause_activity_decay();
}

bool SatSolver::IsConflictValid(const std::vector<Literal>& literals) {
bool SatSolver::IsConflictValid(absl::Span<const Literal> literals) {
SCOPED_TIME_STAT(&stats_);
if (literals.empty()) return false;
const int highest_level = DecisionLevel(literals[0].Variable());
Expand Down
19 changes: 11 additions & 8 deletions ortools/sat/sat_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ class SatSolver {
// Functions to manage the set of learned binary clauses.
// Only clauses added/learned when TrackBinaryClause() is true are managed.
void TrackBinaryClauses(bool value) { track_binary_clauses_ = value; }
bool AddBinaryClauses(const std::vector<BinaryClause>& clauses);
bool AddBinaryClauses(absl::Span<const BinaryClause> clauses);
const std::vector<BinaryClause>& NewlyAddedBinaryClauses();
void ClearNewlyAddedBinaryClauses();

Expand Down Expand Up @@ -690,7 +690,7 @@ class SatSolver {
// - This literal appears in the first position.
// - All the other literals are of smaller decision level.
// - There is no literal with a decision level of zero.
bool IsConflictValid(const std::vector<Literal>& literals);
bool IsConflictValid(absl::Span<const Literal> literals);

// Given the learned clause after a conflict, this computes the correct
// backtrack level to call Backtrack() with.
Expand Down Expand Up @@ -912,8 +912,9 @@ inline std::function<void(Model*)> CardinalityConstraint(
}

inline std::function<void(Model*)> ExactlyOneConstraint(
const std::vector<Literal>& literals) {
return [=](Model* model) {
absl::Span<const Literal> literals) {
return [=, literals = std::vector<Literal>(literals.begin(), literals.end())](
Model* model) {
std::vector<LiteralWithCoeff> cst;
cst.reserve(literals.size());
for (const Literal l : literals) {
Expand All @@ -926,8 +927,9 @@ inline std::function<void(Model*)> ExactlyOneConstraint(
}

inline std::function<void(Model*)> AtMostOneConstraint(
const std::vector<Literal>& literals) {
return [=](Model* model) {
absl::Span<const Literal> literals) {
return [=, literals = std::vector<Literal>(literals.begin(), literals.end())](
Model* model) {
std::vector<LiteralWithCoeff> cst;
cst.reserve(literals.size());
for (const Literal l : literals) {
Expand Down Expand Up @@ -997,8 +999,9 @@ inline std::function<void(Model*)> EnforcedClause(
//
// Note(user): we could have called ReifiedBoolOr() with everything negated.
inline std::function<void(Model*)> ReifiedBoolAnd(
const std::vector<Literal>& literals, Literal r) {
return [=](Model* model) {
absl::Span<const Literal> literals, Literal r) {
return [=, literals = std::vector<Literal>(literals.begin(), literals.end())](
Model* model) {
std::vector<Literal> clause;
for (const Literal l : literals) {
model->Add(Implication(r, l)); // r => l.
Expand Down

0 comments on commit 83e80cf

Please sign in to comment.