summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTim King <taking@cs.nyu.edu>2011-02-22 01:13:56 +0000
committerTim King <taking@cs.nyu.edu>2011-02-22 01:13:56 +0000
commitc40d5678a4bbd73bde711149004206e37176661b (patch)
tree8df1349d7568768e7e8f9f58b2361884dc9fd830 /src
parenta101b2e309dd2818a85c954e45af586e530e289a (diff)
- Adds column based iterators.
Diffstat (limited to 'src')
-rw-r--r--src/theory/arith/arithvar_set.h8
-rw-r--r--src/theory/arith/row_vector.cpp71
-rw-r--r--src/theory/arith/row_vector.h17
-rw-r--r--src/theory/arith/simplex.cpp67
-rw-r--r--src/theory/arith/tableau.cpp17
-rw-r--r--src/theory/arith/tableau.h33
6 files changed, 176 insertions, 37 deletions
diff --git a/src/theory/arith/arithvar_set.h b/src/theory/arith/arithvar_set.h
index de215696e..ff75b373a 100644
--- a/src/theory/arith/arithvar_set.h
+++ b/src/theory/arith/arithvar_set.h
@@ -37,8 +37,9 @@ namespace arith {
*/
class ArithVarSet {
-private:
+public:
typedef std::vector<ArithVar> VarList;
+private:
//List of the ArithVars in the set.
VarList d_list;
@@ -49,7 +50,7 @@ private:
public:
typedef VarList::const_iterator iterator;
- ArithVarSet() : d_list(), d_posVector() {}
+ ArithVarSet() : d_list(), d_posVector() {}
size_t size() const {
return d_list.size();
@@ -95,6 +96,9 @@ public:
iterator begin() const{ return d_list.begin(); }
iterator end() const{ return d_list.end(); }
+ const VarList& getList() const{
+ return d_list;
+ }
/** Invalidates iterators */
void remove(ArithVar x){
diff --git a/src/theory/arith/row_vector.cpp b/src/theory/arith/row_vector.cpp
index 2af03bf08..2463adf47 100644
--- a/src/theory/arith/row_vector.cpp
+++ b/src/theory/arith/row_vector.cpp
@@ -29,6 +29,18 @@ RowVector::~RowVector(){
Assert(d_rowCount[v] >= 1);
--(d_rowCount[v]);
}
+
+ Assert(matchingCounts());
+}
+
+bool RowVector::matchingCounts() const{
+ for(NonZeroIterator i=beginNonZero(), end=endNonZero(); i != end; ++i){
+ ArithVar v = getArithVar(*i);
+ if(d_columnMatrix[v].size() != d_rowCount[v]){
+ return false;
+ }
+ }
+ return true;
}
bool RowVector::noZeroCoefficients(const VarCoeffArray& arr){
@@ -61,8 +73,9 @@ void RowVector::zip(const std::vector< ArithVar >& variables,
RowVector::RowVector(const std::vector< ArithVar >& variables,
const std::vector< Rational >& coefficients,
- std::vector<uint32_t>& counts):
- d_rowCount(counts)
+ std::vector<uint32_t>& counts,
+ std::vector<ArithVarSet>& cm):
+ d_rowCount(counts), d_columnMatrix(cm)
{
zip(variables, coefficients, d_entries);
@@ -94,7 +107,9 @@ void RowVector::merge(VarCoeffArray& arr,
ArithVarContainsSet& contains,
const VarCoeffArray& other,
const Rational& c,
- std::vector<uint32_t>& counts){
+ std::vector<uint32_t>& counts,
+ std::vector<ArithVarSet>& columnMatrix,
+ ArithVar basic){
VarCoeffArray copy = arr;
arr.clear();
@@ -109,7 +124,11 @@ void RowVector::merge(VarCoeffArray& arr,
arr.push_back(*curr1);
++curr1;
}else if(getArithVar(*curr1) > getArithVar(*curr2)){
+
++counts[getArithVar(*curr2)];
+ if(basic != ARITHVAR_SENTINEL){
+ columnMatrix[getArithVar(*curr2)].add(basic);
+ }
addArithVar(contains, getArithVar(*curr2));
arr.push_back( make_pair(getArithVar(*curr2), c * getCoefficient(*curr2)));
@@ -118,12 +137,15 @@ void RowVector::merge(VarCoeffArray& arr,
Rational res = getCoefficient(*curr1) + c * getCoefficient(*curr2);
if(res != 0){
//The variable is not new so the count stays the same
- //bug: ++counts[getArithVar(*curr2)];
arr.push_back(make_pair(getArithVar(*curr1), res));
}else{
removeArithVar(contains, getArithVar(*curr2));
+
--counts[getArithVar(*curr2)];
+ if(basic != ARITHVAR_SENTINEL){
+ columnMatrix[getArithVar(*curr2)].remove(basic);
+ }
}
++curr1;
++curr2;
@@ -135,6 +157,9 @@ void RowVector::merge(VarCoeffArray& arr,
}
while(curr2 != end2){
++counts[getArithVar(*curr2)];
+ if(basic != ARITHVAR_SENTINEL){
+ columnMatrix[getArithVar(*curr2)].add(basic);
+ }
addArithVar(contains, getArithVar(*curr2));
@@ -151,10 +176,10 @@ void RowVector::multiply(const Rational& c){
}
}
-void RowVector::addRowTimesConstant(const Rational& c, const RowVector& other){
+void RowVector::addRowTimesConstant(const Rational& c, const RowVector& other, ArithVar basic){
Assert(c != 0);
- merge(d_entries, d_contains, other.d_entries, c, d_rowCount);
+ merge(d_entries, d_contains, other.d_entries, c, d_rowCount, d_columnMatrix, basic);
}
void RowVector::printRow(){
@@ -165,18 +190,27 @@ void RowVector::printRow(){
Debug("row::print") << std::endl;
}
+
ReducedRowVector::ReducedRowVector(ArithVar basic,
const std::vector<ArithVar>& variables,
const std::vector<Rational>& coefficients,
- std::vector<uint32_t>& count):
- RowVector(variables, coefficients, count), d_basic(basic){
+ std::vector<uint32_t>& count,
+ std::vector<ArithVarSet>& columnMatrix):
+ RowVector(variables, coefficients, count, columnMatrix), d_basic(basic){
+ for(NonZeroIterator i=beginNonZero(), end=endNonZero(); i != end; ++i){
+ //basic is not yet in d_entries
+ Assert(getArithVar(*i) != d_basic);
+ d_columnMatrix[getArithVar(*i)].add(d_basic);
+ }
+
VarCoeffArray justBasic;
justBasic.push_back(make_pair(basic, Rational(-1)));
- merge(d_entries, d_contains, justBasic, Rational(1), d_rowCount);
+ merge(d_entries, d_contains, justBasic, Rational(1), d_rowCount, d_columnMatrix, d_basic);
+ Assert(matchingCounts());
Assert(wellFormed());
Assert(d_rowCount[d_basic] == 1);
}
@@ -190,10 +224,12 @@ void ReducedRowVector::substitute(const ReducedRowVector& row_s){
Rational a_rs = lookup(x_s);
Assert(a_rs != 0);
- addRowTimesConstant(a_rs, row_s);
+ addRowTimesConstant(a_rs, row_s, basic());
+
Assert(!has(x_s));
Assert(wellFormed());
+ Assert(matchingCounts());
Assert(d_rowCount[basic()] == 1);
}
@@ -202,8 +238,15 @@ void ReducedRowVector::pivot(ArithVar x_j){
Assert(basic() != x_j);
Rational negInverseA_rs = -(lookup(x_j).inverse());
multiply(negInverseA_rs);
+
+ for(NonZeroIterator i=beginNonZero(), end=endNonZero(); i != end; ++i){
+ d_columnMatrix[getArithVar(*i)].remove(d_basic);
+ d_columnMatrix[getArithVar(*i)].add(x_j);
+ }
+
d_basic = x_j;
+ Assert(matchingCounts());
Assert(wellFormed());
//The invariant Assert(d_rowCount[basic()] == 1); does not hold.
//This is because the pivot is within the row first then
@@ -249,4 +292,12 @@ ReducedRowVector::~ReducedRowVector(){
//This executes before the super classes destructor RowVector,
// which will set this to 0.
Assert(d_rowCount[basic()] == 1);
+
+ NonZeroIterator curr = beginNonZero();
+ NonZeroIterator end = endNonZero();
+ for(;curr != end; ++curr){
+ ArithVar v = getArithVar(*curr);
+ Assert(d_rowCount[v] >= 1);
+ d_columnMatrix[v].remove(basic());
+ }
}
diff --git a/src/theory/arith/row_vector.h b/src/theory/arith/row_vector.h
index 85a188063..29b79ddd5 100644
--- a/src/theory/arith/row_vector.h
+++ b/src/theory/arith/row_vector.h
@@ -6,6 +6,7 @@
#define __CVC4__THEORY__ARITH__ROW_VECTOR_H
#include "theory/arith/arith_utilities.h"
+#include "theory/arith/arithvar_set.h"
#include "util/rational.h"
#include <vector>
@@ -52,7 +53,9 @@ public:
ArithVarContainsSet& contains,
const VarCoeffArray& other,
const Rational& c,
- std::vector<uint32_t>& count);
+ std::vector<uint32_t>& count,
+ std::vector<ArithVarSet>& columnMatrix,
+ ArithVar basic);
protected:
/**
@@ -62,6 +65,9 @@ protected:
*/
static bool noZeroCoefficients(const VarCoeffArray& arr);
+ /** Debugging code.*/
+ bool matchingCounts() const;
+
/**
* Invariants:
* - isSorted(d_entries, true)
@@ -76,6 +82,7 @@ protected:
ArithVarContainsSet d_contains;
std::vector<uint32_t>& d_rowCount;
+ std::vector<ArithVarSet>& d_columnMatrix;
NonZeroIterator lower_bound(ArithVar x_j) const{
return std::lower_bound(d_entries.begin(), d_entries.end(), make_pair(x_j,0), cmp);
@@ -87,7 +94,8 @@ public:
RowVector(const std::vector< ArithVar >& variables,
const std::vector< Rational >& coefficients,
- std::vector<uint32_t>& counts);
+ std::vector<uint32_t>& counts,
+ std::vector<ArithVarSet>& columnMatrix);
~RowVector();
@@ -135,7 +143,7 @@ public:
* Updates the current row to be the sum of itself and
* another vector times c (c != 0).
*/
- void addRowTimesConstant(const Rational& c, const RowVector& other);
+ void addRowTimesConstant(const Rational& c, const RowVector& other, ArithVar basic);
void printRow();
@@ -176,7 +184,8 @@ public:
ReducedRowVector(ArithVar basic,
const std::vector< ArithVar >& variables,
const std::vector< Rational >& coefficients,
- std::vector<uint32_t>& count);
+ std::vector<uint32_t>& count,
+ std::vector<ArithVarSet>& columnMatrix);
~ReducedRowVector();
diff --git a/src/theory/arith/simplex.cpp b/src/theory/arith/simplex.cpp
index d837d7ac0..2785222e3 100644
--- a/src/theory/arith/simplex.cpp
+++ b/src/theory/arith/simplex.cpp
@@ -168,6 +168,37 @@ bool SimplexDecisionProcedure::AssertEquality(ArithVar x_i, const DeltaRational&
return false;
}
+set<ArithVar> tableauAndHasSet(Tableau& tab, ArithVar v){
+ set<ArithVar> has;
+ for(ArithVarSet::iterator basicIter = tab.begin();
+ basicIter != tab.end();
+ ++basicIter){
+ ArithVar basic = *basicIter;
+ ReducedRowVector& row = tab.lookup(basic);
+
+ if(row.has(v)){
+ has.insert(basic);
+ }
+ }
+ return has;
+}
+
+set<ArithVar> columnIteratorSet(Tableau& tab,ArithVar v){
+ set<ArithVar> has;
+ ArithVarSet::iterator basicIter = tab.beginColumn(v);
+ ArithVarSet::iterator endIter = tab.endColumn(v);
+ for(; basicIter != endIter; ++basicIter){
+ ArithVar basic = *basicIter;
+ has.insert(basic);
+ }
+ return has;
+}
+
+
+bool matchingSets(Tableau& tab, ArithVar v){
+ return tableauAndHasSet(tab, v) == columnIteratorSet(tab, v);
+}
+
void SimplexDecisionProcedure::update(ArithVar x_i, const DeltaRational& v){
Assert(!d_tableau.isBasic(x_i));
DeltaRational assignment_x_i = d_partialModel.getAssignment(x_i);
@@ -177,22 +208,21 @@ void SimplexDecisionProcedure::update(ArithVar x_i, const DeltaRational& v){
<< assignment_x_i << "|-> " << v << endl;
DeltaRational diff = v - assignment_x_i;
- for(ArithVarSet::iterator basicIter = d_tableau.begin();
- basicIter != d_tableau.end();
- ++basicIter){
+ Assert(matchingSets(d_tableau, x_i));
+ ArithVarSet::iterator basicIter = d_tableau.beginColumn(x_i);
+ ArithVarSet::iterator endIter = d_tableau.endColumn(x_i);
+ for(; basicIter != endIter; ++basicIter){
ArithVar x_j = *basicIter;
ReducedRowVector& row_j = d_tableau.lookup(x_j);
- if(row_j.has(x_i)){
- const Rational& a_ji = row_j.lookup(x_i);
+ Assert(row_j.has(x_i));
+ const Rational& a_ji = row_j.lookup(x_i);
- const DeltaRational& assignment = d_partialModel.getAssignment(x_j);
- DeltaRational nAssignment = assignment+(diff * a_ji);
- d_partialModel.setAssignment(x_j, nAssignment);
+ const DeltaRational& assignment = d_partialModel.getAssignment(x_j);
+ DeltaRational nAssignment = assignment+(diff * a_ji);
+ d_partialModel.setAssignment(x_j, nAssignment);
- d_queue.enqueueIfInconsistent(x_j);
- //checkBasicVariable(x_j);
- }
+ d_queue.enqueueIfInconsistent(x_j);
}
d_partialModel.setAssignment(x_i, v);
@@ -250,12 +280,21 @@ void SimplexDecisionProcedure::pivotAndUpdate(ArithVar x_i, ArithVar x_j, DeltaR
DeltaRational tmp = d_partialModel.getAssignment(x_j) + theta;
d_partialModel.setAssignment(x_j, tmp);
- ArithVarSet::iterator basicIter = d_tableau.begin(), end = d_tableau.end();
- for(; basicIter != end; ++basicIter){
+
+ Assert(matchingSets(d_tableau, x_j));
+ ArithVarSet::iterator basicIter = d_tableau.beginColumn(x_j);
+ ArithVarSet::iterator endIter = d_tableau.endColumn(x_j);
+ for(; basicIter != endIter; ++basicIter){
+
+ //ArithVarSet::iterator basicIter = d_tableau.begin(), end = d_tableau.end();
+ //for(; basicIter != end; ++basicIter){
ArithVar x_k = *basicIter;
ReducedRowVector& row_k = d_tableau.lookup(x_k);
- if(x_k != x_i && row_k.has(x_j)){
+ Assert(row_k.has(x_j));
+
+ //if(x_k != x_i && row_k.has(x_j)){
+ if(x_k != x_i ){
const Rational& a_kj = row_k.lookup(x_j);
DeltaRational nextAssignment = d_partialModel.getAssignment(x_k) + (theta * a_kj);
d_partialModel.setAssignment(x_k, nextAssignment);
diff --git a/src/theory/arith/tableau.cpp b/src/theory/arith/tableau.cpp
index d318a70e6..ebf7dbee8 100644
--- a/src/theory/arith/tableau.cpp
+++ b/src/theory/arith/tableau.cpp
@@ -41,7 +41,7 @@ void Tableau::addRow(ArithVar basicVar,
//The new basic variable cannot already be a basic variable
Assert(!d_basicVariables.isMember(basicVar));
d_basicVariables.add(basicVar);
- ReducedRowVector* row_current = new ReducedRowVector(basicVar,variables, coeffs,d_rowCount);
+ ReducedRowVector* row_current = new ReducedRowVector(basicVar,variables, coeffs,d_rowCount, d_columnMatrix);
d_rowsTable[basicVar] = row_current;
//A variable in the row may have been made non-basic already.
@@ -90,17 +90,22 @@ void Tableau::pivot(ArithVar x_r, ArithVar x_s){
row_s->pivot(x_s);
- for(ArithVarSet::iterator basicIter = begin(), endIter = end();
- basicIter != endIter; ++basicIter){
+ ArithVarSet::VarList copy(getColumn(x_s).getList());
+ vector<ArithVar>::iterator basicIter = copy.begin(), endIter = copy.end();
+
+ for(; basicIter != endIter; ++basicIter){
ArithVar basic = *basicIter;
if(basic == x_s) continue;
ReducedRowVector& row_k = lookup(basic);
- if(row_k.has(x_s)){
- row_k.substitute(*row_s);
- }
+ Assert(row_k.has(x_s));
+
+ row_k.substitute(*row_s);
}
+ Assert(getColumn(x_s).size() == 1);
+ Assert(getRowCount(x_s) == 1);
}
+
void Tableau::printTableau(){
Debug("tableau") << "Tableau::d_activeRows" << endl;
diff --git a/src/theory/arith/tableau.h b/src/theory/arith/tableau.h
index 36d61ba25..27aa1305c 100644
--- a/src/theory/arith/tableau.h
+++ b/src/theory/arith/tableau.h
@@ -37,6 +37,10 @@ namespace CVC4 {
namespace theory {
namespace arith {
+typedef ArithVarSet Column;
+
+typedef std::vector<Column> ColumnMatrix;
+
class Tableau {
private:
@@ -47,6 +51,7 @@ private:
ArithVarSet d_basicVariables;
std::vector<uint32_t> d_rowCount;
+ ColumnMatrix d_columnMatrix;
public:
/**
@@ -55,7 +60,8 @@ public:
Tableau() :
d_rowsTable(),
d_basicVariables(),
- d_rowCount()
+ d_rowCount(),
+ d_columnMatrix()
{}
~Tableau();
@@ -67,6 +73,16 @@ public:
d_basicVariables.increaseSize();
d_rowsTable.push_back(NULL);
d_rowCount.push_back(0);
+
+ d_columnMatrix.push_back(ArithVarSet());
+
+ //TODO replace with version of ArithVarSet that handles misses as non-entries
+ // not as buffer overflows
+ ColumnMatrix::iterator i = d_columnMatrix.begin(), end = d_columnMatrix.end();
+ for(; i != end; ++i){
+ Column& col = *i;
+ col.increaseSize(d_columnMatrix.size());
+ }
}
bool isBasic(ArithVar v) const {
@@ -81,6 +97,19 @@ public:
return d_basicVariables.end();
}
+ const Column& getColumn(ArithVar v){
+ Assert(v < d_columnMatrix.size());
+ return d_columnMatrix[v];
+ }
+
+ Column::iterator beginColumn(ArithVar v){
+ return getColumn(v).begin();
+ }
+ Column::iterator endColumn(ArithVar v){
+ return getColumn(v).end();
+ }
+
+
ReducedRowVector& lookup(ArithVar var){
Assert(d_basicVariables.isMember(var));
Assert(d_rowsTable[var] != NULL);
@@ -90,6 +119,8 @@ public:
public:
uint32_t getRowCount(ArithVar x){
Assert(x < d_rowCount.size());
+ AlwaysAssert(d_rowCount[x] == getColumn(x).size());
+
return d_rowCount[x];
}
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback