diff options
author | Andres Noetzli <andres.noetzli@gmail.com> | 2020-08-18 14:13:17 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-08-18 14:13:17 -0700 |
commit | 005b6458d3340dd805279eb4c442d2871d75c844 (patch) | |
tree | 137249805212631349ae0f28b77193bbdd321af5 | |
parent | 8e66e8dfb1a7f54c8982141bdbbc1dfe914e5900 (diff) | |
parent | aa8da1ff4e7f119408dbf14074b9a5efcb06618b (diff) |
Merge branch 'master' into regDisableProofs
33 files changed, 2637 insertions, 672 deletions
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ca873c294..2d3586483 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -62,7 +62,9 @@ jobs: libcln-dev \ libgmp-dev \ libedit-dev \ - swig3.0 + flex \ + libfl-dev \ + flexc++ python3 -m pip install toml python3 -m pip install setuptools python3 -m pip install pexpect @@ -78,7 +80,7 @@ jobs: cln \ gmp \ pkgconfig \ - swig + flex python3 -m pip install toml python3 -m pip install setuptools python3 -m pip install pexpect diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 1d54573e9..48bd99f44 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -271,6 +271,8 @@ libcvc4_add_sources( smt/smt_solver.h smt/smt_statistics_registry.cpp smt/smt_statistics_registry.h + smt/sygus_solver.cpp + smt/sygus_solver.h smt/term_formula_removal.cpp smt/term_formula_removal.h smt/update_ostream.h @@ -695,6 +697,8 @@ libcvc4_add_sources( theory/quantifiers/theory_quantifiers_type_rules.h theory/quantifiers_engine.cpp theory/quantifiers_engine.h + theory/relevance_manager.cpp + theory/relevance_manager.h theory/rep_set.cpp theory/rep_set.h theory/rewriter.cpp @@ -814,6 +818,8 @@ libcvc4_add_sources( theory/theory_rewriter.cpp theory/theory_rewriter.h theory/theory_registrar.h + theory/theory_state.cpp + theory/theory_state.h theory/theory_test_utils.h theory/trust_node.cpp theory/trust_node.h diff --git a/src/api/cvc4cpp.cpp b/src/api/cvc4cpp.cpp index 5ccb4c6c1..150f84301 100644 --- a/src/api/cvc4cpp.cpp +++ b/src/api/cvc4cpp.cpp @@ -5157,7 +5157,7 @@ Term Solver::mkSygusVar(Sort sort, const std::string& symbol) const Expr res = d_exprMgr->mkBoundVar(symbol, *sort.d_type); (void)res.getType(true); /* kick off type checking */ - d_smtEngine->declareSygusVar(symbol, res, *sort.d_type); + d_smtEngine->declareSygusVar(symbol, res, TypeNode::fromType(*sort.d_type)); return Term(this, res); @@ -5279,14 +5279,21 @@ Term Solver::synthFunHelper(const std::string& symbol, ? *sort.d_type : d_exprMgr->mkFunctionType(varTypes, *sort.d_type); - Expr fun = d_exprMgr->mkBoundVar(symbol, funType); + Node fun = getNodeManager()->mkBoundVar(symbol, TypeNode::fromType(funType)); (void)fun.getType(true); /* kick off type checking */ - d_smtEngine->declareSynthFun(symbol, - fun, - g == nullptr ? funType : *g->resolve().d_type, - isInv, - termVectorToExprs(boundVars)); + std::vector<Node> bvns; + for (const Term& t : boundVars) + { + bvns.push_back(*t.d_node); + } + + d_smtEngine->declareSynthFun( + symbol, + fun, + TypeNode::fromType(g == nullptr ? funType : *g->resolve().d_type), + isInv, + bvns); return Term(this, fun); @@ -5373,13 +5380,12 @@ Term Solver::getSynthSolution(Term term) const CVC4_API_ARG_CHECK_NOT_NULL(term); CVC4_API_SOLVER_CHECK_TERM(term); - std::map<CVC4::Expr, CVC4::Expr> map; + std::map<CVC4::Node, CVC4::Node> map; CVC4_API_CHECK(d_smtEngine->getSynthSolutions(map)) << "The solver is not in a state immediately preceeded by a " "successful call to checkSynth"; - std::map<CVC4::Expr, CVC4::Expr>::const_iterator it = - map.find(term.d_node->toExpr()); + std::map<CVC4::Node, CVC4::Node>::const_iterator it = map.find(*term.d_node); CVC4_API_CHECK(it != map.cend()) << "Synth solution not found for given term"; @@ -5403,7 +5409,7 @@ std::vector<Term> Solver::getSynthSolutions( << "non-null term"; } - std::map<CVC4::Expr, CVC4::Expr> map; + std::map<CVC4::Node, CVC4::Node> map; CVC4_API_CHECK(d_smtEngine->getSynthSolutions(map)) << "The solver is not in a state immediately preceeded by a " "successful call to checkSynth"; @@ -5413,8 +5419,8 @@ std::vector<Term> Solver::getSynthSolutions( for (size_t i = 0, n = terms.size(); i < n; ++i) { - std::map<CVC4::Expr, CVC4::Expr>::const_iterator it = - map.find(terms[i].d_node->toExpr()); + std::map<CVC4::Node, CVC4::Node>::const_iterator it = + map.find(*terms[i].d_node); CVC4_API_CHECK(it != map.cend()) << "Synth solution not found for term at index " << i; diff --git a/src/expr/CMakeLists.txt b/src/expr/CMakeLists.txt index 993df2594..0092b78c6 100644 --- a/src/expr/CMakeLists.txt +++ b/src/expr/CMakeLists.txt @@ -61,6 +61,8 @@ libcvc4_add_sources( symbol_table.h term_canonize.cpp term_canonize.h + term_context.cpp + term_context.h term_conversion_proof_generator.cpp term_conversion_proof_generator.h type.cpp diff --git a/src/expr/term_context.cpp b/src/expr/term_context.cpp new file mode 100644 index 000000000..6dcdc25ee --- /dev/null +++ b/src/expr/term_context.cpp @@ -0,0 +1,113 @@ +/********************* */ +/*! \file term_context.cpp + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Implementation of term context utilities. + **/ + +#include "expr/term_context.h" + +namespace CVC4 { + +uint32_t RtfTermContext::initialValue() const +{ + // by default, not in a term context or a quantifier + return 0; +} + +uint32_t RtfTermContext::computeValue(TNode t, + uint32_t tval, + size_t child) const +{ + if (t.isClosure()) + { + if (tval % 2 == 0) + { + return tval + 1; + } + } + else if (hasNestedTermChildren(t)) + { + if (tval < 2) + { + return tval + 2; + } + } + return tval; +} + +uint32_t RtfTermContext::getValue(bool inQuant, bool inTerm) +{ + return (inQuant ? 1 : 0) + 2 * (inTerm ? 1 : 0); +} + +void RtfTermContext::getFlags(uint32_t val, bool& inQuant, bool& inTerm) +{ + inQuant = val % 2 == 1; + inTerm = val >= 2; +} + +bool RtfTermContext::hasNestedTermChildren(TNode t) +{ + Kind k = t.getKind(); + // dont' worry about FORALL or EXISTS, these are part of inQuant. + return theory::kindToTheoryId(k) != theory::THEORY_BOOL && k != kind::EQUAL + && k != kind::SEP_STAR && k != kind::SEP_WAND && k != kind::SEP_LABEL + && k != kind::BITVECTOR_EAGER_ATOM; +} + +uint32_t PolarityTermContext::initialValue() const +{ + // by default, we have true polarity + return 2; +} + +uint32_t PolarityTermContext::computeValue(TNode t, + uint32_t tval, + size_t index) const +{ + switch (t.getKind()) + { + case kind::AND: + case kind::OR: + case kind::SEP_STAR: + // polarity preserved + return tval; + case kind::IMPLIES: + // first child reverses, otherwise we preserve + return index == 0 ? (tval == 0 ? 0 : (3 - tval)) : tval; + case kind::NOT: + // polarity reversed + return tval == 0 ? 0 : (3 - tval); + case kind::ITE: + // head has no polarity, branches preserve + return index == 0 ? 0 : tval; + case kind::FORALL: + // body preserves, others have no polarity. + return index == 1 ? tval : 0; + default: + // no polarity + break; + } + return 0; +} + +uint32_t PolarityTermContext::getValue(bool hasPol, bool pol) +{ + return hasPol ? (pol ? 2 : 1) : 0; +} + +void PolarityTermContext::getFlags(uint32_t val, bool& hasPol, bool& pol) +{ + hasPol = val == 0; + pol = val == 2; +} + +} // namespace CVC4 diff --git a/src/expr/term_context.h b/src/expr/term_context.h new file mode 100644 index 000000000..87f91f2df --- /dev/null +++ b/src/expr/term_context.h @@ -0,0 +1,145 @@ +/********************* */ +/*! \file term_context.h + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Term context utilities. + **/ + +#include "cvc4_private.h" + +#ifndef CVC4__EXPR__TERM_CONTEXT_H +#define CVC4__EXPR__TERM_CONTEXT_H + +#include "expr/node.h" + +namespace CVC4 { + +/** + * This is an abstract class for computing "term context identifiers". A term + * context identifier is a hash value that identifies some property of the + * context in which a term occurs. Common examples of the implementation of + * such a mapping are implemented in the subclasses below. + * + * A term context identifier is intended to be information that can be locally + * computed from the parent's hash, and hence does not rely on maintaining + * paths. + * + * In the below documentation, we write t @ [p] to a term at a given position, + * where p is a list of indices. For example, the atomic subterms of: + * (and P (not Q)) + * are P @ [0] and Q @ [1,0]. + */ +class TermContext +{ + public: + TermContext() {} + virtual ~TermContext() {} + /** The default initial value of root terms. */ + virtual uint32_t initialValue() const = 0; + /** + * Returns the term context identifier of the index^th child of t, where tval + * is the term context identifier of t. + */ + virtual uint32_t computeValue(TNode t, uint32_t tval, size_t index) const = 0; +}; + +/** + * Remove term formulas (rtf) term context. + * + * Computes whether we are inside a term (as opposed to being part of Boolean + * skeleton) and whether we are inside a quantifier. For example, for: + * (and (= a b) (forall ((x Int)) (P x))) + * we have the following mappings (term -> inTerm,inQuant) + * (= a b) @ [0] -> false, false + * a @ [0,1] -> true, false + * (P x) @ [1,1] -> false, true + * x @ [1,1,0] -> true, true + * Notice that the hash of a child can be computed from the parent's hash only, + * and hence this can be implemented as an instance of the abstract class. + */ +class RtfTermContext : public TermContext +{ + public: + RtfTermContext() {} + /** The initial value: not in a term context or beneath a quantifier. */ + uint32_t initialValue() const override; + /** Compute the value of the index^th child of t whose hash is tval */ + uint32_t computeValue(TNode t, uint32_t tval, size_t index) const override; + /** get hash value from the flags */ + static uint32_t getValue(bool inQuant, bool inTerm); + /** get flags from the hash value */ + static void getFlags(uint32_t val, bool& inQuant, bool& inTerm); + + private: + /** + * Returns true if the children of t should be considered in a "term" context, + * which is any context beneath a symbol that does not belong to the Boolean + * theory as well as other exceptions like equality, separation logic + * connectives and bit-vector eager atoms. + */ + static bool hasNestedTermChildren(TNode t); +}; + +/** + * Polarity term context. + * + * This class computes the polarity of a term-context-sensitive term, which is + * one of {true, false, none}. This corresponds to the value that can be + * assigned to that term while preservering satisfiability of the overall + * formula, or none if such a value does not exist. If not "none", this + * typically corresponds to whether the number of NOT the formula is beneath is + * even, although special cases exist (e.g. the first child of IMPLIES). + * + * For example, given the formula: + * (and P (not (= (f x) 0))) + * assuming the root of this formula has true polarity, we have that: + * P @ [0] -> true + * (not (= (f x) 0)) @ [1] -> true + * (= (f x) 0) @ [1,0] -> false + * (f x) @ [1,0,0]), x @ [1,0,0,0]), 0 @ [1,0,1] -> none + * + * Notice that a term-context-sensitive Node is not one-to-one with Node. + * In particular, given the formula: + * (and P (not P)) + * We have that the P at path [0] has polarity true and the P at path [1,0] has + * polarity false. + * + * Finally, notice that polarity does not correspond to a value that the + * formula entails. Thus, for the formula: + * (or P Q) + * we have that + * P @ [0] -> true + * Q @ [1] -> true + * although neither is entailed. + * + * Notice that the hash of a child can be computed from the parent's hash only. + */ +class PolarityTermContext : public TermContext +{ + public: + PolarityTermContext() {} + /** The initial value: true polarity. */ + uint32_t initialValue() const override; + /** Compute the value of the index^th child of t whose hash is tval */ + uint32_t computeValue(TNode t, uint32_t tval, size_t index) const override; + /** + * Get hash value from the flags, where hasPol false means no polarity. + */ + static uint32_t getValue(bool hasPol, bool pol); + /** + * get flags from the hash value. If we have no polarity, both hasPol and pol + * are set to false. + */ + static void getFlags(uint32_t val, bool& hasPol, bool& pol); +}; + +} // namespace CVC4 + +#endif /* CVC4__EXPR__TERM_CONVERSION_PROOF_GENERATOR_H */ diff --git a/src/preprocessing/passes/sygus_inference.cpp b/src/preprocessing/passes/sygus_inference.cpp index 7336ac159..d44321a35 100644 --- a/src/preprocessing/passes/sygus_inference.cpp +++ b/src/preprocessing/passes/sygus_inference.cpp @@ -313,25 +313,25 @@ bool SygusInference::solveSygus(std::vector<Node>& assertions, return false; } // get the synthesis solutions - std::map<Expr, Expr> synth_sols; + std::map<Node, Node> synth_sols; rrSygus->getSynthSolutions(synth_sols); std::vector<Node> final_ff; std::vector<Node> final_ff_sol; - for (std::map<Expr, Expr>::iterator it = synth_sols.begin(); + for (std::map<Node, Node>::iterator it = synth_sols.begin(); it != synth_sols.end(); ++it) { Trace("sygus-infer") << " synth sol : " << it->first << " -> " << it->second << std::endl; - Node ffv = Node::fromExpr(it->first); + Node ffv = it->first; std::map<Node, Node>::iterator itffv = ff_var_to_ff.find(ffv); // all synthesis solutions should correspond to a variable we introduced Assert(itffv != ff_var_to_ff.end()); if (itffv != ff_var_to_ff.end()) { Node ff = itffv->second; - Node body2 = Node::fromExpr(it->second); + Node body2 = it->second; Trace("sygus-infer") << "Define " << ff << " as " << body2 << std::endl; funs.push_back(ff); sols.push_back(body2); diff --git a/src/smt/abduction_solver.cpp b/src/smt/abduction_solver.cpp index 01e2a4f0f..2a6346c18 100644 --- a/src/smt/abduction_solver.cpp +++ b/src/smt/abduction_solver.cpp @@ -89,16 +89,15 @@ bool AbductionSolver::getAbductInternal(Node& abd) if (r.asSatisfiabilityResult().isSat() == Result::UNSAT) { // get the synthesis solution - std::map<Expr, Expr> sols; + std::map<Node, Node> sols; d_subsolver->getSynthSolutions(sols); Assert(sols.size() == 1); - Expr essf = d_sssf.toExpr(); - std::map<Expr, Expr>::iterator its = sols.find(essf); + std::map<Node, Node>::iterator its = sols.find(d_sssf); if (its != sols.end()) { Trace("sygus-abduct") << "SmtEngine::getAbduct: solution is " << its->second << std::endl; - abd = Node::fromExpr(its->second); + abd = its->second; if (abd.getKind() == kind::LAMBDA) { abd = abd[1]; diff --git a/src/smt/command.cpp b/src/smt/command.cpp index f5c997318..2383167a6 100644 --- a/src/smt/command.cpp +++ b/src/smt/command.cpp @@ -577,7 +577,8 @@ void DeclareSygusVarCommand::invoke(SmtEngine* smtEngine) { try { - smtEngine->declareSygusVar(d_symbol, d_var, d_type); + smtEngine->declareSygusVar( + d_symbol, Node::fromExpr(d_var), TypeNode::fromType(d_type)); d_commandStatus = CommandSuccess::instance(); } catch (exception& e) @@ -622,7 +623,8 @@ void DeclareSygusFunctionCommand::invoke(SmtEngine* smtEngine) { try { - smtEngine->declareSygusFunctionVar(d_symbol, d_func, d_type); + smtEngine->declareSygusVar( + d_symbol, Node::fromExpr(d_func), TypeNode::fromType(d_type)); d_commandStatus = CommandSuccess::instance(); } catch (exception& e) @@ -687,13 +689,19 @@ void SynthFunCommand::invoke(SmtEngine* smtEngine) { try { - smtEngine->declareSynthFun(d_symbol, - d_fun.getExpr(), - d_grammar == nullptr - ? d_sort.getType() - : d_grammar->resolve().getType(), - d_isInv, - api::termVectorToExprs(d_vars)); + std::vector<Node> vns; + for (const api::Term& t : d_vars) + { + vns.push_back(Node::fromExpr(t.getExpr())); + } + smtEngine->declareSynthFun( + d_symbol, + Node::fromExpr(d_fun.getExpr()), + TypeNode::fromType(d_grammar == nullptr + ? d_sort.getType() + : d_grammar->resolve().getType()), + d_isInv, + vns); d_commandStatus = CommandSuccess::instance(); } catch (exception& e) diff --git a/src/smt/proof_post_processor.cpp b/src/smt/proof_post_processor.cpp new file mode 100644 index 000000000..5046dee92 --- /dev/null +++ b/src/smt/proof_post_processor.cpp @@ -0,0 +1,572 @@ +/********************* */ +/*! \file proof_post_processor.cpp + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Implementation of module for processing proof nodes + **/ + +#include "smt/proof_post_processor.h" + +#include "expr/skolem_manager.h" +#include "options/smt_options.h" +#include "preprocessing/assertion_pipeline.h" +#include "smt/smt_engine.h" +#include "smt/smt_statistics_registry.h" +#include "theory/builtin/proof_checker.h" +#include "theory/rewriter.h" + +using namespace CVC4::kind; +using namespace CVC4::theory; + +namespace CVC4 { +namespace smt { + +ProofPostprocessCallback::ProofPostprocessCallback(ProofNodeManager* pnm, + SmtEngine* smte, + ProofGenerator* pppg) + : d_pnm(pnm), d_smte(smte), d_pppg(pppg), d_wfpm(pnm) +{ + d_true = NodeManager::currentNM()->mkConst(true); + // always check whether to update ASSUME + d_elimRules.insert(PfRule::ASSUME); +} + +void ProofPostprocessCallback::initializeUpdate() +{ + d_assumpToProof.clear(); + d_wfAssumptions.clear(); +} + +void ProofPostprocessCallback::setEliminateRule(PfRule rule) +{ + d_elimRules.insert(rule); +} + +bool ProofPostprocessCallback::shouldUpdate(ProofNode* pn) +{ + return d_elimRules.find(pn->getRule()) != d_elimRules.end(); +} + +bool ProofPostprocessCallback::update(Node res, + PfRule id, + const std::vector<Node>& children, + const std::vector<Node>& args, + CDProof* cdp) +{ + Trace("smt-proof-pp-debug") << "- Post process " << id << " " << children + << " / " << args << std::endl; + + if (id == PfRule::ASSUME) + { + // we cache based on the assumption node, not the proof node, since there + // may be multiple occurrences of the same node. + Node f = args[0]; + std::shared_ptr<ProofNode> pfn; + std::map<Node, std::shared_ptr<ProofNode>>::iterator it = + d_assumpToProof.find(f); + if (it != d_assumpToProof.end()) + { + Trace("smt-proof-pp-debug") << "...already computed" << std::endl; + pfn = it->second; + } + else + { + Assert(d_pppg != nullptr); + // get proof from preprocess proof generator + pfn = d_pppg->getProofFor(f); + // print for debugging + if (pfn == nullptr) + { + Trace("smt-proof-pp-debug") + << "...no proof, possibly an input assumption" << std::endl; + } + else + { + Assert(pfn->getResult() == f); + if (Trace.isOn("smt-proof-pp")) + { + Trace("smt-proof-pp") + << "=== Connect proof for preprocessing: " << f << std::endl; + Trace("smt-proof-pp") << *pfn.get() << std::endl; + } + } + d_assumpToProof[f] = pfn; + } + if (pfn == nullptr) + { + // no update + return false; + } + // connect the proof + cdp->addProof(pfn); + return true; + } + Node ret = expandMacros(id, children, args, cdp); + Trace("smt-proof-pp-debug") << "...expanded = " << !ret.isNull() << std::endl; + return !ret.isNull(); +} + +Node ProofPostprocessCallback::expandMacros(PfRule id, + const std::vector<Node>& children, + const std::vector<Node>& args, + CDProof* cdp) +{ + if (d_elimRules.find(id) == d_elimRules.end()) + { + // not eliminated + return Node::null(); + } + // macro elimination + if (id == PfRule::MACRO_SR_EQ_INTRO) + { + // (TRANS + // (SUBS <children> :args args[0:1]) + // (REWRITE :args <t.substitute(x1,t1). ... .substitute(xn,tn)> args[2])) + std::vector<Node> tchildren; + Node t = args[0]; + Node ts; + if (!children.empty()) + { + std::vector<Node> sargs; + sargs.push_back(t); + MethodId sid = MethodId::SB_DEFAULT; + if (args.size() >= 2) + { + if (builtin::BuiltinProofRuleChecker::getMethodId(args[1], sid)) + { + sargs.push_back(args[1]); + } + } + ts = + builtin::BuiltinProofRuleChecker::applySubstitution(t, children, sid); + if (ts != t) + { + Node eq = t.eqNode(ts); + // apply SUBS proof rule if necessary + if (!update(eq, PfRule::SUBS, children, sargs, cdp)) + { + // if not elimianted, add as step + cdp->addStep(eq, PfRule::SUBS, children, sargs); + } + tchildren.push_back(eq); + } + } + else + { + // no substitute + ts = t; + } + std::vector<Node> rargs; + rargs.push_back(ts); + MethodId rid = MethodId::RW_REWRITE; + if (args.size() >= 3) + { + if (builtin::BuiltinProofRuleChecker::getMethodId(args[2], rid)) + { + rargs.push_back(args[2]); + } + } + builtin::BuiltinProofRuleChecker* builtinPfC = + static_cast<builtin::BuiltinProofRuleChecker*>( + d_pnm->getChecker()->getCheckerFor(PfRule::MACRO_SR_EQ_INTRO)); + Node tr = builtinPfC->applyRewrite(ts, rid); + if (ts != tr) + { + Node eq = ts.eqNode(tr); + // apply REWRITE proof rule + if (!update(eq, PfRule::REWRITE, {}, rargs, cdp)) + { + // if not elimianted, add as step + cdp->addStep(eq, PfRule::REWRITE, {}, rargs); + } + tchildren.push_back(eq); + } + if (t == tr) + { + // typically not necessary, but done to be robust + cdp->addStep(t.eqNode(tr), PfRule::REFL, {}, {t}); + return t.eqNode(tr); + } + // must add TRANS if two step + return addProofForTrans(tchildren, cdp); + } + else if (id == PfRule::MACRO_SR_PRED_INTRO) + { + std::vector<Node> tchildren; + std::vector<Node> sargs = args; + // take into account witness form, if necessary + if (d_wfpm.requiresWitnessFormIntro(args[0])) + { + Node weq = addProofForWitnessForm(args[0], cdp); + tchildren.push_back(weq); + // replace the first argument + sargs[0] = weq[1]; + } + // (TRUE_ELIM + // (TRANS + // ... proof of t = toWitness(t) ... + // (MACRO_SR_EQ_INTRO <children> :args (toWitness(t) args[1:])))) + // We call the expandMacros method on MACRO_SR_EQ_INTRO, where notice + // that this rule application is immediately expanded in the recursive + // call and not added to the proof. + Node conc = expandMacros(PfRule::MACRO_SR_EQ_INTRO, children, sargs, cdp); + tchildren.push_back(conc); + Assert(!conc.isNull() && conc.getKind() == EQUAL && conc[0] == sargs[0] + && conc[1] == d_true); + // transitivity if necessary + Node eq = addProofForTrans(tchildren, cdp); + + cdp->addStep(eq[0], PfRule::TRUE_ELIM, {eq}, {}); + Assert(eq[0] == args[0]); + return eq[0]; + } + else if (id == PfRule::MACRO_SR_PRED_ELIM) + { + // (TRUE_ELIM + // (TRANS + // (SYMM (MACRO_SR_EQ_INTRO children[1:] :args children[0] ++ args)) + // (TRUE_INTRO children[0]))) + std::vector<Node> schildren(children.begin() + 1, children.end()); + std::vector<Node> srargs; + srargs.push_back(children[0]); + srargs.insert(srargs.end(), args.begin(), args.end()); + Node conc = expandMacros(PfRule::MACRO_SR_EQ_INTRO, schildren, srargs, cdp); + Assert(!conc.isNull() && conc.getKind() == EQUAL && conc[0] == children[0]); + + Node eq1 = children[0].eqNode(d_true); + cdp->addStep(eq1, PfRule::TRUE_INTRO, {children[0]}, {}); + + Node concSym = conc[1].eqNode(conc[0]); + Node eq2 = conc[1].eqNode(d_true); + cdp->addStep(eq2, PfRule::TRANS, {concSym, eq1}, {}); + + cdp->addStep(conc[1], PfRule::TRUE_ELIM, {eq2}, {}); + return conc[1]; + } + else if (id == PfRule::MACRO_SR_PRED_TRANSFORM) + { + // (TRUE_ELIM + // (TRANS + // (MACRO_SR_EQ_INTRO children[1:] :args <args>) + // ... proof of a = wa + // (MACRO_SR_EQ_INTRO {} wa) + // (SYMM + // (MACRO_SR_EQ_INTRO children[1:] :args (children[0] args[1:])) + // ... proof of c = wc + // (MACRO_SR_EQ_INTRO {} wc)) + // (TRUE_INTRO children[0]))) + // where + // wa = toWitness(apply_SR(args[0])) and + // wc = toWitness(apply_SR(children[0])). + Trace("smt-proof-pp-debug") + << "Transform " << children[0] << " == " << args[0] << std::endl; + if (CDProof::isSame(children[0], args[0])) + { + // nothing to do + return children[0]; + } + std::vector<Node> tchildren; + std::vector<Node> schildren(children.begin() + 1, children.end()); + std::vector<Node> sargs = args; + // first, compute if we need + bool reqWitness = d_wfpm.requiresWitnessFormTransform(children[0], args[0]); + // convert both sides, in three steps, take symmetry of second chain + for (unsigned r = 0; r < 2; r++) + { + std::vector<Node> tchildrenr; + // first rewrite args[0], then children[0] + sargs[0] = r == 0 ? args[0] : children[0]; + // t = apply_SR(t) + Node eq = expandMacros(PfRule::MACRO_SR_EQ_INTRO, schildren, sargs, cdp); + Trace("smt-proof-pp-debug") + << "transform subs_rewrite (" << r << "): " << eq << std::endl; + Assert(!eq.isNull() && eq.getKind() == EQUAL && eq[0] == sargs[0]); + addToTransChildren(eq, tchildrenr); + // apply_SR(t) = toWitness(apply_SR(t)) + if (reqWitness) + { + Node weq = addProofForWitnessForm(eq[1], cdp); + Trace("smt-proof-pp-debug") + << "transform toWitness (" << r << "): " << weq << std::endl; + if (addToTransChildren(weq, tchildrenr)) + { + sargs[0] = weq[1]; + // toWitness(apply_SR(t)) = apply_SR(toWitness(apply_SR(t))) + // rewrite again, don't need substitution + Node weqr = expandMacros(PfRule::MACRO_SR_EQ_INTRO, {}, sargs, cdp); + Trace("smt-proof-pp-debug") << "transform rewrite_witness (" << r + << "): " << weqr << std::endl; + addToTransChildren(weqr, tchildrenr); + } + } + Trace("smt-proof-pp-debug") + << "transform connect (" << r << ")" << std::endl; + // add to overall chain + if (r == 0) + { + // add the current chain to the overall chain + tchildren.insert(tchildren.end(), tchildrenr.begin(), tchildrenr.end()); + } + else + { + // add the current chain to cdp + Node eqr = addProofForTrans(tchildrenr, cdp); + if (!eqr.isNull()) + { + // take symmetry of above and add it to the overall chain + addToTransChildren(eqr, tchildren, true); + } + } + Trace("smt-proof-pp-debug") + << "transform finish (" << r << ")" << std::endl; + } + + // children[0] = true + Node eq3 = children[0].eqNode(d_true); + Trace("smt-proof-pp-debug") << "transform true_intro: " << eq3 << std::endl; + cdp->addStep(eq3, PfRule::TRUE_INTRO, {children[0]}, {}); + addToTransChildren(eq3, tchildren); + + // apply transitivity if necessary + Node eq = addProofForTrans(tchildren, cdp); + + cdp->addStep(args[0], PfRule::TRUE_ELIM, {eq}, {}); + return args[0]; + } + else if (id == PfRule::SUBS) + { + // Notice that a naive way to reconstruct SUBS is to do a term conversion + // proof for each substitution. + // The proof of f(a) * { a -> g(b) } * { b -> c } = f(g(c)) is: + // TRANS( CONG{f}( a=g(b) ), CONG{f}( CONG{g}( b=c ) ) ) + // Notice that more optimal proofs are possible that do a single traversal + // over t. This is done by applying later substitutions to the range of + // previous substitutions, until a final simultaneous substitution is + // applied to t. For instance, in the above example, we first prove: + // CONG{g}( b = c ) + // by applying the second substitution { b -> c } to the range of the first, + // giving us a proof of g(b)=g(c). We then construct the updated proof + // by tranitivity: + // TRANS( a=g(b), CONG{g}( b=c ) ) + // We then apply the substitution { a -> g(c), b -> c } to f(a), to obtain: + // CONG{f}( TRANS( a=g(b), CONG{g}( b=c ) ) ) + // which notice is more compact than the proof above. + Node t = args[0]; + // get the kind of substitution + MethodId ids = MethodId::SB_DEFAULT; + if (args.size() >= 2) + { + builtin::BuiltinProofRuleChecker::getMethodId(args[1], ids); + } + std::vector<std::shared_ptr<CDProof>> pfs; + std::vector<Node> vvec; + std::vector<Node> svec; + std::vector<ProofGenerator*> pgs; + for (size_t i = 0, nchild = children.size(); i < nchild; i++) + { + // process in reverse order + size_t index = nchild - (i + 1); + // get the substitution + TNode var, subs; + builtin::BuiltinProofRuleChecker::getSubstitution( + children[index], var, subs, ids); + // apply the current substitution to the range + if (!vvec.empty()) + { + Node ss = + subs.substitute(vvec.begin(), vvec.end(), svec.begin(), svec.end()); + if (ss != subs) + { + // make the proof for the tranitivity step + std::shared_ptr<CDProof> pf = std::make_shared<CDProof>(d_pnm); + pfs.push_back(pf); + // prove the updated substitution + TConvProofGenerator tcg(d_pnm, nullptr, TConvPolicy::ONCE); + // add previous rewrite steps + for (unsigned j = 0, nvars = vvec.size(); j < nvars; j++) + { + tcg.addRewriteStep(vvec[j], svec[j], pgs[j]); + } + // get the proof for the update to the current substitution + Node seqss = subs.eqNode(ss); + std::shared_ptr<ProofNode> pfn = tcg.getProofFor(seqss); + Assert(pfn != nullptr); + // add the proof + pf->addProof(pfn); + // get proof for children[i] from cdp + pfn = cdp->getProofFor(children[i]); + pf->addProof(pfn); + // ensure we have a proof of var = subs + Node veqs = var.eqNode(subs); + if (veqs != children[index]) + { + // should be true intro or false intro + Assert(subs.isConst()); + pf->addStep(veqs, + subs.getConst<bool>() ? PfRule::TRUE_INTRO + : PfRule::FALSE_INTRO, + {children[index]}, + {}); + } + pf->addStep(var.eqNode(ss), PfRule::TRANS, {veqs, seqss}, {}); + // add to the substitution + vvec.push_back(var); + svec.push_back(ss); + pgs.push_back(pf.get()); + continue; + } + } + // just use equality from CDProof + vvec.push_back(var); + svec.push_back(subs); + pgs.push_back(cdp); + } + Node ts = t.substitute(vvec.begin(), vvec.end(), svec.begin(), svec.end()); + Node eq = t.eqNode(ts); + if (ts != t) + { + // should be implied by the substitution now + TConvProofGenerator tcpg(d_pnm, nullptr, TConvPolicy::ONCE); + for (unsigned j = 0, nvars = vvec.size(); j < nvars; j++) + { + tcpg.addRewriteStep(vvec[j], svec[j], pgs[j]); + } + // add the proof constructed by the term conversion utility + std::shared_ptr<ProofNode> pfn = tcpg.getProofFor(eq); + // should give a proof, if not, then tcpg does not agree with the + // substitution. + Assert(pfn != nullptr); + if (pfn != nullptr) + { + cdp->addProof(pfn); + } + } + else + { + // should not be necessary typically + cdp->addStep(eq, PfRule::REFL, {}, {t}); + } + return eq; + } + else if (id == PfRule::REWRITE) + { + // get the kind of rewrite + MethodId idr = MethodId::RW_REWRITE; + if (args.size() >= 2) + { + builtin::BuiltinProofRuleChecker::getMethodId(args[1], idr); + } + builtin::BuiltinProofRuleChecker* builtinPfC = + static_cast<builtin::BuiltinProofRuleChecker*>( + d_pnm->getChecker()->getCheckerFor(PfRule::REWRITE)); + Node ret = builtinPfC->applyRewrite(args[0], idr); + Node eq = args[0].eqNode(ret); + if (idr == MethodId::RW_REWRITE || idr == MethodId::RW_REWRITE_EQ_EXT) + { + // rewrites from theory::Rewriter + // automatically expand THEORY_REWRITE as well here if set + bool elimTR = + (d_elimRules.find(PfRule::THEORY_REWRITE) != d_elimRules.end()); + bool isExtEq = (idr == MethodId::RW_REWRITE_EQ_EXT); + // use rewrite with proof interface + Rewriter* rr = d_smte->getRewriter(); + TrustNode trn = rr->rewriteWithProof(args[0], elimTR, isExtEq); + std::shared_ptr<ProofNode> pfn = + trn.getGenerator()->getProofFor(trn.getProven()); + cdp->addProof(pfn); + Assert(trn.getNode() == ret); + } + else if (idr == MethodId::RW_EVALUATE) + { + // change to evaluate, which is never eliminated + cdp->addStep(eq, PfRule::EVALUATE, {}, {args[0]}); + } + else + { + // don't know how to eliminate + return Node::null(); + } + if (args[0] == ret) + { + // should not be necessary typically + cdp->addStep(eq, PfRule::REFL, {}, {args[0]}); + } + return eq; + } + + // TRUST, PREPROCESS, THEORY_LEMMA, THEORY_PREPROCESS? + + return Node::null(); +} + +Node ProofPostprocessCallback::addProofForWitnessForm(Node t, CDProof* cdp) +{ + Node tw = SkolemManager::getWitnessForm(t); + Node eq = t.eqNode(tw); + if (t == tw) + { + // not necessary, add REFL step + cdp->addStep(eq, PfRule::REFL, {}, {t}); + return eq; + } + std::shared_ptr<ProofNode> pn = d_wfpm.getProofFor(eq); + if (pn != nullptr) + { + // add the proof + cdp->addProof(pn); + } + else + { + Assert(false) << "ProofPostprocessCallback::addProofForWitnessForm: failed " + "to add proof for witness form of " + << t; + } + return eq; +} + +Node ProofPostprocessCallback::addProofForTrans( + const std::vector<Node>& tchildren, CDProof* cdp) +{ + size_t tsize = tchildren.size(); + if (tsize > 1) + { + Node lhs = tchildren[0][0]; + Node rhs = tchildren[tsize - 1][1]; + Node eq = lhs.eqNode(rhs); + cdp->addStep(eq, PfRule::TRANS, tchildren, {}); + return eq; + } + else if (tsize == 1) + { + return tchildren[0]; + } + return Node::null(); +} + +bool ProofPostprocessCallback::addToTransChildren(Node eq, + std::vector<Node>& tchildren, + bool isSymm) +{ + Assert(!eq.isNull()); + Assert(eq.getKind() == kind::EQUAL); + if (eq[0] == eq[1]) + { + return false; + } + Node equ = isSymm ? eq[1].eqNode(eq[0]) : eq; + Assert(tchildren.empty() + || (tchildren[tchildren.size() - 1].getKind() == kind::EQUAL + && tchildren[tchildren.size() - 1][1] == equ[0])); + tchildren.push_back(equ); + return true; +} + +} // namespace smt +} // namespace CVC4 diff --git a/src/smt/proof_post_processor.h b/src/smt/proof_post_processor.h new file mode 100644 index 000000000..8dc540701 --- /dev/null +++ b/src/smt/proof_post_processor.h @@ -0,0 +1,131 @@ +/********************* */ +/*! \file proof_post_processor.h + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief The module for processing proof nodes + **/ + +#include "cvc4_private.h" + +#ifndef CVC4__SMT__PROOF_POST_PROCESSOR_H +#define CVC4__SMT__PROOF_POST_PROCESSOR_H + +#include <map> +#include <unordered_set> + +#include "expr/proof_node_updater.h" +#include "smt/witness_form.h" + +namespace CVC4 { + +class SmtEngine; + +namespace smt { + +/** + * A callback class used by SmtEngine for post-processing proof nodes by + * connecting proofs of preprocessing, and expanding macro PfRule applications. + */ +class ProofPostprocessCallback : public ProofNodeUpdaterCallback +{ + public: + ProofPostprocessCallback(ProofNodeManager* pnm, + SmtEngine* smte, + ProofGenerator* pppg); + ~ProofPostprocessCallback() {} + /** + * Initialize, called once for each new ProofNode to process. This initializes + * static information to be used by successive calls to update. + */ + void initializeUpdate(); + /** + * Set eliminate rule, which adds rule to the list of rules we will eliminate + * during update. This adds rule to d_elimRules. Supported rules for + * elimination include MACRO_*, SUBS and REWRITE. Otherwise, this method + * has no effect. + */ + void setEliminateRule(PfRule rule); + /** Should proof pn be updated? */ + bool shouldUpdate(ProofNode* pn) override; + /** Update the proof rule application. */ + bool update(Node res, + PfRule id, + const std::vector<Node>& children, + const std::vector<Node>& args, + CDProof* cdp) override; + + private: + /** Common constants */ + Node d_true; + /** The proof node manager */ + ProofNodeManager* d_pnm; + /** Pointer to the SmtEngine, which should have proofs enabled */ + SmtEngine* d_smte; + /** The preprocessing proof generator */ + ProofGenerator* d_pppg; + /** The witness form proof generator */ + WitnessFormGenerator d_wfpm; + /** The witness form assumptions used in the proof */ + std::vector<Node> d_wfAssumptions; + /** Kinds of proof rules we are eliminating */ + std::unordered_set<PfRule, PfRuleHashFunction> d_elimRules; + //---------------------------------reset at the begining of each update + /** Mapping assumptions to their proof from preprocessing */ + std::map<Node, std::shared_ptr<ProofNode> > d_assumpToProof; + //---------------------------------end reset at the begining of each update + /** + * Expand rules in the given application, add the expanded proof to cdp. + * The set of rules we expand is configured by calls to setEliminateRule + * above. This method calls update to perform possible post-processing in the + * rules it introduces as a result of the expansion. + * + * @param id The rule of the application + * @param children The children of the application + * @param args The arguments of the application + * @param cdp The proof to add to + * @return The conclusion of the rule, or null if this rule is not eliminated. + */ + Node expandMacros(PfRule id, + const std::vector<Node>& children, + const std::vector<Node>& args, + CDProof* cdp); + /** + * Add proof for witness form. This returns the equality t = toWitness(t) + * and ensures that the proof of this equality has been added to cdp. + * Notice the proof of this fact may have open assumptions of the form: + * k = toWitness(k) + * where k is a skolem. Furthermore, note that all open assumptions of this + * form are available via d_wfpm.getWitnessFormEqs() in the remainder of + * the lifetime of this class. + */ + Node addProofForWitnessForm(Node t, CDProof* cdp); + /** + * Apply transivity if necessary for the arguments. The nodes in + * tchildren have been ordered such that they are legal arguments to TRANS. + * + * Returns the conclusion of the transitivity step, which is null if + * tchildren is empty. Also note if tchildren contains a single element, + * then no TRANS step is necessary to add to cdp. + * + * @param tchildren The children of a TRANS step + * @param cdp The proof to add the TRANS step to + * @return The conclusion of the TRANS step. + */ + Node addProofForTrans(const std::vector<Node>& tchildren, CDProof* cdp); + /** Add eq (or its symmetry) to transivity children, if not reflexive */ + bool addToTransChildren(Node eq, + std::vector<Node>& tchildren, + bool isSymm = false); +}; + +} // namespace smt +} // namespace CVC4 + +#endif diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index e709406d8..83f3cb5e0 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -96,6 +96,7 @@ #include "smt/smt_engine_state.h" #include "smt/smt_engine_stats.h" #include "smt/smt_solver.h" +#include "smt/sygus_solver.h" #include "smt/term_formula_removal.h" #include "smt/update_ostream.h" #include "smt_util/boolean_simplification.h" @@ -140,56 +141,6 @@ extern const char* const plf_signatures; namespace smt { -/** - * This is an inelegant solution, but for the present, it will work. - * The point of this is to separate the public and private portions of - * the SmtEngine class, so that smt_engine.h doesn't - * include "expr/node.h", which is a private CVC4 header (and can lead - * to linking errors due to the improper inlining of non-visible symbols - * into user code!). - * - * The "real" solution (that which is usually implemented) is to move - * ALL the implementation to SmtEnginePrivate and maintain a - * heap-allocated instance of it in SmtEngine. SmtEngine (the public - * one) becomes an "interface shell" which simply acts as a forwarder - * of method calls. - */ -class SmtEnginePrivate -{ - public: - - /* Finishes the initialization of the private portion of SMTEngine. */ - void finishInit(); - - /*------------------- sygus utils ------------------*/ - /** - * sygus variables declared (from "declare-var" and "declare-fun" commands) - * - * The SyGuS semantics for declared variables is that they are implicitly - * universally quantified in the constraints. - */ - std::vector<Node> d_sygusVars; - /** sygus constraints */ - std::vector<Node> d_sygusConstraints; - /** functions-to-synthesize */ - std::vector<Node> d_sygusFunSymbols; - /** - * Whether we need to reconstruct the sygus conjecture. - */ - CDO<bool> d_sygusConjectureStale; - /*------------------- end of sygus utils ------------------*/ - - public: - SmtEnginePrivate(SmtEngine& smt) - : d_sygusConjectureStale(smt.getUserContext(), true) - { - } - - ~SmtEnginePrivate() - { - } -};/* class SmtEnginePrivate */ - }/* namespace CVC4::smt */ SmtEngine::SmtEngine(ExprManager* em, Options* optr) @@ -206,13 +157,13 @@ SmtEngine::SmtEngine(ExprManager* em, Options* optr) d_proofManager(nullptr), d_rewriter(new theory::Rewriter()), d_definedFunctions(nullptr), + d_sygusSolver(nullptr), d_abductSolver(nullptr), d_assignments(nullptr), d_defineCommands(), d_logic(), d_originalOptions(), d_isInternalSubsolver(false), - d_private(nullptr), d_statisticsRegistry(nullptr), d_stats(nullptr), d_resourceManager(nullptr), @@ -245,7 +196,6 @@ SmtEngine::SmtEngine(ExprManager* em, Options* optr) d_optm.reset(new smt::OptionsManager(&d_options, d_resourceManager.get())); d_pp.reset( new smt::Preprocessor(*this, getUserContext(), *d_absValues.get())); - d_private.reset(new smt::SmtEnginePrivate(*this)); // listen to node manager events d_nodeManager->subscribeEvents(d_snmListener.get()); // listen to resource out @@ -255,6 +205,8 @@ SmtEngine::SmtEngine(ExprManager* em, Options* optr) // make the SMT solver d_smtSolver.reset( new SmtSolver(*this, *d_state, d_resourceManager.get(), *d_pp, *d_stats)); + // make the SyGuS solver + d_sygusSolver.reset(new SygusSolver(*d_smtSolver, *d_pp, getUserContext())); // The ProofManager is constructed before any other proof objects such as // SatProof and TheoryProofs. The TheoryProofEngine and the SatProof are @@ -421,10 +373,11 @@ SmtEngine::~SmtEngine() d_exprNames.reset(nullptr); d_dumpm.reset(nullptr); + d_sygusSolver.reset(nullptr); + d_smtSolver.reset(nullptr); d_stats.reset(nullptr); - d_private.reset(nullptr); d_nodeManager->unsubscribeEvents(d_snmListener.get()); d_snmListener.reset(nullptr); d_routListener.reset(nullptr); @@ -1145,57 +1098,26 @@ Result SmtEngine::assertFormula(const Node& formula, bool inUnsatCore) -------------------------------------------------------------------------- */ -void SmtEngine::declareSygusVar(const std::string& id, Expr var, Type type) +void SmtEngine::declareSygusVar(const std::string& id, Node var, TypeNode type) { SmtScope smts(this); finishInit(); - d_private->d_sygusVars.push_back(Node::fromExpr(var)); - Trace("smt") << "SmtEngine::declareSygusVar: " << var << "\n"; - Dump("raw-benchmark") << DeclareSygusVarCommand(id, var, type); - // don't need to set that the conjecture is stale -} - -void SmtEngine::declareSygusFunctionVar(const std::string& id, - Expr var, - Type type) -{ - SmtScope smts(this); - finishInit(); - d_private->d_sygusVars.push_back(Node::fromExpr(var)); - Trace("smt") << "SmtEngine::declareSygusFunctionVar: " << var << "\n"; - Dump("raw-benchmark") << DeclareSygusVarCommand(id, var, type); - + d_sygusSolver->declareSygusVar(id, var, type); + Dump("raw-benchmark") << DeclareSygusVarCommand( + id, var.toExpr(), type.toType()); // don't need to set that the conjecture is stale } void SmtEngine::declareSynthFun(const std::string& id, - Expr func, - Type sygusType, + Node func, + TypeNode sygusType, bool isInv, - const std::vector<Expr>& vars) + const std::vector<Node>& vars) { SmtScope smts(this); finishInit(); d_state->doPendingPops(); - Node fn = Node::fromExpr(func); - d_private->d_sygusFunSymbols.push_back(fn); - if (!vars.empty()) - { - Expr bvl = d_exprManager->mkExpr(kind::BOUND_VAR_LIST, vars); - std::vector<Expr> attr_val_bvl; - attr_val_bvl.push_back(bvl); - setUserAttribute("sygus-synth-fun-var-list", func, attr_val_bvl, ""); - } - // whether sygus type encodes syntax restrictions - TypeNode stn = TypeNode::fromType(sygusType); - if (sygusType.isDatatype() && stn.getDType().isSygus()) - { - Node sym = d_nodeManager->mkBoundVar("sfproxy", stn); - std::vector<Expr> attr_value; - attr_value.push_back(sym.toExpr()); - setUserAttribute("sygus-synth-grammar", func, attr_value, ""); - } - Trace("smt") << "SmtEngine::declareSynthFun: " << func << "\n"; + d_sygusSolver->declareSynthFun(id, func, sygusType, isInv, vars); // !!! TEMPORARY: We cannot construct a SynthFunCommand since we cannot // construct a Term-level Grammar from a Node-level sygus TypeNode. Thus we @@ -1203,187 +1125,50 @@ void SmtEngine::declareSynthFun(const std::string& id, if (Dump.isOn("raw-benchmark")) { - std::vector<Node> nodeVars; - nodeVars.reserve(vars.size()); - for (const Expr& var : vars) - { - nodeVars.push_back(Node::fromExpr(var)); - } - std::stringstream ss; Printer::getPrinter(options::outputLanguage()) - ->toStreamCmdSynthFun( - ss, - id, - nodeVars, - func.getType().isFunction() - ? TypeNode::fromType(func.getType()).getRangeType() - : TypeNode::fromType(func.getType()), - isInv, - TypeNode::fromType(sygusType)); - + ->toStreamCmdSynthFun(ss, + id, + vars, + func.getType().isFunction() + ? func.getType().getRangeType() + : func.getType(), + isInv, + sygusType); + // must print it on the standard output channel since it is not possible // to print anything except for commands with Dump. std::ostream& out = *d_options.getOut(); out << ss.str() << std::endl; } - - // sygus conjecture is now stale - setSygusConjectureStale(); } -void SmtEngine::assertSygusConstraint(const Node& constraint) +void SmtEngine::assertSygusConstraint(Node constraint) { SmtScope smts(this); finishInit(); - d_private->d_sygusConstraints.push_back(constraint); - - Trace("smt") << "SmtEngine::assertSygusConstrant: " << constraint << "\n"; + d_sygusSolver->assertSygusConstraint(constraint); Dump("raw-benchmark") << SygusConstraintCommand(constraint.toExpr()); - // sygus conjecture is now stale - setSygusConjectureStale(); } -void SmtEngine::assertSygusInvConstraint(const Expr& inv, - const Expr& pre, - const Expr& trans, - const Expr& post) +void SmtEngine::assertSygusInvConstraint(Node inv, + Node pre, + Node trans, + Node post) { SmtScope smts(this); finishInit(); - // build invariant constraint - - // get variables (regular and their respective primed versions) - std::vector<Node> terms, vars, primed_vars; - terms.push_back(Node::fromExpr(inv)); - terms.push_back(Node::fromExpr(pre)); - terms.push_back(Node::fromExpr(trans)); - terms.push_back(Node::fromExpr(post)); - // variables are built based on the invariant type - FunctionType t = static_cast<FunctionType>(inv.getType()); - std::vector<Type> argTypes = t.getArgTypes(); - for (const Type& ti : argTypes) - { - TypeNode tn = TypeNode::fromType(ti); - vars.push_back(d_nodeManager->mkBoundVar(tn)); - d_private->d_sygusVars.push_back(vars.back()); - std::stringstream ss; - ss << vars.back() << "'"; - primed_vars.push_back(d_nodeManager->mkBoundVar(ss.str(), tn)); - d_private->d_sygusVars.push_back(primed_vars.back()); - } - - // make relevant terms; 0 -> Inv, 1 -> Pre, 2 -> Trans, 3 -> Post - for (unsigned i = 0; i < 4; ++i) - { - Node op = terms[i]; - Trace("smt-debug") << "Make inv-constraint term #" << i << " : " << op - << " with type " << op.getType() << "...\n"; - std::vector<Node> children; - children.push_back(op); - // transition relation applied over both variable lists - if (i == 2) - { - children.insert(children.end(), vars.begin(), vars.end()); - children.insert(children.end(), primed_vars.begin(), primed_vars.end()); - } - else - { - children.insert(children.end(), vars.begin(), vars.end()); - } - terms[i] = d_nodeManager->mkNode(kind::APPLY_UF, children); - // make application of Inv on primed variables - if (i == 0) - { - children.clear(); - children.push_back(op); - children.insert(children.end(), primed_vars.begin(), primed_vars.end()); - terms.push_back(d_nodeManager->mkNode(kind::APPLY_UF, children)); - } - } - // make constraints - std::vector<Node> conj; - conj.push_back(d_nodeManager->mkNode(kind::IMPLIES, terms[1], terms[0])); - Node term0_and_2 = d_nodeManager->mkNode(kind::AND, terms[0], terms[2]); - conj.push_back(d_nodeManager->mkNode(kind::IMPLIES, term0_and_2, terms[4])); - conj.push_back(d_nodeManager->mkNode(kind::IMPLIES, terms[0], terms[3])); - Node constraint = d_nodeManager->mkNode(kind::AND, conj); - - d_private->d_sygusConstraints.push_back(constraint); - - Trace("smt") << "SmtEngine::assertSygusInvConstrant: " << constraint << "\n"; - Dump("raw-benchmark") << SygusInvConstraintCommand(inv, pre, trans, post); - // sygus conjecture is now stale - setSygusConjectureStale(); + d_sygusSolver->assertSygusInvConstraint(inv, pre, trans, post); + Dump("raw-benchmark") << SygusInvConstraintCommand( + inv.toExpr(), pre.toExpr(), trans.toExpr(), post.toExpr()); } Result SmtEngine::checkSynth() { SmtScope smts(this); - - if (options::incrementalSolving()) - { - // TODO (project #7) - throw ModalException( - "Cannot make check-synth commands when incremental solving is enabled"); - } - std::vector<Node> query; - if (d_private->d_sygusConjectureStale) - { - // build synthesis conjecture from asserted constraints and declared - // variables/functions - Node sygusVar = - d_nodeManager->mkSkolem("sygus", d_nodeManager->booleanType()); - Node inst_attr = d_nodeManager->mkNode(kind::INST_ATTRIBUTE, sygusVar); - Node sygusAttr = d_nodeManager->mkNode(kind::INST_PATTERN_LIST, inst_attr); - std::vector<Node> bodyv; - Trace("smt") << "Sygus : Constructing sygus constraint...\n"; - unsigned n_constraints = d_private->d_sygusConstraints.size(); - Node body = n_constraints == 0 - ? d_nodeManager->mkConst(true) - : (n_constraints == 1 - ? d_private->d_sygusConstraints[0] - : d_nodeManager->mkNode( - kind::AND, d_private->d_sygusConstraints)); - body = body.notNode(); - Trace("smt") << "...constructed sygus constraint " << body << std::endl; - if (!d_private->d_sygusVars.empty()) - { - Node boundVars = - d_nodeManager->mkNode(kind::BOUND_VAR_LIST, d_private->d_sygusVars); - body = d_nodeManager->mkNode(kind::EXISTS, boundVars, body); - Trace("smt") << "...constructed exists " << body << std::endl; - } - if (!d_private->d_sygusFunSymbols.empty()) - { - Node boundVars = d_nodeManager->mkNode(kind::BOUND_VAR_LIST, - d_private->d_sygusFunSymbols); - body = d_nodeManager->mkNode(kind::FORALL, boundVars, body, sygusAttr); - } - Trace("smt") << "...constructed forall " << body << std::endl; - - // set attribute for synthesis conjecture - setUserAttribute("sygus", sygusVar.toExpr(), {}, ""); - - Trace("smt") << "Check synthesis conjecture: " << body << std::endl; - Dump("raw-benchmark") << CheckSynthCommand(); - - d_private->d_sygusConjectureStale = false; - - // TODO (project #7): if incremental, we should push a context and assert - query.push_back(body); - } - - Result r = checkSatInternal(query, true, false); - - // Check that synthesis solutions satisfy the conjecture - if (options::checkSynthSol() - && r.asSatisfiabilityResult().isSat() == Result::UNSAT) - { - checkSynthSolution(); - } - return r; + finishInit(); + return d_sygusSolver->checkSynth(*d_asserts); } /* @@ -2032,160 +1817,12 @@ void SmtEngine::checkModel(bool hardFailure) { Notice() << "SmtEngine::checkModel(): all assertions checked out OK !" << endl; } -void SmtEngine::checkSynthSolution() -{ - NodeManager* nm = NodeManager::currentNM(); - Notice() << "SmtEngine::checkSynthSolution(): checking synthesis solution" << endl; - std::map<Node, std::map<Node, Node>> sol_map; - TheoryEngine* te = getTheoryEngine(); - Assert(te != nullptr); - /* Get solutions and build auxiliary vectors for substituting */ - if (!te->getSynthSolutions(sol_map)) - { - InternalError() << "SmtEngine::checkSynthSolution(): No solution to check!"; - return; - } - if (sol_map.empty()) - { - InternalError() << "SmtEngine::checkSynthSolution(): Got empty solution!"; - return; - } - Trace("check-synth-sol") << "Got solution map:\n"; - // the set of synthesis conjectures in our assertions - std::unordered_set<Node, NodeHashFunction> conjs; - // For each of the above conjectures, the functions-to-synthesis and their - // solutions. This is used as a substitution below. - std::map<Node, std::vector<Node>> fvarMap; - std::map<Node, std::vector<Node>> fsolMap; - for (const std::pair<const Node, std::map<Node, Node>>& cmap : sol_map) - { - Trace("check-synth-sol") << "For conjecture " << cmap.first << ":\n"; - conjs.insert(cmap.first); - std::vector<Node>& fvars = fvarMap[cmap.first]; - std::vector<Node>& fsols = fsolMap[cmap.first]; - for (const std::pair<const Node, Node>& pair : cmap.second) - { - Trace("check-synth-sol") - << " " << pair.first << " --> " << pair.second << "\n"; - fvars.push_back(pair.first); - fsols.push_back(pair.second); - } - } - Trace("check-synth-sol") << "Starting new SMT Engine\n"; - /* Start new SMT engine to check solutions */ - SmtEngine solChecker(d_exprManager, &d_options); - solChecker.setIsInternalSubsolver(); - solChecker.setLogic(getLogicInfo()); - solChecker.getOptions().set(options::checkSynthSol, false); - solChecker.getOptions().set(options::sygusRecFun, false); - - Trace("check-synth-sol") << "Retrieving assertions\n"; - // Build conjecture from original assertions - context::CDList<Node>* al = d_asserts->getAssertionList(); - if (al == nullptr) - { - Trace("check-synth-sol") << "No assertions to check\n"; - return; - } - // auxiliary assertions - std::vector<Node> auxAssertions; - // expand definitions cache - std::unordered_map<Node, Node, NodeHashFunction> cache; - for (const Node& assertion : *al) - { - Notice() << "SmtEngine::checkSynthSolution(): checking assertion " - << assertion << endl; - Trace("check-synth-sol") << "Retrieving assertion " << assertion << "\n"; - // Apply any define-funs from the problem. - Node n = d_pp->expandDefinitions(assertion, cache); - Notice() << "SmtEngine::checkSynthSolution(): -- expands to " << n << endl; - Trace("check-synth-sol") << "Expanded assertion " << n << "\n"; - if (conjs.find(n) == conjs.end()) - { - Trace("check-synth-sol") << "It is an auxiliary assertion\n"; - auxAssertions.push_back(n); - } - else - { - Trace("check-synth-sol") << "It is a synthesis conjecture\n"; - } - } - // check all conjectures - for (const Node& conj : conjs) - { - // get the solution for this conjecture - std::vector<Node>& fvars = fvarMap[conj]; - std::vector<Node>& fsols = fsolMap[conj]; - // Apply solution map to conjecture body - Node conjBody; - /* Whether property is quantifier free */ - if (conj[1].getKind() != kind::EXISTS) - { - conjBody = conj[1].substitute( - fvars.begin(), fvars.end(), fsols.begin(), fsols.end()); - } - else - { - conjBody = conj[1][1].substitute( - fvars.begin(), fvars.end(), fsols.begin(), fsols.end()); - - /* Skolemize property */ - std::vector<Node> vars, skos; - for (unsigned j = 0, size = conj[1][0].getNumChildren(); j < size; ++j) - { - vars.push_back(conj[1][0][j]); - std::stringstream ss; - ss << "sk_" << j; - skos.push_back(nm->mkSkolem(ss.str(), conj[1][0][j].getType())); - Trace("check-synth-sol") << "\tSkolemizing " << conj[1][0][j] << " to " - << skos.back() << "\n"; - } - conjBody = conjBody.substitute( - vars.begin(), vars.end(), skos.begin(), skos.end()); - } - Notice() << "SmtEngine::checkSynthSolution(): -- body substitutes to " - << conjBody << endl; - Trace("check-synth-sol") << "Substituted body of assertion to " << conjBody - << "\n"; - solChecker.assertFormula(conjBody); - // Assert all auxiliary assertions. This may include recursive function - // definitions that were added as assertions to the sygus problem. - for (const Node& a : auxAssertions) - { - solChecker.assertFormula(a); - } - Result r = solChecker.checkSat(); - Notice() << "SmtEngine::checkSynthSolution(): result is " << r << endl; - Trace("check-synth-sol") << "Satsifiability check: " << r << "\n"; - if (r.asSatisfiabilityResult().isUnknown()) - { - InternalError() << "SmtEngine::checkSynthSolution(): could not check " - "solution, result " - "unknown."; - } - else if (r.asSatisfiabilityResult().isSat()) - { - InternalError() - << "SmtEngine::checkSynthSolution(): produced solution leads to " - "satisfiable negated conjecture."; - } - solChecker.resetAssertions(); - } -} - void SmtEngine::checkInterpol(Expr interpol, const std::vector<Expr>& easserts, const Node& conj) { } -void SmtEngine::checkAbduct(Node a) -{ - Assert(a.getType().isBoolean()); - // check it with the abduction solver - return d_abductSolver->checkAbduct(a); -} - // TODO(#1108): Simplify the error reporting of this method. UnsatCore SmtEngine::getUnsatCore() { Trace("smt") << "SMT getUnsatCore()" << endl; @@ -2248,26 +1885,11 @@ void SmtEngine::printSynthSolution( std::ostream& out ) { te->printSynthSolution(out); } -bool SmtEngine::getSynthSolutions(std::map<Expr, Expr>& sol_map) +bool SmtEngine::getSynthSolutions(std::map<Node, Node>& solMap) { SmtScope smts(this); finishInit(); - std::map<Node, std::map<Node, Node>> sol_mapn; - TheoryEngine* te = getTheoryEngine(); - Assert(te != nullptr); - // fail if the theory engine does not have synthesis solutions - if (!te->getSynthSolutions(sol_mapn)) - { - return false; - } - for (std::pair<const Node, std::map<Node, Node>>& cs : sol_mapn) - { - for (std::pair<const Node, Node>& s : cs.second) - { - sol_map[s.first.toExpr()] = s.second.toExpr(); - } - } - return true; + return d_sygusSolver->getSynthSolutions(solMap); } Expr SmtEngine::doQuantifierElimination(const Expr& e, bool doFull, bool strict) @@ -2701,15 +2323,4 @@ ResourceManager* SmtEngine::getResourceManager() DumpManager* SmtEngine::getDumpManager() { return d_dumpm.get(); } -void SmtEngine::setSygusConjectureStale() -{ - if (d_private->d_sygusConjectureStale) - { - // already stale - return; - } - d_private->d_sygusConjectureStale = true; - // TODO (project #7): if incremental, we should pop a context -} - }/* CVC4 namespace */ diff --git a/src/smt/smt_engine.h b/src/smt/smt_engine.h index 1c71e371e..a6688578d 100644 --- a/src/smt/smt_engine.h +++ b/src/smt/smt_engine.h @@ -104,6 +104,7 @@ class OptionsManager; class Preprocessor; /** Subsolvers */ class SmtSolver; +class SygusSolver; class AbductionSolver; /** * Representation of a defined function. We keep these around in @@ -114,7 +115,6 @@ class AbductionSolver; class DefinedFunction; struct SmtEngineStatistics; -class SmtEnginePrivate; class SmtScope; class ProcessAssertions; @@ -146,7 +146,6 @@ class CVC4_PUBLIC SmtEngine friend class ::CVC4::api::Solver; // TODO (Issue #1096): Remove this friend relationship. friend class ::CVC4::preprocessing::PreprocessingPassContext; - friend class ::CVC4::smt::SmtEnginePrivate; friend class ::CVC4::smt::SmtEngineState; friend class ::CVC4::smt::SmtScope; friend class ::CVC4::smt::ProcessAssertions; @@ -417,22 +416,18 @@ class CVC4_PUBLIC SmtEngine /*---------------------------- sygus commands ---------------------------*/ /** - * Add variable declaration. + * Add sygus variable declaration. * * Declared SyGuS variables may be used in SyGuS constraints, in which they * are assumed to be universally quantified. - */ - void declareSygusVar(const std::string& id, Expr var, Type type); - - /** - * Add a function variable declaration. * - * Is SyGuS semantics declared functions are treated in the same manner as + * In SyGuS semantics, declared functions are treated in the same manner as * declared variables, i.e. as universally quantified (function) variables * which can occur in the SyGuS constraints that compose the conjecture to - * which a function is being synthesized. + * which a function is being synthesized. Thus declared functions should use + * this method as well. */ - void declareSygusFunctionVar(const std::string& id, Expr var, Type type); + void declareSygusVar(const std::string& id, Node var, TypeNode type); /** * Add a function-to-synthesize declaration. @@ -450,13 +445,13 @@ class CVC4_PUBLIC SmtEngine * corresponding to this declaration, so that it can be properly printed. */ void declareSynthFun(const std::string& id, - Expr func, - Type type, + Node func, + TypeNode type, bool isInv, - const std::vector<Expr>& vars); + const std::vector<Node>& vars); /** Add a regular sygus constraint.*/ - void assertSygusConstraint(const Node& constraint); + void assertSygusConstraint(Node constraint); /** * Add an invariant constraint. @@ -473,10 +468,7 @@ class CVC4_PUBLIC SmtEngine * The regular and primed variables are retrieved from the declaration of the * invariant-to-synthesize. */ - void assertSygusInvConstraint(const Expr& inv, - const Expr& pre, - const Expr& trans, - const Expr& post); + void assertSygusInvConstraint(Node inv, Node pre, Node trans, Node post); /** * Assert a synthesis conjecture to the current context and call * check(). Returns sat, unsat, or unknown result. @@ -576,18 +568,18 @@ class CVC4_PUBLIC SmtEngine * This method returns true if we are in a state immediately preceeded by * a successful call to checkSynth. * - * This method adds entries to sol_map that map functions-to-synthesize with + * This method adds entries to solMap that map functions-to-synthesize with * their solutions, for all active conjectures. This should be called * immediately after the solver answers unsat for sygus input. * * Specifically, given a sygus conjecture of the form * exists x1...xn. forall y1...yn. P( x1...xn, y1...yn ) * where x1...xn are second order bound variables, we map each xi to - * lambda term in sol_map such that - * forall y1...yn. P( sol_map[x1]...sol_map[xn], y1...yn ) + * lambda term in solMap such that + * forall y1...yn. P( solMap[x1]...solMap[xn], y1...yn ) * is a valid formula. */ - bool getSynthSolutions(std::map<Expr, Expr>& sol_map); + bool getSynthSolutions(std::map<Node, Node>& solMap); /** * Do quantifier elimination. @@ -952,16 +944,6 @@ class CVC4_PUBLIC SmtEngine void checkModel(bool hardFailure = true); /** - * Check that a solution to a synthesis conjecture is indeed a solution. - * - * The check is made by determining if the negation of the synthesis - * conjecture in which the functions-to-synthesize have been replaced by the - * synthesized solutions, which is a quantifier-free formula, is - * unsatisfiable. If not, then the found solutions are wrong. - */ - void checkSynthSolution(); - - /** * Check that a solution to an interpolation problem is indeed a solution. * * The check is made by determining that the assertions imply the solution of @@ -973,16 +955,6 @@ class CVC4_PUBLIC SmtEngine const Node& conj); /** - * Check that a solution to an abduction conjecture is indeed a solution. - * - * The check is made by determining that the assertions conjoined with the - * solution to the abduction problem (a) is SAT, and the assertions conjoined - * with the abduct and the goal is UNSAT. If these criteria are not met, an - * internal error is thrown. - */ - void checkAbduct(Node a); - - /** * This is called by the destructor, just before destroying the * PropEngine, TheoryEngine, and DecisionEngine (in that order). It * is important because there are destruction ordering issues @@ -1129,6 +1101,9 @@ class CVC4_PUBLIC SmtEngine /** An index of our defined functions */ DefinedFunctionMap* d_definedFunctions; + /** The solver for sygus queries */ + std::unique_ptr<smt::SygusSolver> d_sygusSolver; + /** The solver for abduction queries */ std::unique_ptr<smt::AbductionSolver> d_abductSolver; /** @@ -1164,11 +1139,6 @@ class CVC4_PUBLIC SmtEngine */ std::map<std::string, Integer> d_commandVerbosity; - /** - * A private utility class to SmtEngine. - */ - std::unique_ptr<smt::SmtEnginePrivate> d_private; - std::unique_ptr<StatisticsRegistry> d_statisticsRegistry; std::unique_ptr<smt::SmtEngineStatistics> d_stats; @@ -1195,23 +1165,6 @@ class CVC4_PUBLIC SmtEngine * or another SmtEngine is created. */ std::unique_ptr<smt::SmtScope> d_scope; - /*---------------------------- sygus commands ---------------------------*/ - - /** - * Set sygus conjecture is stale. The sygus conjecture is stale if either: - * (1) no sygus conjecture has been added as an assertion to this SMT engine, - * (2) there is a sygus conjecture that has been added as an assertion - * internally to this SMT engine, and there have been further calls such that - * the asserted conjecture is no longer up-to-date. - * - * This method should be called when new sygus constraints are asserted and - * when functions-to-synthesize are declared. This function pops a user - * context if we are in incremental mode and the sygus conjecture was - * previously not stale. - */ - void setSygusConjectureStale(); - - /*------------------------- end of sygus commands ------------------------*/ }; /* class SmtEngine */ /* -------------------------------------------------------------------------- */ diff --git a/src/smt/sygus_solver.cpp b/src/smt/sygus_solver.cpp new file mode 100644 index 000000000..0fc63d198 --- /dev/null +++ b/src/smt/sygus_solver.cpp @@ -0,0 +1,402 @@ +/********************* */ +/*! \file sygus_solver.cpp + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief The solver for sygus queries + **/ + +#include "smt/sygus_solver.h" + +#include "expr/dtype.h" +#include "options/quantifiers_options.h" +#include "options/smt_options.h" +#include "smt/preprocessor.h" +#include "smt/smt_solver.h" +#include "theory/smt_engine_subsolver.h" +#include "theory/theory_engine.h" + +using namespace CVC4::theory; +using namespace CVC4::kind; + +namespace CVC4 { +namespace smt { + +SygusSolver::SygusSolver(SmtSolver& sms, + Preprocessor& pp, + context::UserContext* u) + : d_smtSolver(sms), d_pp(pp), d_sygusConjectureStale(u, true) +{ +} + +SygusSolver::~SygusSolver() {} + +void SygusSolver::declareSygusVar(const std::string& id, + Node var, + TypeNode type) +{ + Trace("smt") << "SygusSolver::declareSygusVar: " << id << " " << var << " " + << type << "\n"; + Assert(var.getType() == type); + d_sygusVars.push_back(var); + // don't need to set that the conjecture is stale +} + +void SygusSolver::declareSynthFun(const std::string& id, + Node fn, + TypeNode sygusType, + bool isInv, + const std::vector<Node>& vars) +{ + Trace("smt") << "SygusSolver::declareSynthFun: " << id << "\n"; + NodeManager* nm = NodeManager::currentNM(); + TheoryEngine* te = d_smtSolver.getTheoryEngine(); + Assert(te != nullptr); + d_sygusFunSymbols.push_back(fn); + if (!vars.empty()) + { + Node bvl = nm->mkNode(BOUND_VAR_LIST, vars); + std::vector<Node> attr_val_bvl; + attr_val_bvl.push_back(bvl); + te->setUserAttribute("sygus-synth-fun-var-list", fn, attr_val_bvl, ""); + } + // whether sygus type encodes syntax restrictions + if (sygusType.isDatatype() && sygusType.getDType().isSygus()) + { + Node sym = nm->mkBoundVar("sfproxy", sygusType); + std::vector<Node> attr_value; + attr_value.push_back(sym); + te->setUserAttribute("sygus-synth-grammar", fn, attr_value, ""); + } + + // sygus conjecture is now stale + setSygusConjectureStale(); +} + +void SygusSolver::assertSygusConstraint(Node constraint) +{ + Trace("smt") << "SygusSolver::assertSygusConstrant: " << constraint << "\n"; + d_sygusConstraints.push_back(constraint); + + // sygus conjecture is now stale + setSygusConjectureStale(); +} + +void SygusSolver::assertSygusInvConstraint(Node inv, + Node pre, + Node trans, + Node post) +{ + Trace("smt") << "SygusSolver::assertSygusInvConstrant: " << inv << " " << pre + << " " << trans << " " << post << "\n"; + // build invariant constraint + + // get variables (regular and their respective primed versions) + std::vector<Node> terms; + std::vector<Node> vars; + std::vector<Node> primed_vars; + terms.push_back(inv); + terms.push_back(pre); + terms.push_back(trans); + terms.push_back(post); + // variables are built based on the invariant type + NodeManager* nm = NodeManager::currentNM(); + std::vector<TypeNode> argTypes = inv.getType().getArgTypes(); + for (const TypeNode& tn : argTypes) + { + vars.push_back(nm->mkBoundVar(tn)); + d_sygusVars.push_back(vars.back()); + std::stringstream ss; + ss << vars.back() << "'"; + primed_vars.push_back(nm->mkBoundVar(ss.str(), tn)); + d_sygusVars.push_back(primed_vars.back()); + } + + // make relevant terms; 0 -> Inv, 1 -> Pre, 2 -> Trans, 3 -> Post + for (unsigned i = 0; i < 4; ++i) + { + Node op = terms[i]; + Trace("smt-debug") << "Make inv-constraint term #" << i << " : " << op + << " with type " << op.getType() << "...\n"; + std::vector<Node> children; + children.push_back(op); + // transition relation applied over both variable lists + if (i == 2) + { + children.insert(children.end(), vars.begin(), vars.end()); + children.insert(children.end(), primed_vars.begin(), primed_vars.end()); + } + else + { + children.insert(children.end(), vars.begin(), vars.end()); + } + terms[i] = nm->mkNode(APPLY_UF, children); + // make application of Inv on primed variables + if (i == 0) + { + children.clear(); + children.push_back(op); + children.insert(children.end(), primed_vars.begin(), primed_vars.end()); + terms.push_back(nm->mkNode(APPLY_UF, children)); + } + } + // make constraints + std::vector<Node> conj; + conj.push_back(nm->mkNode(IMPLIES, terms[1], terms[0])); + Node term0_and_2 = nm->mkNode(AND, terms[0], terms[2]); + conj.push_back(nm->mkNode(IMPLIES, term0_and_2, terms[4])); + conj.push_back(nm->mkNode(IMPLIES, terms[0], terms[3])); + Node constraint = nm->mkNode(AND, conj); + + d_sygusConstraints.push_back(constraint); + + // sygus conjecture is now stale + setSygusConjectureStale(); +} + +Result SygusSolver::checkSynth(Assertions& as) +{ + if (options::incrementalSolving()) + { + // TODO (project #7) + throw ModalException( + "Cannot make check-synth commands when incremental solving is enabled"); + } + Trace("smt") << "SygusSolver::checkSynth" << std::endl; + std::vector<Node> query; + if (d_sygusConjectureStale) + { + NodeManager* nm = NodeManager::currentNM(); + // build synthesis conjecture from asserted constraints and declared + // variables/functions + Node sygusVar = nm->mkSkolem("sygus", nm->booleanType()); + Node inst_attr = nm->mkNode(INST_ATTRIBUTE, sygusVar); + Node sygusAttr = nm->mkNode(INST_PATTERN_LIST, inst_attr); + std::vector<Node> bodyv; + Trace("smt") << "Sygus : Constructing sygus constraint...\n"; + size_t nconstraints = d_sygusConstraints.size(); + Node body = nconstraints == 0 + ? nm->mkConst(true) + : (nconstraints == 1 ? d_sygusConstraints[0] + : nm->mkNode(AND, d_sygusConstraints)); + body = body.notNode(); + Trace("smt") << "...constructed sygus constraint " << body << std::endl; + if (!d_sygusVars.empty()) + { + Node boundVars = nm->mkNode(BOUND_VAR_LIST, d_sygusVars); + body = nm->mkNode(EXISTS, boundVars, body); + Trace("smt") << "...constructed exists " << body << std::endl; + } + if (!d_sygusFunSymbols.empty()) + { + Node boundVars = nm->mkNode(BOUND_VAR_LIST, d_sygusFunSymbols); + body = nm->mkNode(FORALL, boundVars, body, sygusAttr); + } + Trace("smt") << "...constructed forall " << body << std::endl; + + // set attribute for synthesis conjecture + TheoryEngine* te = d_smtSolver.getTheoryEngine(); + te->setUserAttribute("sygus", sygusVar, {}, ""); + + Trace("smt") << "Check synthesis conjecture: " << body << std::endl; + Dump("raw-benchmark") << CheckSynthCommand(); + + d_sygusConjectureStale = false; + + // TODO (project #7): if incremental, we should push a context and assert + query.push_back(body); + } + + Result r = d_smtSolver.checkSatisfiability(as, query, false, false); + + // Check that synthesis solutions satisfy the conjecture + if (options::checkSynthSol() + && r.asSatisfiabilityResult().isSat() == Result::UNSAT) + { + checkSynthSolution(as); + } + return r; +} + +bool SygusSolver::getSynthSolutions(std::map<Node, Node>& sol_map) +{ + Trace("smt") << "SygusSolver::getSynthSolutions" << std::endl; + std::map<Node, std::map<Node, Node>> sol_mapn; + // fail if the theory engine does not have synthesis solutions + TheoryEngine* te = d_smtSolver.getTheoryEngine(); + Assert(te != nullptr); + if (!te->getSynthSolutions(sol_mapn)) + { + return false; + } + for (std::pair<const Node, std::map<Node, Node>>& cs : sol_mapn) + { + for (std::pair<const Node, Node>& s : cs.second) + { + sol_map[s.first] = s.second; + } + } + return true; +} + +void SygusSolver::checkSynthSolution(Assertions& as) +{ + NodeManager* nm = NodeManager::currentNM(); + Notice() << "SygusSolver::checkSynthSolution(): checking synthesis solution" + << std::endl; + std::map<Node, std::map<Node, Node>> sol_map; + // Get solutions and build auxiliary vectors for substituting + TheoryEngine* te = d_smtSolver.getTheoryEngine(); + if (!te->getSynthSolutions(sol_map)) + { + InternalError() + << "SygusSolver::checkSynthSolution(): No solution to check!"; + return; + } + if (sol_map.empty()) + { + InternalError() << "SygusSolver::checkSynthSolution(): Got empty solution!"; + return; + } + Trace("check-synth-sol") << "Got solution map:\n"; + // the set of synthesis conjectures in our assertions + std::unordered_set<Node, NodeHashFunction> conjs; + // For each of the above conjectures, the functions-to-synthesis and their + // solutions. This is used as a substitution below. + std::map<Node, std::vector<Node>> fvarMap; + std::map<Node, std::vector<Node>> fsolMap; + for (const std::pair<const Node, std::map<Node, Node>>& cmap : sol_map) + { + Trace("check-synth-sol") << "For conjecture " << cmap.first << ":\n"; + conjs.insert(cmap.first); + std::vector<Node>& fvars = fvarMap[cmap.first]; + std::vector<Node>& fsols = fsolMap[cmap.first]; + for (const std::pair<const Node, Node>& pair : cmap.second) + { + Trace("check-synth-sol") + << " " << pair.first << " --> " << pair.second << "\n"; + fvars.push_back(pair.first); + fsols.push_back(pair.second); + } + } + Trace("check-synth-sol") << "Starting new SMT Engine\n"; + + Trace("check-synth-sol") << "Retrieving assertions\n"; + // Build conjecture from original assertions + context::CDList<Node>* alist = as.getAssertionList(); + if (alist == nullptr) + { + Trace("check-synth-sol") << "No assertions to check\n"; + return; + } + // auxiliary assertions + std::vector<Node> auxAssertions; + // expand definitions cache + std::unordered_map<Node, Node, NodeHashFunction> cache; + for (Node assertion : *alist) + { + Notice() << "SygusSolver::checkSynthSolution(): checking assertion " + << assertion << std::endl; + Trace("check-synth-sol") << "Retrieving assertion " << assertion << "\n"; + // Apply any define-funs from the problem. + Node n = d_pp.expandDefinitions(assertion, cache); + Notice() << "SygusSolver::checkSynthSolution(): -- expands to " << n + << std::endl; + Trace("check-synth-sol") << "Expanded assertion " << n << "\n"; + if (conjs.find(n) == conjs.end()) + { + Trace("check-synth-sol") << "It is an auxiliary assertion\n"; + auxAssertions.push_back(n); + } + else + { + Trace("check-synth-sol") << "It is a synthesis conjecture\n"; + } + } + // check all conjectures + for (Node conj : conjs) + { + // Start new SMT engine to check solutions + std::unique_ptr<SmtEngine> solChecker; + initializeSubsolver(solChecker); + solChecker->getOptions().set(options::checkSynthSol, false); + solChecker->getOptions().set(options::sygusRecFun, false); + // get the solution for this conjecture + std::vector<Node>& fvars = fvarMap[conj]; + std::vector<Node>& fsols = fsolMap[conj]; + // Apply solution map to conjecture body + Node conjBody; + /* Whether property is quantifier free */ + if (conj[1].getKind() != EXISTS) + { + conjBody = conj[1].substitute( + fvars.begin(), fvars.end(), fsols.begin(), fsols.end()); + } + else + { + conjBody = conj[1][1].substitute( + fvars.begin(), fvars.end(), fsols.begin(), fsols.end()); + + /* Skolemize property */ + std::vector<Node> vars, skos; + for (unsigned j = 0, size = conj[1][0].getNumChildren(); j < size; ++j) + { + vars.push_back(conj[1][0][j]); + std::stringstream ss; + ss << "sk_" << j; + skos.push_back(nm->mkSkolem(ss.str(), conj[1][0][j].getType())); + Trace("check-synth-sol") << "\tSkolemizing " << conj[1][0][j] << " to " + << skos.back() << "\n"; + } + conjBody = conjBody.substitute( + vars.begin(), vars.end(), skos.begin(), skos.end()); + } + Notice() << "SygusSolver::checkSynthSolution(): -- body substitutes to " + << conjBody << std::endl; + Trace("check-synth-sol") + << "Substituted body of assertion to " << conjBody << "\n"; + solChecker->assertFormula(conjBody); + // Assert all auxiliary assertions. This may include recursive function + // definitions that were added as assertions to the sygus problem. + for (Node a : auxAssertions) + { + solChecker->assertFormula(a); + } + Result r = solChecker->checkSat(); + Notice() << "SygusSolver::checkSynthSolution(): result is " << r + << std::endl; + Trace("check-synth-sol") << "Satsifiability check: " << r << "\n"; + if (r.asSatisfiabilityResult().isUnknown()) + { + InternalError() << "SygusSolver::checkSynthSolution(): could not check " + "solution, result " + "unknown."; + } + else if (r.asSatisfiabilityResult().isSat()) + { + InternalError() + << "SygusSolver::checkSynthSolution(): produced solution leads to " + "satisfiable negated conjecture."; + } + } +} + +void SygusSolver::setSygusConjectureStale() +{ + if (d_sygusConjectureStale) + { + // already stale + return; + } + d_sygusConjectureStale = true; + // TODO (project #7): if incremental, we should pop a context +} + +} // namespace smt +} // namespace CVC4 diff --git a/src/smt/sygus_solver.h b/src/smt/sygus_solver.h new file mode 100644 index 000000000..468535da1 --- /dev/null +++ b/src/smt/sygus_solver.h @@ -0,0 +1,182 @@ +/********************* */ +/*! \file sygus_solver.h + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief The solver for sygus queries + **/ + +#include "cvc4_private.h" + +#ifndef CVC4__SMT__SYGUS_SOLVER_H +#define CVC4__SMT__SYGUS_SOLVER_H + +#include "context/cdo.h" +#include "expr/node.h" +#include "expr/type_node.h" +#include "smt/assertions.h" +#include "util/result.h" + +namespace CVC4 { +namespace smt { + +class Preprocessor; +class SmtSolver; + +/** + * A solver for sygus queries. + * + * This class is responsible for responding to check-synth commands. It calls + * check satisfiability using an underlying SmtSolver object. + * + * It also maintains a reference to a preprocessor for implementing + * checkSynthSolution. + */ +class SygusSolver +{ + public: + SygusSolver(SmtSolver& sms, Preprocessor& pp, context::UserContext* u); + ~SygusSolver(); + + /** + * Add variable declaration. + * + * Declared SyGuS variables may be used in SyGuS constraints, in which they + * are assumed to be universally quantified. + * + * In SyGuS semantics, declared functions are treated in the same manner as + * declared variables, i.e. as universally quantified (function) variables + * which can occur in the SyGuS constraints that compose the conjecture to + * which a function is being synthesized. Thus declared functions should use + * this method as well. + */ + void declareSygusVar(const std::string& id, Node var, TypeNode type); + + /** + * Add a function-to-synthesize declaration. + * + * The given type may not correspond to the actual function type but to a + * datatype encoding the syntax restrictions for the + * function-to-synthesize. In this case this information is stored to be used + * during solving. + * + * vars contains the arguments of the function-to-synthesize. These variables + * are also stored to be used during solving. + * + * isInv determines whether the function-to-synthesize is actually an + * invariant. This information is necessary if we are dumping a command + * corresponding to this declaration, so that it can be properly printed. + */ + void declareSynthFun(const std::string& id, + Node func, + TypeNode type, + bool isInv, + const std::vector<Node>& vars); + + /** Add a regular sygus constraint.*/ + void assertSygusConstraint(Node constraint); + + /** + * Add an invariant constraint. + * + * Invariant constraints are not explicitly declared: they are given in terms + * of the invariant-to-synthesize, the pre condition, transition relation and + * post condition. The actual constraint is built based on the inputs of these + * place holder predicates : + * + * PRE(x) -> INV(x) + * INV() ^ TRANS(x, x') -> INV(x') + * INV(x) -> POST(x) + * + * The regular and primed variables are retrieved from the declaration of the + * invariant-to-synthesize. + */ + void assertSygusInvConstraint(Node inv, Node pre, Node trans, Node post); + /** + * Assert a synthesis conjecture to the current context and call + * check(). Returns sat, unsat, or unknown result. + * + * The actual synthesis conjecture is built based on the previously + * communicated information to this module (universal variables, defined + * functions, functions-to-synthesize, and which constraints compose it). The + * built conjecture is a higher-order formula of the form + * + * exists f1...fn . forall v1...vm . F + * + * in which f1...fn are the functions-to-synthesize, v1...vm are the declared + * universal variables and F is the set of declared constraints. + */ + Result checkSynth(Assertions& as); + /** + * Get synth solution. + * + * This method returns true if we are in a state immediately preceeded by + * a successful call to checkSynth. + * + * This method adds entries to sol_map that map functions-to-synthesize with + * their solutions, for all active conjectures. This should be called + * immediately after the solver answers unsat for sygus input. + * + * Specifically, given a sygus conjecture of the form + * exists x1...xn. forall y1...yn. P( x1...xn, y1...yn ) + * where x1...xn are second order bound variables, we map each xi to + * lambda term in sol_map such that + * forall y1...yn. P( sol_map[x1]...sol_map[xn], y1...yn ) + * is a valid formula. + */ + bool getSynthSolutions(std::map<Node, Node>& sol_map); + + private: + /** + * Check that a solution to a synthesis conjecture is indeed a solution. + * + * The check is made by determining if the negation of the synthesis + * conjecture in which the functions-to-synthesize have been replaced by the + * synthesized solutions, which is a quantifier-free formula, is + * unsatisfiable. If not, then the found solutions are wrong. + */ + void checkSynthSolution(Assertions& as); + /** + * Set sygus conjecture is stale. The sygus conjecture is stale if either: + * (1) no sygus conjecture has been added as an assertion to this SMT engine, + * (2) there is a sygus conjecture that has been added as an assertion + * internally to this SMT engine, and there have been further calls such that + * the asserted conjecture is no longer up-to-date. + * + * This method should be called when new sygus constraints are asserted and + * when functions-to-synthesize are declared. This function pops a user + * context if we are in incremental mode and the sygus conjecture was + * previously not stale. + */ + void setSygusConjectureStale(); + /** The SMT solver, which is used during checkSynth. */ + SmtSolver& d_smtSolver; + /** The preprocessor, used for checkSynthSolution. */ + Preprocessor& d_pp; + /** + * sygus variables declared (from "declare-var" and "declare-fun" commands) + * + * The SyGuS semantics for declared variables is that they are implicitly + * universally quantified in the constraints. + */ + std::vector<Node> d_sygusVars; + /** sygus constraints */ + std::vector<Node> d_sygusConstraints; + /** functions-to-synthesize */ + std::vector<Node> d_sygusFunSymbols; + /** + * Whether we need to reconstruct the sygus conjecture. + */ + context::CDO<bool> d_sygusConjectureStale; +}; + +} // namespace smt +} // namespace CVC4 + +#endif /* CVC4__SMT__SYGUS_SOLVER_H */ diff --git a/src/theory/arrays/theory_arrays.cpp b/src/theory/arrays/theory_arrays.cpp index 85759b75f..51e1b367c 100644 --- a/src/theory/arrays/theory_arrays.cpp +++ b/src/theory/arrays/theory_arrays.cpp @@ -1090,119 +1090,41 @@ void TheoryArrays::computeCareGraph() bool TheoryArrays::collectModelInfo(TheoryModel* m) { - set<Node> termSet; - - // Compute terms appearing in assertions and shared terms + // Compute terms appearing in assertions and shared terms, and also + // include additional reads due to the RIntro1 and RIntro2 rules. + std::set<Node> termSet; computeRelevantTerms(termSet); - // Compute arrays that we need to produce representatives for and also make sure RIntro1 reads are included in the relevant set of reads + // Send the equality engine information to the model + if (!m->assertEqualityEngine(d_equalityEngine, &termSet)) + { + return false; + } + NodeManager* nm = NodeManager::currentNM(); + // Compute arrays that we need to produce representatives for std::vector<Node> arrays; - bool computeRep, isArray; eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(d_equalityEngine); for (; !eqcs_i.isFinished(); ++eqcs_i) { Node eqc = (*eqcs_i); - isArray = eqc.getType().isArray(); - if (!isArray) { + if (!eqc.getType().isArray()) + { + // not an array, skip continue; } - computeRep = false; eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, d_equalityEngine); for (; !eqc_i.isFinished(); ++eqc_i) { Node n = *eqc_i; // If this EC is an array type and it contains something other than STORE nodes, we have to compute a representative explicitly - if (isArray && termSet.find(n) != termSet.end()) { - if (n.getKind() == kind::STORE) { - // Make sure RIntro1 reads are included - Node r = nm->mkNode(kind::SELECT, n, n[1]); - Trace("arrays::collectModelInfo") << "TheoryArrays::collectModelInfo, adding RIntro1 read: " << r << endl; - termSet.insert(r); - } - else if (!computeRep) { + if (termSet.find(n) != termSet.end()) + { + if (n.getKind() != kind::STORE) + { arrays.push_back(n); - computeRep = true; - } - } - } - } - - // Now do a fixed-point iteration to get all reads that need to be included because of RIntro2 rule - bool changed; - do { - changed = false; - eqcs_i = eq::EqClassesIterator(d_equalityEngine); - for (; !eqcs_i.isFinished(); ++eqcs_i) { - Node eqc = (*eqcs_i); - eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, d_equalityEngine); - for (; !eqc_i.isFinished(); ++eqc_i) { - Node n = *eqc_i; - if (n.getKind() == kind::SELECT && termSet.find(n) != termSet.end()) { - - // Find all terms equivalent to n[0] and get corresponding read terms - Node array_eqc = d_equalityEngine->getRepresentative(n[0]); - eq::EqClassIterator array_eqc_i = - eq::EqClassIterator(array_eqc, d_equalityEngine); - for (; !array_eqc_i.isFinished(); ++array_eqc_i) { - Node arr = *array_eqc_i; - if (arr.getKind() == kind::STORE - && termSet.find(arr) != termSet.end() - && !d_equalityEngine->areEqual(arr[1], n[1])) - { - Node r = nm->mkNode(kind::SELECT, arr, n[1]); - if (termSet.find(r) == termSet.end() - && d_equalityEngine->hasTerm(r)) - { - Trace("arrays::collectModelInfo") << "TheoryArrays::collectModelInfo, adding RIntro2(a) read: " << r << endl; - termSet.insert(r); - changed = true; - } - r = nm->mkNode(kind::SELECT, arr[0], n[1]); - if (termSet.find(r) == termSet.end() - && d_equalityEngine->hasTerm(r)) - { - Trace("arrays::collectModelInfo") << "TheoryArrays::collectModelInfo, adding RIntro2(b) read: " << r << endl; - termSet.insert(r); - changed = true; - } - } - } - - // Find all stores in which n[0] appears and get corresponding read terms - const CTNodeList* instores = d_infoMap.getInStores(array_eqc); - size_t it = 0; - for(; it < instores->size(); ++it) { - TNode instore = (*instores)[it]; - Assert(instore.getKind() == kind::STORE); - if (termSet.find(instore) != termSet.end() - && !d_equalityEngine->areEqual(instore[1], n[1])) - { - Node r = nm->mkNode(kind::SELECT, instore, n[1]); - if (termSet.find(r) == termSet.end() - && d_equalityEngine->hasTerm(r)) - { - Trace("arrays::collectModelInfo") << "TheoryArrays::collectModelInfo, adding RIntro2(c) read: " << r << endl; - termSet.insert(r); - changed = true; - } - r = nm->mkNode(kind::SELECT, instore[0], n[1]); - if (termSet.find(r) == termSet.end() - && d_equalityEngine->hasTerm(r)) - { - Trace("arrays::collectModelInfo") << "TheoryArrays::collectModelInfo, adding RIntro2(d) read: " << r << endl; - termSet.insert(r); - changed = true; - } - } - } + break; } } } - } while (changed); - - // Send the equality engine information to the model - if (!m->assertEqualityEngine(d_equalityEngine, &termSet)) - { - return false; } // Build a list of all the relevant reads, indexed by the store representative @@ -2339,6 +2261,136 @@ TrustNode TheoryArrays::expandDefinition(Node node) return TrustNode::null(); } +void TheoryArrays::computeRelevantTerms(std::set<Node>& termSet, + bool includeShared) +{ + // include all standard terms + std::set<Kind> irrKinds; + computeRelevantTermsInternal(termSet, irrKinds, includeShared); + + NodeManager* nm = NodeManager::currentNM(); + // make sure RIntro1 reads are included in the relevant set of reads + eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(d_equalityEngine); + for (; !eqcs_i.isFinished(); ++eqcs_i) + { + Node eqc = (*eqcs_i); + if (!eqc.getType().isArray()) + { + // not an array, skip + continue; + } + eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, d_equalityEngine); + for (; !eqc_i.isFinished(); ++eqc_i) + { + Node n = *eqc_i; + if (termSet.find(n) != termSet.end()) + { + if (n.getKind() == kind::STORE) + { + // Make sure RIntro1 reads are included + Node r = nm->mkNode(kind::SELECT, n, n[1]); + Trace("arrays::collectModelInfo") + << "TheoryArrays::collectModelInfo, adding RIntro1 read: " << r + << endl; + termSet.insert(r); + } + } + } + } + + // Now do a fixed-point iteration to get all reads that need to be included + // because of RIntro2 rule + bool changed; + do + { + changed = false; + eqcs_i = eq::EqClassesIterator(d_equalityEngine); + for (; !eqcs_i.isFinished(); ++eqcs_i) + { + Node eqc = (*eqcs_i); + eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, d_equalityEngine); + for (; !eqc_i.isFinished(); ++eqc_i) + { + Node n = *eqc_i; + if (n.getKind() == kind::SELECT && termSet.find(n) != termSet.end()) + { + // Find all terms equivalent to n[0] and get corresponding read terms + Node array_eqc = d_equalityEngine->getRepresentative(n[0]); + eq::EqClassIterator array_eqc_i = + eq::EqClassIterator(array_eqc, d_equalityEngine); + for (; !array_eqc_i.isFinished(); ++array_eqc_i) + { + Node arr = *array_eqc_i; + if (arr.getKind() == kind::STORE + && termSet.find(arr) != termSet.end() + && !d_equalityEngine->areEqual(arr[1], n[1])) + { + Node r = nm->mkNode(kind::SELECT, arr, n[1]); + if (termSet.find(r) == termSet.end() + && d_equalityEngine->hasTerm(r)) + { + Trace("arrays::collectModelInfo") + << "TheoryArrays::collectModelInfo, adding RIntro2(a) " + "read: " + << r << endl; + termSet.insert(r); + changed = true; + } + r = nm->mkNode(kind::SELECT, arr[0], n[1]); + if (termSet.find(r) == termSet.end() + && d_equalityEngine->hasTerm(r)) + { + Trace("arrays::collectModelInfo") + << "TheoryArrays::collectModelInfo, adding RIntro2(b) " + "read: " + << r << endl; + termSet.insert(r); + changed = true; + } + } + } + + // Find all stores in which n[0] appears and get corresponding read + // terms + const CTNodeList* instores = d_infoMap.getInStores(array_eqc); + size_t it = 0; + for (; it < instores->size(); ++it) + { + TNode instore = (*instores)[it]; + Assert(instore.getKind() == kind::STORE); + if (termSet.find(instore) != termSet.end() + && !d_equalityEngine->areEqual(instore[1], n[1])) + { + Node r = nm->mkNode(kind::SELECT, instore, n[1]); + if (termSet.find(r) == termSet.end() + && d_equalityEngine->hasTerm(r)) + { + Trace("arrays::collectModelInfo") + << "TheoryArrays::collectModelInfo, adding RIntro2(c) " + "read: " + << r << endl; + termSet.insert(r); + changed = true; + } + r = nm->mkNode(kind::SELECT, instore[0], n[1]); + if (termSet.find(r) == termSet.end() + && d_equalityEngine->hasTerm(r)) + { + Trace("arrays::collectModelInfo") + << "TheoryArrays::collectModelInfo, adding RIntro2(d) " + "read: " + << r << endl; + termSet.insert(r); + changed = true; + } + } + } + } + } + } + } while (changed); +} + }/* CVC4::theory::arrays namespace */ }/* CVC4::theory namespace */ }/* CVC4 namespace */ diff --git a/src/theory/arrays/theory_arrays.h b/src/theory/arrays/theory_arrays.h index f1cd2ea14..530f8e0e1 100644 --- a/src/theory/arrays/theory_arrays.h +++ b/src/theory/arrays/theory_arrays.h @@ -499,6 +499,12 @@ class TheoryArrays : public Theory { */ Node getNextDecisionRequest(); + /** + * Compute relevant terms. This includes additional select nodes for the + * RIntro1 and RIntro2 rules. + */ + void computeRelevantTerms(std::set<Node>& termSet, + bool includeShared = true) override; };/* class TheoryArrays */ }/* CVC4::theory::arrays namespace */ diff --git a/src/theory/datatypes/theory_datatypes.cpp b/src/theory/datatypes/theory_datatypes.cpp index 4b38ad6bd..e625f57eb 100644 --- a/src/theory/datatypes/theory_datatypes.cpp +++ b/src/theory/datatypes/theory_datatypes.cpp @@ -1522,11 +1522,11 @@ bool TheoryDatatypes::collectModelInfo(TheoryModel* m) Trace("dt-model") << std::endl; printModelDebug( "dt-model" ); Trace("dt-model") << std::endl; - - set<Node> termSet; - + + std::set<Node> termSet; + // Compute terms appearing in assertions and shared terms, and in inferred equalities - getRelevantTerms(termSet); + computeRelevantTerms(termSet); //combine the equality engine if (!m->assertEqualityEngine(d_equalityEngine, &termSet)) @@ -2250,12 +2250,14 @@ Node TheoryDatatypes::mkAnd( std::vector< TNode >& assumptions ) { } } -void TheoryDatatypes::getRelevantTerms( std::set<Node>& termSet ) { +void TheoryDatatypes::computeRelevantTerms(std::set<Node>& termSet, + bool includeShared) +{ // Compute terms appearing in assertions and shared terms - std::set<Kind> irr_kinds; + std::set<Kind> irrKinds; // testers are not relevant for model construction - irr_kinds.insert(APPLY_TESTER); - computeRelevantTerms(termSet, irr_kinds); + irrKinds.insert(APPLY_TESTER); + computeRelevantTermsInternal(termSet, irrKinds, includeShared); Trace("dt-cmi") << "Have " << termSet.size() << " relevant terms..." << std::endl; diff --git a/src/theory/datatypes/theory_datatypes.h b/src/theory/datatypes/theory_datatypes.h index a68caca94..bdc13b5e5 100644 --- a/src/theory/datatypes/theory_datatypes.h +++ b/src/theory/datatypes/theory_datatypes.h @@ -357,8 +357,6 @@ private: void instantiate( EqcInfo* eqc, Node n ); /** must communicate fact */ bool mustCommunicateFact( Node n, Node exp ); - /** get relevant terms */ - void getRelevantTerms( std::set<Node>& termSet ); private: //equality queries bool hasTerm( TNode a ); @@ -367,7 +365,13 @@ private: bool areCareDisequal( TNode x, TNode y ); TNode getRepresentative( TNode a ); - private: + /** + * Compute relevant terms. In addition to all terms in assertions and shared + * terms, this includes datatypes in non-singleton equivalence classes. + */ + void computeRelevantTerms(std::set<Node>& termSet, + bool includeShared = true) override; + /** sygus symmetry breaking utility */ std::unique_ptr<SygusExtension> d_sygusExtension; diff --git a/src/theory/quantifiers/sygus/sygus_interpol.cpp b/src/theory/quantifiers/sygus/sygus_interpol.cpp index 0d08140d3..0ecd888e0 100644 --- a/src/theory/quantifiers/sygus/sygus_interpol.cpp +++ b/src/theory/quantifiers/sygus/sygus_interpol.cpp @@ -274,10 +274,10 @@ void SygusInterpol::mkSygusConjecture(Node itp, bool SygusInterpol::findInterpol(Expr& interpol, Node itp) { // get the synthesis solution - std::map<Expr, Expr> sols; + std::map<Node, Node> sols; d_subSolver->getSynthSolutions(sols); Assert(sols.size() == 1); - std::map<Expr, Expr>::iterator its = sols.find(itp.toExpr()); + std::map<Node, Node>::iterator its = sols.find(itp); if (its == sols.end()) { Trace("sygus-interpol") @@ -288,7 +288,7 @@ bool SygusInterpol::findInterpol(Expr& interpol, Node itp) } Trace("sygus-interpol") << "SmtEngine::getInterpol: solution is " << its->second << std::endl; - Node interpoln = Node::fromExpr(its->second); + Node interpoln = its->second; // replace back the created variables to original symbols. Node interpoln_reduced; if (interpoln.getKind() == kind::LAMBDA) @@ -336,18 +336,17 @@ bool SygusInterpol::SolveInterpolation(const std::string& name, createVariables(itpGType.isNull()); for (Node var : d_vars) { - d_subSolver->declareSygusVar(name, var.toExpr(), var.getType().toType()); + d_subSolver->declareSygusVar(name, var, var.getType()); } - std::vector<Expr> vars_empty; + std::vector<Node> vars_empty; TypeNode grammarType = setSynthGrammar(itpGType, axioms, conj); Node itp = mkPredicate(name); - d_subSolver->declareSynthFun( - name, itp.toExpr(), grammarType.toType(), false, vars_empty); + d_subSolver->declareSynthFun(name, itp, grammarType, false, vars_empty); mkSygusConjecture(itp, axioms, conj); Trace("sygus-interpol") << "SmtEngine::getInterpol: made conjecture : " << d_sygusConj << ", solving for " - << d_sygusConj[0][0].toExpr() << std::endl; - d_subSolver->assertSygusConstraint(d_sygusConj.toExpr()); + << d_sygusConj[0][0] << std::endl; + d_subSolver->assertSygusConstraint(d_sygusConj); Trace("sygus-interpol") << " SmtEngine::getInterpol check sat..." << std::endl; diff --git a/src/theory/quantifiers_engine.cpp b/src/theory/quantifiers_engine.cpp index eafcc1e85..9fdf7e7aa 100644 --- a/src/theory/quantifiers_engine.cpp +++ b/src/theory/quantifiers_engine.cpp @@ -176,6 +176,7 @@ QuantifiersEngine::QuantifiersEngine(context::Context* c, context::UserContext* u, TheoryEngine* te) : d_te(te), + d_masterEqualityEngine(nullptr), d_eq_query(new quantifiers::EqualityQueryQuantifiersEngine(c, this)), d_tr_trie(new inst::TriggerTrie), d_model(nullptr), @@ -274,6 +275,11 @@ QuantifiersEngine::QuantifiersEngine(context::Context* c, QuantifiersEngine::~QuantifiersEngine() {} +void QuantifiersEngine::setMasterEqualityEngine(eq::EqualityEngine* mee) +{ + d_masterEqualityEngine = mee; +} + context::Context* QuantifiersEngine::getSatContext() { return d_te->theoryOf(THEORY_QUANTIFIERS)->getSatContext(); @@ -1258,7 +1264,7 @@ QuantifiersEngine::Statistics::~Statistics(){ eq::EqualityEngine* QuantifiersEngine::getMasterEqualityEngine() const { - return d_te->getMasterEqualityEngine(); + return d_masterEqualityEngine; } Node QuantifiersEngine::getInternalRepresentative( Node a, Node q, int index ){ diff --git a/src/theory/quantifiers_engine.h b/src/theory/quantifiers_engine.h index dd86c0db9..eca108587 100644 --- a/src/theory/quantifiers_engine.h +++ b/src/theory/quantifiers_engine.h @@ -49,6 +49,7 @@ class QuantifiersEnginePrivate; // TODO: organize this more/review this, github issue #1163 class QuantifiersEngine { + friend class ::CVC4::TheoryEngine; typedef context::CDHashMap< Node, bool, NodeHashFunction > BoolMap; typedef context::CDList<Node> NodeList; typedef context::CDList<bool> BoolList; @@ -102,6 +103,10 @@ public: inst::TriggerTrie* getTriggerDatabase() const; //---------------------- end utilities private: + //---------------------- private initialization + /** Set the master equality engine */ + void setMasterEqualityEngine(eq::EqualityEngine* mee); + //---------------------- end private initialization /** * Maps quantified formulas to the module that owns them, if any module has * specifically taken ownership of it. @@ -316,6 +321,8 @@ public: private: /** reference to theory engine object */ TheoryEngine* d_te; + /** Pointer to the master equality engine */ + eq::EqualityEngine* d_masterEqualityEngine; /** vector of utilities for quantifiers */ std::vector<QuantifiersUtil*> d_util; /** vector of modules for quantifiers */ diff --git a/src/theory/relevance_manager.cpp b/src/theory/relevance_manager.cpp new file mode 100644 index 000000000..71962ee07 --- /dev/null +++ b/src/theory/relevance_manager.cpp @@ -0,0 +1,315 @@ +/********************* */ +/*! \file relevance_manager.cpp + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Implementation of relevance manager. + **/ + +#include "theory/relevance_manager.h" + +using namespace CVC4::kind; + +namespace CVC4 { +namespace theory { + +RelevanceManager::RelevanceManager(context::UserContext* userContext, + Valuation val) + : d_val(val), d_input(userContext), d_computed(false), d_success(false) +{ +} + +void RelevanceManager::notifyPreprocessedAssertions( + const std::vector<Node>& assertions) +{ + // add to input list, which is user-context dependent + std::vector<Node> toProcess; + for (const Node& a : assertions) + { + if (a.getKind() == AND) + { + // split top-level AND + for (const Node& ac : a) + { + toProcess.push_back(ac); + } + } + else + { + d_input.push_back(a); + } + } + addAssertionsInternal(toProcess); +} + +void RelevanceManager::notifyPreprocessedAssertion(Node n) +{ + std::vector<Node> toProcess; + toProcess.push_back(n); + addAssertionsInternal(toProcess); +} + +void RelevanceManager::addAssertionsInternal(std::vector<Node>& toProcess) +{ + size_t i = 0; + while (i < toProcess.size()) + { + Node a = toProcess[i]; + if (a.getKind() == AND) + { + // split AND + for (const Node& ac : a) + { + toProcess.push_back(ac); + } + } + else + { + // note that a could be a literal, in which case we could add it to + // an "always relevant" set here. + d_input.push_back(a); + } + i++; + } +} + +void RelevanceManager::resetRound() +{ + d_computed = false; + d_rset.clear(); +} + +void RelevanceManager::computeRelevance() +{ + d_computed = true; + Trace("rel-manager") << "RelevanceManager::computeRelevance..." << std::endl; + std::unordered_map<TNode, int, TNodeHashFunction> cache; + for (const Node& node: d_input) + { + TNode n = node; + int val = justify(n, cache); + if (val != 1) + { + std::stringstream serr; + serr << "RelevanceManager::computeRelevance: WARNING: failed to justify " + << n; + Trace("rel-manager") << serr.str() << std::endl; + Assert(false) << serr.str(); + d_success = false; + return; + } + } + Trace("rel-manager") << "...success, size = " << d_rset.size() << std::endl; + d_success = true; +} + +bool RelevanceManager::isBooleanConnective(TNode cur) +{ + Kind k = cur.getKind(); + return k == NOT || k == IMPLIES || k == AND || k == OR || k == ITE || k == XOR + || (k == EQUAL && cur[0].getType().isBoolean()); +} + +bool RelevanceManager::updateJustifyLastChild( + TNode cur, + std::vector<int>& childrenJustify, + std::unordered_map<TNode, int, TNodeHashFunction>& cache) +{ + // This method is run when we are informed that child index of cur + // has justify status lastChildJustify. We return true if we would like to + // compute the next child, in this case we push the status of the current + // child to childrenJustify. + size_t nchildren = cur.getNumChildren(); + Assert(isBooleanConnective(cur)); + size_t index = childrenJustify.size(); + Assert(index < nchildren); + Assert(cache.find(cur[index]) != cache.end()); + Kind k = cur.getKind(); + // Lookup the last child's value in the overall cache, we may choose to + // add this to childrenJustify if we return true. + int lastChildJustify = cache[cur[index]]; + if (k == NOT) + { + cache[cur] = -lastChildJustify; + } + else if (k == IMPLIES || k == AND || k == OR) + { + if (lastChildJustify != 0) + { + // See if we short circuited? The value for short circuiting is false if + // we are AND or the first child of IMPLIES. + if (lastChildJustify + == ((k == AND || (k == IMPLIES && index == 0)) ? -1 : 1)) + { + cache[cur] = k == AND ? -1 : 1; + return false; + } + } + if (index + 1 == nchildren) + { + // finished all children, compute the overall value + int ret = k == AND ? 1 : -1; + for (int cv : childrenJustify) + { + if (cv == 0) + { + ret = 0; + break; + } + } + cache[cur] = ret; + } + else + { + // continue + childrenJustify.push_back(lastChildJustify); + return true; + } + } + else if (lastChildJustify == 0) + { + // all other cases, an unknown child implies we are unknown + cache[cur] = 0; + } + else if (k == ITE) + { + if (index == 0) + { + Assert(lastChildJustify != 0); + // continue with branch + childrenJustify.push_back(lastChildJustify); + if (lastChildJustify == -1) + { + // also mark first branch as don't care + childrenJustify.push_back(0); + } + return true; + } + else + { + // should be in proper branch + Assert(childrenJustify[0] == (index == 1 ? 1 : -1)); + // we are the value of the branch + cache[cur] = lastChildJustify; + } + } + else + { + Assert(k == XOR || k == EQUAL); + Assert(nchildren == 2); + Assert(lastChildJustify != 0); + if (index == 0) + { + // must compute the other child + childrenJustify.push_back(lastChildJustify); + return true; + } + else + { + // both children known, compute value + Assert(childrenJustify.size() == 1 && childrenJustify[0] != 0); + cache[cur] = + ((k == XOR ? -1 : 1) * lastChildJustify == childrenJustify[0]) ? 1 + : -1; + } + } + return false; +} + +int RelevanceManager::justify( + TNode n, std::unordered_map<TNode, int, TNodeHashFunction>& cache) +{ + // the vector of values of children + std::unordered_map<TNode, std::vector<int>, TNodeHashFunction> childJustify; + std::unordered_map<TNode, int, TNodeHashFunction>::iterator it; + std::unordered_map<TNode, std::vector<int>, TNodeHashFunction>::iterator itc; + std::vector<TNode> visit; + TNode cur; + visit.push_back(n); + do + { + cur = visit.back(); + // should always have Boolean type + Assert(cur.getType().isBoolean()); + it = cache.find(cur); + if (it != cache.end()) + { + visit.pop_back(); + // already computed value + continue; + } + itc = childJustify.find(cur); + // have we traversed to children yet? + if (itc == childJustify.end()) + { + // are we not a Boolean connective (including NOT)? + if (isBooleanConnective(cur)) + { + // initialize its children justify vector as empty + childJustify[cur].clear(); + // start with the first child + visit.push_back(cur[0]); + } + else + { + visit.pop_back(); + // The atom case, lookup the value in the valuation class to + // see its current value in the SAT solver, if it has one. + int ret = 0; + // otherwise we look up the value + bool value; + if (d_val.hasSatValue(cur, value)) + { + ret = value ? 1 : -1; + d_rset.insert(cur); + } + cache[cur] = ret; + } + } + else + { + // this processes the impact of the current child on the value of cur, + // and possibly requests that a new child is computed. + if (updateJustifyLastChild(cur, itc->second, cache)) + { + Assert(itc->second.size() < cur.getNumChildren()); + TNode nextChild = cur[itc->second.size()]; + visit.push_back(nextChild); + } + else + { + visit.pop_back(); + } + } + } while (!visit.empty()); + Assert(cache.find(n) != cache.end()); + return cache[n]; +} + +bool RelevanceManager::isRelevant(Node lit) +{ + if (!d_computed) + { + computeRelevance(); + } + if (!d_success) + { + // always relevant if we failed to compute + return true; + } + // agnostic to negation + while (lit.getKind() == NOT) + { + lit = lit[0]; + } + return d_rset.find(lit) != d_rset.end(); +} + +} // namespace theory +} // namespace CVC4 diff --git a/src/theory/relevance_manager.h b/src/theory/relevance_manager.h new file mode 100644 index 000000000..bbb094fc0 --- /dev/null +++ b/src/theory/relevance_manager.h @@ -0,0 +1,154 @@ +/********************* */ +/*! \file relevance_manager.h + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Relevance manager. + **/ + +#include "cvc4_private.h" + +#ifndef CVC4__THEORY__RELEVANCE_MANAGER__H +#define CVC4__THEORY__RELEVANCE_MANAGER__H + +#include <unordered_map> +#include <unordered_set> + +#include "context/cdlist.h" +#include "expr/node.h" +#include "theory/valuation.h" + +namespace CVC4 { +namespace theory { + +/** + * This class manages queries related to relevance of asserted literals. + * In particular, note the following definition: + * + * Let F be a formula, and let L = { l_1, ..., l_n } be a set of + * literals that propositionally entail it. A "relevant selection of L with + * respect to F" is a subset of L that also propositionally entails F. + * + * This class computes a relevant selection of the current assertion stack + * at FULL effort with respect to the input formula + theory lemmas that are + * critical to justify (see LemmaProperty::NEEDS_JUSTIFY). By default, theory + * lemmas are not critical to justify; in fact, all T-valid theory lemmas + * are not critical to justify, since they are guaranteed to be satisfied in + * all inputs. However, some theory lemmas that introduce skolems need + * justification. + * + * As an example of such a lemma, take the example input formula: + * (and (exists ((x Int)) (P x)) (not (P 0))) + * A skolemization lemma like the following needs justification: + * (=> (exists ((x Int)) (P x)) (P k)) + * Intuitively, this is because the satisfiability of the existential above is + * being deferred to the satisfiability of (P k) where k is fresh. Thus, + * a relevant selection must include both (exists ((x Int)) (P x)) and (P k) + * in this example. + * + * Theories are responsible for marking such lemmas using the NEEDS_JUSTIFY + * property when calling OutputChannel::lemma. + * + * Notice that this class has some relation to the justification decision + * heuristic (--decision=justification), which constructs a relevant selection + * of the input formula by construction. This class is orthogonal to this + * method, since it computes relevant selection *after* a full assignment. Thus + * its main advantage with respect to decision=justification is that it can be + * used in combination with any SAT decision heuristic. + * + * Internally, this class stores the input assertions and can be asked if an + * asserted literal is part of the current relevant selection. The relevant + * selection is computed lazily, i.e. only when someone asks if a literal is + * relevant, and only at most once per FULL effort check. + */ +class RelevanceManager +{ + typedef context::CDList<Node> NodeList; + + public: + RelevanceManager(context::UserContext* userContext, Valuation val); + /** + * Notify (preprocessed) assertions. This is called for input formulas or + * lemmas that need justification that have been fully processed, just before + * adding them to the PropEngine. + */ + void notifyPreprocessedAssertions(const std::vector<Node>& assertions); + /** Singleton version of above */ + void notifyPreprocessedAssertion(Node n); + /** + * Reset round, called at the beginning of a full effort check in + * TheoryEngine. + */ + void resetRound(); + /** + * Is lit part of the current relevant selection? This call is valid during + * full effort check in TheoryEngine. This means that theories can query this + * during FULL or LAST_CALL efforts, through the Valuation class. + */ + bool isRelevant(Node lit); + + private: + /** + * Add the set of assertions to the formulas known to this class. This + * method handles optimizations such as breaking apart top-level applications + * of and. + */ + void addAssertionsInternal(std::vector<Node>& toProcess); + /** compute the relevant selection */ + void computeRelevance(); + /** + * Justify formula n. To "justify" means we have added literals to our + * relevant selection set (d_rset) whose current values ensure that n + * evaluates to true or false. + * + * This method returns 1 if we justified n to be true, -1 means + * justified n to be false, 0 means n could not be justified. + */ + int justify(TNode n, + std::unordered_map<TNode, int, TNodeHashFunction>& cache); + /** Is the top symbol of cur a Boolean connective? */ + bool isBooleanConnective(TNode cur); + /** + * Update justify last child. This method is a helper function for justify, + * which is called at the moment that Boolean connective formula cur + * has a new child that has been computed in the justify cache. + * + * @param cur The Boolean connective formula + * @param childrenJustify The values of the previous children (not including + * the current one) + * @param cache The justify cache + * @return True if we wish to visit the next child. If this is the case, then + * the justify value of the current child is added to childrenJustify. + */ + bool updateJustifyLastChild( + TNode cur, + std::vector<int>& childrenJustify, + std::unordered_map<TNode, int, TNodeHashFunction>& cache); + /** The valuation object, used to query current value of theory literals */ + Valuation d_val; + /** The input assertions */ + NodeList d_input; + /** The current relevant selection. */ + std::unordered_set<TNode, TNodeHashFunction> d_rset; + /** Have we computed the relevant selection this round? */ + bool d_computed; + /** + * Did we succeed in computing the relevant selection? If this is false, there + * was a syncronization issue between the input formula and the satisfying + * assignment since this class found that the input formula was not satisfied + * by the assignment. This should never happen, but if it does, this class + * aborts and indicates that all literals are relevant. + */ + bool d_success; +}; + +} // namespace theory +} // namespace CVC4 + +#endif /* CVC4__THEORY__RELEVANCE_MANAGER__H */ diff --git a/src/theory/theory.cpp b/src/theory/theory.cpp index 4f0cbdb6a..9669d97e0 100644 --- a/src/theory/theory.cpp +++ b/src/theory/theory.cpp @@ -82,6 +82,7 @@ Theory::Theory(TheoryId id, d_valuation(valuation), d_equalityEngine(nullptr), d_allocEqualityEngine(nullptr), + d_theoryState(nullptr), d_proofsEnabled(false) { smtStatisticsRegistry()->registerStat(&d_checkTime); @@ -343,6 +344,23 @@ std::unordered_set<TNode, TNodeHashFunction> Theory::currentlySharedTerms() cons return currentlyShared; } +bool Theory::collectModelInfo(TheoryModel* m) +{ + std::set<Node> termSet; + // Compute terms appearing in assertions and shared terms + computeRelevantTerms(termSet); + // if we are using an equality engine, assert it to the model + if (d_equalityEngine != nullptr) + { + if (!m->assertEqualityEngine(d_equalityEngine, &termSet)) + { + return false; + } + } + // now, collect theory-specific value assigments + return collectModelValues(m, termSet); +} + void Theory::collectTerms(TNode n, set<Kind>& irrKinds, set<Node>& termSet) const @@ -365,16 +383,9 @@ void Theory::collectTerms(TNode n, } } - -void Theory::computeRelevantTerms(set<Node>& termSet, bool includeShared) const -{ - set<Kind> irrKinds; - computeRelevantTerms(termSet, irrKinds, includeShared); -} - -void Theory::computeRelevantTerms(set<Node>& termSet, - set<Kind>& irrKinds, - bool includeShared) const +void Theory::computeRelevantTermsInternal(std::set<Node>& termSet, + std::set<Kind>& irrKinds, + bool includeShared) const { // Collect all terms appearing in assertions irrKinds.insert(kind::EQUAL); @@ -394,6 +405,17 @@ void Theory::computeRelevantTerms(set<Node>& termSet, } } +void Theory::computeRelevantTerms(std::set<Node>& termSet, bool includeShared) +{ + std::set<Kind> irrKinds; + computeRelevantTermsInternal(termSet, irrKinds, includeShared); +} + +bool Theory::collectModelValues(TheoryModel* m, std::set<Node>& termSet) +{ + return true; +} + Theory::PPAssertStatus Theory::ppAssert(TNode in, SubstitutionMap& outSubstitutions) { diff --git a/src/theory/theory.h b/src/theory/theory.h index 4feeac394..349f36a57 100644 --- a/src/theory/theory.h +++ b/src/theory/theory.h @@ -44,6 +44,7 @@ #include "theory/output_channel.h" #include "theory/theory_id.h" #include "theory/theory_rewriter.h" +#include "theory/theory_state.h" #include "theory/trust_node.h" #include "theory/valuation.h" #include "util/statistics_registry.h" @@ -183,13 +184,7 @@ class Theory { */ context::CDList<TNode> d_sharedTerms; - /** - * Helper function for computeRelevantTerms - */ - void collectTerms(TNode n, - std::set<Kind>& irrKinds, - std::set<Node>& termSet) const; - + //---------------------------------- collect model info /** * Scans the current set of assertions and shared terms top-down * until a theory-leaf is reached, and adds all terms found to @@ -203,11 +198,30 @@ class Theory { * includeShared: Whether to include shared terms in termSet. Notice that * shared terms are not influenced by irrKinds. */ - void computeRelevantTerms(std::set<Node>& termSet, - std::set<Kind>& irrKinds, - bool includeShared = true) const; - /** same as above, but with empty irrKinds */ - void computeRelevantTerms(std::set<Node>& termSet, bool includeShared = true) const; + void computeRelevantTermsInternal(std::set<Node>& termSet, + std::set<Kind>& irrKinds, + bool includeShared = true) const; + /** + * Helper function for computeRelevantTerms + */ + void collectTerms(TNode n, + std::set<Kind>& irrKinds, + std::set<Node>& termSet) const; + /** + * Same as above, but with empty irrKinds. This version can be overridden + * by the theory, e.g. by restricting or extending the set of terms returned + * by computeRelevantTermsInternal, which is called by default with no + * irrKinds. + */ + virtual void computeRelevantTerms(std::set<Node>& termSet, + bool includeShared = true); + /** + * Collect model values, after equality information is added to the model. + * The argument termSet is the set of relevant terms returned by + * computeRelevantTerms. + */ + virtual bool collectModelValues(TheoryModel* m, std::set<Node>& termSet); + //---------------------------------- end collect model info /** * Construct a Theory. @@ -255,6 +269,11 @@ class Theory { */ std::unique_ptr<eq::EqualityEngine> d_allocEqualityEngine; /** + * The theory state, which contains contexts, valuation, and equality engine. + * Notice the theory is responsible for memory management of this class. + */ + TheoryState* d_theoryState; + /** * Whether proofs are enabled * */ @@ -619,7 +638,7 @@ class Theory { * This method returns true if and only if the equality engine of m is * consistent as a result of this call. */ - virtual bool collectModelInfo(TheoryModel* m) { return true; } + virtual bool collectModelInfo(TheoryModel* m); /** if theories want to do something with model after building, do it here */ virtual void postProcessModel( TheoryModel* m ){ } /** diff --git a/src/theory/theory_engine.cpp b/src/theory/theory_engine.cpp index 07c160058..e86a09112 100644 --- a/src/theory/theory_engine.cpp +++ b/src/theory/theory_engine.cpp @@ -159,6 +159,13 @@ void TheoryEngine::finishInit() { d_aloc_curr_model_builder = true; } + // set the core equality engine on quantifiers engine + if (d_logicInfo.isQuantified()) + { + d_quantEngine->setMasterEqualityEngine( + d_eeDistributed->getMasterEqualityEngine()); + } + // finish initializing the theories for(TheoryId theoryId = theory::THEORY_FIRST; theoryId != theory::THEORY_LAST; ++ theoryId) { Theory* t = d_theoryTable[theoryId]; @@ -545,7 +552,7 @@ void TheoryEngine::check(Theory::Effort effort) { if( Theory::fullEffort(effort) && !d_inConflict && !needCheck()) { // case where we are about to answer SAT, the master equality engine, // if it exists, must be consistent. - eq::EqualityEngine* mee = getMasterEqualityEngine(); + eq::EqualityEngine* mee = d_eeDistributed->getMasterEqualityEngine(); if (mee != NULL) { AlwaysAssert(mee->consistent()); @@ -1807,12 +1814,6 @@ SharedTermsDatabase* TheoryEngine::getSharedTermsDatabase() return &d_sharedTerms; } -theory::eq::EqualityEngine* TheoryEngine::getMasterEqualityEngine() -{ - Assert(d_eeDistributed != nullptr); - return d_eeDistributed->getMasterEqualityEngine(); -} - void TheoryEngine::getExplanation(std::vector<NodeTheoryPair>& explanationVector, LemmaProofRecipe* proofRecipe) { Assert(explanationVector.size() > 0); diff --git a/src/theory/theory_engine.h b/src/theory/theory_engine.h index aa23aa29b..bedd54130 100644 --- a/src/theory/theory_engine.h +++ b/src/theory/theory_engine.h @@ -738,8 +738,6 @@ public: SharedTermsDatabase* getSharedTermsDatabase(); - theory::eq::EqualityEngine* getMasterEqualityEngine(); - SortInference* getSortInference() { return &d_sortInfer; } /** Prints the assertions to the debug stream */ diff --git a/src/theory/theory_state.cpp b/src/theory/theory_state.cpp new file mode 100644 index 000000000..bc8e53245 --- /dev/null +++ b/src/theory/theory_state.cpp @@ -0,0 +1,129 @@ +/********************* */ +/*! \file theory_state.cpp + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief A theory state for Theory + **/ + +#include "theory/theory_state.h" + +#include "theory/uf/equality_engine.h" + +namespace CVC4 { +namespace theory { + +TheoryState::TheoryState(context::Context* c, + context::UserContext* u, + Valuation val) + : d_context(c), + d_ucontext(u), + d_valuation(val), + d_ee(nullptr), + d_conflict(c, false) +{ +} + +void TheoryState::finishInit(eq::EqualityEngine* ee) { d_ee = ee; } + +context::Context* TheoryState::getSatContext() const { return d_context; } + +context::UserContext* TheoryState::getUserContext() const { return d_ucontext; } + +bool TheoryState::hasTerm(TNode a) const +{ + Assert(d_ee != nullptr); + return d_ee->hasTerm(a); +} + +TNode TheoryState::getRepresentative(TNode t) const +{ + Assert(d_ee != nullptr); + if (d_ee->hasTerm(t)) + { + return d_ee->getRepresentative(t); + } + return t; +} + +bool TheoryState::areEqual(TNode a, TNode b) const +{ + Assert(d_ee != nullptr); + if (a == b) + { + return true; + } + else if (hasTerm(a) && hasTerm(b)) + { + return d_ee->areEqual(a, b); + } + return false; +} + +bool TheoryState::areDisequal(TNode a, TNode b) const +{ + Assert(d_ee != nullptr); + if (a == b) + { + return false; + } + + bool isConst = true; + bool hasTerms = true; + if (hasTerm(a)) + { + a = d_ee->getRepresentative(a); + isConst = a.isConst(); + } + else if (!a.isConst()) + { + // if not constant and not a term in the ee, it cannot be disequal + return false; + } + else + { + hasTerms = false; + } + + if (hasTerm(b)) + { + b = d_ee->getRepresentative(b); + isConst = isConst && b.isConst(); + } + else if (!b.isConst()) + { + // same as above, it cannot be disequal + return false; + } + else + { + hasTerms = false; + } + + if (isConst) + { + // distinct constants are disequal + return a != b; + } + else if (!hasTerms) + { + return false; + } + // otherwise there may be an explicit disequality in the equality engine + return d_ee->areDisequal(a, b, false); +} + +eq::EqualityEngine* TheoryState::getEqualityEngine() const { return d_ee; } + +void TheoryState::notifyInConflict() { d_conflict = true; } + +bool TheoryState::isInConflict() const { return d_conflict; } + +} // namespace theory +} // namespace CVC4 diff --git a/src/theory/theory_state.h b/src/theory/theory_state.h new file mode 100644 index 000000000..71197dddc --- /dev/null +++ b/src/theory/theory_state.h @@ -0,0 +1,94 @@ +/********************* */ +/*! \file theory_state.h + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief A theory state for Theory + **/ + +#include "cvc4_private.h" + +#ifndef CVC4__THEORY__THEORY_STATE_H +#define CVC4__THEORY__THEORY_STATE_H + +#include "context/cdo.h" +#include "expr/node.h" +#include "theory/valuation.h" + +namespace CVC4 { +namespace theory { + +namespace eq { +class EqualityEngine; +} + +class TheoryState +{ + public: + TheoryState(context::Context* c, context::UserContext* u, Valuation val); + virtual ~TheoryState() {} + /** + * Finish initialize, ee is a pointer to the official equality engine + * of theory. + */ + virtual void finishInit(eq::EqualityEngine* ee); + /** Get the SAT context */ + context::Context* getSatContext() const; + /** Get the user context */ + context::UserContext* getUserContext() const; + //-------------------------------------- equality information + /** Is t registered as a term in the equality engine of this class? */ + virtual bool hasTerm(TNode a) const; + /** + * Get the representative of t in the equality engine of this class, or t + * itself if it is not registered as a term. + */ + virtual TNode getRepresentative(TNode t) const; + /** + * Are a and b equal according to the equality engine of this class? Also + * returns true if a and b are identical. + */ + virtual bool areEqual(TNode a, TNode b) const; + /** + * Are a and b disequal according to the equality engine of this class? Also + * returns true if the representative of a and b are distinct constants. + */ + virtual bool areDisequal(TNode a, TNode b) const; + /** get equality engine */ + eq::EqualityEngine* getEqualityEngine() const; + //-------------------------------------- end equality information + /** + * Set that the current state of the solver is in conflict. This should be + * called immediately after a call to conflict(...) on the output channel of + * the theory. + */ + virtual void notifyInConflict(); + /** Are we currently in conflict? */ + virtual bool isInConflict() const; + + protected: + /** Pointer to the SAT context object used by the theory. */ + context::Context* d_context; + /** Pointer to the user context object used by the theory. */ + context::UserContext* d_ucontext; + /** + * The valuation proxy for the Theory to communicate back with the + * theory engine (and other theories). + */ + Valuation d_valuation; + /** Pointer to equality engine of the theory. */ + eq::EqualityEngine* d_ee; + /** Are we in conflict? */ + context::CDO<bool> d_conflict; +}; + +} // namespace theory +} // namespace CVC4 + +#endif /* CVC4__THEORY__SOLVER_STATE_H */ diff --git a/src/theory/trust_node.cpp b/src/theory/trust_node.cpp index 25aef5a72..041d04d75 100644 --- a/src/theory/trust_node.cpp +++ b/src/theory/trust_node.cpp @@ -121,6 +121,22 @@ Node TrustNode::getPropExpProven(TNode lit, Node exp) Node TrustNode::getRewriteProven(TNode n, Node nr) { return n.eqNode(nr); } +void TrustNode::debugCheckClosed(const char* c, + const char* ctx, + bool reqNullGen) +{ + pfgEnsureClosed(d_proven, d_gen, c, ctx, reqNullGen); +} + +std::string TrustNode::identifyGenerator() const +{ + if (d_gen == nullptr) + { + return "null"; + } + return d_gen->identify(); +} + std::ostream& operator<<(std::ostream& out, TrustNode n) { out << "(" << n.getKind() << " " << n.getProven() << ")"; diff --git a/src/theory/trust_node.h b/src/theory/trust_node.h index ff174b63e..b7be0e4e5 100644 --- a/src/theory/trust_node.h +++ b/src/theory/trust_node.h @@ -142,6 +142,15 @@ class TrustNode static Node getPropExpProven(TNode lit, Node exp); /** Get the proven formula corresponding to a rewrite */ static Node getRewriteProven(TNode n, Node nr); + /** For debugging */ + std::string identifyGenerator() const; + + /** + * debug check closed on Trace c, context ctx is string for debugging + * + * @param reqNullGen Whether we consider a null generator to be a failure. + */ + void debugCheckClosed(const char* c, const char* ctx, bool reqNullGen = true); private: TrustNode(TrustNodeKind tnk, Node p, ProofGenerator* g = nullptr); diff --git a/test/signatures/drat_test.plf b/test/signatures/drat_test.plf index e5335a6bb..d66e48f8d 100644 --- a/test/signatures/drat_test.plf +++ b/test/signatures/drat_test.plf @@ -6,7 +6,7 @@ (! a clause (! b clause (! result bool - (! (^ + (! sc (^ (bool_and (bool_eq (clause_eq a b) result) (bool_and |