summaryrefslogtreecommitdiff
path: root/src/theory/quantifiers/sygus_inst.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/theory/quantifiers/sygus_inst.cpp')
-rw-r--r--src/theory/quantifiers/sygus_inst.cpp221
1 files changed, 213 insertions, 8 deletions
diff --git a/src/theory/quantifiers/sygus_inst.cpp b/src/theory/quantifiers/sygus_inst.cpp
index f9a6456e1..4192ca746 100644
--- a/src/theory/quantifiers/sygus_inst.cpp
+++ b/src/theory/quantifiers/sygus_inst.cpp
@@ -29,10 +29,142 @@ namespace CVC4 {
namespace theory {
namespace quantifiers {
+namespace {
+
+/**
+ * Collect maximal ground terms with type tn in node n.
+ *
+ * @param n: Node to traverse.
+ * @param tn: Collects only terms with type tn.
+ * @param terms: Collected terms.
+ * @param cache: Caches visited nodes.
+ * @param skip_quant: Do not traverse quantified formulas (skip quantifiers).
+ */
+void getMaxGroundTerms(TNode n,
+ TypeNode tn,
+ std::unordered_set<Node, NodeHashFunction>& terms,
+ std::unordered_set<TNode, TNodeHashFunction>& cache,
+ bool skip_quant = false)
+{
+ if (options::sygusInstTermSel() != options::SygusInstTermSelMode::MAX
+ && options::sygusInstTermSel() != options::SygusInstTermSelMode::BOTH)
+ {
+ return;
+ }
+
+ Trace("sygus-inst-term") << "Find maximal terms with type " << tn
+ << " in: " << n << std::endl;
+
+ Node cur;
+ std::vector<TNode> visit;
+
+ visit.push_back(n);
+ do
+ {
+ cur = visit.back();
+ visit.pop_back();
+
+ if (cache.find(cur) != cache.end())
+ {
+ continue;
+ }
+ cache.insert(cur);
+
+ if (expr::hasBoundVar(cur) || cur.getType() != tn)
+ {
+ if (!skip_quant || cur.getKind() != kind::FORALL)
+ {
+ visit.insert(visit.end(), cur.begin(), cur.end());
+ }
+ }
+ else
+ {
+ terms.insert(cur);
+ Trace("sygus-inst-term") << " found: " << cur << std::endl;
+ }
+ } while (!visit.empty());
+}
+
+/*
+ * Collect minimal ground terms with type tn in node n.
+ *
+ * @param n: Node to traverse.
+ * @param tn: Collects only terms with type tn.
+ * @param terms: Collected terms.
+ * @param cache: Caches visited nodes and flags indicating whether a minimal
+ * term was already found in a subterm.
+ * @param skip_quant: Do not traverse quantified formulas (skip quantifiers).
+ */
+void getMinGroundTerms(
+ TNode n,
+ TypeNode tn,
+ std::unordered_set<Node, NodeHashFunction>& terms,
+ std::unordered_map<TNode, std::pair<bool, bool>, TNodeHashFunction>& cache,
+ bool skip_quant = false)
+{
+ if (options::sygusInstTermSel() != options::SygusInstTermSelMode::MIN
+ && options::sygusInstTermSel() != options::SygusInstTermSelMode::BOTH)
+ {
+ return;
+ }
+
+ Trace("sygus-inst-term") << "Find minimal terms with type " << tn
+ << " in: " << n << std::endl;
+
+ Node cur;
+ std::vector<TNode> visit;
+
+ visit.push_back(n);
+ do
+ {
+ cur = visit.back();
+ visit.pop_back();
+
+ auto it = cache.find(cur);
+ if (it == cache.end())
+ {
+ cache.emplace(cur, std::make_pair(false, false));
+ if (!skip_quant || cur.getKind() != kind::FORALL)
+ {
+ visit.push_back(cur);
+ visit.insert(visit.end(), cur.begin(), cur.end());
+ }
+ }
+ /* up-traversal */
+ else if (!it->second.first)
+ {
+ bool found_min_term = false;
+
+ /* Check if we found a minimal term in one of the children. */
+ for (const Node& c : cur)
+ {
+ found_min_term |= cache[c].second;
+ if (found_min_term) break;
+ }
+
+ /* If we haven't found a minimal term yet, add this term if it has the
+ * right type. */
+ if (cur.getType() == tn && !expr::hasBoundVar(cur) && !found_min_term)
+ {
+ terms.insert(cur);
+ found_min_term = true;
+ Trace("sygus-inst-term") << " found: " << cur << std::endl;
+ }
+
+ it->second.first = true;
+ it->second.second = found_min_term;
+ }
+ } while (!visit.empty());
+}
+
+} // namespace
+
SygusInst::SygusInst(QuantifiersEngine* qe)
: QuantifiersModule(qe),
d_lemma_cache(qe->getUserContext()),
- d_ce_lemma_added(qe->getUserContext())
+ d_ce_lemma_added(qe->getUserContext()),
+ d_global_terms(qe->getUserContext()),
+ d_notified_assertions(qe->getUserContext())
{
}
@@ -149,14 +281,79 @@ void SygusInst::registerQuantifier(Node q)
std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> include_cons;
std::unordered_set<Node, NodeHashFunction> term_irrelevant;
- /* Collect extra symbols in 'q' to be used in the grammar. */
- std::unordered_set<Node, NodeHashFunction> syms;
- expr::getSymbols(q, syms);
- for (const TNode& var : syms)
+ /* Collect relevant local ground terms for each variable type. */
+ if (options::sygusInstScope() == options::SygusInstScope::IN
+ || options::sygusInstScope() == options::SygusInstScope::BOTH)
+ {
+ std::unordered_map<TypeNode,
+ std::unordered_set<Node, NodeHashFunction>,
+ TypeNodeHashFunction>
+ relevant_terms;
+ for (const Node& var : q[0])
+ {
+ TypeNode tn = var.getType();
+
+ /* Collect relevant ground terms for type tn. */
+ if (relevant_terms.find(tn) == relevant_terms.end())
+ {
+ std::unordered_set<Node, NodeHashFunction> terms;
+ std::unordered_set<TNode, TNodeHashFunction> cache_max;
+ std::unordered_map<TNode, std::pair<bool, bool>, TNodeHashFunction>
+ cache_min;
+
+ getMinGroundTerms(q, tn, terms, cache_min);
+ getMaxGroundTerms(q, tn, terms, cache_max);
+ relevant_terms.emplace(tn, terms);
+ }
+
+ /* Add relevant ground terms to grammar. */
+ auto& terms = relevant_terms[tn];
+ for (const auto& t : terms)
+ {
+ TypeNode ttn = t.getType();
+ extra_cons[ttn].insert(t);
+ Trace("sygus-inst") << "Adding (local) extra cons: " << t << std::endl;
+ }
+ }
+ }
+
+ /* Collect relevant global ground terms for each variable type. */
+ if (options::sygusInstScope() == options::SygusInstScope::OUT
+ || options::sygusInstScope() == options::SygusInstScope::BOTH)
{
- TypeNode tn = var.getType();
- extra_cons[tn].insert(var);
- Trace("sygus-inst") << "Found symbol: " << var << std::endl;
+ for (const Node& var : q[0])
+ {
+ TypeNode tn = var.getType();
+
+ /* Collect relevant ground terms for type tn. */
+ if (d_global_terms.find(tn) == d_global_terms.end())
+ {
+ std::unordered_set<Node, NodeHashFunction> terms;
+ std::unordered_set<TNode, TNodeHashFunction> cache_max;
+ std::unordered_map<TNode, std::pair<bool, bool>, TNodeHashFunction>
+ cache_min;
+
+ for (const Node& a : d_notified_assertions)
+ {
+ getMinGroundTerms(a, tn, terms, cache_min, true);
+ getMaxGroundTerms(a, tn, terms, cache_max, true);
+ }
+ d_global_terms.insert(tn, terms);
+ }
+
+ /* Add relevant ground terms to grammar. */
+ auto it = d_global_terms.find(tn);
+ if (it != d_global_terms.end())
+ {
+ for (const auto& t : (*it).second)
+ {
+ TypeNode ttn = t.getType();
+ extra_cons[ttn].insert(t);
+ Trace("sygus-inst")
+ << "Adding (global) extra cons: " << t << std::endl;
+ }
+ }
+ }
}
/* Construct grammar for each bound variable of 'q'. */
@@ -190,6 +387,14 @@ void SygusInst::preRegisterQuantifier(Node q)
addCeLemma(q);
}
+void SygusInst::ppNotifyAssertions(const std::vector<Node>& assertions)
+{
+ for (const Node& a : assertions)
+ {
+ d_notified_assertions.insert(a);
+ }
+}
+
/*****************************************************************************/
/* private methods */
/*****************************************************************************/
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback