diff options
author | Andrew Reynolds <andrew.j.reynolds@gmail.com> | 2019-12-11 21:24:43 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-12-11 21:24:43 -0600 |
commit | a9cfdf6710e6a1bc4dd49bf09263fd8bce1af6b5 (patch) | |
tree | 3f2bef344eea37a871f31599df7170b89373d2f6 /src | |
parent | d803e7fcf60f9bb847853fe6ccf7589b94b76922 (diff) |
Fix CEGIS refinement for recursive functions evaluation (#3555)
Diffstat (limited to 'src')
-rw-r--r-- | src/theory/quantifiers/sygus/term_database_sygus.cpp | 40 |
1 files changed, 13 insertions, 27 deletions
diff --git a/src/theory/quantifiers/sygus/term_database_sygus.cpp b/src/theory/quantifiers/sygus/term_database_sygus.cpp index c5ea0f9f3..08fb58e40 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.cpp +++ b/src/theory/quantifiers/sygus/term_database_sygus.cpp @@ -735,6 +735,11 @@ SygusTypeInfo& TermDbSygus::getTypeInfo(TypeNode tn) Node TermDbSygus::rewriteNode(Node n) const { Node res = Rewriter::rewrite(n); + if (res.isConst()) + { + // constant, we are done + return res; + } if (options::sygusRecFun()) { if (d_funDefEval->hasDefinitions()) @@ -1006,34 +1011,13 @@ Node TermDbSygus::evaluateWithUnfolding( { if (ret == n && ret[0].isConst()) { - Trace("dt-eval-unfold-debug") - << "Optimize: evaluate constant head " << ret << std::endl; - // can just do direct evaluation here - // notice we prefer this code to the rewriter since it may use - // the evaluator - std::vector<Node> args; - bool success = true; - for (unsigned i = 1, nchild = ret.getNumChildren(); i < nchild; i++) - { - if (!ret[i].isConst()) - { - success = false; - break; - } - args.push_back(ret[i]); - } - if (success) - { - TypeNode rt = ret[0].getType(); - Node bret = sygusToBuiltin(ret[0], rt); - Node rete = evaluateBuiltin(rt, bret, args); - visited[n] = rete; - Trace("dt-eval-unfold-debug") - << "Return " << rete << " for " << n << std::endl; - return rete; - } + // use rewriting, possibly involving recursive functions + ret = rewriteNode(ret); + } + else + { + ret = d_eval_unfold->unfold(ret); } - ret = d_eval_unfold->unfold(ret); } if( ret.getNumChildren()>0 ){ std::vector< Node > children; @@ -1050,6 +1034,8 @@ Node TermDbSygus::evaluateWithUnfolding( ret = NodeManager::currentNM()->mkNode( ret.getKind(), children ); } ret = getExtRewriter()->extendedRewrite(ret); + // use rewriting, possibly involving recursive functions + ret = rewriteNode(ret); } visited[n] = ret; return ret; |