diff options
Diffstat (limited to 'src/theory/quantifiers/sygus_inst.cpp')
-rw-r--r-- | src/theory/quantifiers/sygus_inst.cpp | 221 |
1 files changed, 213 insertions, 8 deletions
diff --git a/src/theory/quantifiers/sygus_inst.cpp b/src/theory/quantifiers/sygus_inst.cpp index f9a6456e1..4192ca746 100644 --- a/src/theory/quantifiers/sygus_inst.cpp +++ b/src/theory/quantifiers/sygus_inst.cpp @@ -29,10 +29,142 @@ namespace CVC4 { namespace theory { namespace quantifiers { +namespace { + +/** + * Collect maximal ground terms with type tn in node n. + * + * @param n: Node to traverse. + * @param tn: Collects only terms with type tn. + * @param terms: Collected terms. + * @param cache: Caches visited nodes. + * @param skip_quant: Do not traverse quantified formulas (skip quantifiers). + */ +void getMaxGroundTerms(TNode n, + TypeNode tn, + std::unordered_set<Node, NodeHashFunction>& terms, + std::unordered_set<TNode, TNodeHashFunction>& cache, + bool skip_quant = false) +{ + if (options::sygusInstTermSel() != options::SygusInstTermSelMode::MAX + && options::sygusInstTermSel() != options::SygusInstTermSelMode::BOTH) + { + return; + } + + Trace("sygus-inst-term") << "Find maximal terms with type " << tn + << " in: " << n << std::endl; + + Node cur; + std::vector<TNode> visit; + + visit.push_back(n); + do + { + cur = visit.back(); + visit.pop_back(); + + if (cache.find(cur) != cache.end()) + { + continue; + } + cache.insert(cur); + + if (expr::hasBoundVar(cur) || cur.getType() != tn) + { + if (!skip_quant || cur.getKind() != kind::FORALL) + { + visit.insert(visit.end(), cur.begin(), cur.end()); + } + } + else + { + terms.insert(cur); + Trace("sygus-inst-term") << " found: " << cur << std::endl; + } + } while (!visit.empty()); +} + +/* + * Collect minimal ground terms with type tn in node n. + * + * @param n: Node to traverse. + * @param tn: Collects only terms with type tn. + * @param terms: Collected terms. + * @param cache: Caches visited nodes and flags indicating whether a minimal + * term was already found in a subterm. + * @param skip_quant: Do not traverse quantified formulas (skip quantifiers). + */ +void getMinGroundTerms( + TNode n, + TypeNode tn, + std::unordered_set<Node, NodeHashFunction>& terms, + std::unordered_map<TNode, std::pair<bool, bool>, TNodeHashFunction>& cache, + bool skip_quant = false) +{ + if (options::sygusInstTermSel() != options::SygusInstTermSelMode::MIN + && options::sygusInstTermSel() != options::SygusInstTermSelMode::BOTH) + { + return; + } + + Trace("sygus-inst-term") << "Find minimal terms with type " << tn + << " in: " << n << std::endl; + + Node cur; + std::vector<TNode> visit; + + visit.push_back(n); + do + { + cur = visit.back(); + visit.pop_back(); + + auto it = cache.find(cur); + if (it == cache.end()) + { + cache.emplace(cur, std::make_pair(false, false)); + if (!skip_quant || cur.getKind() != kind::FORALL) + { + visit.push_back(cur); + visit.insert(visit.end(), cur.begin(), cur.end()); + } + } + /* up-traversal */ + else if (!it->second.first) + { + bool found_min_term = false; + + /* Check if we found a minimal term in one of the children. */ + for (const Node& c : cur) + { + found_min_term |= cache[c].second; + if (found_min_term) break; + } + + /* If we haven't found a minimal term yet, add this term if it has the + * right type. */ + if (cur.getType() == tn && !expr::hasBoundVar(cur) && !found_min_term) + { + terms.insert(cur); + found_min_term = true; + Trace("sygus-inst-term") << " found: " << cur << std::endl; + } + + it->second.first = true; + it->second.second = found_min_term; + } + } while (!visit.empty()); +} + +} // namespace + SygusInst::SygusInst(QuantifiersEngine* qe) : QuantifiersModule(qe), d_lemma_cache(qe->getUserContext()), - d_ce_lemma_added(qe->getUserContext()) + d_ce_lemma_added(qe->getUserContext()), + d_global_terms(qe->getUserContext()), + d_notified_assertions(qe->getUserContext()) { } @@ -149,14 +281,79 @@ void SygusInst::registerQuantifier(Node q) std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> include_cons; std::unordered_set<Node, NodeHashFunction> term_irrelevant; - /* Collect extra symbols in 'q' to be used in the grammar. */ - std::unordered_set<Node, NodeHashFunction> syms; - expr::getSymbols(q, syms); - for (const TNode& var : syms) + /* Collect relevant local ground terms for each variable type. */ + if (options::sygusInstScope() == options::SygusInstScope::IN + || options::sygusInstScope() == options::SygusInstScope::BOTH) + { + std::unordered_map<TypeNode, + std::unordered_set<Node, NodeHashFunction>, + TypeNodeHashFunction> + relevant_terms; + for (const Node& var : q[0]) + { + TypeNode tn = var.getType(); + + /* Collect relevant ground terms for type tn. */ + if (relevant_terms.find(tn) == relevant_terms.end()) + { + std::unordered_set<Node, NodeHashFunction> terms; + std::unordered_set<TNode, TNodeHashFunction> cache_max; + std::unordered_map<TNode, std::pair<bool, bool>, TNodeHashFunction> + cache_min; + + getMinGroundTerms(q, tn, terms, cache_min); + getMaxGroundTerms(q, tn, terms, cache_max); + relevant_terms.emplace(tn, terms); + } + + /* Add relevant ground terms to grammar. */ + auto& terms = relevant_terms[tn]; + for (const auto& t : terms) + { + TypeNode ttn = t.getType(); + extra_cons[ttn].insert(t); + Trace("sygus-inst") << "Adding (local) extra cons: " << t << std::endl; + } + } + } + + /* Collect relevant global ground terms for each variable type. */ + if (options::sygusInstScope() == options::SygusInstScope::OUT + || options::sygusInstScope() == options::SygusInstScope::BOTH) { - TypeNode tn = var.getType(); - extra_cons[tn].insert(var); - Trace("sygus-inst") << "Found symbol: " << var << std::endl; + for (const Node& var : q[0]) + { + TypeNode tn = var.getType(); + + /* Collect relevant ground terms for type tn. */ + if (d_global_terms.find(tn) == d_global_terms.end()) + { + std::unordered_set<Node, NodeHashFunction> terms; + std::unordered_set<TNode, TNodeHashFunction> cache_max; + std::unordered_map<TNode, std::pair<bool, bool>, TNodeHashFunction> + cache_min; + + for (const Node& a : d_notified_assertions) + { + getMinGroundTerms(a, tn, terms, cache_min, true); + getMaxGroundTerms(a, tn, terms, cache_max, true); + } + d_global_terms.insert(tn, terms); + } + + /* Add relevant ground terms to grammar. */ + auto it = d_global_terms.find(tn); + if (it != d_global_terms.end()) + { + for (const auto& t : (*it).second) + { + TypeNode ttn = t.getType(); + extra_cons[ttn].insert(t); + Trace("sygus-inst") + << "Adding (global) extra cons: " << t << std::endl; + } + } + } } /* Construct grammar for each bound variable of 'q'. */ @@ -190,6 +387,14 @@ void SygusInst::preRegisterQuantifier(Node q) addCeLemma(q); } +void SygusInst::ppNotifyAssertions(const std::vector<Node>& assertions) +{ + for (const Node& a : assertions) + { + d_notified_assertions.insert(a); + } +} + /*****************************************************************************/ /* private methods */ /*****************************************************************************/ |