summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>2019-12-06 13:12:12 -0600
committerAndres Noetzli <andres.noetzli@gmail.com>2019-12-06 11:12:12 -0800
commitec865a83596fd1285e033426b80ddfc1c35085cd (patch)
tree07a73efec0083b288bd94c39b301019400b92f3e
parent30e5875e066e917b69d01189233aec26ce226cd6 (diff)
Optimize the rewriter for DT_SYGUS_EVAL (#3529)
This makes it so that we don't construct intermediate unfoldings of applications of DT_SYGUS_EVAL, which wastes time in node construction. It makes the sygusToBuiltin utility in TermDbSygus use this implementation.
-rw-r--r--src/theory/datatypes/datatypes_rewriter.cpp26
-rw-r--r--src/theory/datatypes/theory_datatypes_utils.cpp195
-rw-r--r--src/theory/datatypes/theory_datatypes_utils.h34
-rw-r--r--src/theory/quantifiers/sygus/term_database_sygus.cpp5
4 files changed, 233 insertions, 27 deletions
diff --git a/src/theory/datatypes/datatypes_rewriter.cpp b/src/theory/datatypes/datatypes_rewriter.cpp
index be4226f69..080306d39 100644
--- a/src/theory/datatypes/datatypes_rewriter.cpp
+++ b/src/theory/datatypes/datatypes_rewriter.cpp
@@ -120,34 +120,16 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in)
if (ev.getKind() == APPLY_CONSTRUCTOR)
{
Trace("dt-sygus-util") << "Rewrite " << in << " by unfolding...\n";
- const Datatype& dt = ev.getType().getDatatype();
- unsigned i = utils::indexOf(ev.getOperator());
- Node op = Node::fromExpr(dt[i].getSygusOp());
- // if it is the "any constant" constructor, return its argument
- if (op.getAttribute(SygusAnyConstAttribute()))
- {
- Assert(ev.getNumChildren() == 1);
- Assert(ev[0].getType().isComparableTo(in.getType()));
- return RewriteResponse(REWRITE_AGAIN_FULL, ev[0]);
- }
+ Trace("dt-sygus-util") << "Type is " << in.getType() << std::endl;
std::vector<Node> args;
for (unsigned j = 1, nchild = in.getNumChildren(); j < nchild; j++)
{
args.push_back(in[j]);
}
- Assert(!dt.isParametric());
- std::vector<Node> children;
- for (const Node& evc : ev)
- {
- std::vector<Node> cc;
- cc.push_back(evc);
- cc.insert(cc.end(), args.begin(), args.end());
- children.push_back(nm->mkNode(DT_SYGUS_EVAL, cc));
- }
- Node ret = utils::mkSygusTerm(dt, i, children);
- // apply the appropriate substitution
- ret = utils::applySygusArgs(dt, op, ret, args);
+ Node ret = utils::sygusToBuiltinEval(ev, args);
Trace("dt-sygus-util") << "...got " << ret << "\n";
+ Trace("dt-sygus-util") << "Type is " << ret.getType() << std::endl;
+ Assert(in.getType().isComparableTo(ret.getType()));
return RewriteResponse(REWRITE_AGAIN_FULL, ret);
}
}
diff --git a/src/theory/datatypes/theory_datatypes_utils.cpp b/src/theory/datatypes/theory_datatypes_utils.cpp
index 43d23b523..d2833a852 100644
--- a/src/theory/datatypes/theory_datatypes_utils.cpp
+++ b/src/theory/datatypes/theory_datatypes_utils.cpp
@@ -18,6 +18,7 @@
#include "expr/node_algorithm.h"
#include "expr/sygus_datatype.h"
+#include "theory/evaluator.h"
using namespace CVC4;
using namespace CVC4::kind;
@@ -384,6 +385,200 @@ bool checkClash(Node n1, Node n2, std::vector<Node>& rew)
return false;
}
+struct SygusToBuiltinTermAttributeId
+{
+};
+typedef expr::Attribute<SygusToBuiltinTermAttributeId, Node>
+ SygusToBuiltinTermAttribute;
+
+Node sygusToBuiltin(Node n)
+{
+ Assert(n.isConst());
+ std::unordered_map<TNode, Node, TNodeHashFunction> visited;
+ std::unordered_map<TNode, Node, TNodeHashFunction>::iterator it;
+ std::vector<TNode> visit;
+ TNode cur;
+ unsigned index;
+ visit.push_back(n);
+ do
+ {
+ cur = visit.back();
+ visit.pop_back();
+ it = visited.find(cur);
+ if (it == visited.end())
+ {
+ if (cur.getKind() == APPLY_CONSTRUCTOR)
+ {
+ if (cur.hasAttribute(SygusToBuiltinTermAttribute()))
+ {
+ visited[cur] = cur.getAttribute(SygusToBuiltinTermAttribute());
+ }
+ else
+ {
+ visited[cur] = Node::null();
+ visit.push_back(cur);
+ for (const Node& cn : cur)
+ {
+ visit.push_back(cn);
+ }
+ }
+ }
+ else
+ {
+ // non-datatypes are themselves
+ visited[cur] = cur;
+ }
+ }
+ else if (it->second.isNull())
+ {
+ Node ret = cur;
+ Assert(cur.getKind() == APPLY_CONSTRUCTOR);
+ const Datatype& dt = cur.getType().getDatatype();
+ // Non sygus-datatype terms are also themselves. Notice we treat the
+ // case of non-sygus datatypes this way since it avoids computing
+ // the type / datatype of the node in the pre-traversal above. The
+ // case of non-sygus datatypes is very rare, so the extra addition to
+ // visited is justified performance-wise.
+ if (dt.isSygus())
+ {
+ std::vector<Node> children;
+ for (const Node& cn : cur)
+ {
+ it = visited.find(cn);
+ Assert(it != visited.end());
+ Assert(!it->second.isNull());
+ children.push_back(it->second);
+ }
+ index = indexOf(cur.getOperator());
+ ret = mkSygusTerm(dt, index, children);
+ }
+ visited[cur] = ret;
+ // cache
+ SygusToBuiltinTermAttribute stbt;
+ cur.setAttribute(stbt, ret);
+ }
+ } while (!visit.empty());
+ Assert(visited.find(n) != visited.end());
+ Assert(!visited.find(n)->second.isNull());
+ return visited[n];
+}
+
+Node sygusToBuiltinEval(Node n, const std::vector<Node>& args)
+{
+ NodeManager* nm = NodeManager::currentNM();
+ Evaluator eval;
+ // constant arguments?
+ bool constArgs = true;
+ for (const Node& a : args)
+ {
+ if (!a.isConst())
+ {
+ constArgs = false;
+ break;
+ }
+ }
+ std::vector<Node> eargs;
+ bool svarsInit = false;
+ std::vector<Node> svars;
+ std::unordered_map<TNode, Node, TNodeHashFunction> visited;
+ std::unordered_map<TNode, Node, TNodeHashFunction>::iterator it;
+ std::vector<TNode> visit;
+ TNode cur;
+ unsigned index;
+ visit.push_back(n);
+ do
+ {
+ cur = visit.back();
+ visit.pop_back();
+ it = visited.find(cur);
+ if (it == visited.end())
+ {
+ TypeNode tn = cur.getType();
+ if (!tn.isDatatype() || !tn.getDatatype().isSygus())
+ {
+ visited[cur] = cur;
+ }
+ else if (cur.isConst())
+ {
+ // convert to builtin term
+ Node bt = sygusToBuiltin(cur);
+ // run the evaluator if possible
+ if (!svarsInit)
+ {
+ svarsInit = true;
+ TypeNode tn = cur.getType();
+ Node varList = Node::fromExpr(tn.getDatatype().getSygusVarList());
+ for (const Node& v : varList)
+ {
+ svars.push_back(v);
+ }
+ }
+ Assert(args.size() == svars.size());
+ // try evaluation if we have constant arguments
+ Node ret = constArgs ? eval.eval(bt, svars, args) : Node::null();
+ if (ret.isNull())
+ {
+ // if evaluation was not available, use a substitution
+ ret = bt.substitute(
+ svars.begin(), svars.end(), args.begin(), args.end());
+ }
+ visited[cur] = ret;
+ }
+ else
+ {
+ if (cur.getKind() == APPLY_CONSTRUCTOR)
+ {
+ visited[cur] = Node::null();
+ visit.push_back(cur);
+ for (const Node& cn : cur)
+ {
+ visit.push_back(cn);
+ }
+ }
+ else
+ {
+ // it is the evaluation of this term on the arguments
+ if (eargs.empty())
+ {
+ eargs.push_back(cur);
+ eargs.insert(eargs.end(), args.begin(), args.end());
+ }
+ else
+ {
+ eargs[0] = cur;
+ }
+ visited[cur] = nm->mkNode(DT_SYGUS_EVAL, eargs);
+ }
+ }
+ }
+ else if (it->second.isNull())
+ {
+ Node ret = cur;
+ Assert(cur.getKind() == APPLY_CONSTRUCTOR);
+ const Datatype& dt = cur.getType().getDatatype();
+ // non sygus-datatype terms are also themselves
+ if (dt.isSygus())
+ {
+ std::vector<Node> children;
+ for (const Node& cn : cur)
+ {
+ it = visited.find(cn);
+ Assert(it != visited.end());
+ Assert(!it->second.isNull());
+ children.push_back(it->second);
+ }
+ index = indexOf(cur.getOperator());
+ // apply to arguments
+ ret = mkSygusTerm(dt, index, children);
+ }
+ visited[cur] = ret;
+ }
+ } while (!visit.empty());
+ Assert(visited.find(n) != visited.end());
+ Assert(!visited.find(n)->second.isNull());
+ return visited[n];
+}
+
} // namespace utils
} // namespace datatypes
} // namespace theory
diff --git a/src/theory/datatypes/theory_datatypes_utils.h b/src/theory/datatypes/theory_datatypes_utils.h
index 5f74a4bee..46a6d56be 100644
--- a/src/theory/datatypes/theory_datatypes_utils.h
+++ b/src/theory/datatypes/theory_datatypes_utils.h
@@ -185,12 +185,36 @@ Node applySygusArgs(const Datatype& dt,
Node op,
Node n,
const std::vector<Node>& args);
-/**
- * Get the builtin sygus operator for constructor term n of sygus datatype
- * type. For example, if n is the term C_+( d1, d2 ) where C_+ is a sygus
- * constructor whose sygus op is the builtin operator +, this method returns +.
+/** Sygus to builtin
+ *
+ * This method converts a constant term of SyGuS datatype type to its builtin
+ * equivalent. For example, given input C_*( C_x(), C_y() ), this method returns
+ * x*y, assuming C_+, C_x, and C_y have sygus operators *, x, and y
+ * respectively.
+ */
+Node sygusToBuiltin(Node c);
+/** Sygus to builtin eval
+ *
+ * This method returns the rewritten form of (DT_SYGUS_EVAL n args). Notice that
+ * n does not necessarily need to be a constant.
+ *
+ * It does so by (1) converting constant subterms of n to builtin terms and
+ * evaluating them on the arguments args, (2) unfolding non-constant
+ * applications of sygus constructors in n with respect to args and (3)
+ * converting all other non-constant subterms of n to applications of
+ * DT_SYGUS_EVAL.
+ *
+ * For example, if
+ * n = C_+( C_*( C_x(), C_y() ), n' ), and args = { 3, 4 }
+ * where n' is a variable, then this method returns:
+ * 12 + (DT_SYGUS_EVAL n' 3 4)
+ * Notice that the subterm C_*( C_x(), C_y() ) is converted to its builtin
+ * equivalent x*y and evaluated under the substition { x -> 3, x -> 4 } giving
+ * 12. The subterm n' is non-constant and thus we return its evaluation under
+ * 3,4, giving the term (DT_SYGUS_EVAL n' 3 4). Since the top-level constructor
+ * is C_+, these terms are added together to give the result.
*/
-Node getSygusOpForCTerm(Node n);
+Node sygusToBuiltinEval(Node n, const std::vector<Node>& args);
// ------------------------ end sygus utils
diff --git a/src/theory/quantifiers/sygus/term_database_sygus.cpp b/src/theory/quantifiers/sygus/term_database_sygus.cpp
index d664a462d..c5ea0f9f3 100644
--- a/src/theory/quantifiers/sygus/term_database_sygus.cpp
+++ b/src/theory/quantifiers/sygus/term_database_sygus.cpp
@@ -277,6 +277,11 @@ typedef expr::Attribute<SygusToBuiltinAttributeId, Node>
Node TermDbSygus::sygusToBuiltin(Node n, TypeNode tn)
{
+ if (n.isConst())
+ {
+ // if its a constant, we use the datatype utility version
+ return datatypes::utils::sygusToBuiltin(n);
+ }
Assert(n.getType().isComparableTo(tn));
if (!tn.isDatatype())
{
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback