summaryrefslogtreecommitdiff
path: root/src/theory/arith/arith_rewriter.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/theory/arith/arith_rewriter.cpp')
-rw-r--r--src/theory/arith/arith_rewriter.cpp76
1 files changed, 67 insertions, 9 deletions
diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp
index e1cab0356..5aa904aed 100644
--- a/src/theory/arith/arith_rewriter.cpp
+++ b/src/theory/arith/arith_rewriter.cpp
@@ -222,36 +222,94 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
RewriteResponse ArithRewriter::preRewriteMult(TNode t){
Assert(t.getKind()== kind::MULT);
- // Rewrite multiplications with a 0 argument and to 0
- Rational qZero(0);
+ if(t.getNumChildren() == 2){
+ if(t[0].getKind() == kind::CONST_RATIONAL
+ && t[0].getConst<Rational>().isOne()){
+ return RewriteResponse(REWRITE_DONE, t[1]);
+ }
+ if(t[1].getKind() == kind::CONST_RATIONAL
+ && t[1].getConst<Rational>().isOne()){
+ return RewriteResponse(REWRITE_DONE, t[0]);
+ }
+ }
+ // Rewrite multiplications with a 0 argument and to 0
for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
if((*i).getKind() == kind::CONST_RATIONAL) {
- if((*i).getConst<Rational>() == qZero) {
- return RewriteResponse(REWRITE_DONE, mkRationalNode(qZero));
+ if((*i).getConst<Rational>().isZero()) {
+ TNode zero = (*i);
+ return RewriteResponse(REWRITE_DONE, zero);
}
}
}
return RewriteResponse(REWRITE_DONE, t);
}
+
+static bool canFlatten(Kind k, TNode t){
+ for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
+ TNode child = *i;
+ if(child.getKind() == k){
+ return true;
+ }
+ }
+ return false;
+}
+
+static void flatten(std::vector<TNode>& pb, Kind k, TNode t){
+ if(t.getKind() == k){
+ for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
+ TNode child = *i;
+ if(child.getKind() == k){
+ flatten(pb, k, child);
+ }else{
+ pb.push_back(child);
+ }
+ }
+ }else{
+ pb.push_back(t);
+ }
+}
+
+static Node flatten(Kind k, TNode t){
+ std::vector<TNode> pb;
+ flatten(pb, k, t);
+ Assert(pb.size() >= 2);
+ return NodeManager::currentNM()->mkNode(k, pb);
+}
+
RewriteResponse ArithRewriter::preRewritePlus(TNode t){
Assert(t.getKind()== kind::PLUS);
- return RewriteResponse(REWRITE_DONE, t);
+ if(canFlatten(kind::PLUS, t)){
+ return RewriteResponse(REWRITE_DONE, flatten(kind::PLUS, t));
+ }else{
+ return RewriteResponse(REWRITE_DONE, t);
+ }
}
RewriteResponse ArithRewriter::postRewritePlus(TNode t){
Assert(t.getKind()== kind::PLUS);
- Polynomial res = Polynomial::mkZero();
+ std::vector<Monomial> monomials;
+ std::vector<Polynomial> polynomials;
for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
- Node curr = *i;
- Polynomial currPoly = Polynomial::parsePolynomial(curr);
+ TNode curr = *i;
+ if(Monomial::isMember(curr)){
+ monomials.push_back(Monomial::parseMonomial(curr));
+ }else{
+ polynomials.push_back(Polynomial::parsePolynomial(curr));
+ }
+ }
- res = res + currPoly;
+ if(!monomials.empty()){
+ Monomial::sort(monomials);
+ Monomial::combineAdjacentMonomials(monomials);
+ polynomials.push_back(Polynomial::mkPolynomial(monomials));
}
+ Polynomial res = Polynomial::sumPolynomials(polynomials);
+
return RewriteResponse(REWRITE_DONE, res.getNode());
}
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback