summaryrefslogtreecommitdiff
path: root/src/theory/arith/matrix.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/theory/arith/matrix.h')
-rw-r--r--src/theory/arith/matrix.h320
1 files changed, 183 insertions, 137 deletions
diff --git a/src/theory/arith/matrix.h b/src/theory/arith/matrix.h
index 51c2114a0..100f999e0 100644
--- a/src/theory/arith/matrix.h
+++ b/src/theory/arith/matrix.h
@@ -19,14 +19,9 @@
#pragma once
-#include "expr/node.h"
-
#include "util/index.h"
#include "util/dense_map.h"
-
#include "theory/arith/arithvar.h"
-#include "theory/arith/arithvar_node_map.h"
-#include "theory/arith/normal_form.h"
#include <queue>
#include <vector>
@@ -42,6 +37,20 @@ const EntryID ENTRYID_SENTINEL = std::numeric_limits<EntryID>::max();
typedef Index RowIndex;
const RowIndex ROW_INDEX_SENTINEL = std::numeric_limits<RowIndex>::max();
+class CoefficientChangeCallback {
+public:
+ virtual void update(RowIndex basic, ArithVar nb, int oldSgn, int currSgn) = 0;
+ virtual void swap(ArithVar basic, ArithVar nb, int nbSgn) = 0;
+ virtual bool canUseRow(RowIndex ridx) const = 0;
+};
+
+class NoEffectCCCB : public CoefficientChangeCallback {
+public:
+ void update(RowIndex ridx, ArithVar nb, int oldSgn, int currSgn);
+ void swap(ArithVar basic, ArithVar nb, int nbSgn);
+ bool canUseRow(RowIndex ridx) const;
+};
+
template<class T>
class MatrixEntry {
private:
@@ -334,6 +343,8 @@ public:
typedef typename SuperT::const_iterator const_iterator;
RowVector(MatrixEntryVector<T>* mev) : SuperT(mev){}
+ RowVector(EntryID head, uint32_t size, MatrixEntryVector<T>* mev)
+ : SuperT(head, size, mev){}
};/* class RowVector<T> */
template <class T>
@@ -345,6 +356,8 @@ public:
typedef typename SuperT::const_iterator const_iterator;
ColumnVector(MatrixEntryVector<T>* mev) : SuperT(mev){}
+ ColumnVector(EntryID head, uint32_t size, MatrixEntryVector<T>* mev)
+ : SuperT(head, size, mev){}
};/* class ColumnVector<T> */
template <class T>
@@ -406,6 +419,45 @@ public:
d_zero(zero)
{}
+ Matrix(const Matrix& m)
+ : d_rows(),
+ d_columns(),
+ d_mergeBuffer(m.d_mergeBuffer),
+ d_rowInMergeBuffer(m.d_rowInMergeBuffer),
+ d_entriesInUse(m.d_entriesInUse),
+ d_entries(m.d_entries),
+ d_zero(m.d_zero)
+ {
+ d_columns.clear();
+ for(typename ColumnTable::const_iterator c=m.d_columns.begin(), cend = m.d_columns.end(); c!=cend; ++c){
+ const ColumnVectorT& col = *c;
+ d_columns.push_back(ColumnVectorT(col.getHead(),col.getSize(),&d_entries));
+ }
+ d_rows.clear();
+ for(typename RowTable::const_iterator r=m.d_rows.begin(), rend = m.d_rows.end(); r!=rend; ++r){
+ const RowVectorT& row = *r;
+ d_rows.push_back(RowVectorT(row.getHead(),row.getSize(),&d_entries));
+ }
+ }
+
+ Matrix& operator=(const Matrix& m){
+ d_mergeBuffer = (m.d_mergeBuffer);
+ d_rowInMergeBuffer = (m.d_rowInMergeBuffer);
+ d_entriesInUse = (m.d_entriesInUse);
+ d_entries = (m.d_entries);
+ d_zero = (m.d_zero);
+ d_columns.clear();
+ for(typename ColumnTable::const_iterator c=m.d_columns.begin(), cend = m.d_columns.end(); c!=cend; ++c){
+ const ColumnVector<T>& col = *c;
+ d_columns.push_back(ColumnVector<T>(col.getHead(), col.getSize(), &d_entries));
+ }
+ d_rows.clear();
+ for(typename RowTable::const_iterator r=m.d_rows.begin(), rend = m.d_rows.end(); r!=rend; ++r){
+ const RowVector<T>& row = *r;
+ d_rows.push_back(RowVector<T>(row.getHead(), row.getSize(), &d_entries));
+ }
+ return *this;
+ }
protected:
@@ -511,12 +563,12 @@ public:
//RowIndex ridx = d_rows.size();
//d_rows.push_back(RowVectorT(&d_entries));
- std::vector<Rational>::const_iterator coeffIter = coeffs.begin();
+ typename std::vector<T>::const_iterator coeffIter = coeffs.begin();
std::vector<ArithVar>::const_iterator varsIter = variables.begin();
std::vector<ArithVar>::const_iterator varsEnd = variables.end();
for(; varsIter != varsEnd; ++coeffIter, ++varsIter){
- const Rational& coeff = *coeffIter;
+ const T& coeff = *coeffIter;
ArithVar var_i = *varsIter;
Assert(var_i < getNumColumns());
addEntry(ridx, var_i, coeff);
@@ -578,9 +630,10 @@ public:
d_mergeBuffer.get(colVar).second = true;
const Entry& other = d_entries.get(bufferEntry);
- entry.getCoefficient() += mult * other.getCoefficient();
+ T& coeff = entry.getCoefficient();
+ coeff += mult * other.getCoefficient();
- if(entry.getCoefficient() == d_zero){
+ if(coeff.sgn() == 0){
removeEntry(id);
}
}
@@ -607,6 +660,74 @@ public:
if(Debug.isOn("matrix")) { printMatrix(); }
}
+ /** to += mult * buffer. */
+ void rowPlusBufferTimesConstant(RowIndex to, const T& mult, CoefficientChangeCallback& cb){
+ Assert(d_rowInMergeBuffer != ROW_INDEX_SENTINEL);
+ Assert(to != ROW_INDEX_SENTINEL);
+
+ Debug("tableau") << "rowPlusRowTimesConstant("
+ << to << "," << mult << "," << d_rowInMergeBuffer << ")"
+ << std::endl;
+
+ Assert(debugNoZeroCoefficients(to));
+ Assert(debugNoZeroCoefficients(d_rowInMergeBuffer));
+
+ Assert(mult != 0);
+
+
+ RowIterator i = getRow(to).begin();
+ RowIterator i_end = getRow(to).end();
+ while(i != i_end){
+ EntryID id = i.getID();
+ Entry& entry = d_entries.get(id);
+ ArithVar colVar = entry.getColVar();
+
+ ++i;
+
+ if(d_mergeBuffer.isKey(colVar)){
+ EntryID bufferEntry = d_mergeBuffer[colVar].first;
+ Assert(!d_mergeBuffer[colVar].second);
+ d_mergeBuffer.get(colVar).second = true;
+
+ const Entry& other = d_entries.get(bufferEntry);
+ T& coeff = entry.getCoefficient();
+ int coeffOldSgn = coeff.sgn();
+ coeff += mult * other.getCoefficient();
+ int coeffNewSgn = coeff.sgn();
+
+ if(coeffOldSgn != coeffNewSgn){
+ cb.update(to, colVar, coeffOldSgn, coeffNewSgn);
+
+ if(coeffNewSgn == 0){
+ removeEntry(id);
+ }
+ }
+ }
+ }
+
+ i = getRow(d_rowInMergeBuffer).begin();
+ i_end = getRow(d_rowInMergeBuffer).end();
+
+ for(; i != i_end; ++i){
+ const Entry& entry = *i;
+ ArithVar colVar = entry.getColVar();
+
+ if(d_mergeBuffer[colVar].second){
+ d_mergeBuffer.get(colVar).second = false;
+ }else{
+ Assert(!(d_mergeBuffer[colVar]).second);
+ T newCoeff = mult * entry.getCoefficient();
+ addEntry(to, colVar, newCoeff);
+
+ cb.update(to, colVar, 0, newCoeff.sgn());
+ }
+ }
+
+ Assert(mergeBufferIsClear());
+
+ if(Debug.isOn("matrix")) { printMatrix(); }
+ }
+
bool mergeBufferIsClear() const{
RowToPosUsedPairMap::const_iterator i = d_mergeBuffer.begin();
RowToPosUsedPairMap::const_iterator i_end = d_mergeBuffer.end();
@@ -621,7 +742,7 @@ public:
protected:
- EntryID findOnRow(RowIndex rid, ArithVar column){
+ EntryID findOnRow(RowIndex rid, ArithVar column) const {
RowIterator i = d_rows[rid].begin(), i_end = d_rows[rid].end();
for(; i != i_end; ++i){
EntryID id = i.getID();
@@ -635,7 +756,7 @@ protected:
return ENTRYID_SENTINEL;
}
- EntryID findOnCol(RowIndex rid, ArithVar column){
+ EntryID findOnCol(RowIndex rid, ArithVar column) const{
ColIterator i = d_columns[column].begin(), i_end = d_columns[column].end();
for(; i != i_end; ++i){
EntryID id = i.getID();
@@ -649,63 +770,59 @@ protected:
return ENTRYID_SENTINEL;
}
+ EntryID findEntryID(RowIndex rid, ArithVar col) const{
+ bool colIsShorter = getColLength(col) < getRowLength(rid);
+ EntryID id = colIsShorter ? findOnCol(rid, col) : findOnRow(rid,col);
+ return id;
+ }
MatrixEntry<T> d_failedFind;
public:
/** If the find fails, isUnused is true on the entry. */
- const MatrixEntry<T>& findEntry(RowIndex rid, ArithVar col){
- bool colIsShorter = getColLength(col) < getRowLength(rid);
- EntryID id = colIsShorter ? findOnCol(rid, col) : findOnRow(rid,col);
+ const MatrixEntry<T>& findEntry(RowIndex rid, ArithVar col) const{
+ EntryID id = findEntryID(rid, col);
if(id == ENTRYID_SENTINEL){
return d_failedFind;
}else{
- return d_entries.get(id);
+ return d_entries[id];
}
}
/**
* Prints the contents of the Matrix to Debug("matrix")
*/
- void printMatrix() const {
- Debug("matrix") << "Matrix::printMatrix" << std::endl;
+ void printMatrix(std::ostream& out) const {
+ out << "Matrix::printMatrix" << std::endl;
for(RowIndex i = 0, N = d_rows.size(); i < N; ++i){
- printRow(i);
+ printRow(i, out);
}
}
+ void printMatrix() const {
+ printMatrix(Debug("matrix"));
+ }
- void printRow(RowIndex rid) const {
- Debug("matrix") << "{" << rid << ":";
+ void printRow(RowIndex rid, std::ostream& out) const {
+ out << "{" << rid << ":";
const RowVector<T>& row = getRow(rid);
RowIterator i = row.begin();
RowIterator i_end = row.end();
for(; i != i_end; ++i){
- printEntry(*i);
- Debug("matrix") << ",";
+ printEntry(*i, out);
+ out << ",";
}
- Debug("matrix") << "}" << std::endl;
+ out << "}" << std::endl;
+ }
+ void printRow(RowIndex rid) const {
+ printRow(rid, Debug("matrix"));
}
+ void printEntry(const MatrixEntry<T>& entry, std::ostream& out) const {
+ out << entry.getColVar() << "*" << entry.getCoefficient();
+ }
void printEntry(const MatrixEntry<T>& entry) const {
- Debug("matrix") << entry.getColVar() << "*" << entry.getCoefficient();
+ printEntry(entry, Debug("matrix"));
}
-
-
-protected:
-
- // static bool bufferPairIsNotEmpty(const PosUsedPair& p){
- // return !(p.first == ENTRYID_SENTINEL && p.second == false);
- // }
-
- // static bool bufferPairIsEmpty(const PosUsedPair& p){
- // return (p.first == ENTRYID_SENTINEL && p.second == false);
- // }
- // bool mergeBufferIsEmpty() const {
- // return d_mergeBuffer.end() == std::find_if(d_mergeBuffer.begin(),
- // d_mergeBuffer.end(),
- // bufferPairIsNotEmpty);
- // }
-
public:
uint32_t size() const {
return d_entriesInUse;
@@ -717,6 +834,31 @@ public:
return d_entries.capacity();
}
+ void manipulateRowEntry(RowIndex row, ArithVar col, const T& c, CoefficientChangeCallback& cb){
+ int coeffOldSgn;
+ int coeffNewSgn;
+
+ EntryID id = findEntryID(row, col);
+ if(id == ENTRYID_SENTINEL){
+ coeffOldSgn = 0;
+ addEntry(row, col, c);
+ coeffNewSgn = c.sgn();
+ }else{
+ Entry& e = d_entries.get(id);
+ T& t = e.getCoefficient();
+ coeffOldSgn = t.sgn();
+ t += c;
+ coeffNewSgn = t.sgn();
+ }
+
+ if(coeffOldSgn != coeffNewSgn){
+ cb.update(row, col, coeffOldSgn, coeffNewSgn);
+ }
+ if(coeffNewSgn == 0){
+ removeEntry(id);
+ }
+ }
+
void removeRow(RowIndex rid){
RowIterator i = getRow(rid).begin();
RowIterator i_end = getRow(rid).end();
@@ -822,102 +964,6 @@ protected:
};/* class Matrix<T> */
-
-/**
- * A Tableau is a Rational matrix that keeps its rows in solved form.
- * Each row has a basic variable with coefficient -1 that is solved.
- * Tableau is optimized for pivoting.
- * The tableau should only be updated via pivot calls.
- */
-class Tableau : public Matrix<Rational> {
-public:
-private:
- typedef DenseMap<RowIndex> BasicToRowMap;
- // Set of all of the basic variables in the tableau.
- // ArithVarMap<RowIndex> : ArithVar |-> RowIndex
- BasicToRowMap d_basic2RowIndex;
-
- // RowIndex |-> Basic Variable
- typedef DenseMap<ArithVar> RowIndexToBasicMap;
- RowIndexToBasicMap d_rowIndex2basic;
-
-public:
-
- Tableau() : Matrix<Rational>(Rational(0)) {}
-
- typedef Matrix<Rational>::ColIterator ColIterator;
- typedef Matrix<Rational>::RowIterator RowIterator;
- typedef BasicToRowMap::const_iterator BasicIterator;
-
- typedef MatrixEntry<Rational> Entry;
-
- bool isBasic(ArithVar v) const{
- return d_basic2RowIndex.isKey(v);
- }
-
- void debugPrintIsBasic(ArithVar v) const {
- if(isBasic(v)){
- Warning() << v << " is basic." << std::endl;
- }else{
- Warning() << v << " is non-basic." << std::endl;
- }
- }
-
- BasicIterator beginBasic() const {
- return d_basic2RowIndex.begin();
- }
- BasicIterator endBasic() const {
- return d_basic2RowIndex.end();
- }
-
- RowIndex basicToRowIndex(ArithVar x) const {
- return d_basic2RowIndex[x];
- }
-
- ArithVar rowIndexToBasic(RowIndex rid) const {
- Assert(rid < d_rowIndex2basic.size());
- return d_rowIndex2basic[rid];
- }
-
- ColIterator colIterator(ArithVar x) const {
- return getColumn(x).begin();
- }
-
- RowIterator basicRowIterator(ArithVar basic) const {
- return getRow(basicToRowIndex(basic)).begin();
- }
-
- /**
- * Adds a row to the tableau.
- * The new row is equivalent to:
- * basicVar = \f$\sum_i\f$ coeffs[i] * variables[i]
- * preconditions:
- * basicVar is already declared to be basic
- * basicVar does not have a row associated with it in the tableau.
- *
- * Note: each variables[i] does not have to be non-basic.
- * Pivoting will be mimicked if it is basic.
- */
- void addRow(ArithVar basicVar,
- const std::vector<Rational>& coeffs,
- const std::vector<ArithVar>& variables);
-
- /**
- * preconditions:
- * x_r is basic,
- * x_s is non-basic, and
- * a_rs != 0.
- */
- void pivot(ArithVar basicOld, ArithVar basicNew);
-
- void removeBasicRow(ArithVar basic);
-
-private:
- /* Changes the basic variable on the row for basicOld to basicNew. */
- void rowPivot(ArithVar basicOld, ArithVar basicNew);
-
-};/* class Tableau */
-
}/* CVC4::theory::arith namespace */
}/* CVC4::theory namespace */
}/* CVC4 namespace */
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback