diff options
Diffstat (limited to 'src/theory/arith')
-rw-r--r-- | src/theory/arith/arith_rewriter.cpp | 69 | ||||
-rw-r--r-- | src/theory/arith/arith_rewriter.h | 74 | ||||
-rw-r--r-- | src/theory/arith/kinds | 12 | ||||
-rw-r--r-- | src/theory/arith/normal_form.cpp | 9 | ||||
-rw-r--r-- | src/theory/arith/normal_form.h | 6 | ||||
-rw-r--r-- | src/theory/arith/theory_arith.cpp | 32 | ||||
-rw-r--r-- | src/theory/arith/theory_arith.h | 25 |
7 files changed, 122 insertions, 105 deletions
diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp index 9f4388b54..75216dac6 100644 --- a/src/theory/arith/arith_rewriter.cpp +++ b/src/theory/arith/arith_rewriter.cpp @@ -27,11 +27,12 @@ #include <set> #include <stack> - using namespace CVC4; using namespace CVC4::theory; using namespace CVC4::theory::arith; +arith::ArithConstants* ArithRewriter::s_constants = NULL; + bool isVariable(TNode t){ return t.getMetaKind() == kind::metakind::VARIABLE; } @@ -40,25 +41,25 @@ RewriteResponse ArithRewriter::rewriteConstant(TNode t){ Assert(t.getMetaKind() == kind::metakind::CONSTANT); Node val = coerceToRationalNode(t); - return RewriteComplete(val); + return RewriteResponse(REWRITE_DONE, val); } RewriteResponse ArithRewriter::rewriteVariable(TNode t){ Assert(isVariable(t)); - return RewriteComplete(t); + return RewriteResponse(REWRITE_DONE, t); } RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){ Assert(t.getKind()== kind::MINUS); - if(t[0] == t[1]) return RewriteComplete(d_constants->d_ZERO_NODE); + if(t[0] == t[1]) return RewriteResponse(REWRITE_DONE, s_constants->d_ZERO_NODE); Node noMinus = makeSubtractionNode(t[0],t[1]); if(pre){ - return RewriteComplete(noMinus); + return RewriteResponse(REWRITE_DONE, noMinus); }else{ - return FullRewriteNeeded(noMinus); + return RewriteResponse(REWRITE_AGAIN_FULL, noMinus); } } @@ -67,9 +68,9 @@ RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){ Node noUminus = makeUnaryMinusNode(t[0]); if(pre) - return RewriteComplete(noUminus); + return RewriteResponse(REWRITE_DONE, noUminus); else - return RewriteAgain(noUminus); + return RewriteResponse(REWRITE_AGAIN, noUminus); } RewriteResponse ArithRewriter::preRewriteTerm(TNode t){ @@ -85,7 +86,7 @@ RewriteResponse ArithRewriter::preRewriteTerm(TNode t){ if(t[0].getKind()== kind::CONST_RATIONAL){ return rewriteDivByConstant(t, true); }else{ - return RewriteComplete(t); + return RewriteResponse(REWRITE_DONE, t); } }else if(t.getKind() == kind::PLUS){ return preRewritePlus(t); @@ -123,25 +124,25 @@ RewriteResponse ArithRewriter::preRewriteMult(TNode t){ for(TNode::iterator i = t.begin(); i != t.end(); ++i) { if((*i).getKind() == kind::CONST_RATIONAL) { - if((*i).getConst<Rational>() == d_constants->d_ZERO) { - return RewriteComplete(d_constants->d_ZERO_NODE); + if((*i).getConst<Rational>() == s_constants->d_ZERO) { + return RewriteResponse(REWRITE_DONE, s_constants->d_ZERO_NODE); } } else if((*i).getKind() == kind::CONST_INTEGER) { if((*i).getConst<Integer>() == intZero) { if(t.getType().isInteger()) { - return RewriteComplete(NodeManager::currentNM()->mkConst(intZero)); + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(intZero)); } else { - return RewriteComplete(d_constants->d_ZERO_NODE); + return RewriteResponse(REWRITE_DONE, s_constants->d_ZERO_NODE); } } } } - return RewriteComplete(t); + return RewriteResponse(REWRITE_DONE, t); } RewriteResponse ArithRewriter::preRewritePlus(TNode t){ Assert(t.getKind()== kind::PLUS); - return RewriteComplete(t); + return RewriteResponse(REWRITE_DONE, t); } RewriteResponse ArithRewriter::postRewritePlus(TNode t){ @@ -156,7 +157,7 @@ RewriteResponse ArithRewriter::postRewritePlus(TNode t){ res = res + currPoly; } - return RewriteComplete(res.getNode()); + return RewriteResponse(REWRITE_DONE, res.getNode()); } RewriteResponse ArithRewriter::postRewriteMult(TNode t){ @@ -171,7 +172,7 @@ RewriteResponse ArithRewriter::postRewriteMult(TNode t){ res = res * currPoly; } - return RewriteComplete(res.getNode()); + return RewriteResponse(REWRITE_DONE, res.getNode()); } RewriteResponse ArithRewriter::postRewriteAtomConstantRHS(TNode t){ @@ -182,7 +183,7 @@ RewriteResponse ArithRewriter::postRewriteAtomConstantRHS(TNode t){ Comparison cmp = Comparison::mkComparison(t.getKind(), Polynomial::parsePolynomial(left), Constant(right)); if(cmp.isBoolean()){ - return RewriteComplete(cmp.getNode()); + return RewriteResponse(REWRITE_DONE, cmp.getNode()); } if(cmp.getLeft().containsConstant()){ @@ -209,7 +210,7 @@ RewriteResponse ArithRewriter::postRewriteAtomConstantRHS(TNode t){ Assert(cmp.getLeft().getHead().coefficientIsOne()); Assert(cmp.isBoolean() || cmp.isNormalForm()); - return RewriteComplete(cmp.getNode()); + return RewriteResponse(REWRITE_DONE, cmp.getNode()); } RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){ @@ -222,8 +223,8 @@ RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){ }else{ //Transform this to: (left - right) |><| 0 Node diff = makeSubtractionNode(left, right); - Node reduction = NodeManager::currentNM()->mkNode(atom.getKind(), diff, d_constants->d_ZERO_NODE); - return FullRewriteNeeded(reduction); + Node reduction = NodeManager::currentNM()->mkNode(atom.getKind(), diff, s_constants->d_ZERO_NODE); + return RewriteResponse(REWRITE_AGAIN_FULL, reduction); } } @@ -233,7 +234,7 @@ RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){ if(atom.getKind() == kind::EQUAL) { if(atom[0] == atom[1]) { - return RewriteComplete(currNM->mkConst(true)); + return RewriteResponse(REWRITE_DONE, currNM->mkConst(true)); } } @@ -246,7 +247,7 @@ RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){ //Transform this to: (left - right) |><| 0 Node diff = makeSubtractionNode(left, right); - reduction = currNM->mkNode(atom.getKind(), diff, d_constants->d_ZERO_NODE); + reduction = currNM->mkNode(atom.getKind(), diff, s_constants->d_ZERO_NODE); } if(reduction.getKind() == kind::GT){ @@ -257,25 +258,25 @@ RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){ reduction = currNM->mkNode(kind::NOT, geq); } - return RewriteComplete(reduction); + return RewriteResponse(REWRITE_DONE, reduction); } RewriteResponse ArithRewriter::postRewrite(TNode t){ if(isTerm(t)){ RewriteResponse response = postRewriteTerm(t); - if(Debug.isOn("arith::rewriter") && response.isDone()) { - Polynomial::parsePolynomial(response.getNode()); + if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) { + Polynomial::parsePolynomial(response.node); } return response; }else if(isAtom(t)){ RewriteResponse response = postRewriteAtom(t); - if(Debug.isOn("arith::rewriter") && response.isDone()) { - Comparison::parseNormalForm(response.getNode()); + if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) { + Comparison::parseNormalForm(response.node); } return response; }else{ Unreachable(); - return RewriteComplete(Node::null()); + return RewriteResponse(REWRITE_DONE, Node::null()); } } @@ -286,12 +287,12 @@ RewriteResponse ArithRewriter::preRewrite(TNode t){ return preRewriteAtom(t); }else{ Unreachable(); - return RewriteComplete(Node::null()); + return RewriteResponse(REWRITE_DONE, Node::null()); } } Node ArithRewriter::makeUnaryMinusNode(TNode n){ - return NodeManager::currentNM()->mkNode(kind::MULT,d_constants->d_NEGATIVE_ONE_NODE,n); + return NodeManager::currentNM()->mkNode(kind::MULT,s_constants->d_NEGATIVE_ONE_NODE,n); } Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){ @@ -311,7 +312,7 @@ RewriteResponse ArithRewriter::rewriteDivByConstant(TNode t, bool pre){ const Rational& den = right.getConst<Rational>(); - Assert(den != d_constants->d_ZERO); + Assert(den != s_constants->d_ZERO); Rational div = den.inverse(); @@ -319,8 +320,8 @@ RewriteResponse ArithRewriter::rewriteDivByConstant(TNode t, bool pre){ Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result); if(pre){ - return RewriteComplete(mult); + return RewriteResponse(REWRITE_DONE, mult); }else{ - return RewriteAgain(mult); + return RewriteResponse(REWRITE_AGAIN, mult); } } diff --git a/src/theory/arith/arith_rewriter.h b/src/theory/arith/arith_rewriter.h index f7ef8c0c7..e161bd8d6 100644 --- a/src/theory/arith/arith_rewriter.h +++ b/src/theory/arith/arith_rewriter.h @@ -17,10 +17,13 @@ ** \todo document this file **/ -#include "theory/arith/arith_constants.h" #include "theory/theory.h" +#include "theory/arith/arith_constants.h" +#include "theory/arith/arith_utilities.h" #include "theory/arith/normal_form.h" +#include "theory/rewriter.h" + #ifndef __CVC4__THEORY__ARITH__REWRITER_H #define __CVC4__THEORY__ARITH__REWRITER_H @@ -28,46 +31,67 @@ namespace CVC4 { namespace theory { namespace arith { -class ArithRewriter{ +class ArithRewriter { + private: - ArithConstants* d_constants; - Node makeSubtractionNode(TNode l, TNode r); - Node makeUnaryMinusNode(TNode n); + static arith::ArithConstants* s_constants; + + static Node makeSubtractionNode(TNode l, TNode r); + static Node makeUnaryMinusNode(TNode n); - RewriteResponse preRewriteTerm(TNode t); - RewriteResponse postRewriteTerm(TNode t); + static RewriteResponse preRewriteTerm(TNode t); + static RewriteResponse postRewriteTerm(TNode t); - RewriteResponse rewriteVariable(TNode t); - RewriteResponse rewriteConstant(TNode t); - RewriteResponse rewriteMinus(TNode t, bool pre); - RewriteResponse rewriteUMinus(TNode t, bool pre); - RewriteResponse rewriteDivByConstant(TNode t, bool pre); + static RewriteResponse rewriteVariable(TNode t); + static RewriteResponse rewriteConstant(TNode t); + static RewriteResponse rewriteMinus(TNode t, bool pre); + static RewriteResponse rewriteUMinus(TNode t, bool pre); + static RewriteResponse rewriteDivByConstant(TNode t, bool pre); - RewriteResponse preRewritePlus(TNode t); - RewriteResponse postRewritePlus(TNode t); + static RewriteResponse preRewritePlus(TNode t); + static RewriteResponse postRewritePlus(TNode t); - RewriteResponse preRewriteMult(TNode t); - RewriteResponse postRewriteMult(TNode t); + static RewriteResponse preRewriteMult(TNode t); + static RewriteResponse postRewriteMult(TNode t); - RewriteResponse preRewriteAtom(TNode t); - RewriteResponse postRewriteAtom(TNode t); - RewriteResponse postRewriteAtomConstantRHS(TNode t); + static RewriteResponse preRewriteAtom(TNode t); + static RewriteResponse postRewriteAtom(TNode t); + static RewriteResponse postRewriteAtomConstantRHS(TNode t); public: - ArithRewriter(ArithConstants* ac) : d_constants(ac) {} - RewriteResponse preRewrite(TNode n); - RewriteResponse postRewrite(TNode n); + static RewriteResponse preRewrite(TNode n); + static RewriteResponse postRewrite(TNode n); + + static void init() { + if (s_constants == NULL) { + s_constants = new arith::ArithConstants(NodeManager::currentNM()); + } + } + + static void shutdown() { + if (s_constants != NULL) { + delete s_constants; + s_constants = NULL; + } + } private: - bool isAtom(TNode n) const { return isRelationOperator(n.getKind()); } - bool isTerm(TNode n) const { return !isAtom(n); } + + static inline bool isAtom(TNode n) { + return arith::isRelationOperator(n.getKind()); + } + + static inline bool isTerm(TNode n) { + return !isAtom(n); + } + }; -}; /* namesapce arith */ +}; /* namesapce rewrite */ }; /* namespace theory */ }; /* namespace CVC4 */ diff --git a/src/theory/arith/kinds b/src/theory/arith/kinds index 6808e3d8f..9e2e3a3a7 100644 --- a/src/theory/arith/kinds +++ b/src/theory/arith/kinds @@ -4,7 +4,12 @@ # src/theory/builtin/kinds. # -theory ::CVC4::theory::arith::TheoryArith "theory_arith.h" +theory THEORY_ARITH ::CVC4::theory::arith::TheoryArith "theory/arith/theory_arith.h" + +properties stable-infinite check propagate staticLearning presolve + +rewriter ::CVC4::theory::arith::ArithRewriter "theory/arith/arith_rewriter.h" + operator PLUS 2: "arithmetic addition" operator MULT 2: "arithmetic multiplication" @@ -12,6 +17,9 @@ operator MINUS 2 "arithmetic binary subtraction operator" operator UMINUS 1 "arithmetic unary negation" operator DIVISION 2 "arithmetic division" +sort REAL_TYPE "Real type" +sort INTEGER_TYPE "Integer type" + constant CONST_RATIONAL \ ::CVC4::Rational \ ::CVC4::RationalHashStrategy \ @@ -28,3 +36,5 @@ operator LT 2 "less than, x < y" operator LEQ 2 "less than or equal, x <= y" operator GT 2 "greater than, x > y" operator GEQ 2 "greater than or equal, x >= y" + +endtheory
\ No newline at end of file diff --git a/src/theory/arith/normal_form.cpp b/src/theory/arith/normal_form.cpp index 766a8fc0a..2a8c1077e 100644 --- a/src/theory/arith/normal_form.cpp +++ b/src/theory/arith/normal_form.cpp @@ -30,9 +30,10 @@ bool VarList::isSorted(iterator start, iterator end) { } bool VarList::isMember(Node n) { - if(n.getNumChildren() == 0) { - return Variable::isMember(n); - } else if(n.getKind() == kind::MULT) { + if(Variable::isMember(n)) { + return true; + } + if(n.getKind() == kind::MULT) { Node::iterator curr = n.begin(), end = n.end(); Node prev = *curr; if(!Variable::isMember(prev)) return false; @@ -59,7 +60,7 @@ int VarList::cmp(const VarList& vl) const { } VarList VarList::parseVarList(Node n) { - if(n.getNumChildren() == 0) { + if(Variable::isMember(n)) { return VarList(Variable(n)); } else { Assert(n.getKind() == kind::MULT); diff --git a/src/theory/arith/normal_form.h b/src/theory/arith/normal_form.h index 1c9b2685d..29db6cdb9 100644 --- a/src/theory/arith/normal_form.h +++ b/src/theory/arith/normal_form.h @@ -25,6 +25,7 @@ #include "expr/node.h" #include "expr/node_self_iterator.h" #include "util/rational.h" +#include "theory/theory.h" #include "theory/arith/arith_constants.h" #include "theory/arith/arith_utilities.h" @@ -183,8 +184,11 @@ public: Assert(isMember(getNode())); } + // TODO: check if it's a theory leaf also static bool isMember(Node n) { - return n.getMetaKind() == kind::metakind::VARIABLE; + if (n.getKind() == kind::CONST_INTEGER) return false; + if (n.getKind() == kind::CONST_RATIONAL) return false; + return Theory::isLeafOf(n, theory::THEORY_ARITH); } bool isNormalForm() { return isMember(getNode()); } diff --git a/src/theory/arith/theory_arith.cpp b/src/theory/arith/theory_arith.cpp index bf5f285a5..b9c983215 100644 --- a/src/theory/arith/theory_arith.cpp +++ b/src/theory/arith/theory_arith.cpp @@ -53,15 +53,14 @@ using namespace CVC4::theory::arith; struct SlackAttrID; typedef expr::Attribute<SlackAttrID, Node> Slack; -TheoryArith::TheoryArith(int id, context::Context* c, OutputChannel& out) : - Theory(id, c, out), +TheoryArith::TheoryArith(context::Context* c, OutputChannel& out) : + Theory(THEORY_ARITH, c, out), d_constants(NodeManager::currentNM()), d_partialModel(c), d_basicManager(), d_activityMonitor(), d_diseq(c), d_tableau(d_activityMonitor, d_basicManager), - d_rewriter(&d_constants), d_propagator(c, out), d_simplex(d_constants, d_partialModel, d_basicManager, d_out, d_activityMonitor, d_tableau), d_statistics() @@ -116,7 +115,7 @@ void TheoryArith::preRegisterTerm(TNode n) { d_out->setIncomplete(); } - if(isTheoryLeaf(n) || isStrictlyVarList){ + if(isLeaf(n) || isStrictlyVarList){ ++(d_statistics.d_statUserVariables); ArithVar varN = requestArithVar(n,false); setupInitialValue(varN); @@ -144,13 +143,8 @@ void TheoryArith::preRegisterTerm(TNode n) { } - -bool TheoryArith::isTheoryLeaf(TNode x) const{ - return x.getMetaKind() == kind::metakind::VARIABLE; -} - ArithVar TheoryArith::requestArithVar(TNode x, bool basic){ - Assert(isTheoryLeaf(x)); + Assert(isLeaf(x)); Assert(!hasArithVar(x)); ArithVar varX = d_variables.size(); @@ -179,7 +173,9 @@ void TheoryArith::asVectors(Polynomial& p, std::vector<Rational>& coeffs, std::v Node n = variable.getNode(); - Assert(isTheoryLeaf(n)); + Debug("rewriter") << "should be var: " << n << endl; + + Assert(isLeaf(n)); Assert(hasArithVar(n)); ArithVar av = asArithVar(n); @@ -191,8 +187,6 @@ void TheoryArith::asVectors(Polynomial& p, std::vector<Rational>& coeffs, std::v void TheoryArith::setupSlack(TNode left){ - - ++(d_statistics.d_statSlackVariables); TypeNode real_type = NodeManager::currentNM()->realType(); Node slack = NodeManager::currentNM()->mkVar(real_type); @@ -242,10 +236,6 @@ void TheoryArith::setupInitialValue(ArithVar x){ Debug("arithgc") << "setupVariable("<<x<<")"<<std::endl; }; -RewriteResponse TheoryArith::preRewrite(TNode n, bool topLevel) { - return d_rewriter.preRewrite(n); -} - void TheoryArith::registerTerm(TNode tn){ Debug("arith") << "registerTerm(" << tn << ")" << endl; } @@ -270,7 +260,7 @@ TNode getSide(TNode assertion, Kind simpleKind){ ArithVar TheoryArith::determineLeftVariable(TNode assertion, Kind simpleKind){ TNode left = getSide<true>(assertion, simpleKind); - if(isTheoryLeaf(left)){ + if(isLeaf(left)){ return asArithVar(left); }else{ Assert(left.hasAttribute(Slack())); @@ -457,7 +447,7 @@ void TheoryArith::check(Effort effortLevel){ } } -void TheoryArith::explain(TNode n, Effort e) { +void TheoryArith::explain(TNode n) { // Node explanation = d_propagator.explain(n); // Debug("arith") << "arith::explain("<<explanation<<")->" // << explanation << endl; @@ -552,3 +542,7 @@ Node TheoryArith::getValue(TNode n, TheoryEngine* engine) { Unhandled(n.getKind()); } } + +void TheoryArith::notifyEq(TNode lhs, TNode rhs) { + +} diff --git a/src/theory/arith/theory_arith.h b/src/theory/arith/theory_arith.h index e9ff06adb..c95ca6cc4 100644 --- a/src/theory/arith/theory_arith.h +++ b/src/theory/arith/theory_arith.h @@ -94,31 +94,14 @@ private: */ Tableau d_tableau; - /** - * The rewriter module for arithmetic. - */ - ArithRewriter d_rewriter; - ArithUnatePropagator d_propagator; SimplexDecisionProcedure d_simplex; public: - TheoryArith(int id, context::Context* c, OutputChannel& out); + TheoryArith(context::Context* c, OutputChannel& out); ~TheoryArith(); /** - * Rewriting optimizations. - */ - RewriteResponse preRewrite(TNode n, bool topLevel); - - /** - * Plug in old rewrite to the new (pre,post)rewrite interface. - */ - RewriteResponse postRewrite(TNode n, bool topLevel) { - return d_rewriter.postRewrite(n); - } - - /** * Does non-context dependent setup for a node connected to a theory. */ void preRegisterTerm(TNode n); @@ -128,7 +111,9 @@ public: void check(Effort e); void propagate(Effort e); - void explain(TNode n, Effort e); + void explain(TNode n); + + void notifyEq(TNode lhs, TNode rhs); Node getValue(TNode n, TheoryEngine* engine); @@ -144,8 +129,6 @@ public: private: - bool isTheoryLeaf(TNode x) const; - ArithVar determineLeftVariable(TNode assertion, Kind simpleKind); |