diff options
Diffstat (limited to 'src/theory/arith/arith_rewriter.cpp')
-rw-r--r-- | src/theory/arith/arith_rewriter.cpp | 135 |
1 files changed, 124 insertions, 11 deletions
diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp index 4684ec4a3..57428d209 100644 --- a/src/theory/arith/arith_rewriter.cpp +++ b/src/theory/arith/arith_rewriter.cpp @@ -101,7 +101,12 @@ RewriteResponse ArithRewriter::preRewriteTerm(TNode t){ return preRewritePlus(t); case kind::MULT: case kind::NONLINEAR_MULT: - return preRewriteMult(t); + return preRewriteMult(t); + case kind::EXPONENTIAL: + case kind::SINE: + case kind::COSINE: + case kind::TANGENT: + return preRewriteTranscendental(t); case kind::INTS_DIVISION: case kind::INTS_MODULUS: return RewriteResponse(REWRITE_DONE, t); @@ -126,6 +131,8 @@ RewriteResponse ArithRewriter::preRewriteTerm(TNode t){ return RewriteResponse(REWRITE_DONE, t[0]); case kind::POW: return RewriteResponse(REWRITE_DONE, t); + case kind::PI: + return RewriteResponse(REWRITE_DONE, t); default: Unhandled(k); } @@ -150,7 +157,12 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){ return postRewritePlus(t); case kind::MULT: case kind::NONLINEAR_MULT: - return postRewriteMult(t); + return postRewriteMult(t); + case kind::EXPONENTIAL: + case kind::SINE: + case kind::COSINE: + case kind::TANGENT: + return postRewriteTranscendental(t); case kind::INTS_DIVISION: case kind::INTS_MODULUS: return RewriteResponse(REWRITE_DONE, t); @@ -197,15 +209,21 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){ if(exp.sgn() == 0){ return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(1))); }else if(exp.sgn() > 0 && exp.isIntegral()){ - Integer num = exp.getNumerator(); - NodeBuilder<> nb(kind::MULT); - Integer one(1); - for(Integer i(0); i < num; i = i + one){ - nb << base; + CVC4::Rational r(INT_MAX); + if( exp<r ){ + unsigned num = exp.getNumerator().toUnsignedInt(); + if( num==1 ){ + return RewriteResponse(REWRITE_AGAIN, base); + }else{ + NodeBuilder<> nb(kind::MULT); + for(unsigned i=0; i < num; ++i){ + nb << base; + } + Assert(nb.getNumChildren() > 0); + Node mult = nb; + return RewriteResponse(REWRITE_AGAIN, mult); + } } - Assert(nb.getNumChildren() > 0); - Node mult = nb; - return RewriteResponse(REWRITE_AGAIN, mult); } } @@ -216,6 +234,8 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){ ss << " " << t; throw LogicException(ss.str()); } + case kind::PI: + return RewriteResponse(REWRITE_DONE, t); default: Unreachable(); } @@ -332,6 +352,100 @@ RewriteResponse ArithRewriter::postRewriteMult(TNode t){ return RewriteResponse(REWRITE_DONE, res.getNode()); } + +RewriteResponse ArithRewriter::preRewriteTranscendental(TNode t) { + return RewriteResponse(REWRITE_DONE, t); +} + +RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) { + Trace("arith-tf-rewrite") << "Rewrite transcendental function : " << t << std::endl; + switch( t.getKind() ){ + case kind::EXPONENTIAL: { + if(t[0].getKind() == kind::CONST_RATIONAL){ + Node one = NodeManager::currentNM()->mkConst(Rational(1)); + if(t[0].getConst<Rational>().sgn()>=0 && t[0].getType().isInteger() && t[0]!=one){ + return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::POW, NodeManager::currentNM()->mkNode( kind::EXPONENTIAL, one ), t[0])); + }else{ + return RewriteResponse(REWRITE_DONE, t); + } + }else if(t[0].getKind() == kind::PLUS ){ + std::vector<Node> product; + for( unsigned i=0; i<t[0].getNumChildren(); i++ ){ + product.push_back( NodeManager::currentNM()->mkNode( kind::EXPONENTIAL, t[0][i] ) ); + } + return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::MULT, product)); + } + } + break; + case kind::SINE: + if(t[0].getKind() == kind::CONST_RATIONAL){ + const Rational& rat = t[0].getConst<Rational>(); + if(rat.sgn() == 0){ + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(Rational(0))); + } + }else{ + Node pi_factor; + Node pi; + if( t[0].getKind()==kind::PI ){ + pi_factor = NodeManager::currentNM()->mkConst(Rational(1)); + pi = t[0]; + }else if( t[0].getKind()==kind::MULT && t[0][0].isConst() && t[0][1].getKind()==kind::PI ){ + pi_factor = t[0][0]; + pi = t[0][1]; + } + if( !pi_factor.isNull() ){ + Trace("arith-tf-rewrite-debug") << "Process pi factor = " << pi_factor << std::endl; + Rational r = pi_factor.getConst<Rational>(); + Rational ra = r.abs(); + Rational rone = Rational(1); + Node ntwo = NodeManager::currentNM()->mkConst( Rational(2) ); + if( ra > rone ){ + //add/substract 2*pi beyond scope + Node ra_div_two = NodeManager::currentNM()->mkNode( kind::INTS_DIVISION, NodeManager::currentNM()->mkConst( ra + rone ), ntwo ); + Node new_pi_factor; + if( r.sgn()==1 ){ + new_pi_factor = NodeManager::currentNM()->mkNode( kind::MINUS, pi_factor, NodeManager::currentNM()->mkNode( kind::MULT, ntwo, ra_div_two ) ); + }else{ + Assert( r.sgn()==-1 ); + new_pi_factor = NodeManager::currentNM()->mkNode( kind::PLUS, pi_factor, NodeManager::currentNM()->mkNode( kind::MULT, ntwo, ra_div_two ) ); + } + return RewriteResponse(REWRITE_AGAIN_FULL, NodeManager::currentNM()->mkNode( kind::SINE, + NodeManager::currentNM()->mkNode( kind::MULT, new_pi_factor, pi ) ) ); + }else if( ra == rone ){ + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(Rational(0))); + }else{ + Integer one = Integer(1); + Integer two = Integer(2); + Integer six = Integer(6); + if( ra.getDenominator()==two ){ + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst( Rational( r.sgn() ) ) ); + }else if( ra.getDenominator()==six ){ + Integer five = Integer(5); + if( ra.getNumerator()==one || ra.getNumerator()==five ){ + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst( Rational( r.sgn() )/Rational(2) ) ); + } + } + } + } + } + break; + case kind::COSINE: { + return RewriteResponse(REWRITE_AGAIN_FULL, NodeManager::currentNM()->mkNode( kind::SINE, + NodeManager::currentNM()->mkNode( kind::MINUS, + NodeManager::currentNM()->mkNode( kind::MULT, + NodeManager::currentNM()->mkConst( Rational(1)/Rational(2) ), + NodeManager::currentNM()->mkNullaryOperator( NodeManager::currentNM()->realType(), kind::PI ) ), + t[0] ) ) ); + } break; + case kind::TANGENT: + return RewriteResponse(REWRITE_AGAIN_FULL, NodeManager::currentNM()->mkNode(kind::DIVISION, NodeManager::currentNM()->mkNode( kind::SINE, t[0] ), + NodeManager::currentNM()->mkNode( kind::COSINE, t[0] ) )); + default: + break; + } + return RewriteResponse(REWRITE_DONE, t); +} + RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){ if(atom.getKind() == kind::IS_INTEGER) { if(atom[0].isConst()) { @@ -440,7 +554,6 @@ Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){ RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){ Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind()== kind::DIVISION); - Node left = t[0]; Node right = t[1]; if(right.getKind() == kind::CONST_RATIONAL){ |