summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTim King <taking@cs.nyu.edu>2011-02-27 18:29:38 +0000
committerTim King <taking@cs.nyu.edu>2011-02-27 18:29:38 +0000
commit57fe149cf7915d721912e1d1866c31346f66e2f8 (patch)
treed6192a8c3d3ed8eab9d275523a0d254a85279a8d /src
parent8d52dbabb099cb66cfffe0d63397764d8a53b21c (diff)
- Makes VarCoeffPair a class instead of a typedef of pair<ArithVar, Rational>. This addresses a point Dejan brought up in the code review.
Diffstat (limited to 'src')
-rw-r--r--src/theory/arith/row_vector.cpp120
-rw-r--r--src/theory/arith/row_vector.h31
-rw-r--r--src/theory/arith/simplex.cpp24
3 files changed, 109 insertions, 66 deletions
diff --git a/src/theory/arith/row_vector.cpp b/src/theory/arith/row_vector.cpp
index 78ec55c2a..090938f28 100644
--- a/src/theory/arith/row_vector.cpp
+++ b/src/theory/arith/row_vector.cpp
@@ -11,10 +11,10 @@ bool ReducedRowVector::isSorted(const VarCoeffArray& arr, bool strictlySorted) {
if(arr.size() >= 2){
const_iterator curr = arr.begin();
const_iterator end = arr.end();
- ArithVar prev = getArithVar(*curr);
+ ArithVar prev = (*curr).getArithVar();
++curr;
for(;curr != end; ++curr){
- ArithVar v = getArithVar(*curr);
+ ArithVar v = (*curr).getArithVar();
if(strictlySorted && prev > v) return false;
if(!strictlySorted && prev >= v) return false;
prev = v;
@@ -31,7 +31,7 @@ ReducedRowVector::~ReducedRowVector(){
const_iterator curr = begin();
const_iterator endEntries = end();
for(;curr != endEntries; ++curr){
- ArithVar v = getArithVar(*curr);
+ ArithVar v = (*curr).getArithVar();
Assert(d_rowCount[v] >= 1);
d_columnMatrix[v].remove(basic());
--(d_rowCount[v]);
@@ -43,7 +43,7 @@ ReducedRowVector::~ReducedRowVector(){
bool ReducedRowVector::matchingCounts() const{
for(const_iterator i=begin(), endEntries=end(); i != endEntries; ++i){
- ArithVar v = getArithVar(*i);
+ ArithVar v = (*i).getArithVar();
if(d_columnMatrix[v].size() != d_rowCount[v]){
return false;
}
@@ -54,7 +54,7 @@ bool ReducedRowVector::matchingCounts() const{
bool ReducedRowVector::noZeroCoefficients(const VarCoeffArray& arr){
for(const_iterator curr = arr.begin(), endEntries = arr.end();
curr != endEntries; ++curr){
- const Rational& coeff = getCoefficient(*curr);
+ const Rational& coeff = (*curr).getCoefficient();
if(coeff == 0) return false;
}
return true;
@@ -74,7 +74,7 @@ void ReducedRowVector::zip(const std::vector< ArithVar >& variables,
const Rational& coeff = *coeffIter;
ArithVar var_i = *varIter;
- output.push_back(make_pair(var_i, coeff));
+ output.push_back(VarCoeffPair(var_i, coeff));
}
}
@@ -95,13 +95,14 @@ void ReducedRowVector::multiply(const Rational& c){
Assert(c != 0);
for(iterator i = d_entries.begin(), end = d_entries.end(); i != end; ++i){
- getCoefficient(*i) *= c;
+ (*i).getCoefficient() *= c;
}
}
void ReducedRowVector::addRowTimesConstant(const Rational& c, const ReducedRowVector& other){
Assert(c != 0);
Assert(d_buffer.empty());
+ Assert(wellFormed());
d_buffer.reserve(other.d_entries.size());
@@ -112,32 +113,34 @@ void ReducedRowVector::addRowTimesConstant(const Rational& c, const ReducedRowVe
const_iterator end2 = other.d_entries.end();
while(curr1 != end1 && curr2 != end2){
- if(getArithVar(*curr1) < getArithVar(*curr2)){
+ ArithVar var1 = (*curr1).getArithVar();
+ ArithVar var2 = (*curr2).getArithVar();
+
+ if(var1 < var2){
d_buffer.push_back(*curr1);
++curr1;
- }else if(getArithVar(*curr1) > getArithVar(*curr2)){
+ }else if(var1 > var2){
- ++d_rowCount[getArithVar(*curr2)];
- if(d_basic != ARITHVAR_SENTINEL){
- d_columnMatrix[getArithVar(*curr2)].add(d_basic);
- }
+ ++d_rowCount[var2];
+ d_columnMatrix[var2].add(d_basic);
- addArithVar(d_contains, getArithVar(*curr2));
- d_buffer.push_back( make_pair(getArithVar(*curr2), c * getCoefficient(*curr2)));
+ addArithVar(d_contains, var2);
+ const Rational& coeff2 = (*curr2).getCoefficient();
+ d_buffer.push_back( VarCoeffPair(var2, c * coeff2));
++curr2;
}else{
- Rational res = getCoefficient(*curr1) + c * getCoefficient(*curr2);
+ Assert(var1 == var2);
+ const Rational& coeff1 = (*curr1).getCoefficient();
+ const Rational& coeff2 = (*curr2).getCoefficient();
+ Rational res = coeff1 + (c * coeff2);
if(res != 0){
//The variable is not new so the count stays the same
-
- d_buffer.push_back(make_pair(getArithVar(*curr1), res));
+ d_buffer.push_back(VarCoeffPair(var1, res));
}else{
- removeArithVar(d_contains, getArithVar(*curr2));
+ removeArithVar(d_contains, var1);
- --d_rowCount[getArithVar(*curr2)];
- if(d_basic != ARITHVAR_SENTINEL){
- d_columnMatrix[getArithVar(*curr2)].remove(d_basic);
- }
+ --d_rowCount[var1];
+ d_columnMatrix[var1].remove(d_basic);
}
++curr1;
++curr2;
@@ -148,14 +151,14 @@ void ReducedRowVector::addRowTimesConstant(const Rational& c, const ReducedRowVe
++curr1;
}
while(curr2 != end2){
- ++d_rowCount[getArithVar(*curr2)];
- if(d_basic != ARITHVAR_SENTINEL){
- d_columnMatrix[getArithVar(*curr2)].add(d_basic);
- }
+ ArithVar var2 = (*curr2).getArithVar();
+ const Rational& coeff2 = (*curr2).getCoefficient();
+ ++d_rowCount[var2];
+ d_columnMatrix[var2].add(d_basic);
- addArithVar(d_contains, getArithVar(*curr2));
+ addArithVar(d_contains, var2);
- d_buffer.push_back(make_pair(getArithVar(*curr2), c * getCoefficient(*curr2)));
+ d_buffer.push_back(VarCoeffPair(var2, c * coeff2));
++curr2;
}
@@ -167,8 +170,9 @@ void ReducedRowVector::addRowTimesConstant(const Rational& c, const ReducedRowVe
void ReducedRowVector::printRow(){
for(const_iterator i = begin(); i != end(); ++i){
- ArithVar nb = getArithVar(*i);
- Debug("row::print") << "{" << nb << "," << getCoefficient(*i) << "}";
+ ArithVar nb = (*i).getArithVar();
+ const Rational& coeff = (*i).getCoefficient();
+ Debug("row::print") << "{" << nb << "," << coeff << "}";
}
Debug("row::print") << std::endl;
}
@@ -182,14 +186,15 @@ ReducedRowVector::ReducedRowVector(ArithVar basic,
d_basic(basic), d_rowCount(counts), d_columnMatrix(cm)
{
zip(variables, coefficients, d_entries);
- d_entries.push_back(make_pair(basic, Rational(-1)));
+ d_entries.push_back(VarCoeffPair(basic, Rational(-1)));
- std::sort(d_entries.begin(), d_entries.end(), cmp);
+ std::sort(d_entries.begin(), d_entries.end());
for(const_iterator i=begin(), endEntries=end(); i != endEntries; ++i){
- ++d_rowCount[getArithVar(*i)];
- addArithVar(d_contains, getArithVar(*i));
- d_columnMatrix[getArithVar(*i)].add(d_basic);
+ ArithVar var = (*i).getArithVar();
+ ++d_rowCount[var];
+ addArithVar(d_contains, var);
+ d_columnMatrix[var].add(d_basic);
}
Assert(isSorted(d_entries, true));
@@ -225,8 +230,9 @@ void ReducedRowVector::pivot(ArithVar x_j){
multiply(negInverseA_rs);
for(const_iterator i=begin(), endEntries=end(); i != endEntries; ++i){
- d_columnMatrix[getArithVar(*i)].remove(d_basic);
- d_columnMatrix[getArithVar(*i)].add(x_j);
+ ArithVar var = (*i).getArithVar();
+ d_columnMatrix[var].remove(d_basic);
+ d_columnMatrix[var].add(x_j);
}
d_basic = x_j;
@@ -243,15 +249,40 @@ Node ReducedRowVector::asEquality(const ArithVarToNodeMap& map) const{
using namespace CVC4::kind;
Assert(size() >= 2);
+
+ vector<Node> nonBasicPairs;
+ for(const_iterator i = begin(); i != end(); ++i){
+ ArithVar nb = (*i).getArithVar();
+ if(nb == basic()) continue;
+ Node var = (map.find(nb))->second;
+ Node coeff = mkRationalNode((*i).getCoefficient());
+
+ Node mult = NodeBuilder<2>(MULT) << coeff << var;
+ nonBasicPairs.push_back(mult);
+ }
+
+ Node sum = Node::null();
+ if(nonBasicPairs.size() == 1 ){
+ sum = nonBasicPairs.front();
+ }else{
+ Assert(nonBasicPairs.size() >= 2);
+ NodeBuilder<> sumBuilder(PLUS);
+ sumBuilder.append(nonBasicPairs);
+ sum = sumBuilder;
+ }
+ Node basicVar = (map.find(basic()))->second;
+ return NodeBuilder<2>(EQUAL) << basicVar << sum;
+
+ /*
Node sum = Node::null();
if(size() > 2){
NodeBuilder<> sumBuilder(PLUS);
for(const_iterator i = begin(); i != end(); ++i){
- ArithVar nb = getArithVar(*i);
+ ArithVar nb = (*i).getArithVar();
if(nb == basic()) continue;
Node var = (map.find(nb))->second;
- Node coeff = mkRationalNode(getCoefficient(*i));
+ Node coeff = mkRationalNode((*i).getCoefficient());
Node mult = NodeBuilder<2>(MULT) << coeff << var;
sumBuilder << mult;
@@ -260,15 +291,16 @@ Node ReducedRowVector::asEquality(const ArithVarToNodeMap& map) const{
}else{
Assert(size() == 2);
const_iterator i = begin();
- if(getArithVar(*i) == basic()){
+ if((*i).getArithVar() == basic()){
++i;
}
- Assert(getArithVar(*i) != basic());
- Node var = (map.find(getArithVar(*i)))->second;
- Node coeff = mkRationalNode(getCoefficient(*i));
+ Assert((*i).getArithVar() != basic());
+ Node var = (map.find((*i).getArithVar()))->second;
+ Node coeff = mkRationalNode((*i).getCoefficient());
sum = NodeBuilder<2>(MULT) << coeff << var;
}
Node basicVar = (map.find(basic()))->second;
return NodeBuilder<2>(EQUAL) << basicVar << sum;
+*/
}
diff --git a/src/theory/arith/row_vector.h b/src/theory/arith/row_vector.h
index 983e19a0a..0fdfd7f0c 100644
--- a/src/theory/arith/row_vector.h
+++ b/src/theory/arith/row_vector.h
@@ -14,15 +14,26 @@ namespace CVC4 {
namespace theory {
namespace arith {
-typedef std::pair<ArithVar, Rational> VarCoeffPair;
+class VarCoeffPair {
+private:
+ ArithVar d_variable;
+ Rational d_coeff;
+
+public:
+ VarCoeffPair(ArithVar v, const Rational& q): d_variable(v), d_coeff(q) {}
-inline ArithVar getArithVar(const VarCoeffPair& v) { return v.first; }
-inline Rational& getCoefficient(VarCoeffPair& v) { return v.second; }
-inline const Rational& getCoefficient(const VarCoeffPair& v) { return v.second; }
+ ArithVar getArithVar() const { return d_variable; }
+ Rational& getCoefficient() { return d_coeff; }
+ const Rational& getCoefficient() const { return d_coeff; }
-inline bool cmp(const VarCoeffPair& a, const VarCoeffPair& b){
- return getArithVar(a) < getArithVar(b);
-}
+ bool operator<(const VarCoeffPair& other) const{
+ return getArithVar() < other.getArithVar();
+ }
+
+ static bool variableLess(const VarCoeffPair& a, const VarCoeffPair& b){
+ return a < b;
+ }
+};
/**
* ReducedRowVector is a sparse vector representation that represents the
@@ -109,7 +120,7 @@ public:
Assert(has(x_j));
Assert(hasInEntries(x_j));
const_iterator lb = lower_bound(x_j);
- return getCoefficient(*lb);
+ return (*lb).getCoefficient();
}
@@ -190,7 +201,7 @@ private:
bool matchingCounts() const;
const_iterator lower_bound(ArithVar x_j) const{
- return std::lower_bound(d_entries.begin(), d_entries.end(), std::make_pair(x_j,0), cmp);
+ return std::lower_bound(d_entries.begin(), d_entries.end(), VarCoeffPair(x_j, 0));
}
/** Debugging code */
@@ -207,7 +218,7 @@ private:
/** Debugging code. */
bool hasInEntries(ArithVar x_j) const {
- return std::binary_search(d_entries.begin(), d_entries.end(), std::make_pair(x_j,0), cmp);
+ return std::binary_search(d_entries.begin(), d_entries.end(), VarCoeffPair(x_j,0));
}
}; /* class ReducedRowVector */
diff --git a/src/theory/arith/simplex.cpp b/src/theory/arith/simplex.cpp
index 02ce310ff..0809e0788 100644
--- a/src/theory/arith/simplex.cpp
+++ b/src/theory/arith/simplex.cpp
@@ -257,8 +257,8 @@ void SimplexDecisionProcedure::pivotAndUpdate(ArithVar x_i, ArithVar x_j, DeltaR
varIter != row_k.end();
++varIter){
- ArithVar var = varIter->first;
- const Rational& coeff = varIter->second;
+ ArithVar var = (*varIter).getArithVar();
+ const Rational& coeff = (*varIter).getCoefficient();
DeltaRational beta = d_partialModel.getAssignment(var);
Debug("arith::pivotAndUpdate") << var << beta << coeff;
if(d_partialModel.hasLowerBound(var)){
@@ -334,10 +334,10 @@ ArithVar SimplexDecisionProcedure::selectSlack(ArithVar x_i, bool first){
for(ReducedRowVector::const_iterator nbi = row_i.begin(), end = row_i.end();
nbi != end; ++nbi){
- ArithVar nonbasic = getArithVar(*nbi);
+ ArithVar nonbasic = (*nbi).getArithVar();
if(nonbasic == x_i) continue;
- const Rational& a_ij = nbi->second;
+ const Rational& a_ij = (*nbi).getCoefficient();
int cmp = a_ij.cmp(d_constants.d_ZERO);
if(above){ // beta(x_i) > u_i
if( cmp < 0 && d_partialModel.strictlyBelowUpperBound(nonbasic)){
@@ -566,10 +566,10 @@ Node SimplexDecisionProcedure::generateConflictAbove(ArithVar conflictVar){
ReducedRowVector::const_iterator nbi = row_i.begin(), end = row_i.end();
for(; nbi != end; ++nbi){
- ArithVar nonbasic = getArithVar(*nbi);
+ ArithVar nonbasic = (*nbi).getArithVar();
if(nonbasic == conflictVar) continue;
- const Rational& a_ij = nbi->second;
+ const Rational& a_ij = (*nbi).getCoefficient();
Assert(a_ij != d_constants.d_ZERO);
@@ -606,10 +606,10 @@ Node SimplexDecisionProcedure::generateConflictBelow(ArithVar conflictVar){
ReducedRowVector::const_iterator nbi = row_i.begin(), end = row_i.end();
for(; nbi != end; ++nbi){
- ArithVar nonbasic = getArithVar(*nbi);
+ ArithVar nonbasic = (*nbi).getArithVar();
if(nonbasic == conflictVar) continue;
- const Rational& a_ij = nbi->second;
+ const Rational& a_ij = (*nbi).getCoefficient();
Assert(a_ij != d_constants.d_ZERO);
@@ -643,9 +643,9 @@ DeltaRational SimplexDecisionProcedure::computeRowValue(ArithVar x, bool useSafe
ReducedRowVector& row = d_tableau.lookup(x);
for(ReducedRowVector::const_iterator i = row.begin(), end = row.end();
i != end;++i){
- ArithVar nonbasic = getArithVar(*i);
+ ArithVar nonbasic = (*i).getArithVar();
if(nonbasic == row.basic()) continue;
- const Rational& coeff = getCoefficient(*i);
+ const Rational& coeff = (*i).getCoefficient();
const DeltaRational& assignment = d_partialModel.getAssignment(nonbasic, useSafe);
sum = sum + (assignment * coeff);
@@ -671,10 +671,10 @@ void SimplexDecisionProcedure::checkTableau(){
for(ReducedRowVector::const_iterator nonbasicIter = row_k.begin();
nonbasicIter != row_k.end();
++nonbasicIter){
- ArithVar nonbasic = nonbasicIter->first;
+ ArithVar nonbasic = (*nonbasicIter).getArithVar();
if(basic == nonbasic) continue;
- const Rational& coeff = nonbasicIter->second;
+ const Rational& coeff = (*nonbasicIter).getCoefficient();
DeltaRational beta = d_partialModel.getAssignment(nonbasic);
Debug("paranoid:check_tableau") << nonbasic << beta << coeff<<endl;
sum = sum + (beta*coeff);
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback