summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTim King <taking@cs.nyu.edu>2010-09-13 16:08:21 +0000
committerTim King <taking@cs.nyu.edu>2010-09-13 16:08:21 +0000
commit0e18d60841c2a7cd5c079b6c0dacf5d61afb4835 (patch)
tree470e4868ca9576dc20d491afa7462d6e9f1f8c56
parent8d74ddb6380f39034e5cae5d4b094a283e14ffb3 (diff)
* New normal form for arithmetic is in place.
* src/theory/arith/normal_form.{h,cpp} contains the description for the new normal form as well as utilities for dealing with the normal form. * src/theory/arith/next_arith_rewriter.{h,cpp} contains the new rewriter. The new rewriter implements preRewrite() and postRewrite() for arithmetic. * src/theory/arith/arith_rewriter.{h,cpp} have been removed. * TheoryArith::rewrite() has been removed. * Arithmetic with the new normal form outperforms the trunk where the branch occurred (-r797) on 46% of the examples in QF_LRA. (33% have no noticeable difference.) Some important optimizations are stilling pending to the code for handling the new normal form. (Bug 196.)
-rw-r--r--src/expr/node_manager.cpp3
-rw-r--r--src/theory/arith/Makefile.am6
-rw-r--r--src/theory/arith/arith_rewriter.cpp557
-rw-r--r--src/theory/arith/arith_rewriter.h123
-rw-r--r--src/theory/arith/arith_utilities.h17
-rw-r--r--src/theory/arith/kinds2
-rw-r--r--src/theory/arith/next_arith_rewriter.cpp326
-rw-r--r--src/theory/arith/next_arith_rewriter.h74
-rw-r--r--src/theory/arith/normal_form.cpp250
-rw-r--r--src/theory/arith/normal_form.h613
-rw-r--r--src/theory/arith/tableau.h41
-rw-r--r--src/theory/arith/theory_arith.cpp79
-rw-r--r--src/theory/arith/theory_arith.h11
-rw-r--r--test/unit/theory/theory_arith_white.h171
14 files changed, 1389 insertions, 884 deletions
diff --git a/src/expr/node_manager.cpp b/src/expr/node_manager.cpp
index d017ad799..37ed4fe20 100644
--- a/src/expr/node_manager.cpp
+++ b/src/expr/node_manager.cpp
@@ -235,6 +235,9 @@ TypeNode NodeManager::getType(TNode n, bool check)
case kind::APPLY_UF:
typeNode = CVC4::theory::uf::UfTypeRule::computeType(this, n, check);
break;
+ case kind::IDENTITY:
+ typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check);
+ break;
case kind::PLUS:
typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check);
break;
diff --git a/src/theory/arith/Makefile.am b/src/theory/arith/Makefile.am
index e500f5cf8..ead39082c 100644
--- a/src/theory/arith/Makefile.am
+++ b/src/theory/arith/Makefile.am
@@ -7,8 +7,10 @@ noinst_LTLIBRARIES = libarith.la
libarith_la_SOURCES = \
theory_arith_type_rules.h \
- arith_rewriter.h \
- arith_rewriter.cpp \
+ next_arith_rewriter.h \
+ next_arith_rewriter.cpp \
+ normal_form.h\
+ normal_form.cpp \
arith_utilities.h \
arith_constants.h \
arith_activity.h \
diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp
deleted file mode 100644
index ba1445df8..000000000
--- a/src/theory/arith/arith_rewriter.cpp
+++ /dev/null
@@ -1,557 +0,0 @@
-/********************* */
-/*! \file arith_rewriter.cpp
- ** \verbatim
- ** Original author: taking
- ** Major contributors: none
- ** Minor contributors (to current version): mdeters
- ** This file is part of the CVC4 prototype.
- ** Copyright (c) 2009, 2010 The Analysis of Computer Systems Group (ACSys)
- ** Courant Institute of Mathematical Sciences
- ** New York University
- ** See the file COPYING in the top-level source directory for licensing
- ** information.\endverbatim
- **
- ** \brief [[ Add one-line brief description here ]]
- **
- ** [[ Add lengthier description here ]]
- ** \todo document this file
- **/
-
-
-#include "theory/arith/arith_rewriter.h"
-#include "theory/arith/arith_utilities.h"
-
-#include <vector>
-#include <set>
-#include <stack>
-
-
-using namespace CVC4;
-using namespace CVC4::theory;
-using namespace CVC4::theory::arith;
-
-
-
-
-
-Kind multKind(Kind k, int sgn);
-
-/**
- * Performs a quick check to see if it is easy to rewrite to
- * this normal form
- * v |><| b
- * Also writes relations with constants on both sides to TRUE or FALSE.
- * If it can, it returns true and sets res to this value.
- *
- * This is for optimizing rewriteAtom() to avoid the more computationally
- * expensive general rewriting procedure.
- *
- * If simplification is not done, it returns Node::null()
- */
-Node almostVarOrConstEqn(TNode atom, Kind k, TNode left, TNode right){
- Assert(atom.getKind() == k);
- Assert(isRelationOperator(k));
- Assert(atom[0] == left);
- Assert(atom[1] == right);
- bool leftIsConst = left.getMetaKind() == kind::metakind::CONSTANT;
- bool rightIsConst = right.getMetaKind() == kind::metakind::CONSTANT;
-
- bool leftIsVar = left.getMetaKind() == kind::metakind::VARIABLE;
- bool rightIsVar = right.getMetaKind() == kind::metakind::VARIABLE;
-
- if(leftIsConst && rightIsConst){
- Rational lc = coerceToRational(left);
- Rational rc = coerceToRational(right);
- bool res = evaluateConstantPredicate(k,lc, rc);
- return mkBoolNode(res);
- }else if(leftIsVar && rightIsConst){
- if(right.getKind() == kind::CONST_RATIONAL){
- return atom;
- }else{
- return NodeManager::currentNM()->mkNode(k,left,coerceToRationalNode(right));
- }
- }else if(leftIsConst && rightIsVar){
- if(left.getKind() == kind::CONST_RATIONAL){
- return NodeManager::currentNM()->mkNode(multKind(k,-1),right,left);
- }else{
- Node q_left = coerceToRationalNode(left);
- return NodeManager::currentNM()->mkNode(multKind(k,-1),right,q_left);
- }
- }
-
- return Node::null();
-}
-
-Node ArithRewriter::rewriteAtomCore(TNode atom){
-
- Kind k = atom.getKind();
- Assert(isRelationOperator(k));
-
- // left |><| right
- TNode left = atom[0];
- TNode right = atom[1];
-
- Node nf = almostVarOrConstEqn(atom, k,left,right);
- if(nf != Node::null() ){
- return nf;
- }
-
-
- //Transform this to: (left- right) |><| 0
- Node diff = makeSubtractionNode(left, right);
-
- Node rewritten = rewrite(diff);
- // rewritten =_{Reals} left - right => rewritten |><| 0
-
- if(rewritten.getMetaKind() == kind::metakind::CONSTANT){
- // Case 1 rewritten : c
- Rational c = rewritten.getConst<Rational>();
- bool res = evaluateConstantPredicate(k, c, d_constants->d_ZERO);
- nf = mkBoolNode(res);
- }else if(rewritten.getMetaKind() == kind::metakind::VARIABLE){
- // Case 2 rewritten : v
- nf = NodeManager::currentNM()->mkNode(k, rewritten, d_constants->d_ZERO_NODE);
- }else{
- // Case 3 rewritten : (+ c p_1 p_2 ... p_N) | not(N=1 and c=0 and p_1.d=1)
- Rational c = rewritten[0].getConst<Rational>();
- c = -c;
- TNode p_1 = rewritten[1];
- Rational d = p_1[0].getConst<Rational>();
- d = d.inverse();
- c = c * d;
- Node newRight = mkRationalNode(c);
- Kind newKind = multKind(k, d.sgn());
- int N = rewritten.getNumChildren() - 1;
-
- if(N==1){
- int M = p_1.getNumChildren()-1;
- if(M == 1){ // v |><| b
- TNode v = p_1[1];
- nf = NodeManager::currentNM()->mkNode(newKind, v, newRight);
- }else{ // p |><| b
- Node newLeft = multPnfByNonZero(p_1, d);
- nf = NodeManager::currentNM()->mkNode(newKind, newLeft, newRight);
- }
- }else{ //(+ p_1 .. p_N) |><| b
- NodeBuilder<> plus(kind::PLUS);
- for(int i=1; i<=N; ++i){
- TNode p_i = rewritten[i];
- plus << multPnfByNonZero(p_i, d);
- }
- Node newLeft = plus;
- nf = NodeManager::currentNM()->mkNode(newKind, newLeft, newRight);
- }
- }
-
- return nf;
-}
-
-Node ArithRewriter::rewriteAtom(TNode atom){
- Node rewritten = rewriteAtomCore(atom);
- if(rewritten.getKind() == kind::LT){
- Node geq = NodeManager::currentNM()->mkNode(kind::GEQ, rewritten[0], rewritten[1]);
- return NodeManager::currentNM()->mkNode(kind::NOT, geq);
- }else if(rewritten.getKind() == kind::GT){
- Node leq = NodeManager::currentNM()->mkNode(kind::LEQ, rewritten[0], rewritten[1]);
- return NodeManager::currentNM()->mkNode(kind::NOT, leq);
- }else{
- return rewritten;
- }
-}
-
-
-/* cmp( (* d v_1 v_2 ... v_M), (* d' v'_1 v'_2 ... v'_M'):
- * if(M == M'):
- * then tupleCompare(v_i, v'_i)
- * else M -M'
- */
-struct pnfLessThan {
- bool operator()(Node p0, Node p1) {
- int p0_M = p0.getNumChildren() -1;
- int p1_M = p1.getNumChildren() -1;
- if(p0_M == p1_M){
- for(int i=1; i<= p0_M; ++i){
- if(p0[i] != p1[i]){
- return p0[i] < p1[i];
- }
- }
- return false; //p0 == p1 in this order
- }else{
- return p0_M < p1_M;
- }
- }
-};
-
-//Two pnfs are equal up to their coefficients
-bool pnfsMatch(TNode p0, TNode p1){
-
- unsigned M = p0.getNumChildren()-1;
- if (M+1 != p1.getNumChildren()){
- return false;
- }
-
- for(unsigned i=1; i <= M; ++i){
- if(p0[i] != p1[i])
- return false;
- }
- return true;
-}
-
-Node addMatchingPnfs(TNode p0, TNode p1){
- Assert(pnfsMatch(p0,p1));
-
- unsigned M = p0.getNumChildren()-1;
-
- Rational c0 = p0[0].getConst<Rational>();
- Rational c1 = p1[0].getConst<Rational>();
-
- Rational addedC = c0 + c1;
- Node newC = mkRationalNode(addedC);
- NodeBuilder<> nb(kind::MULT);
- nb << newC;
- for(unsigned i=1; i <= M; ++i){
- nb << p0[i];
- }
- Node newPnf = nb;
- return newPnf;
-}
-
-void ArithRewriter::sortAndCombineCoefficients(std::vector<Node>& pnfs){
- using namespace std;
-
- /* combined contains exactly 1 representative per for each pnf.
- * This is maintained by combining the coefficients for pnfs.
- * that is equal according to pnfLessThan.
- */
- typedef set<Node, pnfLessThan> PnfSet;
- PnfSet combined;
-
- for(vector<Node>::iterator i=pnfs.begin(); i != pnfs.end(); ++i){
- Node pnf = *i;
- PnfSet::iterator pos = combined.find(pnf);
-
- if(pos == combined.end()){
- combined.insert(pnf);
- }else{
- Node current = *pos;
- Node sum = addMatchingPnfs(pnf, current);
- combined.erase(pos);
- combined.insert(sum);
- }
- }
- pnfs.clear();
- for(PnfSet::iterator i=combined.begin(); i != combined.end(); ++i){
- Node pnf = *i;
- if(pnf[0].getConst<Rational>() != d_constants->d_ZERO){
- //after combination the coefficient may be zero
- pnfs.push_back(pnf);
- }
- }
-}
-
-Node ArithRewriter::var2pnf(TNode variable){
- return NodeManager::currentNM()->mkNode(kind::MULT,d_constants->d_ONE_NODE,variable);
-}
-
-Node ArithRewriter::rewritePlus(TNode t){
- using namespace std;
-
- Rational accumulator;
- vector<Node> pnfs;
-
- for(TNode::iterator i = t.begin(); i!= t.end(); ++i){
- TNode child = *i;
- Node rewrittenChild = rewrite(child);
-
- if(rewrittenChild.getMetaKind() == kind::metakind::CONSTANT){//c
- Rational c = rewrittenChild.getConst<Rational>();
- accumulator = accumulator + c;
- }else if(rewrittenChild.getMetaKind() == kind::metakind::VARIABLE){ //v
- Node pnf = var2pnf(rewrittenChild);
- pnfs.push_back(pnf);
- }else{ //(+ c p_1 p_2 ... p_N)
- Rational c = rewrittenChild[0].getConst<Rational>();
- accumulator = accumulator + c;
- int N = rewrittenChild.getNumChildren() - 1;
- for(int i=1; i<=N; ++i){
- TNode pnf = rewrittenChild[i];
- pnfs.push_back(pnf);
- }
- }
- }
- sortAndCombineCoefficients(pnfs);
- if(pnfs.size() == 0){
- return mkRationalNode(accumulator);
- }
-
- // pnfs.size() >= 1
-
- //Enforce not(N=1 and c=0 and p_1.d=1)
- if(pnfs.size() == 1){
- Node p_1 = *(pnfs.begin());
- if(p_1[0].getConst<Rational>() == d_constants->d_ONE){
- if(accumulator == d_constants->d_ZERO){ // 0 + (* 1 var) |-> var
- Node var = p_1[1];
- return var;
- }
- }
- }
-
- //We must be in this case
- //(+ c p_1 p_2 ... p_N) | not(N=1 and c=0 and p_1.d=1)
-
- NodeBuilder<> nb(kind::PLUS);
- nb << mkRationalNode(accumulator);
- Debug("arithrewrite") << mkRationalNode(accumulator) << std::endl;
- for(vector<Node>::iterator i = pnfs.begin(); i != pnfs.end(); ++i){
- nb << *i;
- Debug("arithrewrite") << (*i) << std::endl;
-
- }
-
- Node normalForm = nb;
- return normalForm;
-}
-
-//Does not enforce
-//5) v_i are of metakind VARIABLE,
-//6) v_i are in increasing (not strict) nodeOrder,
-Node toPnf(Rational& c, std::set<Node>& variables){
- NodeBuilder<> nb(kind::MULT);
- nb << mkRationalNode(c);
-
- for(std::set<Node>::iterator i = variables.begin(); i != variables.end(); ++i){
- nb << *i;
- }
- Node pnf = nb;
- return pnf;
-}
-
-Node distribute(TNode n, TNode sum){
- NodeBuilder<> nb(kind::PLUS);
- for(TNode::iterator i=sum.begin(); i!=sum.end(); ++i){
- Node prod = NodeManager::currentNM()->mkNode(kind::MULT, n, *i);
- nb << prod;
- }
- return nb;
-}
-Node distributeSum(TNode sum, TNode distribSum){
- NodeBuilder<> nb(kind::PLUS);
- for(TNode::iterator i=sum.begin(); i!=sum.end(); ++i){
- Node dist = distribute(*i, distribSum);
- for(Node::iterator j=dist.begin(); j!=dist.end(); ++j){
- nb << *j;
- }
- }
- return nb;
-}
-
-Node ArithRewriter::rewriteMult(TNode t){
-
- using namespace std;
-
- Rational accumulator(1,1);
- set<Node> variables;
- vector<Node> sums;
-
- //These stacks need to be kept in lock step
- stack<TNode> mult_iterators_nodes;
- stack<TNode::const_iterator> mult_iterators_iters;
-
- mult_iterators_nodes.push(t);
- mult_iterators_iters.push(t.begin());
-
- while(!mult_iterators_nodes.empty()){
- TNode mult = mult_iterators_nodes.top();
- TNode::const_iterator i = mult_iterators_iters.top();
-
- mult_iterators_nodes.pop();
- mult_iterators_iters.pop();
-
- for(; i != mult.end(); ++i){
- TNode child = *i;
- if(child.getKind() == kind::MULT){ //TODO add not rewritten already checks
- ++i;
- mult_iterators_nodes.push(mult);
- mult_iterators_iters.push(i);
-
- mult_iterators_nodes.push(child);
- mult_iterators_iters.push(child.begin());
- break;
- }
- Node rewrittenChild = rewrite(child);
-
- if(rewrittenChild.getMetaKind() == kind::metakind::CONSTANT){//c
- Rational c = rewrittenChild.getConst<Rational>();
- accumulator = accumulator * c;
- if(accumulator == d_constants->d_ZERO){
- return d_constants->d_ZERO_NODE;
- }
- }else if(rewrittenChild.getMetaKind() == kind::metakind::VARIABLE){ //v
- variables.insert(rewrittenChild);
- }else{ //(+ c p_1 p_2 ... p_N)
- sums.push_back(rewrittenChild);
- }
- }
- }
- // accumulator * (\prod var_i) *(\prod sum_j)
-
- if(sums.size() == 0){ //accumulator * (\prod var_i)
- if(variables.size() == 0){ //accumulator
- return mkRationalNode(accumulator);
- }else if(variables.size() == 1 && accumulator == d_constants->d_ONE){ // var_1
- Node var = *(variables.begin());
- return var;
- }else{
- //We need to return (+ c p_1 p_2 ... p_N)
- //To accomplish this:
- // let pnf = pnf(accumulator * (\prod var_i)) in (+ 0 pnf)
- Node pnf = toPnf(accumulator, variables);
- Node normalForm = NodeManager::currentNM()->mkNode(kind::PLUS, d_constants->d_ZERO_NODE, pnf);
- return normalForm;
- }
- }else{
- vector<Node>::iterator sum_iter = sums.begin();
- // \sum t
- // t \in Q \cup A
- // where A = lfp {\prod s | s \in Q \cup Variables \cup A}
- Node distributed = *sum_iter;
- ++sum_iter;
- while(sum_iter != sums.end()){
- Node curr = *sum_iter;
- distributed = distributeSum(curr, distributed);
- ++sum_iter;
- }
- if(variables.size() >= 1){
- Node pnf = toPnf(accumulator, variables);
- distributed = distribute(pnf, distributed);
- }else{
- Node constant = mkRationalNode(accumulator);
- distributed = distribute(constant, distributed);
- }
-
- Node nf_distributed = rewrite(distributed);
- return nf_distributed;
- }
-}
-
-Node ArithRewriter::rewriteDivByConstant(TNode t){
- Assert(t.getKind()== kind::DIVISION);
-
- Node left = t[0];
- Node reRight = rewrite(t[1]);
- Assert(reRight.getKind()== kind::CONST_RATIONAL);
-
-
- Rational den = reRight.getConst<Rational>();
-
- Assert(den != d_constants->d_ZERO);
-
- Rational div = den.inverse();
-
- Node result = mkRationalNode(div);
-
- Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
-
- Node reMult = rewrite(mult);
-
- return reMult;
-}
-
-Node ArithRewriter::rewriteTerm(TNode t){
- if(t.getMetaKind() == kind::metakind::CONSTANT){
- return coerceToRationalNode(t);
- }else if(t.getMetaKind() == kind::metakind::VARIABLE){
- return t;
- }else if(t.getKind() == kind::MULT){
- return rewriteMult(t);
- }else if(t.getKind() == kind::PLUS){
- return rewritePlus(t);
- }else if(t.getKind() == kind::DIVISION){
- return rewriteDivByConstant(t);
- }else if(t.getKind() == kind::MINUS){
- Node sub = makeSubtractionNode(t[0],t[1]);
- return rewrite(sub);
- }else if(t.getKind() == kind::UMINUS){
- Node sub = makeUnaryMinusNode(t[0]);
- return rewrite(sub);
- }else{
- Unhandled(t);
- }
-}
-
-
-/**
- * Given a node in PNF pnf = (* d p_1 p_2 .. p_M) and a rational q != 0
- * constuct a node equal to q * pnf that is in pnf.
- *
- * The claim is that this is always okay:
- * If d' = q*d, p' = (* d' p_1 p_2 .. p_M) =_{Reals} q * pnf.
- */
-Node ArithRewriter::multPnfByNonZero(TNode pnf, Rational& q){
- Assert(q != d_constants->d_ZERO);
- //TODO Assert(isPNF(pnf) );
-
- int M = pnf.getNumChildren()-1;
- Rational d = pnf[0].getConst<Rational>();
- Rational new_d = d*q;
-
-
- NodeBuilder<> mult(kind::MULT);
- mult << mkRationalNode(new_d);
- for(int i=1; i<=M; ++i){
- mult << pnf[i];
- }
-
- Node result = mult;
- return result;
-}
-
-Node ArithRewriter::makeUnaryMinusNode(TNode n){
- Node tmp = NodeManager::currentNM()->mkNode(kind::MULT,d_constants->d_NEGATIVE_ONE_NODE,n);
- return tmp;
-}
-
-Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
- Node negR = makeUnaryMinusNode(r);
- Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
-
- return diff;
-}
-
-
-Kind multKind(Kind k, int sgn){
- using namespace kind;
-
- if(sgn < 0){
-
- switch(k){
- case LT: return GT;
- case LEQ: return GEQ;
- case EQUAL: return EQUAL;
- case GEQ: return LEQ;
- case GT: return LT;
- default:
- Unhandled(k);
- }
- return NULL_EXPR;
- }else{
- return k;
- }
-}
-
-Node ArithRewriter::rewrite(TNode n){
- Debug("arithrewriter") << "Trace rewrite:" << n << std::endl;
-
- Node res;
-
- if(isRelationOperator(n.getKind())){
- res = rewriteAtom(n);
- }else{
- res = rewriteTerm(n);
- }
-
- Debug("arithrewriter") << "Trace rewrite:" << n << "|->"<< res << std::endl;
-
- return res;
-}
diff --git a/src/theory/arith/arith_rewriter.h b/src/theory/arith/arith_rewriter.h
deleted file mode 100644
index a76ee6e61..000000000
--- a/src/theory/arith/arith_rewriter.h
+++ /dev/null
@@ -1,123 +0,0 @@
-/********************* */
-/*! \file arith_rewriter.h
- ** \verbatim
- ** Original author: taking
- ** Major contributors: mdeters
- ** Minor contributors (to current version): none
- ** This file is part of the CVC4 prototype.
- ** Copyright (c) 2009, 2010 The Analysis of Computer Systems Group (ACSys)
- ** Courant Institute of Mathematical Sciences
- ** New York University
- ** See the file COPYING in the top-level source directory for licensing
- ** information.\endverbatim
- **
- ** \brief [[ Add one-line brief description here ]]
- **
- ** [[ Add lengthier description here ]]
- ** \todo document this file
- **/
-
-
-#include "expr/node.h"
-#include "util/rational.h"
-#include "theory/arith/arith_constants.h"
-
-#ifndef __CVC4__THEORY__ARITH__REWRITER_H
-#define __CVC4__THEORY__ARITH__REWRITER_H
-
-namespace CVC4 {
-namespace theory {
-namespace arith {
-
-
-/***********************************************/
-/***************** Normal Form *****************/
-/***********************************************/
-/***********************************************/
-
-/**
- * Normal form for predicates:
- * TRUE
- * FALSE
- * v |><| b
- * p |><| b
- * (+ p_1 .. p_N) |><| b
- * where
- * 1) b is of type CONST_RATIONAL
- * 2) |><| is of kind <, <=, =, >= or >
- * 3) p, p_i is in PNF,
- * 4) p.M >= 2
- * 5) p_i's are in strictly ascending <p,
- * 6) N >= 2,
- * 7) the kind of (+ p_1 .. p_N) is an N arity PLUS,
- * 8) p.d, p_1.d are 1,
- * 9) v has metakind variable, and
- *
- * PNF(t):
- * (* d v_1 v_2 ... v_M)
- * where
- * 1) d is of type CONST_RATIONAL,
- * 2) d != 0,
- * 4) M>=1,
- * 5) v_i are of metakind VARIABLE,
- * 6) v_i are in increasing (not strict) nodeOrder, and
- * 7) the kind of t is an M+1 arity MULT.
- *
- * <p is defined over PNF as follows (skipping some symmetry):
- * cmp( (* d v_1 v_2 ... v_M), (* d' v'_1 v'_2 ... v'_M'):
- * if(M == M'):
- * then tupleCompare(v_i, v'_i)
- * else M -M'
- *
- * Rewrite Normal Form for Terms:
- * b
- * v
- * (+ c p_1 p_2 ... p_N) | not(N=1 and c=0 and p_1.d=1)
- * where
- * 1) b,c is of type CONST_RATIONAL,
- * 3) p_i is in PNF,
- * 4) N >= 1
- * 5) the kind of (+ c p_1 p_2 ... p_N) is an N+1 arity PLUS,
- * 6) and p_i's are in strictly <p.
- *
- */
-
-class ArithRewriter{
-private:
- ArithConstants* d_constants;
-
- //This is where the core of the work is done for rewriteAtom
- //With a few additional checks done by rewriteAtom
- Node rewriteAtomCore(TNode atom);
- Node rewriteAtom(TNode atom);
-
- Node rewriteTerm(TNode t);
- Node rewriteMult(TNode t);
- Node rewritePlus(TNode t);
- Node rewriteMinus(TNode t);
- Node makeSubtractionNode(TNode l, TNode r);
- Node makeUnaryMinusNode(TNode n);
-
-
- Node var2pnf(TNode variable);
-
- Node multPnfByNonZero(TNode pnf, Rational& q);
-
- Node rewriteDivByConstant(TNode t);
- void sortAndCombineCoefficients(std::vector<Node>& pnfs);
-
-
-public:
- ArithRewriter(ArithConstants* ac) :
- d_constants(ac)
- {}
- Node rewrite(TNode t);
-
-};
-
-
-}; /* namesapce arith */
-}; /* namespace theory */
-}; /* namespace CVC4 */
-
-#endif /* __CVC4__THEORY__ARITH__REWRITER_H */
diff --git a/src/theory/arith/arith_utilities.h b/src/theory/arith/arith_utilities.h
index fa3356c60..6706ad76a 100644
--- a/src/theory/arith/arith_utilities.h
+++ b/src/theory/arith/arith_utilities.h
@@ -27,7 +27,7 @@ namespace CVC4 {
namespace theory {
namespace arith {
-inline Node mkRationalNode(Rational& q){
+inline Node mkRationalNode(const Rational& q){
return NodeManager::currentNM()->mkConst<Rational>(q);
}
@@ -87,6 +87,21 @@ inline bool isRelationOperator(Kind k){
}
}
+/** is k \in {LT, LEQ, EQ, GEQ, GT} */
+inline Kind negateRelationKind(Kind k){
+ using namespace kind;
+
+ switch(k){
+ case LT: return GT;
+ case LEQ: return GEQ;
+ case EQUAL: return EQUAL;
+ case GEQ: return LEQ;
+ case GT: return LT;
+
+ default:
+ Unreachable();
+ }
+}
inline bool evaluateConstantPredicate(Kind k, const Rational& left, const Rational& right){
using namespace kind;
diff --git a/src/theory/arith/kinds b/src/theory/arith/kinds
index 99f7258da..07d48b1f6 100644
--- a/src/theory/arith/kinds
+++ b/src/theory/arith/kinds
@@ -12,6 +12,8 @@ operator MINUS 2 "arithmetic binary subtraction operator"
operator UMINUS 1 "arithmetic unary negation"
operator DIVISION 2 "arithmetic division"
+operator IDENTITY 1 "identity function"
+
constant CONST_RATIONAL \
::CVC4::Rational \
::CVC4::RationalHashStrategy \
diff --git a/src/theory/arith/next_arith_rewriter.cpp b/src/theory/arith/next_arith_rewriter.cpp
new file mode 100644
index 000000000..c14f806c9
--- /dev/null
+++ b/src/theory/arith/next_arith_rewriter.cpp
@@ -0,0 +1,326 @@
+/********************* */
+/*! \file arith_rewriter.cpp
+ ** \verbatim
+ ** Original author: taking
+ ** Major contributors: none
+ ** Minor contributors (to current version): mdeters
+ ** This file is part of the CVC4 prototype.
+ ** Copyright (c) 2009, 2010 The Analysis of Computer Systems Group (ACSys)
+ ** Courant Institute of Mathematical Sciences
+ ** New York University
+ ** See the file COPYING in the top-level source directory for licensing
+ ** information.\endverbatim
+ **
+ ** \brief [[ Add one-line brief description here ]]
+ **
+ ** [[ Add lengthier description here ]]
+ ** \todo document this file
+ **/
+
+
+#include "theory/theory.h"
+#include "theory/arith/normal_form.h"
+#include "theory/arith/next_arith_rewriter.h"
+#include "theory/arith/arith_utilities.h"
+
+#include <vector>
+#include <set>
+#include <stack>
+
+
+using namespace CVC4;
+using namespace CVC4::theory;
+using namespace CVC4::theory::arith;
+
+bool isVariable(TNode t){
+ return t.getMetaKind() == kind::metakind::VARIABLE;
+}
+
+RewriteResponse NextArithRewriter::rewriteConstant(TNode t){
+ Assert(t.getMetaKind() == kind::metakind::CONSTANT);
+ Node val = coerceToRationalNode(t);
+
+ return RewriteComplete(val);
+}
+
+RewriteResponse NextArithRewriter::rewriteVariable(TNode t){
+ Assert(isVariable(t));
+
+ return RewriteComplete(t);
+}
+
+RewriteResponse NextArithRewriter::rewriteMinus(TNode t, bool pre){
+ Assert(t.getKind()== kind::MINUS);
+
+ if(t[0] == t[1]) return RewriteComplete(d_constants->d_ZERO_NODE);
+
+ Node noMinus = makeSubtractionNode(t[0],t[1]);
+ if(pre){
+ return RewriteComplete(noMinus);
+ }else{
+ return FullRewriteNeeded(noMinus);
+ }
+}
+
+RewriteResponse NextArithRewriter::rewriteUMinus(TNode t, bool pre){
+ Assert(t.getKind()== kind::UMINUS);
+
+ Node noUminus = makeUnaryMinusNode(t[0]);
+ if(pre)
+ return RewriteComplete(noUminus);
+ else
+ return RewriteAgain(noUminus);
+}
+
+RewriteResponse NextArithRewriter::preRewriteTerm(TNode t){
+ if(t.getMetaKind() == kind::metakind::CONSTANT){
+ return rewriteConstant(t);
+ }else if(isVariable(t)){
+ return rewriteVariable(t);
+ }else if(t.getKind() == kind::MINUS){
+ return rewriteMinus(t, true);
+ }else if(t.getKind() == kind::UMINUS){
+ return rewriteUMinus(t, true);
+ }else if(t.getKind() == kind::DIVISION){
+ if(t[0].getKind()== kind::CONST_RATIONAL){
+ return rewriteDivByConstant(t, true);
+ }else{
+ return RewriteComplete(t);
+ }
+ }else if(t.getKind() == kind::PLUS){
+ return preRewritePlus(t);
+ }else if(t.getKind() == kind::MULT){
+ return preRewriteMult(t);
+ }else{
+ Unreachable();
+ }
+}
+RewriteResponse NextArithRewriter::postRewriteTerm(TNode t){
+ if(t.getMetaKind() == kind::metakind::CONSTANT){
+ return rewriteConstant(t);
+ }else if(isVariable(t)){
+ return rewriteVariable(t);
+ }else if(t.getKind() == kind::MINUS){
+ return rewriteMinus(t, false);
+ }else if(t.getKind() == kind::UMINUS){
+ return rewriteUMinus(t, false);
+ }else if(t.getKind() == kind::DIVISION){
+ return rewriteDivByConstant(t, false);
+ }else if(t.getKind() == kind::PLUS){
+ return postRewritePlus(t);
+ }else if(t.getKind() == kind::MULT){
+ return postRewriteMult(t);
+ }else{
+ Unreachable();
+ }
+}
+
+RewriteResponse NextArithRewriter::preRewriteMult(TNode t){
+ Assert(t.getKind()== kind::MULT);
+
+ // Rewrite multiplications with a 0 argument and to 0
+ Integer intZero;
+
+ for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
+ if((*i).getKind() == kind::CONST_RATIONAL) {
+ if((*i).getConst<Rational>() == d_constants->d_ZERO) {
+ return RewriteComplete(d_constants->d_ZERO_NODE);
+ }
+ } else if((*i).getKind() == kind::CONST_INTEGER) {
+ if((*i).getConst<Integer>() == intZero) {
+ if(t.getType().isInteger()) {
+ return RewriteComplete(NodeManager::currentNM()->mkConst(intZero));
+ } else {
+ return RewriteComplete(d_constants->d_ZERO_NODE);
+ }
+ }
+ }
+ }
+ return RewriteComplete(t);
+}
+RewriteResponse NextArithRewriter::preRewritePlus(TNode t){
+ Assert(t.getKind()== kind::PLUS);
+
+ return RewriteComplete(t);
+}
+
+RewriteResponse NextArithRewriter::postRewritePlus(TNode t){
+ Assert(t.getKind()== kind::PLUS);
+
+ Polynomial res = Polynomial::mkZero();
+
+ for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
+ Node curr = *i;
+ Polynomial currPoly = Polynomial::parsePolynomial(curr);
+
+ res = res + currPoly;
+ }
+
+ return RewriteComplete(res.getNode());
+}
+
+RewriteResponse NextArithRewriter::postRewriteMult(TNode t){
+ Assert(t.getKind()== kind::MULT);
+
+ Polynomial res = Polynomial::mkOne();
+
+ for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
+ Node curr = *i;
+ Polynomial currPoly = Polynomial::parsePolynomial(curr);
+
+ res = res * currPoly;
+ }
+
+ return RewriteComplete(res.getNode());
+}
+
+RewriteResponse NextArithRewriter::postRewriteAtomConstantRHS(TNode t){
+ TNode left = t[0];
+ TNode right = t[1];
+
+
+ Comparison cmp = Comparison::mkComparison(t.getKind(), Polynomial::parsePolynomial(left), Constant(right));
+
+ if(cmp.isBoolean()){
+ return RewriteComplete(cmp.getNode());
+ }
+
+ if(cmp.getLeft().containsConstant()){
+ Monomial constantHead = cmp.getLeft().getHead();
+ Assert(constantHead.isConstant());
+
+ Constant constant = constantHead.getConstant();
+
+ Constant negativeConstantHead = -constant;
+
+ cmp = cmp.addConstant(negativeConstantHead);
+ }
+ Assert(!cmp.getLeft().containsConstant());
+
+ if(!cmp.getLeft().getHead().coefficientIsOne()){
+ Monomial constantHead = cmp.getLeft().getHead();
+ Assert(!constantHead.isConstant());
+ Constant constant = constantHead.getConstant();
+
+ Constant inverse = Constant::mkConstant(constant.getValue().inverse());
+
+ cmp = cmp.multiplyConstant(inverse);
+ }
+ Assert(cmp.getLeft().getHead().coefficientIsOne());
+
+ Assert(cmp.isBoolean() || cmp.isNormalForm());
+ return RewriteComplete(cmp.getNode());
+}
+
+RewriteResponse NextArithRewriter::postRewriteAtom(TNode atom){
+ // left |><| right
+ TNode left = atom[0];
+ TNode right = atom[1];
+
+ if(right.getMetaKind() == kind::metakind::CONSTANT){
+ return postRewriteAtomConstantRHS(atom);
+ }else{
+ //Transform this to: (left - right) |><| 0
+ Node diff = makeSubtractionNode(left, right);
+ Node reduction = NodeManager::currentNM()->mkNode(atom.getKind(), diff, d_constants->d_ZERO_NODE);
+ return FullRewriteNeeded(reduction);
+ }
+}
+
+RewriteResponse NextArithRewriter::preRewriteAtom(TNode atom){
+ Assert(isAtom(atom));
+ NodeManager* currNM = NodeManager::currentNM();
+
+ if(atom.getKind() == kind::EQUAL) {
+ if(atom[0] == atom[1]) {
+ return RewriteComplete(currNM->mkConst(true));
+ }
+ }
+
+ Node reduction = atom;
+
+ if(atom[1].getMetaKind() != kind::metakind::CONSTANT){
+ // left |><| right
+ TNode left = atom[0];
+ TNode right = atom[1];
+
+ //Transform this to: (left - right) |><| 0
+ Node diff = makeSubtractionNode(left, right);
+ reduction = currNM->mkNode(atom.getKind(), diff, d_constants->d_ZERO_NODE);
+ }
+
+ if(reduction.getKind() == kind::GT){
+ Node leq = currNM->mkNode(kind::LEQ, reduction[0], reduction[1]);
+ reduction = currNM->mkNode(kind::NOT, leq);
+ }else if(reduction.getKind() == kind::LT){
+ Node geq = currNM->mkNode(kind::GEQ, reduction[0], reduction[1]);
+ reduction = currNM->mkNode(kind::NOT, geq);
+ }
+
+ return RewriteComplete(reduction);
+}
+
+RewriteResponse NextArithRewriter::postRewrite(TNode t){
+ if(isTerm(t)){
+ RewriteResponse response = postRewriteTerm(t);
+ if(Debug.isOn("arith::rewriter") && response.isDone()) {
+ Polynomial::parsePolynomial(response.getNode());
+ }
+ return response;
+ }else if(isAtom(t)){
+ RewriteResponse response = postRewriteAtom(t);
+ if(Debug.isOn("arith::rewriter") && response.isDone()) {
+ Comparison::parseNormalForm(response.getNode());
+ }
+ return response;
+ }else{
+ Unreachable();
+ return RewriteComplete(Node::null());
+ }
+}
+
+RewriteResponse NextArithRewriter::preRewrite(TNode t){
+ if(isTerm(t)){
+ return preRewriteTerm(t);
+ }else if(isAtom(t)){
+ return preRewriteAtom(t);
+ }else{
+ Unreachable();
+ return RewriteComplete(Node::null());
+ }
+}
+
+Node NextArithRewriter::makeUnaryMinusNode(TNode n){
+ return NodeManager::currentNM()->mkNode(kind::MULT,d_constants->d_NEGATIVE_ONE_NODE,n);
+}
+
+Node NextArithRewriter::makeSubtractionNode(TNode l, TNode r){
+ Node negR = makeUnaryMinusNode(r);
+ Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
+
+ return diff;
+}
+
+RewriteResponse NextArithRewriter::rewriteDivByConstant(TNode t, bool pre){
+ Assert(t.getKind()== kind::DIVISION);
+
+ Node left = t[0];
+ Node right = t[1];
+ Assert(right.getKind()== kind::CONST_RATIONAL);
+
+
+ const Rational& den = right.getConst<Rational>();
+
+ Assert(den != d_constants->d_ZERO);
+
+ Rational div = den.inverse();
+
+ Node result = mkRationalNode(div);
+
+ Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
+ if(pre){
+ return RewriteComplete(mult);
+ }else{
+ return RewriteAgain(mult);
+ }
+}
diff --git a/src/theory/arith/next_arith_rewriter.h b/src/theory/arith/next_arith_rewriter.h
new file mode 100644
index 000000000..7f1ec0fbd
--- /dev/null
+++ b/src/theory/arith/next_arith_rewriter.h
@@ -0,0 +1,74 @@
+/********************* */
+/*! \file arith_rewriter.h
+ ** \verbatim
+ ** Original author: taking
+ ** Major contributors: mdeters
+ ** Minor contributors (to current version): none
+ ** This file is part of the CVC4 prototype.
+ ** Copyright (c) 2009, 2010 The Analysis of Computer Systems Group (ACSys)
+ ** Courant Institute of Mathematical Sciences
+ ** New York University
+ ** See the file COPYING in the top-level source directory for licensing
+ ** information.\endverbatim
+ **
+ ** \brief [[ Add one-line brief description here ]]
+ **
+ ** [[ Add lengthier description here ]]
+ ** \todo document this file
+ **/
+
+#include "theory/arith/arith_constants.h"
+#include "theory/theory.h"
+#include "theory/arith/normal_form.h"
+
+#ifndef __CVC4__THEORY__ARITH__REWRITER_NEXT_H
+#define __CVC4__THEORY__ARITH__REWRITER_NEXT_H
+
+namespace CVC4 {
+namespace theory {
+namespace arith {
+
+class NextArithRewriter{
+private:
+ ArithConstants* d_constants;
+
+ Node makeSubtractionNode(TNode l, TNode r);
+ Node makeUnaryMinusNode(TNode n);
+
+ RewriteResponse preRewriteTerm(TNode t);
+ RewriteResponse postRewriteTerm(TNode t);
+
+ RewriteResponse rewriteVariable(TNode t);
+ RewriteResponse rewriteConstant(TNode t);
+ RewriteResponse rewriteMinus(TNode t, bool pre);
+ RewriteResponse rewriteUMinus(TNode t, bool pre);
+ RewriteResponse rewriteDivByConstant(TNode t, bool pre);
+
+ RewriteResponse preRewritePlus(TNode t);
+ RewriteResponse postRewritePlus(TNode t);
+
+ RewriteResponse preRewriteMult(TNode t);
+ RewriteResponse postRewriteMult(TNode t);
+
+
+ RewriteResponse preRewriteAtom(TNode t);
+ RewriteResponse postRewriteAtom(TNode t);
+ RewriteResponse postRewriteAtomConstantRHS(TNode t);
+
+public:
+ NextArithRewriter(ArithConstants* ac) : d_constants(ac) {}
+
+ RewriteResponse preRewrite(TNode n);
+ RewriteResponse postRewrite(TNode n);
+
+private:
+ bool isAtom(TNode n) const { return isRelationOperator(n.getKind()); }
+ bool isTerm(TNode n) const { return !isAtom(n); }
+};
+
+
+}; /* namesapce arith */
+}; /* namespace theory */
+}; /* namespace CVC4 */
+
+#endif /* __CVC4__THEORY__ARITH__REWRITER_NEXT_H */
diff --git a/src/theory/arith/normal_form.cpp b/src/theory/arith/normal_form.cpp
new file mode 100644
index 000000000..18e31848b
--- /dev/null
+++ b/src/theory/arith/normal_form.cpp
@@ -0,0 +1,250 @@
+
+#include "theory/arith/normal_form.h"
+#include <list>
+
+using namespace std;
+using namespace CVC4;
+using namespace CVC4::theory;
+using namespace CVC4::theory::arith;
+
+bool VarList::isSorted(iterator start, iterator end){
+ return __gnu_cxx::is_sorted(start, end);
+}
+
+bool VarList::isMember(Node n){
+ if(n.getNumChildren() == 0){
+ return Variable::isMember(n);
+ }else if(n.getKind() == kind::MULT){
+ Node::iterator curr = n.begin(), end = n.end();
+ Node prev = *curr;
+ if(!Variable::isMember(prev)) return false;
+
+ while( (++curr) != end){
+ if(!Variable::isMember(*curr)) return false;
+ if(!(prev <= *curr)) return false;
+ prev = *curr;
+ }
+ return true;
+ }else{
+ return false;
+ }
+}
+int VarList::cmp(const VarList& vl) const{
+ int dif = this->size() - vl.size();
+ if (dif == 0){
+ return this->getNode().getId() - vl.getNode().getId();
+ }else if(dif < 0){
+ return -1;
+ }else{
+ return 1;
+ }
+}
+
+VarList VarList::parseVarList(Node n){
+ if(n.getNumChildren() == 0){
+ return VarList(Variable(n));
+ }else{
+ Assert(n.getKind() == kind::MULT);
+ for(Node::iterator i=n.begin(), end = n.end(); i!=end; ++i){
+ Assert(Variable::isMember(*i));
+ }
+ return VarList(n);
+ }
+}
+
+VarList VarList::operator*(const VarList& vl) const{
+ if(this->empty()){
+ return vl;
+ }else if(vl.empty()){
+ return *this;
+ }else{
+ vector<Variable> result;
+ vector<Variable> thisAsVec = this->asList();
+ vector<Variable> vlAsVec = vl.asList();
+ back_insert_iterator<vector<Variable> > bii(result);
+
+ merge(thisAsVec.begin(), thisAsVec.end(), vlAsVec.begin(), vlAsVec.end(), bii);
+
+ return VarList::mkVarList(result);
+ }
+}
+
+std::vector<Variable> VarList::asList() const {
+ vector<Variable> res;
+ for(iterator i = begin(), e = end(); i != e; ++i){
+ res.push_back(*i);
+ }
+ return res;
+}
+
+Monomial Monomial::mkMonomial(const Constant& c, const VarList& vl){
+ if(c.isZero() || vl.empty() ){
+ return Monomial(c);
+ }else if(c.isOne()){
+ return Monomial(vl);
+ }else{
+ return Monomial(c, vl);
+ }
+}
+Monomial Monomial::parseMonomial(Node n){
+ if(n.getKind() == kind::CONST_RATIONAL){
+ return Monomial(Constant(n));
+ }else if(multStructured(n)){
+ return Monomial::mkMonomial(Constant(n[0]),VarList::parseVarList(n[1]));
+ }else{
+ return Monomial(VarList::parseVarList(n));
+ }
+}
+
+Monomial Monomial::operator*(const Monomial& mono) const {
+ Constant newConstant = this->getConstant() * mono.getConstant();
+ VarList newVL = this->getVarList() * mono.getVarList();
+
+ return Monomial::mkMonomial(newConstant, newVL);
+}
+
+vector<Monomial> Monomial::sumLikeTerms(const vector<Monomial> & monos){
+ Assert(isSorted(monos));
+
+ Debug("blah") << "start sumLikeTerms" << std::endl;
+ printList(monos);
+ vector<Monomial> outMonomials;
+ typedef vector<Monomial>::const_iterator iterator;
+ for(iterator rangeIter = monos.begin(), end=monos.end(); rangeIter != end;){
+ Rational constant = (*rangeIter).getConstant().getValue();
+ VarList varList = (*rangeIter).getVarList();
+ ++rangeIter;
+ while(rangeIter != end && varList == (*rangeIter).getVarList()){
+ constant += (*rangeIter).getConstant().getValue();
+ ++rangeIter;
+ }
+ if(constant != 0){
+ Constant asConstant = Constant::mkConstant(constant);
+ Monomial nonZero = Monomial::mkMonomial(asConstant, varList);
+ outMonomials.push_back(nonZero);
+ }
+ }
+ Debug("blah") << "outmonomials" << std::endl;
+ printList(monos);
+ Debug("blah") << "end sumLikeTerms" << std::endl;
+
+ Assert(isStrictlySorted(outMonomials));
+ return outMonomials;
+}
+
+void Monomial::printList(const std::vector<Monomial>& monos){
+ typedef std::vector<Monomial>::const_iterator iterator;
+ for(iterator i = monos.begin(), end = monos.end(); i != end; ++i){
+ Debug("blah") << ((*i).getNode()) << std::endl;
+ }
+}
+
+Polynomial Polynomial::operator+(const Polynomial& vl) const{
+ this->printList();
+ vl.printList();
+
+ std::vector<Monomial> sortedMonos;
+ std::back_insert_iterator<std::vector<Monomial> > bii(sortedMonos);
+ std::merge(begin(), end(), vl.begin(), vl.end(), bii);
+
+ std::vector<Monomial> combined = Monomial::sumLikeTerms(sortedMonos);
+
+ Polynomial result = mkPolynomial(combined);
+ result.printList();
+ return result;
+}
+
+Polynomial Polynomial::operator*(const Monomial& mono) const{
+ if(mono.isZero()){
+ return Polynomial(mono); //Don't multiply by zero
+ }else{
+ std::vector<Monomial> newMonos;
+ for(iterator i = this->begin(), end = this->end(); i != end; ++i){
+ newMonos.push_back(mono * (*i));
+ }
+ return Polynomial::mkPolynomial(newMonos);
+ }
+}
+
+Polynomial Polynomial::operator*(const Polynomial& poly) const{
+
+ Polynomial res = Polynomial::mkZero();
+ for(iterator i = this->begin(), end = this->end(); i != end; ++i){
+ Monomial curr = *i;
+ Polynomial prod = poly * curr;
+ Polynomial sum = res + prod;
+ res = sum;
+ }
+ return res;
+}
+
+
+Node Comparison::toNode(Kind k, const Polynomial& l, const Constant& r){
+ Assert(!l.isConstant());
+ Assert(isRelationOperator(k));
+ switch(k){
+ case kind::GEQ:
+ case kind::EQUAL:
+ case kind::LEQ:
+ return NodeManager::currentNM()->mkNode(k, l.getNode(),r.getNode());
+ case kind::LT:
+ return NodeManager::currentNM()->mkNode(kind::NOT, toNode(kind::GEQ,l,r));
+ case kind::GT:
+ return NodeManager::currentNM()->mkNode(kind::NOT, toNode(kind::LEQ,l,r));
+ default:
+ Unreachable();
+ }
+}
+
+Comparison Comparison::parseNormalForm(TNode n){
+ if(n.getKind() == kind::CONST_BOOLEAN){
+ return Comparison(n.getConst<bool>());
+ }else{
+ bool negated = n.getKind() == kind::NOT;
+ Node relation = negated ? n[0] : n;
+ Assert( !negated ||
+ relation.getKind() == kind::LEQ ||
+ relation.getKind() == kind::GEQ);
+
+ Polynomial left = Polynomial::parsePolynomial(relation[0]);
+ Constant right(relation[1]);
+
+ Kind newOperator = relation.getKind();
+ if(negated){
+ if(newOperator == kind::LEQ){
+ newOperator = kind::GT;
+ }else{
+ newOperator = kind::LT;
+ }
+ }
+ return Comparison(n, newOperator, left, right);
+ }
+}
+
+Comparison Comparison::mkComparison(Kind k, const Polynomial& left, const Constant& right){
+ Assert(isRelationOperator(k));
+ if(left.isConstant()){
+ const Rational& rConst = left.getNode().getConst<Rational>();
+ const Rational& lConst = right.getNode().getConst<Rational>();
+ bool res = evaluateConstantPredicate(k, lConst, rConst);
+ return Comparison(res);
+ }else{
+ return Comparison(toNode(k, left, right), k, left, right);
+ }
+}
+
+Comparison Comparison::addConstant(const Constant& constant) const{
+ Assert(!isBoolean());
+ Monomial mono(constant);
+ Polynomial constAsPoly( mono );
+ Polynomial newLeft = getLeft() + constAsPoly;
+ Constant newRight = getRight() + constant;
+ return mkComparison(oper, newLeft, newRight);
+}
+
+Comparison Comparison::multiplyConstant(const Constant& constant) const{
+ Assert(!isBoolean());
+ Kind newOper = (constant.getValue() < 0) ? negateRelationKind(oper) : oper;
+
+ return mkComparison(newOper, left*Monomial(constant), right*constant);
+}
diff --git a/src/theory/arith/normal_form.h b/src/theory/arith/normal_form.h
new file mode 100644
index 000000000..1f7bc6be3
--- /dev/null
+++ b/src/theory/arith/normal_form.h
@@ -0,0 +1,613 @@
+
+#include "expr/node.h"
+#include "util/rational.h"
+#include "theory/arith/arith_constants.h"
+#include "theory/arith/arith_utilities.h"
+
+#include <list>
+#include <algorithm>
+#include <ext/algorithm>
+
+#ifndef __CVC4__THEORY__ARITH__NORMAL_FORM_H
+#define __CVC4__THEORY__ARITH__NORMAL_FORM_H
+
+namespace CVC4 {
+namespace theory {
+namespace arith {
+
+/***********************************************/
+/***************** Normal Form *****************/
+/***********************************************/
+/***********************************************/
+
+/**
+ * Section 1: Languages
+ * The normal form for arithmetic nodes is defined by the language
+ * accepted by the following BNFs with some guard conditions.
+ * (The guard conditions are in Section 3 for completeness.)
+ *
+ * variable := n
+ * where
+ * n.getMetaKind() == metakind::VARIABLE
+
+ * constant := n
+ * where
+ * n.getKind() == kind::CONST_RATIONAL
+
+ * var_list := variable | (* [variable])
+ * where
+ * len [variable] >= 2
+ * isSorted varOrder [variable]
+
+ * monomial := constant | var_list | (* constant' var_list')
+ * where
+ * constant' \not\in {0,1}
+
+ * polynomial := monomial | (+ [monomial])
+ * where
+ * len [monomial] >= 2
+ * isStrictlySorted monoOrder [monomial]
+
+ * restricted_cmp := (|><| polynomial constant)
+ * where
+ * |><| is GEQ, EQ, or EQ
+ * not (exists constantMonomial (monomialList polynomial))
+ * monomialCoefficient (head (monomialList polynomial)) == 1
+
+ * comparison := TRUE | FALSE | restricted_cmp | (not restricted_cmp)
+
+ * Normal Form for terms := polynomial
+ * Normal Form for atoms := comparison
+ */
+
+/**
+ * Section 2: Helper Classes
+ * The langauges accepted by each of these defintions
+ * roughly corresponds to one of the following helper classes:
+ * Variable
+ * Constant
+ * VarList
+ * Monomial
+ * Polynomial
+ * Comparison
+ *
+ * Each of the classes obeys the following contracts/design decisions:
+ * -Calling isMember(Node node) on a node returns true iff that node is a
+ * a member of the language. Note: isMember is O(n).
+ * -Calling isNormalForm() on a helper class object returns true iff that
+ * helper class currently represents a normal form object.
+ * -If isNormalForm() is false, then this object must have been made
+ * using a mk*() factory function.
+ * -If isNormalForm() is true, calling getNode() on all of these classes
+ * returns a node that would be accepted by the corresponding language.
+ * And if isNormalForm() is false, returns Node::null().
+ * -Each of the classes is immutable.
+ * -Public facing constuctors have a 1-to-1 correspondence with one of
+ * production rules in the above grammar.
+ * -Public facing constuctors are required to fail in debug mode when the
+ * guards of the production rule are not strictly met.
+ * For example: Monomial(Constant(1),VarList(Variable(x))) must fail.
+ * -When a class has a Class parseClass(Node node) function,
+ * if isMember(node) is true, the function is required to return an instance
+ * of the helper class, instance, s.t. instance.getNode() == node.
+ * And if isMember(node) is false, this throws an assertion failure in debug
+ * mode and has undefined behaviour if not in debug mode.
+ * -Only public facing constructors, parseClass(node), and mk*() functions are
+ * considered privledged functions for the helper class.
+ * -Only privledged functions may use private constructors, and access
+ * private data members.
+ * -All non-privledges functions are considered utility functions and
+ * must use a privledged function in order to create an instance of the class.
+ */
+
+/**
+ * Section 3: Guard Conditions Misc.
+ *
+ *
+ * var_list_len vl =
+ * match vl with
+ * variable -> 1
+ * | (* [variable]) -> len [variable]
+ *
+ * order res =
+ * match res with
+ * Empty -> (0,Node::null())
+ * | NonEmpty(vl) -> (var_list_len vl, vl)
+ *
+ * var_listOrder a b = tuple_cmp (order a) (order b)
+ *
+ * monomialVarList monomial =
+ * match monomial with
+ * constant -> Empty
+ * | var_list -> NonEmpty(var_list)
+ * | (* constant' var_list') -> NonEmpty(var_list')
+ *
+ * monoOrder m0 m1 = var_listOrder (monomialVarList m0) (monomialVarList m1)
+ *
+ * constantMonomial monomial =
+ * match monomial with
+ * constant -> true
+ * | var_list -> false
+ * | (* constant' var_list') -> false
+ *
+ * monomialCoefficient monomial =
+ * match monomial with
+ * constant -> constant
+ * | var_list -> Constant(1)
+ * | (* constant' var_list') -> constant'
+ *
+ * monomialList polynomial =
+ * match polynomial with
+ * monomial -> monomial::[]
+ * | (+ [monomial]) -> [monomial]
+ */
+
+/**
+ * A NodeWrapper is a class that is a thinly veiled container of a Node object.
+ */
+class NodeWrapper {
+private:
+ Node node;
+public:
+ NodeWrapper(Node n) : node(n) {}
+ const Node& getNode() const { return node; }
+};
+
+class Variable : public NodeWrapper {
+public:
+ Variable(Node n) : NodeWrapper(n) {
+ Assert(isMember(getNode()));
+ }
+
+ static bool isMember(Node n) {
+ return n.getMetaKind() == kind::metakind::VARIABLE;
+ }
+
+ bool isNormalForm() { return isMember(getNode()); }
+
+ bool operator<(const Variable& v) const{ return getNode() < v.getNode();}
+ bool operator==(const Variable& v) const{ return getNode() == v.getNode();}
+
+};
+
+class Constant : public NodeWrapper {
+public:
+ Constant(Node n) : NodeWrapper(n) {
+ Assert(isMember(getNode()));
+ }
+
+ static bool isMember(Node n) {
+ return n.getKind() == kind::CONST_RATIONAL;
+ }
+
+ bool isNormalForm() { return isMember(getNode()); }
+
+ static Constant mkConstant(Node n) {
+ return Constant(coerceToRationalNode(n));
+ }
+
+ static Constant mkConstant(const Rational& rat){
+ return Constant(mkRationalNode(rat));
+ }
+
+ const Rational& getValue() const {
+ return getNode().getConst<Rational>();
+ }
+
+ bool isZero() const{ return getValue() == 0; }
+ bool isOne() const{ return getValue() == 1; }
+
+ Constant operator*(const Constant& other) const{
+ return mkConstant(getValue() * other.getValue());
+ }
+ Constant operator+(const Constant& other) const{
+ return mkConstant(getValue() + other.getValue());
+ }
+ Constant operator-() const{
+ return mkConstant(-getValue());
+ }
+};
+
+template <class GetNodeIterator>
+inline Node makeNode(Kind k, GetNodeIterator start, GetNodeIterator end){
+ NodeBuilder<> nb(k);
+
+ while(start != end){
+ nb << (*start).getNode();
+ ++start;
+ }
+ return Node(nb);
+}
+
+/**
+ * A VarList is a sorted list of variables representing a product.
+ * If the VarList is empty, it represents an empty product or 1.
+ * If the VarList has size 1, it represents a single variable.
+ *
+ * A non-sorted VarList can never be successfully made in debug mode.
+ */
+class VarList {
+private:
+ Node backingNode;
+
+ static Node multList(const std::vector<Variable>& list){
+ Assert(list.size() >= 2);
+
+ return makeNode(kind::MULT, list.begin(), list.end());
+ }
+ static Node makeTuple(Node n){
+ return NodeManager::currentNM()->mkNode(kind::IDENTITY, n);
+ }
+
+ VarList() : backingNode(Node::null()){}
+
+ VarList(Node n){
+ backingNode = (Variable::isMember(n)) ? makeTuple(n) : n;
+
+ Assert(isSorted(begin(), end()));
+ }
+
+public:
+ class iterator {
+ private:
+ Node::iterator d_iter;
+ public:
+ explicit iterator(Node::iterator i) : d_iter(i) {}
+
+ inline Variable operator*(){
+ return Variable(*d_iter);
+ }
+
+ bool operator==(const iterator& i){
+ return d_iter == i.d_iter;
+ }
+
+ bool operator!=(const iterator& i){
+ return d_iter != i.d_iter;
+ }
+
+ iterator operator++() {
+ ++d_iter;
+ return *this;
+ }
+
+ iterator operator++(int) {
+ return iterator(d_iter++);
+ }
+ };
+
+ Node getNode() const{
+ if(singleton()){
+ return backingNode[0];
+ }else{
+ return backingNode;
+ }
+ }
+
+ iterator begin() const{
+ return iterator(backingNode.begin());
+ }
+ iterator end() const{
+ return iterator(backingNode.end());
+ }
+
+ VarList(Variable v) : backingNode(makeTuple(v.getNode())){
+ Assert(isSorted(begin(), end()));
+ }
+ VarList(const std::vector<Variable>& l) : backingNode(multList(l)){
+ Assert(l.size() >= 2);
+ Assert(isSorted(begin(), end()));
+ }
+
+ static bool isMember(Node n);
+
+ bool isNormalForm() const{
+ return !empty();
+ }
+
+ static VarList mkEmptyVarList(){
+ return VarList();
+ }
+
+
+ /** There are no restrictions on the size of l */
+ static VarList mkVarList(const std::vector<Variable>& l){
+ if(l.size() == 0){
+ return mkEmptyVarList();
+ }else if(l.size() == 1){
+ return VarList((*l.begin()).getNode());
+ }else{
+ return VarList(l);
+ }
+ }
+
+ int size() const{ return backingNode.getNumChildren(); }
+ bool empty() const { return size() == 0; }
+ bool singleton() const { return backingNode.getKind() == kind::IDENTITY; }
+
+ static VarList parseVarList(Node n);
+
+ VarList operator*(const VarList& vl) const;
+
+ int cmp(const VarList& vl) const;
+
+ bool operator<(const VarList& vl) const{ return cmp(vl) < 0; }
+
+ bool operator==(const VarList& vl) const{ return cmp(vl) == 0; }
+
+ std::vector<Variable> asList() const;
+
+private:
+ bool isSorted(iterator start, iterator end);
+};
+
+class Monomial : public NodeWrapper {
+private:
+ Constant constant;
+ VarList varList;
+ Monomial(Node n, const Constant& c, const VarList& vl):
+ NodeWrapper(n), constant(c), varList(vl)
+ {
+ Assert(!c.isZero() || vl.empty() );
+ Assert( c.isZero() || !vl.empty() );
+
+ Assert(!c.isOne() || !multStructured(n));
+ }
+
+ static Node makeMultNode(const Constant& c, const VarList& vl){
+ Assert(!c.isZero());
+ Assert(!c.isOne());
+ Assert(!vl.empty());
+ return NodeManager::currentNM()->mkNode(kind::MULT, c.getNode(), vl.getNode());
+ }
+
+ static bool multStructured(Node n){
+ return n.getKind() == kind::MULT &&
+ n[0].getKind() == kind::CONST_RATIONAL &&
+ n.getNumChildren() == 2;
+ }
+
+public:
+
+ Monomial(const Constant& c):
+ NodeWrapper(c.getNode()), constant(c), varList(VarList::mkEmptyVarList())
+ { }
+
+ Monomial(const VarList& vl):
+ NodeWrapper(vl.getNode()), constant(Constant::mkConstant(1)), varList(vl)
+ {
+ Assert( !varList.empty() );
+ }
+
+ Monomial(const Constant& c, const VarList& vl):
+ NodeWrapper(makeMultNode(c,vl)), constant(c), varList(vl)
+ {
+ Assert( !c.isZero() );
+ Assert( !c.isOne() );
+ Assert( !varList.empty() );
+
+ Assert(multStructured(getNode()));
+ }
+
+ /** Makes a monomial with no restrictions on c and vl. */
+ static Monomial mkMonomial(const Constant& c, const VarList& vl);
+
+
+ static Monomial parseMonomial(Node n);
+
+ static Monomial mkZero(){
+ return Monomial(Constant::mkConstant(0));
+ }
+ static Monomial mkOne(){
+ return Monomial(Constant::mkConstant(1));
+ }
+ const Constant& getConstant() const{ return constant; }
+ const VarList& getVarList() const{ return varList; }
+
+ bool isConstant() const{
+ return varList.empty();
+ }
+
+ bool isZero() const{
+ return constant.isZero();
+ }
+
+ bool coefficientIsOne() const {
+ return constant.isOne();
+ }
+
+ Monomial operator*(const Monomial& mono) const;
+
+
+ int cmp(const Monomial& mono) const{
+ return getVarList().cmp(mono.getVarList());
+ }
+
+ bool operator<(const Monomial& vl) const{
+ return cmp(vl) < 0;
+ }
+
+ bool operator==(const Monomial& vl) const{
+ return cmp(vl) == 0;
+ }
+
+ static bool isSorted(const std::vector<Monomial>& m){
+ return __gnu_cxx::is_sorted(m.begin(), m.end());
+ }
+
+ static bool isStrictlySorted(const std::vector<Monomial>& m){
+ return isSorted(m) && std::adjacent_find(m.begin(),m.end()) == m.end();
+ }
+
+ /**
+ * Given a sorted list of monomials, this function transforms this
+ * into a strictly sorted list of monomials that does not contain zero.
+ */
+ static std::vector<Monomial> sumLikeTerms(const std::vector<Monomial>& monos);
+
+ static void printList(const std::vector<Monomial>& monos);
+};
+
+class Polynomial : public NodeWrapper {
+private:
+ std::vector<Monomial> monos;
+
+ Polynomial(Node n, const std::vector<Monomial>& m):
+ NodeWrapper(n), monos(m)
+ {
+ Assert( !monos.empty() );
+ Assert( Monomial::isStrictlySorted(monos) );
+ }
+
+ static Node makePlusNode(const std::vector<Monomial>& m){
+ Assert(m.size() >= 2);
+
+ return makeNode(kind::PLUS, m.begin(), m.end());
+ }
+
+public:
+ typedef std::vector<Monomial>::const_iterator iterator;
+
+ iterator begin() const{ return monos.begin(); }
+ iterator end() const{ return monos.end(); }
+
+ Polynomial(const Monomial& m):
+ NodeWrapper(m.getNode()), monos()
+ {
+ monos.push_back(m);
+ }
+ Polynomial(const std::vector<Monomial>& m):
+ NodeWrapper(makePlusNode(m)), monos(m)
+ {
+ Assert( monos.size() >= 2);
+ Assert( Monomial::isStrictlySorted(monos) );
+ }
+
+
+ static Polynomial mkPolynomial(const std::vector<Monomial>& m){
+ if(m.size() == 0){
+ return Polynomial(Monomial::mkZero());
+ }else if(m.size() == 1){
+ return Polynomial((*m.begin()));
+ }else{
+ return Polynomial(m);
+ }
+ }
+
+ static Polynomial parsePolynomial(Node n){
+ std::vector<Monomial> monos;
+ if(n.getKind() == kind::PLUS){
+ for(Node::iterator i=n.begin(), end=n.end(); i != end; ++i){
+ monos.push_back(Monomial::parseMonomial(*i));
+ }
+ }else{
+ monos.push_back(Monomial::parseMonomial(n));
+ }
+ return Polynomial(n,monos);
+ }
+
+ static Polynomial mkZero(){
+ return Polynomial(Monomial::mkZero());
+ }
+ static Polynomial mkOne(){
+ return Polynomial(Monomial::mkOne());
+ }
+ bool isZero() const{
+ return (monos.size() == 1) && (getHead().isZero());
+ }
+
+ bool isConstant() const{
+ return (monos.size() == 1) && (getHead().isConstant());
+ }
+
+ bool containsConstant() const{
+ return getHead().isConstant();
+ }
+
+ Monomial getHead() const{
+ return *(begin());
+ }
+
+ Polynomial getTail() const{
+ Assert(monos.size() >= 1);
+
+ iterator start = begin()+1;
+ std::vector<Monomial> subrange(start, end());
+ return mkPolynomial(subrange);
+ }
+
+ void printList() const{
+ Debug("blah") << "start list" << std::endl;
+ Monomial::printList(monos);
+ Debug("blah") << "end list" << std::endl;
+ }
+
+ Polynomial operator+(const Polynomial& vl) const;
+
+ Polynomial operator*(const Monomial& mono) const;
+
+ Polynomial operator*(const Polynomial& poly) const;
+
+};
+
+class Comparison : public NodeWrapper {
+private:
+ Kind oper;
+ Polynomial left;
+ Constant right;
+
+ static Node toNode(Kind k, const Polynomial& l, const Constant& r);
+
+ Comparison(TNode n, Kind k, const Polynomial& l, const Constant& r):
+ NodeWrapper(n), oper(k), left(l), right(r)
+ { }
+public:
+ Comparison(bool val) :
+ NodeWrapper(NodeManager::currentNM()->mkConst(val)),
+ oper(kind::CONST_BOOLEAN),
+ left(Polynomial::mkZero()),
+ right(Constant::mkConstant(0))
+ { }
+
+ Comparison(Kind k, const Polynomial& l, const Constant& r):
+ NodeWrapper(toNode(k, l, r)), oper(k), left(l), right(r)
+ {
+ Assert(isRelationOperator(oper));
+ Assert(!left.containsConstant());
+ Assert(left.getHead().getConstant().isOne());
+ }
+
+ static Comparison mkComparison(Kind k, const Polynomial& left, const Constant& right);
+
+ bool isBoolean() const{
+ return (oper == kind::CONST_BOOLEAN);
+ }
+
+ bool isNormalForm() const{
+ if(isBoolean()){
+ return true;
+ }else if(left.containsConstant()){
+ return false;
+ }else if(left.getHead().getConstant().isOne()){
+ return true;
+ }else{
+ return false;
+ }
+ }
+
+ const Polynomial& getLeft() const { return left; }
+ const Constant& getRight() const { return right; }
+
+ Comparison addConstant(const Constant& constant) const;
+ Comparison multiplyConstant(const Constant& constant) const;
+
+ static Comparison parseNormalForm(TNode n);
+};
+
+
+
+}; /* namesapce arith */
+}; /* namespace theory */
+}; /* namespace CVC4 */
+
+#endif /* __CVC4__THEORY__ARITH__NORMAL_FORM_H */
diff --git a/src/theory/arith/tableau.h b/src/theory/arith/tableau.h
index 12d93d9fe..603eb5278 100644
--- a/src/theory/arith/tableau.h
+++ b/src/theory/arith/tableau.h
@@ -23,7 +23,7 @@
#include "theory/arith/basic.h"
#include "theory/arith/arith_activity.h"
-
+#include "theory/arith/normal_form.h"
#include <ext/hash_map>
#include <map>
@@ -52,21 +52,21 @@ public:
* Construct a row equal to:
* basic = \sum_{x_i} c_i * x_i
*/
- Row(TNode basic, TNode sum):
+ Row(TNode basic, const Polynomial& sum):
d_x_i(basic),
d_coeffs(){
Assert(d_x_i.getMetaKind() == kind::metakind::VARIABLE);
- Assert(sum.getKind() == kind::PLUS);
-
- for(TNode::iterator iter=sum.begin(); iter != sum.end(); ++iter){
- TNode pair = *iter;
- Assert(pair.getKind() == kind::MULT);
- Assert(pair.getNumChildren() == 2);
- TNode coeff = pair[0];
- TNode var_i = pair[1];
+
+ for(Polynomial::iterator iter=sum.begin(), end = sum.end(); iter != end; ++iter){
+ const Monomial& mono = *iter;
+
+ Assert(!mono.isConstant());
+
+ TNode coeff = mono.getConstant().getNode();
+ TNode var_i = mono.getVarList().getNode();
+
Assert(coeff.getKind() == kind::CONST_RATIONAL);
- Assert(var_i.getKind() == kind::VARIABLE);
Assert(!has(var_i));
d_coeffs[var_i] = coeff.getConst<Rational>();
@@ -192,14 +192,13 @@ private:
public:
void addRow(TNode eq){
- Assert(eq.getKind() == kind::EQUAL);
- Assert(eq.getNumChildren() == 2);
-
TNode var = eq[0];
- TNode sum = eq[1];
+ TNode sumNode = eq[1];
Assert(var.getAttribute(IsBasic()));
+ Polynomial sum = Polynomial::parsePolynomial(sumNode);
+
//The new basic variable cannot already be a basic variable
Assert(!isActiveBasicVariable(var));
d_activeBasicVars.insert(var);
@@ -208,13 +207,11 @@ public:
//A variable in the row may have been made non-basic already.
//If this is the case we fake pivoting this variable
- for(TNode::iterator sumIter = sum.begin(); sumIter!=sum.end(); ++sumIter){
- TNode child = *sumIter;
- Assert(child.getKind() == kind::MULT);
- Assert(child.getNumChildren() == 2);
- Assert(child[0].getKind() == kind::CONST_RATIONAL);
- TNode c = child[1];
- Assert(var.getMetaKind() == kind::metakind::VARIABLE);
+ for(Polynomial::iterator sumIter = sum.begin(); sumIter!= sum.end(); ++sumIter){
+ const Monomial& child = *sumIter;
+
+ Assert(!child.isConstant());
+ TNode c = child.getVarList().getNode();
if(isActiveBasicVariable(c)){
Row* row_c = lookup(c);
row_var->subsitute(*row_c);
diff --git a/src/theory/arith/theory_arith.cpp b/src/theory/arith/theory_arith.cpp
index 157c45160..8f17b01a9 100644
--- a/src/theory/arith/theory_arith.cpp
+++ b/src/theory/arith/theory_arith.cpp
@@ -33,10 +33,12 @@
#include "theory/arith/basic.h"
#include "theory/arith/arith_activity.h"
-#include "theory/arith/arith_rewriter.h"
+#include "theory/arith/next_arith_rewriter.h"
#include "theory/arith/arith_propagator.h"
#include "theory/arith/theory_arith.h"
+#include "theory/arith/normal_form.h"
+
#include <map>
#include <stdint.h>
@@ -55,7 +57,7 @@ TheoryArith::TheoryArith(int id, context::Context* c, OutputChannel& out) :
d_constants(NodeManager::currentNM()),
d_partialModel(c),
d_diseq(c),
- d_rewriter(&d_constants),
+ d_nextRewriter(&d_constants),
d_propagator(c),
d_statistics()
{
@@ -109,22 +111,9 @@ bool isBasicSum(TNode n){
bool isNormalAtom(TNode n){
- if(!(n.getKind() == LEQ|| n.getKind() == GEQ || n.getKind() == EQUAL)){
- return false;
- }
- TNode left = n[0];
- TNode right = n[1];
- if(right.getKind() != CONST_RATIONAL){
- return false;
- }
- if(left.getMetaKind() == metakind::VARIABLE){
- return true;
- }else if(isBasicSum(left)){
- return true;
- }else{
- return false;
- }
+ Comparison parse = Comparison::parseNormalForm(n);
+ return parse.isNormalForm();
}
@@ -213,7 +202,6 @@ void TheoryArith::preRegisterTerm(TNode n) {
if(left.getKind() == PLUS){
//We may need to introduce a slack variable.
Assert(left.getNumChildren() >= 2);
- Assert(isBasicSum(left));
if(!left.hasAttribute(Slack())){
setupSlack(left);
}
@@ -229,11 +217,9 @@ void TheoryArith::setupSlack(TNode left){
left.setAttribute(Slack(), slack);
makeBasic(slack);
- Node slackEqLeft = NodeManager::currentNM()->mkNode(EQUAL,slack,left);
-
- Debug("slack") << "slack " << slackEqLeft << endl;
+ Node eq = NodeManager::currentNM()->mkNode(kind::EQUAL, slack, left);
- d_tableau.addRow(slackEqLeft);
+ d_tableau.addRow(eq);
setupVariable(slack);
}
@@ -316,56 +302,9 @@ DeltaRational TheoryArith::computeRowValueUsingSavedAssignment(TNode x){
}
RewriteResponse TheoryArith::preRewrite(TNode n, bool topLevel) {
- // ensure a hard link to the node we're returning
- Node out;
-
- // Look for multiplications with a 0 argument and rewrite the whole
- // thing as 0
- if(n.getKind() == MULT) {
- Rational ratZero;
- Integer intZero;
- for(TNode::iterator i = n.begin(); i != n.end(); ++i) {
- if((*i).getKind() == CONST_RATIONAL) {
- if((*i).getConst<Rational>() == ratZero) {
- out = NodeManager::currentNM()->mkConst(ratZero);
- break;
- }
- } else if((*i).getKind() == CONST_INTEGER) {
- if((*i).getConst<Integer>() == intZero) {
- if(n.getType().isInteger()) {
- out = NodeManager::currentNM()->mkConst(intZero);
- break;
- } else {
- out = NodeManager::currentNM()->mkConst(ratZero);
- break;
- }
- }
- }
- }
- } else if(n.getKind() == EQUAL) {
- if(n[0] == n[1]) {
- out = NodeManager::currentNM()->mkConst(true);
- }
- }
-
- if(out.isNull()) {
- // no preRewrite to perform
- return RewriteComplete(Node(n));
- } else {
- // out is always a constant, so doesn't need to be rewritten again
- return RewriteComplete(out);
- }
+ return d_nextRewriter.preRewrite(n);
}
-Node TheoryArith::rewrite(TNode n){
- Debug("arith") << "rewrite(" << n << ")" << endl;
-
- Node result = d_rewriter.rewrite(n);
- Debug("arith-rewrite") << "rewrite(" << n << ") -> " << result << endl;
- return result;
-}
-
-
void TheoryArith::registerTerm(TNode tn){
Debug("arith") << "registerTerm(" << tn << ")" << endl;
}
diff --git a/src/theory/arith/theory_arith.h b/src/theory/arith/theory_arith.h
index 7367f5726..03be7a77b 100644
--- a/src/theory/arith/theory_arith.h
+++ b/src/theory/arith/theory_arith.h
@@ -28,7 +28,7 @@
#include "theory/arith/delta_rational.h"
#include "theory/arith/tableau.h"
-#include "theory/arith/arith_rewriter.h"
+#include "theory/arith/next_arith_rewriter.h"
#include "theory/arith/partial_model.h"
#include "theory/arith/arith_propagator.h"
@@ -94,7 +94,7 @@ private:
/**
* The rewriter module for arithmetic.
*/
- ArithRewriter d_rewriter;
+ NextArithRewriter d_nextRewriter;
ArithUnatePropagator d_propagator;
@@ -103,11 +103,6 @@ public:
~TheoryArith();
/**
- * Rewrites a node to a unique normal form given in normal_form_notes.txt
- */
- Node rewrite(TNode n);
-
- /**
* Rewriting optimizations.
*/
RewriteResponse preRewrite(TNode n, bool topLevel);
@@ -116,7 +111,7 @@ public:
* Plug in old rewrite to the new (pre,post)rewrite interface.
*/
RewriteResponse postRewrite(TNode n, bool topLevel) {
- return RewriteComplete(topLevel ? rewrite(n) : Node(n));
+ return d_nextRewriter.postRewrite(n);
}
/**
diff --git a/test/unit/theory/theory_arith_white.h b/test/unit/theory/theory_arith_white.h
index ea1ee698f..763e03fdb 100644
--- a/test/unit/theory/theory_arith_white.h
+++ b/test/unit/theory/theory_arith_white.h
@@ -65,6 +65,32 @@ public:
TheoryArithWhite() : d_level(Theory::FULL_EFFORT), d_zero(0), d_one(1), debug(false) {}
+ void fakeTheoryEnginePreprocess(TNode inp){
+ Node rewrite = inp; //FIXME this needs to enforce that inp is fully rewritten already!
+
+ if(debug) cout << rewrite << inp << endl;
+
+ std::list<Node> toPreregister;
+
+ toPreregister.push_back(rewrite);
+ for(std::list<Node>::iterator i = toPreregister.begin(); i != toPreregister.end(); ++i){
+ Node n = *i;
+ preregistered->insert(n);
+
+ for(Node::iterator citer = n.begin(); citer != n.end(); ++citer){
+ Node c = *citer;
+ if(preregistered->find(c) == preregistered->end()){
+ toPreregister.push_back(c);
+ }
+ }
+ }
+ for(std::list<Node>::reverse_iterator i = toPreregister.rbegin(); i != toPreregister.rend(); ++i){
+ Node n = *i;
+ if(debug) cout << n.getId() << " "<< n << endl;
+ d_arith->preRegisterTerm(n);
+ }
+ }
+
void setUp() {
d_ctxt = new Context;
d_nm = new NodeManager(d_ctxt);
@@ -92,42 +118,14 @@ public:
delete d_ctxt;
}
- Node fakeTheoryEnginePreprocess(TNode inp){
- Node rewrite = d_arith->rewrite(inp);
-
- if(debug) cout << rewrite << inp << endl;
-
- std::list<Node> toPreregister;
-
- toPreregister.push_back(rewrite);
- for(std::list<Node>::iterator i = toPreregister.begin(); i != toPreregister.end(); ++i){
- Node n = *i;
- preregistered->insert(n);
-
- for(Node::iterator citer = n.begin(); citer != n.end(); ++citer){
- Node c = *citer;
- if(preregistered->find(c) == preregistered->end()){
- toPreregister.push_back(c);
- }
- }
- }
- for(std::list<Node>::reverse_iterator i = toPreregister.rbegin(); i != toPreregister.rend(); ++i){
- Node n = *i;
- if(debug) cout << n.getId() << " "<< n << endl;
- d_arith->preRegisterTerm(n);
- }
-
- return rewrite;
- }
-
void testAssert() {
Node x = d_nm->mkVar(*d_realType);
Node c = d_nm->mkConst<Rational>(d_zero);
Node leq = d_nm->mkNode(LEQ, x, c);
- Node rLeq = fakeTheoryEnginePreprocess(leq);
+ fakeTheoryEnginePreprocess(leq);
- d_arith->assertFact(rLeq);
+ d_arith->assertFact(leq);
d_arith->check(d_level);
@@ -143,51 +141,19 @@ public:
return dis;
}
- void testAssertEqualityEagerSplit() {
- Node x = d_nm->mkVar(*d_realType);
- Node c = d_nm->mkConst<Rational>(d_zero);
-
- Node eq = d_nm->mkNode(EQUAL, x, c);
- Node expectedDisjunct = simulateSplit(x,c);
-
- Node rEq = fakeTheoryEnginePreprocess(eq);
-
- d_arith->assertFact(rEq);
-
- d_arith->check(d_level);
-
- TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 1u);
-
- TS_ASSERT_EQUALS(d_outputChannel.getIthNode(0), expectedDisjunct);
- TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(0), AUG_LEMMA);
-
- }
- void testLtRewrite() {
- Node x = d_nm->mkVar(*d_realType);
- Node c = d_nm->mkConst<Rational>(d_zero);
-
- Node lt = d_nm->mkNode(LT, x, c);
- Node geq = d_nm->mkNode(GEQ, x, c);
- Node expectedRewrite = d_nm->mkNode(NOT, geq);
-
- Node rewrite = d_arith->rewrite(lt);
-
- TS_ASSERT_EQUALS(expectedRewrite, rewrite);
- }
-
void testBasicConflict() {
Node x = d_nm->mkVar(*d_realType);
Node c = d_nm->mkConst<Rational>(d_zero);
Node eq = d_nm->mkNode(EQUAL, x, c);
- Node lt = d_nm->mkNode(LT, x, c);
+ Node lt = d_nm->mkNode(NOT, d_nm->mkNode(GEQ, x, c));
Node expectedDisjunct = simulateSplit(x,c);
- Node rEq = fakeTheoryEnginePreprocess(eq);
- Node rLt = fakeTheoryEnginePreprocess(lt);
+ fakeTheoryEnginePreprocess(eq);
+ fakeTheoryEnginePreprocess(lt);
- d_arith->assertFact(rEq);
- d_arith->assertFact(rLt);
+ d_arith->assertFact(eq);
+ d_arith->assertFact(lt);
d_arith->check(d_level);
@@ -198,7 +164,7 @@ public:
TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(1), CONFLICT);
- Node expectedClonflict = d_nm->mkNode(AND, rEq, rLt);
+ Node expectedClonflict = d_nm->mkNode(AND, eq, lt);
TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), expectedClonflict);
}
@@ -208,13 +174,13 @@ public:
Node c = d_nm->mkConst<Rational>(d_zero);
Node eq = d_nm->mkNode(EQUAL, x, c);
- Node lt = d_nm->mkNode(LT, x, c);
+ Node lt = d_nm->mkNode(NOT, d_nm->mkNode(GEQ, x, c));
Node expectedDisjunct = simulateSplit(x,c);
- Node rEq = fakeTheoryEnginePreprocess(eq);
- Node rLt = fakeTheoryEnginePreprocess(lt);
+ fakeTheoryEnginePreprocess(eq);
+ fakeTheoryEnginePreprocess(lt);
- d_arith->assertFact(rEq);
+ d_arith->assertFact(eq);
d_arith->check(d_level);
@@ -236,29 +202,30 @@ public:
Node leq0 = d_nm->mkNode(LEQ, x, c0);
Node leq1 = d_nm->mkNode(LEQ, x, c1);
- Node lt1 = d_nm->mkNode(LT, x, c1);
+ Node geq1 = d_nm->mkNode(GEQ, x, c1);
+ Node lt1 = d_nm->mkNode(NOT, geq1);
- Node rLeq0 = fakeTheoryEnginePreprocess(leq0);
- Node rLt1 = fakeTheoryEnginePreprocess(lt1);
- Node rLeq1 = fakeTheoryEnginePreprocess(leq1);
+ fakeTheoryEnginePreprocess(leq0);
+ fakeTheoryEnginePreprocess(leq1);
+ fakeTheoryEnginePreprocess(geq1);
- d_arith->assertFact(rLt1);
+ d_arith->assertFact(lt1);
d_arith->check(d_level);
d_arith->propagate(d_level);
#ifdef CVC4_ASSERTIONS
- TS_ASSERT_THROWS( d_arith->explain(rLeq0, d_level), AssertionException );
- TS_ASSERT_THROWS( d_arith->explain(rLt1, d_level), AssertionException );
+ TS_ASSERT_THROWS( d_arith->explain(leq0, d_level), AssertionException );
+ TS_ASSERT_THROWS( d_arith->explain(lt1, d_level), AssertionException );
#endif
- d_arith->explain(rLeq1, d_level);
+ d_arith->explain(leq1, d_level);
TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 2u);
TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(0), PROPAGATE);
TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(1), EXPLANATION);
TS_ASSERT_EQUALS(d_outputChannel.getIthNode(0), leq1);
- TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), rLt1);
+ TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), lt1);
}
@@ -269,24 +236,25 @@ public:
Node leq0 = d_nm->mkNode(LEQ, x, c0);
Node leq1 = d_nm->mkNode(LEQ, x, c1);
- Node lt1 = d_nm->mkNode(LT, x, c1);
+ Node geq1 = d_nm->mkNode(GEQ, x, c1);
+ Node lt1 = d_nm->mkNode(NOT, geq1);
- Node rLeq0 = fakeTheoryEnginePreprocess(leq0);
- Node rLt1 = fakeTheoryEnginePreprocess(lt1);
- Node rLeq1 = fakeTheoryEnginePreprocess(leq1);
+ fakeTheoryEnginePreprocess(leq0);
+ fakeTheoryEnginePreprocess(leq1);
+ fakeTheoryEnginePreprocess(geq1);
- d_arith->assertFact(rLeq0);
+ d_arith->assertFact(leq0);
d_arith->check(d_level);
d_arith->propagate(d_level);
- d_arith->explain(rLt1, d_level);
+ d_arith->explain(lt1, d_level);
#ifdef CVC4_ASSERTIONS
- TS_ASSERT_THROWS( d_arith->explain(rLeq0, d_level), AssertionException );
+ TS_ASSERT_THROWS( d_arith->explain(leq0, d_level), AssertionException );
#endif
- d_arith->explain(rLeq1, d_level);
+ d_arith->explain(leq1, d_level);
TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 4u);
TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(0), PROPAGATE);
@@ -294,12 +262,12 @@ public:
TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(2), EXPLANATION);
TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(3), EXPLANATION);
- TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), rLt1);
- TS_ASSERT_EQUALS(d_outputChannel.getIthNode(0), rLeq1);
+ TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), lt1);
+ TS_ASSERT_EQUALS(d_outputChannel.getIthNode(0), leq1);
- TS_ASSERT_EQUALS(d_outputChannel.getIthNode(2), rLeq0);
- TS_ASSERT_EQUALS(d_outputChannel.getIthNode(3), rLeq0);
+ TS_ASSERT_EQUALS(d_outputChannel.getIthNode(2), leq0);
+ TS_ASSERT_EQUALS(d_outputChannel.getIthNode(3), leq0);
}
void testTPLeq1() {
Node x = d_nm->mkVar(*d_realType);
@@ -308,22 +276,23 @@ public:
Node leq0 = d_nm->mkNode(LEQ, x, c0);
Node leq1 = d_nm->mkNode(LEQ, x, c1);
- Node lt1 = d_nm->mkNode(LT, x, c1);
+ Node geq1 = d_nm->mkNode(GEQ, x, c1);
+ Node lt1 = d_nm->mkNode(NOT, geq1);
- Node rLeq0 = fakeTheoryEnginePreprocess(leq0);
- Node rLt1 = fakeTheoryEnginePreprocess(lt1);
- Node rLeq1 = fakeTheoryEnginePreprocess(leq1);
+ fakeTheoryEnginePreprocess(leq0);
+ fakeTheoryEnginePreprocess(leq1);
+ fakeTheoryEnginePreprocess(geq1);
- d_arith->assertFact(rLeq1);
+ d_arith->assertFact(leq1);
d_arith->check(d_level);
d_arith->propagate(d_level);
#ifdef CVC4_ASSERTIONS
- TS_ASSERT_THROWS( d_arith->explain(rLeq0, d_level), AssertionException );
- TS_ASSERT_THROWS( d_arith->explain(rLeq1, d_level), AssertionException );
- TS_ASSERT_THROWS( d_arith->explain(rLt1, d_level), AssertionException );
+ TS_ASSERT_THROWS( d_arith->explain(leq0, d_level), AssertionException );
+ TS_ASSERT_THROWS( d_arith->explain(leq1, d_level), AssertionException );
+ TS_ASSERT_THROWS( d_arith->explain(lt1, d_level), AssertionException );
#endif
TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 0u);
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback