diff options
Diffstat (limited to 'src/theory/arith/arith_rewriter.cpp')
-rw-r--r-- | src/theory/arith/arith_rewriter.cpp | 64 |
1 files changed, 15 insertions, 49 deletions
diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp index 689f231e6..a367b8599 100644 --- a/src/theory/arith/arith_rewriter.cpp +++ b/src/theory/arith/arith_rewriter.cpp @@ -68,6 +68,11 @@ RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){ RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){ Assert(t.getKind()== kind::UMINUS); + if(t[0].getKind() == kind::CONST_RATIONAL){ + Rational neg = -(t[0].getConst<Rational>()); + return RewriteResponse(REWRITE_DONE, mkRationalNode(neg)); + } + Node noUminus = makeUnaryMinusNode(t[0]); if(pre) return RewriteResponse(REWRITE_DONE, noUminus); @@ -87,9 +92,8 @@ RewriteResponse ArithRewriter::preRewriteTerm(TNode t){ case kind::UMINUS: return rewriteUMinus(t, true); case kind::DIVISION: - return rewriteDiv(t,true); case kind::DIVISION_TOTAL: - return rewriteDivTotal(t,true); + return rewriteDiv(t,true); case kind::PLUS: return preRewritePlus(t); case kind::MULT: @@ -116,9 +120,8 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){ case kind::UMINUS: return rewriteUMinus(t, false); case kind::DIVISION: - return rewriteDiv(t, false); case kind::DIVISION_TOTAL: - return rewriteDivTotal(t, false); + return rewriteDiv(t, false); case kind::PLUS: return postRewritePlus(t); case kind::MULT: @@ -260,51 +263,9 @@ Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){ return diff; } -RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){ - Assert(t.getKind()== kind::DIVISION); - - Node left = t[0]; - Node right = t[1]; - - if(right.getKind() == kind::CONST_RATIONAL && - left.getKind() != kind::CONST_RATIONAL){ - - const Rational& den = right.getConst<Rational>(); - - Assert(!den.isZero()); - - Rational div = den.inverse(); - Node result = mkRationalNode(div); - Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result); - if(pre){ - return RewriteResponse(REWRITE_DONE, mult); - }else{ - return RewriteResponse(REWRITE_AGAIN, mult); - } - } - - if(pre){ - if(right.getKind() != kind::CONST_RATIONAL || - left.getKind() != kind::CONST_RATIONAL){ - return RewriteResponse(REWRITE_DONE, t); - } - } - - Assert(right.getKind() == kind::CONST_RATIONAL); - Assert(left.getKind() == kind::CONST_RATIONAL); - - const Rational& den = right.getConst<Rational>(); - - Assert(!den.isZero()); - - const Rational& num = left.getConst<Rational>(); - Rational div = num / den; - Node result = mkRationalNode(div); - return RewriteResponse(REWRITE_DONE, result); -} -RewriteResponse ArithRewriter::rewriteDivTotal(TNode t, bool pre){ - Assert(t.getKind() == kind::DIVISION_TOTAL); +RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){ + Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind()== kind::DIVISION); Node left = t[0]; @@ -313,7 +274,12 @@ RewriteResponse ArithRewriter::rewriteDivTotal(TNode t, bool pre){ const Rational& den = right.getConst<Rational>(); if(den.isZero()){ - return RewriteResponse(REWRITE_DONE, mkRationalNode(0)); + if(t.getKind() == kind::DIVISION_TOTAL){ + return RewriteResponse(REWRITE_DONE, mkRationalNode(0)); + }else{ + // This is unsupported, but this is not a good place to complain + return RewriteResponse(REWRITE_DONE, t); + } } Assert(den != Rational(0)); |