summaryrefslogtreecommitdiff
path: root/src/theory/quantifiers/sygus_sampler.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/theory/quantifiers/sygus_sampler.cpp')
-rw-r--r--src/theory/quantifiers/sygus_sampler.cpp187
1 files changed, 180 insertions, 7 deletions
diff --git a/src/theory/quantifiers/sygus_sampler.cpp b/src/theory/quantifiers/sygus_sampler.cpp
index 0b8f390f3..82720dd5e 100644
--- a/src/theory/quantifiers/sygus_sampler.cpp
+++ b/src/theory/quantifiers/sygus_sampler.cpp
@@ -14,6 +14,7 @@
#include "theory/quantifiers/sygus_sampler.h"
+#include "options/quantifiers_options.h"
#include "util/bitvector.h"
#include "util/random.h"
@@ -80,6 +81,9 @@ void SygusSampler::initialize(TypeNode tn,
d_var_index[sv] = d_type_vars[svt].size();
d_type_vars[svt].push_back(sv);
}
+ d_rvalue_cindices.clear();
+ d_rvalue_null_cindices.clear();
+ d_var_sygus_types.clear();
initializeSamples(nsamples);
}
@@ -109,6 +113,10 @@ void SygusSampler::initializeSygus(TermDbSygus* tds, Node f, unsigned nsamples)
d_type_vars[svt].push_back(sv);
}
}
+ d_rvalue_cindices.clear();
+ d_rvalue_null_cindices.clear();
+ d_var_sygus_types.clear();
+ registerSygusType(d_ftn);
initializeSamples(nsamples);
}
@@ -123,28 +131,90 @@ void SygusSampler::initializeSamples(unsigned nsamples)
Trace("sygus-sample") << " var #" << types.size() << " : " << v << " : "
<< vt << std::endl;
}
+ std::map<unsigned, std::map<Node, std::vector<TypeNode> >::iterator> sts;
+ if (options::sygusSampleGrammar())
+ {
+ for (unsigned j = 0, size = types.size(); j < size; j++)
+ {
+ sts[j] = d_var_sygus_types.find(d_vars[j]);
+ }
+ }
+
+ unsigned nduplicates = 0;
for (unsigned i = 0; i < nsamples; i++)
{
std::vector<Node> sample_pt;
- Trace("sygus-sample") << "Sample point #" << i << " : ";
for (unsigned j = 0, size = types.size(); j < size; j++)
{
- Node r = getRandomValue(types[j]);
+ Node v = d_vars[j];
+ Node r;
+ if (options::sygusSampleGrammar())
+ {
+ // choose a random start sygus type, if possible
+ if (sts[j] != d_var_sygus_types.end())
+ {
+ unsigned ntypes = sts[j]->second.size();
+ Assert(ntypes > 0);
+ unsigned index = Random::getRandom().pick(0, ntypes - 1);
+ if (index < ntypes)
+ {
+ // currently hard coded to 0.0, 0.5
+ r = getSygusRandomValue(sts[j]->second[index], 0.0, 0.5);
+ }
+ }
+ }
if (r.isNull())
{
- Trace("sygus-sample") << "INVALID";
- d_is_valid = false;
+ r = getRandomValue(types[j]);
+ if (r.isNull())
+ {
+ d_is_valid = false;
+ }
}
- Trace("sygus-sample") << r << " ";
sample_pt.push_back(r);
}
- Trace("sygus-sample") << std::endl;
- d_samples.push_back(sample_pt);
+ if (d_samples_trie.add(sample_pt))
+ {
+ if (Trace.isOn("sygus-sample"))
+ {
+ Trace("sygus-sample") << "Sample point #" << i << " : ";
+ for (const Node& r : sample_pt)
+ {
+ Trace("sygus-sample") << r << " ";
+ }
+ Trace("sygus-sample") << std::endl;
+ }
+ d_samples.push_back(sample_pt);
+ }
+ else
+ {
+ i--;
+ nduplicates++;
+ if (nduplicates == nsamples * 10)
+ {
+ Trace("sygus-sample")
+ << "...WARNING: excessive duplicates, cut off sampling at " << i
+ << "/" << nsamples << " points." << std::endl;
+ break;
+ }
+ }
}
d_trie.clear();
}
+bool SygusSampler::PtTrie::add(std::vector<Node>& pt)
+{
+ PtTrie* curr = this;
+ for (unsigned i = 0, size = pt.size(); i < size; i++)
+ {
+ curr = &(curr->d_children[pt[i]]);
+ }
+ bool retVal = curr->d_children.empty();
+ curr = &(curr->d_children[Node::null()]);
+ return retVal;
+}
+
Node SygusSampler::registerTerm(Node n, bool forceKeep)
{
if (d_is_valid)
@@ -389,6 +459,109 @@ Node SygusSampler::getRandomValue(TypeNode tn)
return Node::null();
}
+Node SygusSampler::getSygusRandomValue(TypeNode tn,
+ double rchance,
+ double rinc,
+ unsigned depth)
+{
+ Assert(tn.isDatatype());
+ const Datatype& dt = static_cast<DatatypeType>(tn.toType()).getDatatype();
+ Assert(dt.isSygus());
+ Assert(d_rvalue_cindices.find(tn) != d_rvalue_cindices.end());
+ Trace("sygus-sample-grammar") << "Sygus random value " << tn
+ << ", depth = " << depth
+ << ", rchance = " << rchance << std::endl;
+ // check if we terminate on this call
+ // we refuse to enumerate terms of 10+ depth as a hard limit
+ bool terminate = Random::getRandom().pickWithProb(rchance) || depth >= 10;
+ // if we terminate, only nullary constructors can be chosen
+ std::vector<unsigned>& cindices =
+ terminate ? d_rvalue_null_cindices[tn] : d_rvalue_cindices[tn];
+ unsigned ncons = cindices.size();
+ // select a random constructor, or random value when index=ncons.
+ unsigned index = Random::getRandom().pick(0, ncons);
+ Trace("sygus-sample-grammar") << "Random index 0..." << ncons
+ << " was : " << index << std::endl;
+ if (index < ncons)
+ {
+ Trace("sygus-sample-grammar") << "Recurse constructor index #" << index
+ << std::endl;
+ unsigned cindex = cindices[index];
+ Assert(cindex < dt.getNumConstructors());
+ const DatatypeConstructor& dtc = dt[cindex];
+ // more likely to terminate in recursive calls
+ double rchance_new = rchance + (1.0 - rchance) * rinc;
+ std::map<int, Node> pre;
+ bool success = true;
+ // generate random values for all arguments
+ for (unsigned i = 0, nargs = dtc.getNumArgs(); i < nargs; i++)
+ {
+ TypeNode tnc = d_tds->getArgType(dtc, i);
+ Node c = getSygusRandomValue(tnc, rchance_new, rinc, depth + 1);
+ if (c.isNull())
+ {
+ success = false;
+ Trace("sygus-sample-grammar") << "...fail." << std::endl;
+ break;
+ }
+ Trace("sygus-sample-grammar") << " child #" << i << " : " << c
+ << std::endl;
+ pre[i] = c;
+ }
+ if (success)
+ {
+ Trace("sygus-sample-grammar") << "mkGeneric" << std::endl;
+ Node ret = d_tds->mkGeneric(dt, cindex, pre);
+ Trace("sygus-sample-grammar") << "...returned " << ret << std::endl;
+ ret = Rewriter::rewrite(ret);
+ Trace("sygus-sample-grammar") << "...after rewrite " << ret << std::endl;
+ Assert(ret.isConst());
+ return ret;
+ }
+ }
+ Trace("sygus-sample-grammar") << "...resort to random value" << std::endl;
+ // if we did not generate based on the grammar, pick a random value
+ return getRandomValue(TypeNode::fromType(dt.getSygusType()));
+}
+
+// recursion depth bounded by number of types in grammar (small)
+void SygusSampler::registerSygusType(TypeNode tn)
+{
+ if (d_rvalue_cindices.find(tn) == d_rvalue_cindices.end())
+ {
+ d_rvalue_cindices[tn].clear();
+ Assert(tn.isDatatype());
+ const Datatype& dt = static_cast<DatatypeType>(tn.toType()).getDatatype();
+ Assert(dt.isSygus());
+ for (unsigned i = 0, ncons = dt.getNumConstructors(); i < ncons; i++)
+ {
+ const DatatypeConstructor& dtc = dt[i];
+ Node sop = Node::fromExpr(dtc.getSygusOp());
+ bool isVar = std::find(d_vars.begin(), d_vars.end(), sop) != d_vars.end();
+ if (isVar)
+ {
+ // if it is a variable, add it to the list of sygus types for that var
+ d_var_sygus_types[sop].push_back(tn);
+ }
+ else
+ {
+ // otherwise, it is a constructor for sygus random value
+ d_rvalue_cindices[tn].push_back(i);
+ if (dtc.getNumArgs() == 0)
+ {
+ d_rvalue_null_cindices[tn].push_back(i);
+ }
+ }
+ // recurse on all subfields
+ for (unsigned j = 0, nargs = dtc.getNumArgs(); j < nargs; j++)
+ {
+ TypeNode tnc = d_tds->getArgType(dtc, j);
+ registerSygusType(tnc);
+ }
+ }
+ }
+}
+
} /* CVC4::theory::quantifiers namespace */
} /* CVC4::theory namespace */
} /* CVC4 namespace */
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback