summaryrefslogtreecommitdiff
path: root/src/theory
diff options
context:
space:
mode:
Diffstat (limited to 'src/theory')
-rw-r--r--src/theory/datatypes/theory_datatypes.cpp102
-rw-r--r--src/theory/datatypes/theory_datatypes.h9
-rw-r--r--src/theory/quantifiers/sygus/sygus_interpol.cpp57
-rw-r--r--src/theory/quantifiers/sygus/sygus_interpol.h10
-rw-r--r--src/theory/sep/theory_sep.cpp20
-rw-r--r--src/theory/sets/normal_form.h109
-rw-r--r--src/theory/sets/theory_sets_private.cpp2
-rw-r--r--src/theory/sets/theory_sets_rewriter.cpp70
-rw-r--r--src/theory/sets/theory_sets_rewriter.h6
-rw-r--r--src/theory/theory_inference_manager.cpp1
-rw-r--r--src/theory/uf/proof_equality_engine.cpp2
11 files changed, 189 insertions, 199 deletions
diff --git a/src/theory/datatypes/theory_datatypes.cpp b/src/theory/datatypes/theory_datatypes.cpp
index 585f13d82..afa640650 100644
--- a/src/theory/datatypes/theory_datatypes.cpp
+++ b/src/theory/datatypes/theory_datatypes.cpp
@@ -602,21 +602,7 @@ bool TheoryDatatypes::propagateLit(TNode literal)
{
Debug("dt::propagate") << "TheoryDatatypes::propagateLit(" << literal << ")"
<< std::endl;
- // If already in conflict, no more propagation
- if (d_state.isInConflict())
- {
- Debug("dt::propagate") << "TheoryDatatypes::propagateLit(" << literal
- << "): already in conflict" << std::endl;
- return false;
- }
- Trace("dt-prop") << "dtPropagate " << literal << std::endl;
- // Propagate out
- bool ok = d_out->propagate(literal);
- if (!ok) {
- Trace("dt-conflict") << "CONFLICT: Eq engine propagate conflict " << std::endl;
- d_state.notifyInConflict();
- }
- return ok;
+ return d_im.propagateLit(literal);
}
void TheoryDatatypes::addAssumptions( std::vector<TNode>& assumptions, std::vector<TNode>& tassumptions ) {
@@ -671,8 +657,7 @@ void TheoryDatatypes::explain(TNode literal, std::vector<TNode>& assumptions){
TrustNode TheoryDatatypes::explain(TNode literal)
{
- Node exp = explainLit(literal);
- return TrustNode::mkTrustPropExp(literal, exp, nullptr);
+ return d_im.explainLit(literal);
}
Node TheoryDatatypes::explainLit(TNode literal)
@@ -692,11 +677,9 @@ Node TheoryDatatypes::explain( std::vector< Node >& lits ) {
/** Conflict when merging two constants */
void TheoryDatatypes::conflict(TNode a, TNode b){
- Node eq = a.eqNode(b);
- d_conflictNode = explainLit(eq);
- Trace("dt-conflict") << "CONFLICT: Eq engine conflict : " << d_conflictNode << std::endl;
- d_out->conflict( d_conflictNode );
- d_state.notifyInConflict();
+ Trace("dt-conflict") << "CONFLICT: Eq engine conflict merge : " << a
+ << " == " << b << std::endl;
+ d_im.conflictEqConstantMerge(a, b);
}
/** called when a new equivalance class is created */
@@ -744,10 +727,11 @@ void TheoryDatatypes::merge( Node t1, Node t2 ){
std::vector< Node > rew;
if (utils::checkClash(cons1, cons2, rew))
{
- d_conflictNode = explainLit(unifEq);
- Trace("dt-conflict") << "CONFLICT: Clash conflict : " << d_conflictNode << std::endl;
- d_out->conflict( d_conflictNode );
- d_state.notifyInConflict();
+ std::vector<Node> conf;
+ conf.push_back(unifEq);
+ Trace("dt-conflict")
+ << "CONFLICT: Clash conflict : " << conf << std::endl;
+ d_im.conflictExp(conf, nullptr);
return;
}
else
@@ -946,13 +930,12 @@ void TheoryDatatypes::addTester(
{
if( !eqc->d_constructor.get().isNull() ){
//conflict because equivalence class contains a constructor
- std::vector< TNode > assumptions;
- explain( t, assumptions );
- explainEquality( eqc->d_constructor.get(), t_arg, true, assumptions );
- d_conflictNode = mkAnd( assumptions );
- Trace("dt-conflict") << "CONFLICT: Tester eq conflict : " << d_conflictNode << std::endl;
- d_out->conflict( d_conflictNode );
- d_state.notifyInConflict();
+ std::vector<Node> conf;
+ conf.push_back(t);
+ conf.push_back(eqc->d_constructor.get().eqNode(t_arg));
+ Trace("dt-conflict")
+ << "CONFLICT: Tester eq conflict " << conf << std::endl;
+ d_im.conflictExp(conf, nullptr);
return;
}else{
makeConflict = true;
@@ -1051,15 +1034,13 @@ void TheoryDatatypes::addTester(
}
}
if( makeConflict ){
- d_state.notifyInConflict();
Debug("datatypes-labels") << "Explain " << j << " " << t << std::endl;
- std::vector< TNode > assumptions;
- explain( j, assumptions );
- explain( t, assumptions );
- explainEquality( jt[0], t_arg, true, assumptions );
- d_conflictNode = mkAnd( assumptions );
- Trace("dt-conflict") << "CONFLICT: Tester conflict : " << d_conflictNode << std::endl;
- d_out->conflict( d_conflictNode );
+ std::vector<Node> conf;
+ conf.push_back(j);
+ conf.push_back(t);
+ conf.push_back(jt[0].eqNode(t_arg));
+ Trace("dt-conflict") << "CONFLICT: Tester conflict : " << conf << std::endl;
+ d_im.conflictExp(conf, nullptr);
}
}
@@ -1112,13 +1093,12 @@ void TheoryDatatypes::addConstructor( Node c, EqcInfo* eqc, Node n ){
unsigned tindex = d_labels_tindex[n][i];
if (tindex == constructorIndex)
{
- std::vector< TNode > assumptions;
- explain( t, assumptions );
- explainEquality( c, t[0][0], true, assumptions );
- d_conflictNode = mkAnd( assumptions );
- Trace("dt-conflict") << "CONFLICT: Tester merge eq conflict : " << d_conflictNode << std::endl;
- d_out->conflict( d_conflictNode );
- d_state.notifyInConflict();
+ std::vector<Node> conf;
+ conf.push_back(t);
+ conf.push_back(c.eqNode(t[0][0]));
+ Trace("dt-conflict")
+ << "CONFLICT: Tester merge eq conflict : " << conf << std::endl;
+ d_im.conflictExp(conf, nullptr);
return;
}
}
@@ -1671,7 +1651,7 @@ void TheoryDatatypes::checkCycles() {
//do cycle checks
std::map< TNode, bool > visited;
std::map< TNode, bool > proc;
- std::vector< TNode > expl;
+ std::vector<Node> expl;
Trace("datatypes-cycle-check") << "...search for cycle starting at " << eqc << std::endl;
Node cn = searchForCycle( eqc, eqc, visited, proc, expl );
Trace("datatypes-cycle-check") << "...finish." << std::endl;
@@ -1687,10 +1667,9 @@ void TheoryDatatypes::checkCycles() {
if( !cn.isNull() ) {
Assert(expl.size() > 0);
- d_conflictNode = mkAnd( expl );
- Trace("dt-conflict") << "CONFLICT: Cycle conflict : " << d_conflictNode << std::endl;
- d_out->conflict( d_conflictNode );
- d_state.notifyInConflict();
+ Trace("dt-conflict")
+ << "CONFLICT: Cycle conflict : " << expl << std::endl;
+ d_im.conflictExp(expl, nullptr);
return;
}
}
@@ -1860,16 +1839,23 @@ void TheoryDatatypes::separateBisimilar( std::vector< Node >& part, std::vector<
}
//postcondition: if cycle detected, explanation is why n is a subterm of on
-Node TheoryDatatypes::searchForCycle( TNode n, TNode on,
- std::map< TNode, bool >& visited, std::map< TNode, bool >& proc,
- std::vector< TNode >& explanation, bool firstTime ) {
+Node TheoryDatatypes::searchForCycle(TNode n,
+ TNode on,
+ std::map<TNode, bool>& visited,
+ std::map<TNode, bool>& proc,
+ std::vector<Node>& explanation,
+ bool firstTime)
+{
Trace("datatypes-cycle-check2") << "Search for cycle " << n << " " << on << endl;
TNode ncons;
TNode nn;
if( !firstTime ){
nn = getRepresentative( n );
if( nn==on ){
- explainEquality( n, nn, true, explanation );
+ if (n != nn)
+ {
+ explanation.push_back(n.eqNode(nn));
+ }
return on;
}
}else{
@@ -1893,7 +1879,7 @@ Node TheoryDatatypes::searchForCycle( TNode n, TNode on,
//add explanation for why the constructor is connected
if (n != nncons)
{
- explainEquality(n, nncons, true, explanation);
+ explanation.push_back(n.eqNode(nncons));
}
return on;
}else if( !cn.isNull() ){
diff --git a/src/theory/datatypes/theory_datatypes.h b/src/theory/datatypes/theory_datatypes.h
index bf5d33177..d34390a5f 100644
--- a/src/theory/datatypes/theory_datatypes.h
+++ b/src/theory/datatypes/theory_datatypes.h
@@ -293,9 +293,12 @@ private:
Node removeUninterpretedConstants( Node n, std::map< Node, Node >& visited );
/** for checking if cycles exist */
void checkCycles();
- Node searchForCycle( TNode n, TNode on,
- std::map< TNode, bool >& visited, std::map< TNode, bool >& proc,
- std::vector< TNode >& explanation, bool firstTime = true );
+ Node searchForCycle(TNode n,
+ TNode on,
+ std::map<TNode, bool>& visited,
+ std::map<TNode, bool>& proc,
+ std::vector<Node>& explanation,
+ bool firstTime = true);
/** for checking whether two codatatype terms must be equal */
void separateBisimilar( std::vector< Node >& part, std::vector< std::vector< Node > >& part_out,
std::vector< TNode >& exp,
diff --git a/src/theory/quantifiers/sygus/sygus_interpol.cpp b/src/theory/quantifiers/sygus/sygus_interpol.cpp
index c2ca83e41..4d18c850b 100644
--- a/src/theory/quantifiers/sygus/sygus_interpol.cpp
+++ b/src/theory/quantifiers/sygus/sygus_interpol.cpp
@@ -23,6 +23,7 @@
#include "theory/quantifiers/quantifiers_attributes.h"
#include "theory/quantifiers/sygus/sygus_grammar_cons.h"
#include "theory/rewriter.h"
+#include "theory/smt_engine_subsolver.h"
namespace CVC4 {
namespace theory {
@@ -30,8 +31,6 @@ namespace quantifiers {
SygusInterpol::SygusInterpol() {}
-SygusInterpol::SygusInterpol(LogicInfo logic) : d_logic(logic) {}
-
void SygusInterpol::collectSymbols(const std::vector<Node>& axioms,
const Node& conj)
{
@@ -75,6 +74,9 @@ void SygusInterpol::createVariables(bool needsShared)
Node var = nm->mkBoundVar(tn);
d_vars.push_back(var);
Node vlv = nm->mkBoundVar(ss.str(), tn);
+ // set that this variable encodes the term s
+ SygusVarToTermAttribute sta;
+ vlv.setAttribute(sta, s);
d_vlvs.push_back(vlv);
if (!needsShared || d_symSetShared.find(s) != d_symSetShared.end())
{
@@ -266,7 +268,7 @@ void SygusInterpol::mkSygusConjecture(Node itp,
Trace("sygus-interpol") << "Generate: " << d_sygusConj << std::endl;
}
-bool SygusInterpol::findInterpol(Expr& interpol, Node itp)
+bool SygusInterpol::findInterpol(Node& interpol, Node itp)
{
// get the synthesis solution
std::map<Node, Node> sols;
@@ -283,31 +285,31 @@ bool SygusInterpol::findInterpol(Expr& interpol, Node itp)
}
Trace("sygus-interpol") << "SmtEngine::getInterpol: solution is "
<< its->second << std::endl;
- Node interpoln = its->second;
+ interpol = its->second;
// replace back the created variables to original symbols.
- Node interpoln_reduced;
- if (interpoln.getKind() == kind::LAMBDA)
+ if (interpol.getKind() == kind::LAMBDA)
{
- interpoln_reduced = interpoln[1];
+ interpol = interpol[1];
}
- else
+
+ // get the grammar type for the interpolant
+ Node igdtbv = itp.getAttribute(SygusSynthFunVarListAttribute());
+ Assert(!igdtbv.isNull());
+ Assert(igdtbv.getKind() == kind::BOUND_VAR_LIST);
+ // convert back to original
+ // must replace formal arguments of itp with the free variables in the
+ // input problem that they correspond to.
+ std::vector<Node> vars;
+ std::vector<Node> syms;
+ SygusVarToTermAttribute sta;
+ for (const Node& bv : igdtbv)
{
- interpoln_reduced = interpoln;
+ vars.push_back(bv);
+ syms.push_back(bv.hasAttribute(sta) ? bv.getAttribute(sta) : bv);
}
- if (interpoln.getNumChildren() != 0 && interpoln[0].getNumChildren() != 0)
- {
- std::vector<Node> formals;
- for (const Node& n : interpoln[0])
- {
- formals.push_back(n);
- }
- interpoln_reduced = interpoln_reduced.substitute(formals.begin(),
- formals.end(),
- d_symSetShared.begin(),
- d_symSetShared.end());
- }
- // convert to expression
- interpol = interpoln_reduced.toExpr();
+ interpol =
+ interpol.substitute(vars.begin(), vars.end(), syms.begin(), syms.end());
+
return true;
}
@@ -315,14 +317,11 @@ bool SygusInterpol::SolveInterpolation(const std::string& name,
const std::vector<Node>& axioms,
const Node& conj,
const TypeNode& itpGType,
- Expr& interpol)
+ Node& interpol)
{
- NodeManager* nm = NodeManager::currentNM();
- // we generate a new smt engine to do the interpolation query
- d_subSolver.reset(new SmtEngine(nm->toExprManager()));
- d_subSolver->setIsInternalSubsolver();
+ initializeSubsolver(d_subSolver);
// get the logic
- LogicInfo l = d_logic.getUnlockedCopy();
+ LogicInfo l = d_subSolver->getLogicInfo().getUnlockedCopy();
// enable everything needed for sygus
l.enableSygus();
d_subSolver->setLogic(l);
diff --git a/src/theory/quantifiers/sygus/sygus_interpol.h b/src/theory/quantifiers/sygus/sygus_interpol.h
index 0fe66694f..4abe94f15 100644
--- a/src/theory/quantifiers/sygus/sygus_interpol.h
+++ b/src/theory/quantifiers/sygus/sygus_interpol.h
@@ -46,8 +46,6 @@ class SygusInterpol
public:
SygusInterpol();
- SygusInterpol(LogicInfo logic);
-
/**
* Returns the sygus conjecture in interpol corresponding to the interpolation
* problem for input problem (F above) given by axioms (Fa above), and conj
@@ -65,7 +63,7 @@ class SygusInterpol
const std::vector<Node>& axioms,
const Node& conj,
const TypeNode& itpGType,
- Expr& interpol);
+ Node& interpol);
private:
/**
@@ -158,7 +156,7 @@ class SygusInterpol
* @param interpol the solution to the sygus conjecture.
* @param itp the interpolation predicate.
*/
- bool findInterpol(Expr& interpol, Node itp);
+ bool findInterpol(Node& interpol, Node itp);
/** The SMT engine subSolver
*
@@ -179,10 +177,6 @@ class SygusInterpol
std::unique_ptr<SmtEngine> d_subSolver;
/**
- * The logic for the local copy of SMT engine (d_subSolver).
- */
- LogicInfo d_logic;
- /**
* symbols from axioms and conjecture.
*/
std::vector<Node> d_syms;
diff --git a/src/theory/sep/theory_sep.cpp b/src/theory/sep/theory_sep.cpp
index c9b6a9d89..573449287 100644
--- a/src/theory/sep/theory_sep.cpp
+++ b/src/theory/sep/theory_sep.cpp
@@ -1646,9 +1646,9 @@ void TheorySep::computeLabelModel( Node lbl ) {
Trace("sep-process") << "Model value (from valuation) for " << lbl << " : " << v_val << std::endl;
if( v_val.getKind()!=kind::EMPTYSET ){
while( v_val.getKind()==kind::UNION ){
- Assert(v_val[1].getKind() == kind::SINGLETON);
- d_label_model[lbl].d_heap_locs_model.push_back( v_val[1] );
- v_val = v_val[0];
+ Assert(v_val[0].getKind() == kind::SINGLETON);
+ d_label_model[lbl].d_heap_locs_model.push_back(v_val[0]);
+ v_val = v_val[1];
}
if( v_val.getKind()==kind::SINGLETON ){
d_label_model[lbl].d_heap_locs_model.push_back( v_val );
@@ -1916,15 +1916,13 @@ Node TheorySep::HeapInfo::getValue( TypeNode tn ) {
Assert(d_heap_locs.size() == d_heap_locs_model.size());
if( d_heap_locs.empty() ){
return NodeManager::currentNM()->mkConst(EmptySet(tn));
- }else if( d_heap_locs.size()==1 ){
- return d_heap_locs[0];
- }else{
- Node curr = NodeManager::currentNM()->mkNode( kind::UNION, d_heap_locs[0], d_heap_locs[1] );
- for( unsigned j=2; j<d_heap_locs.size(); j++ ){
- curr = NodeManager::currentNM()->mkNode( kind::UNION, curr, d_heap_locs[j] );
- }
- return curr;
}
+ Node curr = d_heap_locs[0];
+ for (unsigned j = 1; j < d_heap_locs.size(); j++)
+ {
+ curr = NodeManager::currentNM()->mkNode(kind::UNION, d_heap_locs[j], curr);
+ }
+ return curr;
}
}/* CVC4::theory::sep namespace */
diff --git a/src/theory/sets/normal_form.h b/src/theory/sets/normal_form.h
index 0607a0e6c..b53a1c03d 100644
--- a/src/theory/sets/normal_form.h
+++ b/src/theory/sets/normal_form.h
@@ -25,6 +25,12 @@ namespace sets {
class NormalForm {
public:
+ /**
+ * Constructs a set of the form:
+ * (union (singleton c1) ... (union (singleton c_{n-1}) (singleton c_n))))
+ * from the set { c1 ... cn }, also handles empty set case, which is why
+ * setType is passed to this method.
+ */
template <bool ref_count>
static Node elementsToSet(const std::set<NodeTemplate<ref_count> >& elements,
TypeNode setType)
@@ -42,12 +48,21 @@ class NormalForm {
Node cur = nm->mkNode(kind::SINGLETON, *it);
while (++it != elements.end())
{
- cur = nm->mkNode(kind::UNION, cur, nm->mkNode(kind::SINGLETON, *it));
+ cur = nm->mkNode(kind::UNION, nm->mkNode(kind::SINGLETON, *it), cur);
}
return cur;
}
}
+ /**
+ * Returns true if n is considered a to be a (canonical) constant set value.
+ * A canonical set value is one whose AST is:
+ * (union (singleton c1) ... (union (singleton c_{n-1}) (singleton c_n))))
+ * where c1 ... cn are constants and the node identifier of these constants
+ * are such that:
+ * c1 > ... > cn.
+ * Also handles the corner cases of empty set and singleton set.
+ */
static bool checkNormalConstant(TNode n) {
Debug("sets-checknormal") << "[sets-checknormal] checkNormal " << n << " :"
<< std::endl;
@@ -56,46 +71,62 @@ class NormalForm {
} else if (n.getKind() == kind::SINGLETON) {
return n[0].isConst();
} else if (n.getKind() == kind::UNION) {
- // assuming (union ... (union {SmallestNodeID} {BiggerNodeId}) ...
- // {BiggestNodeId})
-
- // store BiggestNodeId in prvs
- if (n[1].getKind() != kind::SINGLETON) return false;
- if (!n[1][0].isConst()) return false;
- Debug("sets-checknormal")
- << "[sets-checknormal] frst element = " << n[1][0] << " "
- << n[1][0].getId() << std::endl;
- TNode prvs = n[1][0];
- n = n[0];
+ // assuming (union {SmallestNodeID} ... (union {BiggerNodeId} ...
+ Node orig = n;
+ TNode prvs;
// check intermediate nodes
- while (n.getKind() == kind::UNION) {
- if (n[1].getKind() != kind::SINGLETON) return false;
- if (!n[1].isConst()) return false;
+ while (n.getKind() == kind::UNION)
+ {
+ if (n[0].getKind() != kind::SINGLETON || !n[0][0].isConst())
+ {
+ // not a constant
+ Trace("sets-isconst") << "sets::isConst: " << orig << " not due to "
+ << n[0] << std::endl;
+ return false;
+ }
Debug("sets-checknormal")
- << "[sets-checknormal] element = " << n[1][0] << " "
- << n[1][0].getId() << std::endl;
- if (n[1][0] >= prvs) return false;
- prvs = n[1][0];
- n = n[0];
+ << "[sets-checknormal] element = " << n[0][0] << " "
+ << n[0][0].getId() << std::endl;
+ if (!prvs.isNull() && n[0][0] >= prvs)
+ {
+ Trace("sets-isconst")
+ << "sets::isConst: " << orig << " not due to compare " << n[0][0]
+ << std::endl;
+ return false;
+ }
+ prvs = n[0][0];
+ n = n[1];
}
// check SmallestNodeID is smallest
- if (n.getKind() != kind::SINGLETON) return false;
- if (!n[0].isConst()) return false;
+ if (n.getKind() != kind::SINGLETON || !n[0].isConst())
+ {
+ Trace("sets-isconst") << "sets::isConst: " << orig
+ << " not due to final " << n << std::endl;
+ return false;
+ }
Debug("sets-checknormal")
<< "[sets-checknormal] lst element = " << n[0] << " "
<< n[0].getId() << std::endl;
- if (n[0] >= prvs) return false;
-
- // we made it
- return true;
-
- } else {
- return false;
+ // compare last ID
+ if (n[0] < prvs)
+ {
+ return true;
+ }
+ Trace("sets-isconst")
+ << "sets::isConst: " << orig << " not due to compare final " << n[0]
+ << std::endl;
}
+ return false;
}
+ /**
+ * Converts a set term to a std::set of its elements. This expects a set of
+ * the form:
+ * (union (singleton c1) ... (union (singleton c_{n-1}) (singleton c_n))))
+ * Also handles the corner cases of empty set and singleton set.
+ */
static std::set<Node> getElementsFromNormalConstant(TNode n) {
Assert(n.isConst());
std::set<Node> ret;
@@ -103,29 +134,15 @@ class NormalForm {
return ret;
}
while (n.getKind() == kind::UNION) {
- Assert(n[1].getKind() == kind::SINGLETON);
- ret.insert(ret.begin(), n[1][0]);
- n = n[0];
+ Assert(n[0].getKind() == kind::SINGLETON);
+ ret.insert(ret.begin(), n[0][0]);
+ n = n[1];
}
Assert(n.getKind() == kind::SINGLETON);
ret.insert(n[0]);
return ret;
}
-
- //AJR
-
- static void getElementsFromBop( Kind k, Node n, std::vector< Node >& els ){
- if( n.getKind()==k ){
- for( unsigned i=0; i<n.getNumChildren(); i++ ){
- getElementsFromBop( k, n[i], els );
- }
- }else{
- if( std::find( els.begin(), els.end(), n )==els.end() ){
- els.push_back( n );
- }
- }
- }
static Node mkBop( Kind k, std::vector< Node >& els, TypeNode tn, unsigned index = 0 ){
if( index>=els.size() ){
return NodeManager::currentNM()->mkConst(EmptySet(tn));
diff --git a/src/theory/sets/theory_sets_private.cpp b/src/theory/sets/theory_sets_private.cpp
index 741f45dd8..b1831f261 100644
--- a/src/theory/sets/theory_sets_private.cpp
+++ b/src/theory/sets/theory_sets_private.cpp
@@ -320,7 +320,7 @@ void TheorySetsPrivate::fullEffortCheck()
Node n = (*eqc_i);
if (n != eqc)
{
- Trace("sets-eqc") << n << " ";
+ Trace("sets-eqc") << n << " (" << n.isConst() << ") ";
}
TypeNode tnn = n.getType();
if (isSet)
diff --git a/src/theory/sets/theory_sets_rewriter.cpp b/src/theory/sets/theory_sets_rewriter.cpp
index eb168c6ed..50aa89cc8 100644
--- a/src/theory/sets/theory_sets_rewriter.cpp
+++ b/src/theory/sets/theory_sets_rewriter.cpp
@@ -27,7 +27,7 @@ namespace CVC4 {
namespace theory {
namespace sets {
-bool checkConstantMembership(TNode elementTerm, TNode setTerm)
+bool TheorySetsRewriter::checkConstantMembership(TNode elementTerm, TNode setTerm)
{
if(setTerm.getKind() == kind::EMPTYSET) {
return false;
@@ -38,12 +38,11 @@ bool checkConstantMembership(TNode elementTerm, TNode setTerm)
}
Assert(setTerm.getKind() == kind::UNION
- && setTerm[1].getKind() == kind::SINGLETON)
+ && setTerm[0].getKind() == kind::SINGLETON)
<< "kind was " << setTerm.getKind() << ", term: " << setTerm;
- return
- elementTerm == setTerm[1][0] ||
- checkConstantMembership(elementTerm, setTerm[0]);
+ return elementTerm == setTerm[0][0]
+ || checkConstantMembership(elementTerm, setTerm[1]);
}
// static
@@ -53,6 +52,8 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
Trace("sets-postrewrite") << "Process: " << node << std::endl;
if(node.isConst()) {
+ Trace("sets-rewrite-nf")
+ << "Sets::rewrite: no rewrite (constant) " << node << std::endl;
// Dare you touch the const and mangle it to something else.
return RewriteResponse(REWRITE_DONE, node);
}
@@ -163,23 +164,13 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
Assert(newNode.isConst());
Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl;
return RewriteResponse(REWRITE_DONE, newNode);
- } else {
- std::vector< Node > els;
- NormalForm::getElementsFromBop( kind::INTERSECTION, node, els );
- std::sort( els.begin(), els.end() );
- Node rew = NormalForm::mkBop( kind::INTERSECTION, els, node.getType() );
- if( rew!=node ){
- Trace("sets-rewrite") << "Sets::rewrite " << node << " -> " << rew << std::endl;
- }
- return RewriteResponse(REWRITE_DONE, rew);
}
- /*
- } else if (node[0] > node[1]) {
+ else if (node[0] > node[1])
+ {
Node newNode = nm->mkNode(node.getKind(), node[1], node[0]);
- Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl;
return RewriteResponse(REWRITE_DONE, newNode);
}
- */
+ // we don't merge non-constant intersections
break;
}//kind::INTERSECION
@@ -200,19 +191,16 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
std::inserter(newSet, newSet.begin()));
Node newNode = NormalForm::elementsToSet(newSet, node.getType());
Assert(newNode.isConst());
- Trace("sets-postrewrite") << "Sets::postRewrite returning " << newNode << std::endl;
+ Trace("sets-rewrite")
+ << "Sets::rewrite: UNION_CONSTANT_MERGE: " << newNode << std::endl;
return RewriteResponse(REWRITE_DONE, newNode);
- } else {
- std::vector< Node > els;
- NormalForm::getElementsFromBop( kind::UNION, node, els );
- std::sort( els.begin(), els.end() );
- Node rew = NormalForm::mkBop( kind::UNION, els, node.getType() );
- if( rew!=node ){
- Trace("sets-rewrite") << "Sets::rewrite " << node << " -> " << rew << std::endl;
- }
- Trace("sets-rewrite") << "...no rewrite." << std::endl;
- return RewriteResponse(REWRITE_DONE, rew);
}
+ else if (node[0] > node[1])
+ {
+ Node newNode = nm->mkNode(node.getKind(), node[1], node[0]);
+ return RewriteResponse(REWRITE_DONE, newNode);
+ }
+ // we don't merge non-constant unions
break;
}//kind::UNION
case kind::COMPLEMENT: {
@@ -491,16 +479,15 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
// static
RewriteResponse TheorySetsRewriter::preRewrite(TNode node) {
NodeManager* nm = NodeManager::currentNM();
-
- if(node.getKind() == kind::EQUAL) {
-
+ Kind k = node.getKind();
+ if (k == kind::EQUAL)
+ {
if(node[0] == node[1]) {
return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
}
-
- }//kind::EQUAL
- else if(node.getKind() == kind::INSERT) {
-
+ }
+ else if (k == kind::INSERT)
+ {
Node insertedElements = nm->mkNode(kind::SINGLETON, node[0]);
size_t setNodeIndex = node.getNumChildren()-1;
for(size_t i = 1; i < setNodeIndex; ++i) {
@@ -512,17 +499,16 @@ RewriteResponse TheorySetsRewriter::preRewrite(TNode node) {
nm->mkNode(kind::UNION,
insertedElements,
node[setNodeIndex]));
-
- }//kind::INSERT
- else if(node.getKind() == kind::SUBSET) {
-
+ }
+ else if (k == kind::SUBSET)
+ {
// rewrite (A subset-or-equal B) as (A union B = B)
return RewriteResponse(REWRITE_AGAIN,
nm->mkNode(kind::EQUAL,
nm->mkNode(kind::UNION, node[0], node[1]),
node[1]) );
-
- }//kind::SUBSET
+ }
+ // could have an efficient normalizer for union here
return RewriteResponse(REWRITE_DONE, node);
}
diff --git a/src/theory/sets/theory_sets_rewriter.h b/src/theory/sets/theory_sets_rewriter.h
index 7d1a6c188..fdc9caefb 100644
--- a/src/theory/sets/theory_sets_rewriter.h
+++ b/src/theory/sets/theory_sets_rewriter.h
@@ -70,7 +70,11 @@ class TheorySetsRewriter : public TheoryRewriter
// often this will suffice
return postRewrite(equality).d_node;
}
-
+private:
+ /**
+ * Returns true if elementTerm is in setTerm, where both terms are constants.
+ */
+ bool checkConstantMembership(TNode elementTerm, TNode setTerm);
}; /* class TheorySetsRewriter */
}/* CVC4::theory::sets namespace */
diff --git a/src/theory/theory_inference_manager.cpp b/src/theory/theory_inference_manager.cpp
index 81f5c45e6..980763040 100644
--- a/src/theory/theory_inference_manager.cpp
+++ b/src/theory/theory_inference_manager.cpp
@@ -141,6 +141,7 @@ TrustNode TheoryInferenceManager::mkConflictExp(const std::vector<Node>& exp,
{
if (d_pfee != nullptr)
{
+ Assert(pg != nullptr);
// use proof equality engine to construct the trust node
return d_pfee->assertConflict(exp, pg);
}
diff --git a/src/theory/uf/proof_equality_engine.cpp b/src/theory/uf/proof_equality_engine.cpp
index 00e4662f9..fa9482094 100644
--- a/src/theory/uf/proof_equality_engine.cpp
+++ b/src/theory/uf/proof_equality_engine.cpp
@@ -221,6 +221,7 @@ TrustNode ProofEqEngine::assertConflict(const std::vector<Node>& exp,
TrustNode ProofEqEngine::assertConflict(const std::vector<Node>& exp,
ProofGenerator* pg)
{
+ Assert(pg != nullptr);
Trace("pfee") << "pfee::assertConflict " << exp << " via generator"
<< std::endl;
return assertLemma(d_false, exp, {}, pg);
@@ -306,6 +307,7 @@ TrustNode ProofEqEngine::assertLemma(Node conc,
const std::vector<Node>& noExplain,
ProofGenerator* pg)
{
+ Assert(pg != nullptr);
Trace("pfee") << "pfee::assertLemma " << conc << ", exp = " << exp
<< ", noExplain = " << noExplain << " via buffer with generator"
<< std::endl;
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback