summaryrefslogtreecommitdiff
path: root/src/theory/arith
diff options
context:
space:
mode:
Diffstat (limited to 'src/theory/arith')
-rw-r--r--src/theory/arith/arith_rewriter.cpp69
-rw-r--r--src/theory/arith/arith_rewriter.h74
-rw-r--r--src/theory/arith/kinds12
-rw-r--r--src/theory/arith/normal_form.cpp9
-rw-r--r--src/theory/arith/normal_form.h6
-rw-r--r--src/theory/arith/theory_arith.cpp32
-rw-r--r--src/theory/arith/theory_arith.h25
7 files changed, 122 insertions, 105 deletions
diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp
index 9f4388b54..75216dac6 100644
--- a/src/theory/arith/arith_rewriter.cpp
+++ b/src/theory/arith/arith_rewriter.cpp
@@ -27,11 +27,12 @@
#include <set>
#include <stack>
-
using namespace CVC4;
using namespace CVC4::theory;
using namespace CVC4::theory::arith;
+arith::ArithConstants* ArithRewriter::s_constants = NULL;
+
bool isVariable(TNode t){
return t.getMetaKind() == kind::metakind::VARIABLE;
}
@@ -40,25 +41,25 @@ RewriteResponse ArithRewriter::rewriteConstant(TNode t){
Assert(t.getMetaKind() == kind::metakind::CONSTANT);
Node val = coerceToRationalNode(t);
- return RewriteComplete(val);
+ return RewriteResponse(REWRITE_DONE, val);
}
RewriteResponse ArithRewriter::rewriteVariable(TNode t){
Assert(isVariable(t));
- return RewriteComplete(t);
+ return RewriteResponse(REWRITE_DONE, t);
}
RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
Assert(t.getKind()== kind::MINUS);
- if(t[0] == t[1]) return RewriteComplete(d_constants->d_ZERO_NODE);
+ if(t[0] == t[1]) return RewriteResponse(REWRITE_DONE, s_constants->d_ZERO_NODE);
Node noMinus = makeSubtractionNode(t[0],t[1]);
if(pre){
- return RewriteComplete(noMinus);
+ return RewriteResponse(REWRITE_DONE, noMinus);
}else{
- return FullRewriteNeeded(noMinus);
+ return RewriteResponse(REWRITE_AGAIN_FULL, noMinus);
}
}
@@ -67,9 +68,9 @@ RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
Node noUminus = makeUnaryMinusNode(t[0]);
if(pre)
- return RewriteComplete(noUminus);
+ return RewriteResponse(REWRITE_DONE, noUminus);
else
- return RewriteAgain(noUminus);
+ return RewriteResponse(REWRITE_AGAIN, noUminus);
}
RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
@@ -85,7 +86,7 @@ RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
if(t[0].getKind()== kind::CONST_RATIONAL){
return rewriteDivByConstant(t, true);
}else{
- return RewriteComplete(t);
+ return RewriteResponse(REWRITE_DONE, t);
}
}else if(t.getKind() == kind::PLUS){
return preRewritePlus(t);
@@ -123,25 +124,25 @@ RewriteResponse ArithRewriter::preRewriteMult(TNode t){
for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
if((*i).getKind() == kind::CONST_RATIONAL) {
- if((*i).getConst<Rational>() == d_constants->d_ZERO) {
- return RewriteComplete(d_constants->d_ZERO_NODE);
+ if((*i).getConst<Rational>() == s_constants->d_ZERO) {
+ return RewriteResponse(REWRITE_DONE, s_constants->d_ZERO_NODE);
}
} else if((*i).getKind() == kind::CONST_INTEGER) {
if((*i).getConst<Integer>() == intZero) {
if(t.getType().isInteger()) {
- return RewriteComplete(NodeManager::currentNM()->mkConst(intZero));
+ return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(intZero));
} else {
- return RewriteComplete(d_constants->d_ZERO_NODE);
+ return RewriteResponse(REWRITE_DONE, s_constants->d_ZERO_NODE);
}
}
}
}
- return RewriteComplete(t);
+ return RewriteResponse(REWRITE_DONE, t);
}
RewriteResponse ArithRewriter::preRewritePlus(TNode t){
Assert(t.getKind()== kind::PLUS);
- return RewriteComplete(t);
+ return RewriteResponse(REWRITE_DONE, t);
}
RewriteResponse ArithRewriter::postRewritePlus(TNode t){
@@ -156,7 +157,7 @@ RewriteResponse ArithRewriter::postRewritePlus(TNode t){
res = res + currPoly;
}
- return RewriteComplete(res.getNode());
+ return RewriteResponse(REWRITE_DONE, res.getNode());
}
RewriteResponse ArithRewriter::postRewriteMult(TNode t){
@@ -171,7 +172,7 @@ RewriteResponse ArithRewriter::postRewriteMult(TNode t){
res = res * currPoly;
}
- return RewriteComplete(res.getNode());
+ return RewriteResponse(REWRITE_DONE, res.getNode());
}
RewriteResponse ArithRewriter::postRewriteAtomConstantRHS(TNode t){
@@ -182,7 +183,7 @@ RewriteResponse ArithRewriter::postRewriteAtomConstantRHS(TNode t){
Comparison cmp = Comparison::mkComparison(t.getKind(), Polynomial::parsePolynomial(left), Constant(right));
if(cmp.isBoolean()){
- return RewriteComplete(cmp.getNode());
+ return RewriteResponse(REWRITE_DONE, cmp.getNode());
}
if(cmp.getLeft().containsConstant()){
@@ -209,7 +210,7 @@ RewriteResponse ArithRewriter::postRewriteAtomConstantRHS(TNode t){
Assert(cmp.getLeft().getHead().coefficientIsOne());
Assert(cmp.isBoolean() || cmp.isNormalForm());
- return RewriteComplete(cmp.getNode());
+ return RewriteResponse(REWRITE_DONE, cmp.getNode());
}
RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
@@ -222,8 +223,8 @@ RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
}else{
//Transform this to: (left - right) |><| 0
Node diff = makeSubtractionNode(left, right);
- Node reduction = NodeManager::currentNM()->mkNode(atom.getKind(), diff, d_constants->d_ZERO_NODE);
- return FullRewriteNeeded(reduction);
+ Node reduction = NodeManager::currentNM()->mkNode(atom.getKind(), diff, s_constants->d_ZERO_NODE);
+ return RewriteResponse(REWRITE_AGAIN_FULL, reduction);
}
}
@@ -233,7 +234,7 @@ RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
if(atom.getKind() == kind::EQUAL) {
if(atom[0] == atom[1]) {
- return RewriteComplete(currNM->mkConst(true));
+ return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
}
}
@@ -246,7 +247,7 @@ RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
//Transform this to: (left - right) |><| 0
Node diff = makeSubtractionNode(left, right);
- reduction = currNM->mkNode(atom.getKind(), diff, d_constants->d_ZERO_NODE);
+ reduction = currNM->mkNode(atom.getKind(), diff, s_constants->d_ZERO_NODE);
}
if(reduction.getKind() == kind::GT){
@@ -257,25 +258,25 @@ RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
reduction = currNM->mkNode(kind::NOT, geq);
}
- return RewriteComplete(reduction);
+ return RewriteResponse(REWRITE_DONE, reduction);
}
RewriteResponse ArithRewriter::postRewrite(TNode t){
if(isTerm(t)){
RewriteResponse response = postRewriteTerm(t);
- if(Debug.isOn("arith::rewriter") && response.isDone()) {
- Polynomial::parsePolynomial(response.getNode());
+ if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
+ Polynomial::parsePolynomial(response.node);
}
return response;
}else if(isAtom(t)){
RewriteResponse response = postRewriteAtom(t);
- if(Debug.isOn("arith::rewriter") && response.isDone()) {
- Comparison::parseNormalForm(response.getNode());
+ if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
+ Comparison::parseNormalForm(response.node);
}
return response;
}else{
Unreachable();
- return RewriteComplete(Node::null());
+ return RewriteResponse(REWRITE_DONE, Node::null());
}
}
@@ -286,12 +287,12 @@ RewriteResponse ArithRewriter::preRewrite(TNode t){
return preRewriteAtom(t);
}else{
Unreachable();
- return RewriteComplete(Node::null());
+ return RewriteResponse(REWRITE_DONE, Node::null());
}
}
Node ArithRewriter::makeUnaryMinusNode(TNode n){
- return NodeManager::currentNM()->mkNode(kind::MULT,d_constants->d_NEGATIVE_ONE_NODE,n);
+ return NodeManager::currentNM()->mkNode(kind::MULT,s_constants->d_NEGATIVE_ONE_NODE,n);
}
Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
@@ -311,7 +312,7 @@ RewriteResponse ArithRewriter::rewriteDivByConstant(TNode t, bool pre){
const Rational& den = right.getConst<Rational>();
- Assert(den != d_constants->d_ZERO);
+ Assert(den != s_constants->d_ZERO);
Rational div = den.inverse();
@@ -319,8 +320,8 @@ RewriteResponse ArithRewriter::rewriteDivByConstant(TNode t, bool pre){
Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
if(pre){
- return RewriteComplete(mult);
+ return RewriteResponse(REWRITE_DONE, mult);
}else{
- return RewriteAgain(mult);
+ return RewriteResponse(REWRITE_AGAIN, mult);
}
}
diff --git a/src/theory/arith/arith_rewriter.h b/src/theory/arith/arith_rewriter.h
index f7ef8c0c7..e161bd8d6 100644
--- a/src/theory/arith/arith_rewriter.h
+++ b/src/theory/arith/arith_rewriter.h
@@ -17,10 +17,13 @@
** \todo document this file
**/
-#include "theory/arith/arith_constants.h"
#include "theory/theory.h"
+#include "theory/arith/arith_constants.h"
+#include "theory/arith/arith_utilities.h"
#include "theory/arith/normal_form.h"
+#include "theory/rewriter.h"
+
#ifndef __CVC4__THEORY__ARITH__REWRITER_H
#define __CVC4__THEORY__ARITH__REWRITER_H
@@ -28,46 +31,67 @@ namespace CVC4 {
namespace theory {
namespace arith {
-class ArithRewriter{
+class ArithRewriter {
+
private:
- ArithConstants* d_constants;
- Node makeSubtractionNode(TNode l, TNode r);
- Node makeUnaryMinusNode(TNode n);
+ static arith::ArithConstants* s_constants;
+
+ static Node makeSubtractionNode(TNode l, TNode r);
+ static Node makeUnaryMinusNode(TNode n);
- RewriteResponse preRewriteTerm(TNode t);
- RewriteResponse postRewriteTerm(TNode t);
+ static RewriteResponse preRewriteTerm(TNode t);
+ static RewriteResponse postRewriteTerm(TNode t);
- RewriteResponse rewriteVariable(TNode t);
- RewriteResponse rewriteConstant(TNode t);
- RewriteResponse rewriteMinus(TNode t, bool pre);
- RewriteResponse rewriteUMinus(TNode t, bool pre);
- RewriteResponse rewriteDivByConstant(TNode t, bool pre);
+ static RewriteResponse rewriteVariable(TNode t);
+ static RewriteResponse rewriteConstant(TNode t);
+ static RewriteResponse rewriteMinus(TNode t, bool pre);
+ static RewriteResponse rewriteUMinus(TNode t, bool pre);
+ static RewriteResponse rewriteDivByConstant(TNode t, bool pre);
- RewriteResponse preRewritePlus(TNode t);
- RewriteResponse postRewritePlus(TNode t);
+ static RewriteResponse preRewritePlus(TNode t);
+ static RewriteResponse postRewritePlus(TNode t);
- RewriteResponse preRewriteMult(TNode t);
- RewriteResponse postRewriteMult(TNode t);
+ static RewriteResponse preRewriteMult(TNode t);
+ static RewriteResponse postRewriteMult(TNode t);
- RewriteResponse preRewriteAtom(TNode t);
- RewriteResponse postRewriteAtom(TNode t);
- RewriteResponse postRewriteAtomConstantRHS(TNode t);
+ static RewriteResponse preRewriteAtom(TNode t);
+ static RewriteResponse postRewriteAtom(TNode t);
+ static RewriteResponse postRewriteAtomConstantRHS(TNode t);
public:
- ArithRewriter(ArithConstants* ac) : d_constants(ac) {}
- RewriteResponse preRewrite(TNode n);
- RewriteResponse postRewrite(TNode n);
+ static RewriteResponse preRewrite(TNode n);
+ static RewriteResponse postRewrite(TNode n);
+
+ static void init() {
+ if (s_constants == NULL) {
+ s_constants = new arith::ArithConstants(NodeManager::currentNM());
+ }
+ }
+
+ static void shutdown() {
+ if (s_constants != NULL) {
+ delete s_constants;
+ s_constants = NULL;
+ }
+ }
private:
- bool isAtom(TNode n) const { return isRelationOperator(n.getKind()); }
- bool isTerm(TNode n) const { return !isAtom(n); }
+
+ static inline bool isAtom(TNode n) {
+ return arith::isRelationOperator(n.getKind());
+ }
+
+ static inline bool isTerm(TNode n) {
+ return !isAtom(n);
+ }
+
};
-}; /* namesapce arith */
+}; /* namesapce rewrite */
}; /* namespace theory */
}; /* namespace CVC4 */
diff --git a/src/theory/arith/kinds b/src/theory/arith/kinds
index 6808e3d8f..9e2e3a3a7 100644
--- a/src/theory/arith/kinds
+++ b/src/theory/arith/kinds
@@ -4,7 +4,12 @@
# src/theory/builtin/kinds.
#
-theory ::CVC4::theory::arith::TheoryArith "theory_arith.h"
+theory THEORY_ARITH ::CVC4::theory::arith::TheoryArith "theory/arith/theory_arith.h"
+
+properties stable-infinite check propagate staticLearning presolve
+
+rewriter ::CVC4::theory::arith::ArithRewriter "theory/arith/arith_rewriter.h"
+
operator PLUS 2: "arithmetic addition"
operator MULT 2: "arithmetic multiplication"
@@ -12,6 +17,9 @@ operator MINUS 2 "arithmetic binary subtraction operator"
operator UMINUS 1 "arithmetic unary negation"
operator DIVISION 2 "arithmetic division"
+sort REAL_TYPE "Real type"
+sort INTEGER_TYPE "Integer type"
+
constant CONST_RATIONAL \
::CVC4::Rational \
::CVC4::RationalHashStrategy \
@@ -28,3 +36,5 @@ operator LT 2 "less than, x < y"
operator LEQ 2 "less than or equal, x <= y"
operator GT 2 "greater than, x > y"
operator GEQ 2 "greater than or equal, x >= y"
+
+endtheory \ No newline at end of file
diff --git a/src/theory/arith/normal_form.cpp b/src/theory/arith/normal_form.cpp
index 766a8fc0a..2a8c1077e 100644
--- a/src/theory/arith/normal_form.cpp
+++ b/src/theory/arith/normal_form.cpp
@@ -30,9 +30,10 @@ bool VarList::isSorted(iterator start, iterator end) {
}
bool VarList::isMember(Node n) {
- if(n.getNumChildren() == 0) {
- return Variable::isMember(n);
- } else if(n.getKind() == kind::MULT) {
+ if(Variable::isMember(n)) {
+ return true;
+ }
+ if(n.getKind() == kind::MULT) {
Node::iterator curr = n.begin(), end = n.end();
Node prev = *curr;
if(!Variable::isMember(prev)) return false;
@@ -59,7 +60,7 @@ int VarList::cmp(const VarList& vl) const {
}
VarList VarList::parseVarList(Node n) {
- if(n.getNumChildren() == 0) {
+ if(Variable::isMember(n)) {
return VarList(Variable(n));
} else {
Assert(n.getKind() == kind::MULT);
diff --git a/src/theory/arith/normal_form.h b/src/theory/arith/normal_form.h
index 1c9b2685d..29db6cdb9 100644
--- a/src/theory/arith/normal_form.h
+++ b/src/theory/arith/normal_form.h
@@ -25,6 +25,7 @@
#include "expr/node.h"
#include "expr/node_self_iterator.h"
#include "util/rational.h"
+#include "theory/theory.h"
#include "theory/arith/arith_constants.h"
#include "theory/arith/arith_utilities.h"
@@ -183,8 +184,11 @@ public:
Assert(isMember(getNode()));
}
+ // TODO: check if it's a theory leaf also
static bool isMember(Node n) {
- return n.getMetaKind() == kind::metakind::VARIABLE;
+ if (n.getKind() == kind::CONST_INTEGER) return false;
+ if (n.getKind() == kind::CONST_RATIONAL) return false;
+ return Theory::isLeafOf(n, theory::THEORY_ARITH);
}
bool isNormalForm() { return isMember(getNode()); }
diff --git a/src/theory/arith/theory_arith.cpp b/src/theory/arith/theory_arith.cpp
index bf5f285a5..b9c983215 100644
--- a/src/theory/arith/theory_arith.cpp
+++ b/src/theory/arith/theory_arith.cpp
@@ -53,15 +53,14 @@ using namespace CVC4::theory::arith;
struct SlackAttrID;
typedef expr::Attribute<SlackAttrID, Node> Slack;
-TheoryArith::TheoryArith(int id, context::Context* c, OutputChannel& out) :
- Theory(id, c, out),
+TheoryArith::TheoryArith(context::Context* c, OutputChannel& out) :
+ Theory(THEORY_ARITH, c, out),
d_constants(NodeManager::currentNM()),
d_partialModel(c),
d_basicManager(),
d_activityMonitor(),
d_diseq(c),
d_tableau(d_activityMonitor, d_basicManager),
- d_rewriter(&d_constants),
d_propagator(c, out),
d_simplex(d_constants, d_partialModel, d_basicManager, d_out, d_activityMonitor, d_tableau),
d_statistics()
@@ -116,7 +115,7 @@ void TheoryArith::preRegisterTerm(TNode n) {
d_out->setIncomplete();
}
- if(isTheoryLeaf(n) || isStrictlyVarList){
+ if(isLeaf(n) || isStrictlyVarList){
++(d_statistics.d_statUserVariables);
ArithVar varN = requestArithVar(n,false);
setupInitialValue(varN);
@@ -144,13 +143,8 @@ void TheoryArith::preRegisterTerm(TNode n) {
}
-
-bool TheoryArith::isTheoryLeaf(TNode x) const{
- return x.getMetaKind() == kind::metakind::VARIABLE;
-}
-
ArithVar TheoryArith::requestArithVar(TNode x, bool basic){
- Assert(isTheoryLeaf(x));
+ Assert(isLeaf(x));
Assert(!hasArithVar(x));
ArithVar varX = d_variables.size();
@@ -179,7 +173,9 @@ void TheoryArith::asVectors(Polynomial& p, std::vector<Rational>& coeffs, std::v
Node n = variable.getNode();
- Assert(isTheoryLeaf(n));
+ Debug("rewriter") << "should be var: " << n << endl;
+
+ Assert(isLeaf(n));
Assert(hasArithVar(n));
ArithVar av = asArithVar(n);
@@ -191,8 +187,6 @@ void TheoryArith::asVectors(Polynomial& p, std::vector<Rational>& coeffs, std::v
void TheoryArith::setupSlack(TNode left){
-
-
++(d_statistics.d_statSlackVariables);
TypeNode real_type = NodeManager::currentNM()->realType();
Node slack = NodeManager::currentNM()->mkVar(real_type);
@@ -242,10 +236,6 @@ void TheoryArith::setupInitialValue(ArithVar x){
Debug("arithgc") << "setupVariable("<<x<<")"<<std::endl;
};
-RewriteResponse TheoryArith::preRewrite(TNode n, bool topLevel) {
- return d_rewriter.preRewrite(n);
-}
-
void TheoryArith::registerTerm(TNode tn){
Debug("arith") << "registerTerm(" << tn << ")" << endl;
}
@@ -270,7 +260,7 @@ TNode getSide(TNode assertion, Kind simpleKind){
ArithVar TheoryArith::determineLeftVariable(TNode assertion, Kind simpleKind){
TNode left = getSide<true>(assertion, simpleKind);
- if(isTheoryLeaf(left)){
+ if(isLeaf(left)){
return asArithVar(left);
}else{
Assert(left.hasAttribute(Slack()));
@@ -457,7 +447,7 @@ void TheoryArith::check(Effort effortLevel){
}
}
-void TheoryArith::explain(TNode n, Effort e) {
+void TheoryArith::explain(TNode n) {
// Node explanation = d_propagator.explain(n);
// Debug("arith") << "arith::explain("<<explanation<<")->"
// << explanation << endl;
@@ -552,3 +542,7 @@ Node TheoryArith::getValue(TNode n, TheoryEngine* engine) {
Unhandled(n.getKind());
}
}
+
+void TheoryArith::notifyEq(TNode lhs, TNode rhs) {
+
+}
diff --git a/src/theory/arith/theory_arith.h b/src/theory/arith/theory_arith.h
index e9ff06adb..c95ca6cc4 100644
--- a/src/theory/arith/theory_arith.h
+++ b/src/theory/arith/theory_arith.h
@@ -94,31 +94,14 @@ private:
*/
Tableau d_tableau;
- /**
- * The rewriter module for arithmetic.
- */
- ArithRewriter d_rewriter;
-
ArithUnatePropagator d_propagator;
SimplexDecisionProcedure d_simplex;
public:
- TheoryArith(int id, context::Context* c, OutputChannel& out);
+ TheoryArith(context::Context* c, OutputChannel& out);
~TheoryArith();
/**
- * Rewriting optimizations.
- */
- RewriteResponse preRewrite(TNode n, bool topLevel);
-
- /**
- * Plug in old rewrite to the new (pre,post)rewrite interface.
- */
- RewriteResponse postRewrite(TNode n, bool topLevel) {
- return d_rewriter.postRewrite(n);
- }
-
- /**
* Does non-context dependent setup for a node connected to a theory.
*/
void preRegisterTerm(TNode n);
@@ -128,7 +111,9 @@ public:
void check(Effort e);
void propagate(Effort e);
- void explain(TNode n, Effort e);
+ void explain(TNode n);
+
+ void notifyEq(TNode lhs, TNode rhs);
Node getValue(TNode n, TheoryEngine* engine);
@@ -144,8 +129,6 @@ public:
private:
- bool isTheoryLeaf(TNode x) const;
-
ArithVar determineLeftVariable(TNode assertion, Kind simpleKind);
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback