diff options
Diffstat (limited to 'src/theory')
-rw-r--r-- | src/theory/datatypes/theory_datatypes.cpp | 102 | ||||
-rw-r--r-- | src/theory/datatypes/theory_datatypes.h | 9 | ||||
-rw-r--r-- | src/theory/quantifiers/sygus/sygus_interpol.cpp | 57 | ||||
-rw-r--r-- | src/theory/quantifiers/sygus/sygus_interpol.h | 10 | ||||
-rw-r--r-- | src/theory/sep/theory_sep.cpp | 20 | ||||
-rw-r--r-- | src/theory/sets/normal_form.h | 109 | ||||
-rw-r--r-- | src/theory/sets/theory_sets_private.cpp | 2 | ||||
-rw-r--r-- | src/theory/sets/theory_sets_rewriter.cpp | 70 | ||||
-rw-r--r-- | src/theory/sets/theory_sets_rewriter.h | 6 | ||||
-rw-r--r-- | src/theory/theory_inference_manager.cpp | 1 | ||||
-rw-r--r-- | src/theory/uf/proof_equality_engine.cpp | 2 |
11 files changed, 189 insertions, 199 deletions
diff --git a/src/theory/datatypes/theory_datatypes.cpp b/src/theory/datatypes/theory_datatypes.cpp index 585f13d82..afa640650 100644 --- a/src/theory/datatypes/theory_datatypes.cpp +++ b/src/theory/datatypes/theory_datatypes.cpp @@ -602,21 +602,7 @@ bool TheoryDatatypes::propagateLit(TNode literal) { Debug("dt::propagate") << "TheoryDatatypes::propagateLit(" << literal << ")" << std::endl; - // If already in conflict, no more propagation - if (d_state.isInConflict()) - { - Debug("dt::propagate") << "TheoryDatatypes::propagateLit(" << literal - << "): already in conflict" << std::endl; - return false; - } - Trace("dt-prop") << "dtPropagate " << literal << std::endl; - // Propagate out - bool ok = d_out->propagate(literal); - if (!ok) { - Trace("dt-conflict") << "CONFLICT: Eq engine propagate conflict " << std::endl; - d_state.notifyInConflict(); - } - return ok; + return d_im.propagateLit(literal); } void TheoryDatatypes::addAssumptions( std::vector<TNode>& assumptions, std::vector<TNode>& tassumptions ) { @@ -671,8 +657,7 @@ void TheoryDatatypes::explain(TNode literal, std::vector<TNode>& assumptions){ TrustNode TheoryDatatypes::explain(TNode literal) { - Node exp = explainLit(literal); - return TrustNode::mkTrustPropExp(literal, exp, nullptr); + return d_im.explainLit(literal); } Node TheoryDatatypes::explainLit(TNode literal) @@ -692,11 +677,9 @@ Node TheoryDatatypes::explain( std::vector< Node >& lits ) { /** Conflict when merging two constants */ void TheoryDatatypes::conflict(TNode a, TNode b){ - Node eq = a.eqNode(b); - d_conflictNode = explainLit(eq); - Trace("dt-conflict") << "CONFLICT: Eq engine conflict : " << d_conflictNode << std::endl; - d_out->conflict( d_conflictNode ); - d_state.notifyInConflict(); + Trace("dt-conflict") << "CONFLICT: Eq engine conflict merge : " << a + << " == " << b << std::endl; + d_im.conflictEqConstantMerge(a, b); } /** called when a new equivalance class is created */ @@ -744,10 +727,11 @@ void TheoryDatatypes::merge( Node t1, Node t2 ){ std::vector< Node > rew; if (utils::checkClash(cons1, cons2, rew)) { - d_conflictNode = explainLit(unifEq); - Trace("dt-conflict") << "CONFLICT: Clash conflict : " << d_conflictNode << std::endl; - d_out->conflict( d_conflictNode ); - d_state.notifyInConflict(); + std::vector<Node> conf; + conf.push_back(unifEq); + Trace("dt-conflict") + << "CONFLICT: Clash conflict : " << conf << std::endl; + d_im.conflictExp(conf, nullptr); return; } else @@ -946,13 +930,12 @@ void TheoryDatatypes::addTester( { if( !eqc->d_constructor.get().isNull() ){ //conflict because equivalence class contains a constructor - std::vector< TNode > assumptions; - explain( t, assumptions ); - explainEquality( eqc->d_constructor.get(), t_arg, true, assumptions ); - d_conflictNode = mkAnd( assumptions ); - Trace("dt-conflict") << "CONFLICT: Tester eq conflict : " << d_conflictNode << std::endl; - d_out->conflict( d_conflictNode ); - d_state.notifyInConflict(); + std::vector<Node> conf; + conf.push_back(t); + conf.push_back(eqc->d_constructor.get().eqNode(t_arg)); + Trace("dt-conflict") + << "CONFLICT: Tester eq conflict " << conf << std::endl; + d_im.conflictExp(conf, nullptr); return; }else{ makeConflict = true; @@ -1051,15 +1034,13 @@ void TheoryDatatypes::addTester( } } if( makeConflict ){ - d_state.notifyInConflict(); Debug("datatypes-labels") << "Explain " << j << " " << t << std::endl; - std::vector< TNode > assumptions; - explain( j, assumptions ); - explain( t, assumptions ); - explainEquality( jt[0], t_arg, true, assumptions ); - d_conflictNode = mkAnd( assumptions ); - Trace("dt-conflict") << "CONFLICT: Tester conflict : " << d_conflictNode << std::endl; - d_out->conflict( d_conflictNode ); + std::vector<Node> conf; + conf.push_back(j); + conf.push_back(t); + conf.push_back(jt[0].eqNode(t_arg)); + Trace("dt-conflict") << "CONFLICT: Tester conflict : " << conf << std::endl; + d_im.conflictExp(conf, nullptr); } } @@ -1112,13 +1093,12 @@ void TheoryDatatypes::addConstructor( Node c, EqcInfo* eqc, Node n ){ unsigned tindex = d_labels_tindex[n][i]; if (tindex == constructorIndex) { - std::vector< TNode > assumptions; - explain( t, assumptions ); - explainEquality( c, t[0][0], true, assumptions ); - d_conflictNode = mkAnd( assumptions ); - Trace("dt-conflict") << "CONFLICT: Tester merge eq conflict : " << d_conflictNode << std::endl; - d_out->conflict( d_conflictNode ); - d_state.notifyInConflict(); + std::vector<Node> conf; + conf.push_back(t); + conf.push_back(c.eqNode(t[0][0])); + Trace("dt-conflict") + << "CONFLICT: Tester merge eq conflict : " << conf << std::endl; + d_im.conflictExp(conf, nullptr); return; } } @@ -1671,7 +1651,7 @@ void TheoryDatatypes::checkCycles() { //do cycle checks std::map< TNode, bool > visited; std::map< TNode, bool > proc; - std::vector< TNode > expl; + std::vector<Node> expl; Trace("datatypes-cycle-check") << "...search for cycle starting at " << eqc << std::endl; Node cn = searchForCycle( eqc, eqc, visited, proc, expl ); Trace("datatypes-cycle-check") << "...finish." << std::endl; @@ -1687,10 +1667,9 @@ void TheoryDatatypes::checkCycles() { if( !cn.isNull() ) { Assert(expl.size() > 0); - d_conflictNode = mkAnd( expl ); - Trace("dt-conflict") << "CONFLICT: Cycle conflict : " << d_conflictNode << std::endl; - d_out->conflict( d_conflictNode ); - d_state.notifyInConflict(); + Trace("dt-conflict") + << "CONFLICT: Cycle conflict : " << expl << std::endl; + d_im.conflictExp(expl, nullptr); return; } } @@ -1860,16 +1839,23 @@ void TheoryDatatypes::separateBisimilar( std::vector< Node >& part, std::vector< } //postcondition: if cycle detected, explanation is why n is a subterm of on -Node TheoryDatatypes::searchForCycle( TNode n, TNode on, - std::map< TNode, bool >& visited, std::map< TNode, bool >& proc, - std::vector< TNode >& explanation, bool firstTime ) { +Node TheoryDatatypes::searchForCycle(TNode n, + TNode on, + std::map<TNode, bool>& visited, + std::map<TNode, bool>& proc, + std::vector<Node>& explanation, + bool firstTime) +{ Trace("datatypes-cycle-check2") << "Search for cycle " << n << " " << on << endl; TNode ncons; TNode nn; if( !firstTime ){ nn = getRepresentative( n ); if( nn==on ){ - explainEquality( n, nn, true, explanation ); + if (n != nn) + { + explanation.push_back(n.eqNode(nn)); + } return on; } }else{ @@ -1893,7 +1879,7 @@ Node TheoryDatatypes::searchForCycle( TNode n, TNode on, //add explanation for why the constructor is connected if (n != nncons) { - explainEquality(n, nncons, true, explanation); + explanation.push_back(n.eqNode(nncons)); } return on; }else if( !cn.isNull() ){ diff --git a/src/theory/datatypes/theory_datatypes.h b/src/theory/datatypes/theory_datatypes.h index bf5d33177..d34390a5f 100644 --- a/src/theory/datatypes/theory_datatypes.h +++ b/src/theory/datatypes/theory_datatypes.h @@ -293,9 +293,12 @@ private: Node removeUninterpretedConstants( Node n, std::map< Node, Node >& visited ); /** for checking if cycles exist */ void checkCycles(); - Node searchForCycle( TNode n, TNode on, - std::map< TNode, bool >& visited, std::map< TNode, bool >& proc, - std::vector< TNode >& explanation, bool firstTime = true ); + Node searchForCycle(TNode n, + TNode on, + std::map<TNode, bool>& visited, + std::map<TNode, bool>& proc, + std::vector<Node>& explanation, + bool firstTime = true); /** for checking whether two codatatype terms must be equal */ void separateBisimilar( std::vector< Node >& part, std::vector< std::vector< Node > >& part_out, std::vector< TNode >& exp, diff --git a/src/theory/quantifiers/sygus/sygus_interpol.cpp b/src/theory/quantifiers/sygus/sygus_interpol.cpp index c2ca83e41..4d18c850b 100644 --- a/src/theory/quantifiers/sygus/sygus_interpol.cpp +++ b/src/theory/quantifiers/sygus/sygus_interpol.cpp @@ -23,6 +23,7 @@ #include "theory/quantifiers/quantifiers_attributes.h" #include "theory/quantifiers/sygus/sygus_grammar_cons.h" #include "theory/rewriter.h" +#include "theory/smt_engine_subsolver.h" namespace CVC4 { namespace theory { @@ -30,8 +31,6 @@ namespace quantifiers { SygusInterpol::SygusInterpol() {} -SygusInterpol::SygusInterpol(LogicInfo logic) : d_logic(logic) {} - void SygusInterpol::collectSymbols(const std::vector<Node>& axioms, const Node& conj) { @@ -75,6 +74,9 @@ void SygusInterpol::createVariables(bool needsShared) Node var = nm->mkBoundVar(tn); d_vars.push_back(var); Node vlv = nm->mkBoundVar(ss.str(), tn); + // set that this variable encodes the term s + SygusVarToTermAttribute sta; + vlv.setAttribute(sta, s); d_vlvs.push_back(vlv); if (!needsShared || d_symSetShared.find(s) != d_symSetShared.end()) { @@ -266,7 +268,7 @@ void SygusInterpol::mkSygusConjecture(Node itp, Trace("sygus-interpol") << "Generate: " << d_sygusConj << std::endl; } -bool SygusInterpol::findInterpol(Expr& interpol, Node itp) +bool SygusInterpol::findInterpol(Node& interpol, Node itp) { // get the synthesis solution std::map<Node, Node> sols; @@ -283,31 +285,31 @@ bool SygusInterpol::findInterpol(Expr& interpol, Node itp) } Trace("sygus-interpol") << "SmtEngine::getInterpol: solution is " << its->second << std::endl; - Node interpoln = its->second; + interpol = its->second; // replace back the created variables to original symbols. - Node interpoln_reduced; - if (interpoln.getKind() == kind::LAMBDA) + if (interpol.getKind() == kind::LAMBDA) { - interpoln_reduced = interpoln[1]; + interpol = interpol[1]; } - else + + // get the grammar type for the interpolant + Node igdtbv = itp.getAttribute(SygusSynthFunVarListAttribute()); + Assert(!igdtbv.isNull()); + Assert(igdtbv.getKind() == kind::BOUND_VAR_LIST); + // convert back to original + // must replace formal arguments of itp with the free variables in the + // input problem that they correspond to. + std::vector<Node> vars; + std::vector<Node> syms; + SygusVarToTermAttribute sta; + for (const Node& bv : igdtbv) { - interpoln_reduced = interpoln; + vars.push_back(bv); + syms.push_back(bv.hasAttribute(sta) ? bv.getAttribute(sta) : bv); } - if (interpoln.getNumChildren() != 0 && interpoln[0].getNumChildren() != 0) - { - std::vector<Node> formals; - for (const Node& n : interpoln[0]) - { - formals.push_back(n); - } - interpoln_reduced = interpoln_reduced.substitute(formals.begin(), - formals.end(), - d_symSetShared.begin(), - d_symSetShared.end()); - } - // convert to expression - interpol = interpoln_reduced.toExpr(); + interpol = + interpol.substitute(vars.begin(), vars.end(), syms.begin(), syms.end()); + return true; } @@ -315,14 +317,11 @@ bool SygusInterpol::SolveInterpolation(const std::string& name, const std::vector<Node>& axioms, const Node& conj, const TypeNode& itpGType, - Expr& interpol) + Node& interpol) { - NodeManager* nm = NodeManager::currentNM(); - // we generate a new smt engine to do the interpolation query - d_subSolver.reset(new SmtEngine(nm->toExprManager())); - d_subSolver->setIsInternalSubsolver(); + initializeSubsolver(d_subSolver); // get the logic - LogicInfo l = d_logic.getUnlockedCopy(); + LogicInfo l = d_subSolver->getLogicInfo().getUnlockedCopy(); // enable everything needed for sygus l.enableSygus(); d_subSolver->setLogic(l); diff --git a/src/theory/quantifiers/sygus/sygus_interpol.h b/src/theory/quantifiers/sygus/sygus_interpol.h index 0fe66694f..4abe94f15 100644 --- a/src/theory/quantifiers/sygus/sygus_interpol.h +++ b/src/theory/quantifiers/sygus/sygus_interpol.h @@ -46,8 +46,6 @@ class SygusInterpol public: SygusInterpol(); - SygusInterpol(LogicInfo logic); - /** * Returns the sygus conjecture in interpol corresponding to the interpolation * problem for input problem (F above) given by axioms (Fa above), and conj @@ -65,7 +63,7 @@ class SygusInterpol const std::vector<Node>& axioms, const Node& conj, const TypeNode& itpGType, - Expr& interpol); + Node& interpol); private: /** @@ -158,7 +156,7 @@ class SygusInterpol * @param interpol the solution to the sygus conjecture. * @param itp the interpolation predicate. */ - bool findInterpol(Expr& interpol, Node itp); + bool findInterpol(Node& interpol, Node itp); /** The SMT engine subSolver * @@ -179,10 +177,6 @@ class SygusInterpol std::unique_ptr<SmtEngine> d_subSolver; /** - * The logic for the local copy of SMT engine (d_subSolver). - */ - LogicInfo d_logic; - /** * symbols from axioms and conjecture. */ std::vector<Node> d_syms; diff --git a/src/theory/sep/theory_sep.cpp b/src/theory/sep/theory_sep.cpp index c9b6a9d89..573449287 100644 --- a/src/theory/sep/theory_sep.cpp +++ b/src/theory/sep/theory_sep.cpp @@ -1646,9 +1646,9 @@ void TheorySep::computeLabelModel( Node lbl ) { Trace("sep-process") << "Model value (from valuation) for " << lbl << " : " << v_val << std::endl; if( v_val.getKind()!=kind::EMPTYSET ){ while( v_val.getKind()==kind::UNION ){ - Assert(v_val[1].getKind() == kind::SINGLETON); - d_label_model[lbl].d_heap_locs_model.push_back( v_val[1] ); - v_val = v_val[0]; + Assert(v_val[0].getKind() == kind::SINGLETON); + d_label_model[lbl].d_heap_locs_model.push_back(v_val[0]); + v_val = v_val[1]; } if( v_val.getKind()==kind::SINGLETON ){ d_label_model[lbl].d_heap_locs_model.push_back( v_val ); @@ -1916,15 +1916,13 @@ Node TheorySep::HeapInfo::getValue( TypeNode tn ) { Assert(d_heap_locs.size() == d_heap_locs_model.size()); if( d_heap_locs.empty() ){ return NodeManager::currentNM()->mkConst(EmptySet(tn)); - }else if( d_heap_locs.size()==1 ){ - return d_heap_locs[0]; - }else{ - Node curr = NodeManager::currentNM()->mkNode( kind::UNION, d_heap_locs[0], d_heap_locs[1] ); - for( unsigned j=2; j<d_heap_locs.size(); j++ ){ - curr = NodeManager::currentNM()->mkNode( kind::UNION, curr, d_heap_locs[j] ); - } - return curr; } + Node curr = d_heap_locs[0]; + for (unsigned j = 1; j < d_heap_locs.size(); j++) + { + curr = NodeManager::currentNM()->mkNode(kind::UNION, d_heap_locs[j], curr); + } + return curr; } }/* CVC4::theory::sep namespace */ diff --git a/src/theory/sets/normal_form.h b/src/theory/sets/normal_form.h index 0607a0e6c..b53a1c03d 100644 --- a/src/theory/sets/normal_form.h +++ b/src/theory/sets/normal_form.h @@ -25,6 +25,12 @@ namespace sets { class NormalForm { public: + /** + * Constructs a set of the form: + * (union (singleton c1) ... (union (singleton c_{n-1}) (singleton c_n)))) + * from the set { c1 ... cn }, also handles empty set case, which is why + * setType is passed to this method. + */ template <bool ref_count> static Node elementsToSet(const std::set<NodeTemplate<ref_count> >& elements, TypeNode setType) @@ -42,12 +48,21 @@ class NormalForm { Node cur = nm->mkNode(kind::SINGLETON, *it); while (++it != elements.end()) { - cur = nm->mkNode(kind::UNION, cur, nm->mkNode(kind::SINGLETON, *it)); + cur = nm->mkNode(kind::UNION, nm->mkNode(kind::SINGLETON, *it), cur); } return cur; } } + /** + * Returns true if n is considered a to be a (canonical) constant set value. + * A canonical set value is one whose AST is: + * (union (singleton c1) ... (union (singleton c_{n-1}) (singleton c_n)))) + * where c1 ... cn are constants and the node identifier of these constants + * are such that: + * c1 > ... > cn. + * Also handles the corner cases of empty set and singleton set. + */ static bool checkNormalConstant(TNode n) { Debug("sets-checknormal") << "[sets-checknormal] checkNormal " << n << " :" << std::endl; @@ -56,46 +71,62 @@ class NormalForm { } else if (n.getKind() == kind::SINGLETON) { return n[0].isConst(); } else if (n.getKind() == kind::UNION) { - // assuming (union ... (union {SmallestNodeID} {BiggerNodeId}) ... - // {BiggestNodeId}) - - // store BiggestNodeId in prvs - if (n[1].getKind() != kind::SINGLETON) return false; - if (!n[1][0].isConst()) return false; - Debug("sets-checknormal") - << "[sets-checknormal] frst element = " << n[1][0] << " " - << n[1][0].getId() << std::endl; - TNode prvs = n[1][0]; - n = n[0]; + // assuming (union {SmallestNodeID} ... (union {BiggerNodeId} ... + Node orig = n; + TNode prvs; // check intermediate nodes - while (n.getKind() == kind::UNION) { - if (n[1].getKind() != kind::SINGLETON) return false; - if (!n[1].isConst()) return false; + while (n.getKind() == kind::UNION) + { + if (n[0].getKind() != kind::SINGLETON || !n[0][0].isConst()) + { + // not a constant + Trace("sets-isconst") << "sets::isConst: " << orig << " not due to " + << n[0] << std::endl; + return false; + } Debug("sets-checknormal") - << "[sets-checknormal] element = " << n[1][0] << " " - << n[1][0].getId() << std::endl; - if (n[1][0] >= prvs) return false; - prvs = n[1][0]; - n = n[0]; + << "[sets-checknormal] element = " << n[0][0] << " " + << n[0][0].getId() << std::endl; + if (!prvs.isNull() && n[0][0] >= prvs) + { + Trace("sets-isconst") + << "sets::isConst: " << orig << " not due to compare " << n[0][0] + << std::endl; + return false; + } + prvs = n[0][0]; + n = n[1]; } // check SmallestNodeID is smallest - if (n.getKind() != kind::SINGLETON) return false; - if (!n[0].isConst()) return false; + if (n.getKind() != kind::SINGLETON || !n[0].isConst()) + { + Trace("sets-isconst") << "sets::isConst: " << orig + << " not due to final " << n << std::endl; + return false; + } Debug("sets-checknormal") << "[sets-checknormal] lst element = " << n[0] << " " << n[0].getId() << std::endl; - if (n[0] >= prvs) return false; - - // we made it - return true; - - } else { - return false; + // compare last ID + if (n[0] < prvs) + { + return true; + } + Trace("sets-isconst") + << "sets::isConst: " << orig << " not due to compare final " << n[0] + << std::endl; } + return false; } + /** + * Converts a set term to a std::set of its elements. This expects a set of + * the form: + * (union (singleton c1) ... (union (singleton c_{n-1}) (singleton c_n)))) + * Also handles the corner cases of empty set and singleton set. + */ static std::set<Node> getElementsFromNormalConstant(TNode n) { Assert(n.isConst()); std::set<Node> ret; @@ -103,29 +134,15 @@ class NormalForm { return ret; } while (n.getKind() == kind::UNION) { - Assert(n[1].getKind() == kind::SINGLETON); - ret.insert(ret.begin(), n[1][0]); - n = n[0]; + Assert(n[0].getKind() == kind::SINGLETON); + ret.insert(ret.begin(), n[0][0]); + n = n[1]; } Assert(n.getKind() == kind::SINGLETON); ret.insert(n[0]); return ret; } - - //AJR - - static void getElementsFromBop( Kind k, Node n, std::vector< Node >& els ){ - if( n.getKind()==k ){ - for( unsigned i=0; i<n.getNumChildren(); i++ ){ - getElementsFromBop( k, n[i], els ); - } - }else{ - if( std::find( els.begin(), els.end(), n )==els.end() ){ - els.push_back( n ); - } - } - } static Node mkBop( Kind k, std::vector< Node >& els, TypeNode tn, unsigned index = 0 ){ if( index>=els.size() ){ return NodeManager::currentNM()->mkConst(EmptySet(tn)); diff --git a/src/theory/sets/theory_sets_private.cpp b/src/theory/sets/theory_sets_private.cpp index 741f45dd8..b1831f261 100644 --- a/src/theory/sets/theory_sets_private.cpp +++ b/src/theory/sets/theory_sets_private.cpp @@ -320,7 +320,7 @@ void TheorySetsPrivate::fullEffortCheck() Node n = (*eqc_i); if (n != eqc) { - Trace("sets-eqc") << n << " "; + Trace("sets-eqc") << n << " (" << n.isConst() << ") "; } TypeNode tnn = n.getType(); if (isSet) diff --git a/src/theory/sets/theory_sets_rewriter.cpp b/src/theory/sets/theory_sets_rewriter.cpp index eb168c6ed..50aa89cc8 100644 --- a/src/theory/sets/theory_sets_rewriter.cpp +++ b/src/theory/sets/theory_sets_rewriter.cpp @@ -27,7 +27,7 @@ namespace CVC4 { namespace theory { namespace sets { -bool checkConstantMembership(TNode elementTerm, TNode setTerm) +bool TheorySetsRewriter::checkConstantMembership(TNode elementTerm, TNode setTerm) { if(setTerm.getKind() == kind::EMPTYSET) { return false; @@ -38,12 +38,11 @@ bool checkConstantMembership(TNode elementTerm, TNode setTerm) } Assert(setTerm.getKind() == kind::UNION - && setTerm[1].getKind() == kind::SINGLETON) + && setTerm[0].getKind() == kind::SINGLETON) << "kind was " << setTerm.getKind() << ", term: " << setTerm; - return - elementTerm == setTerm[1][0] || - checkConstantMembership(elementTerm, setTerm[0]); + return elementTerm == setTerm[0][0] + || checkConstantMembership(elementTerm, setTerm[1]); } // static @@ -53,6 +52,8 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) { Trace("sets-postrewrite") << "Process: " << node << std::endl; if(node.isConst()) { + Trace("sets-rewrite-nf") + << "Sets::rewrite: no rewrite (constant) " << node << std::endl; // Dare you touch the const and mangle it to something else. return RewriteResponse(REWRITE_DONE, node); } @@ -163,23 +164,13 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) { Assert(newNode.isConst()); Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl; return RewriteResponse(REWRITE_DONE, newNode); - } else { - std::vector< Node > els; - NormalForm::getElementsFromBop( kind::INTERSECTION, node, els ); - std::sort( els.begin(), els.end() ); - Node rew = NormalForm::mkBop( kind::INTERSECTION, els, node.getType() ); - if( rew!=node ){ - Trace("sets-rewrite") << "Sets::rewrite " << node << " -> " << rew << std::endl; - } - return RewriteResponse(REWRITE_DONE, rew); } - /* - } else if (node[0] > node[1]) { + else if (node[0] > node[1]) + { Node newNode = nm->mkNode(node.getKind(), node[1], node[0]); - Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl; return RewriteResponse(REWRITE_DONE, newNode); } - */ + // we don't merge non-constant intersections break; }//kind::INTERSECION @@ -200,19 +191,16 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) { std::inserter(newSet, newSet.begin())); Node newNode = NormalForm::elementsToSet(newSet, node.getType()); Assert(newNode.isConst()); - Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl; + Trace("sets-rewrite") + << "Sets::rewrite: UNION_CONSTANT_MERGE: " << newNode << std::endl; return RewriteResponse(REWRITE_DONE, newNode); - } else { - std::vector< Node > els; - NormalForm::getElementsFromBop( kind::UNION, node, els ); - std::sort( els.begin(), els.end() ); - Node rew = NormalForm::mkBop( kind::UNION, els, node.getType() ); - if( rew!=node ){ - Trace("sets-rewrite") << "Sets::rewrite " << node << " -> " << rew << std::endl; - } - Trace("sets-rewrite") << "...no rewrite." << std::endl; - return RewriteResponse(REWRITE_DONE, rew); } + else if (node[0] > node[1]) + { + Node newNode = nm->mkNode(node.getKind(), node[1], node[0]); + return RewriteResponse(REWRITE_DONE, newNode); + } + // we don't merge non-constant unions break; }//kind::UNION case kind::COMPLEMENT: { @@ -491,16 +479,15 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) { // static RewriteResponse TheorySetsRewriter::preRewrite(TNode node) { NodeManager* nm = NodeManager::currentNM(); - - if(node.getKind() == kind::EQUAL) { - + Kind k = node.getKind(); + if (k == kind::EQUAL) + { if(node[0] == node[1]) { return RewriteResponse(REWRITE_DONE, nm->mkConst(true)); } - - }//kind::EQUAL - else if(node.getKind() == kind::INSERT) { - + } + else if (k == kind::INSERT) + { Node insertedElements = nm->mkNode(kind::SINGLETON, node[0]); size_t setNodeIndex = node.getNumChildren()-1; for(size_t i = 1; i < setNodeIndex; ++i) { @@ -512,17 +499,16 @@ RewriteResponse TheorySetsRewriter::preRewrite(TNode node) { nm->mkNode(kind::UNION, insertedElements, node[setNodeIndex])); - - }//kind::INSERT - else if(node.getKind() == kind::SUBSET) { - + } + else if (k == kind::SUBSET) + { // rewrite (A subset-or-equal B) as (A union B = B) return RewriteResponse(REWRITE_AGAIN, nm->mkNode(kind::EQUAL, nm->mkNode(kind::UNION, node[0], node[1]), node[1]) ); - - }//kind::SUBSET + } + // could have an efficient normalizer for union here return RewriteResponse(REWRITE_DONE, node); } diff --git a/src/theory/sets/theory_sets_rewriter.h b/src/theory/sets/theory_sets_rewriter.h index 7d1a6c188..fdc9caefb 100644 --- a/src/theory/sets/theory_sets_rewriter.h +++ b/src/theory/sets/theory_sets_rewriter.h @@ -70,7 +70,11 @@ class TheorySetsRewriter : public TheoryRewriter // often this will suffice return postRewrite(equality).d_node; } - +private: + /** + * Returns true if elementTerm is in setTerm, where both terms are constants. + */ + bool checkConstantMembership(TNode elementTerm, TNode setTerm); }; /* class TheorySetsRewriter */ }/* CVC4::theory::sets namespace */ diff --git a/src/theory/theory_inference_manager.cpp b/src/theory/theory_inference_manager.cpp index 81f5c45e6..980763040 100644 --- a/src/theory/theory_inference_manager.cpp +++ b/src/theory/theory_inference_manager.cpp @@ -141,6 +141,7 @@ TrustNode TheoryInferenceManager::mkConflictExp(const std::vector<Node>& exp, { if (d_pfee != nullptr) { + Assert(pg != nullptr); // use proof equality engine to construct the trust node return d_pfee->assertConflict(exp, pg); } diff --git a/src/theory/uf/proof_equality_engine.cpp b/src/theory/uf/proof_equality_engine.cpp index 00e4662f9..fa9482094 100644 --- a/src/theory/uf/proof_equality_engine.cpp +++ b/src/theory/uf/proof_equality_engine.cpp @@ -221,6 +221,7 @@ TrustNode ProofEqEngine::assertConflict(const std::vector<Node>& exp, TrustNode ProofEqEngine::assertConflict(const std::vector<Node>& exp, ProofGenerator* pg) { + Assert(pg != nullptr); Trace("pfee") << "pfee::assertConflict " << exp << " via generator" << std::endl; return assertLemma(d_false, exp, {}, pg); @@ -306,6 +307,7 @@ TrustNode ProofEqEngine::assertLemma(Node conc, const std::vector<Node>& noExplain, ProofGenerator* pg) { + Assert(pg != nullptr); Trace("pfee") << "pfee::assertLemma " << conc << ", exp = " << exp << ", noExplain = " << noExplain << " via buffer with generator" << std::endl; |