diff options
Diffstat (limited to 'src/theory/arith/row_vector.cpp')
-rw-r--r-- | src/theory/arith/row_vector.cpp | 56 |
1 files changed, 50 insertions, 6 deletions
diff --git a/src/theory/arith/row_vector.cpp b/src/theory/arith/row_vector.cpp index 6835fc435..f3b979bfd 100644 --- a/src/theory/arith/row_vector.cpp +++ b/src/theory/arith/row_vector.cpp @@ -5,6 +5,31 @@ using namespace CVC4; using namespace CVC4::theory; using namespace CVC4::theory::arith ; +bool RowVector::isSorted(const VarCoeffArray& arr, bool strictlySorted) { + if(arr.size() >= 2){ + NonZeroIterator curr = arr.begin(); + NonZeroIterator end = arr.end(); + ArithVar prev = getArithVar(*curr); + ++curr; + for(;curr != end; ++curr){ + ArithVar v = getArithVar(*curr); + if(strictlySorted && prev > v) return false; + if(!strictlySorted && prev >= v) return false; + prev = v; + } + } + return true; +} + +bool RowVector::noZeroCoefficients(const VarCoeffArray& arr){ + for(NonZeroIterator curr = arr.begin(), end = arr.end(); + curr != end; ++curr){ + const Rational& coeff = getCoefficient(*curr); + if(coeff == 0) return false; + } + return true; +} + void RowVector::zip(const vector< ArithVar >& variables, const vector< Rational >& coefficients, VarCoeffArray& output){ @@ -24,16 +49,26 @@ void RowVector::zip(const vector< ArithVar >& variables, } RowVector::RowVector(const vector< ArithVar >& variables, - const vector< Rational >& coefficients){ + const vector< Rational >& coefficients, + std::vector<uint32_t>& counts): + d_rowCount(counts) +{ zip(variables, coefficients, d_entries); std::sort(d_entries.begin(), d_entries.end(), cmp); + for(NonZeroIterator i=beginNonZero(), end=endNonZero(); i != end; ++i){ + ++d_rowCount[getArithVar(*i)]; + } + Assert(isSorted(d_entries, true)); Assert(noZeroCoefficients(d_entries)); } -void RowVector::merge(VarCoeffArray& arr, const VarCoeffArray& other, const Rational& c){ +void RowVector::merge(VarCoeffArray& arr, + const VarCoeffArray& other, + const Rational& c, + std::vector<uint32_t>& counts){ VarCoeffArray copy = arr; arr.clear(); @@ -48,12 +83,18 @@ void RowVector::merge(VarCoeffArray& arr, const VarCoeffArray& other, const Rati arr.push_back(*curr1); ++curr1; }else if(getArithVar(*curr1) > getArithVar(*curr2)){ + ++counts[getArithVar(*curr2)]; + arr.push_back( make_pair(getArithVar(*curr2), c * getCoefficient(*curr2))); ++curr2; }else{ Rational res = getCoefficient(*curr1) + c * getCoefficient(*curr2); if(res != 0){ + ++counts[getArithVar(*curr2)]; + arr.push_back(make_pair(getArithVar(*curr1), res)); + }else{ + --counts[getArithVar(*curr2)]; } ++curr1; ++curr2; @@ -64,6 +105,8 @@ void RowVector::merge(VarCoeffArray& arr, const VarCoeffArray& other, const Rati ++curr1; } while(curr2 != end2){ + ++counts[getArithVar(*curr2)]; + arr.push_back(make_pair(getArithVar(*curr2), c * getCoefficient(*curr2))); ++curr2; } @@ -80,7 +123,7 @@ void RowVector::multiply(const Rational& c){ void RowVector::addRowTimesConstant(const Rational& c, const RowVector& other){ Assert(c != 0); - merge(d_entries, other.d_entries, c); + merge(d_entries, other.d_entries, c, d_rowCount); } void RowVector::printRow(){ @@ -93,14 +136,15 @@ void RowVector::printRow(){ ReducedRowVector::ReducedRowVector(ArithVar basic, const vector<ArithVar>& variables, - const vector<Rational>& coefficients): - RowVector(variables, coefficients), d_basic(basic){ + const vector<Rational>& coefficients, + std::vector<uint32_t>& count): + RowVector(variables, coefficients, count), d_basic(basic){ VarCoeffArray justBasic; justBasic.push_back(make_pair(basic, Rational(-1))); - merge(d_entries,justBasic, Rational(1)); + merge(d_entries,justBasic, Rational(1), d_rowCount); Assert(wellFormed()); } |