diff options
Diffstat (limited to 'src/theory/quantifiers/sygus_sampler.cpp')
-rw-r--r-- | src/theory/quantifiers/sygus_sampler.cpp | 187 |
1 files changed, 180 insertions, 7 deletions
diff --git a/src/theory/quantifiers/sygus_sampler.cpp b/src/theory/quantifiers/sygus_sampler.cpp index 0b8f390f3..82720dd5e 100644 --- a/src/theory/quantifiers/sygus_sampler.cpp +++ b/src/theory/quantifiers/sygus_sampler.cpp @@ -14,6 +14,7 @@ #include "theory/quantifiers/sygus_sampler.h" +#include "options/quantifiers_options.h" #include "util/bitvector.h" #include "util/random.h" @@ -80,6 +81,9 @@ void SygusSampler::initialize(TypeNode tn, d_var_index[sv] = d_type_vars[svt].size(); d_type_vars[svt].push_back(sv); } + d_rvalue_cindices.clear(); + d_rvalue_null_cindices.clear(); + d_var_sygus_types.clear(); initializeSamples(nsamples); } @@ -109,6 +113,10 @@ void SygusSampler::initializeSygus(TermDbSygus* tds, Node f, unsigned nsamples) d_type_vars[svt].push_back(sv); } } + d_rvalue_cindices.clear(); + d_rvalue_null_cindices.clear(); + d_var_sygus_types.clear(); + registerSygusType(d_ftn); initializeSamples(nsamples); } @@ -123,28 +131,90 @@ void SygusSampler::initializeSamples(unsigned nsamples) Trace("sygus-sample") << " var #" << types.size() << " : " << v << " : " << vt << std::endl; } + std::map<unsigned, std::map<Node, std::vector<TypeNode> >::iterator> sts; + if (options::sygusSampleGrammar()) + { + for (unsigned j = 0, size = types.size(); j < size; j++) + { + sts[j] = d_var_sygus_types.find(d_vars[j]); + } + } + + unsigned nduplicates = 0; for (unsigned i = 0; i < nsamples; i++) { std::vector<Node> sample_pt; - Trace("sygus-sample") << "Sample point #" << i << " : "; for (unsigned j = 0, size = types.size(); j < size; j++) { - Node r = getRandomValue(types[j]); + Node v = d_vars[j]; + Node r; + if (options::sygusSampleGrammar()) + { + // choose a random start sygus type, if possible + if (sts[j] != d_var_sygus_types.end()) + { + unsigned ntypes = sts[j]->second.size(); + Assert(ntypes > 0); + unsigned index = Random::getRandom().pick(0, ntypes - 1); + if (index < ntypes) + { + // currently hard coded to 0.0, 0.5 + r = getSygusRandomValue(sts[j]->second[index], 0.0, 0.5); + } + } + } if (r.isNull()) { - Trace("sygus-sample") << "INVALID"; - d_is_valid = false; + r = getRandomValue(types[j]); + if (r.isNull()) + { + d_is_valid = false; + } } - Trace("sygus-sample") << r << " "; sample_pt.push_back(r); } - Trace("sygus-sample") << std::endl; - d_samples.push_back(sample_pt); + if (d_samples_trie.add(sample_pt)) + { + if (Trace.isOn("sygus-sample")) + { + Trace("sygus-sample") << "Sample point #" << i << " : "; + for (const Node& r : sample_pt) + { + Trace("sygus-sample") << r << " "; + } + Trace("sygus-sample") << std::endl; + } + d_samples.push_back(sample_pt); + } + else + { + i--; + nduplicates++; + if (nduplicates == nsamples * 10) + { + Trace("sygus-sample") + << "...WARNING: excessive duplicates, cut off sampling at " << i + << "/" << nsamples << " points." << std::endl; + break; + } + } } d_trie.clear(); } +bool SygusSampler::PtTrie::add(std::vector<Node>& pt) +{ + PtTrie* curr = this; + for (unsigned i = 0, size = pt.size(); i < size; i++) + { + curr = &(curr->d_children[pt[i]]); + } + bool retVal = curr->d_children.empty(); + curr = &(curr->d_children[Node::null()]); + return retVal; +} + Node SygusSampler::registerTerm(Node n, bool forceKeep) { if (d_is_valid) @@ -389,6 +459,109 @@ Node SygusSampler::getRandomValue(TypeNode tn) return Node::null(); } +Node SygusSampler::getSygusRandomValue(TypeNode tn, + double rchance, + double rinc, + unsigned depth) +{ + Assert(tn.isDatatype()); + const Datatype& dt = static_cast<DatatypeType>(tn.toType()).getDatatype(); + Assert(dt.isSygus()); + Assert(d_rvalue_cindices.find(tn) != d_rvalue_cindices.end()); + Trace("sygus-sample-grammar") << "Sygus random value " << tn + << ", depth = " << depth + << ", rchance = " << rchance << std::endl; + // check if we terminate on this call + // we refuse to enumerate terms of 10+ depth as a hard limit + bool terminate = Random::getRandom().pickWithProb(rchance) || depth >= 10; + // if we terminate, only nullary constructors can be chosen + std::vector<unsigned>& cindices = + terminate ? d_rvalue_null_cindices[tn] : d_rvalue_cindices[tn]; + unsigned ncons = cindices.size(); + // select a random constructor, or random value when index=ncons. + unsigned index = Random::getRandom().pick(0, ncons); + Trace("sygus-sample-grammar") << "Random index 0..." << ncons + << " was : " << index << std::endl; + if (index < ncons) + { + Trace("sygus-sample-grammar") << "Recurse constructor index #" << index + << std::endl; + unsigned cindex = cindices[index]; + Assert(cindex < dt.getNumConstructors()); + const DatatypeConstructor& dtc = dt[cindex]; + // more likely to terminate in recursive calls + double rchance_new = rchance + (1.0 - rchance) * rinc; + std::map<int, Node> pre; + bool success = true; + // generate random values for all arguments + for (unsigned i = 0, nargs = dtc.getNumArgs(); i < nargs; i++) + { + TypeNode tnc = d_tds->getArgType(dtc, i); + Node c = getSygusRandomValue(tnc, rchance_new, rinc, depth + 1); + if (c.isNull()) + { + success = false; + Trace("sygus-sample-grammar") << "...fail." << std::endl; + break; + } + Trace("sygus-sample-grammar") << " child #" << i << " : " << c + << std::endl; + pre[i] = c; + } + if (success) + { + Trace("sygus-sample-grammar") << "mkGeneric" << std::endl; + Node ret = d_tds->mkGeneric(dt, cindex, pre); + Trace("sygus-sample-grammar") << "...returned " << ret << std::endl; + ret = Rewriter::rewrite(ret); + Trace("sygus-sample-grammar") << "...after rewrite " << ret << std::endl; + Assert(ret.isConst()); + return ret; + } + } + Trace("sygus-sample-grammar") << "...resort to random value" << std::endl; + // if we did not generate based on the grammar, pick a random value + return getRandomValue(TypeNode::fromType(dt.getSygusType())); +} + +// recursion depth bounded by number of types in grammar (small) +void SygusSampler::registerSygusType(TypeNode tn) +{ + if (d_rvalue_cindices.find(tn) == d_rvalue_cindices.end()) + { + d_rvalue_cindices[tn].clear(); + Assert(tn.isDatatype()); + const Datatype& dt = static_cast<DatatypeType>(tn.toType()).getDatatype(); + Assert(dt.isSygus()); + for (unsigned i = 0, ncons = dt.getNumConstructors(); i < ncons; i++) + { + const DatatypeConstructor& dtc = dt[i]; + Node sop = Node::fromExpr(dtc.getSygusOp()); + bool isVar = std::find(d_vars.begin(), d_vars.end(), sop) != d_vars.end(); + if (isVar) + { + // if it is a variable, add it to the list of sygus types for that var + d_var_sygus_types[sop].push_back(tn); + } + else + { + // otherwise, it is a constructor for sygus random value + d_rvalue_cindices[tn].push_back(i); + if (dtc.getNumArgs() == 0) + { + d_rvalue_null_cindices[tn].push_back(i); + } + } + // recurse on all subfields + for (unsigned j = 0, nargs = dtc.getNumArgs(); j < nargs; j++) + { + TypeNode tnc = d_tds->getArgType(dtc, j); + registerSygusType(tnc); + } + } + } +} + } /* CVC4::theory::quantifiers namespace */ } /* CVC4::theory namespace */ } /* CVC4 namespace */ |