From 83e80cf4691972af9d2f40d421de65f7708e01c8 Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Fri, 13 Dec 2024 14:50:57 +0100 Subject: [PATCH] [CP-SAT] more work on hints --- ortools/sat/cp_model_presolve.cc | 39 +++++++++++++++++++++++++- ortools/sat/cp_model_solver_helpers.cc | 4 +-- ortools/sat/cp_model_solver_helpers.h | 3 +- ortools/sat/cuts.cc | 8 +++--- ortools/sat/cuts.h | 2 +- ortools/sat/diffn.cc | 2 +- ortools/sat/diffn.h | 3 +- ortools/sat/diffn_util_test.cc | 16 +++++------ ortools/sat/integer_expr.h | 14 ++++++--- ortools/sat/integer_search.cc | 9 ++++-- ortools/sat/integer_search.h | 4 +-- ortools/sat/sat_solver.cc | 6 ++-- ortools/sat/sat_solver.h | 19 +++++++------ 13 files changed, 89 insertions(+), 40 deletions(-) diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index d9b02452e7..cd717a16d6 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -2942,6 +2942,34 @@ bool CpModelPresolver::PresolveSmallLinear(ConstraintProto* ct) { return false; } +namespace { +// Set the hint in `context` for the variable in `equality` that has no hint, if +// there is exactly one. Otherwise do nothing. +void MaybeComputeMissingHint(PresolveContext* context, + const LinearConstraintProto& equality) { + DCHECK(equality.domain_size() == 2 && + equality.domain(0) == equality.domain(1)); + if (!context->HintIsLoaded()) return; + int term_with_missing_hint = -1; + int64_t missing_term_value = equality.domain(0); + for (int i = 0; i < equality.vars_size(); ++i) { + if (context->VarHasSolutionHint(equality.vars(i))) { + missing_term_value -= + context->SolutionHint(equality.vars(i)) * equality.coeffs(i); + } else if (term_with_missing_hint == -1) { + term_with_missing_hint = i; + } else { + // More than one variable has a missing hint. + return; + } + } + if (term_with_missing_hint == -1) return; + context->SetNewVariableHint( + equality.vars(term_with_missing_hint), + missing_term_value / equality.coeffs(term_with_missing_hint)); +} +} // namespace + bool CpModelPresolver::PresolveDiophantine(ConstraintProto* ct) { if (ct->constraint_case() != ConstraintProto::kLinear) return false; if (ct->linear().vars().size() <= 1) return false; @@ -3064,6 +3092,15 @@ bool CpModelPresolver::PresolveDiophantine(ConstraintProto* ct) { } } context_->InitializeNewDomains(); + // Scan the new constraints added above in reverse order so that the hint of + // `new_variables[k]` can be computed from the hint of the existing variables + // and from the hints of `new_variables[k']`, with k' > k. + const int num_constraints = context_->working_model->constraints_size(); + for (int i = 0; i < num_replaced_variables; ++i) { + MaybeComputeMissingHint( + context_, + context_->working_model->constraints(num_constraints - 1 - i).linear()); + } if (VLOG_IS_ON(2)) { std::string log_eq = absl::StrCat(linear_constraint.domain(0), " = "); @@ -7457,7 +7494,7 @@ void CpModelPresolver::Probe() { namespace { bool FixFromAssignment(const VariablesAssignment& assignment, - const std::vector& var_mapping, + absl::Span var_mapping, PresolveContext* context) { const int num_vars = assignment.NumberOfVariables(); for (int i = 0; i < num_vars; ++i) { diff --git a/ortools/sat/cp_model_solver_helpers.cc b/ortools/sat/cp_model_solver_helpers.cc index 5ad19abe82..c95bafd6c5 100644 --- a/ortools/sat/cp_model_solver_helpers.cc +++ b/ortools/sat/cp_model_solver_helpers.cc @@ -337,7 +337,7 @@ IntegerVariable GetOrCreateVariableWithTightBound( } IntegerVariable GetOrCreateVariableLinkedToSumOf( - const std::vector>& terms, + absl::Span> terms, bool lb_required, bool ub_required, Model* model) { if (terms.empty()) return model->Add(ConstantIntegerVariable(0)); if (terms.size() == 1 && terms.front().second == 1) { @@ -1862,7 +1862,7 @@ void PostsolveResponseWithFullSolver(int num_variables_in_original_model, void PostsolveResponseWrapper(const SatParameters& params, int num_variable_in_original_model, const CpModelProto& mapping_proto, - const std::vector& postsolve_mapping, + absl::Span postsolve_mapping, std::vector* solution) { if (params.debug_postsolve_with_full_solver()) { PostsolveResponseWithFullSolver(num_variable_in_original_model, diff --git a/ortools/sat/cp_model_solver_helpers.h b/ortools/sat/cp_model_solver_helpers.h index 55403be916..01ff1892d1 100644 --- a/ortools/sat/cp_model_solver_helpers.h +++ b/ortools/sat/cp_model_solver_helpers.h @@ -22,6 +22,7 @@ #include #include "absl/flags/declare.h" +#include "absl/types/span.h" #include "ortools/base/timer.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/integer_base.h" @@ -128,7 +129,7 @@ int RegisterClausesLevelZeroImport(int id, void PostsolveResponseWrapper(const SatParameters& params, int num_variable_in_original_model, const CpModelProto& mapping_proto, - const std::vector& postsolve_mapping, + absl::Span postsolve_mapping, std::vector* solution); // Try to find a solution by following the hint and using a low conflict limit. diff --git a/ortools/sat/cuts.cc b/ortools/sat/cuts.cc index 31e1d3cbf4..4fe4f7f8d2 100644 --- a/ortools/sat/cuts.cc +++ b/ortools/sat/cuts.cc @@ -2414,7 +2414,7 @@ IntegerValue SumOfAllDiffLowerBounder::GetBestLowerBound(std::string& suffix) { namespace { void TryToGenerateAllDiffCut( - const std::vector>& sorted_exprs_lp, + absl::Span> sorted_exprs_lp, const IntegerTrail& integer_trail, const util_intops::StrongVector& lp_values, TopNCuts& top_n_cuts, Model* model) { @@ -2527,8 +2527,8 @@ IntegerValue MaxCornerDifference(const IntegerVariable var, // target expr I(i), max expr k. // The coefficient of zk is Sum(i=1..n)(MPlusCoefficient_ki) + bk IntegerValue MPlusCoefficient( - const std::vector& x_vars, - const std::vector& exprs, + absl::Span x_vars, + absl::Span exprs, const util_intops::StrongVector& variable_partition, const int max_index, const IntegerTrail& integer_trail) { IntegerValue coeff = exprs[max_index].offset; @@ -2659,7 +2659,7 @@ IntegerValue EvaluateMaxAffine( bool BuildMaxAffineUpConstraint( const LinearExpression& target, IntegerVariable var, - const std::vector>& affines, + absl::Span> affines, Model* model, LinearConstraintBuilder* builder) { auto* integer_trail = model->GetOrCreate(); const IntegerValue x_min = integer_trail->LevelZeroLowerBound(var); diff --git a/ortools/sat/cuts.h b/ortools/sat/cuts.h index c13957886b..c6d3636784 100644 --- a/ortools/sat/cuts.h +++ b/ortools/sat/cuts.h @@ -702,7 +702,7 @@ CutGenerator CreateLinMaxCutGenerator( // This function will reset the bounds of the builder. bool BuildMaxAffineUpConstraint( const LinearExpression& target, IntegerVariable var, - const std::vector>& affines, + absl::Span> affines, Model* model, LinearConstraintBuilder* builder); // By definition, the Max of affine functions is convex. The linear polytope is diff --git a/ortools/sat/diffn.cc b/ortools/sat/diffn.cc index 7f3bdc01a8..ef74f4ecc4 100644 --- a/ortools/sat/diffn.cc +++ b/ortools/sat/diffn.cc @@ -530,7 +530,7 @@ int NonOverlappingRectanglesEnergyPropagator::RegisterWith( } bool NonOverlappingRectanglesEnergyPropagator::BuildAndReportEnergyTooLarge( - const std::vector& ranges) { + absl::Span ranges) { if (ranges.size() == 2) { num_conflicts_two_boxes_++; return ClearAndAddTwoBoxesConflictReason(ranges[0].box_index, diff --git a/ortools/sat/diffn.h b/ortools/sat/diffn.h index d874af62bb..f7571bbd3a 100644 --- a/ortools/sat/diffn.h +++ b/ortools/sat/diffn.h @@ -63,8 +63,7 @@ class NonOverlappingRectanglesEnergyPropagator : public PropagatorInterface { std::vector GeneralizeExplanation(const Conflict& conflict); - bool BuildAndReportEnergyTooLarge( - const std::vector& ranges); + bool BuildAndReportEnergyTooLarge(absl::Span ranges); SchedulingConstraintHelper& x_; SchedulingConstraintHelper& y_; diff --git a/ortools/sat/diffn_util_test.cc b/ortools/sat/diffn_util_test.cc index 9788e1ebf1..c5670ee255 100644 --- a/ortools/sat/diffn_util_test.cc +++ b/ortools/sat/diffn_util_test.cc @@ -315,11 +315,11 @@ std::vector> GetOverlappingIntervalComponentsBruteForce( components[component_indices[i]].push_back(i); } // Sort the components by start, like GetOverlappingIntervalComponents(). - absl::c_sort(components, [intervals](const std::vector& c1, - const std::vector& c2) { - CHECK(!c1.empty() && !c2.empty()); - return intervals[c1[0]].start < intervals[c2[0]].start; - }); + absl::c_sort(components, + [intervals](absl::Span c1, absl::Span c2) { + CHECK(!c1.empty() && !c2.empty()); + return intervals[c1[0]].start < intervals[c2[0]].start; + }); // Inside each component, the intervals should be sorted, too. // Moreover, we need to convert our indices to IntervalIndex.index. for (std::vector& component : components) { @@ -736,7 +736,7 @@ void ReduceUntilDone(ProbingRectangle& ranges, absl::BitGen& random) { // detect a conflict even if there is one by looking only at those rectangles, // see the ProbingRectangleTest.CounterExample unit test for a concrete example. std::optional FindRectangleWithEnergyTooLargeExhaustive( - const std::vector& box_ranges) { + absl::Span box_ranges) { int num_boxes = box_ranges.size(); std::vector x; x.reserve(num_boxes * 4); @@ -957,8 +957,8 @@ TEST(FindPartialIntersections, Simple) { } bool GraphsDefineSameConnectedComponents( - const std::vector>& graph1, - const std::vector>& graph2) { + absl::Span> graph1, + absl::Span> graph2) { int max = -1; int max2 = -1; for (const auto& [a, b] : graph1) { diff --git a/ortools/sat/integer_expr.h b/ortools/sat/integer_expr.h index 4dcf7c3a55..2de0b04b8d 100644 --- a/ortools/sat/integer_expr.h +++ b/ortools/sat/integer_expr.h @@ -603,10 +603,16 @@ inline std::function ConditionalWeightedSumLowerOrEqual( }; } inline std::function ConditionalWeightedSumGreaterOrEqual( - const std::vector& enforcement_literals, - const std::vector& vars, - const std::vector& coefficients, int64_t upper_bound) { - return [=](Model* model) { + absl::Span enforcement_literals, + absl::Span vars, + absl::Span coefficients, int64_t upper_bound) { + return [=, + coefficients = + std::vector(coefficients.begin(), coefficients.end()), + vars = std::vector(vars.begin(), vars.end()), + enforcement_literals = + std::vector(enforcement_literals.begin(), + enforcement_literals.end())](Model* model) { AddWeightedSumGreaterOrEqual(enforcement_literals, vars, coefficients, upper_bound, model); }; diff --git a/ortools/sat/integer_search.cc b/ortools/sat/integer_search.cc index 3a666bcb93..7a833b8b03 100644 --- a/ortools/sat/integer_search.cc +++ b/ortools/sat/integer_search.cc @@ -1116,8 +1116,8 @@ std::function RandomizeOnRestartHeuristic( } std::function FollowHint( - const std::vector& vars, - const std::vector& values, Model* model) { + absl::Span vars, + absl::Span values, Model* model) { auto* trail = model->GetOrCreate(); auto* integer_trail = model->GetOrCreate(); auto* rev_int_repo = model->GetOrCreate(); @@ -1130,7 +1130,10 @@ std::function FollowHint( int* rev_start_index = model->TakeOwnership(new int); *rev_start_index = 0; - return [=]() { + return [=, + vars = + std::vector(vars.begin(), vars.end()), + values = std::vector(values.begin(), values.end())]() { rev_int_repo->SaveState(rev_start_index); for (int i = *rev_start_index; i < vars.size(); ++i) { const IntegerValue value = values[i]; diff --git a/ortools/sat/integer_search.h b/ortools/sat/integer_search.h index f8cfb7cf88..fca8e221ae 100644 --- a/ortools/sat/integer_search.h +++ b/ortools/sat/integer_search.h @@ -203,8 +203,8 @@ struct BooleanOrIntegerVariable { IntegerVariable int_var = kNoIntegerVariable; }; std::function FollowHint( - const std::vector& vars, - const std::vector& values, Model* model); + absl::Span vars, + absl::Span values, Model* model); // Combines search heuristics in order: if the i-th one returns kNoLiteralIndex, // ask the (i+1)-th. If every heuristic returned kNoLiteralIndex, diff --git a/ortools/sat/sat_solver.cc b/ortools/sat/sat_solver.cc index 9a8355ac51..13c2d10062 100644 --- a/ortools/sat/sat_solver.cc +++ b/ortools/sat/sat_solver.cc @@ -538,7 +538,7 @@ namespace { // Returns true iff 'b' is subsumed by 'a' (i.e 'a' is included in 'b'). // This is slow and only meant to be used in DCHECKs. -bool ClauseSubsumption(const std::vector& a, SatClause* b) { +bool ClauseSubsumption(absl::Span a, SatClause* b) { std::vector superset(b->begin(), b->end()); std::vector subset(a.begin(), a.end()); std::sort(superset.begin(), superset.end()); @@ -1062,7 +1062,7 @@ void SatSolver::Backtrack(int target_level) { last_decision_or_backtrack_trail_index_ = trail_->Index(); } -bool SatSolver::AddBinaryClauses(const std::vector& clauses) { +bool SatSolver::AddBinaryClauses(absl::Span clauses) { SCOPED_TIME_STAT(&stats_); CHECK_EQ(CurrentDecisionLevel(), 0); for (const BinaryClause c : clauses) { @@ -1684,7 +1684,7 @@ void SatSolver::UpdateClauseActivityIncrement() { clause_activity_increment_ *= 1.0 / parameters_->clause_activity_decay(); } -bool SatSolver::IsConflictValid(const std::vector& literals) { +bool SatSolver::IsConflictValid(absl::Span literals) { SCOPED_TIME_STAT(&stats_); if (literals.empty()) return false; const int highest_level = DecisionLevel(literals[0].Variable()); diff --git a/ortools/sat/sat_solver.h b/ortools/sat/sat_solver.h index 3bd16ace19..b0606c49c5 100644 --- a/ortools/sat/sat_solver.h +++ b/ortools/sat/sat_solver.h @@ -374,7 +374,7 @@ class SatSolver { // Functions to manage the set of learned binary clauses. // Only clauses added/learned when TrackBinaryClause() is true are managed. void TrackBinaryClauses(bool value) { track_binary_clauses_ = value; } - bool AddBinaryClauses(const std::vector& clauses); + bool AddBinaryClauses(absl::Span clauses); const std::vector& NewlyAddedBinaryClauses(); void ClearNewlyAddedBinaryClauses(); @@ -690,7 +690,7 @@ class SatSolver { // - This literal appears in the first position. // - All the other literals are of smaller decision level. // - There is no literal with a decision level of zero. - bool IsConflictValid(const std::vector& literals); + bool IsConflictValid(absl::Span literals); // Given the learned clause after a conflict, this computes the correct // backtrack level to call Backtrack() with. @@ -912,8 +912,9 @@ inline std::function CardinalityConstraint( } inline std::function ExactlyOneConstraint( - const std::vector& literals) { - return [=](Model* model) { + absl::Span literals) { + return [=, literals = std::vector(literals.begin(), literals.end())]( + Model* model) { std::vector cst; cst.reserve(literals.size()); for (const Literal l : literals) { @@ -926,8 +927,9 @@ inline std::function ExactlyOneConstraint( } inline std::function AtMostOneConstraint( - const std::vector& literals) { - return [=](Model* model) { + absl::Span literals) { + return [=, literals = std::vector(literals.begin(), literals.end())]( + Model* model) { std::vector cst; cst.reserve(literals.size()); for (const Literal l : literals) { @@ -997,8 +999,9 @@ inline std::function EnforcedClause( // // Note(user): we could have called ReifiedBoolOr() with everything negated. inline std::function ReifiedBoolAnd( - const std::vector& literals, Literal r) { - return [=](Model* model) { + absl::Span literals, Literal r) { + return [=, literals = std::vector(literals.begin(), literals.end())]( + Model* model) { std::vector clause; for (const Literal l : literals) { model->Add(Implication(r, l)); // r => l.