diff options
author | Tim King <taking@cs.nyu.edu> | 2013-04-30 00:46:14 -0400 |
---|---|---|
committer | Tim King <taking@cs.nyu.edu> | 2013-04-30 00:46:14 -0400 |
commit | 2b9e032cc93a96dccab8757326645da82b5866e5 (patch) | |
tree | 3d579a615f0d3acbf7edadc7cf81a237c4888f43 /src/theory/arith/linear_equality.cpp | |
parent | 9098391fe334d829ec4101f190b8f1fa21c30752 (diff) |
Adding has bound counts and tracking for rows.
Diffstat (limited to 'src/theory/arith/linear_equality.cpp')
-rw-r--r-- | src/theory/arith/linear_equality.cpp | 225 |
1 files changed, 95 insertions, 130 deletions
diff --git a/src/theory/arith/linear_equality.cpp b/src/theory/arith/linear_equality.cpp index 42d8b41f8..eda3bf682 100644 --- a/src/theory/arith/linear_equality.cpp +++ b/src/theory/arith/linear_equality.cpp @@ -31,12 +31,6 @@ template void LinearEqualityModule::propagateNonbasics<false>(ArithVar basic, Co template ArithVar LinearEqualityModule::selectSlack<true>(ArithVar x_i, VarPreferenceFunction pf) const; template ArithVar LinearEqualityModule::selectSlack<false>(ArithVar x_i, VarPreferenceFunction pf) const; -// template bool LinearEqualityModule::preferNonDegenerate<true>(const UpdateInfo& a, const UpdateInfo& b) const; -// template bool LinearEqualityModule::preferNonDegenerate<false>(const UpdateInfo& a, const UpdateInfo& b) const; - -// template bool LinearEqualityModule::preferErrorsFixed<true>(const UpdateInfo& a, const UpdateInfo& b) const; -// template bool LinearEqualityModule::preferErrorsFixed<false>(const UpdateInfo& a, const UpdateInfo& b) const; - template bool LinearEqualityModule::preferWitness<true>(const UpdateInfo& a, const UpdateInfo& b) const; template bool LinearEqualityModule::preferWitness<false>(const UpdateInfo& a, const UpdateInfo& b) const; @@ -57,14 +51,14 @@ void Border::output(std::ostream& out) const{ << "}"; } -LinearEqualityModule::LinearEqualityModule(ArithVariables& vars, Tableau& t, BoundCountingVector& boundTracking, BasicVarModelUpdateCallBack f): +LinearEqualityModule::LinearEqualityModule(ArithVariables& vars, Tableau& t, BoundInfoMap& boundsTracking, BasicVarModelUpdateCallBack f): d_variables(vars), d_tableau(t), d_basicVariableUpdates(f), d_increasing(1), d_decreasing(-1), d_relevantErrorBuffer(), - d_boundTracking(boundTracking), + d_btracking(boundsTracking), d_areTracking(false), d_trackCallback(this) {} @@ -103,31 +97,24 @@ LinearEqualityModule::Statistics::~Statistics(){ StatisticsRegistry::unregisterStat(&d_weakenings); StatisticsRegistry::unregisterStat(&d_weakenTime); } -void LinearEqualityModule::includeBoundCountChange(ArithVar nb, BoundCounts prev){ - if(d_tableau.isBasic(nb)){ - return; - } - Assert(!d_tableau.isBasic(nb)); +void LinearEqualityModule::includeBoundUpdate(ArithVar v, const BoundsInfo& prev){ Assert(!d_areTracking); - BoundCounts curr = d_variables.boundCounts(nb); + BoundsInfo curr = d_variables.boundsInfo(v); Assert(prev != curr); - Tableau::ColIterator basicIter = d_tableau.colIterator(nb); + Tableau::ColIterator basicIter = d_tableau.colIterator(v); for(; !basicIter.atEnd(); ++basicIter){ const Tableau::Entry& entry = *basicIter; - Assert(entry.getColVar() == nb); + Assert(entry.getColVar() == v); int a_ijSgn = entry.getCoefficient().sgn(); - ArithVar basic = d_tableau.rowIndexToBasic(entry.getRowIndex()); - - BoundCounts& counts = d_boundTracking.get(basic); - Debug("includeBoundCountChange") << basic << " " << counts << " to " ; - counts -= prev.multiplyBySgn(a_ijSgn); - counts += curr.multiplyBySgn(a_ijSgn); - Debug("includeBoundCountChange") << counts << " " << a_ijSgn << std::endl; + RowIndex ridx = entry.getRowIndex(); + BoundsInfo& counts = d_btracking.get(ridx); + Debug("includeBoundUpdate") << d_tableau.rowIndexToBasic(ridx) << " " << counts << " to " ; + counts.addInChange(a_ijSgn, prev, curr); + Debug("includeBoundUpdate") << counts << " " << a_ijSgn << std::endl; } - d_boundTracking.set(nb, curr); } void LinearEqualityModule::updateMany(const DenseMap<DeltaRational>& many){ @@ -231,9 +218,9 @@ void LinearEqualityModule::updateTracked(ArithVar x_i, const DeltaRational& v){ << d_variables.getAssignment(x_i) << "|-> " << v << endl; - BoundCounts before = d_variables.boundCounts(x_i); + BoundCounts before = d_variables.atBoundCounts(x_i); d_variables.setAssignment(x_i, v); - BoundCounts after = d_variables.boundCounts(x_i); + BoundCounts after = d_variables.atBoundCounts(x_i); bool anyChange = before != after; @@ -242,17 +229,24 @@ void LinearEqualityModule::updateTracked(ArithVar x_i, const DeltaRational& v){ const Tableau::Entry& entry = *colIter; Assert(entry.getColVar() == x_i); - ArithVar x_j = d_tableau.rowIndexToBasic(entry.getRowIndex()); + RowIndex ridx = entry.getRowIndex(); + ArithVar x_j = d_tableau.rowIndexToBasic(ridx); const Rational& a_ji = entry.getCoefficient(); const DeltaRational& assignment = d_variables.getAssignment(x_j); DeltaRational nAssignment = assignment+(diff * a_ji); Debug("update") << x_j << " " << a_ji << assignment << " -> " << nAssignment << endl; + BoundCounts xjBefore = d_variables.atBoundCounts(x_j); d_variables.setAssignment(x_j, nAssignment); + BoundCounts xjAfter = d_variables.atBoundCounts(x_j); - if(anyChange && basicIsTracked(x_j)){ - BoundCounts& next_bc_k = d_boundTracking.get(x_j); - next_bc_k.addInChange(a_ji.sgn(), before, after); + Assert(rowIndexIsTracked(ridx)); + BoundsInfo& next_bc_k = d_btracking.get(ridx); + if(anyChange){ + next_bc_k.addInAtBoundChange(a_ji.sgn(), before, after); + } + if(xjBefore != xjAfter){ + next_bc_k.addInAtBoundChange(-1, xjBefore, xjAfter); } d_basicVariableUpdates(x_j); @@ -332,7 +326,7 @@ void LinearEqualityModule::debugCheckTracking(){ ArithVar var = entry.getColVar(); const Rational& coeff = entry.getCoefficient(); DeltaRational beta = d_variables.getAssignment(var); - Debug("arith::tracking") << var << " " << d_variables.boundCounts(var) + Debug("arith::tracking") << var << " " << d_variables.boundsInfo(var) << " " << beta << coeff; if(d_variables.hasLowerBound(var)){ Debug("arith::tracking") << "(lb " << d_variables.getLowerBound(var) << ")"; @@ -345,11 +339,12 @@ void LinearEqualityModule::debugCheckTracking(){ Debug("arith::tracking") << "end row"<< endl; if(basicIsTracked(basic)){ - BoundCounts computed = computeBoundCounts(basic); + RowIndex ridx = d_tableau.basicToRowIndex(basic); + BoundsInfo computed = computeRowBoundInfo(ridx, false); Debug("arith::tracking") << "computed " << computed - << " tracking " << d_boundTracking[basic] << endl; - Assert(computed == d_boundTracking[basic]); + << " tracking " << d_btracking[ridx] << endl; + Assert(computed == d_btracking[ridx]); } } @@ -745,60 +740,34 @@ void LinearEqualityModule::stopTrackingBoundCounts(){ } -void LinearEqualityModule::trackVariable(ArithVar x_i){ - Assert(!basicIsTracked(x_i)); - BoundCounts counts(0,0); - - for(Tableau::RowIterator iter = d_tableau.basicRowIterator(x_i); !iter.atEnd(); ++iter){ - const Tableau::Entry& entry = *iter; - ArithVar nonbasic = entry.getColVar(); - if(nonbasic == x_i) continue; - - const Rational& a_ij = entry.getCoefficient(); - counts += (d_variables.oldBoundCounts(nonbasic)).multiplyBySgn(a_ij.sgn()); - } - d_boundTracking.set(x_i, counts); +void LinearEqualityModule::trackRowIndex(RowIndex ridx){ + Assert(!rowIndexIsTracked(ridx)); + BoundsInfo bi = computeRowBoundInfo(ridx, true); + d_btracking.set(ridx, bi); } -BoundCounts LinearEqualityModule::computeBoundCounts(ArithVar x_i) const{ - BoundCounts counts(0,0); +BoundsInfo LinearEqualityModule::computeRowBoundInfo(RowIndex ridx, bool inQueue) const{ + BoundsInfo bi; - for(Tableau::RowIterator iter = d_tableau.basicRowIterator(x_i); !iter.atEnd(); ++iter){ + for(Tableau::RowIterator iter = d_tableau.getRow(ridx).begin(); !iter.atEnd(); ++iter){ const Tableau::Entry& entry = *iter; - ArithVar nonbasic = entry.getColVar(); - if(nonbasic == x_i) continue; - + ArithVar v = entry.getColVar(); const Rational& a_ij = entry.getCoefficient(); - counts += (d_variables.boundCounts(nonbasic)).multiplyBySgn(a_ij.sgn()); + bi += (d_variables.selectBoundsInfo(v, inQueue)).multiplyBySgn(a_ij.sgn()); } - - return counts; + return bi; } -// BoundCounts LinearEqualityModule::cachingCountBounds(ArithVar x_i) const{ -// if(d_boundTracking.isKey(x_i)){ -// return d_boundTracking[x_i]; -// }else{ -// return computeBoundCounts(x_i); -// } -// } -BoundCounts LinearEqualityModule::_countBounds(ArithVar x_i) const { - Assert(d_boundTracking.isKey(x_i)); - return d_boundTracking[x_i]; +BoundCounts LinearEqualityModule::debugBasicAtBoundCount(ArithVar x_i) const { + return d_btracking[d_tableau.basicToRowIndex(x_i)].atBounds(); } -// BoundCounts LinearEqualityModule::countBounds(ArithVar x_i){ -// if(d_boundTracking.isKey(x_i)){ -// return d_boundTracking[x_i]; -// }else{ -// BoundCounts bc = computeBoundCounts(x_i); -// if(d_areTracking){ -// d_boundTracking.set(x_i,bc); -// } -// return bc; -// } -// } - +/** + * If the pivot described in u were performed, + * then the row would qualify as being either at the minimum/maximum + * to the non-basics being at their bounds. + * The minimum/maximum is determined by the direction the non-basic is changing. + */ bool LinearEqualityModule::basicsAtBounds(const UpdateInfo& u) const { Assert(u.describesPivot()); @@ -814,79 +783,78 @@ bool LinearEqualityModule::basicsAtBounds(const UpdateInfo& u) const { int toLB = (c->getType() == LowerBound || c->getType() == Equality) ? 1 : 0; + RowIndex ridx = d_tableau.basicToRowIndex(basic); - BoundCounts bcs = d_boundTracking[basic]; + BoundCounts bcs = d_btracking[ridx].atBounds(); // x = c*n + \sum d*m - // n = 1/c * x + -1/c * (\sum d*m) - BoundCounts nonb = bcs - d_variables.boundCounts(nonbasic).multiplyBySgn(coeffSgn); + // 0 = -x + c*n + \sum d*m + // n = 1/c * x + -1/c * (\sum d*m) + BoundCounts nonb = bcs - d_variables.atBoundCounts(nonbasic).multiplyBySgn(coeffSgn); + nonb.addInChange(-1, d_variables.atBoundCounts(basic), BoundCounts(toLB, toUB)); nonb = nonb.multiplyBySgn(-coeffSgn); - nonb += BoundCounts(toLB, toUB).multiplyBySgn(coeffSgn); uint32_t length = d_tableau.basicRowLength(basic); Debug("basicsAtBounds") << "bcs " << bcs << "nonb " << nonb << "length " << length << endl; - + // nonb has nb excluded. if(nbdir < 0){ - return bcs.atLowerBounds() + 1 == length; + return nonb.lowerBoundCount() + 1 == length; }else{ Assert(nbdir > 0); - return bcs.atUpperBounds() + 1 == length; + return nonb.upperBoundCount() + 1 == length; } } bool LinearEqualityModule::nonbasicsAtLowerBounds(ArithVar basic) const { Assert(basicIsTracked(basic)); - BoundCounts bcs = d_boundTracking[basic]; + RowIndex ridx = d_tableau.basicToRowIndex(basic); + + BoundCounts bcs = d_btracking[ridx].atBounds(); uint32_t length = d_tableau.basicRowLength(basic); - return bcs.atLowerBounds() + 1 == length; + // return true if excluding the basic is every element is at its "lowerbound" + // The psuedo code is: + // bcs -= basic.count(basic, basic's sgn) + // return bcs.lowerBoundCount() + 1 == length + // As basic's sign is always -1, we can pull out the pieces of the count: + // bcs.lowerBoundCount() - basic.atUpperBoundInd() + 1 == length + // basic.atUpperBoundInd() is either 0 or 1 + uint32_t lbc = bcs.lowerBoundCount(); + return (lbc == length) || + (lbc + 1 == length && d_variables.cmpAssignmentUpperBound(basic) != 0); } bool LinearEqualityModule::nonbasicsAtUpperBounds(ArithVar basic) const { Assert(basicIsTracked(basic)); - BoundCounts bcs = d_boundTracking[basic]; + RowIndex ridx = d_tableau.basicToRowIndex(basic); + BoundCounts bcs = d_btracking[ridx].atBounds(); uint32_t length = d_tableau.basicRowLength(basic); + uint32_t ubc = bcs.upperBoundCount(); + // See the comment for nonbasicsAtLowerBounds() - return bcs.atUpperBounds() + 1 == length; + return (ubc == length) || + (ubc + 1 == length && d_variables.cmpAssignmentLowerBound(basic) != 0); } -void LinearEqualityModule::trackingSwap(ArithVar basic, ArithVar nb, int nbSgn) { - Assert(basicIsTracked(basic)); - - // z = a*x + \sum b*y - // x = (1/a) z + \sum (-1/a)*b*y - // basicCount(z) = bc(a*x) + bc(\sum b y) - // basicCount(x) = bc(z/a) + bc(\sum -b/a * y) - - // sgn(1/a) = sgn(a) - // bc(a*x) = bc(x).multiply(sgn(a)) - // bc(z/a) = bc(z).multiply(sgn(a)) - // bc(\sum -b/a * y) = bc(\sum b y).multiplyBySgn(-sgn(a)) - // bc(\sum b y) = basicCount(z) - bc(a*x) - // basicCount(x) = - // = bc(z).multiply(sgn(a)) + (basicCount(z) - bc(a*x)).multiplyBySgn(-sgn(a)) - - BoundCounts bc = d_boundTracking[basic]; - bc -= (d_variables.boundCounts(nb)).multiplyBySgn(nbSgn); - bc = bc.multiplyBySgn(-nbSgn); - bc += d_variables.boundCounts(basic).multiplyBySgn(nbSgn); - d_boundTracking.set(nb, bc); - d_boundTracking.remove(basic); +void LinearEqualityModule::trackingMultiplyRow(RowIndex ridx, int sgn) { + Assert(rowIndexIsTracked(ridx)); + Assert(sgn != 0); + if(sgn < 0){ + BoundsInfo& bi = d_btracking.get(ridx); + bi = bi.multiplyBySgn(sgn); + } } void LinearEqualityModule::trackingCoefficientChange(RowIndex ridx, ArithVar nb, int oldSgn, int currSgn){ Assert(oldSgn != currSgn); - BoundCounts nb_bc = d_variables.boundCounts(nb); + BoundsInfo nb_inf = d_variables.boundsInfo(nb); - if(!nb_bc.isZero()){ - ArithVar basic = d_tableau.rowIndexToBasic(ridx); - Assert(basicIsTracked(basic)); + Assert(rowIndexIsTracked(ridx)); - BoundCounts& basic_bc = d_boundTracking.get(basic); - basic_bc.addInSgn(nb_bc, oldSgn, currSgn); - } + BoundsInfo& row_bi = d_btracking.get(ridx); + row_bi.addInSgn(nb_inf, oldSgn, currSgn); } void LinearEqualityModule::computeSafeUpdate(UpdateInfo& inf, VarPreferenceFunction pref){ @@ -895,9 +863,6 @@ void LinearEqualityModule::computeSafeUpdate(UpdateInfo& inf, VarPreferenceFunct Assert(sgn != 0); Assert(!d_tableau.isBasic(nb)); - //inf.setErrorsChange(0); - //inf.setlimiting = NullConstraint; - // Error variables moving in the correct direction Assert(d_relevantErrorBuffer.empty()); @@ -1188,8 +1153,9 @@ bool LinearEqualityModule::willBeInConflictAfterPivot(const Tableau::Entry& entr // Assume past this point, nb will be in error if this pivot is done ArithVar nb = entry.getColVar(); - ArithVar basic = d_tableau.rowIndexToBasic(entry.getRowIndex()); - Assert(basicIsTracked(basic)); + RowIndex ridx = entry.getRowIndex(); + ArithVar basic = d_tableau.rowIndexToBasic(ridx); + Assert(rowIndexIsTracked(ridx)); int coeffSgn = entry.getCoefficient().sgn(); @@ -1201,12 +1167,11 @@ bool LinearEqualityModule::willBeInConflictAfterPivot(const Tableau::Entry& entr // 2) -a * x = -y + \sum b * z // 3) x = (-1/a) * ( -y + \sum b * z) - Assert(basicIsTracked(basic)); - BoundCounts bc = d_boundTracking[basic]; + BoundCounts bc = d_btracking[ridx].atBounds(); // 1) y = a * x + \sum b * z // Get bc(\sum b * z) - BoundCounts sumOnly = bc - d_variables.boundCounts(nb).multiplyBySgn(coeffSgn); + BoundCounts sumOnly = bc - d_variables.atBoundCounts(nb).multiplyBySgn(coeffSgn); // y's bounds in the proposed model int yWillBeAtUb = (bToUB || d_variables.boundsAreEqual(basic)) ? 1 : 0; @@ -1215,19 +1180,19 @@ bool LinearEqualityModule::willBeInConflictAfterPivot(const Tableau::Entry& entr // 2) -a * x = -y + \sum b * z // Get bc(-y + \sum b * z) - BoundCounts withNegY = sumOnly + ysBounds.multiplyBySgn(-1); + sumOnly.addInChange(-1, d_variables.atBoundCounts(basic), ysBounds); // 3) x = (-1/a) * ( -y + \sum b * z) // Get bc((-1/a) * ( -y + \sum b * z)) - BoundCounts xsBoundsAfterPivot = withNegY.multiplyBySgn(-coeffSgn); + BoundCounts xsBoundsAfterPivot = sumOnly.multiplyBySgn(-coeffSgn); uint32_t length = d_tableau.basicRowLength(basic); if(nbSgn > 0){ // Only check for the upper bound being violated - return xsBoundsAfterPivot.atLowerBounds() + 1 == length; + return xsBoundsAfterPivot.lowerBoundCount() + 1 == length; }else{ // Only check for the lower bound being violated - return xsBoundsAfterPivot.atUpperBounds() + 1 == length; + return xsBoundsAfterPivot.upperBoundCount() + 1 == length; } } |