summaryrefslogtreecommitdiff
path: root/src/theory/arith/linear_equality.cpp
diff options
context:
space:
mode:
authorTim King <taking@cs.nyu.edu>2013-04-30 00:46:14 -0400
committerTim King <taking@cs.nyu.edu>2013-04-30 00:46:14 -0400
commit2b9e032cc93a96dccab8757326645da82b5866e5 (patch)
tree3d579a615f0d3acbf7edadc7cf81a237c4888f43 /src/theory/arith/linear_equality.cpp
parent9098391fe334d829ec4101f190b8f1fa21c30752 (diff)
Adding has bound counts and tracking for rows.
Diffstat (limited to 'src/theory/arith/linear_equality.cpp')
-rw-r--r--src/theory/arith/linear_equality.cpp225
1 files changed, 95 insertions, 130 deletions
diff --git a/src/theory/arith/linear_equality.cpp b/src/theory/arith/linear_equality.cpp
index 42d8b41f8..eda3bf682 100644
--- a/src/theory/arith/linear_equality.cpp
+++ b/src/theory/arith/linear_equality.cpp
@@ -31,12 +31,6 @@ template void LinearEqualityModule::propagateNonbasics<false>(ArithVar basic, Co
template ArithVar LinearEqualityModule::selectSlack<true>(ArithVar x_i, VarPreferenceFunction pf) const;
template ArithVar LinearEqualityModule::selectSlack<false>(ArithVar x_i, VarPreferenceFunction pf) const;
-// template bool LinearEqualityModule::preferNonDegenerate<true>(const UpdateInfo& a, const UpdateInfo& b) const;
-// template bool LinearEqualityModule::preferNonDegenerate<false>(const UpdateInfo& a, const UpdateInfo& b) const;
-
-// template bool LinearEqualityModule::preferErrorsFixed<true>(const UpdateInfo& a, const UpdateInfo& b) const;
-// template bool LinearEqualityModule::preferErrorsFixed<false>(const UpdateInfo& a, const UpdateInfo& b) const;
-
template bool LinearEqualityModule::preferWitness<true>(const UpdateInfo& a, const UpdateInfo& b) const;
template bool LinearEqualityModule::preferWitness<false>(const UpdateInfo& a, const UpdateInfo& b) const;
@@ -57,14 +51,14 @@ void Border::output(std::ostream& out) const{
<< "}";
}
-LinearEqualityModule::LinearEqualityModule(ArithVariables& vars, Tableau& t, BoundCountingVector& boundTracking, BasicVarModelUpdateCallBack f):
+LinearEqualityModule::LinearEqualityModule(ArithVariables& vars, Tableau& t, BoundInfoMap& boundsTracking, BasicVarModelUpdateCallBack f):
d_variables(vars),
d_tableau(t),
d_basicVariableUpdates(f),
d_increasing(1),
d_decreasing(-1),
d_relevantErrorBuffer(),
- d_boundTracking(boundTracking),
+ d_btracking(boundsTracking),
d_areTracking(false),
d_trackCallback(this)
{}
@@ -103,31 +97,24 @@ LinearEqualityModule::Statistics::~Statistics(){
StatisticsRegistry::unregisterStat(&d_weakenings);
StatisticsRegistry::unregisterStat(&d_weakenTime);
}
-void LinearEqualityModule::includeBoundCountChange(ArithVar nb, BoundCounts prev){
- if(d_tableau.isBasic(nb)){
- return;
- }
- Assert(!d_tableau.isBasic(nb));
+void LinearEqualityModule::includeBoundUpdate(ArithVar v, const BoundsInfo& prev){
Assert(!d_areTracking);
- BoundCounts curr = d_variables.boundCounts(nb);
+ BoundsInfo curr = d_variables.boundsInfo(v);
Assert(prev != curr);
- Tableau::ColIterator basicIter = d_tableau.colIterator(nb);
+ Tableau::ColIterator basicIter = d_tableau.colIterator(v);
for(; !basicIter.atEnd(); ++basicIter){
const Tableau::Entry& entry = *basicIter;
- Assert(entry.getColVar() == nb);
+ Assert(entry.getColVar() == v);
int a_ijSgn = entry.getCoefficient().sgn();
- ArithVar basic = d_tableau.rowIndexToBasic(entry.getRowIndex());
-
- BoundCounts& counts = d_boundTracking.get(basic);
- Debug("includeBoundCountChange") << basic << " " << counts << " to " ;
- counts -= prev.multiplyBySgn(a_ijSgn);
- counts += curr.multiplyBySgn(a_ijSgn);
- Debug("includeBoundCountChange") << counts << " " << a_ijSgn << std::endl;
+ RowIndex ridx = entry.getRowIndex();
+ BoundsInfo& counts = d_btracking.get(ridx);
+ Debug("includeBoundUpdate") << d_tableau.rowIndexToBasic(ridx) << " " << counts << " to " ;
+ counts.addInChange(a_ijSgn, prev, curr);
+ Debug("includeBoundUpdate") << counts << " " << a_ijSgn << std::endl;
}
- d_boundTracking.set(nb, curr);
}
void LinearEqualityModule::updateMany(const DenseMap<DeltaRational>& many){
@@ -231,9 +218,9 @@ void LinearEqualityModule::updateTracked(ArithVar x_i, const DeltaRational& v){
<< d_variables.getAssignment(x_i) << "|-> " << v << endl;
- BoundCounts before = d_variables.boundCounts(x_i);
+ BoundCounts before = d_variables.atBoundCounts(x_i);
d_variables.setAssignment(x_i, v);
- BoundCounts after = d_variables.boundCounts(x_i);
+ BoundCounts after = d_variables.atBoundCounts(x_i);
bool anyChange = before != after;
@@ -242,17 +229,24 @@ void LinearEqualityModule::updateTracked(ArithVar x_i, const DeltaRational& v){
const Tableau::Entry& entry = *colIter;
Assert(entry.getColVar() == x_i);
- ArithVar x_j = d_tableau.rowIndexToBasic(entry.getRowIndex());
+ RowIndex ridx = entry.getRowIndex();
+ ArithVar x_j = d_tableau.rowIndexToBasic(ridx);
const Rational& a_ji = entry.getCoefficient();
const DeltaRational& assignment = d_variables.getAssignment(x_j);
DeltaRational nAssignment = assignment+(diff * a_ji);
Debug("update") << x_j << " " << a_ji << assignment << " -> " << nAssignment << endl;
+ BoundCounts xjBefore = d_variables.atBoundCounts(x_j);
d_variables.setAssignment(x_j, nAssignment);
+ BoundCounts xjAfter = d_variables.atBoundCounts(x_j);
- if(anyChange && basicIsTracked(x_j)){
- BoundCounts& next_bc_k = d_boundTracking.get(x_j);
- next_bc_k.addInChange(a_ji.sgn(), before, after);
+ Assert(rowIndexIsTracked(ridx));
+ BoundsInfo& next_bc_k = d_btracking.get(ridx);
+ if(anyChange){
+ next_bc_k.addInAtBoundChange(a_ji.sgn(), before, after);
+ }
+ if(xjBefore != xjAfter){
+ next_bc_k.addInAtBoundChange(-1, xjBefore, xjAfter);
}
d_basicVariableUpdates(x_j);
@@ -332,7 +326,7 @@ void LinearEqualityModule::debugCheckTracking(){
ArithVar var = entry.getColVar();
const Rational& coeff = entry.getCoefficient();
DeltaRational beta = d_variables.getAssignment(var);
- Debug("arith::tracking") << var << " " << d_variables.boundCounts(var)
+ Debug("arith::tracking") << var << " " << d_variables.boundsInfo(var)
<< " " << beta << coeff;
if(d_variables.hasLowerBound(var)){
Debug("arith::tracking") << "(lb " << d_variables.getLowerBound(var) << ")";
@@ -345,11 +339,12 @@ void LinearEqualityModule::debugCheckTracking(){
Debug("arith::tracking") << "end row"<< endl;
if(basicIsTracked(basic)){
- BoundCounts computed = computeBoundCounts(basic);
+ RowIndex ridx = d_tableau.basicToRowIndex(basic);
+ BoundsInfo computed = computeRowBoundInfo(ridx, false);
Debug("arith::tracking")
<< "computed " << computed
- << " tracking " << d_boundTracking[basic] << endl;
- Assert(computed == d_boundTracking[basic]);
+ << " tracking " << d_btracking[ridx] << endl;
+ Assert(computed == d_btracking[ridx]);
}
}
@@ -745,60 +740,34 @@ void LinearEqualityModule::stopTrackingBoundCounts(){
}
-void LinearEqualityModule::trackVariable(ArithVar x_i){
- Assert(!basicIsTracked(x_i));
- BoundCounts counts(0,0);
-
- for(Tableau::RowIterator iter = d_tableau.basicRowIterator(x_i); !iter.atEnd(); ++iter){
- const Tableau::Entry& entry = *iter;
- ArithVar nonbasic = entry.getColVar();
- if(nonbasic == x_i) continue;
-
- const Rational& a_ij = entry.getCoefficient();
- counts += (d_variables.oldBoundCounts(nonbasic)).multiplyBySgn(a_ij.sgn());
- }
- d_boundTracking.set(x_i, counts);
+void LinearEqualityModule::trackRowIndex(RowIndex ridx){
+ Assert(!rowIndexIsTracked(ridx));
+ BoundsInfo bi = computeRowBoundInfo(ridx, true);
+ d_btracking.set(ridx, bi);
}
-BoundCounts LinearEqualityModule::computeBoundCounts(ArithVar x_i) const{
- BoundCounts counts(0,0);
+BoundsInfo LinearEqualityModule::computeRowBoundInfo(RowIndex ridx, bool inQueue) const{
+ BoundsInfo bi;
- for(Tableau::RowIterator iter = d_tableau.basicRowIterator(x_i); !iter.atEnd(); ++iter){
+ for(Tableau::RowIterator iter = d_tableau.getRow(ridx).begin(); !iter.atEnd(); ++iter){
const Tableau::Entry& entry = *iter;
- ArithVar nonbasic = entry.getColVar();
- if(nonbasic == x_i) continue;
-
+ ArithVar v = entry.getColVar();
const Rational& a_ij = entry.getCoefficient();
- counts += (d_variables.boundCounts(nonbasic)).multiplyBySgn(a_ij.sgn());
+ bi += (d_variables.selectBoundsInfo(v, inQueue)).multiplyBySgn(a_ij.sgn());
}
-
- return counts;
+ return bi;
}
-// BoundCounts LinearEqualityModule::cachingCountBounds(ArithVar x_i) const{
-// if(d_boundTracking.isKey(x_i)){
-// return d_boundTracking[x_i];
-// }else{
-// return computeBoundCounts(x_i);
-// }
-// }
-BoundCounts LinearEqualityModule::_countBounds(ArithVar x_i) const {
- Assert(d_boundTracking.isKey(x_i));
- return d_boundTracking[x_i];
+BoundCounts LinearEqualityModule::debugBasicAtBoundCount(ArithVar x_i) const {
+ return d_btracking[d_tableau.basicToRowIndex(x_i)].atBounds();
}
-// BoundCounts LinearEqualityModule::countBounds(ArithVar x_i){
-// if(d_boundTracking.isKey(x_i)){
-// return d_boundTracking[x_i];
-// }else{
-// BoundCounts bc = computeBoundCounts(x_i);
-// if(d_areTracking){
-// d_boundTracking.set(x_i,bc);
-// }
-// return bc;
-// }
-// }
-
+/**
+ * If the pivot described in u were performed,
+ * then the row would qualify as being either at the minimum/maximum
+ * to the non-basics being at their bounds.
+ * The minimum/maximum is determined by the direction the non-basic is changing.
+ */
bool LinearEqualityModule::basicsAtBounds(const UpdateInfo& u) const {
Assert(u.describesPivot());
@@ -814,79 +783,78 @@ bool LinearEqualityModule::basicsAtBounds(const UpdateInfo& u) const {
int toLB = (c->getType() == LowerBound ||
c->getType() == Equality) ? 1 : 0;
+ RowIndex ridx = d_tableau.basicToRowIndex(basic);
- BoundCounts bcs = d_boundTracking[basic];
+ BoundCounts bcs = d_btracking[ridx].atBounds();
// x = c*n + \sum d*m
- // n = 1/c * x + -1/c * (\sum d*m)
- BoundCounts nonb = bcs - d_variables.boundCounts(nonbasic).multiplyBySgn(coeffSgn);
+ // 0 = -x + c*n + \sum d*m
+ // n = 1/c * x + -1/c * (\sum d*m)
+ BoundCounts nonb = bcs - d_variables.atBoundCounts(nonbasic).multiplyBySgn(coeffSgn);
+ nonb.addInChange(-1, d_variables.atBoundCounts(basic), BoundCounts(toLB, toUB));
nonb = nonb.multiplyBySgn(-coeffSgn);
- nonb += BoundCounts(toLB, toUB).multiplyBySgn(coeffSgn);
uint32_t length = d_tableau.basicRowLength(basic);
Debug("basicsAtBounds")
<< "bcs " << bcs
<< "nonb " << nonb
<< "length " << length << endl;
-
+ // nonb has nb excluded.
if(nbdir < 0){
- return bcs.atLowerBounds() + 1 == length;
+ return nonb.lowerBoundCount() + 1 == length;
}else{
Assert(nbdir > 0);
- return bcs.atUpperBounds() + 1 == length;
+ return nonb.upperBoundCount() + 1 == length;
}
}
bool LinearEqualityModule::nonbasicsAtLowerBounds(ArithVar basic) const {
Assert(basicIsTracked(basic));
- BoundCounts bcs = d_boundTracking[basic];
+ RowIndex ridx = d_tableau.basicToRowIndex(basic);
+
+ BoundCounts bcs = d_btracking[ridx].atBounds();
uint32_t length = d_tableau.basicRowLength(basic);
- return bcs.atLowerBounds() + 1 == length;
+ // return true if excluding the basic is every element is at its "lowerbound"
+ // The psuedo code is:
+ // bcs -= basic.count(basic, basic's sgn)
+ // return bcs.lowerBoundCount() + 1 == length
+ // As basic's sign is always -1, we can pull out the pieces of the count:
+ // bcs.lowerBoundCount() - basic.atUpperBoundInd() + 1 == length
+ // basic.atUpperBoundInd() is either 0 or 1
+ uint32_t lbc = bcs.lowerBoundCount();
+ return (lbc == length) ||
+ (lbc + 1 == length && d_variables.cmpAssignmentUpperBound(basic) != 0);
}
bool LinearEqualityModule::nonbasicsAtUpperBounds(ArithVar basic) const {
Assert(basicIsTracked(basic));
- BoundCounts bcs = d_boundTracking[basic];
+ RowIndex ridx = d_tableau.basicToRowIndex(basic);
+ BoundCounts bcs = d_btracking[ridx].atBounds();
uint32_t length = d_tableau.basicRowLength(basic);
+ uint32_t ubc = bcs.upperBoundCount();
+ // See the comment for nonbasicsAtLowerBounds()
- return bcs.atUpperBounds() + 1 == length;
+ return (ubc == length) ||
+ (ubc + 1 == length && d_variables.cmpAssignmentLowerBound(basic) != 0);
}
-void LinearEqualityModule::trackingSwap(ArithVar basic, ArithVar nb, int nbSgn) {
- Assert(basicIsTracked(basic));
-
- // z = a*x + \sum b*y
- // x = (1/a) z + \sum (-1/a)*b*y
- // basicCount(z) = bc(a*x) + bc(\sum b y)
- // basicCount(x) = bc(z/a) + bc(\sum -b/a * y)
-
- // sgn(1/a) = sgn(a)
- // bc(a*x) = bc(x).multiply(sgn(a))
- // bc(z/a) = bc(z).multiply(sgn(a))
- // bc(\sum -b/a * y) = bc(\sum b y).multiplyBySgn(-sgn(a))
- // bc(\sum b y) = basicCount(z) - bc(a*x)
- // basicCount(x) =
- // = bc(z).multiply(sgn(a)) + (basicCount(z) - bc(a*x)).multiplyBySgn(-sgn(a))
-
- BoundCounts bc = d_boundTracking[basic];
- bc -= (d_variables.boundCounts(nb)).multiplyBySgn(nbSgn);
- bc = bc.multiplyBySgn(-nbSgn);
- bc += d_variables.boundCounts(basic).multiplyBySgn(nbSgn);
- d_boundTracking.set(nb, bc);
- d_boundTracking.remove(basic);
+void LinearEqualityModule::trackingMultiplyRow(RowIndex ridx, int sgn) {
+ Assert(rowIndexIsTracked(ridx));
+ Assert(sgn != 0);
+ if(sgn < 0){
+ BoundsInfo& bi = d_btracking.get(ridx);
+ bi = bi.multiplyBySgn(sgn);
+ }
}
void LinearEqualityModule::trackingCoefficientChange(RowIndex ridx, ArithVar nb, int oldSgn, int currSgn){
Assert(oldSgn != currSgn);
- BoundCounts nb_bc = d_variables.boundCounts(nb);
+ BoundsInfo nb_inf = d_variables.boundsInfo(nb);
- if(!nb_bc.isZero()){
- ArithVar basic = d_tableau.rowIndexToBasic(ridx);
- Assert(basicIsTracked(basic));
+ Assert(rowIndexIsTracked(ridx));
- BoundCounts& basic_bc = d_boundTracking.get(basic);
- basic_bc.addInSgn(nb_bc, oldSgn, currSgn);
- }
+ BoundsInfo& row_bi = d_btracking.get(ridx);
+ row_bi.addInSgn(nb_inf, oldSgn, currSgn);
}
void LinearEqualityModule::computeSafeUpdate(UpdateInfo& inf, VarPreferenceFunction pref){
@@ -895,9 +863,6 @@ void LinearEqualityModule::computeSafeUpdate(UpdateInfo& inf, VarPreferenceFunct
Assert(sgn != 0);
Assert(!d_tableau.isBasic(nb));
- //inf.setErrorsChange(0);
- //inf.setlimiting = NullConstraint;
-
// Error variables moving in the correct direction
Assert(d_relevantErrorBuffer.empty());
@@ -1188,8 +1153,9 @@ bool LinearEqualityModule::willBeInConflictAfterPivot(const Tableau::Entry& entr
// Assume past this point, nb will be in error if this pivot is done
ArithVar nb = entry.getColVar();
- ArithVar basic = d_tableau.rowIndexToBasic(entry.getRowIndex());
- Assert(basicIsTracked(basic));
+ RowIndex ridx = entry.getRowIndex();
+ ArithVar basic = d_tableau.rowIndexToBasic(ridx);
+ Assert(rowIndexIsTracked(ridx));
int coeffSgn = entry.getCoefficient().sgn();
@@ -1201,12 +1167,11 @@ bool LinearEqualityModule::willBeInConflictAfterPivot(const Tableau::Entry& entr
// 2) -a * x = -y + \sum b * z
// 3) x = (-1/a) * ( -y + \sum b * z)
- Assert(basicIsTracked(basic));
- BoundCounts bc = d_boundTracking[basic];
+ BoundCounts bc = d_btracking[ridx].atBounds();
// 1) y = a * x + \sum b * z
// Get bc(\sum b * z)
- BoundCounts sumOnly = bc - d_variables.boundCounts(nb).multiplyBySgn(coeffSgn);
+ BoundCounts sumOnly = bc - d_variables.atBoundCounts(nb).multiplyBySgn(coeffSgn);
// y's bounds in the proposed model
int yWillBeAtUb = (bToUB || d_variables.boundsAreEqual(basic)) ? 1 : 0;
@@ -1215,19 +1180,19 @@ bool LinearEqualityModule::willBeInConflictAfterPivot(const Tableau::Entry& entr
// 2) -a * x = -y + \sum b * z
// Get bc(-y + \sum b * z)
- BoundCounts withNegY = sumOnly + ysBounds.multiplyBySgn(-1);
+ sumOnly.addInChange(-1, d_variables.atBoundCounts(basic), ysBounds);
// 3) x = (-1/a) * ( -y + \sum b * z)
// Get bc((-1/a) * ( -y + \sum b * z))
- BoundCounts xsBoundsAfterPivot = withNegY.multiplyBySgn(-coeffSgn);
+ BoundCounts xsBoundsAfterPivot = sumOnly.multiplyBySgn(-coeffSgn);
uint32_t length = d_tableau.basicRowLength(basic);
if(nbSgn > 0){
// Only check for the upper bound being violated
- return xsBoundsAfterPivot.atLowerBounds() + 1 == length;
+ return xsBoundsAfterPivot.lowerBoundCount() + 1 == length;
}else{
// Only check for the lower bound being violated
- return xsBoundsAfterPivot.atUpperBounds() + 1 == length;
+ return xsBoundsAfterPivot.upperBoundCount() + 1 == length;
}
}
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback