summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>2018-03-27 11:53:49 -0500
committerGitHub <noreply@github.com>2018-03-27 11:53:49 -0500
commit9dcaaeba4880a8f145df00289ff1b092a7e3dd47 (patch)
tree97c6ffc45fd906e8f1b84726653cdac52bdc2f26
parent6a656809c353776c9de9580b19a6de79ef5a76d4 (diff)
Filter candidate rewrites based on matching (#1682)
-rw-r--r--src/theory/quantifiers/dynamic_rewrite.cpp14
-rw-r--r--src/theory/quantifiers/dynamic_rewrite.h4
-rw-r--r--src/theory/quantifiers/sygus_sampler.cpp280
-rw-r--r--src/theory/quantifiers/sygus_sampler.h133
4 files changed, 383 insertions, 48 deletions
diff --git a/src/theory/quantifiers/dynamic_rewrite.cpp b/src/theory/quantifiers/dynamic_rewrite.cpp
index 3462a4d10..cb7379910 100644
--- a/src/theory/quantifiers/dynamic_rewrite.cpp
+++ b/src/theory/quantifiers/dynamic_rewrite.cpp
@@ -66,6 +66,20 @@ bool DynamicRewriter::addRewrite(Node a, Node b)
return true;
}
+bool DynamicRewriter::areEqual(Node a, Node b)
+{
+ if (a == b)
+ {
+ return true;
+ }
+ // add to the equality engine
+ Node ai = toInternal(a);
+ Node bi = toInternal(b);
+ d_equalityEngine.addTerm(ai);
+ d_equalityEngine.addTerm(bi);
+ return d_equalityEngine.areEqual(ai, bi);
+}
+
Node DynamicRewriter::toInternal(Node a)
{
std::map<Node, Node>::iterator it = d_term_to_internal.find(a);
diff --git a/src/theory/quantifiers/dynamic_rewrite.h b/src/theory/quantifiers/dynamic_rewrite.h
index 2b5464151..388173829 100644
--- a/src/theory/quantifiers/dynamic_rewrite.h
+++ b/src/theory/quantifiers/dynamic_rewrite.h
@@ -63,6 +63,10 @@ class DynamicRewriter
* a = b based on the previous equalities it has seen.
*/
bool addRewrite(Node a, Node b);
+ /**
+ * Check whether this class knows that the equality a = b holds.
+ */
+ bool areEqual(Node a, Node b);
private:
/** pointer to the quantifiers engine */
diff --git a/src/theory/quantifiers/sygus_sampler.cpp b/src/theory/quantifiers/sygus_sampler.cpp
index afbdc42e1..99494657f 100644
--- a/src/theory/quantifiers/sygus_sampler.cpp
+++ b/src/theory/quantifiers/sygus_sampler.cpp
@@ -678,6 +678,8 @@ void SygusSampler::registerSygusType(TypeNode tn)
}
}
+SygusSamplerExt::SygusSamplerExt() : d_ssenm(*this) {}
+
void SygusSamplerExt::initializeSygusExt(QuantifiersEngine* qe,
Node f,
unsigned nsamples,
@@ -691,6 +693,8 @@ void SygusSamplerExt::initializeSygusExt(QuantifiersEngine* qe,
ss << f;
d_drewrite =
std::unique_ptr<DynamicRewriter>(new DynamicRewriter(ss.str(), qe));
+ d_pairs.clear();
+ d_match_trie.clear();
}
Node SygusSamplerExt::registerTerm(Node n, bool forceKeep)
@@ -700,6 +704,7 @@ Node SygusSamplerExt::registerTerm(Node n, bool forceKeep)
<< std::endl;
if (eq_n == n)
{
+ // this is a unique term
return n;
}
Node bn = n;
@@ -709,63 +714,268 @@ Node SygusSamplerExt::registerTerm(Node n, bool forceKeep)
bn = d_tds->sygusToBuiltin(n);
beq_n = d_tds->sygusToBuiltin(eq_n);
}
- // one of eq_n or n must be ordered
- bool eqor = isOrdered(beq_n);
- bool nor = isOrdered(bn);
- Trace("sygus-synth-rr-debug")
- << "Ordered? : " << nor << " " << eqor << std::endl;
- bool isUnique = false;
- if (eqor || nor)
+ // whether we will keep this pair
+ bool keep = true;
+
+ // ----- check matchable
+ // check whether the pair is matchable with a previous one
+ d_curr_pair_rhs = beq_n;
+ Trace("sse-match") << "SSE check matches : " << n << " [rhs = " << eq_n
+ << "]..." << std::endl;
+ if (!d_match_trie.getMatches(bn, &d_ssenm))
{
- isUnique = true;
- // if only one is ordered, then the ordered one must contain the
- // free variables of the other
- if (!eqor)
- {
- isUnique = containsFreeVariables(bn, beq_n);
- }
- else if (!nor)
- {
- isUnique = containsFreeVariables(beq_n, bn);
- }
+ keep = false;
+ Trace("sygus-synth-rr-debug") << "...redundant (matchable)" << std::endl;
}
- Trace("sygus-synth-rr-debug") << "AlphaEq unique: " << isUnique << std::endl;
- bool rewRedundant = false;
+
+ // ----- check rewriting redundancy
if (d_drewrite != nullptr)
{
- Trace("sygus-synth-rr-debug") << "Add rewrite..." << std::endl;
+ Trace("sygus-synth-rr-debug") << "Add rewrite pair..." << std::endl;
if (!d_drewrite->addRewrite(bn, beq_n))
{
- rewRedundant = isUnique;
// must be unique according to the dynamic rewriter
- isUnique = false;
+ keep = false;
+ Trace("sygus-synth-rr-debug") << "...redundant (rewritable)" << std::endl;
}
}
- Trace("sygus-synth-rr-debug") << "Rewrite unique: " << isUnique << std::endl;
- if (isUnique)
+ if (keep)
{
- // if the previous value stored was unordered, but this is
- // ordered, we prefer this one. Thus, we force its addition to the
- // sampler database.
- if (!eqor)
+ // add to match information
+ for (unsigned r = 0; r < 2; r++)
{
- SygusSampler::registerTerm(n, true);
+ Node t = r == 0 ? bn : beq_n;
+ Node to = r == 0 ? beq_n : bn;
+ // insert in match trie if first time
+ if (d_pairs.find(t) == d_pairs.end())
+ {
+ Trace("sse-match") << "SSE add term : " << t << std::endl;
+ d_match_trie.addTerm(t);
+ }
+ d_pairs[t].insert(to);
}
return eq_n;
}
else if (Trace.isOn("sygus-synth-rr"))
{
- Trace("sygus-synth-rr") << "Redundant rewrite : " << eq_n << " " << n;
- if (rewRedundant)
- {
- Trace("sygus-synth-rr") << " (by rewriting)";
- }
+ Trace("sygus-synth-rr") << "Redundant pair : " << eq_n << " " << n;
Trace("sygus-synth-rr") << std::endl;
}
return Node::null();
}
+bool SygusSamplerExt::notify(Node s,
+ Node n,
+ std::vector<Node>& vars,
+ std::vector<Node>& subs)
+{
+ Assert(!d_curr_pair_rhs.isNull());
+ std::map<Node, std::unordered_set<Node, NodeHashFunction> >::iterator it =
+ d_pairs.find(n);
+ if (Trace.isOn("sse-match"))
+ {
+ Trace("sse-match") << " " << s << " matches " << n
+ << " under:" << std::endl;
+ for (unsigned i = 0, size = vars.size(); i < size; i++)
+ {
+ Trace("sse-match") << " " << vars[i] << " -> " << subs[i] << std::endl;
+ }
+ }
+ Assert(it != d_pairs.end());
+ for (const Node& nr : it->second)
+ {
+ Node nrs =
+ nr.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
+ bool areEqual = (nrs == d_curr_pair_rhs);
+ if (!areEqual && d_drewrite != nullptr)
+ {
+ // if dynamic rewriter is available, consult it
+ areEqual = d_drewrite->areEqual(nrs, d_curr_pair_rhs);
+ }
+ if (areEqual)
+ {
+ Trace("sse-match") << "*** Match, current pair: " << std::endl;
+ Trace("sse-match") << " (" << s << ", " << d_curr_pair_rhs << ")"
+ << std::endl;
+ Trace("sse-match") << "is an instance of previous pair:" << std::endl;
+ Trace("sse-match") << " (" << n << ", " << nr << ")" << std::endl;
+ return false;
+ }
+ }
+ return true;
+}
+
+bool MatchTrie::getMatches(Node n, NotifyMatch* ntm)
+{
+ std::vector<Node> vars;
+ std::vector<Node> subs;
+ std::map<Node, Node> smap;
+
+ std::vector<std::vector<Node> > visit;
+ std::vector<MatchTrie*> visit_trie;
+ std::vector<int> visit_var_index;
+ std::vector<bool> visit_bound_var;
+
+ visit.push_back(std::vector<Node>{n});
+ visit_trie.push_back(this);
+ visit_var_index.push_back(-1);
+ visit_bound_var.push_back(false);
+ while (!visit.empty())
+ {
+ std::vector<Node> cvisit = visit.back();
+ MatchTrie* curr = visit_trie.back();
+ if (cvisit.empty())
+ {
+ Assert(n
+ == curr->d_data.substitute(
+ vars.begin(), vars.end(), subs.begin(), subs.end()));
+ Trace("sse-match-debug") << "notify : " << curr->d_data << std::endl;
+ if (!ntm->notify(n, curr->d_data, vars, subs))
+ {
+ return false;
+ }
+ visit.pop_back();
+ visit_trie.pop_back();
+ visit_var_index.pop_back();
+ visit_bound_var.pop_back();
+ }
+ else
+ {
+ Node cn = cvisit.back();
+ Trace("sse-match-debug")
+ << "traverse : " << cn << " at depth " << visit.size() << std::endl;
+ unsigned index = visit.size() - 1;
+ int vindex = visit_var_index[index];
+ if (vindex == -1)
+ {
+ if (!cn.isVar())
+ {
+ Node op = cn.hasOperator() ? cn.getOperator() : cn;
+ unsigned nchild = cn.hasOperator() ? cn.getNumChildren() : 0;
+ std::map<unsigned, MatchTrie>::iterator itu =
+ curr->d_children[op].find(nchild);
+ if (itu != curr->d_children[op].end())
+ {
+ // recurse on the operator or self
+ cvisit.pop_back();
+ if (cn.hasOperator())
+ {
+ for (const Node& cnc : cn)
+ {
+ cvisit.push_back(cnc);
+ }
+ }
+ Trace("sse-match-debug") << "recurse op : " << op << std::endl;
+ visit.push_back(cvisit);
+ visit_trie.push_back(&itu->second);
+ visit_var_index.push_back(-1);
+ visit_bound_var.push_back(false);
+ }
+ }
+ visit_var_index[index]++;
+ }
+ else
+ {
+ // clean up if we previously bound a variable
+ if (visit_bound_var[index])
+ {
+ Assert(!vars.empty());
+ smap.erase(vars.back());
+ vars.pop_back();
+ subs.pop_back();
+ }
+
+ if (vindex == static_cast<int>(curr->d_vars.size()))
+ {
+ Trace("sse-match-debug")
+ << "finished checking " << curr->d_vars.size()
+ << " variables at depth " << visit.size() << std::endl;
+ // finished
+ visit.pop_back();
+ visit_trie.pop_back();
+ visit_var_index.pop_back();
+ visit_bound_var.pop_back();
+ }
+ else
+ {
+ Trace("sse-match-debug") << "check variable #" << vindex
+ << " at depth " << visit.size() << std::endl;
+ Assert(vindex < static_cast<int>(curr->d_vars.size()));
+ // recurse on variable?
+ Node var = curr->d_vars[vindex];
+ bool recurse = true;
+ // check if it is already bound
+ std::map<Node, Node>::iterator its = smap.find(var);
+ if (its != smap.end())
+ {
+ if (its->second != cn)
+ {
+ recurse = false;
+ }
+ }
+ else
+ {
+ vars.push_back(var);
+ subs.push_back(cn);
+ smap[var] = cn;
+ visit_bound_var[index] = true;
+ }
+ if (recurse)
+ {
+ Trace("sse-match-debug") << "recurse var : " << var << std::endl;
+ cvisit.pop_back();
+ visit.push_back(cvisit);
+ visit_trie.push_back(&curr->d_children[var][0]);
+ visit_var_index.push_back(-1);
+ visit_bound_var.push_back(false);
+ }
+ visit_var_index[index]++;
+ }
+ }
+ }
+ }
+ return true;
+}
+
+void MatchTrie::addTerm(Node n)
+{
+ std::vector<Node> visit;
+ visit.push_back(n);
+ MatchTrie* curr = this;
+ while (!visit.empty())
+ {
+ Node cn = visit.back();
+ visit.pop_back();
+ if (cn.hasOperator())
+ {
+ curr = &(curr->d_children[cn.getOperator()][cn.getNumChildren()]);
+ for (const Node& cnc : cn)
+ {
+ visit.push_back(cnc);
+ }
+ }
+ else
+ {
+ if (cn.isVar()
+ && std::find(curr->d_vars.begin(), curr->d_vars.end(), cn)
+ == curr->d_vars.end())
+ {
+ curr->d_vars.push_back(cn);
+ }
+ curr = &(curr->d_children[cn][0]);
+ }
+ }
+ curr->d_data = n;
+}
+
+void MatchTrie::clear()
+{
+ d_children.clear();
+ d_vars.clear();
+ d_data = Node::null();
+}
+
} /* CVC4::theory::quantifiers namespace */
} /* CVC4::theory namespace */
} /* CVC4 namespace */
diff --git a/src/theory/quantifiers/sygus_sampler.h b/src/theory/quantifiers/sygus_sampler.h
index 4bc10075d..fa0d670d2 100644
--- a/src/theory/quantifiers/sygus_sampler.h
+++ b/src/theory/quantifiers/sygus_sampler.h
@@ -340,10 +340,60 @@ class SygusSampler : public LazyTrieEvaluator
void registerSygusType(TypeNode tn);
};
+/** A virtual class for notifications regarding matches. */
+class NotifyMatch
+{
+ public:
+ /**
+ * A notification that s is equal to n * { vars -> subs }. This function
+ * should return false if we do not wish to be notified of further matches.
+ */
+ virtual bool notify(Node s,
+ Node n,
+ std::vector<Node>& vars,
+ std::vector<Node>& subs) = 0;
+};
+
+/**
+ * A trie (discrimination tree) storing a set of terms S, that can be used to
+ * query, for a given term t, all terms from S that are matchable with t.
+ */
+class MatchTrie
+{
+ public:
+ /** Get matches
+ *
+ * This calls ntm->notify( n, s, vars, subs ) for each term s stored in this
+ * trie that is matchable with n where s = n * { vars -> subs } for some
+ * vars, subs. This function returns false if one of these calls to notify
+ * returns false.
+ */
+ bool getMatches(Node n, NotifyMatch* ntm);
+ /** Adds node n to this trie */
+ void addTerm(Node n);
+ /** Clear this trie */
+ void clear();
+
+ private:
+ /**
+ * The children of this node in the trie. Terms t are indexed by a
+ * depth-first (right to left) traversal on its subterms, where the
+ * top-symbol of t is indexed by:
+ * - (operator, #children) if t has an operator, or
+ * - (t, 0) if t does not have an operator.
+ */
+ std::map<Node, std::map<unsigned, MatchTrie> > d_children;
+ /** The set of variables in the domain of d_children */
+ std::vector<Node> d_vars;
+ /** The data of this node in the trie */
+ Node d_data;
+};
+
/** Version of the above class with some additional features */
class SygusSamplerExt : public SygusSampler
{
public:
+ SygusSamplerExt();
/** initialize extended */
void initializeSygusExt(QuantifiersEngine* qe,
Node f,
@@ -351,31 +401,88 @@ class SygusSamplerExt : public SygusSampler
bool useSygusType);
/** register term n with this sampler database
*
+ * For each call to registerTerm( t, ... ) that returns s, we say that
+ * (t,s) and (s,t) are "relevant pairs".
+ *
* This returns either null, or a term ret with the same guarantees as
* SygusSampler::registerTerm with the additional guarantee
- * that for all ret' returned by a previous call to registerTerm( n' ),
- * we have that n = ret is not alpha-equivalent to n' = ret'
+ * that for all previous relevant pairs ( n', nret' ),
+ * we have that n = ret is not an instance of n' = ret'
* modulo symmetry of equality, nor is n = ret derivable from the set of
- * all previous input/output pairs based on the d_drewrite utility.
- * For example,
- * (t+0), t and (s+0), s
- * will not both be input/output pairs of this function since t+0=t is
- * alpha-equivalent to s+0=s.
- * s, t and s+0, t+0
- * will not both be input/output pairs of this function since s+0=t+0 is
+ * all previous relevant pairs. The latter is determined by the d_drewrite
+ * utility. For example:
+ * [1] ( t+0, t ) and ( x+0, x )
+ * will not both be relevant pairs of this function since t+0=t is
+ * an instance of x+0=x.
+ * [2] ( s, t ) and ( s+0, t+0 )
+ * will not both be relevant pairs of this function since s+0=t+0 is
* derivable from s=t.
+ * These two criteria may be combined, for example:
+ * [3] ( t+0, s ) is not a relevant pair if both ( x+0, x+s ) and ( t+s, s )
+ * are relevant pairs, since t+0 is an instance of x+0 where
+ * { x |-> t }, and x+s { x |-> t } = s is derivable, via the third pair
+ * above (t+s = s).
*
* If this function returns null, then n is equivalent to a previously
- * registered term ret, and the equality n = ret is either alpha-equivalent
- * to a previous input/output pair n' = ret', or n = ret is derivable
- * from the set of all previous input/output pairs based on the
- * d_drewrite utility.
+ * registered term ret, and the equality ( n, ret ) is either an instance
+ * of a previous relevant pair ( n', ret' ), or n = ret is derivable
+ * from the set of all previous relevant pairs based on the
+ * d_drewrite utility, or is an instance of a previous pair
*/
Node registerTerm(Node n, bool forceKeep = false) override;
private:
/** dynamic rewriter class */
std::unique_ptr<DynamicRewriter> d_drewrite;
+
+ //----------------------------match filtering
+ /**
+ * Stores all relevant pairs returned by this sampler (see registerTerm). In
+ * detail, if (t,s) is a relevant pair, then t in d_pairs[s].
+ */
+ std::map<Node, std::unordered_set<Node, NodeHashFunction> > d_pairs;
+ /** Match trie storing all terms in the domain of d_pairs. */
+ MatchTrie d_match_trie;
+ /** Notify class */
+ class SygusSamplerExtNotifyMatch : public NotifyMatch
+ {
+ SygusSamplerExt& d_sse;
+
+ public:
+ SygusSamplerExtNotifyMatch(SygusSamplerExt& sse) : d_sse(sse) {}
+ /** notify match */
+ bool notify(Node n,
+ Node s,
+ std::vector<Node>& vars,
+ std::vector<Node>& subs) override
+ {
+ return d_sse.notify(n, s, vars, subs);
+ }
+ };
+ /** Notify object used for reporting matches from d_match_trie */
+ SygusSamplerExtNotifyMatch d_ssenm;
+ /**
+ * Stores the current right hand side of a pair we are considering.
+ *
+ * In more detail, in registerTerm, we are interested in whether a pair (s,t)
+ * is a relevant pair. We do this by:
+ * (1) Setting the node d_curr_pair_rhs to t,
+ * (2) Using d_match_trie, compute all terms s1...sn that match s.
+ * For each si, where s = si * sigma for some substitution sigma, we check
+ * whether t = ti * sigma for some previously relevant pair (si,ti), in
+ * which case (s,t) is an instance of (si,ti).
+ */
+ Node d_curr_pair_rhs;
+ /**
+ * Called by the above class during d_match_trie.getMatches( s ), when we
+ * find that si = s * sigma, where si is a term that is stored in
+ * d_match_trie.
+ *
+ * This function returns false if ( s, d_curr_pair_rhs ) is an instance of
+ * previously relevant pair.
+ */
+ bool notify(Node s, Node n, std::vector<Node>& vars, std::vector<Node>& subs);
+ //----------------------------end match filtering
};
} /* CVC4::theory::quantifiers namespace */
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback