summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>2013-09-27 09:27:19 -0500
committerAndrew Reynolds <andrew.j.reynolds@gmail.com>2013-09-27 09:27:29 -0500
commite277b4d220a1d15ac32f6e4fc5f06e88f55b7f68 (patch)
tree2a56691dea81453e5f9ba42e859fdc6783fa1545
parentccd1ca4c32e8a3eac8b18911a7b2d32b55203707 (diff)
Add new symmetry breaking technique for finite model finding. Improvements to bounded integer quantifier instantiation.
-rw-r--r--src/smt/options2
-rw-r--r--src/smt/smt_engine.cpp3
-rw-r--r--src/theory/quantifiers/Makefile.am5
-rw-r--r--src/theory/quantifiers/bounded_integers.cpp6
-rw-r--r--src/theory/quantifiers/bounded_integers.h1
-rw-r--r--src/theory/quantifiers/full_model_check.cpp54
-rwxr-xr-xsrc/theory/quantifiers/symmetry_breaking.cpp296
-rwxr-xr-xsrc/theory/quantifiers/symmetry_breaking.h121
-rwxr-xr-x[-rw-r--r--]src/theory/quantifiers_engine.cpp2
-rw-r--r--src/theory/rep_set.cpp7
-rw-r--r--src/theory/uf/options2
-rw-r--r--src/theory/uf/theory_uf_strong_solver.cpp131
-rw-r--r--src/theory/uf/theory_uf_strong_solver.h35
-rw-r--r--src/util/sort_inference.cpp185
-rw-r--r--src/util/sort_inference.h33
15 files changed, 770 insertions, 113 deletions
diff --git a/src/smt/options b/src/smt/options
index f39662c10..7a72881b4 100644
--- a/src/smt/options
+++ b/src/smt/options
@@ -48,7 +48,7 @@ option unconstrainedSimp --unconstrained-simp bool :default false :read-write
option repeatSimp --repeat-simp bool :read-write
make multiple passes with nonclausal simplifier
-option sortInference --sort-inference bool :default false
+option sortInference --sort-inference bool :read-write :default false
apply sort inference to input problem
common-option incrementalSolving incremental -i --incremental bool
diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp
index e1dc3531e..7fadb477b 100644
--- a/src/smt/smt_engine.cpp
+++ b/src/smt/smt_engine.cpp
@@ -1028,6 +1028,9 @@ void SmtEngine::setLogicInternal() throw() {
options::fmfInstGen.set( false );
}
}
+ if( options::ufssSymBreak() ){
+ options::sortInference.set( true );
+ }
//until bugs 371,431 are fixed
if( ! options::minisatUseElim.wasSetByUser()){
diff --git a/src/theory/quantifiers/Makefile.am b/src/theory/quantifiers/Makefile.am
index 80011868b..be24d6c67 100644
--- a/src/theory/quantifiers/Makefile.am
+++ b/src/theory/quantifiers/Makefile.am
@@ -52,7 +52,10 @@ libquantifiers_la_SOURCES = \
rewrite_engine.h \
rewrite_engine.cpp \
relevant_domain.h \
- relevant_domain.cpp
+ relevant_domain.cpp \
+ symmetry_breaking.h \
+ symmetry_breaking.cpp
+
EXTRA_DIST = \
kinds \
diff --git a/src/theory/quantifiers/bounded_integers.cpp b/src/theory/quantifiers/bounded_integers.cpp
index e1e2f96c2..30ff5242b 100644
--- a/src/theory/quantifiers/bounded_integers.cpp
+++ b/src/theory/quantifiers/bounded_integers.cpp
@@ -321,7 +321,7 @@ Node BoundedIntegers::getNextDecisionRequest() {
return Node::null();
}
-void BoundedIntegers::getBoundValues( Node f, Node v, RepSetIterator * rsi, Node & l, Node & u ) {
+void BoundedIntegers::getBounds( Node f, Node v, RepSetIterator * rsi, Node & l, Node & u ) {
l = d_bounds[0][f][v];
u = d_bounds[1][f][v];
if( d_nground_range[f].find(v)!=d_nground_range[f].end() ){
@@ -356,6 +356,10 @@ void BoundedIntegers::getBoundValues( Node f, Node v, RepSetIterator * rsi, Node
l = l.substitute( vars.begin(), vars.end(), subs.begin(), subs.end() );
}
}
+}
+
+void BoundedIntegers::getBoundValues( Node f, Node v, RepSetIterator * rsi, Node & l, Node & u ) {
+ getBounds( f, v, rsi, l, u );
Trace("bound-int-rsi") << "Get value in model for..." << l << " and " << u << std::endl;
l = d_quantEngine->getModel()->getCurrentModelValue( l );
u = d_quantEngine->getModel()->getCurrentModelValue( u );
diff --git a/src/theory/quantifiers/bounded_integers.h b/src/theory/quantifiers/bounded_integers.h
index 27d5b7569..3da938d31 100644
--- a/src/theory/quantifiers/bounded_integers.h
+++ b/src/theory/quantifiers/bounded_integers.h
@@ -115,6 +115,7 @@ public:
int getBoundVarNum( Node f, int i ) { return d_set_nums[f][i]; }
Node getLowerBound( Node f, Node v ){ return d_bounds[0][f][v]; }
Node getUpperBound( Node f, Node v ){ return d_bounds[1][f][v]; }
+ void getBounds( Node f, Node v, RepSetIterator * rsi, Node & l, Node & u );
void getBoundValues( Node f, Node v, RepSetIterator * rsi, Node & l, Node & u );
bool isGroundRange(Node f, Node v);
};
diff --git a/src/theory/quantifiers/full_model_check.cpp b/src/theory/quantifiers/full_model_check.cpp
index cdf697675..bf10369e6 100644
--- a/src/theory/quantifiers/full_model_check.cpp
+++ b/src/theory/quantifiers/full_model_check.cpp
@@ -840,14 +840,15 @@ void FullModelChecker::doCheck(FirstOrderModelFmc * fm, Node f, Def & d, Node n
Node i = fm->getUsedRepresentative( r[1] );
Node e = fm->getUsedRepresentative( r[2] );
d.addEntry(fm, mkArrayCond(i), e );
- r = r[0];
+ r = fm->getRepresentative( r[0] );
}
Node defC = mkArrayCond(fm->getStar(n.getType().getArrayIndexType()));
bool success = false;
+ Node odefaultValue;
if( r.getKind() == kind::STORE_ALL ){
ArrayStoreAll storeAll = r.getConst<ArrayStoreAll>();
- Node defaultValue = Node::fromExpr(storeAll.getExpr());
- defaultValue = fm->getUsedRepresentative( defaultValue, true );
+ odefaultValue = Node::fromExpr(storeAll.getExpr());
+ Node defaultValue = fm->getUsedRepresentative( odefaultValue, true );
if( !defaultValue.isNull() ){
d.addEntry(fm, defC, defaultValue);
success = true;
@@ -855,6 +856,7 @@ void FullModelChecker::doCheck(FirstOrderModelFmc * fm, Node f, Def & d, Node n
}
if( !success ){
Trace("fmc-warn") << "WARNING : ARRAYS : Can't process base array " << r << std::endl;
+ Trace("fmc-warn") << " Default value was : " << odefaultValue << std::endl;
Trace("fmc-debug") << "Can't process base array " << r << std::endl;
//can't process this array
d.reset();
@@ -1191,29 +1193,35 @@ bool FullModelChecker::doMeet( FirstOrderModelFmc * fm, std::vector< Node > & co
}
Node FullModelChecker::doIntervalMeet( FirstOrderModelFmc * fm, Node i1, Node i2, bool mk ) {
- if( !fm->isInterval( i1 ) || !fm->isInterval( i2 ) ){
- std::cout << "Not interval during meet! " << i1 << " " << i2 << std::endl;
- exit( 0 );
- }
- Node b[2];
- for( unsigned j=0; j<2; j++ ){
- Node b1 = i1[j];
- Node b2 = i2[j];
- if( fm->isStar( b1 ) ){
- b[j] = b2;
- }else if( fm->isStar( b2 ) ){
- b[j] = b1;
- }else if( b1.getConst<Rational>() < b2.getConst<Rational>() ){
- b[j] = j==0 ? b2 : b1;
+ if( fm->isStar( i1 ) ){
+ return i2;
+ }else if( fm->isStar( i2 ) ){
+ return i1;
+ }else{
+ if( !fm->isInterval( i1 ) || !fm->isInterval( i2 ) ){
+ std::cout << "Not interval during meet! " << i1 << " " << i2 << std::endl;
+ exit( 0 );
+ }
+ Node b[2];
+ for( unsigned j=0; j<2; j++ ){
+ Node b1 = i1[j];
+ Node b2 = i2[j];
+ if( fm->isStar( b1 ) ){
+ b[j] = b2;
+ }else if( fm->isStar( b2 ) ){
+ b[j] = b1;
+ }else if( b1.getConst<Rational>() < b2.getConst<Rational>() ){
+ b[j] = j==0 ? b2 : b1;
+ }else{
+ b[j] = j==0 ? b1 : b2;
+ }
+ }
+ if( fm->isStar( b[0] ) || fm->isStar( b[1] ) || b[0].getConst<Rational>() < b[1].getConst<Rational>() ){
+ return mk ? fm->getInterval( b[0], b[1] ) : i1;
}else{
- b[j] = j==0 ? b1 : b2;
+ return Node::null();
}
}
- if( fm->isStar( b[0] ) || fm->isStar( b[1] ) || b[0].getConst<Rational>() < b[1].getConst<Rational>() ){
- return mk ? fm->getInterval( b[0], b[1] ) : i1;
- }else{
- return Node::null();
- }
}
Node FullModelChecker::mkCond( std::vector< Node > & cond ) {
diff --git a/src/theory/quantifiers/symmetry_breaking.cpp b/src/theory/quantifiers/symmetry_breaking.cpp
new file mode 100755
index 000000000..6a7baebb1
--- /dev/null
+++ b/src/theory/quantifiers/symmetry_breaking.cpp
@@ -0,0 +1,296 @@
+/********************* */
+/*! \file symmetry_breaking.cpp
+ ** \verbatim
+ ** Original author: ajreynol
+ ** Major contributors: none
+ ** Minor contributors (to current version): none
+ ** This file is part of the CVC4 prototype.
+ ** Copyright (c) 2009-2012 New York University and The University of Iowa
+ ** See the file COPYING in the top-level source directory for licensing
+ ** information.\endverbatim
+ **
+ ** \brief symmetry breaking module
+ **
+ **/
+
+#include <vector>
+
+#include "theory/quantifiers/symmetry_breaking.h"
+#include "theory/rewriter.h"
+#include "theory/quantifiers_engine.h"
+#include "theory/theory_engine.h"
+#include "util/sort_inference.h"
+#include "theory/uf/theory_uf_strong_solver.h"
+
+using namespace CVC4;
+using namespace CVC4::kind;
+using namespace CVC4::theory;
+using namespace std;
+
+namespace CVC4 {
+
+eq::EqualityEngine * SubsortSymmetryBreaker::getEqualityEngine() {
+ return ((uf::TheoryUF*)d_qe->getTheoryEngine()->theoryOf( theory::THEORY_UF ))->getEqualityEngine();
+}
+
+bool SubsortSymmetryBreaker::areEqual( Node n1, Node n2 ) {
+ return getEqualityEngine()->hasTerm( n1 ) && getEqualityEngine()->hasTerm( n2 ) && getEqualityEngine()->areEqual( n1,n2 );
+}
+
+bool SubsortSymmetryBreaker::areDisequal( Node n1, Node n2 ) {
+ return getEqualityEngine()->hasTerm( n1 ) && getEqualityEngine()->hasTerm( n2 ) && getEqualityEngine()->areDisequal( n1,n2, false );
+}
+
+
+Node SubsortSymmetryBreaker::getRepresentative( Node n ) {
+ return getEqualityEngine()->getRepresentative( n );
+}
+
+uf::StrongSolverTheoryUF * SubsortSymmetryBreaker::getStrongSolver() {
+ return ((uf::TheoryUF*)d_qe->getTheoryEngine()->theoryOf( theory::THEORY_UF ))->getStrongSolver();
+}
+
+SubsortSymmetryBreaker::SubsortSymmetryBreaker(QuantifiersEngine* qe, context::Context* c) :
+d_qe(qe), d_conflict(c,false), d_max_dom_const_sort(c,0), d_has_dom_const_sort(c,false),
+d_fact_index(c,0), d_fact_list(c) {
+ d_true = NodeManager::currentNM()->mkConst( true );
+}
+
+SubsortSymmetryBreaker::TypeInfo::TypeInfo( SubsortSymmetryBreaker * ssb, context::Context * c ) :
+d_ssb( ssb ), d_dom_constants( c ), d_first_active( c, 0 ){
+ d_dc_nodes = 0;
+}
+
+unsigned SubsortSymmetryBreaker::TypeInfo::getNumDomainConstants() {
+ if( d_nodes.empty() ){
+ return 0;
+ }else{
+ return 1 + d_dom_constants.size();
+ }
+}
+
+Node SubsortSymmetryBreaker::TypeInfo::getDomainConstant( int i ) {
+ if( i==0 ){
+ return d_nodes[0];
+ }else{
+ Assert( i<=(int)d_dom_constants.size() );
+ return d_dom_constants[i-1];
+ }
+}
+
+Node SubsortSymmetryBreaker::TypeInfo::getFirstActive() {
+ if( d_first_active.get()<(int)d_nodes.size() ){
+ Node fa = d_nodes[d_first_active.get()];
+ return d_ssb->getEqualityEngine()->hasTerm( fa ) ? fa : Node::null();
+ }else{
+ return Node::null();
+ }
+}
+
+SubsortSymmetryBreaker::TypeInfo * SubsortSymmetryBreaker::getTypeInfo( TypeNode tn, int sid ) {
+ if( d_type_info.find( sid )==d_type_info.end() ){
+ d_type_info[sid] = new TypeInfo( this, d_qe->getSatContext() );
+ d_sub_sorts[tn].push_back( sid );
+ d_sid_to_type[sid] = tn;
+ }
+ return d_type_info[sid];
+}
+
+void SubsortSymmetryBreaker::newEqClass( Node n ) {
+ Trace("sym-break-temp") << "New eq class " << n << std::endl;
+ if( !d_conflict ){
+ TypeNode tn = n.getType();
+ SortInference * si = d_qe->getTheoryEngine()->getSortInference();
+ if( si->isWellSorted( n ) ){
+ int sid = si->getSortId( n );
+ Trace("sym-break-debug") << "SSB: New eq class " << n << " : " << n.getType() << " : " << sid << std::endl;
+ TypeInfo * ti = getTypeInfo( tn, sid );
+ if( std::find( ti->d_nodes.begin(), ti->d_nodes.end(), n )==ti->d_nodes.end() ){
+ if( ti->d_nodes.empty() ){
+ //for first subsort, we add unit equality
+ if( d_sub_sorts[tn][0]!=sid ){
+ Trace("sym-break-temp") << "Do sym break unit with " << d_type_info[d_sub_sorts[tn][0]]->getBaseConstant() << std::endl;
+ //add unit symmetry breaking lemma
+ Node eq = n.eqNode( d_type_info[d_sub_sorts[tn][0]]->getBaseConstant() );
+ eq = Rewriter::rewrite( eq );
+ d_unit_lemmas.push_back( eq );
+ Trace("sym-break-lemma") << "*** SymBreak : Unit lemma (" << sid << "==" << d_sub_sorts[tn][0] << ") : " << eq << std::endl;
+ d_pending_lemmas.push_back( eq );
+ }
+ Trace("sym-break-dc") << "* Set first domain constant : " << n << " for " << tn << " : " << sid << std::endl;
+ ti->d_dc_nodes++;
+ }
+ ti->d_node_to_id[n] = ti->d_nodes.size();
+ ti->d_nodes.push_back( n );
+ }
+ if( !d_has_dom_const_sort.get() ){
+ d_has_dom_const_sort.set( true );
+ d_max_dom_const_sort.set( sid );
+ }
+ }
+ }
+ Trace("sym-break-temp") << "Done new eq class" << std::endl;
+}
+
+
+
+void SubsortSymmetryBreaker::merge( Node a, Node b ) {
+
+}
+
+void SubsortSymmetryBreaker::assertDisequal( Node a, Node b ) {
+
+}
+
+void SubsortSymmetryBreaker::processFirstActive( TypeNode tn, int sid, int curr_card ){
+ TypeInfo * ti = getTypeInfo( tn, sid );
+ if( (int)ti->getNumDomainConstants()<curr_card ){
+ Trace("sym-break-dc-debug") << "Check for domain constants " << tn << " : " << sid << ", curr_card = " << curr_card << ", ";
+ Trace("sym-break-dc-debug") << "#domain constants = " << ti->getNumDomainConstants() << std::endl;
+ Node fa = ti->getFirstActive();
+ bool invalid = true;
+ while( invalid && !fa.isNull() && (int)ti->getNumDomainConstants()<curr_card ){
+ invalid = false;
+ unsigned deq = 0;
+ for( unsigned i=0; i<ti->getNumDomainConstants(); i++ ){
+ Node dc = ti->getDomainConstant( i );
+ if( areEqual( fa, dc ) ){
+ invalid = true;
+ break;
+ }else if( areDisequal( fa, dc ) ){
+ deq++;
+ }
+ }
+ if( deq==ti->getNumDomainConstants() ){
+ Trace("sym-break-dc") << "* Can infer domain constant #" << ti->getNumDomainConstants()+1;
+ Trace("sym-break-dc") << " : " << fa << " for " << tn << " : " << sid << std::endl;
+ //add to domain constants
+ ti->d_dom_constants.push_back( fa );
+ if( ti->d_node_to_id[fa]>ti->d_dc_nodes ){
+ Trace("sym-break-dc-debug") << "Swap nodes... " << ti->d_dc_nodes << " " << ti->d_node_to_id[fa] << " " << ti->d_nodes.size() << std::endl;
+ //swap
+ Node on = ti->d_nodes[ti->d_dc_nodes];
+ int id = ti->d_node_to_id[fa];
+
+ ti->d_nodes[ti->d_dc_nodes] = fa;
+ ti->d_nodes[id] = on;
+ ti->d_node_to_id[fa] = ti->d_dc_nodes;
+ ti->d_node_to_id[on] = id;
+ }
+ ti->d_dc_nodes++;
+ Trace("sym-break-dc-debug") << "Get max type info..." << std::endl;
+ Assert( d_has_dom_const_sort.get() );
+ int msid = d_max_dom_const_sort.get();
+ TypeInfo * max_ti = getTypeInfo( d_sid_to_type[msid], msid );
+ Trace("sym-break-dc-debug") << "Swap nodes..." << std::endl;
+ //now, check if we can apply symmetry breaking to another sort
+ if( ti->getNumDomainConstants()>max_ti->getNumDomainConstants() ){
+ Trace("sym-break-dc") << "Max domain constant subsort for " << tn << " becomes " << sid << std::endl;
+ d_max_dom_const_sort.set( sid );
+ }else if( ti!=max_ti ){
+ //construct symmetry breaking lemma
+ //current domain constant must be disequal from all current ones
+ Trace("sym-break-dc") << "Get domain constant " << ti->getNumDomainConstants()-1;
+ Trace("sym-break-dc") << " from max_ti, " << max_ti->getNumDomainConstants() << std::endl;
+ //apply a symmetry breaking lemma
+ Node m = max_ti->getDomainConstant(ti->getNumDomainConstants()-1);
+ //if fa and m are disequal from all previous domain constants in the other sort
+ std::vector< Node > cc;
+ for( unsigned r=0; r<2; r++ ){
+ Node n = ((r==0)==(msid>sid)) ? fa : m;
+ Node on = ((r==0)==(msid>sid)) ? m : fa;
+ TypeInfo * t = ((r==0)==(msid>sid)) ? max_ti : ti;
+ for( unsigned i=0; i<t->d_node_to_id[on]; i++ ){
+ cc.push_back( n.eqNode( t->d_nodes[i] ) );
+ }
+ }
+ //then, we can assume fa = m
+ cc.push_back( fa.eqNode( m ) );
+ Node lem = NodeManager::currentNM()->mkNode( kind::OR, cc );
+ lem = Rewriter::rewrite( lem );
+ if( std::find( d_lemmas.begin(), d_lemmas.end(), lem )==d_lemmas.end() ){
+ d_lemmas.push_back( lem );
+ Trace("sym-break-lemma") << "*** Symmetry break lemma for " << tn << " (" << sid << "==" << d_max_dom_const_sort.get() << ") : ";
+ Trace("sym-break-lemma") << lem << std::endl;
+ d_pending_lemmas.push_back( lem );
+ }
+ }
+ invalid = true;
+ }
+ if( invalid ){
+ ti->d_first_active.set( ti->d_first_active + 1 );
+ fa = ti->getFirstActive();
+ }
+ }
+ }
+}
+
+void SubsortSymmetryBreaker::printDebugTypeInfo( const char * c, TypeNode tn, int sid ) {
+ Trace(c) << "TypeInfo( " << tn << ", " << sid << " ) = " << std::endl;
+ Trace(c) << " Domain constants : ";
+ TypeInfo * ti = getTypeInfo( tn, sid );
+ for( NodeList::const_iterator it = ti->d_dom_constants.begin(); it != ti->d_dom_constants.end(); ++it ){
+ Node dc = *it;
+ Trace(c) << dc << " ";
+ }
+ Trace(c) << std::endl;
+ Trace(c) << " First active node : " << ti->getFirstActive() << std::endl;
+}
+
+
+void SubsortSymmetryBreaker::queueFact( Node n ) {
+ d_fact_list.push_back( n );
+ /*
+ if( n.getKind()==EQUAL ){
+ merge( n[0], n[1] );
+ }else if( n.getKind()==NOT && n[0].getKind()==EQUAL ){
+ assertDisequal( n[0][0], n[0][1] );
+ }else{
+ newEqClass( n );
+ }
+ */
+}
+
+bool SubsortSymmetryBreaker::check( Theory::Effort level ) {
+ d_pending_lemmas.clear();
+
+ Trace("sym-break-debug") << "SymBreak : check " << level << std::endl;
+ while( d_fact_index.get()<d_fact_list.size() ){
+ Node f = d_fact_list[d_fact_index.get()];
+ d_fact_index.set( d_fact_index.get() + 1 );
+ if( f.getKind()==EQUAL ){
+ merge( f[0], f[1] );
+ }else if( f.getKind()==NOT && f[0].getKind()==EQUAL ){
+ assertDisequal( f[0][0], f[0][1] );
+ }else{
+ newEqClass( f );
+ }
+ }
+ Trace("sym-break-debug") << "SymBreak : update first actives" << std::endl;
+ for( std::map< TypeNode, std::vector< int > >::iterator it = d_sub_sorts.begin(); it != d_sub_sorts.end(); ++it ){
+ int card = getStrongSolver()->getCardinality( it->first );
+ for( unsigned i=0; i<it->second.size(); i++ ){
+ //check if the first active is disequal from all domain constants
+ processFirstActive( it->first, it->second[i], card );
+ }
+ }
+
+
+ Trace("sym-break-debug") << "SymBreak : finished check, now flush lemmas... (#lemmas = " << d_pending_lemmas.size() << ")" << std::endl;
+ //flush pending lemmas
+ if( !d_pending_lemmas.empty() ){
+ for( unsigned i=0; i<d_pending_lemmas.size(); i++ ){
+ getStrongSolver()->getOutputChannel().lemma( d_pending_lemmas[i] );
+ ++( getStrongSolver()->d_statistics.d_sym_break_lemmas );
+ }
+ d_pending_lemmas.clear();
+ return true;
+ }else{
+ return false;
+ }
+}
+
+
+
+}
+
diff --git a/src/theory/quantifiers/symmetry_breaking.h b/src/theory/quantifiers/symmetry_breaking.h
new file mode 100755
index 000000000..3db9097f5
--- /dev/null
+++ b/src/theory/quantifiers/symmetry_breaking.h
@@ -0,0 +1,121 @@
+/********************* */
+/*! \file symmetry_breaking.h
+ ** \verbatim
+ ** Original author: ajreynol
+ ** Major contributors: none
+ ** Minor contributors (to current version): none
+ ** This file is part of the CVC4 prototype.
+ ** Copyright (c) 2009-2012 New York University and The University of Iowa
+ ** See the file COPYING in the top-level source directory for licensing
+ ** information.\endverbatim
+ **
+ ** \brief Pre-process step for first-order reasoning
+ **/
+
+#include "cvc4_private.h"
+
+#ifndef __CVC4__QUANT_SYMMETRY_BREAKING_H
+#define __CVC4__QUANT_SYMMETRY_BREAKING_H
+
+#include "theory/theory.h"
+
+#include <iostream>
+#include <string>
+#include <vector>
+#include <map>
+#include "expr/node.h"
+#include "expr/type_node.h"
+
+#include "util/sort_inference.h"
+#include "context/context.h"
+#include "context/context_mm.h"
+#include "context/cdchunk_list.h"
+
+namespace CVC4 {
+namespace theory {
+
+namespace uf {
+ class StrongSolverTheoryUF;
+}
+
+class SubsortSymmetryBreaker {
+ typedef context::CDHashMap<Node, bool, NodeHashFunction> NodeBoolMap;
+ typedef context::CDHashMap<Node, int, NodeHashFunction> NodeIntMap;
+ typedef context::CDHashMap<Node, Node, NodeHashFunction> NodeNodeMap;
+ //typedef context::CDChunkList<int> IntList;
+ typedef context::CDList<Node> NodeList;
+ typedef context::CDHashMap<Node, NodeList*, NodeHashFunction> NodeListMap;
+private:
+ /** quantifiers engine */
+ QuantifiersEngine* d_qe;
+ eq::EqualityEngine * getEqualityEngine();
+ bool areDisequal( Node n1, Node n2 );
+ bool areEqual( Node n1, Node n2 );
+ Node getRepresentative( Node n );
+ uf::StrongSolverTheoryUF * getStrongSolver();
+ std::vector< Node > d_unit_lemmas;
+ Node d_true;
+ context::CDO< bool > d_conflict;
+public:
+ SubsortSymmetryBreaker( QuantifiersEngine* qe, context::Context* c );
+ ~SubsortSymmetryBreaker(){}
+
+private:
+ class TypeInfo {
+ private:
+ SubsortSymmetryBreaker * d_ssb;
+ //bool isActive( Node n, unsigned & deq );
+ public:
+ TypeInfo( SubsortSymmetryBreaker * ssb, context::Context* c );
+ //list of all nodes from this (sub)type
+ std::vector< Node > d_nodes;
+ //the current domain constants for this (sub)type
+ NodeList d_dom_constants;
+ //# nodes in d_nodes that have been domain constants, size of this distinct # of domain constants seen
+ unsigned d_dc_nodes;
+ //the node we are currently watching to become a domain constant
+ context::CDO< int > d_first_active;
+ //node to id
+ std::map< Node, unsigned > d_node_to_id;
+ Node getBaseConstant() { return d_nodes.empty() ? Node::null() : d_nodes[0]; }
+ bool hasDomainConstant( Node n );
+ unsigned getNumDomainConstants();
+ Node getDomainConstant( int i );
+ Node getFirstActive();
+ };
+ std::map< TypeNode, std::vector< int > > d_sub_sorts;
+ std::map< int, TypeNode > d_sid_to_type;
+ std::map< int, TypeInfo * > d_type_info;
+
+ //maximum domain constants sort
+ context::CDO< int > d_max_dom_const_sort;
+ context::CDO< bool > d_has_dom_const_sort;
+
+ TypeInfo * getTypeInfo( TypeNode tn, int sid );
+
+ void processFirstActive( TypeNode tn, int sid, int curr_card );
+private:
+ //void printDebugNodeInfo( const char * c, Node n );
+ void printDebugTypeInfo( const char * c, TypeNode tn, int sid );
+ /** new node */
+ void newEqClass( Node n );
+ /** merge */
+ void merge( Node a, Node b );
+ /** assert disequal */
+ void assertDisequal( Node a, Node b );
+ /** fact list */
+ context::CDO< unsigned > d_fact_index;
+ NodeList d_fact_list;
+ std::vector< Node > d_pending_lemmas;
+ std::vector< Node > d_lemmas;
+public:
+ /** queue fact */
+ void queueFact( Node n );
+ /** check */
+ bool check( Theory::Effort level );
+};
+
+}
+}
+
+#endif
diff --git a/src/theory/quantifiers_engine.cpp b/src/theory/quantifiers_engine.cpp
index e5cc8a1fb..0fe50aad6 100644..100755
--- a/src/theory/quantifiers_engine.cpp
+++ b/src/theory/quantifiers_engine.cpp
@@ -29,6 +29,7 @@
#include "theory/rewriterules/rr_trigger.h"
#include "theory/quantifiers/bounded_integers.h"
#include "theory/quantifiers/rewrite_engine.h"
+#include "theory/uf/options.h"
using namespace std;
using namespace CVC4;
@@ -632,6 +633,7 @@ Node EqualityQueryQuantifiersEngine::getInternalRepresentative( Node a, Node f,
}else{
int sortId = 0;
if( optInternalRepSortInference() ){
+ //if( options::ufssSymBreak() ){
sortId = d_qe->getTheoryEngine()->getSortInference()->getSortId( f, f[0][index] );
}
if( d_int_rep[sortId].find( r )==d_int_rep[sortId].end() ){
diff --git a/src/theory/rep_set.cpp b/src/theory/rep_set.cpp
index 647ef965a..800e007f7 100644
--- a/src/theory/rep_set.cpp
+++ b/src/theory/rep_set.cpp
@@ -278,7 +278,12 @@ bool RepSetIterator::resetIndex( int i, bool initial ) {
Node range = Rewriter::rewrite( NodeManager::currentNM()->mkNode( MINUS, u, l ) );
Node ra = Rewriter::rewrite( NodeManager::currentNM()->mkNode( LEQ, range, NodeManager::currentNM()->mkConst( Rational( 9999 ) ) ) );
d_domain[ii].clear();
- d_lower_bounds[ii] = l;
+ Node tl = l;
+ Node tu = u;
+ if( d_qe->getBoundedIntegers() && d_qe->getBoundedIntegers()->isBoundVar( d_owner, d_owner[0][ii] ) ){
+ d_qe->getBoundedIntegers()->getBounds( d_owner, d_owner[0][ii], this, tl, tu );
+ }
+ d_lower_bounds[ii] = tl;
if( ra==NodeManager::currentNM()->mkConst(true) ){
long rr = range.getConst<Rational>().getNumerator().getLong()+1;
Trace("bound-int-rsi") << "Actual bound range is " << rr << std::endl;
diff --git a/src/theory/uf/options b/src/theory/uf/options
index 437e30e46..b9f60b83d 100644
--- a/src/theory/uf/options
+++ b/src/theory/uf/options
@@ -40,5 +40,7 @@ option ufssMinimalModel /--disable-uf-ss-min-model bool :default true
option ufssCliqueSplits --uf-ss-clique-splits bool :default false
use cliques instead of splitting on demand to shrink model
+option ufssSymBreak --uf-ss-sym-break bool :default false
+ finite model finding symmetry breaking techniques
endmodule
diff --git a/src/theory/uf/theory_uf_strong_solver.cpp b/src/theory/uf/theory_uf_strong_solver.cpp
index adcf78a86..82cd1f809 100644
--- a/src/theory/uf/theory_uf_strong_solver.cpp
+++ b/src/theory/uf/theory_uf_strong_solver.cpp
@@ -20,6 +20,8 @@
#include "theory/quantifiers/term_database.h"
#include "theory/uf/options.h"
#include "theory/model.h"
+#include "theory/quantifiers/symmetry_breaking.h"
+
//#define ONE_SPLIT_REGION
//#define DISABLE_QUICK_CLIQUE_CHECKS
@@ -117,6 +119,10 @@ void StrongSolverTheoryUF::SortModel::Region::setEqual( Node a, Node b ){
if( options::ufssDiseqPropagation() ){
d_cf->d_thss->getDisequalityPropagator()->assertDisequal(a, n, Node::null());
}
+ if( options::ufssSymBreak() ){
+ //d_cf->d_thss->getSymmetryBreaker()->assertDisequal( a, n );
+ d_cf->d_thss->getSymmetryBreaker()->queueFact( a.eqNode( n ).negate() );
+ }
}
setDisequal( b, n, t, false );
nr->setDisequal( n, b, t, false );
@@ -515,9 +521,15 @@ void StrongSolverTheoryUF::SortModel::merge( Node a, Node b ){
}
d_reps = d_reps - 1;
- if( options::ufssDiseqPropagation() && !d_conflict ){
- //notify the disequality propagator
- d_thss->getDisequalityPropagator()->merge(a, b);
+ if( !d_conflict ){
+ if( options::ufssDiseqPropagation() ){
+ //notify the disequality propagator
+ d_thss->getDisequalityPropagator()->merge(a, b);
+ }
+ if( options::ufssSymBreak() ){
+ //d_thss->getSymmetryBreaker()->merge(a, b);
+ d_thss->getSymmetryBreaker()->queueFact( a.eqNode( b ) );
+ }
}
}
}
@@ -565,9 +577,15 @@ void StrongSolverTheoryUF::SortModel::assertDisequal( Node a, Node b, Node reaso
checkRegion( bi );
}
- if( options::ufssDiseqPropagation() && !d_conflict ){
- //notify the disequality propagator
- d_thss->getDisequalityPropagator()->assertDisequal(a, b, Node::null());
+ if( !d_conflict ){
+ if( options::ufssDiseqPropagation() ){
+ //notify the disequality propagator
+ d_thss->getDisequalityPropagator()->assertDisequal(a, b, Node::null());
+ }
+ if( options::ufssSymBreak() ){
+ //d_thss->getSymmetryBreaker()->assertDisequal(a, b);
+ d_thss->getSymmetryBreaker()->queueFact( a.eqNode( b ).negate() );
+ }
}
}
}
@@ -670,7 +688,7 @@ void StrongSolverTheoryUF::SortModel::check( Theory::Effort level, OutputChannel
for( int i=0; i<(int)d_regions_index; i++ ){
if( d_regions[i]->d_valid ){
Node op = d_regions[i]->d_nodes.begin()->first;
- int sort_id = d_thss->getTheory()->getQuantifiersEngine()->getTheoryEngine()->getSortInference()->getSortId(op);
+ int sort_id = d_thss->getSortInference()->getSortId(op);
if( sortsFound.find( sort_id )!=sortsFound.end() ){
combineRegions( sortsFound[sort_id], i );
recheck = true;
@@ -979,17 +997,32 @@ void StrongSolverTheoryUF::SortModel::moveNode( Node n, int ri ){
void StrongSolverTheoryUF::SortModel::allocateCardinality( OutputChannel* out ){
if( d_aloc_cardinality>0 ){
Trace("uf-ss-fmf") << "No model of size " << d_aloc_cardinality << " exists for type " << d_type << " in this branch" << std::endl;
- if( Trace.isOn("uf-ss-cliques") ){
- Trace("uf-ss-cliques") << "Cliques of size " << (d_aloc_cardinality+1) << " : " << std::endl;
- for( size_t i=0; i<d_cliques[ d_aloc_cardinality ].size(); i++ ){
- Trace("uf-ss-cliques") << " ";
- for( size_t j=0; j<d_cliques[ d_aloc_cardinality ][i].size(); j++ ){
- Trace("uf-ss-cliques") << d_cliques[ d_aloc_cardinality ][i][j] << " ";
- }
- Trace("uf-ss-cliques") << std::endl;
+ }
+ if( Trace.isOn("uf-ss-cliques") ){
+ Trace("uf-ss-cliques") << "Cliques of size " << (d_aloc_cardinality+1) << " for " << d_type << " : " << std::endl;
+ for( size_t i=0; i<d_cliques[ d_aloc_cardinality ].size(); i++ ){
+ Trace("uf-ss-cliques") << " ";
+ for( size_t j=0; j<d_cliques[ d_aloc_cardinality ][i].size(); j++ ){
+ Trace("uf-ss-cliques") << d_cliques[ d_aloc_cardinality ][i][j] << " ";
}
+ Trace("uf-ss-cliques") << std::endl;
+ }
+ }
+ /*
+ if( options::ufssSymBreak() ){
+ std::vector< Node > reps;
+ getRepresentatives( reps );
+ if( d_aloc_cardinality>0 ){
+ d_thss->getSymmetryBreaker()->allocateCardinality( out, d_type, d_aloc_cardinality+1, d_cliques[ d_aloc_cardinality ], reps );
+ }else{
+ std::vector< Node > clique;
+ clique.push_back( d_cardinality_term );
+ std::vector< std::vector< Node > > cliques;
+ cliques.push_back( clique );
+ d_thss->getSymmetryBreaker()->allocateCardinality( out, d_type, 1, cliques, reps );
}
}
+ */
d_aloc_cardinality = d_aloc_cardinality + 1;
//check for abort case
@@ -1094,7 +1127,7 @@ bool StrongSolverTheoryUF::SortModel::addSplit( Region* r, OutputChannel* out ){
Trace("uf-ss-lemma") << "*** Split on " << s << std::endl;
if( options::sortInference()) {
for( int i=0; i<2; i++ ){
- int si = d_thss->getTheory()->getQuantifiersEngine()->getTheoryEngine()->getSortInference()->getSortId( s[i] );
+ int si = d_thss->getSortInference()->getSortId( s[i] );
Trace("uf-ss-split-si") << si << " ";
}
Trace("uf-ss-split-si") << std::endl;
@@ -1122,10 +1155,10 @@ void StrongSolverTheoryUF::SortModel::addCliqueLemma( std::vector< Node >& cliqu
clique.pop_back();
}
//debugging information
- if( Trace.isOn("uf-ss-cliques") ){
+ if( options::ufssSymBreak() ){
std::vector< Node > clique_vec;
clique_vec.insert( clique_vec.begin(), clique.begin(), clique.end() );
- d_cliques[ d_cardinality ].push_back( clique_vec );
+ addClique( d_cardinality, clique_vec );
}
if( options::ufssSimpleCliques() && !options::ufssExplainedCliques() ){
//add as lemma
@@ -1273,7 +1306,7 @@ void StrongSolverTheoryUF::SortModel::addTotalityAxiom( Node n, int cardinality,
Node cardLit = d_cardinality_literal[ cardinality ];
int sort_id = 0;
if( options::sortInference() ){
- sort_id = d_thss->getTheory()->getQuantifiersEngine()->getTheoryEngine()->getSortInference()->getSortId(n);
+ sort_id = d_thss->getSortInference()->getSortId(n);
}
Trace("uf-ss-totality") << "Add totality lemma for " << n << " " << cardinality << ", sort id is " << sort_id << std::endl;
int use_cardinality = cardinality;
@@ -1302,6 +1335,14 @@ void StrongSolverTheoryUF::SortModel::addTotalityAxiom( Node n, int cardinality,
}
}
+void StrongSolverTheoryUF::SortModel::addClique( int c, std::vector< Node >& clique ) {
+
+ if( d_clique_trie[c].add( clique ) ){
+ d_cliques[ c ].push_back( clique );
+ }
+}
+
+
/** apply totality */
bool StrongSolverTheoryUF::SortModel::applyTotality( int cardinality ){
return options::ufssTotality() || cardinality<=options::ufssTotalityLimited();
@@ -1379,22 +1420,16 @@ int StrongSolverTheoryUF::SortModel::getNumRegions(){
}
void StrongSolverTheoryUF::SortModel::getRepresentatives( std::vector< Node >& reps ){
- //if( !options::ufssColoringSat() ){
- bool foundRegion = false;
- for( int i=0; i<(int)d_regions_index; i++ ){
- //should not have multiple regions at this point
- if( foundRegion ){
- Assert( !d_regions[i]->d_valid );
- }
- if( d_regions[i]->d_valid ){
- //this is the only valid region
- d_regions[i]->getRepresentatives( reps );
- foundRegion = true;
- }
+ for( int i=0; i<(int)d_regions_index; i++ ){
+ //should not have multiple regions at this point
+ //if( foundRegion ){
+ // Assert( !d_regions[i]->d_valid );
+ //}
+ if( d_regions[i]->d_valid ){
+ //this is the only valid region
+ d_regions[i]->getRepresentatives( reps );
}
- //}else{
- // Unimplemented("Build representatives for fmf region sat is not implemented");
- //}
+ }
}
StrongSolverTheoryUF::StrongSolverTheoryUF(context::Context* c, context::UserContext* u, OutputChannel& out, TheoryUF* th) :
@@ -1415,6 +1450,15 @@ d_rep_model_init( c )
}else{
d_deq_prop = NULL;
}
+ if( options::ufssSymBreak() ){
+ d_sym_break = new SubsortSymmetryBreaker( th->getQuantifiersEngine(), c );
+ }else{
+ d_sym_break = NULL;
+ }
+}
+
+SortInference* StrongSolverTheoryUF::getSortInference() {
+ return d_th->getQuantifiersEngine()->getTheoryEngine()->getSortInference();
}
/** get default sat context */
@@ -1433,6 +1477,10 @@ void StrongSolverTheoryUF::newEqClass( Node n ){
if( c ){
Trace("uf-ss-solver") << "StrongSolverTheoryUF: New eq class " << n << " : " << n.getType() << std::endl;
c->newEqClass( n );
+ if( options::ufssSymBreak() ){
+ //d_sym_break->newEqClass( n );
+ d_sym_break->queueFact( n );
+ }
}
}
@@ -1539,6 +1587,10 @@ void StrongSolverTheoryUF::check( Theory::Effort level ){
break;
}
}
+ //check symmetry breaker
+ if( !d_conflict && options::ufssSymBreak() ){
+ d_sym_break->check( level );
+ }
//disambiguate terms if necessary
//if( !d_conflict && level==Theory::EFFORT_FULL && options::ufssColoringSat() ){
// Assert( d_term_amb!=NULL );
@@ -1644,6 +1696,14 @@ int StrongSolverTheoryUF::getCardinality( Node n ) {
}
}
+int StrongSolverTheoryUF::getCardinality( TypeNode tn ) {
+ std::map< TypeNode, SortModel* >::iterator it = d_rep_model.find( tn );
+ if( it!=d_rep_model.end() && it->second ){
+ return it->second->getCardinality();
+ }
+ return -1;
+}
+
void StrongSolverTheoryUF::getRepresentatives( Node n, std::vector< Node >& reps ){
SortModel* c = getSortModel( n );
if( c ){
@@ -1698,6 +1758,7 @@ StrongSolverTheoryUF::Statistics::Statistics():
d_clique_lemmas("StrongSolverTheoryUF::Clique_Lemmas", 0),
d_split_lemmas("StrongSolverTheoryUF::Split_Lemmas", 0),
d_disamb_term_lemmas("StrongSolverTheoryUF::Disambiguate_Term_Lemmas", 0),
+ d_sym_break_lemmas("StrongSolverTheoryUF::Symmetry_Breaking_Lemmas", 0),
d_totality_lemmas("StrongSolverTheoryUF::Totality_Lemmas", 0),
d_max_model_size("StrongSolverTheoryUF::Max_Model_Size", 1)
{
@@ -1705,6 +1766,7 @@ StrongSolverTheoryUF::Statistics::Statistics():
StatisticsRegistry::registerStat(&d_clique_lemmas);
StatisticsRegistry::registerStat(&d_split_lemmas);
StatisticsRegistry::registerStat(&d_disamb_term_lemmas);
+ StatisticsRegistry::registerStat(&d_sym_break_lemmas);
StatisticsRegistry::registerStat(&d_totality_lemmas);
StatisticsRegistry::registerStat(&d_max_model_size);
}
@@ -1714,6 +1776,7 @@ StrongSolverTheoryUF::Statistics::~Statistics(){
StatisticsRegistry::unregisterStat(&d_clique_lemmas);
StatisticsRegistry::unregisterStat(&d_split_lemmas);
StatisticsRegistry::unregisterStat(&d_disamb_term_lemmas);
+ StatisticsRegistry::unregisterStat(&d_sym_break_lemmas);
StatisticsRegistry::unregisterStat(&d_totality_lemmas);
StatisticsRegistry::unregisterStat(&d_max_model_size);
}
diff --git a/src/theory/uf/theory_uf_strong_solver.h b/src/theory/uf/theory_uf_strong_solver.h
index fa8d60b49..8e568444b 100644
--- a/src/theory/uf/theory_uf_strong_solver.h
+++ b/src/theory/uf/theory_uf_strong_solver.h
@@ -26,7 +26,13 @@
#include "util/statistics_registry.h"
namespace CVC4 {
+
+class SortInference;
+
namespace theory {
+
+class SubsortSymmetryBreaker;
+
namespace uf {
class TheoryUF;
@@ -40,7 +46,6 @@ protected:
typedef context::CDHashMap<Node, Node, NodeHashFunction> NodeNodeMap;
typedef context::CDChunkList<Node> NodeList;
typedef context::CDList<bool> BoolList;
- typedef context::CDList<bool> IntList;
typedef context::CDHashMap<TypeNode, bool, TypeNodeHashFunction> TypeNodeBoolMap;
public:
/** information for incremental conflict/clique finding for a particular sort */
@@ -202,6 +207,23 @@ public:
/** add totality axiom */
void addTotalityAxiom( Node n, int cardinality, OutputChannel* out );
private:
+ class NodeTrie {
+ std::map< Node, NodeTrie > d_children;
+ public:
+ bool add( std::vector< Node >& n, unsigned i = 0 ){
+ Assert( i<n.size() );
+ if( i==(n.size()-1) ){
+ bool ret = d_children.find( n[i] )==d_children.end();
+ d_children[n[i]].d_children.clear();
+ return ret;
+ }else{
+ return d_children[n[i]].add( n, i+1 );
+ }
+ }
+ };
+ std::map< int, NodeTrie > d_clique_trie;
+ void addClique( int c, std::vector< Node >& clique );
+ private:
/** Are we in conflict */
context::CDO<bool> d_conflict;
/** cardinality */
@@ -286,6 +308,8 @@ private:
TermDisambiguator* d_term_amb;
/** disequality propagator */
DisequalityPropagator* d_deq_prop;
+ /** symmetry breaking techniques */
+ SubsortSymmetryBreaker* d_sym_break;
public:
StrongSolverTheoryUF(context::Context* c, context::UserContext* u, OutputChannel& out, TheoryUF* th);
~StrongSolverTheoryUF() {}
@@ -295,6 +319,10 @@ public:
TermDisambiguator* getTermDisambiguator() { return d_term_amb; }
/** disequality propagator */
DisequalityPropagator* getDisequalityPropagator() { return d_deq_prop; }
+ /** symmetry breaker */
+ SubsortSymmetryBreaker* getSymmetryBreaker() { return d_sym_break; }
+ /** get sort inference module */
+ SortInference* getSortInference();
/** get default sat context */
context::Context* getSatContext();
/** get default output channel */
@@ -336,8 +364,10 @@ public:
TypeNode getCardinalityType( int i ) { return d_conf_types[i]; }
/** get is in conflict */
bool isConflict() { return d_conflict; }
- /** get cardinality for sort */
+ /** get cardinality for node */
int getCardinality( Node n );
+ /** get cardinality for type */
+ int getCardinality( TypeNode tn );
/** get representatives */
void getRepresentatives( Node n, std::vector< Node >& reps );
/** minimize */
@@ -349,6 +379,7 @@ public:
IntStat d_clique_lemmas;
IntStat d_split_lemmas;
IntStat d_disamb_term_lemmas;
+ IntStat d_sym_break_lemmas;
IntStat d_totality_lemmas;
IntStat d_max_model_size;
Statistics();
diff --git a/src/util/sort_inference.cpp b/src/util/sort_inference.cpp
index 13631e590..a4c34faec 100644
--- a/src/util/sort_inference.cpp
+++ b/src/util/sort_inference.cpp
@@ -27,8 +27,55 @@ using namespace std;
namespace CVC4 {
+void SortInference::UnionFind::print(const char * c){
+ for( std::map< int, int >::iterator it = d_eqc.begin(); it != d_eqc.end(); ++it ){
+ Trace(c) << "s_" << it->first << " = s_" << it->second << ", ";
+ }
+ for( unsigned i=0; i<d_deq.size(); i++ ){
+ Trace(c) << "s_" << d_deq[i].first << " != s_" << d_deq[i].second << ", ";
+ }
+ Trace(c) << std::endl;
+}
+void SortInference::UnionFind::set( UnionFind& c ) {
+ clear();
+ for( std::map< int, int >::iterator it = c.d_eqc.begin(); it != c.d_eqc.end(); ++it ){
+ d_eqc[ it->first ] = it->second;
+ }
+ d_deq.insert( d_deq.end(), c.d_deq.begin(), c.d_deq.end() );
+}
+int SortInference::UnionFind::getRepresentative( int t ){
+ std::map< int, int >::iterator it = d_eqc.find( t );
+ if( it==d_eqc.end() || it->second==t ){
+ return t;
+ }else{
+ int rt = getRepresentative( it->second );
+ d_eqc[t] = rt;
+ return rt;
+ }
+}
+void SortInference::UnionFind::setEqual( int t1, int t2 ){
+ if( t1!=t2 ){
+ int rt1 = getRepresentative( t1 );
+ int rt2 = getRepresentative( t2 );
+ if( rt1>rt2 ){
+ d_eqc[rt1] = rt2;
+ }else{
+ d_eqc[rt2] = rt1;
+ }
+ }
+}
+bool SortInference::UnionFind::isValid() {
+ for( unsigned i=0; i<d_deq.size(); i++ ){
+ if( areEqual( d_deq[i].first, d_deq[i].second ) ){
+ return false;
+ }
+ }
+ return true;
+}
+
+
void SortInference::printSort( const char* c, int t ){
- int rt = getRepresentative( t );
+ int rt = d_type_union_find.getRepresentative( t );
if( d_type_types.find( rt )!=d_type_types.end() ){
Trace(c) << d_type_types[rt];
}else{
@@ -83,46 +130,19 @@ void SortInference::simplify( std::vector< Node >& assertions, bool doRewrite ){
//add lemma enforcing introduced constants to be distinct?
}
}
-}
-
-int SortInference::getRepresentative( int t ){
- std::map< int, int >::iterator it = d_type_union_find.find( t );
- if( it!=d_type_union_find.end() ){
- if( it->second==t ){
- return t;
- }else{
- int rt = getRepresentative( it->second );
- d_type_union_find[t] = rt;
- return rt;
- }
- }else{
- return t;
- }
+ initialSortCount = sortCount;
}
void SortInference::setEqual( int t1, int t2 ){
if( t1!=t2 ){
- int rt1 = getRepresentative( t1 );
- int rt2 = getRepresentative( t2 );
+ int rt1 = d_type_union_find.getRepresentative( t1 );
+ int rt2 = d_type_union_find.getRepresentative( t2 );
if( rt1!=rt2 ){
Trace("sort-inference-debug") << "Set equal : ";
printSort( "sort-inference-debug", rt1 );
Trace("sort-inference-debug") << " ";
printSort( "sort-inference-debug", rt2 );
Trace("sort-inference-debug") << std::endl;
- //check if they must be a type
- std::map< int, TypeNode >::iterator it1 = d_type_types.find( rt1 );
- std::map< int, TypeNode >::iterator it2 = d_type_types.find( rt2 );
- if( it2!=d_type_types.end() ){
- if( it1==d_type_types.end() ){
- //swap sides
- int swap = rt1;
- rt1 = rt2;
- rt2 = swap;
- }else{
- Assert( rt1==rt2 );
- }
- }
/*
d_type_eq_class[rt1].insert( d_type_eq_class[rt1].end(), d_type_eq_class[rt2].begin(), d_type_eq_class[rt2].end() );
d_type_eq_class[rt2].clear();
@@ -132,7 +152,19 @@ void SortInference::setEqual( int t1, int t2 ){
}
Trace("sort-inference-debug") << "}" << std::endl;
*/
- d_type_union_find[rt2] = rt1;
+ if( rt2>rt1 ){
+ //swap
+ int swap = rt1;
+ rt1 = rt2;
+ rt2 = swap;
+ }
+ d_type_union_find.d_eqc[rt1] = rt2;
+ std::map< int, TypeNode >::iterator it1 = d_type_types.find( rt1 );
+ if( it1!=d_type_types.end() ){
+ Assert( d_type_types.find( rt2 )==d_type_types.end() );
+ d_type_types[rt2] = it1->second;
+ d_type_types.erase( rt1 );
+ }
}
}
}
@@ -155,14 +187,17 @@ int SortInference::process( Node n, std::map< Node, Node >& var_bound ){
Trace("sort-inference-debug") << "Process " << n << std::endl;
//add to variable bindings
if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
- for( size_t i=0; i<n[0].getNumChildren(); i++ ){
- //TODO: try applying sort inference to quantified variables
- d_var_types[n][ n[0][i] ] = sortCount;
- sortCount++;
+ if( d_var_types.find( n )!=d_var_types.end() ){
+ return getIdForType( n.getType() );
+ }else{
+ for( size_t i=0; i<n[0].getNumChildren(); i++ ){
+ //apply sort inference to quantified variables
+ d_var_types[n][ n[0][i] ] = sortCount;
+ sortCount++;
- //type of the quantified variable must be the same
- //d_var_types[n][ n[0][i] ] = getIdForType( n[0][i].getType() );
- var_bound[ n[0][i] ] = n;
+ //type of the quantified variable must be the same
+ var_bound[ n[0][i] ] = n;
+ }
}
}
@@ -191,10 +226,10 @@ int SortInference::process( Node n, std::map< Node, Node >& var_bound ){
int retType;
if( n.getKind()==kind::EQUAL ){
//we only require that the left and right hand side must be equal
- //setEqual( child_types[0], child_types[1] );
- int eqType = getIdForType( n[0].getType() );
- setEqual( child_types[0], eqType );
- setEqual( child_types[1], eqType );
+ setEqual( child_types[0], child_types[1] );
+ //int eqType = getIdForType( n[0].getType() );
+ //setEqual( child_types[0], eqType );
+ //setEqual( child_types[1], eqType );
retType = getIdForType( n.getType() );
}else if( n.getKind()==kind::APPLY_UF ){
Node op = n.getOperator();
@@ -256,7 +291,7 @@ int SortInference::process( Node n, std::map< Node, Node >& var_bound ){
TypeNode SortInference::getOrCreateTypeForId( int t, TypeNode pref ){
- int rt = getRepresentative( t );
+ int rt = d_type_union_find.getRepresentative( t );
if( d_type_types.find( rt )!=d_type_types.end() ){
return d_type_types[rt];
}else{
@@ -281,7 +316,7 @@ TypeNode SortInference::getOrCreateTypeForId( int t, TypeNode pref ){
}
TypeNode SortInference::getTypeForId( int t ){
- int rt = getRepresentative( t );
+ int rt = d_type_union_find.getRepresentative( t );
if( d_type_types.find( rt )!=d_type_types.end() ){
return d_type_types[rt];
}else{
@@ -417,15 +452,71 @@ Node SortInference::simplify( Node n, std::map< Node, Node >& var_bound ){
}
int SortInference::getSortId( Node n ) {
Node op = n.getKind()==kind::APPLY_UF ? n.getOperator() : n;
- return getRepresentative( d_op_return_types[op] );
+ if( d_op_return_types.find( op )!=d_op_return_types.end() ){
+ return d_type_union_find.getRepresentative( d_op_return_types[op] );
+ }else{
+ return 0;
+ }
}
int SortInference::getSortId( Node f, Node v ) {
- return getRepresentative( d_var_types[f][v] );
+ if( d_var_types.find( f )!=d_var_types.end() ){
+ return d_type_union_find.getRepresentative( d_var_types[f][v] );
+ }else{
+ return 0;
+ }
}
void SortInference::setSkolemVar( Node f, Node v, Node sk ){
+ Trace("sort-inference-temp") << "Set skolem var for " << f << ", variable " << v << std::endl;
+ if( isWellSortedFormula( f ) && d_var_types.find( f )==d_var_types.end() ){
+ std::map< Node, Node > var_bound;
+ process( f, var_bound );
+ }
d_op_return_types[sk] = getSortId( f, v );
+ Trace("sort-inference-temp") << "Set skolem sort id for " << sk << " to " << d_op_return_types[sk] << std::endl;
+}
+
+bool SortInference::isWellSortedFormula( Node n ) {
+ if( n.getType().isBoolean() && n.getKind()!=kind::APPLY_UF ){
+ for( unsigned i=0; i<n.getNumChildren(); i++ ){
+ if( !isWellSortedFormula( n[i] ) ){
+ return false;
+ }
+ }
+ return true;
+ }else{
+ return isWellSorted( n );
+ }
+}
+
+bool SortInference::isWellSorted( Node n ) {
+ if( getSortId( n )==0 ){
+ return false;
+ }else{
+ if( n.getKind()==kind::APPLY_UF ){
+ for( unsigned i=0; i<n.getNumChildren(); i++ ){
+ int s1 = getSortId( n[i] );
+ int s2 = d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] );
+ if( s1!=s2 ){
+ return false;
+ }
+ if( !isWellSorted( n[i] ) ){
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+}
+
+void SortInference::getSortConstraints( Node n, UnionFind& uf ) {
+ if( n.getKind()==kind::APPLY_UF ){
+ for( unsigned i=0; i<n.getNumChildren(); i++ ){
+ getSortConstraints( n[i], uf );
+ uf.setEqual( getSortId( n[i] ), d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] ) );
+ }
+ }
}
}/* CVC4 namespace */
diff --git a/src/util/sort_inference.h b/src/util/sort_inference.h
index 1bcb8a208..53dff823f 100644
--- a/src/util/sort_inference.h
+++ b/src/util/sort_inference.h
@@ -30,9 +30,29 @@ class SortInference{
private:
//for debugging
//std::map< int, std::vector< Node > > d_type_eq_class;
+public:
+ class UnionFind {
+ public:
+ UnionFind(){}
+ UnionFind( UnionFind& c ){
+ set( c );
+ }
+ std::map< int, int > d_eqc;
+ //pairs that must be disequal
+ std::vector< std::pair< int, int > > d_deq;
+ void print(const char * c);
+ void clear() { d_eqc.clear(); d_deq.clear(); }
+ void set( UnionFind& c );
+ int getRepresentative( int t );
+ void setEqual( int t1, int t2 );
+ void setDisequal( int t1, int t2 ){ d_deq.push_back( std::pair< int, int >( t1, t2 ) ); }
+ bool areEqual( int t1, int t2 ) { return getRepresentative( t1 )==getRepresentative( t2 ); }
+ bool isValid();
+ };
private:
int sortCount;
- std::map< int, int > d_type_union_find;
+ int initialSortCount;
+ UnionFind d_type_union_find;
std::map< int, TypeNode > d_type_types;
std::map< TypeNode, int > d_id_for_types;
//for apply uf operators
@@ -41,7 +61,6 @@ private:
//for bound variables
std::map< Node, std::map< Node, int > > d_var_types;
//get representative
- int getRepresentative( int t );
void setEqual( int t1, int t2 );
int getIdForType( TypeNode tn );
void printSort( const char* c, int t );
@@ -61,14 +80,22 @@ private:
//simplify
Node simplify( Node n, std::map< Node, Node >& var_bound );
public:
- SortInference() : sortCount( 0 ){}
+ SortInference() : sortCount( 1 ){}
~SortInference(){}
void simplify( std::vector< Node >& assertions, bool doRewrite = false );
+ //get sort id for term n
int getSortId( Node n );
+ //get sort id for variable of quantified formula f
int getSortId( Node f, Node v );
//set that sk is the skolem variable of v for quantifier f
void setSkolemVar( Node f, Node v, Node sk );
+public:
+ //is well sorted
+ bool isWellSortedFormula( Node n );
+ bool isWellSorted( Node n );
+ //get constraints for being well-typed according to computed sub-types
+ void getSortConstraints( Node n, SortInference::UnionFind& uf );
};
}
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback