diff options
Diffstat (limited to 'src/theory')
23 files changed, 1380 insertions, 283 deletions
diff --git a/src/theory/.gitignore b/src/theory/.gitignore deleted file mode 100644 index 4d15f70c0..000000000 --- a/src/theory/.gitignore +++ /dev/null @@ -1 +0,0 @@ -/theoryof_table.h diff --git a/src/theory/bv/theory_bv_rewrite_rules.h b/src/theory/bv/theory_bv_rewrite_rules.h index b52cb91a4..c6cd9eb1c 100644 --- a/src/theory/bv/theory_bv_rewrite_rules.h +++ b/src/theory/bv/theory_bv_rewrite_rules.h @@ -119,7 +119,7 @@ enum RewriteRuleId BitwiseIdemp, AndZero, AndOne, - AndOrConcatPullUp, + AndOrXorConcatPullUp, OrZero, OrOne, XorDuplicate, @@ -201,7 +201,7 @@ inline std::ostream& operator << (std::ostream& out, RewriteRuleId ruleId) { case ConcatFlatten: out << "ConcatFlatten"; return out; case ConcatExtractMerge: out << "ConcatExtractMerge"; return out; case ConcatConstantMerge: out << "ConcatConstantMerge"; return out; - case AndOrConcatPullUp: out << "AndOrConcatPullUp"; return out; + case AndOrXorConcatPullUp:out << "AndOrXorConcatPullUp";return out; case ExtractExtract: out << "ExtractExtract"; return out; case ExtractWhole: out << "ExtractWhole"; return out; case ExtractConcat: out << "ExtractConcat"; return out; @@ -581,7 +581,7 @@ struct AllRewriteRules { RewriteRule<BvIteMergeElseIf> rule136; RewriteRule<BvIteMergeThenElse> rule137; RewriteRule<BvIteMergeElseElse> rule138; - RewriteRule<AndOrConcatPullUp> rule139; + RewriteRule<AndOrXorConcatPullUp> rule139; }; template<> inline diff --git a/src/theory/bv/theory_bv_rewrite_rules_simplification.h b/src/theory/bv/theory_bv_rewrite_rules_simplification.h index 5e9d2b349..7efdc2c81 100644 --- a/src/theory/bv/theory_bv_rewrite_rules_simplification.h +++ b/src/theory/bv/theory_bv_rewrite_rules_simplification.h @@ -489,48 +489,22 @@ Node RewriteRule<AndOne>::apply(TNode node) { /* -------------------------------------------------------------------------- */ /** - * AndOrConcatPullUp + * AndOrXorConcatPullUp * - * And: - * ---------------------------------------------------------------- - * Match: x_m & concat(y_my, 0_n, z_mz) - * Rewrites to: concat(x[m-1:m-my] & y, 0_n, x[mz-1:0] & z) + * Match: x_m <op> concat(y_my, <const>_n, z_mz) + * <const>_n in { 0_n, 1_n, ~0_n } * - * Match: x_m & concat(y_my, 1_n, z_mz) - * Rewrites to: concat(x[m-1:m-my] & y, - * 0_[n-1], - * x[mz:mz], - * x[mz-1:0] & z) - * - * Match: x_m & concat(y_my, ~0_n, z_mz) - * Rewrites to: concat(x[m-1:m-my] & y, - * x[m-my-1:mz], - * x[mz-1:0] & z) - * - * Or: - * ---------------------------------------------------------------- - * Match: x_m | concat(y_my, 0_n, z_mz) - * Rewrites to: concat(x[m-1:m-my] | y, - * x[m-my-1:mz], - * x[mz-1:0] | z) - * - * Match: x_m | concat(y_my, 1_n, z_mz) - * Rewrites to: concat(x[m-1:m-my] | y, - * x[m-my-1:mz+1], - * 1_1, - * x[mz-1:0] | z) - * - * Match: x_m | concat(y_my, ~0_n, z_mz) - * Rewrites to: concat(x[m-1:m-my] | y, - * ~0_n, - * x[mz-1:0] | z) + * Rewrites to: concat(x[m-1:m-my] <op> y, + * x[m-my-1:mz] <op> <const>_n, + * x[mz-1:0] <op> z) */ template <> -inline bool RewriteRule<AndOrConcatPullUp>::applies(TNode node) +inline bool RewriteRule<AndOrXorConcatPullUp>::applies(TNode node) { if (node.getKind() != kind::BITVECTOR_AND - && node.getKind() != kind::BITVECTOR_OR) + && node.getKind() != kind::BITVECTOR_OR + && node.getKind() != kind::BITVECTOR_XOR) { return false; } @@ -557,11 +531,10 @@ inline bool RewriteRule<AndOrConcatPullUp>::applies(TNode node) } template <> -inline Node RewriteRule<AndOrConcatPullUp>::apply(TNode node) +inline Node RewriteRule<AndOrXorConcatPullUp>::apply(TNode node) { - Debug("bv-rewrite") << "RewriteRule<AndOrConcatPullUp>(" << node << ")" + Debug("bv-rewrite") << "RewriteRule<AndOrXorConcatPullUp>(" << node << ")" << std::endl; - int32_t is_const; uint32_t m, my, mz, n; size_t nc; Kind kind = node.getKind(); @@ -586,24 +559,12 @@ inline Node RewriteRule<AndOrConcatPullUp>::apply(TNode node) } x = xb.getNumChildren() > 1 ? xb.constructNode() : xb[0]; - is_const = -2; for (const TNode& child : concat) { if (c.isNull()) { - if (utils::isZero(child)) - { - is_const = 0; - c = child; - } - else if (utils::isOne(child)) - { - is_const = 1; - c = child; - } - else if (utils::isOnes(child)) + if (utils::isZero(child) || utils::isOne(child) || utils::isOnes(child)) { - is_const = -1; c = child; } else @@ -638,49 +599,14 @@ inline Node RewriteRule<AndOrConcatPullUp>::apply(TNode node) { res << nm->mkNode(kind, utils::mkExtract(x, m - 1, m - my), y); } - if (is_const == 0) - { - if (kind == kind::BITVECTOR_AND) - { - res << c; - } - else - { - Assert(kind == kind::BITVECTOR_OR); - res << utils::mkExtract(x, m - my - 1, mz); - } - } - else if (is_const == 1) - { - if (kind == kind::BITVECTOR_AND) - { - if (n > 1) res << utils::mkZero(n - 1); - res << utils::mkExtract(x, mz, mz); - } - else - { - Assert(kind == kind::BITVECTOR_OR); - if (n > 1) res << utils::mkExtract(x, m - my - 1, mz + 1); - res << utils::mkOne(1); - } - } - else - { - Assert(is_const == -1); - if (kind == kind::BITVECTOR_AND) - { - res << utils::mkExtract(x, m - my - 1, mz); - } - else - { - Assert(kind == kind::BITVECTOR_OR); - res << c; - } - } + + res << nm->mkNode(kind, utils::mkExtract(x, m - my - 1, mz), c); + if (mz) { res << nm->mkNode(kind, utils::mkExtract(x, mz - 1, 0), z); } + return res; } diff --git a/src/theory/bv/theory_bv_rewriter.cpp b/src/theory/bv/theory_bv_rewriter.cpp index 3f018f800..0c6f1d37a 100644 --- a/src/theory/bv/theory_bv_rewriter.cpp +++ b/src/theory/bv/theory_bv_rewriter.cpp @@ -253,7 +253,7 @@ RewriteResponse TheoryBVRewriter::RewriteAnd(TNode node, bool prerewrite) resultNode = LinearRewriteStrategy<RewriteRule<FlattenAssocCommutNoDuplicates>, RewriteRule<AndSimplify>, - RewriteRule<AndOrConcatPullUp>>::apply(node); + RewriteRule<AndOrXorConcatPullUp>>::apply(node); if (!prerewrite) { resultNode = @@ -274,7 +274,7 @@ RewriteResponse TheoryBVRewriter::RewriteOr(TNode node, bool prerewrite) resultNode = LinearRewriteStrategy<RewriteRule<FlattenAssocCommutNoDuplicates>, RewriteRule<OrSimplify>, - RewriteRule<AndOrConcatPullUp>>::apply(node); + RewriteRule<AndOrXorConcatPullUp>>::apply(node); if (!prerewrite) { @@ -290,27 +290,30 @@ RewriteResponse TheoryBVRewriter::RewriteOr(TNode node, bool prerewrite) return RewriteResponse(REWRITE_DONE, resultNode); } -RewriteResponse TheoryBVRewriter::RewriteXor(TNode node, bool prerewrite) { +RewriteResponse TheoryBVRewriter::RewriteXor(TNode node, bool prerewrite) +{ Node resultNode = node; - resultNode = LinearRewriteStrategy - < RewriteRule<FlattenAssocCommut>, // flatten the expression - RewriteRule<XorSimplify>, // simplify duplicates and constants - RewriteRule<XorZero>, // checks if the constant part is zero and eliminates it - RewriteRule<BitwiseSlicing> - >::apply(node); + resultNode = LinearRewriteStrategy< + RewriteRule<FlattenAssocCommut>, // flatten the expression + RewriteRule<XorSimplify>, // simplify duplicates and constants + RewriteRule<XorZero>, // checks if the constant part is zero and + // eliminates it + RewriteRule<AndOrXorConcatPullUp>, + RewriteRule<BitwiseSlicing>>::apply(node); - if (!prerewrite) { - resultNode = LinearRewriteStrategy - < RewriteRule<XorOne>, - RewriteRule <BitwiseSlicing> - >::apply(resultNode); - - if (resultNode.getKind() != node.getKind()) { - return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); + if (!prerewrite) + { + resultNode = + LinearRewriteStrategy<RewriteRule<XorOne>, + RewriteRule<BitwiseSlicing>>::apply(resultNode); + + if (resultNode.getKind() != node.getKind()) + { + return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); } } - return RewriteResponse(REWRITE_DONE, resultNode); + return RewriteResponse(REWRITE_DONE, resultNode); } RewriteResponse TheoryBVRewriter::RewriteXnor(TNode node, bool prerewrite) { diff --git a/src/theory/datatypes/datatypes_sygus.cpp b/src/theory/datatypes/datatypes_sygus.cpp index 17ef4f968..a7763bff1 100644 --- a/src/theory/datatypes/datatypes_sygus.cpp +++ b/src/theory/datatypes/datatypes_sygus.cpp @@ -1096,45 +1096,46 @@ Node SygusSymBreakNew::registerSearchValue(Node a, d_tds, nv, options::sygusSamples(), false); its = d_sampler[a].find(tn); } - - // register the rewritten node with the sampler - Node bvr_sample_ret = its->second.registerTerm(bvr); - // register the current node with the sampler - Node sample_ret = its->second.registerTerm(bv); - + // see if they evaluate to same thing on all sample points + bool ptDisequal = false; + unsigned pt_index = 0; + Node bve, bvre; + for (unsigned i = 0, npoints = its->second.getNumSamplePoints(); + i < npoints; + i++) + { + bve = its->second.evaluate(bv, i); + bvre = its->second.evaluate(bvr, i); + if (bve != bvre) + { + ptDisequal = true; + pt_index = i; + break; + } + } // bv and bvr should be equivalent under examples - if (sample_ret != bvr_sample_ret) + if (ptDisequal) { // we have detected unsoundness in the rewriter Options& nodeManagerOptions = NodeManager::currentNM()->getOptions(); std::ostream* out = nodeManagerOptions.getOut(); (*out) << "(unsound-rewrite " << bv << " " << bvr << ")" << std::endl; // debugging information - int pt_index = its->second.getDiffSamplePointIndex(bv, bvr); - if (pt_index >= 0) + (*out) << "; unsound: are not equivalent for : " << std::endl; + std::vector<Node> vars; + its->second.getVariables(vars); + std::vector<Node> pt; + its->second.getSamplePoint(pt_index, pt); + Assert(vars.size() == pt.size()); + for (unsigned i = 0, size = pt.size(); i < size; i++) { - (*out) << "; unsound: are not equivalent for : " << std::endl; - std::vector<Node> vars; - its->second.getVariables(vars); - std::vector<Node> pt; - its->second.getSamplePoint(pt_index, pt); - Assert(vars.size() == pt.size()); - for (unsigned i = 0, size = pt.size(); i < size; i++) - { - (*out) << "; unsound: " << vars[i] << " -> " << pt[i] - << std::endl; - } - Node bv_e = its->second.evaluate(bv, pt_index); - Node pbv_e = its->second.evaluate(bvr, pt_index); - Assert(bv_e != pbv_e); - (*out) << "; unsound: where they evaluate to " << bv_e << " and " - << pbv_e << std::endl; - } - else - { - // no witness point found? - Assert(false); + (*out) << "; unsound: " << vars[i] << " -> " << pt[i] + << std::endl; } + Assert(bve != bvre); + (*out) << "; unsound: where they evaluate to " << bve << " and " + << bvre << std::endl; + if (options::sygusRewVerifyAbort()) { AlwaysAssert( diff --git a/src/theory/quantifiers/expr_miner_manager.cpp b/src/theory/quantifiers/expr_miner_manager.cpp index 8c116781d..cc97888e3 100644 --- a/src/theory/quantifiers/expr_miner_manager.cpp +++ b/src/theory/quantifiers/expr_miner_manager.cpp @@ -13,6 +13,7 @@ **/ #include "theory/quantifiers/expr_miner_manager.h" +#include "theory/quantifiers_engine.h" namespace CVC4 { namespace theory { @@ -20,6 +21,8 @@ namespace quantifiers { ExpressionMinerManager::ExpressionMinerManager() : d_doRewSynth(false), + d_doQueryGen(false), + d_doFilterImplied(false), d_use_sygus_type(false), d_qe(nullptr), d_tds(nullptr) @@ -32,6 +35,8 @@ void ExpressionMinerManager::initialize(const std::vector<Node>& vars, bool unique_type_ids) { d_doRewSynth = false; + d_doQueryGen = false; + d_doFilterImplied = false; d_sygus_fun = Node::null(); d_use_sygus_type = false; d_qe = nullptr; @@ -46,6 +51,8 @@ void ExpressionMinerManager::initializeSygus(QuantifiersEngine* qe, bool useSygusType) { d_doRewSynth = false; + d_doQueryGen = false; + d_doFilterImplied = false; d_sygus_fun = f; d_use_sygus_type = useSygusType; d_qe = qe; @@ -78,11 +85,66 @@ void ExpressionMinerManager::enableRewriteRuleSynth() d_crd.setSilent(false); } +void ExpressionMinerManager::enableQueryGeneration(unsigned deqThresh) +{ + if (d_doQueryGen) + { + // already enabled + return; + } + d_doQueryGen = true; + std::vector<Node> vars; + d_sampler.getVariables(vars); + // must also enable rewrite rule synthesis + if (!d_doRewSynth) + { + // initialize the candidate rewrite database, in silent mode + enableRewriteRuleSynth(); + d_crd.setSilent(true); + } + // initialize the query generator + d_qg.initialize(vars, &d_sampler); + d_qg.setThreshold(deqThresh); +} + +void ExpressionMinerManager::enableFilterImpliedSolutions() +{ + d_doFilterImplied = true; + std::vector<Node> vars; + d_sampler.getVariables(vars); + d_solf.initialize(vars, &d_sampler); +} + bool ExpressionMinerManager::addTerm(Node sol, std::ostream& out, bool& rew_print) { - return d_crd.addTerm(sol, out, rew_print); + // set the builtin version + Node solb = sol; + if (d_use_sygus_type) + { + solb = d_tds->sygusToBuiltin(sol); + } + + // add to the candidate rewrite rule database + bool ret = true; + if (d_doRewSynth) + { + ret = d_crd.addTerm(sol, out, rew_print); + } + + // a unique term, let's try the query generator + if (ret && d_doQueryGen) + { + d_qg.addTerm(solb, out); + } + + // filter if it's implied + if (ret && d_doFilterImplied) + { + ret = d_solf.addTerm(solb, out); + } + return ret; } bool ExpressionMinerManager::addTerm(Node sol, std::ostream& out) diff --git a/src/theory/quantifiers/expr_miner_manager.h b/src/theory/quantifiers/expr_miner_manager.h index 668d04beb..d8e6ae651 100644 --- a/src/theory/quantifiers/expr_miner_manager.h +++ b/src/theory/quantifiers/expr_miner_manager.h @@ -20,6 +20,8 @@ #include "expr/node.h" #include "theory/quantifiers/candidate_rewrite_database.h" #include "theory/quantifiers/extended_rewrite.h" +#include "theory/quantifiers/query_generator.h" +#include "theory/quantifiers/solution_filter.h" #include "theory/quantifiers/sygus_sampler.h" #include "theory/quantifiers_engine.h" @@ -67,6 +69,10 @@ class ExpressionMinerManager bool useSygusType); /** enable rewrite rule synthesis (--sygus-rr-synth) */ void enableRewriteRuleSynth(); + /** enable query generation (--sygus-query-gen) */ + void enableQueryGeneration(unsigned deqThresh); + /** filter implied solutions (--sygus-sol-filter-implied) */ + void enableFilterImpliedSolutions(); /** add term * * Expression miners may print information on the output stream out, for @@ -84,6 +90,10 @@ class ExpressionMinerManager private: /** whether we are doing rewrite synthesis */ bool d_doRewSynth; + /** whether we are doing query generation */ + bool d_doQueryGen; + /** whether we are filtering implied candidates */ + bool d_doFilterImplied; /** the sygus function passed to initializeSygus, if any */ Node d_sygus_fun; /** whether we are using sygus types */ @@ -94,6 +104,10 @@ class ExpressionMinerManager TermDbSygus* d_tds; /** candidate rewrite database */ CandidateRewriteDatabase d_crd; + /** query generator */ + QueryGenerator d_qg; + /** solution filter */ + SolutionFilter d_solf; /** sygus sampler object */ SygusSampler d_sampler; /** extended rewriter object */ diff --git a/src/theory/quantifiers/query_generator.cpp b/src/theory/quantifiers/query_generator.cpp new file mode 100644 index 000000000..e62f3513c --- /dev/null +++ b/src/theory/quantifiers/query_generator.cpp @@ -0,0 +1,416 @@ +/********************* */ +/*! \file query_generator.cpp + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Implementation of a class for mining interesting satisfiability + ** queries from a stream of generated expressions. + **/ + +#include "theory/quantifiers/query_generator.h" + +#include <fstream> +#include "options/quantifiers_options.h" +#include "smt/smt_engine.h" +#include "smt/smt_engine_scope.h" +#include "util/random.h" + +using namespace std; +using namespace CVC4::kind; + +namespace CVC4 { +namespace theory { +namespace quantifiers { + +QueryGenerator::QueryGenerator() : d_queryCount(0) {} +void QueryGenerator::initialize(const std::vector<Node>& vars, SygusSampler* ss) +{ + Assert(ss != nullptr); + d_queryCount = 0; + ExprMiner::initialize(vars, ss); +} + +void QueryGenerator::setThreshold(unsigned deqThresh) +{ + d_deqThresh = deqThresh; +} + +bool QueryGenerator::addTerm(Node n, std::ostream& out) +{ + Node nn = n.getKind() == NOT ? n[0] : n; + if (d_terms.find(nn) != d_terms.end()) + { + return false; + } + d_terms.insert(nn); + + Trace("sygus-qgen") << "QueryGenerator::addTerm : " << n << std::endl; + unsigned npts = d_sampler->getNumSamplePoints(); + TypeNode tn = n.getType(); + // TODO : as an optimization, use a shared lazy trie? + + // the queries we generate on this round + std::vector<Node> queries; + // For each query in the above vector, this stores the indices of the points + // for which that query evaluated to true on. + std::vector<std::vector<unsigned>> queriesPtTrue; + // the sample point indices for which the above queries are true + std::unordered_set<unsigned> indices; + + // collect predicate queries (if n is Boolean) + if (tn.isBoolean()) + { + std::map<Node, std::vector<unsigned>> ev_to_pt; + unsigned index = 0; + unsigned threshCount = 0; + while (index < npts && threshCount < 2) + { + Node v = d_sampler->evaluate(nn, index); + ev_to_pt[v].push_back(index); + if (ev_to_pt[v].size() == d_deqThresh + 1) + { + threshCount++; + } + index++; + } + if (threshCount < 2) + { + for (const std::pair<Node, std::vector<unsigned>>& etp : ev_to_pt) + { + if (etp.second.size() < d_deqThresh) + { + indices.insert(etp.second.begin(), etp.second.end()); + Node qy = nn; + Assert(etp.first.isConst()); + if (!etp.first.getConst<bool>()) + { + qy = qy.negate(); + } + queries.push_back(qy); + queriesPtTrue.push_back(etp.second); + } + } + } + } + + // collect equality queries + findQueries(nn, queries, queriesPtTrue); + Assert(queries.size() == queriesPtTrue.size()); + if (queries.empty()) + { + return true; + } + Trace("sygus-qgen-debug") + << "query: Check " << queries.size() << " queries..." << std::endl; + // literal queries + for (unsigned i = 0, nqueries = queries.size(); i < nqueries; i++) + { + Node qy = queries[i]; + std::vector<unsigned>& tIndices = queriesPtTrue[i]; + // we have an interesting query + out << "(query " << qy << ") ; " << tIndices.size() << "/" << npts + << std::endl; + AlwaysAssert(!tIndices.empty()); + checkQuery(qy, tIndices[0]); + // add information + for (unsigned& ti : tIndices) + { + d_ptToQueries[ti].push_back(qy); + d_qysToPoints[qy].push_back(ti); + indices.insert(ti); + } + } + // for each new index, we may have a new conjunctive query + NodeManager* nm = NodeManager::currentNM(); + for (const unsigned& i : indices) + { + std::vector<Node>& qsi = d_ptToQueries[i]; + if (qsi.size() > 1) + { + // take two random queries + std::shuffle(qsi.begin(), qsi.end(), Random::getRandom()); + Node qy = nm->mkNode(AND, qsi[0], qsi[1]); + checkQuery(qy, i); + } + } + Trace("sygus-qgen-check") << "...finished." << std::endl; + return true; +} + +void QueryGenerator::checkQuery(Node qy, unsigned spIndex) +{ + // external query + if (options::sygusQueryGenDumpFiles()) + { + // Print the query and the query + its model (commented) to queryN.smt2 + std::vector<Node> pt; + d_sampler->getSamplePoint(spIndex, pt); + unsigned nvars = d_vars.size(); + AlwaysAssert(pt.size() == d_vars.size()); + std::stringstream fname; + fname << "query" << d_queryCount << ".smt2"; + std::ofstream fs(fname.str(), std::ofstream::out); + fs << "(set-logic ALL)" << std::endl; + for (unsigned i = 0; i < 2; i++) + { + for (unsigned j = 0; j < nvars; j++) + { + Node x = d_vars[j]; + if (i == 0) + { + fs << "(declare-fun " << x << " () " << x.getType() << ")"; + } + else + { + fs << ";(define-fun " << x << " () " << x.getType() << " " << pt[j] + << ")"; + } + fs << std::endl; + } + } + fs << "(assert " << qy << ")" << std::endl; + fs << "(check-sat)" << std::endl; + fs.close(); + } + + if (options::sygusQueryGenCheck()) + { + Trace("sygus-qgen-check") << " query: check " << qy << "..." << std::endl; + NodeManager* nm = NodeManager::currentNM(); + // make the satisfiability query + bool needExport = false; + ExprManagerMapCollection varMap; + ExprManager em(nm->getOptions()); + std::unique_ptr<SmtEngine> queryChecker; + initializeChecker(queryChecker, em, varMap, qy, needExport); + Result r = queryChecker->checkSat(); + Trace("sygus-qgen-check") << " query: ...got : " << r << std::endl; + if (r.asSatisfiabilityResult().isSat() == Result::UNSAT) + { + std::stringstream ss; + ss << "--sygus-rr-query-gen detected unsoundness in CVC4 on input " << qy + << "!" << std::endl; + ss << "This query has a model : " << std::endl; + std::vector<Node> pt; + d_sampler->getSamplePoint(spIndex, pt); + Assert(pt.size() == d_vars.size()); + for (unsigned i = 0, size = pt.size(); i < size; i++) + { + ss << " " << d_vars[i] << " -> " << pt[i] << std::endl; + } + ss << "but CVC4 answered unsat!" << std::endl; + AlwaysAssert(false, ss.str().c_str()); + } + } + + d_queryCount++; +} + +void QueryGenerator::findQueries( + Node n, + std::vector<Node>& queries, + std::vector<std::vector<unsigned>>& queriesPtTrue) +{ + // At a high level, this method traverses the LazyTrie for the type of n + // and tries to find paths to leafs that contain terms n' such that n = n' + // or n != n' is an interesting query, i.e. satisfied for a small number of + // points. + TypeNode tn = n.getType(); + LazyTrie* lt = &d_qgtTrie[tn]; + // These vectors are the set of indices of sample points for which the current + // node we are considering are { equal, disequal } from n. + std::vector<unsigned> eqIndex[2]; + Trace("sygus-qgen-debug") << "Compute queries for " << n << "...\n"; + + LazyTrieEvaluator* ev = d_sampler; + unsigned ntotal = d_sampler->getNumSamplePoints(); + unsigned index = 0; + bool exact = true; + bool pushEq[2] = {false, false}; + bool pre = true; + // The following parallel vectors describe the state of the locations in the + // trie we are currently visiting. + // Reference to the location in the trie + std::vector<LazyTrie*> visitTr; + // The index of the sample point we are testing + std::vector<unsigned> currIndex; + // Whether the path to this location exactly matches the evaluation of n + std::vector<bool> currExact; + // Whether we are adding to the points that are { equal, disequal } by + // traversing to this location. + std::vector<bool> pushIndex[2]; + // Whether we are in a pre-traversal for this location. + std::vector<bool> preVisit; + visitTr.push_back(lt); + currIndex.push_back(0); + currExact.push_back(true); + pushIndex[0].push_back(false); + pushIndex[1].push_back(false); + preVisit.push_back(true); + do + { + lt = visitTr.back(); + index = currIndex.back(); + exact = currExact.back(); + for (unsigned r = 0; r < 2; r++) + { + pushEq[r] = pushIndex[r].back(); + } + pre = preVisit.back(); + if (!pre) + { + visitTr.pop_back(); + currIndex.pop_back(); + currExact.pop_back(); + preVisit.pop_back(); + // clean up the indices of points that are { equal, disequal } + for (unsigned r = 0; r < 2; r++) + { + if (pushEq[r]) + { + eqIndex[r].pop_back(); + } + pushIndex[r].pop_back(); + } + } + else + { + preVisit[preVisit.size() - 1] = false; + // add to the indices of points that are { equal, disequal } + for (unsigned r = 0; r < 2; r++) + { + if (pushEq[r]) + { + eqIndex[r].push_back(index - 1); + } + } + int eqAllow = d_deqThresh - eqIndex[0].size(); + int deqAllow = d_deqThresh - eqIndex[1].size(); + Trace("sygus-qgen-debug") + << "Find queries " << n << " " << index << "/" << ntotal + << ", deq/eq allow = " << deqAllow << "/" << eqAllow + << ", exact = " << exact << std::endl; + if (index == ntotal) + { + if (exact) + { + // add to the trie + lt->d_lazy_child = n; + } + else + { + Node nAlmostEq = lt->d_lazy_child; + // if made it here, we still should have either a equality or + // a disequality that is allowed. + Assert(deqAllow >= 0 || eqAllow >= 0); + Node query = n.eqNode(nAlmostEq); + std::vector<unsigned> tIndices; + if (eqAllow >= 0) + { + tIndices.insert( + tIndices.end(), eqIndex[0].begin(), eqIndex[0].end()); + } + else if (deqAllow >= 0) + { + query = query.negate(); + tIndices.insert( + tIndices.end(), eqIndex[1].begin(), eqIndex[1].end()); + } + AlwaysAssert(tIndices.size() <= d_deqThresh); + if (!tIndices.empty()) + { + queries.push_back(query); + queriesPtTrue.push_back(tIndices); + } + } + } + else + { + if (!lt->d_lazy_child.isNull()) + { + // if there is a lazy child here, push + Node e_lc = ev->evaluate(lt->d_lazy_child, index); + // store at next level + lt->d_children[e_lc].d_lazy_child = lt->d_lazy_child; + // replace + lt->d_lazy_child = Node::null(); + } + // compute + Node e_this = ev->evaluate(n, index); + + if (deqAllow >= 0) + { + // recursing on disequal points + deqAllow--; + // if there is use continuing + if (deqAllow >= 0 || eqAllow >= 0) + { + for (std::pair<const Node, LazyTrie>& ltc : lt->d_children) + { + if (ltc.first != e_this) + { + visitTr.push_back(<c.second); + currIndex.push_back(index + 1); + currExact.push_back(false); + pushIndex[0].push_back(false); + pushIndex[1].push_back(true); + preVisit.push_back(true); + } + } + } + deqAllow++; + } + bool pushEqNext = false; + if (eqAllow >= 0) + { + // below, we try recursing (if at all) on equal nodes. + eqAllow--; + pushEqNext = true; + } + // if we are on the exact path of n + if (exact) + { + if (lt->d_children.empty()) + { + // if no one has been here before, we are done + lt->d_lazy_child = n; + } + else + { + // otherwise, we recurse on the equal point + visitTr.push_back(&(lt->d_children[e_this])); + currIndex.push_back(index + 1); + currExact.push_back(true); + pushIndex[0].push_back(pushEqNext); + pushIndex[1].push_back(false); + preVisit.push_back(true); + } + } + else if (deqAllow >= 0 || eqAllow >= 0) + { + // recurse on the equal point if it exists + std::map<Node, LazyTrie>::iterator iteq = lt->d_children.find(e_this); + if (iteq != lt->d_children.end()) + { + visitTr.push_back(&(iteq->second)); + currIndex.push_back(index + 1); + currExact.push_back(false); + pushIndex[0].push_back(pushEqNext); + pushIndex[1].push_back(false); + preVisit.push_back(true); + } + } + } + } + } while (!visitTr.empty()); +} + +} // namespace quantifiers +} // namespace theory +} // namespace CVC4 diff --git a/src/theory/quantifiers/query_generator.h b/src/theory/quantifiers/query_generator.h new file mode 100644 index 000000000..f0b3fa565 --- /dev/null +++ b/src/theory/quantifiers/query_generator.h @@ -0,0 +1,116 @@ +/********************* */ +/*! \file query_generator.h + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief A class for mining interesting satisfiability queries from a stream + ** of generated expressions. + **/ + +#include "cvc4_private.h" + +#ifndef __CVC4__THEORY__QUANTIFIERS__QUERY_GENERATOR_H +#define __CVC4__THEORY__QUANTIFIERS__QUERY_GENERATOR_H + +#include <map> +#include <unordered_set> +#include "expr/node.h" +#include "theory/quantifiers/expr_miner.h" +#include "theory/quantifiers/lazy_trie.h" +#include "theory/quantifiers/sygus_sampler.h" + +namespace CVC4 { +namespace theory { +namespace quantifiers { + +/** QueryGenerator + * + * This module is used for finding satisfiable queries that are maximally + * likely to trigger an unsound response in an SMT solver. These queries are + * mined from a stream of enumerated expressions. We judge likelihood of + * triggering unsoundness by the frequency at which the query is satisfied. + * + * In detail, given a stream of expressions t_1, ..., t_{n-1}, upon generating + * term t_n, we consider a query (not) t_n = t_i to be an interesting query + * if it is satisfied by at most D points, where D is a predefined threshold + * given by options::sygusQueryGenThresh(). If t_n has type Bool, we + * additionally consider the case where t_n is satisfied (or not satisfied) by + * fewer than D points. + * + * In addition to generating single literal queries, this module also generates + * conjunctive queries, for instance, by remembering that literals L1 and L2 + * were both satisfied by the same point, and thus L1 ^ L2 is an interesting + * query as well. + */ +class QueryGenerator : public ExprMiner +{ + public: + QueryGenerator(); + ~QueryGenerator() {} + /** initialize */ + void initialize(const std::vector<Node>& vars, + SygusSampler* ss = nullptr) override; + /** + * Add term to this module. This may trigger the printing and/or checking of + * new queries. + */ + bool addTerm(Node n, std::ostream& out) override; + /** + * Set the threshold value. This is the maximal number of sample points that + * each query we generate is allowed to be satisfied by. + */ + void setThreshold(unsigned deqThresh); + + private: + /** cache of all terms registered to this generator */ + std::unordered_set<Node, NodeHashFunction> d_terms; + /** the threshold used by this module for maximum number of sat points */ + unsigned d_deqThresh; + /** + * For each type, a lazy trie storing the evaluation of all added terms on + * sample points. + */ + std::map<TypeNode, LazyTrie> d_qgtTrie; + /** total number of queries generated by this class */ + unsigned d_queryCount; + /** find queries + * + * This function traverses the lazy trie for the type of n, finding equality + * and disequality queries between n and other terms in the trie. The argument + * queries collects the newly generated queries, and the argument + * queriesPtTrue collects the indices of points that each query was satisfied + * by (these indices are the indices of the points in the sampler used by this + * class). + */ + void findQueries(Node n, + std::vector<Node>& queries, + std::vector<std::vector<unsigned>>& queriesPtTrue); + /** + * Maps the index of each sample point to the list of queries that it + * satisfies, and that were generated by the above function. This map is used + * for generating conjunctive queries. + */ + std::map<unsigned, std::vector<Node>> d_ptToQueries; + /** + * Map from queries to the indices of the points that satisfy them. + */ + std::map<Node, std::vector<unsigned>> d_qysToPoints; + /** + * Check query qy, which is satisfied by (at least) sample point spIndex, + * using a separate copy of the SMT engine. Throws an exception if qy is + * reported to be unsatisfiable. + */ + void checkQuery(Node qy, unsigned spIndex); +}; + +} // namespace quantifiers +} // namespace theory +} // namespace CVC4 + +#endif /* __CVC4__THEORY__QUANTIFIERS___H */ diff --git a/src/theory/quantifiers/solution_filter.cpp b/src/theory/quantifiers/solution_filter.cpp new file mode 100644 index 000000000..bea3356d1 --- /dev/null +++ b/src/theory/quantifiers/solution_filter.cpp @@ -0,0 +1,92 @@ +/********************* */ +/*! \file solution_filter.cpp + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Utilities for filtering solutions. + **/ + +#include "theory/quantifiers/solution_filter.h" + +#include <fstream> +#include "options/quantifiers_options.h" +#include "smt/smt_engine.h" +#include "smt/smt_engine_scope.h" +#include "util/random.h" + +using namespace CVC4::kind; + +namespace CVC4 { +namespace theory { +namespace quantifiers { + +SolutionFilter::SolutionFilter() {} +void SolutionFilter::initialize(const std::vector<Node>& vars, SygusSampler* ss) +{ + ExprMiner::initialize(vars, ss); +} + +bool SolutionFilter::addTerm(Node n, std::ostream& out) +{ + if (!n.getType().isBoolean()) + { + // currently, should not register non-Boolean terms here + Assert(false); + return true; + } + NodeManager* nm = NodeManager::currentNM(); + Node imp = d_conj.isNull() ? n.negate() : nm->mkNode(AND, d_conj, n.negate()); + imp = Rewriter::rewrite(imp); + bool success = false; + if (imp.isConst()) + { + if (!imp.getConst<bool>()) + { + // if the implication rewrites to false, we filter + Trace("sygus-sol-implied-filter") << "Filtered (by rewriting) : " << n + << std::endl; + return false; + } + else + { + // if the implication rewrites to true, it is trivial + success = true; + } + } + if (!success) + { + Trace("sygus-sol-implied") << " implies: check " << imp << "..." + << std::endl; + // make the satisfiability query + bool needExport = false; + ExprManagerMapCollection varMap; + ExprManager em(nm->getOptions()); + std::unique_ptr<SmtEngine> queryChecker; + initializeChecker(queryChecker, em, varMap, imp, needExport); + Result r = queryChecker->checkSat(); + Trace("sygus-sol-implied") << " implies: ...got : " << r << std::endl; + if (r.asSatisfiabilityResult().isSat() != Result::UNSAT) + { + success = true; + } + } + if (success) + { + d_conj = d_conj.isNull() ? n : nm->mkNode(AND, d_conj, n); + d_conj = Rewriter::rewrite(d_conj); + // note if d_conj is false, we could terminate here + return true; + } + Trace("sygus-sol-implied-filter") << "Filtered : " << n << std::endl; + return false; +} + +} // namespace quantifiers +} // namespace theory +} // namespace CVC4 diff --git a/src/theory/quantifiers/solution_filter.h b/src/theory/quantifiers/solution_filter.h new file mode 100644 index 000000000..9f098cf69 --- /dev/null +++ b/src/theory/quantifiers/solution_filter.h @@ -0,0 +1,62 @@ +/********************* */ +/*! \file solution_filter.h + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Utility for filtering sygus solutions. + **/ + +#include "cvc4_private.h" + +#ifndef __CVC4__THEORY__QUANTIFIERS__SOLUTION_FILTER_H +#define __CVC4__THEORY__QUANTIFIERS__SOLUTION_FILTER_H + +#include <map> +#include <unordered_set> +#include "expr/node.h" +#include "theory/quantifiers/expr_miner.h" +#include "theory/quantifiers/lazy_trie.h" +#include "theory/quantifiers/sygus_sampler.h" + +namespace CVC4 { +namespace theory { +namespace quantifiers { + +/** + * This class is used to filter solutions based on some criteria. + * + * Currently, it is used to filter predicate solutions that are collectively + * entailed by the previous predicate solutions. + */ +class SolutionFilter : public ExprMiner +{ + public: + SolutionFilter(); + ~SolutionFilter() {} + /** initialize */ + void initialize(const std::vector<Node>& vars, + SygusSampler* ss = nullptr) override; + /** + * Add term to this module. It is expected that n has Boolean type. + * If this method returns false, then the entailment n_1 ^ ... ^ n_m |= n + * holds, where n_1, ..., n_m are the terms previously registered to this + * class. + */ + bool addTerm(Node n, std::ostream& out) override; + + private: + /** conjunction of all (non-implied) terms registered to this class */ + Node d_conj; +}; + +} // namespace quantifiers +} // namespace theory +} // namespace CVC4 + +#endif /* __CVC4__THEORY__QUANTIFIERS__SOLUTION_FILTER_H */ diff --git a/src/theory/quantifiers/sygus/cegis.cpp b/src/theory/quantifiers/sygus/cegis.cpp index 06f041d93..79bec60ee 100644 --- a/src/theory/quantifiers/sygus/cegis.cpp +++ b/src/theory/quantifiers/sygus/cegis.cpp @@ -562,7 +562,7 @@ bool Cegis::sampleAddRefinementLemma(const std::vector<Node>& candidates, Assert(vals.size() == candidates.size()); Node sbody = d_base_body.substitute( candidates.begin(), candidates.end(), vals.begin(), vals.end()); - Trace("cegis-sample-debug") << "Sample " << sbody << std::endl; + Trace("cegis-sample-debug2") << "Sample " << sbody << std::endl; // do eager unfolding std::map<Node, Node> visited_n; sbody = d_qe->getTermDatabaseSygus()->getEagerUnfold(sbody, visited_n); diff --git a/src/theory/quantifiers/sygus/synth_conjecture.cpp b/src/theory/quantifiers/sygus/synth_conjecture.cpp index 7955d59a8..b95af719e 100644 --- a/src/theory/quantifiers/sygus/synth_conjecture.cpp +++ b/src/theory/quantifiers/sygus/synth_conjecture.cpp @@ -458,11 +458,11 @@ bool SynthConjecture::doCheck(std::vector<Node>& lems) Node lem; // introduce the skolem variables std::vector<Node> sks; + std::vector<Node> vars; if (constructed_cand) { if (inst.getKind() == NOT && inst[0].getKind() == FORALL) { - std::vector<Node> vars; for (const Node& v : inst[0][0]) { Node sk = nm->mkSkolem("rsk", v.getType()); @@ -527,10 +527,11 @@ bool SynthConjecture::doCheck(std::vector<Node>& lems) { Trace("cegqi-engine") << " * Verification lemma failed for:\n "; // do not send out - for (const Node& v : d_ce_sk_vars) + for (unsigned i = 0, size = d_ce_sk_vars.size(); i < size; i++) { + Node v = d_ce_sk_vars[i]; Node mv = Node::fromExpr(verifySmt.getValue(v.toExpr())); - Trace("cegqi-engine") << v << " -> " << mv << " "; + Trace("cegqi-engine") << vars[i] << " -> " << mv << " "; d_ce_sk_var_mvs.push_back(mv); } Trace("cegqi-engine") << std::endl; @@ -955,7 +956,7 @@ void SynthConjecture::printAndContinueStream() void SynthConjecture::printSynthSolution(std::ostream& out) { - Trace("cegqi-debug") << "Printing synth solution..." << std::endl; + Trace("cegqi-sol-debug") << "Printing synth solution..." << std::endl; Assert(d_quant[0].getNumChildren() == d_embed_quant[0].getNumChildren()); std::vector<Node> sols; std::vector<int> statuses; @@ -981,8 +982,10 @@ void SynthConjecture::printSynthSolution(std::ostream& out) bool is_unique_term = true; - if (status != 0 && options::sygusRewSynth()) + if (status != 0 && (options::sygusRewSynth() || options::sygusQueryGen() + || options::sygusSolFilterImplied())) { + Trace("cegqi-sol-debug") << "Run expression mining..." << std::endl; std::map<Node, ExpressionMinerManager>::iterator its = d_exprm.find(prog); if (its == d_exprm.end()) @@ -993,6 +996,14 @@ void SynthConjecture::printSynthSolution(std::ostream& out) { d_exprm[prog].enableRewriteRuleSynth(); } + if (options::sygusQueryGen()) + { + d_exprm[prog].enableQueryGeneration(options::sygusQueryGenThresh()); + } + if (options::sygusSolFilterImplied()) + { + d_exprm[prog].enableFilterImpliedSolutions(); + } its = d_exprm.find(prog); } bool rew_print = false; @@ -1003,7 +1014,7 @@ void SynthConjecture::printSynthSolution(std::ostream& out) } if (!is_unique_term) { - ++(cei->d_statistics.d_candidate_rewrites); + ++(cei->d_statistics.d_filtered_solutions); } } if (is_unique_term) diff --git a/src/theory/quantifiers/sygus/synth_engine.cpp b/src/theory/quantifiers/sygus/synth_engine.cpp index ba227bc8f..479cfa535 100644 --- a/src/theory/quantifiers/sygus/synth_engine.cpp +++ b/src/theory/quantifiers/sygus/synth_engine.cpp @@ -426,17 +426,16 @@ SynthEngine::Statistics::Statistics() d_cegqi_lemmas_refine("SynthEngine::cegqi_lemmas_refine", 0), d_cegqi_si_lemmas("SynthEngine::cegqi_lemmas_si", 0), d_solutions("SynthConjecture::solutions", 0), - d_candidate_rewrites_print("SynthConjecture::candidate_rewrites_print", - 0), - d_candidate_rewrites("SynthConjecture::candidate_rewrites", 0) + d_filtered_solutions("SynthConjecture::filtered_solutions", 0), + d_candidate_rewrites_print("SynthConjecture::candidate_rewrites_print", 0) { smtStatisticsRegistry()->registerStat(&d_cegqi_lemmas_ce); smtStatisticsRegistry()->registerStat(&d_cegqi_lemmas_refine); smtStatisticsRegistry()->registerStat(&d_cegqi_si_lemmas); smtStatisticsRegistry()->registerStat(&d_solutions); + smtStatisticsRegistry()->registerStat(&d_filtered_solutions); smtStatisticsRegistry()->registerStat(&d_candidate_rewrites_print); - smtStatisticsRegistry()->registerStat(&d_candidate_rewrites); } SynthEngine::Statistics::~Statistics() @@ -445,8 +444,8 @@ SynthEngine::Statistics::~Statistics() smtStatisticsRegistry()->unregisterStat(&d_cegqi_lemmas_refine); smtStatisticsRegistry()->unregisterStat(&d_cegqi_si_lemmas); smtStatisticsRegistry()->unregisterStat(&d_solutions); + smtStatisticsRegistry()->unregisterStat(&d_filtered_solutions); smtStatisticsRegistry()->unregisterStat(&d_candidate_rewrites_print); - smtStatisticsRegistry()->unregisterStat(&d_candidate_rewrites); } } // namespace quantifiers diff --git a/src/theory/quantifiers/sygus/synth_engine.h b/src/theory/quantifiers/sygus/synth_engine.h index 8f0eea58f..a7346b888 100644 --- a/src/theory/quantifiers/sygus/synth_engine.h +++ b/src/theory/quantifiers/sygus/synth_engine.h @@ -100,8 +100,8 @@ class SynthEngine : public QuantifiersModule IntStat d_cegqi_lemmas_refine; IntStat d_cegqi_si_lemmas; IntStat d_solutions; + IntStat d_filtered_solutions; IntStat d_candidate_rewrites_print; - IntStat d_candidate_rewrites; Statistics(); ~Statistics(); }; /* class SynthEngine::Statistics */ diff --git a/src/theory/sets/.gitignore b/src/theory/sets/.gitignore deleted file mode 100644 index 4c83ffd6f..000000000 --- a/src/theory/sets/.gitignore +++ /dev/null @@ -1 +0,0 @@ -README.WHATS-NEXT diff --git a/src/theory/strings/regexp_elim.cpp b/src/theory/strings/regexp_elim.cpp index 0310e4620..a0d806c52 100644 --- a/src/theory/strings/regexp_elim.cpp +++ b/src/theory/strings/regexp_elim.cpp @@ -50,11 +50,89 @@ Node RegExpElimination::eliminateConcat(Node atom) Node x = atom[0]; Node lenx = nm->mkNode(STRING_LENGTH, x); Node re = atom[1]; + std::vector<Node> children; + TheoryStringsRewriter::getConcat(re, children); + + // If it can be reduced to memberships in fixed length regular expressions. + // This includes concatenations where at most one child is of the form + // (re.* re.allchar), which we abbreviate _* below, and all other children + // have a fixed length. + // The intuition why this is a "non-aggressive" rewrite is that membership + // into fixed length regular expressions are easy to handle. + bool hasFixedLength = true; + // the index of _* in re + unsigned pivotIndex = 0; + bool hasPivotIndex = false; + std::vector<Node> childLengths; + std::vector<Node> childLengthsPostPivot; + for (unsigned i = 0, size = children.size(); i < size; i++) + { + Node c = children[i]; + Node fl = TheoryStringsRewriter::getFixedLengthForRegexp(c); + if (fl.isNull()) + { + if (!hasPivotIndex && c.getKind() == REGEXP_STAR + && c[0].getKind() == REGEXP_SIGMA) + { + hasPivotIndex = true; + pivotIndex = i; + // set to zero for the sum below + fl = d_zero; + } + else + { + hasFixedLength = false; + break; + } + } + childLengths.push_back(fl); + if (hasPivotIndex) + { + childLengthsPostPivot.push_back(fl); + } + } + if (hasFixedLength) + { + Assert(re.getNumChildren() == children.size()); + Node sum = nm->mkNode(PLUS, childLengths); + std::vector<Node> conc; + conc.push_back(nm->mkNode(hasPivotIndex ? GEQ : EQUAL, lenx, sum)); + Node currEnd = d_zero; + for (unsigned i = 0, size = childLengths.size(); i < size; i++) + { + if (hasPivotIndex && i == pivotIndex) + { + Node ppSum = childLengthsPostPivot.size() == 1 + ? childLengthsPostPivot[0] + : nm->mkNode(PLUS, childLengthsPostPivot); + currEnd = nm->mkNode(MINUS, lenx, ppSum); + } + else + { + Node curr = nm->mkNode(STRING_SUBSTR, x, currEnd, childLengths[i]); + Node currMem = nm->mkNode(STRING_IN_REGEXP, curr, re[i]); + conc.push_back(currMem); + currEnd = nm->mkNode(PLUS, currEnd, childLengths[i]); + currEnd = Rewriter::rewrite(currEnd); + } + } + Node res = nm->mkNode(AND, conc); + // For example: + // x in re.++(re.union(re.range("A", "J"), re.range("N", "Z")), "AB") --> + // len( x ) = 3 ^ + // substr(x,0,1) in re.union(re.range("A", "J"), re.range("N", "Z")) ^ + // substr(x,1,2) in "AB" + // An example with a pivot index: + // x in re.++( "AB" ++ _* ++ "C" ) --> + // len( x ) >= 3 ^ + // substr( x, 0, 2 ) in "AB" ^ + // substr( x, len( x ) - 1, 1 ) in "C" + return returnElim(atom, res, "concat-fixed-len"); + } + // memberships of the form x in re.++ * s1 * ... * sn *, where * are // any number of repetitions (exact or indefinite) of re.allchar. Trace("re-elim-debug") << "Try re concat with gaps " << atom << std::endl; - std::vector<Node> children; - TheoryStringsRewriter::getConcat(re, children); bool onlySigmasAndConsts = true; std::vector<Node> sep_children; std::vector<unsigned> gap_minsize; diff --git a/src/theory/strings/theory_strings_preprocess.cpp b/src/theory/strings/theory_strings_preprocess.cpp index bdb339324..fcb02d058 100644 --- a/src/theory/strings/theory_strings_preprocess.cpp +++ b/src/theory/strings/theory_strings_preprocess.cpp @@ -37,6 +37,7 @@ StringsPreprocess::StringsPreprocess(SkolemCache *sc, context::UserContext *u) //Constants d_zero = NodeManager::currentNM()->mkConst(Rational(0)); d_one = NodeManager::currentNM()->mkConst(Rational(1)); + d_neg_one = NodeManager::currentNM()->mkConst(Rational(-1)); d_empty_str = NodeManager::currentNM()->mkConst(String("")); } @@ -257,104 +258,90 @@ Node StringsPreprocess::simplify( Node t, std::vector< Node > &new_nodes ) { // enforces that int.to.str( n ) has no leading zeroes. retNode = itost; } else if( t.getKind() == kind::STRING_STOI ) { - Node str = t[0]; - Node pret = nm->mkSkolem("stoit", nm->integerType(), "created for stoi"); - Node lenp = NodeManager::currentNM()->mkNode(kind::STRING_LENGTH, str); - - Node negone = NodeManager::currentNM()->mkConst( ::CVC4::Rational(-1) ); - Node one = NodeManager::currentNM()->mkConst( ::CVC4::Rational(1) ); - Node nine = NodeManager::currentNM()->mkConst( ::CVC4::Rational(9) ); - Node ten = NodeManager::currentNM()->mkConst( ::CVC4::Rational(10) ); + Node s = t[0]; + Node stoit = nm->mkSkolem("stoit", nm->integerType(), "created for stoi"); + Node lens = nm->mkNode(STRING_LENGTH, s); + + std::vector<Node> conc1; + Node lem = stoit.eqNode(d_neg_one); + conc1.push_back(lem); + + Node sEmpty = s.eqNode(d_empty_str); + Node k = nm->mkSkolem("k", nm->integerType()); + Node kc1 = nm->mkNode(GEQ, k, d_zero); + Node kc2 = nm->mkNode(LT, k, lens); + Node c0 = nm->mkNode(STRING_CODE, nm->mkConst(String("0"))); + Node codeSk = nm->mkNode( + MINUS, + nm->mkNode(STRING_CODE, nm->mkNode(STRING_SUBSTR, s, k, d_one)), + c0); + Node ten = nm->mkConst(Rational(10)); + Node kc3 = nm->mkNode( + OR, nm->mkNode(LT, codeSk, d_zero), nm->mkNode(GEQ, codeSk, ten)); + conc1.push_back(nm->mkNode(OR, sEmpty, nm->mkNode(AND, kc1, kc2, kc3))); + + std::vector<Node> conc2; std::vector< TypeNode > argTypes; - argTypes.push_back(NodeManager::currentNM()->integerType()); - Node ufP = NodeManager::currentNM()->mkSkolem("ufP", - NodeManager::currentNM()->mkFunctionType( - argTypes, NodeManager::currentNM()->integerType()), - "uf type conv P"); - Node ufM = NodeManager::currentNM()->mkSkolem("ufM", - NodeManager::currentNM()->mkFunctionType( - argTypes, NodeManager::currentNM()->integerType()), - "uf type conv M"); - - //Node ufP0 = NodeManager::currentNM()->mkNode(kind::APPLY_UF, ufP, d_zero); - //new_nodes.push_back(pret.eqNode(ufP0)); - //lemma - Node lem = NodeManager::currentNM()->mkNode(kind::IMPLIES, - str.eqNode(NodeManager::currentNM()->mkConst(::CVC4::String(""))), - pret.eqNode(negone)); + argTypes.push_back(nm->integerType()); + Node u = nm->mkSkolem("U", nm->mkFunctionType(argTypes, nm->integerType())); + Node us = + nm->mkSkolem("Us", nm->mkFunctionType(argTypes, nm->stringType())); + Node ud = + nm->mkSkolem("Ud", nm->mkFunctionType(argTypes, nm->stringType())); + + lem = stoit.eqNode(nm->mkNode(APPLY_UF, u, lens)); + conc2.push_back(lem); + + lem = d_zero.eqNode(nm->mkNode(APPLY_UF, u, d_zero)); + conc2.push_back(lem); + + lem = d_empty_str.eqNode(nm->mkNode(APPLY_UF, us, lens)); + conc2.push_back(lem); + + lem = s.eqNode(nm->mkNode(APPLY_UF, us, d_zero)); + conc2.push_back(lem); + + Node x = nm->mkBoundVar(nm->integerType()); + Node xbv = nm->mkNode(BOUND_VAR_LIST, x); + Node g = + nm->mkNode(AND, nm->mkNode(GEQ, x, d_zero), nm->mkNode(LT, x, lens)); + Node udx = nm->mkNode(APPLY_UF, ud, x); + Node ux = nm->mkNode(APPLY_UF, u, x); + Node ux1 = nm->mkNode(APPLY_UF, u, nm->mkNode(PLUS, x, d_one)); + Node c = nm->mkNode(MINUS, nm->mkNode(STRING_CODE, udx), c0); + Node usx = nm->mkNode(APPLY_UF, us, x); + Node usx1 = nm->mkNode(APPLY_UF, us, nm->mkNode(PLUS, x, d_one)); + + Node eqs = usx.eqNode(nm->mkNode(STRING_CONCAT, udx, usx1)); + Node eq = ux1.eqNode(nm->mkNode(PLUS, c, nm->mkNode(MULT, ten, ux))); + Node cb = + nm->mkNode(AND, nm->mkNode(GEQ, c, d_zero), nm->mkNode(LT, c, ten)); + + lem = nm->mkNode(OR, g.negate(), nm->mkNode(AND, eqs, eq, cb)); + lem = nm->mkNode(FORALL, xbv, lem); + conc2.push_back(lem); + + Node sneg = nm->mkNode(LT, stoit, d_zero); + lem = nm->mkNode(ITE, sneg, nm->mkNode(AND, conc1), nm->mkNode(AND, conc2)); new_nodes.push_back(lem); - /*lem = NodeManager::currentNM()->mkNode(kind::EQUAL, - t[0].eqNode(NodeManager::currentNM()->mkConst(::CVC4::String("0"))), - t.eqNode(d_zero)); - new_nodes.push_back(lem);*/ - //cc1 - Node cc1 = str.eqNode(NodeManager::currentNM()->mkConst(::CVC4::String(""))); - //cc1 = NodeManager::currentNM()->mkNode(kind::AND, ufP0.eqNode(negone), cc1); - //cc2 - std::vector< Node > vec_n; - Node p = NodeManager::currentNM()->mkSkolem("p", NodeManager::currentNM()->integerType()); - Node g = NodeManager::currentNM()->mkNode(kind::GEQ, p, d_zero); - vec_n.push_back(g); - g = NodeManager::currentNM()->mkNode(kind::GT, lenp, p); - vec_n.push_back(g); - Node z2 = NodeManager::currentNM()->mkNode(kind::STRING_SUBSTR, str, p, one); - char chtmp[2]; - chtmp[1] = '\0'; - for(unsigned i=0; i<=9; i++) { - chtmp[0] = i + '0'; - std::string stmp(chtmp); - g = z2.eqNode( NodeManager::currentNM()->mkConst(::CVC4::String(stmp)) ).negate(); - vec_n.push_back(g); - } - Node cc2 = NodeManager::currentNM()->mkNode(kind::AND, vec_n); - //cc3 - Node b2 = NodeManager::currentNM()->mkBoundVar(NodeManager::currentNM()->integerType()); - Node b2v = NodeManager::currentNM()->mkNode(kind::BOUND_VAR_LIST, b2); - Node g2 = NodeManager::currentNM()->mkNode(kind::AND, - NodeManager::currentNM()->mkNode(kind::GEQ, b2, d_zero), - NodeManager::currentNM()->mkNode(kind::GT, lenp, b2)); - Node ufx = NodeManager::currentNM()->mkNode(kind::APPLY_UF, ufP, b2); - Node ufx1 = NodeManager::currentNM()->mkNode(kind::APPLY_UF, ufP, NodeManager::currentNM()->mkNode(kind::MINUS,b2,one)); - Node ufMx = NodeManager::currentNM()->mkNode(kind::APPLY_UF, ufM, b2); - std::vector< Node > vec_c3; - std::vector< Node > vec_c3b; - //qx between 0 and 9 - Node c3cc = NodeManager::currentNM()->mkNode(kind::GEQ, ufMx, d_zero); - vec_c3b.push_back(c3cc); - c3cc = NodeManager::currentNM()->mkNode(kind::GEQ, nine, ufMx); - vec_c3b.push_back(c3cc); - Node sx = NodeManager::currentNM()->mkNode(kind::STRING_SUBSTR, str, b2, one); - for(unsigned i=0; i<=9; i++) { - chtmp[0] = i + '0'; - std::string stmp(chtmp); - c3cc = NodeManager::currentNM()->mkNode(kind::EQUAL, - ufMx.eqNode(NodeManager::currentNM()->mkConst(::CVC4::Rational(i))), - sx.eqNode(NodeManager::currentNM()->mkConst(::CVC4::String(stmp)))); - vec_c3b.push_back(c3cc); - } - //c312 - Node b2gtz = NodeManager::currentNM()->mkNode(kind::GT, b2, d_zero); - c3cc = NodeManager::currentNM()->mkNode(kind::IMPLIES, b2gtz, - ufx.eqNode(NodeManager::currentNM()->mkNode(kind::PLUS, - NodeManager::currentNM()->mkNode(kind::MULT, ufx1, ten), - ufMx))); - vec_c3b.push_back(c3cc); - c3cc = NodeManager::currentNM()->mkNode(kind::AND, vec_c3b); - c3cc = NodeManager::currentNM()->mkNode(kind::IMPLIES, g2, c3cc); - c3cc = NodeManager::currentNM()->mkNode(kind::FORALL, b2v, c3cc); - vec_c3.push_back(c3cc); - //unbound - c3cc = NodeManager::currentNM()->mkNode(kind::APPLY_UF, ufP, d_zero).eqNode(NodeManager::currentNM()->mkNode(kind::APPLY_UF, ufM, d_zero)); - vec_c3.push_back(c3cc); - Node lstx = NodeManager::currentNM()->mkNode(kind::MINUS, lenp, one); - Node upflstx = NodeManager::currentNM()->mkNode(kind::APPLY_UF, ufP, lstx); - c3cc = upflstx.eqNode(pret); - vec_c3.push_back(c3cc); - Node cc3 = NodeManager::currentNM()->mkNode(kind::AND, vec_c3); - Node conc = NodeManager::currentNM()->mkNode(kind::ITE, pret.eqNode(negone), - NodeManager::currentNM()->mkNode(kind::OR, cc1, cc2), cc3); - new_nodes.push_back( conc ); - retNode = pret; + + // assert: + // IF stoit < 0 + // THEN: + // stoit = -1 ^ + // ( s = "" OR + // ( k>=0 ^ k<len( s ) ^ ( str.code( str.substr( s, k, 1 ) ) < 48 OR + // str.code( str.substr( s, k, 1 ) ) >= 58 ))) + // ELSE: + // stoit = U( len( s ) ) ^ U( 0 ) = 0 ^ + // "" = Us( len( s ) ) ^ s = Us( 0 ) ^ + // forall x. (x>=0 ^ x < str.len(s)) => + // Us( x ) = Ud( x ) ++ Us( x+1 ) ^ + // U( x+1 ) = ( str.code( Ud( x ) ) - 48 ) + 10*U( x ) + // 48 <= str.code( Ud( x ) ) < 58 + // Thus, str.to.int( s ) = stoit + + retNode = stoit; } else if (t.getKind() == kind::STRING_STRREPL) { diff --git a/src/theory/strings/theory_strings_preprocess.h b/src/theory/strings/theory_strings_preprocess.h index c670a5483..ff0195dc1 100644 --- a/src/theory/strings/theory_strings_preprocess.h +++ b/src/theory/strings/theory_strings_preprocess.h @@ -68,6 +68,7 @@ private: /** commonly used constants */ Node d_zero; Node d_one; + Node d_neg_one; Node d_empty_str; /** pointer to the skolem cache used by this class */ SkolemCache *d_sc; diff --git a/src/theory/strings/theory_strings_rewriter.cpp b/src/theory/strings/theory_strings_rewriter.cpp index e8a11e62e..5ba9d6e3f 100644 --- a/src/theory/strings/theory_strings_rewriter.cpp +++ b/src/theory/strings/theory_strings_rewriter.cpp @@ -507,9 +507,7 @@ Node TheoryStringsRewriter::rewriteStrEqualityExt(Node node) } // (= "" (str.replace x "A" "")) ---> (str.prefix x "A") - Node one = nm->mkConst(Rational(1)); - Node ylen = nm->mkNode(STRING_LENGTH, ne[1]); - if (checkEntailArithEq(ylen, one) && ne[2] == empty) + if (checkEntailLengthOne(ne[1]) && ne[2] == empty) { Node ret = nm->mkNode(STRING_PREFIX, ne[0], ne[1]); return returnRewrite(node, ret, "str-emp-repl-emp"); @@ -577,6 +575,21 @@ Node TheoryStringsRewriter::rewriteStrEqualityExt(Node node) return returnRewrite(node, ret, "str-eq-repl-not-ctn"); } } + + // (= (str.replace x y z) z) --> (or (= x y) (= x z)) + // if (str.len y) = (str.len z) + if (repl[2] == x) + { + Node lenY = nm->mkNode(STRING_LENGTH, repl[1]); + Node lenZ = nm->mkNode(STRING_LENGTH, repl[2]); + if (checkEntailArithEq(lenY, lenZ)) + { + Node ret = nm->mkNode(OR, + nm->mkNode(EQUAL, repl[0], repl[1]), + nm->mkNode(EQUAL, repl[0], repl[2])); + return returnRewrite(node, ret, "str-eq-repl-to-dis"); + } + } } } @@ -1658,11 +1671,8 @@ Node TheoryStringsRewriter::rewriteSubstr(Node node) // if (str.len y) = 1 and (str.len z) = 1 if (node[1] == zero) { - Node one = nm->mkConst(Rational(1)); - Node n1len = nm->mkNode(kind::STRING_LENGTH, node[0][1]); - Node n2len = nm->mkNode(kind::STRING_LENGTH, node[0][2]); - if (checkEntailArith(one, n1len) && checkEntailArith(one, n2len) - && checkEntailNonEmpty(node[0][1]) && checkEntailNonEmpty(node[0][2])) + if (checkEntailLengthOne(node[0][1], true) + && checkEntailLengthOne(node[0][2], true)) { Node ret = nm->mkNode( kind::STRING_STRREPL, @@ -1738,9 +1748,17 @@ Node TheoryStringsRewriter::rewriteSubstr(Node node) return returnRewrite(node, ret, "ss-start-entails-zero-len"); } + // (str.substr s x y) --> "" if 0 < y |= x >= str.len(s) + Node non_zero_len = + Rewriter::rewrite(nm->mkNode(kind::LT, zero, node[2])); + if (checkEntailArithWithAssumption(non_zero_len, node[1], tot_len, false)) + { + Node ret = nm->mkConst(::CVC4::String("")); + return returnRewrite(node, ret, "ss-non-zero-len-entails-oob"); + } + // (str.substr s x x) ---> "" if (str.len s) <= 1 - Node one = nm->mkConst(CVC4::Rational(1)); - if (node[1] == node[2] && checkEntailArith(one, tot_len)) + if (node[1] == node[2] && checkEntailLengthOne(node[0])) { Node ret = nm->mkConst(::CVC4::String("")); return returnRewrite(node, ret, "ss-len-one-z-z"); @@ -2155,8 +2173,7 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) { // (str.contains (str.replace x y x) z) ---> (str.contains x z) // if (str.len z) <= 1 - Node one = nm->mkConst(Rational(1)); - if (checkEntailArith(one, len_n2)) + if (checkEntailLengthOne(node[1])) { Node ret = nm->mkNode(kind::STRING_STRCTN, node[0][0], node[1]); return returnRewrite(node, ret, "ctn-repl-len-one-to-ctn"); @@ -2172,6 +2189,24 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) { nm->mkNode(STRING_STRCTN, node[0][0], node[0][2])); return returnRewrite(node, ret, "ctn-repl-to-ctn-disj"); } + + // (str.contains (str.replace x y z) w) ---> + // (str.contains (str.replace x y "") w) + // if (str.contains z w) ---> false and (str.len w) = 1 + if (checkEntailLengthOne(node[1])) + { + Node ctn = Rewriter::rewrite( + nm->mkNode(kind::STRING_STRCTN, node[1], node[0][2])); + if (ctn.isConst() && !ctn.getConst<bool>()) + { + Node empty = nm->mkConst(String("")); + Node ret = nm->mkNode( + kind::STRING_STRCTN, + nm->mkNode(kind::STRING_STRREPL, node[0][0], node[0][1], empty), + node[1]); + return returnRewrite(node, ret, "ctn-repl-simp-repl"); + } + } } if (node[1].getKind() == kind::STRING_STRREPL) @@ -2474,7 +2509,7 @@ Node TheoryStringsRewriter::rewriteReplace( Node node ) { // (str.replace x y x) ---> (str.replace x (str.++ y1 ... yn) x) // if 1 >= (str.len x) and (= y "") ---> (= y1 "") ... (= yn "") - if (checkEntailArith(nm->mkConst(Rational(1)), l0)) + if (checkEntailLengthOne(node[0])) { Node empty = nm->mkConst(String("")); Node rn1 = Rewriter::rewrite( @@ -2845,6 +2880,53 @@ Node TheoryStringsRewriter::rewriteReplace( Node node ) { } } } + // miniscope based on components that do not contribute to contains + // for example, + // str.replace( x ++ y ++ x ++ y, "A", z ) --> + // str.replace( x ++ y, "A", z ) ++ x ++ y + // since if "A" occurs in x ++ y ++ x ++ y, then it must occur in x ++ y. + if (checkEntailLengthOne(node[1])) + { + Node lastLhs; + unsigned lastCheckIndex = 0; + for (unsigned i = 1, iend = children0.size(); i < iend; i++) + { + unsigned checkIndex = children0.size() - i; + std::vector<Node> checkLhs; + checkLhs.insert( + checkLhs.end(), children0.begin(), children0.begin() + checkIndex); + Node lhs = mkConcat(STRING_CONCAT, checkLhs); + Node rhs = children0[checkIndex]; + Node ctn = nm->mkNode(STRING_STRCTN, lhs, rhs); + ctn = Rewriter::rewrite(ctn); + if (ctn.isConst() && ctn.getConst<bool>()) + { + lastLhs = lhs; + lastCheckIndex = checkIndex; + } + else + { + break; + } + } + if (!lastLhs.isNull()) + { + std::vector<Node> remc(children0.begin() + lastCheckIndex, + children0.end()); + Node rem = mkConcat(STRING_CONCAT, remc); + Node ret = + nm->mkNode(STRING_CONCAT, + nm->mkNode(STRING_STRREPL, lastLhs, node[1], node[2]), + rem); + // for example: + // str.replace( x ++ x, "A", y ) ---> str.replace( x, "A", y ) ++ x + // Since we know that the first occurrence of "A" cannot be in the + // second occurrence of x. Notice this is specific to single characters + // due to complications with finds that span multiple components for + // non-characters. + return returnRewrite(node, ret, "repl-char-ncontrib-find"); + } + } // TODO (#1180) incorporate these? // contains( t, s ) => @@ -3779,6 +3861,15 @@ bool TheoryStringsRewriter::checkEntailNonEmpty(Node a) return checkEntailArith(len, true); } +bool TheoryStringsRewriter::checkEntailLengthOne(Node s, bool strict) +{ + NodeManager* nm = NodeManager::currentNM(); + Node one = nm->mkConst(Rational(1)); + Node len = nm->mkNode(STRING_LENGTH, s); + len = Rewriter::rewrite(len); + return checkEntailArith(one, len) && (!strict || checkEntailArith(len, true)); +} + bool TheoryStringsRewriter::checkEntailArithEq(Node a, Node b) { if (a == b) @@ -4538,6 +4629,59 @@ Node TheoryStringsRewriter::getConstantArithBound(Node a, bool isLower) return ret; } +Node TheoryStringsRewriter::getFixedLengthForRegexp(Node n) +{ + NodeManager* nm = NodeManager::currentNM(); + if (n.getKind() == STRING_TO_REGEXP) + { + Node ret = nm->mkNode(STRING_LENGTH, n[0]); + ret = Rewriter::rewrite(ret); + if (ret.isConst()) + { + return ret; + } + } + else if (n.getKind() == REGEXP_SIGMA || n.getKind() == REGEXP_RANGE) + { + return nm->mkConst(Rational(1)); + } + else if (n.getKind() == REGEXP_UNION || n.getKind() == REGEXP_INTER) + { + Node ret; + for (const Node& nc : n) + { + Node flc = getFixedLengthForRegexp(nc); + if (flc.isNull() || (!ret.isNull() && ret != flc)) + { + return Node::null(); + } + else if (ret.isNull()) + { + // first time + ret = flc; + } + } + return ret; + } + else if (n.getKind() == REGEXP_CONCAT) + { + NodeBuilder<> nb(PLUS); + for (const Node& nc : n) + { + Node flc = getFixedLengthForRegexp(nc); + if (flc.isNull()) + { + return flc; + } + nb << flc; + } + Node ret = nb.constructNode(); + ret = Rewriter::rewrite(ret); + return ret; + } + return Node::null(); +} + bool TheoryStringsRewriter::checkEntailArithInternal(Node a) { Assert(Rewriter::rewrite(a) == a); @@ -4614,8 +4758,7 @@ Node TheoryStringsRewriter::getStringOrEmpty(Node n) break; } - Node strlen = Rewriter::rewrite(nm->mkNode(kind::STRING_LENGTH, n[0])); - if (strlen == nm->mkConst(Rational(1)) && n[2] == empty) + if (checkEntailLengthOne(n[0]) && n[2] == empty) { // (str.replace "A" x "") --> "A" res = n[0]; @@ -4627,8 +4770,7 @@ Node TheoryStringsRewriter::getStringOrEmpty(Node n) } case kind::STRING_SUBSTR: { - Node strlen = Rewriter::rewrite(nm->mkNode(kind::STRING_LENGTH, n[0])); - if (strlen == nm->mkConst(Rational(1))) + if (checkEntailLengthOne(n[0])) { // (str.substr "A" x y) --> "A" res = n[0]; diff --git a/src/theory/strings/theory_strings_rewriter.h b/src/theory/strings/theory_strings_rewriter.h index 2c38ce8dc..2e356f8f7 100644 --- a/src/theory/strings/theory_strings_rewriter.h +++ b/src/theory/strings/theory_strings_rewriter.h @@ -455,6 +455,19 @@ class TheoryStringsRewriter { * the call checkArithEntail( len( a ), true ). */ static bool checkEntailNonEmpty(Node a); + + /** + * Checks whether string has at most/exactly length one. Length one strings + * can be used for more aggressive rewriting because there is guaranteed that + * it cannot be overlap multiple components in a string concatenation. + * + * @param s The string to check + * @param strict If true, the string must have exactly length one, otherwise + * at most length one + * @return True if the string has at most/exactly length one, false otherwise + */ + static bool checkEntailLengthOne(Node s, bool strict = false); + /** check arithmetic entailment equal * Returns true if it is always the case that a = b. */ @@ -566,6 +579,12 @@ class TheoryStringsRewriter { * checkEntailArith( a, strict ) = true. */ static Node getConstantArithBound(Node a, bool isLower = true); + /** get length for regular expression + * + * Given regular expression n, if this method returns a non-null value c, then + * x in n entails len( x ) = c. + */ + static Node getFixedLengthForRegexp(Node n); /** decompose substr chain * * If s is substr( ... substr( base, x1, y1 ) ..., xn, yn ), then this diff --git a/src/theory/subs_minimize.cpp b/src/theory/subs_minimize.cpp index 03a55b3a4..58daf5c75 100644 --- a/src/theory/subs_minimize.cpp +++ b/src/theory/subs_minimize.cpp @@ -14,6 +14,7 @@ #include "theory/subs_minimize.h" +#include "expr/node_algorithm.h" #include "theory/bv/theory_bv_utils.h" #include "theory/rewriter.h" @@ -25,20 +26,157 @@ namespace theory { SubstitutionMinimize::SubstitutionMinimize() {} -bool SubstitutionMinimize::find(Node n, +bool SubstitutionMinimize::find(Node t, Node target, const std::vector<Node>& vars, const std::vector<Node>& subs, std::vector<Node>& reqVars) { + return findInternal(t, target, vars, subs, reqVars); +} + +void getConjuncts(Node n, std::vector<Node>& conj) +{ + if (n.getKind() == AND) + { + for (const Node& nc : n) + { + conj.push_back(nc); + } + } + else + { + conj.push_back(n); + } +} + +bool SubstitutionMinimize::findWithImplied(Node t, + const std::vector<Node>& vars, + const std::vector<Node>& subs, + std::vector<Node>& reqVars, + std::vector<Node>& impliedVars) +{ + NodeManager* nm = NodeManager::currentNM(); + Node truen = nm->mkConst(true); + if (!findInternal(t, truen, vars, subs, reqVars)) + { + return false; + } + if (reqVars.empty()) + { + return true; + } + + // map from conjuncts of t to whether they may be used to show an implied var + std::vector<Node> tconj; + getConjuncts(t, tconj); + // map from conjuncts to their free symbols + std::map<Node, std::unordered_set<Node, NodeHashFunction> > tcFv; + + std::unordered_set<Node, NodeHashFunction> reqSet; + std::vector<Node> reqSubs; + std::map<Node, unsigned> reqVarToIndex; + for (const Node& v : reqVars) + { + reqVarToIndex[v] = reqSubs.size(); + const std::vector<Node>::const_iterator& it = + std::find(vars.begin(), vars.end(), v); + Assert(it != vars.end()); + ptrdiff_t pos = std::distance(vars.begin(), it); + reqSubs.push_back(subs[pos]); + } + std::vector<Node> finalReqVars; + for (const Node& v : vars) + { + if (reqVarToIndex.find(v) == reqVarToIndex.end()) + { + // not a required variable, nothing to do + continue; + } + unsigned vindex = reqVarToIndex[v]; + Node prev = reqSubs[vindex]; + // make identity substitution + reqSubs[vindex] = v; + bool madeImplied = false; + // it is a required variable, can we make an implied variable? + for (const Node& tc : tconj) + { + // ensure we've computed its free symbols + std::map<Node, std::unordered_set<Node, NodeHashFunction> >::iterator + itf = tcFv.find(tc); + if (itf == tcFv.end()) + { + expr::getSymbols(tc, tcFv[tc]); + itf = tcFv.find(tc); + } + // only have a chance if contains v + if (itf->second.find(v) == itf->second.end()) + { + continue; + } + // try the current substitution + Node tcs = tc.substitute( + reqVars.begin(), reqVars.end(), reqSubs.begin(), reqSubs.end()); + Node tcsr = Rewriter::rewrite(tcs); + std::vector<Node> tcsrConj; + getConjuncts(tcsr, tcsrConj); + for (const Node& tcc : tcsrConj) + { + if (tcc.getKind() == EQUAL) + { + for (unsigned r = 0; r < 2; r++) + { + if (tcc[r] == v) + { + Node res = tcc[1 - r]; + if (res.isConst()) + { + Assert(res == prev); + madeImplied = true; + break; + } + } + } + } + if (madeImplied) + { + break; + } + } + if (madeImplied) + { + break; + } + } + if (!madeImplied) + { + // revert the substitution + reqSubs[vindex] = prev; + finalReqVars.push_back(v); + } + else + { + impliedVars.push_back(v); + } + } + reqVars.clear(); + reqVars.insert(reqVars.end(), finalReqVars.begin(), finalReqVars.end()); + + return true; +} + +bool SubstitutionMinimize::findInternal(Node n, + Node target, + const std::vector<Node>& vars, + const std::vector<Node>& subs, + std::vector<Node>& reqVars) +{ Trace("subs-min") << "Substitution minimize : " << std::endl; Trace("subs-min") << " substitution : " << vars << " -> " << subs << std::endl; Trace("subs-min") << " node : " << n << std::endl; Trace("subs-min") << " target : " << target << std::endl; - std::map<Node, std::unordered_set<Node, NodeHashFunction> > fvDepend; - Trace("subs-min") << "--- Compute values for subterms..." << std::endl; // the value of each subterm in n under the substitution std::unordered_map<TNode, Node, TNodeHashFunction> value; @@ -124,8 +262,6 @@ bool SubstitutionMinimize::find(Node n, Trace("subs-min") << "--- Compute relevant variables..." << std::endl; std::unordered_set<Node, NodeHashFunction> rlvFv; // only variables that occur in assertions are relevant - std::map<Node, unsigned> iteBranch; - std::map<Node, std::vector<unsigned> > justifyArgs; visit.push_back(n); std::unordered_set<TNode, TNodeHashFunction> visited; diff --git a/src/theory/subs_minimize.h b/src/theory/subs_minimize.h index 55e57b921..bf6ccffae 100644 --- a/src/theory/subs_minimize.h +++ b/src/theory/subs_minimize.h @@ -36,21 +36,55 @@ class SubstitutionMinimize ~SubstitutionMinimize() {} /** find * - * If n { vars -> subs } rewrites to target, this method returns true, and - * vars[i1], ..., vars[in] are added to rewVars, such that - * n { vars[i_1] -> subs[i_1], ..., vars[i_n] -> subs[i_n] } also rewrites to - * target. + * If t { vars -> subs } rewrites to target, this method returns true, and + * vars[i_1], ..., vars[i_n] are added to reqVars, such that i_1, ..., i_n are + * distinct, and t { vars[i_1] -> subs[i_1], ..., vars[i_n] -> subs[i_n] } + * rewrites to target. * - * If n { vars -> subs } does not rewrite to target, this method returns + * If t { vars -> subs } does not rewrite to target, this method returns * false. */ - static bool find(Node n, + static bool find(Node t, Node target, const std::vector<Node>& vars, const std::vector<Node>& subs, std::vector<Node>& reqVars); + /** find with implied + * + * This method should be called on a formula t. + * + * If t { vars -> subs } rewrites to true, this method returns true, + * vars[i_1], ..., vars[i_n] are added to reqVars, and + * vars[i_{n+1}], ..., vars[i_{n+m}] are added to impliedVars such that + * i_1...i_{n+m} are distinct, i_{n+1} < ... < i_{n+m}, and: + * + * (1) t { vars[i_1]->subs[i_1], ..., vars[i_{n+k}]->subs[i_{n+k}] } implies + * vars[i_{n+k+1}] = subs[i_{n+k+1}] for k = 0, ..., m-1. + * + * (2) t { vars[i_1] -> subs[i_1], ..., vars[i_{n+m}] -> subs[i_{n+m}] } + * rewrites to true. + * + * For example, given (x>0 ^ x = y ^ y = z){ x -> 1, y -> 1, z -> 1, w -> 0 }, + * this method may add { x } to reqVars, and { y, z } to impliedVars. + * + * Notice that the order of variables in vars matters. By the semantics above, + * variables that appear earlier in the variable list vars are more likely + * to appear in reqVars, whereas those later in the vars are more likely to + * appear in impliedVars. + */ + static bool findWithImplied(Node t, + const std::vector<Node>& vars, + const std::vector<Node>& subs, + std::vector<Node>& reqVars, + std::vector<Node>& impliedVars); private: + /** Common helper function for the above functions. */ + static bool findInternal(Node t, + Node target, + const std::vector<Node>& vars, + const std::vector<Node>& subs, + std::vector<Node>& reqVars); /** is singular arg * * Returns true if |