From b9a903cc9a13c7bcdd334eb38730e62858321f07 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Sun, 12 Apr 2020 15:10:31 -0500 Subject: Fixes for extended rewriter (#4278) Fixes #4273 and fixes #4274 . This also removes a spurious assertion from the Node::substitute method that the result node is not equal to the domain. This is violated for f(f(x)) { f(x) -> x }. --- src/expr/node.h | 4 +- src/theory/quantifiers/extended_rewrite.cpp | 54 ++++++++++++++++++---- src/theory/quantifiers/extended_rewrite.h | 2 + .../quantifiers/sygus/term_database_sygus.cpp | 5 +- 4 files changed, 52 insertions(+), 13 deletions(-) (limited to 'src') diff --git a/src/expr/node.h b/src/expr/node.h index e07499b69..8ded28f07 100644 --- a/src/expr/node.h +++ b/src/expr/node.h @@ -1348,7 +1348,8 @@ NodeTemplate::substitute(TNode node, TNode replacement, std::unordered_map& cache) const { Assert(node != *this); - if (getNumChildren() == 0) { + if (getNumChildren() == 0 || node == replacement) + { return *this; } @@ -1382,7 +1383,6 @@ NodeTemplate::substitute(TNode node, TNode replacement, // put in cache Node n = nb; - Assert(node != n); cache[*this] = n; return n; } diff --git a/src/theory/quantifiers/extended_rewrite.cpp b/src/theory/quantifiers/extended_rewrite.cpp index b0a474c56..1f42c384f 100644 --- a/src/theory/quantifiers/extended_rewrite.cpp +++ b/src/theory/quantifiers/extended_rewrite.cpp @@ -14,7 +14,6 @@ #include "theory/quantifiers/extended_rewrite.h" -#include "options/quantifiers_options.h" #include "theory/arith/arith_msum.h" #include "theory/bv/theory_bv_utils.h" #include "theory/datatypes/datatypes_rewriter.h" @@ -34,6 +33,11 @@ struct ExtRewriteAttributeId }; typedef expr::Attribute ExtRewriteAttribute; +struct ExtRewriteAggAttributeId +{ +}; +typedef expr::Attribute ExtRewriteAggAttribute; + ExtendedRewriter::ExtendedRewriter(bool aggr) : d_aggr(aggr) { d_true = NodeManager::currentNM()->mkConst(true); @@ -42,8 +46,35 @@ ExtendedRewriter::ExtendedRewriter(bool aggr) : d_aggr(aggr) void ExtendedRewriter::setCache(Node n, Node ret) { - ExtRewriteAttribute era; - n.setAttribute(era, ret); + if (d_aggr) + { + ExtRewriteAggAttribute erga; + n.setAttribute(erga, ret); + } + else + { + ExtRewriteAttribute era; + n.setAttribute(era, ret); + } +} + +Node ExtendedRewriter::getCache(Node n) +{ + if (d_aggr) + { + if (n.hasAttribute(ExtRewriteAggAttribute())) + { + return n.getAttribute(ExtRewriteAggAttribute()); + } + } + else + { + if (n.hasAttribute(ExtRewriteAttribute())) + { + return n.getAttribute(ExtRewriteAttribute()); + } + } + return Node::null(); } bool ExtendedRewriter::addToChildren(Node nc, @@ -63,15 +94,12 @@ bool ExtendedRewriter::addToChildren(Node nc, Node ExtendedRewriter::extendedRewrite(Node n) { n = Rewriter::rewrite(n); - if (!options::sygusExtRew()) - { - return n; - } // has it already been computed? - if (n.hasAttribute(ExtRewriteAttribute())) + Node ncache = getCache(n); + if (!ncache.isNull()) { - return n.getAttribute(ExtRewriteAttribute()); + return ncache; } Node ret = n; @@ -1226,6 +1254,7 @@ Node ExtendedRewriter::extendedRewriteEqChain( } Node c = cp.first; Kind ck = c.getKind(); + Trace("ext-rew-eqchain") << " process c = " << c << std::endl; if (ck == andk || ck == ork) { for (unsigned j = 0, nchild = c.getNumChildren(); j < nchild; j++) @@ -1233,9 +1262,12 @@ Node ExtendedRewriter::extendedRewriteEqChain( Node cl = c[j]; bool pol = cl.getKind() != notk; Node ca = pol ? cl : cl[0]; + bool newVal = (ck == andk ? !pol : pol); + Trace("ext-rew-eqchain") + << " atoms(" << c << ", " << ca << ") = " << newVal << std::endl; Assert(atoms[c].find(ca) == atoms[c].end()); // polarity is flipped when we are AND - atoms[c][ca] = (ck == andk ? !pol : pol); + atoms[c][ca] = newVal; alist[c].push_back(ca); // if this already exists as a child of the equality chain, eliminate. @@ -1284,6 +1316,8 @@ Node ExtendedRewriter::extendedRewriteEqChain( bool pol = ck != notk; Node ca = pol ? c : c[0]; atoms[c][ca] = pol; + Trace("ext-rew-eqchain") + << " atoms(" << c << ", " << ca << ") = " << pol << std::endl; alist[c].push_back(ca); } atom_count.push_back(std::pair(c, alist[c].size())); diff --git a/src/theory/quantifiers/extended_rewrite.h b/src/theory/quantifiers/extended_rewrite.h index 836e15b7b..9a0ab6382 100644 --- a/src/theory/quantifiers/extended_rewrite.h +++ b/src/theory/quantifiers/extended_rewrite.h @@ -69,6 +69,8 @@ class ExtendedRewriter Node d_false; /** cache that the extended rewritten form of n is ret */ void setCache(Node n, Node ret); + /** get the cache for n */ + Node getCache(Node n); /** add to children * * Adds nc to the vector of children, if dropDup is true, we do not add diff --git a/src/theory/quantifiers/sygus/term_database_sygus.cpp b/src/theory/quantifiers/sygus/term_database_sygus.cpp index a1b250142..ee028bff0 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.cpp +++ b/src/theory/quantifiers/sygus/term_database_sygus.cpp @@ -1018,7 +1018,10 @@ Node TermDbSygus::evaluateWithUnfolding( if( childChanged ){ ret = NodeManager::currentNM()->mkNode( ret.getKind(), children ); } - ret = getExtRewriter()->extendedRewrite(ret); + if (options::sygusExtRew()) + { + ret = getExtRewriter()->extendedRewrite(ret); + } // use rewriting, possibly involving recursive functions ret = rewriteNode(ret); } -- cgit v1.2.3