summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaulMeng <baolmeng@gmail.com>2016-04-12 10:02:26 -0500
committerPaulMeng <baolmeng@gmail.com>2016-04-12 10:02:26 -0500
commitb8cce053839961e89ce71d7862f60b5c745258ee (patch)
treeea22330d1092e329f4ff7e470d16b329cd9104a8
parent2e5e6efa0163b6e4316133007a394856c8c02ddd (diff)
fixed explanation for transitive closure inferences
-rw-r--r--src/theory/sets/theory_sets_rels.cpp184
-rw-r--r--src/theory/sets/theory_sets_rels.h1
2 files changed, 120 insertions, 65 deletions
diff --git a/src/theory/sets/theory_sets_rels.cpp b/src/theory/sets/theory_sets_rels.cpp
index 5df44d9f8..0e20b9bfa 100644
--- a/src/theory/sets/theory_sets_rels.cpp
+++ b/src/theory/sets/theory_sets_rels.cpp
@@ -59,7 +59,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
MEM_IT m_it = d_membership_constraints_cache.begin();
while(m_it != d_membership_constraints_cache.end()) {
Node rel_rep = m_it->first;
- Trace("rels-debug") << "[sets-rels] Processing rel_rep = " << rel_rep << std::endl;
// No relational terms found with rel_rep as its representative
// But TRANSPOSE(rel_rep) may occur in the context
@@ -201,7 +200,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
* -----------------------------------------------------------
* x <= TRANSCLOSURE(x) && (x JOIN x) <= TRANSCLOSURE(x) ....
*
- * TC(x) = TC(y) => x = y
+ * TC(x) = TC(y) => x = y ?
*
*/
@@ -237,10 +236,12 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
Node tc_r_rep = getRepresentative(tc_term[0]);
// build the TC graph for tc_rep if it was not created before
- if( d_membership_tc_cache.find(tc_rep) == d_membership_tc_cache.end() ) {
+ if( d_tc_nodes.find(tc_rep) == d_tc_nodes.end() ) {
+ Trace("rels-debug") << "[sets-rels] Start building the TC graph!" << std::endl;
buildTCGraph(tc_r_rep, tc_rep, tc_term);
+ d_tc_nodes.insert(tc_rep);
}
- // insert atom[0] in the tc_graph
+ // insert atom[0] in the tc_graph if it is not in the graph already
TC_IT tc_graph_it = d_membership_tc_cache.find(tc_rep);
if(polarity) {
if(tc_graph_it != d_membership_tc_cache.end()) {
@@ -268,7 +269,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
d_membership_tc_exp_cache[tc_rep] = reason;
}
}
- // check if atom[0] exists in TC graph for conflict
+ // check if atom[0] already exists in TC graph for conflict
} else {
if(tc_graph_it != d_membership_tc_cache.end()) {
checkTCGraphForConflict(atom, tc_rep, d_trueNode, nthElementOfTuple(atom[0], 0),
@@ -284,11 +285,11 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
if(pair_set_it->second.find(b) != pair_set_it->second.end()) {
Node reason = AND(exp, findMemExp(tc_rep, constructPair(tc_rep, a, b)));
if(atom[1] != tc_rep) {
- reason = AND(exp, EQUAL(atom[1], tc_rep));
+ reason = AND(exp, explain(EQUAL(atom[1], tc_rep)));
}
Trace("rels-debug") << "[sets-rels] found a conflict and send out lemma : "
- << NodeManager::currentNM()->mkNode(kind::IMPLIES, reason, atom) << std::endl;
- d_sets_theory.d_out->lemma(NodeManager::currentNM()->mkNode(kind::IMPLIES, reason, atom));
+ << NodeManager::currentNM()->mkNode(kind::IMPLIES, Rewriter::rewrite(reason), atom) << std::endl;
+ d_sets_theory.d_out->lemma(NodeManager::currentNM()->mkNode(kind::IMPLIES, Rewriter::rewrite(reason), atom));
// Trace("rels-debug") << "[sets-rels] found a conflict and send out lemma : "
// << AND(reason.negate(), atom) << std::endl;
// d_sets_theory.d_out->conflict(AND(reason.negate(), atom));
@@ -319,53 +320,67 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
Node atom = polarity ? exp : exp[0];
Node r1_rep = getRepresentative(product_term[0]);
Node r2_rep = getRepresentative(product_term[1]);
+ Trace("rels-debug") << "\n[sets-rels] Apply PRODUCT-SPLIT rule on term: " << product_term
+ << " with explanation: " << exp << std::endl;
+ std::vector<Node> r1_element;
+ std::vector<Node> r2_element;
- if(polarity) {
- Trace("rels-debug") << "\n[sets-rels] Apply PRODUCT-SPLIT rule on term: " << product_term
- << " with explanation: " << exp << std::endl;
- std::vector<Node> r1_element;
- std::vector<Node> r2_element;
-
- NodeManager *nm = NodeManager::currentNM();
- Datatype dt = r1_rep.getType().getSetElementType().getDatatype();
- unsigned int i = 0;
- unsigned int s1_len = r1_rep.getType().getSetElementType().getTupleLength();
- unsigned int tup_len = product_term.getType().getSetElementType().getTupleLength();
+ NodeManager *nm = NodeManager::currentNM();
+ Datatype dt = r1_rep.getType().getSetElementType().getDatatype();
+ unsigned int i = 0;
+ unsigned int s1_len = r1_rep.getType().getSetElementType().getTupleLength();
+ unsigned int tup_len = product_term.getType().getSetElementType().getTupleLength();
- r1_element.push_back(Node::fromExpr(dt[0].getConstructor()));
- for(; i < s1_len; ++i) {
- r1_element.push_back(nthElementOfTuple(atom[0], i));
- }
+ r1_element.push_back(Node::fromExpr(dt[0].getConstructor()));
+ for(; i < s1_len; ++i) {
+ r1_element.push_back(nthElementOfTuple(atom[0], i));
+ }
- dt = r2_rep.getType().getSetElementType().getDatatype();
- r2_element.push_back(Node::fromExpr(dt[0].getConstructor()));
- for(; i < tup_len; ++i) {
- r2_element.push_back(nthElementOfTuple(atom[0], i));
- }
+ dt = r2_rep.getType().getSetElementType().getDatatype();
+ r2_element.push_back(Node::fromExpr(dt[0].getConstructor()));
+ for(; i < tup_len; ++i) {
+ r2_element.push_back(nthElementOfTuple(atom[0], i));
+ }
- Node fact;
- Node reason = exp;
- Node t1 = getRepresentative(nm->mkNode(kind::APPLY_CONSTRUCTOR, r1_element));
- Node t2 = getRepresentative(nm->mkNode(kind::APPLY_CONSTRUCTOR, r2_element));
-
- if(!hasMember(r1_rep, t1)) {
- fact = MEMBER( t1, r1_rep );
- if(r1_rep != product_term[0])
- reason = Rewriter::rewrite(AND(reason, EQUAL(r1_rep, product_term[0])));
- addToMap(d_membership_db, r1_rep, t1);
- addToMap(d_membership_exp_db, r1_rep, reason);
- sendInfer(fact, reason, "product-split");
+ Node fact_1;
+ Node fact_2;
+ Node reason_1 = exp;
+ Node reason_2 = exp;
+ Node t1 = nm->mkNode(kind::APPLY_CONSTRUCTOR, r1_element);
+ Node t1_rep = getRepresentative(t1);
+ Node t2 = nm->mkNode(kind::APPLY_CONSTRUCTOR, r2_element);
+ Node t2_rep = getRepresentative(t2);
+
+ fact_1 = MEMBER( t1, r1_rep );
+ fact_2 = MEMBER( t2, r2_rep );
+ if(r1_rep != product_term[0]) {
+ reason_1 = AND(reason_1, explain(EQUAL(r1_rep, product_term[0])));
+ }
+ if(t1 != t1_rep) {
+ reason_1 = Rewriter::rewrite(AND(reason_1, explain(EQUAL(t1, t1_rep))));
+ }
+ if(r2_rep != product_term[1]) {
+ reason_2 = AND(reason_2, explain(EQUAL(r2_rep, product_term[1])));
+ }
+ if(t2 != t2_rep) {
+ reason_2 = Rewriter::rewrite(AND(reason_2, explain(EQUAL(t2, t2_rep))));
+ }
+ if(polarity) {
+ if(!hasMember(r1_rep, t1_rep)) {
+ addToMap(d_membership_db, r1_rep, t1_rep);
+ addToMap(d_membership_exp_db, r1_rep, reason_1);
+ sendInfer(fact_1, reason_1, "product-split");
}
-
if(!hasMember(r2_rep, t2)) {
- fact = MEMBER( t2, r2_rep );
- if(r2_rep != product_term[1])
- reason = Rewriter::rewrite(AND(reason, EQUAL(r2_rep, product_term[1])));
addToMap(d_membership_db, r2_rep, t2);
- addToMap(d_membership_exp_db, r2_rep, reason);
- sendInfer(fact, reason, "product-split");
+ addToMap(d_membership_exp_db, r2_rep, reason_2);
+ sendInfer(fact_2, reason_2, "product-split");
}
+
} else {
+// sendInfer(fact_1.negate(), reason_1, "product-split");
+// sendInfer(fact_2.negate(), reason_2, "product-split");
+
// ONLY need to explicitly compute joins if there are negative literals involving PRODUCT
Trace("rels-debug") << "\n[sets-rels] Apply PRODUCT-COMPOSE rule on term: " << product_term
<< " with explanation: " << exp << std::endl;
@@ -528,15 +543,16 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
}
}
- // Todo: need to add equality between two pair's left and right elements as explanation
+
void TheorySetsRels::inferTC( Node exp, Node tc_rep, std::map< Node, std::hash_set< Node, NodeHashFunction > >& tc_graph,
Node start_node, Node cur_node, std::hash_set< Node, NodeHashFunction >& elements, bool first_round ) {
Node pair = constructPair(tc_rep, start_node, cur_node);
if(safeAddToMap(d_membership_db, tc_rep, pair)) {
- addToMap(d_membership_exp_db, tc_rep, exp);
- sendLemma( MEMBER(pair, tc_rep), exp, "Transitivity" );
+ addToMap(d_membership_exp_cache, tc_rep, Rewriter::rewrite(exp));
+ sendLemma( MEMBER(pair, tc_rep), Rewriter::rewrite(exp), "Transitivity" );
}
+ // check if cur_node has been traversed or not
if(!first_round) {
std::hash_set< Node, NodeHashFunction >::iterator ele_it = elements.begin();
while(ele_it != elements.end()) {
@@ -547,8 +563,10 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
}
}
std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator pair_set_it = tc_graph.begin();
+ Node reason = exp;
while(pair_set_it != tc_graph.end()) {
if(areEqual(pair_set_it->first, cur_node)) {
+ reason = AND(exp, EQUAL(pair_set_it->first, cur_node));
break;
}
pair_set_it++;
@@ -557,10 +575,9 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
for(std::hash_set< Node, NodeHashFunction >::iterator set_it = pair_set_it->second.begin();
set_it != pair_set_it->second.end(); set_it++) {
Node p = constructPair( tc_rep, cur_node, *set_it );
- Node reason = AND( findMemExp(tc_rep, p), exp );
Assert(!reason.isNull());
elements.insert(*set_it);
- inferTC( reason, tc_rep, tc_graph, start_node, *set_it, elements, false );
+ inferTC( AND( findMemExp(tc_rep, p), reason ), tc_rep, tc_graph, start_node, *set_it, elements, false );
}
}
}
@@ -574,7 +591,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
std::hash_set<Node, NodeHashFunction> elements;
Node pair = constructPair(tc_rep, pair_set_it->first, *set_it);
Node exp = findMemExp(tc_rep, pair);
- Trace("rels-debug") << "[sets-rels] pair = " << pair << std::endl;
if(d_membership_tc_exp_cache.find(tc_rep) != d_membership_tc_exp_cache.end()) {
exp = AND(d_membership_tc_exp_cache[tc_rep], exp);
}
@@ -753,7 +769,8 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
void TheorySetsRels::doPendingLemmas() {
if( !(*d_conflict) && (!d_lemma_cache.empty() || !d_pending_facts.empty())){
for( unsigned i=0; i < d_lemma_cache.size(); i++ ){
- if(holds( d_lemma_cache[i] )) {
+ Assert(d_lemma_cache[i].getKind() == kind::IMPLIES);
+ if(holds( d_lemma_cache[i][1] )) {
Trace("rels-lemma") << "[sets-rels-lemma-skip] Skip the already held lemma: "
<< d_lemma_cache[i]<< std::endl;
continue;
@@ -775,6 +792,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
d_sets_theory.d_out->lemma(NodeManager::currentNM()->mkNode(kind::IMPLIES, child_it->second, child_it->first));
}
}
+ d_tc_nodes.clear();
d_pending_facts.clear();
d_membership_constraints_cache.clear();
d_membership_tc_cache.clear();
@@ -890,7 +908,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
}
bool TheorySetsRels::areEqual( Node a, Node b ){
- Trace("rels-debug") << "[sets-rels] areEqual( a = " << a << ", b = " << b << ")" << std::endl;
if(a == b) {
return true;
} else if( hasTerm( a ) && hasTerm( b ) ){
@@ -936,28 +953,49 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
}
inline Node TheorySetsRels::getReason(Node tc_rep, Node tc_term, Node tc_r_rep, Node tc_r) {
+ Trace("rels-reason") << "[sets-rels] getReason(" << tc_rep << ", " << tc_term << ", " << tc_r_rep << ", " << tc_r << std::endl;
if(tc_term != tc_rep) {
Node reason = explain(EQUAL(tc_term, tc_rep));
if(tc_term[0] != tc_r_rep) {
return AND(reason, explain(EQUAL(tc_term[0], tc_r_rep)));
}
}
+ Trace("rels-reason") << "[sets-rels] done getReason(" << tc_rep << ", " << tc_term << ", " << tc_r_rep << ", " << tc_r << std::endl;
return Node::null();
}
- // tuple might be a member of tc_rep; or it might be a member of tc_terms
+ // tuple might be a member of tc_rep; or it might be a member of rels or tc_terms such that
+ // tc_terms are transitive closure of rels and are modulo equal to tc_rep
Node TheorySetsRels::findMemExp(Node tc_rep, Node tuple) {
Trace("rels-exp") << "TheorySetsRels::findMemExp ( tc_rep = " << tc_rep << ", tuple = " << tuple << ")" << std::endl;
std::vector<Node> tc_terms = d_terms_cache.find(tc_rep)->second[kind::TRANSCLOSURE];
Assert(tc_terms.size() > 0);
for(unsigned int i = 0; i < tc_terms.size(); i++) {
- Node r_rep = getRepresentative(tc_terms[i][0]);
- Trace("rels-exp") << "TheorySetsRels::findMemExp ( r_rep = " << r_rep << ", tuple = " << tuple << ")" << std::endl;
- std::map< Node, std::vector< Node > >::iterator tc_r_mems = d_membership_db.find(r_rep);
+ Node tc_term = tc_terms[i];
+ Node tc_r_rep = getRepresentative(tc_term[0]);
+
+ Trace("rels-exp") << "TheorySetsRels::findMemExp ( r_rep = " << tc_r_rep << ", tuple = " << tuple << ")" << std::endl;
+ std::map< Node, std::vector< Node > >::iterator tc_r_mems = d_membership_db.find(tc_r_rep);
if(tc_r_mems != d_membership_db.end()) {
for(unsigned int i = 0; i < tc_r_mems->second.size(); i++) {
if(areEqual(tc_r_mems->second[i], tuple)) {
- return explain(d_membership_exp_db[r_rep][i]);
+ Node exp = d_trueNode;
+ if(tc_r_rep != tc_term[0]) {
+ exp = explain(EQUAL(tc_r_rep, tc_term[0]));
+ }
+ if(tc_rep != tc_term) {
+ exp = AND(exp, explain(EQUAL(tc_rep, tc_term)));
+ }
+ if(tc_r_mems->second[i] != tuple) {
+ if(nthElementOfTuple(tc_r_mems->second[i], 0) != nthElementOfTuple(tuple, 0)) {
+ exp = AND(exp, explain(EQUAL(nthElementOfTuple(tc_r_mems->second[i], 0), nthElementOfTuple(tuple, 0))));
+ }
+ if(nthElementOfTuple(tc_r_mems->second[i], 1) != nthElementOfTuple(tuple, 1)) {
+ exp = AND(exp, explain(EQUAL(nthElementOfTuple(tc_r_mems->second[i], 1), nthElementOfTuple(tuple, 1))));
+ }
+ exp = AND(exp, EQUAL(tc_r_mems->second[i], tuple));
+ }
+ return Rewriter::rewrite(AND(exp, explain(d_membership_exp_db[tc_r_rep][i])));
}
}
}
@@ -966,9 +1004,25 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
std::map< Node, std::vector< Node > >::iterator tc_t_mems = d_membership_db.find(tc_term_rep);
Trace("rels-exp") << "TheorySetsRels::findMemExp ( tc_t_rep = " << tc_term_rep << ", tuple = " << tuple << ")" << std::endl;
if(tc_t_mems != d_membership_db.end()) {
- for(unsigned int i = 0; i < tc_t_mems->second.size(); i++) {
- if(areEqual(tc_t_mems->second[i], tuple)) {
- return explain(d_membership_exp_db[tc_term_rep][i]);
+ for(unsigned int j = 0; j < tc_t_mems->second.size(); j++) {
+ if(areEqual(tc_t_mems->second[j], tuple)) {
+ Node exp = d_trueNode;
+ if(tc_rep != tc_terms[i]) {
+ exp = AND(exp, explain(EQUAL(tc_rep, tc_terms[i])));
+ }
+ if(tc_term_rep != tc_terms[i]) {
+ exp = AND(exp, explain(EQUAL(tc_term_rep, tc_terms[i])));
+ }
+ if(tc_t_mems->second[j] != tuple) {
+ if(nthElementOfTuple(tc_t_mems->second[j], 0) != nthElementOfTuple(tuple, 0)) {
+ exp = AND(exp, explain(EQUAL(nthElementOfTuple(tc_t_mems->second[j], 0), nthElementOfTuple(tuple, 0))));
+ }
+ if(nthElementOfTuple(tc_t_mems->second[j], 1) != nthElementOfTuple(tuple, 1)) {
+ exp = AND(exp, explain(EQUAL(nthElementOfTuple(tc_t_mems->second[j], 1), nthElementOfTuple(tuple, 1))));
+ }
+ exp = AND(exp, EQUAL(tc_t_mems->second[j], tuple));
+ }
+ return Rewriter::rewrite(AND(exp, explain(d_membership_exp_db[tc_term_rep][j])));
}
}
}
@@ -1155,7 +1209,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
Node TheorySetsRels::explain(Node literal)
{
- Trace("rels-debug") << "[sets-rels] TheorySetsRels::explain(" << literal << ")"<< std::endl;
+ Trace("rels-exp") << "[sets-rels] TheorySetsRels::explain(" << literal << ")"<< std::endl;
bool polarity = literal.getKind() != kind::NOT;
TNode atom = polarity ? literal : literal[0];
@@ -1169,11 +1223,11 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
}
d_eqEngine->explainPredicate(atom, polarity, assumptions);
} else {
- Trace("rels-debug") << "unhandled: " << literal << "; (" << atom << ", "
+ Trace("rels-exp") << "unhandled: " << literal << "; (" << atom << ", "
<< polarity << "); kind" << atom.getKind() << std::endl;
Unhandled();
}
- Trace("rels-debug") << "[sets-rels] ****** done with TheorySetsRels::explain(" << literal << ")"<< std::endl;
+ Trace("rels-exp") << "[sets-rels] ****** done with TheorySetsRels::explain(" << literal << ")"<< std::endl;
return mkAnd(assumptions);
}
diff --git a/src/theory/sets/theory_sets_rels.h b/src/theory/sets/theory_sets_rels.h
index 8fc107a82..0876cc5b3 100644
--- a/src/theory/sets/theory_sets_rels.h
+++ b/src/theory/sets/theory_sets_rels.h
@@ -100,6 +100,7 @@ private:
NodeSet d_lemma;
NodeSet d_shared_terms;
+ std::hash_set< Node, NodeHashFunction > d_tc_nodes;
std::map< Node, std::vector<Node> > d_tuple_reps;
std::map< Node, TupleTrie > d_membership_trie;
std::hash_set< Node, NodeHashFunction > d_symbolic_tuples;
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback