diff options
Diffstat (limited to 'src/theory/arrays/theory_arrays.cpp')
-rw-r--r-- | src/theory/arrays/theory_arrays.cpp | 244 |
1 files changed, 244 insertions, 0 deletions
diff --git a/src/theory/arrays/theory_arrays.cpp b/src/theory/arrays/theory_arrays.cpp index 37c49b341..dab78c17a 100644 --- a/src/theory/arrays/theory_arrays.cpp +++ b/src/theory/arrays/theory_arrays.cpp @@ -21,6 +21,7 @@ #include "theory/valuation.h" #include "expr/kind.h" #include <map> +#include "theory/rewriter.h" using namespace std; using namespace CVC4; @@ -184,6 +185,208 @@ Node TheoryArrays::getValue(TNode n) { } } +Theory::SolveStatus TheoryArrays::solve(TNode in, SubstitutionMap& outSubstitutions) { + switch(in.getKind()) { + case kind::EQUAL: + { + d_staticFactManager.addEq(in); + if (in[0].getMetaKind() == kind::metakind::VARIABLE && !in[1].hasSubterm(in[0])) { + outSubstitutions.addSubstitution(in[0], in[1]); + return SOLVE_STATUS_SOLVED; + } + if (in[1].getMetaKind() == kind::metakind::VARIABLE && !in[0].hasSubterm(in[1])) { + outSubstitutions.addSubstitution(in[1], in[0]); + return SOLVE_STATUS_SOLVED; + } + break; + } + case kind::NOT: + { + Assert(in[0].getKind() == kind::EQUAL || + in[0].getKind() == kind::IFF ); + Node a = in[0][0]; + Node b = in[0][1]; + d_staticFactManager.addDiseq(in[0]); + break; + } + default: + break; + } + return SOLVE_STATUS_UNSOLVED; +} + +Node TheoryArrays::preprocessTerm(TNode term) { + switch (term.getKind()) { + case kind::SELECT: { + // select(store(a,i,v),j) = select(a,j) + // IF i != j + if (term[0].getKind() == kind::STORE && + d_staticFactManager.areDiseq(term[0][1], term[1])) { + return NodeBuilder<2>(kind::SELECT) << term[0][0] << term[1]; + } + break; + } + case kind::STORE: { + // store(store(a,i,v),j,w) = store(store(a,j,w),i,v) + // IF i != j and j comes before i in the ordering + if (term[0].getKind() == kind::STORE && + (term[1] < term[0][1]) && + d_staticFactManager.areDiseq(term[1], term[0][1])) { + Node inner = NodeBuilder<3>(kind::STORE) << term[0][0] << term[1] << term[2]; + Node outer = NodeBuilder<3>(kind::STORE) << inner << term[0][1] << term[0][2]; + return outer; + } + break; + } + case kind::EQUAL: { + if (term[0].getKind() == kind::STORE || + term[1].getKind() == kind::STORE) { + TNode left = term[0]; + TNode right = term[1]; + int leftWrites = 0, rightWrites = 0; + + // Count nested writes + TNode e1 = left; + while (e1.getKind() == kind::STORE) { + ++leftWrites; + e1 = e1[0]; + } + + TNode e2 = right; + while (e2.getKind() == kind::STORE) { + ++rightWrites; + e2 = e2[0]; + } + + if (rightWrites > leftWrites) { + TNode tmp = left; + left = right; + right = tmp; + int tmpWrites = leftWrites; + leftWrites = rightWrites; + rightWrites = tmpWrites; + } + + NodeManager* nm = NodeManager::currentNM(); + if (rightWrites == 0) { + if (e1 == e2) { + // write(store, index_0, v_0, index_1, v_1, ..., index_n, v_n) = store IFF + // + // read(store, index_n) = v_n & + // index_{n-1} != index_n -> read(store, index_{n-1}) = v_{n-1} & + // (index_{n-2} != index_{n-1} & index_{n-2} != index_n) -> read(store, index_{n-2}) = v_{n-2} & + // ... + // (index_1 != index_2 & ... & index_1 != index_n) -> read(store, index_1) = v_1 + // (index_0 != index_1 & index_0 != index_2 & ... & index_0 != index_n) -> read(store, index_0) = v_0 + TNode write_i, write_j, index_i, index_j; + Node conc; + NodeBuilder<> result(kind::AND); + int i, j; + write_i = left; + for (i = leftWrites-1; i >= 0; --i) { + index_i = write_i[1]; + + // build: [index_i /= index_n && index_i /= index_(n-1) && + // ... && index_i /= index_(i+1)] -> read(store, index_i) = v_i + write_j = left; + { + NodeBuilder<> hyp(kind::AND); + for (j = leftWrites - 1; j > i; --j) { + index_j = write_j[1]; + if (d_staticFactManager.areDiseq(index_i, index_j)) { + continue; + } + Node hyp2(index_i.getType() == nm->booleanType()? + index_i.iffNode(index_j) : index_i.eqNode(index_j)); + hyp << hyp2.notNode(); + write_j = write_j[0]; + } + + Node r1 = nm->mkNode(kind::SELECT, e1, index_i); + conc = (r1.getType() == nm->booleanType())? + r1.iffNode(write_i[2]) : r1.eqNode(write_i[2]); + if (hyp.getNumChildren() != 0) { + if (hyp.getNumChildren() == 1) { + conc = hyp.getChild(0).impNode(conc); + } + else { + r1 = hyp; + conc = r1.impNode(conc); + } + } + + // And into result + result << conc; + + // Prepare for next iteration + write_i = write_i[0]; + } + } + Assert(result.getNumChildren() > 0); + if (result.getNumChildren() == 1) { + return result.getChild(0); + } + return result; + } + break; + } + else { + // store(...) = store(a,i,v) ==> + // store(store(...),i,select(a,i)) = a && select(store(...),i)=v + Node l = left; + Node tmp; + NodeBuilder<> nb(kind::AND); + while (right.getKind() == STORE) { + tmp = nm->mkNode(kind::SELECT, l, right[1]); + nb << tmp.eqNode(right[2]); + tmp = nm->mkNode(kind::SELECT, right[0], right[1]); + l = nm->mkNode(kind::STORE, l, right[1], tmp); + right = right[0]; + } + nb << l.eqNode(right); + return nb; + } + } + break; + } + default: + break; + } + return term; +} + +Node TheoryArrays::recursivePreprocessTerm(TNode term) { + unsigned nc = term.getNumChildren(); + if (nc == 0 || + (theoryOf(term) != theory::THEORY_ARRAY && + term.getType() != NodeManager::currentNM()->booleanType())) { + return term; + } + NodeMap::iterator find = d_ppCache.find(term); + if (find != d_ppCache.end()) { + return (*find).second; + } + NodeBuilder<> newNode(term.getKind()); + unsigned i; + for (i = 0; i < nc; ++i) { + newNode << recursivePreprocessTerm(term[i]); + } + Node newTerm = Rewriter::rewrite(newNode); + Node newTerm2 = preprocessTerm(newTerm); + if (newTerm != newTerm2) { + newTerm = recursivePreprocessTerm(Rewriter::rewrite(newTerm2)); + } + d_ppCache[term] = newTerm; + return newTerm; +} + +Node TheoryArrays::preprocess(TNode atom) { + if (d_donePreregister) return atom; + Assert(atom.getKind() == kind::EQUAL); + return recursivePreprocessTerm(atom); +} + + void TheoryArrays::merge(TNode a, TNode b) { Assert(d_conflict.isNull()); @@ -508,7 +711,48 @@ bool TheoryArrays::isRedundantInContext(TNode a, TNode b, TNode i, TNode j) { checkRowForIndex(j,b); // why am i doing this? checkRowForIndex(i,a); return true; + } + Node literal1 = Rewriter::rewrite(i.eqNode(j)); + bool hasValue1, satValue1; + Node ff = nm->mkConst<bool>(false); + Node tt = nm->mkConst<bool>(true); + if (literal1 == ff) { + hasValue1 = true; + satValue1 = false; + } + else if (literal1 == tt) { + hasValue1 = true; + satValue1 = true; + } + else hasValue1 = (d_valuation.isSatLiteral(literal1) && d_valuation.hasSatValue(literal1, satValue1)); + if (hasValue1) { + if (satValue1) return true; + Node literal2 = Rewriter::rewrite(aj.eqNode(bj)); + bool hasValue2, satValue2; + if (literal2 == ff) { + hasValue2 = true; + satValue2 = false; } + else if (literal2 == tt) { + hasValue2 = true; + satValue2 = true; + } + else hasValue2 = (d_valuation.isSatLiteral(literal2) && d_valuation.hasSatValue(literal2, satValue2)); + if (hasValue2) { + if (satValue2) return true; + // conflict + Assert(!satValue1 && !satValue2); + Assert(literal1.getKind() == kind::EQUAL && literal2.getKind() == kind::EQUAL); + NodeBuilder<2> nb(kind::AND); + literal1 = areDisequal(literal1[0],literal1[1]); + literal2 = areDisequal(literal2[0],literal2[1]); + Assert(!literal1.isNull() && !literal2.isNull()); + nb << literal1.notNode() << literal2.notNode(); + literal1 = nb; + d_out->conflict(literal1, false); + return true; + } + } if(alreadyAddedRow(a,b,i,j)) { // Debug("arrays-lem")<<"isRedundantInContext already added "<<a<<" "<<b<<" "<<i<<" "<<j<<"\n"; return true; |