diff options
author | Tim King <taking@cs.nyu.edu> | 2013-04-26 17:10:21 -0400 |
---|---|---|
committer | Morgan Deters <mdeters@cs.nyu.edu> | 2013-04-26 17:10:21 -0400 |
commit | 9098391fe334d829ec4101f190b8f1fa21c30752 (patch) | |
tree | b134fc1fe1c767a50013e1449330ca6a7ee18a3d /src/theory/arith/matrix.h | |
parent | a9174ce4dc3939bbe14c9aa1fd11c79c7877eb16 (diff) |
FCSimplex branch merge
Diffstat (limited to 'src/theory/arith/matrix.h')
-rw-r--r-- | src/theory/arith/matrix.h | 320 |
1 files changed, 183 insertions, 137 deletions
diff --git a/src/theory/arith/matrix.h b/src/theory/arith/matrix.h index 51c2114a0..100f999e0 100644 --- a/src/theory/arith/matrix.h +++ b/src/theory/arith/matrix.h @@ -19,14 +19,9 @@ #pragma once -#include "expr/node.h" - #include "util/index.h" #include "util/dense_map.h" - #include "theory/arith/arithvar.h" -#include "theory/arith/arithvar_node_map.h" -#include "theory/arith/normal_form.h" #include <queue> #include <vector> @@ -42,6 +37,20 @@ const EntryID ENTRYID_SENTINEL = std::numeric_limits<EntryID>::max(); typedef Index RowIndex; const RowIndex ROW_INDEX_SENTINEL = std::numeric_limits<RowIndex>::max(); +class CoefficientChangeCallback { +public: + virtual void update(RowIndex basic, ArithVar nb, int oldSgn, int currSgn) = 0; + virtual void swap(ArithVar basic, ArithVar nb, int nbSgn) = 0; + virtual bool canUseRow(RowIndex ridx) const = 0; +}; + +class NoEffectCCCB : public CoefficientChangeCallback { +public: + void update(RowIndex ridx, ArithVar nb, int oldSgn, int currSgn); + void swap(ArithVar basic, ArithVar nb, int nbSgn); + bool canUseRow(RowIndex ridx) const; +}; + template<class T> class MatrixEntry { private: @@ -334,6 +343,8 @@ public: typedef typename SuperT::const_iterator const_iterator; RowVector(MatrixEntryVector<T>* mev) : SuperT(mev){} + RowVector(EntryID head, uint32_t size, MatrixEntryVector<T>* mev) + : SuperT(head, size, mev){} };/* class RowVector<T> */ template <class T> @@ -345,6 +356,8 @@ public: typedef typename SuperT::const_iterator const_iterator; ColumnVector(MatrixEntryVector<T>* mev) : SuperT(mev){} + ColumnVector(EntryID head, uint32_t size, MatrixEntryVector<T>* mev) + : SuperT(head, size, mev){} };/* class ColumnVector<T> */ template <class T> @@ -406,6 +419,45 @@ public: d_zero(zero) {} + Matrix(const Matrix& m) + : d_rows(), + d_columns(), + d_mergeBuffer(m.d_mergeBuffer), + d_rowInMergeBuffer(m.d_rowInMergeBuffer), + d_entriesInUse(m.d_entriesInUse), + d_entries(m.d_entries), + d_zero(m.d_zero) + { + d_columns.clear(); + for(typename ColumnTable::const_iterator c=m.d_columns.begin(), cend = m.d_columns.end(); c!=cend; ++c){ + const ColumnVectorT& col = *c; + d_columns.push_back(ColumnVectorT(col.getHead(),col.getSize(),&d_entries)); + } + d_rows.clear(); + for(typename RowTable::const_iterator r=m.d_rows.begin(), rend = m.d_rows.end(); r!=rend; ++r){ + const RowVectorT& row = *r; + d_rows.push_back(RowVectorT(row.getHead(),row.getSize(),&d_entries)); + } + } + + Matrix& operator=(const Matrix& m){ + d_mergeBuffer = (m.d_mergeBuffer); + d_rowInMergeBuffer = (m.d_rowInMergeBuffer); + d_entriesInUse = (m.d_entriesInUse); + d_entries = (m.d_entries); + d_zero = (m.d_zero); + d_columns.clear(); + for(typename ColumnTable::const_iterator c=m.d_columns.begin(), cend = m.d_columns.end(); c!=cend; ++c){ + const ColumnVector<T>& col = *c; + d_columns.push_back(ColumnVector<T>(col.getHead(), col.getSize(), &d_entries)); + } + d_rows.clear(); + for(typename RowTable::const_iterator r=m.d_rows.begin(), rend = m.d_rows.end(); r!=rend; ++r){ + const RowVector<T>& row = *r; + d_rows.push_back(RowVector<T>(row.getHead(), row.getSize(), &d_entries)); + } + return *this; + } protected: @@ -511,12 +563,12 @@ public: //RowIndex ridx = d_rows.size(); //d_rows.push_back(RowVectorT(&d_entries)); - std::vector<Rational>::const_iterator coeffIter = coeffs.begin(); + typename std::vector<T>::const_iterator coeffIter = coeffs.begin(); std::vector<ArithVar>::const_iterator varsIter = variables.begin(); std::vector<ArithVar>::const_iterator varsEnd = variables.end(); for(; varsIter != varsEnd; ++coeffIter, ++varsIter){ - const Rational& coeff = *coeffIter; + const T& coeff = *coeffIter; ArithVar var_i = *varsIter; Assert(var_i < getNumColumns()); addEntry(ridx, var_i, coeff); @@ -578,9 +630,10 @@ public: d_mergeBuffer.get(colVar).second = true; const Entry& other = d_entries.get(bufferEntry); - entry.getCoefficient() += mult * other.getCoefficient(); + T& coeff = entry.getCoefficient(); + coeff += mult * other.getCoefficient(); - if(entry.getCoefficient() == d_zero){ + if(coeff.sgn() == 0){ removeEntry(id); } } @@ -607,6 +660,74 @@ public: if(Debug.isOn("matrix")) { printMatrix(); } } + /** to += mult * buffer. */ + void rowPlusBufferTimesConstant(RowIndex to, const T& mult, CoefficientChangeCallback& cb){ + Assert(d_rowInMergeBuffer != ROW_INDEX_SENTINEL); + Assert(to != ROW_INDEX_SENTINEL); + + Debug("tableau") << "rowPlusRowTimesConstant(" + << to << "," << mult << "," << d_rowInMergeBuffer << ")" + << std::endl; + + Assert(debugNoZeroCoefficients(to)); + Assert(debugNoZeroCoefficients(d_rowInMergeBuffer)); + + Assert(mult != 0); + + + RowIterator i = getRow(to).begin(); + RowIterator i_end = getRow(to).end(); + while(i != i_end){ + EntryID id = i.getID(); + Entry& entry = d_entries.get(id); + ArithVar colVar = entry.getColVar(); + + ++i; + + if(d_mergeBuffer.isKey(colVar)){ + EntryID bufferEntry = d_mergeBuffer[colVar].first; + Assert(!d_mergeBuffer[colVar].second); + d_mergeBuffer.get(colVar).second = true; + + const Entry& other = d_entries.get(bufferEntry); + T& coeff = entry.getCoefficient(); + int coeffOldSgn = coeff.sgn(); + coeff += mult * other.getCoefficient(); + int coeffNewSgn = coeff.sgn(); + + if(coeffOldSgn != coeffNewSgn){ + cb.update(to, colVar, coeffOldSgn, coeffNewSgn); + + if(coeffNewSgn == 0){ + removeEntry(id); + } + } + } + } + + i = getRow(d_rowInMergeBuffer).begin(); + i_end = getRow(d_rowInMergeBuffer).end(); + + for(; i != i_end; ++i){ + const Entry& entry = *i; + ArithVar colVar = entry.getColVar(); + + if(d_mergeBuffer[colVar].second){ + d_mergeBuffer.get(colVar).second = false; + }else{ + Assert(!(d_mergeBuffer[colVar]).second); + T newCoeff = mult * entry.getCoefficient(); + addEntry(to, colVar, newCoeff); + + cb.update(to, colVar, 0, newCoeff.sgn()); + } + } + + Assert(mergeBufferIsClear()); + + if(Debug.isOn("matrix")) { printMatrix(); } + } + bool mergeBufferIsClear() const{ RowToPosUsedPairMap::const_iterator i = d_mergeBuffer.begin(); RowToPosUsedPairMap::const_iterator i_end = d_mergeBuffer.end(); @@ -621,7 +742,7 @@ public: protected: - EntryID findOnRow(RowIndex rid, ArithVar column){ + EntryID findOnRow(RowIndex rid, ArithVar column) const { RowIterator i = d_rows[rid].begin(), i_end = d_rows[rid].end(); for(; i != i_end; ++i){ EntryID id = i.getID(); @@ -635,7 +756,7 @@ protected: return ENTRYID_SENTINEL; } - EntryID findOnCol(RowIndex rid, ArithVar column){ + EntryID findOnCol(RowIndex rid, ArithVar column) const{ ColIterator i = d_columns[column].begin(), i_end = d_columns[column].end(); for(; i != i_end; ++i){ EntryID id = i.getID(); @@ -649,63 +770,59 @@ protected: return ENTRYID_SENTINEL; } + EntryID findEntryID(RowIndex rid, ArithVar col) const{ + bool colIsShorter = getColLength(col) < getRowLength(rid); + EntryID id = colIsShorter ? findOnCol(rid, col) : findOnRow(rid,col); + return id; + } MatrixEntry<T> d_failedFind; public: /** If the find fails, isUnused is true on the entry. */ - const MatrixEntry<T>& findEntry(RowIndex rid, ArithVar col){ - bool colIsShorter = getColLength(col) < getRowLength(rid); - EntryID id = colIsShorter ? findOnCol(rid, col) : findOnRow(rid,col); + const MatrixEntry<T>& findEntry(RowIndex rid, ArithVar col) const{ + EntryID id = findEntryID(rid, col); if(id == ENTRYID_SENTINEL){ return d_failedFind; }else{ - return d_entries.get(id); + return d_entries[id]; } } /** * Prints the contents of the Matrix to Debug("matrix") */ - void printMatrix() const { - Debug("matrix") << "Matrix::printMatrix" << std::endl; + void printMatrix(std::ostream& out) const { + out << "Matrix::printMatrix" << std::endl; for(RowIndex i = 0, N = d_rows.size(); i < N; ++i){ - printRow(i); + printRow(i, out); } } + void printMatrix() const { + printMatrix(Debug("matrix")); + } - void printRow(RowIndex rid) const { - Debug("matrix") << "{" << rid << ":"; + void printRow(RowIndex rid, std::ostream& out) const { + out << "{" << rid << ":"; const RowVector<T>& row = getRow(rid); RowIterator i = row.begin(); RowIterator i_end = row.end(); for(; i != i_end; ++i){ - printEntry(*i); - Debug("matrix") << ","; + printEntry(*i, out); + out << ","; } - Debug("matrix") << "}" << std::endl; + out << "}" << std::endl; + } + void printRow(RowIndex rid) const { + printRow(rid, Debug("matrix")); } + void printEntry(const MatrixEntry<T>& entry, std::ostream& out) const { + out << entry.getColVar() << "*" << entry.getCoefficient(); + } void printEntry(const MatrixEntry<T>& entry) const { - Debug("matrix") << entry.getColVar() << "*" << entry.getCoefficient(); + printEntry(entry, Debug("matrix")); } - - -protected: - - // static bool bufferPairIsNotEmpty(const PosUsedPair& p){ - // return !(p.first == ENTRYID_SENTINEL && p.second == false); - // } - - // static bool bufferPairIsEmpty(const PosUsedPair& p){ - // return (p.first == ENTRYID_SENTINEL && p.second == false); - // } - // bool mergeBufferIsEmpty() const { - // return d_mergeBuffer.end() == std::find_if(d_mergeBuffer.begin(), - // d_mergeBuffer.end(), - // bufferPairIsNotEmpty); - // } - public: uint32_t size() const { return d_entriesInUse; @@ -717,6 +834,31 @@ public: return d_entries.capacity(); } + void manipulateRowEntry(RowIndex row, ArithVar col, const T& c, CoefficientChangeCallback& cb){ + int coeffOldSgn; + int coeffNewSgn; + + EntryID id = findEntryID(row, col); + if(id == ENTRYID_SENTINEL){ + coeffOldSgn = 0; + addEntry(row, col, c); + coeffNewSgn = c.sgn(); + }else{ + Entry& e = d_entries.get(id); + T& t = e.getCoefficient(); + coeffOldSgn = t.sgn(); + t += c; + coeffNewSgn = t.sgn(); + } + + if(coeffOldSgn != coeffNewSgn){ + cb.update(row, col, coeffOldSgn, coeffNewSgn); + } + if(coeffNewSgn == 0){ + removeEntry(id); + } + } + void removeRow(RowIndex rid){ RowIterator i = getRow(rid).begin(); RowIterator i_end = getRow(rid).end(); @@ -822,102 +964,6 @@ protected: };/* class Matrix<T> */ - -/** - * A Tableau is a Rational matrix that keeps its rows in solved form. - * Each row has a basic variable with coefficient -1 that is solved. - * Tableau is optimized for pivoting. - * The tableau should only be updated via pivot calls. - */ -class Tableau : public Matrix<Rational> { -public: -private: - typedef DenseMap<RowIndex> BasicToRowMap; - // Set of all of the basic variables in the tableau. - // ArithVarMap<RowIndex> : ArithVar |-> RowIndex - BasicToRowMap d_basic2RowIndex; - - // RowIndex |-> Basic Variable - typedef DenseMap<ArithVar> RowIndexToBasicMap; - RowIndexToBasicMap d_rowIndex2basic; - -public: - - Tableau() : Matrix<Rational>(Rational(0)) {} - - typedef Matrix<Rational>::ColIterator ColIterator; - typedef Matrix<Rational>::RowIterator RowIterator; - typedef BasicToRowMap::const_iterator BasicIterator; - - typedef MatrixEntry<Rational> Entry; - - bool isBasic(ArithVar v) const{ - return d_basic2RowIndex.isKey(v); - } - - void debugPrintIsBasic(ArithVar v) const { - if(isBasic(v)){ - Warning() << v << " is basic." << std::endl; - }else{ - Warning() << v << " is non-basic." << std::endl; - } - } - - BasicIterator beginBasic() const { - return d_basic2RowIndex.begin(); - } - BasicIterator endBasic() const { - return d_basic2RowIndex.end(); - } - - RowIndex basicToRowIndex(ArithVar x) const { - return d_basic2RowIndex[x]; - } - - ArithVar rowIndexToBasic(RowIndex rid) const { - Assert(rid < d_rowIndex2basic.size()); - return d_rowIndex2basic[rid]; - } - - ColIterator colIterator(ArithVar x) const { - return getColumn(x).begin(); - } - - RowIterator basicRowIterator(ArithVar basic) const { - return getRow(basicToRowIndex(basic)).begin(); - } - - /** - * Adds a row to the tableau. - * The new row is equivalent to: - * basicVar = \f$\sum_i\f$ coeffs[i] * variables[i] - * preconditions: - * basicVar is already declared to be basic - * basicVar does not have a row associated with it in the tableau. - * - * Note: each variables[i] does not have to be non-basic. - * Pivoting will be mimicked if it is basic. - */ - void addRow(ArithVar basicVar, - const std::vector<Rational>& coeffs, - const std::vector<ArithVar>& variables); - - /** - * preconditions: - * x_r is basic, - * x_s is non-basic, and - * a_rs != 0. - */ - void pivot(ArithVar basicOld, ArithVar basicNew); - - void removeBasicRow(ArithVar basic); - -private: - /* Changes the basic variable on the row for basicOld to basicNew. */ - void rowPivot(ArithVar basicOld, ArithVar basicNew); - -};/* class Tableau */ - }/* CVC4::theory::arith namespace */ }/* CVC4::theory namespace */ }/* CVC4 namespace */ |