summaryrefslogtreecommitdiff
path: root/src/theory/arith/arith_rewriter.cpp
diff options
context:
space:
mode:
authorTim King <taking@cs.nyu.edu>2012-11-11 00:28:05 +0000
committerTim King <taking@cs.nyu.edu>2012-11-11 00:28:05 +0000
commit341794b1cbd5693010c78b9f5bfe232ee90404b0 (patch)
treeb03c9a0d39050cf0fb5dbfe7393435adc7c5de19 /src/theory/arith/arith_rewriter.cpp
parenta4bebb3ec1e27b433b63dcb2b82f6385e0c40561 (diff)
Fixes for the arithmetic normal form and rewriter to handle arbitrary constants for total functions.
Diffstat (limited to 'src/theory/arith/arith_rewriter.cpp')
-rw-r--r--src/theory/arith/arith_rewriter.cpp232
1 files changed, 138 insertions, 94 deletions
diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp
index b6275ba24..689f231e6 100644
--- a/src/theory/arith/arith_rewriter.cpp
+++ b/src/theory/arith/arith_rewriter.cpp
@@ -80,58 +80,28 @@ RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
return rewriteConstant(t);
}else if(t.isVar()){
return rewriteVariable(t);
- }else if(t.getKind() == kind::MINUS){
- return rewriteMinus(t, true);
- }else if(t.getKind() == kind::UMINUS){
- return rewriteUMinus(t, true);
- }else if(t.getKind() == kind::DIVISION){
- return RewriteResponse(REWRITE_DONE, t); // wait until t[1] is rewritten
- }else if(t.getKind() == kind::DIVISION_TOTAL){
- if(t[1].getKind()== kind::CONST_RATIONAL &&
- t[1].getConst<Rational>().isZero()){
- return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
- }else{
- return RewriteResponse(REWRITE_DONE, t); // wait until t[1] is rewritten
- }
- }else if(t.getKind() == kind::PLUS){
- return preRewritePlus(t);
- }else if(t.getKind() == kind::MULT){
- return preRewriteMult(t);
- }else if(t.getKind() == kind::INTS_DIVISION){
- Rational intOne(1);
- if(t[1].getKind()== kind::CONST_RATIONAL &&
- t[1].getConst<Rational>().isOne()){
- return RewriteResponse(REWRITE_AGAIN, t[0]);
- }else{
- return RewriteResponse(REWRITE_DONE, t);
- }
- }else if(t.getKind() == kind::INTS_DIVISION_TOTAL){
- if(t[1].getKind()== kind::CONST_RATIONAL){
- Rational intOne(1), intZero(0);
- if(t[1].getConst<Rational>().isOne()){
- return RewriteResponse(REWRITE_AGAIN, t[0]);
- } else if(t[1].getConst<Rational>().isZero()){
- return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
- }
- }
- return RewriteResponse(REWRITE_DONE, t);
- }else if(t.getKind() == kind::INTS_MODULUS){
- Rational intOne(1);
- if(t[1].getKind()== kind::CONST_RATIONAL &&
- t[1].getConst<Rational>().isOne()){
- return RewriteResponse(REWRITE_AGAIN, mkRationalNode(0));
- }else{
- return RewriteResponse(REWRITE_DONE, t);
- }
- }else if(t.getKind() == kind::INTS_MODULUS_TOTAL){
- if(t[1].getKind()== kind::CONST_RATIONAL){
- if(t[1].getConst<Rational>().isOne() || t[1].getConst<Rational>().isZero()){
- return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
- }
- }
- return RewriteResponse(REWRITE_DONE, t);
}else{
- Unreachable();
+ switch(t.getKind()){
+ case kind::MINUS:
+ return rewriteMinus(t, true);
+ case kind::UMINUS:
+ return rewriteUMinus(t, true);
+ case kind::DIVISION:
+ return rewriteDiv(t,true);
+ case kind::DIVISION_TOTAL:
+ return rewriteDivTotal(t,true);
+ case kind::PLUS:
+ return preRewritePlus(t);
+ case kind::MULT:
+ return preRewriteMult(t);
+ //case kind::INTS_DIVISION:
+ //case kind::INTS_MODULUS:
+ case kind::INTS_DIVISION_TOTAL:
+ case kind::INTS_MODULUS_TOTAL:
+ return rewriteIntsDivModTotal(t,true);
+ default:
+ Unreachable();
+ }
}
}
RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
@@ -139,33 +109,32 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
return rewriteConstant(t);
}else if(t.isVar()){
return rewriteVariable(t);
- }else if(t.getKind() == kind::MINUS){
- return rewriteMinus(t, false);
- }else if(t.getKind() == kind::UMINUS){
- return rewriteUMinus(t, false);
- }else if(t.getKind() == kind::DIVISION ||
- t.getKind() == kind::DIVISION_TOTAL){
- return rewriteDiv(t, false);
- }else if(t.getKind() == kind::PLUS){
- return postRewritePlus(t);
- }else if(t.getKind() == kind::MULT){
- return postRewriteMult(t);
- }else if(t.getKind() == kind::INTS_DIVISION ||
- t.getKind() == kind::INTS_MODULUS){
- return RewriteResponse(REWRITE_DONE, t);
- }else if(t.getKind() == kind::INTS_DIVISION_TOTAL ||
- t.getKind() == kind::INTS_MODULUS_TOTAL){
- if(t[1].getKind() == kind::CONST_RATIONAL &&
- t[1].getConst<Rational>().isZero()){
- return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
- }else{
- return RewriteResponse(REWRITE_DONE, t);
- }
}else{
- Unreachable();
+ switch(t.getKind()){
+ case kind::MINUS:
+ return rewriteMinus(t, false);
+ case kind::UMINUS:
+ return rewriteUMinus(t, false);
+ case kind::DIVISION:
+ return rewriteDiv(t, false);
+ case kind::DIVISION_TOTAL:
+ return rewriteDivTotal(t, false);
+ case kind::PLUS:
+ return postRewritePlus(t);
+ case kind::MULT:
+ return postRewriteMult(t);
+ //case kind::INTS_DIVISION:
+ //case kind::INTS_MODULUS:
+ case kind::INTS_DIVISION_TOTAL:
+ case kind::INTS_MODULUS_TOTAL:
+ return rewriteIntsDivModTotal(t, false);
+ default:
+ Unreachable();
+ }
}
}
+
RewriteResponse ArithRewriter::preRewriteMult(TNode t){
Assert(t.getKind()== kind::MULT);
@@ -217,19 +186,6 @@ RewriteResponse ArithRewriter::postRewriteMult(TNode t){
return RewriteResponse(REWRITE_DONE, res.getNode());
}
-// RewriteResponse ArithRewriter::postRewriteAtomConstantRHS(TNode t){
-// TNode left = t[0];
-// TNode right = t[1];
-
-// Polynomial pLeft = Polynomial::parsePolynomial(left);
-
-
-// Comparison cmp = Comparison::mkComparison(t.getKind(), Polynomial::parsePolynomial(left), Constant(right));
-
-// Assert(cmp.isNormalForm());
-// return RewriteResponse(REWRITE_DONE, cmp.getNode());
-// }
-
RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
// left |><| right
TNode left = atom[0];
@@ -304,9 +260,52 @@ Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
return diff;
}
-
RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
- Assert(t.getKind()== kind::DIVISION || t.getKind() == kind::DIVISION_TOTAL);
+ Assert(t.getKind()== kind::DIVISION);
+
+ Node left = t[0];
+ Node right = t[1];
+
+ if(right.getKind() == kind::CONST_RATIONAL &&
+ left.getKind() != kind::CONST_RATIONAL){
+
+ const Rational& den = right.getConst<Rational>();
+
+ Assert(!den.isZero());
+
+ Rational div = den.inverse();
+ Node result = mkRationalNode(div);
+ Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
+ if(pre){
+ return RewriteResponse(REWRITE_DONE, mult);
+ }else{
+ return RewriteResponse(REWRITE_AGAIN, mult);
+ }
+ }
+
+ if(pre){
+ if(right.getKind() != kind::CONST_RATIONAL ||
+ left.getKind() != kind::CONST_RATIONAL){
+ return RewriteResponse(REWRITE_DONE, t);
+ }
+ }
+
+ Assert(right.getKind() == kind::CONST_RATIONAL);
+ Assert(left.getKind() == kind::CONST_RATIONAL);
+
+ const Rational& den = right.getConst<Rational>();
+
+ Assert(!den.isZero());
+
+ const Rational& num = left.getConst<Rational>();
+ Rational div = num / den;
+ Node result = mkRationalNode(div);
+ return RewriteResponse(REWRITE_DONE, result);
+}
+
+RewriteResponse ArithRewriter::rewriteDivTotal(TNode t, bool pre){
+ Assert(t.getKind() == kind::DIVISION_TOTAL);
+
Node left = t[0];
Node right = t[1];
@@ -314,14 +313,17 @@ RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
const Rational& den = right.getConst<Rational>();
if(den.isZero()){
- if(t.getKind() == kind::DIVISION_TOTAL){
- return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
- }else{
- return RewriteResponse(REWRITE_DONE, t);
- }
+ return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
}
Assert(den != Rational(0));
+ if(left.getKind() == kind::CONST_RATIONAL){
+ const Rational& num = left.getConst<Rational>();
+ Rational div = num / den;
+ Node result = mkRationalNode(div);
+ return RewriteResponse(REWRITE_DONE, result);
+ }
+
Rational div = den.inverse();
Node result = mkRationalNode(div);
@@ -337,6 +339,48 @@ RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
}
}
+RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre){
+ Kind k = t.getKind();
+ // Assert(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL ||
+ // k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
+
+ //Leaving the function as before (INTS_MODULUS can be handled),
+ // but restricting its use here
+ Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
+ TNode n = t[0], d = t[1];
+ bool dIsConstant = d.getKind() == kind::CONST_RATIONAL;
+ if(dIsConstant && d.getConst<Rational>().isZero()){
+ if(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL){
+ return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
+ }else{
+ // Do nothing for k == INTS_MODULUS
+ return RewriteResponse(REWRITE_DONE, t);
+ }
+ }else if(dIsConstant && d.getConst<Rational>().isOne()){
+ if(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL){
+ return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
+ }else{
+ Assert(k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
+ return RewriteResponse(REWRITE_AGAIN, n);
+ }
+ }else if(dIsConstant && n.getKind() == kind::CONST_RATIONAL){
+ Assert(d.getConst<Rational>().isIntegral());
+ Assert(n.getConst<Rational>().isIntegral());
+ Assert(!d.getConst<Rational>().isZero());
+ Integer di = d.getConst<Rational>().getNumerator();
+ Integer ni = n.getConst<Rational>().getNumerator();
+
+ bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
+
+ Integer result = isDiv ? ni.floorDivideQuotient(di) : ni.floorDivideRemainder(di);
+
+ Node resultNode = mkRationalNode(Rational(result));
+ return RewriteResponse(REWRITE_DONE, resultNode);
+ }else{
+ return RewriteResponse(REWRITE_DONE, t);
+ }
+}
+
}/* CVC4::theory::arith namespace */
}/* CVC4::theory namespace */
}/* CVC4 namespace */
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback