summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTim King <taking@cs.nyu.edu>2011-02-17 21:30:57 +0000
committerTim King <taking@cs.nyu.edu>2011-02-17 21:30:57 +0000
commitb2eba85abe17f3cb661b537d4ac6c55c2e222c65 (patch)
tree45be6dccebf40921566aceb3db2c0c5dbc4bacbc /src
parent595024febc8dc014518db8e74a489d3c6d169493 (diff)
This commit merges the branch branches/arithmetic/quick-row-has into trunk. quick-row-has has an optimization to make checking if a variable is in a row faster.
Diffstat (limited to 'src')
-rw-r--r--src/theory/arith/row_vector.cpp24
-rw-r--r--src/theory/arith/row_vector.h60
2 files changed, 74 insertions, 10 deletions
diff --git a/src/theory/arith/row_vector.cpp b/src/theory/arith/row_vector.cpp
index 01131c4c9..6486077fb 100644
--- a/src/theory/arith/row_vector.cpp
+++ b/src/theory/arith/row_vector.cpp
@@ -48,6 +48,7 @@ void RowVector::zip(const std::vector< ArithVar >& variables,
}
}
+
RowVector::RowVector(const std::vector< ArithVar >& variables,
const std::vector< Rational >& coefficients,
std::vector<uint32_t>& counts):
@@ -59,13 +60,28 @@ RowVector::RowVector(const std::vector< ArithVar >& variables,
for(NonZeroIterator i=beginNonZero(), end=endNonZero(); i != end; ++i){
++d_rowCount[getArithVar(*i)];
+ addArithVar(d_contains, getArithVar(*i));
}
Assert(isSorted(d_entries, true));
Assert(noZeroCoefficients(d_entries));
}
+void RowVector::addArithVar(ArithVarContainsSet& contains, ArithVar v){
+ if(v >= contains.size()){
+ contains.resize(v+1, false);
+ }
+ contains[v] = true;
+}
+
+void RowVector::removeArithVar(ArithVarContainsSet& contains, ArithVar v){
+ Assert(v < contains.size());
+ Assert(contains[v]);
+ contains[v] = false;
+}
+
void RowVector::merge(VarCoeffArray& arr,
+ ArithVarContainsSet& contains,
const VarCoeffArray& other,
const Rational& c,
std::vector<uint32_t>& counts){
@@ -85,6 +101,7 @@ void RowVector::merge(VarCoeffArray& arr,
}else if(getArithVar(*curr1) > getArithVar(*curr2)){
++counts[getArithVar(*curr2)];
+ addArithVar(contains, getArithVar(*curr2));
arr.push_back( make_pair(getArithVar(*curr2), c * getCoefficient(*curr2)));
++curr2;
}else{
@@ -94,6 +111,7 @@ void RowVector::merge(VarCoeffArray& arr,
arr.push_back(make_pair(getArithVar(*curr1), res));
}else{
+ removeArithVar(contains, getArithVar(*curr2));
--counts[getArithVar(*curr2)];
}
++curr1;
@@ -107,6 +125,8 @@ void RowVector::merge(VarCoeffArray& arr,
while(curr2 != end2){
++counts[getArithVar(*curr2)];
+ addArithVar(contains, getArithVar(*curr2));
+
arr.push_back(make_pair(getArithVar(*curr2), c * getCoefficient(*curr2)));
++curr2;
}
@@ -123,7 +143,7 @@ void RowVector::multiply(const Rational& c){
void RowVector::addRowTimesConstant(const Rational& c, const RowVector& other){
Assert(c != 0);
- merge(d_entries, other.d_entries, c, d_rowCount);
+ merge(d_entries, d_contains, other.d_entries, c, d_rowCount);
}
void RowVector::printRow(){
@@ -144,7 +164,7 @@ ReducedRowVector::ReducedRowVector(ArithVar basic,
VarCoeffArray justBasic;
justBasic.push_back(make_pair(basic, Rational(-1)));
- merge(d_entries,justBasic, Rational(1), d_rowCount);
+ merge(d_entries, d_contains, justBasic, Rational(1), d_rowCount);
Assert(wellFormed());
}
diff --git a/src/theory/arith/row_vector.h b/src/theory/arith/row_vector.h
index a967f8d68..05ceeb986 100644
--- a/src/theory/arith/row_vector.h
+++ b/src/theory/arith/row_vector.h
@@ -31,6 +31,8 @@ public:
typedef std::vector<VarCoeffPair> VarCoeffArray;
typedef VarCoeffArray::const_iterator NonZeroIterator;
+ typedef std::vector<bool> ArithVarContainsSet;
+
/**
* Let c be -1 if strictlySorted is true and c be 0 otherwise.
* isSorted(arr, strictlySorted) is then equivalent to
@@ -39,12 +41,6 @@ public:
static bool isSorted(const VarCoeffArray& arr, bool strictlySorted);
/**
- * noZeroCoefficients(arr) is equivalent to
- * 0 != getCoefficient(arr[i]) for all i.
- */
- static bool noZeroCoefficients(const VarCoeffArray& arr);
-
- /**
* Zips together an array of variables and coefficients and appends
* it to the end of an output vector.
*/
@@ -52,17 +48,33 @@ public:
const std::vector< Rational >& coefficients,
VarCoeffArray& output);
- static void merge(VarCoeffArray& arr, const VarCoeffArray& other, const Rational& c, std::vector<uint32_t>& count);
-
+ static void merge(VarCoeffArray& arr,
+ ArithVarContainsSet& contains,
+ const VarCoeffArray& other,
+ const Rational& c,
+ std::vector<uint32_t>& count);
protected:
/**
+ * Debugging code.
+ * noZeroCoefficients(arr) is equivalent to
+ * 0 != getCoefficient(arr[i]) for all i.
+ */
+ static bool noZeroCoefficients(const VarCoeffArray& arr);
+
+ /**
* Invariants:
* - isSorted(d_entries, true)
* - noZeroCoefficients(d_entries)
*/
VarCoeffArray d_entries;
+ /**
+ * Invariants:
+ * - This set is the same as the set maintained in d_entries.
+ */
+ ArithVarContainsSet d_contains;
+
std::vector<uint32_t>& d_rowCount;
NonZeroIterator lower_bound(ArithVar x_j) const{
@@ -89,14 +101,26 @@ public:
/** Returns true if the variable is in the row. */
bool has(ArithVar x_j) const{
+ if(x_j >= d_contains.size()){
+ return false;
+ }else{
+ return d_contains[x_j];
+ }
+ }
+
+private:
+ /** Debugging code. */
+ bool hasInEntries(ArithVar x_j) const {
return std::binary_search(d_entries.begin(), d_entries.end(), make_pair(x_j,0), cmp);
}
+public:
/**
* Returns the coefficient of a variable in the row.
*/
const Rational& lookup(ArithVar x_j) const{
Assert(has(x_j));
+ Assert(hasInEntries(x_j));
NonZeroIterator lb = lower_bound(x_j);
return getCoefficient(*lb);
}
@@ -113,6 +137,17 @@ public:
void addRowTimesConstant(const Rational& c, const RowVector& other);
void printRow();
+
+protected:
+ /**
+ * Adds v to d_contains.
+ * This may resize d_contains.
+ */
+ static void addArithVar(ArithVarContainsSet& contains, ArithVar v);
+
+ /** Removes v from d_contains. */
+ static void removeArithVar(ArithVarContainsSet& contains, ArithVar v);
+
}; /* class RowVector */
/**
@@ -148,6 +183,15 @@ public:
return d_basic;
}
+ /** Return true if x is in the row and is not the basic variable. */
+ bool hasNonBasic(ArithVar x) const {
+ if(x == basic()){
+ return false;
+ }else{
+ return has(x);
+ }
+ }
+
void pivot(ArithVar x_j);
void substitute(const ReducedRowVector& other);
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback