diff options
Diffstat (limited to 'src/util/sort_inference.cpp')
-rw-r--r-- | src/util/sort_inference.cpp | 308 |
1 files changed, 243 insertions, 65 deletions
diff --git a/src/util/sort_inference.cpp b/src/util/sort_inference.cpp index d44499fa8..b66d1cbe4 100644 --- a/src/util/sort_inference.cpp +++ b/src/util/sort_inference.cpp @@ -20,15 +20,70 @@ #include <vector> #include "util/sort_inference.h" +#include "theory/uf/options.h" +//#include "theory/rewriter.h" using namespace CVC4; using namespace std; namespace CVC4 { +void SortInference::UnionFind::print(const char * c){ + for( std::map< int, int >::iterator it = d_eqc.begin(); it != d_eqc.end(); ++it ){ + Trace(c) << "s_" << it->first << " = s_" << it->second << ", "; + } + for( unsigned i=0; i<d_deq.size(); i++ ){ + Trace(c) << "s_" << d_deq[i].first << " != s_" << d_deq[i].second << ", "; + } + Trace(c) << std::endl; +} +void SortInference::UnionFind::set( UnionFind& c ) { + clear(); + for( std::map< int, int >::iterator it = c.d_eqc.begin(); it != c.d_eqc.end(); ++it ){ + d_eqc[ it->first ] = it->second; + } + d_deq.insert( d_deq.end(), c.d_deq.begin(), c.d_deq.end() ); +} +int SortInference::UnionFind::getRepresentative( int t ){ + std::map< int, int >::iterator it = d_eqc.find( t ); + if( it==d_eqc.end() || it->second==t ){ + return t; + }else{ + int rt = getRepresentative( it->second ); + d_eqc[t] = rt; + return rt; + } +} +void SortInference::UnionFind::setEqual( int t1, int t2 ){ + if( t1!=t2 ){ + int rt1 = getRepresentative( t1 ); + int rt2 = getRepresentative( t2 ); + if( rt1>rt2 ){ + d_eqc[rt1] = rt2; + }else{ + d_eqc[rt2] = rt1; + } + } +} +bool SortInference::UnionFind::isValid() { + for( unsigned i=0; i<d_deq.size(); i++ ){ + if( areEqual( d_deq[i].first, d_deq[i].second ) ){ + return false; + } + } + return true; +} + + +void SortInference::recordSubsort( int s ){ + s = d_type_union_find.getRepresentative( s ); + if( std::find( d_sub_sorts.begin(), d_sub_sorts.end(), s )==d_sub_sorts.end() ){ + d_sub_sorts.push_back( s ); + } +} void SortInference::printSort( const char* c, int t ){ - int rt = getRepresentative( t ); + int rt = d_type_union_find.getRepresentative( t ); if( d_type_types.find( rt )!=d_type_types.end() ){ Trace(c) << d_type_types[rt]; }else{ @@ -43,30 +98,49 @@ void SortInference::simplify( std::vector< Node >& assertions, bool doRewrite ){ std::map< Node, Node > var_bound; process( assertions[i], var_bound ); } - //print debug - if( Trace.isOn("sort-inference") ){ - for( std::map< Node, int >::iterator it = d_op_return_types.begin(); it != d_op_return_types.end(); ++it ){ - Trace("sort-inference") << it->first << " : "; - if( !d_op_arg_types[ it->first ].empty() ){ - Trace("sort-inference") << "( "; - for( size_t i=0; i<d_op_arg_types[ it->first ].size(); i++ ){ - printSort( "sort-inference", d_op_arg_types[ it->first ][i] ); - Trace("sort-inference") << " "; - } - Trace("sort-inference") << ") -> "; + for( std::map< Node, int >::iterator it = d_op_return_types.begin(); it != d_op_return_types.end(); ++it ){ + Trace("sort-inference") << it->first << " : "; + if( !d_op_arg_types[ it->first ].empty() ){ + Trace("sort-inference") << "( "; + for( size_t i=0; i<d_op_arg_types[ it->first ].size(); i++ ){ + recordSubsort( d_op_arg_types[ it->first ][i] ); + printSort( "sort-inference", d_op_arg_types[ it->first ][i] ); + Trace("sort-inference") << " "; } - printSort( "sort-inference", it->second ); - Trace("sort-inference") << std::endl; + Trace("sort-inference") << ") -> "; } - for( std::map< Node, std::map< Node, int > >::iterator it = d_var_types.begin(); it != d_var_types.end(); ++it ){ - Trace("sort-inference") << "Quantified formula " << it->first << " : " << std::endl; - for( std::map< Node, int >::iterator it2 = it->second.begin(); it2 != it->second.end(); ++it2 ){ - printSort( "sort-inference", it2->second ); - Trace("sort-inference") << std::endl; - } + recordSubsort( it->second ); + printSort( "sort-inference", it->second ); + Trace("sort-inference") << std::endl; + } + for( std::map< Node, std::map< Node, int > >::iterator it = d_var_types.begin(); it != d_var_types.end(); ++it ){ + Trace("sort-inference") << "Quantified formula : " << it->first << " : " << std::endl; + for( std::map< Node, int >::iterator it2 = it->second.begin(); it2 != it->second.end(); ++it2 ){ + printSort( "sort-inference", it2->second ); Trace("sort-inference") << std::endl; } + Trace("sort-inference") << std::endl; } + + //determine monotonicity of sorts + for( unsigned i=0; i<assertions.size(); i++ ){ + Trace("sort-inference-debug") << "Process monotonicity for " << assertions[i] << std::endl; + std::map< Node, Node > var_bound; + processMonotonic( assertions[i], true, true, var_bound ); + } + + Trace("sort-inference") << "We have " << d_sub_sorts.size() << " sub-sorts : " << std::endl; + for( unsigned i=0; i<d_sub_sorts.size(); i++ ){ + printSort( "sort-inference", d_sub_sorts[i] ); + if( d_type_types.find( d_sub_sorts[i] )!=d_type_types.end() ){ + Trace("sort-inference") << " is interpreted." << std::endl; + }else if( d_non_monotonic_sorts.find( d_sub_sorts[i] )==d_non_monotonic_sorts.end() ){ + Trace("sort-inference") << " is monotonic." << std::endl; + }else{ + Trace("sort-inference") << " is not monotonic." << std::endl; + } + } + if( doRewrite ){ //simplify all assertions by introducing new symbols wherever necessary (NOTE: this is unsound for quantifiers) for( unsigned i=0; i<assertions.size(); i++ ){ @@ -82,47 +156,43 @@ void SortInference::simplify( std::vector< Node >& assertions, bool doRewrite ){ } //add lemma enforcing introduced constants to be distinct? } - } -} - -int SortInference::getRepresentative( int t ){ - std::map< int, int >::iterator it = d_type_union_find.find( t ); - if( it!=d_type_union_find.end() ){ - if( it->second==t ){ - return t; - }else{ - int rt = getRepresentative( it->second ); - d_type_union_find[t] = rt; - return rt; + }else if( !options::ufssSymBreak() ){ + std::map< int, Node > constants; + //just add a bunch of unit lemmas + for( std::map< Node, int >::iterator it = d_op_return_types.begin(); it != d_op_return_types.end(); ++it ){ + int rt = d_type_union_find.getRepresentative( it->second ); + if( d_op_arg_types[ it->first ].empty() && constants.find( rt )==constants.end() ){ + constants[ rt ] = it->first; + } } - }else{ - return t; + //add unit lemmas for each constant + Node first_const; + for( std::map< int, Node >::iterator it = constants.begin(); it != constants.end(); ++it ){ + if( first_const.isNull() ){ + first_const = it->second; + }else{ + Node eq = first_const.eqNode( it->second ); + //eq = Rewriter::rewrite( eq ); + Trace("sort-inference-lemma") << "Sort inference lemma : " << eq << std::endl; + assertions.push_back( eq ); + } + } + + } + initialSortCount = sortCount; } void SortInference::setEqual( int t1, int t2 ){ if( t1!=t2 ){ - int rt1 = getRepresentative( t1 ); - int rt2 = getRepresentative( t2 ); + int rt1 = d_type_union_find.getRepresentative( t1 ); + int rt2 = d_type_union_find.getRepresentative( t2 ); if( rt1!=rt2 ){ Trace("sort-inference-debug") << "Set equal : "; printSort( "sort-inference-debug", rt1 ); Trace("sort-inference-debug") << " "; printSort( "sort-inference-debug", rt2 ); Trace("sort-inference-debug") << std::endl; - //check if they must be a type - std::map< int, TypeNode >::iterator it1 = d_type_types.find( rt1 ); - std::map< int, TypeNode >::iterator it2 = d_type_types.find( rt2 ); - if( it2!=d_type_types.end() ){ - if( it1==d_type_types.end() ){ - //swap sides - int swap = rt1; - rt1 = rt2; - rt2 = swap; - }else{ - Assert( rt1==rt2 ); - } - } /* d_type_eq_class[rt1].insert( d_type_eq_class[rt1].end(), d_type_eq_class[rt2].begin(), d_type_eq_class[rt2].end() ); d_type_eq_class[rt2].clear(); @@ -132,7 +202,19 @@ void SortInference::setEqual( int t1, int t2 ){ } Trace("sort-inference-debug") << "}" << std::endl; */ - d_type_union_find[rt2] = rt1; + if( rt2>rt1 ){ + //swap + int swap = rt1; + rt1 = rt2; + rt2 = swap; + } + d_type_union_find.d_eqc[rt1] = rt2; + std::map< int, TypeNode >::iterator it1 = d_type_types.find( rt1 ); + if( it1!=d_type_types.end() ){ + Assert( d_type_types.find( rt2 )==d_type_types.end() ); + d_type_types[rt2] = it1->second; + d_type_types.erase( rt1 ); + } } } } @@ -155,14 +237,17 @@ int SortInference::process( Node n, std::map< Node, Node >& var_bound ){ Trace("sort-inference-debug") << "Process " << n << std::endl; //add to variable bindings if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){ - for( size_t i=0; i<n[0].getNumChildren(); i++ ){ - //TODO: try applying sort inference to quantified variables - d_var_types[n][ n[0][i] ] = sortCount; - sortCount++; + if( d_var_types.find( n )!=d_var_types.end() ){ + return getIdForType( n.getType() ); + }else{ + for( size_t i=0; i<n[0].getNumChildren(); i++ ){ + //apply sort inference to quantified variables + d_var_types[n][ n[0][i] ] = sortCount; + sortCount++; - //type of the quantified variable must be the same - //d_var_types[n][ n[0][i] ] = getIdForType( n[0][i].getType() ); - var_bound[ n[0][i] ] = n; + //type of the quantified variable must be the same + var_bound[ n[0][i] ] = n; + } } } @@ -192,6 +277,9 @@ int SortInference::process( Node n, std::map< Node, Node >& var_bound ){ if( n.getKind()==kind::EQUAL ){ //we only require that the left and right hand side must be equal setEqual( child_types[0], child_types[1] ); + //int eqType = getIdForType( n[0].getType() ); + //setEqual( child_types[0], eqType ); + //setEqual( child_types[1], eqType ); retType = getIdForType( n.getType() ); }else if( n.getKind()==kind::APPLY_UF ){ Node op = n.getOperator(); @@ -227,11 +315,11 @@ int SortInference::process( Node n, std::map< Node, Node >& var_bound ){ //d_type_eq_class[sortCount].push_back( n ); } retType = d_op_return_types[n]; - }else if( n.isConst() ){ - Trace("sort-inference-debug") << n << " is a constant." << std::endl; + //}else if( n.isConst() ){ + // Trace("sort-inference-debug") << n << " is a constant." << std::endl; //can be any type we want - retType = sortCount; - sortCount++; + // retType = sortCount; + // sortCount++; }else{ Trace("sort-inference-debug") << n << " is a interpreted symbol." << std::endl; //it is an interpretted term @@ -251,9 +339,43 @@ int SortInference::process( Node n, std::map< Node, Node >& var_bound ){ return retType; } +void SortInference::processMonotonic( Node n, bool pol, bool hasPol, std::map< Node, Node >& var_bound ) { + if( n.getKind()==kind::FORALL ){ + for( unsigned i=0; i<n[0].getNumChildren(); i++ ){ + var_bound[n[0][i]] = n; + } + processMonotonic( n[1], pol, hasPol, var_bound ); + for( unsigned i=0; i<n[0].getNumChildren(); i++ ){ + var_bound.erase( n[0][i] ); + } + }else if( n.getKind()==kind::EQUAL ){ + if( !hasPol || pol ){ + for( unsigned i=0; i<2; i++ ){ + if( var_bound.find( n[i] )==var_bound.end() ){ + int sid = getSortId( var_bound[n[i]], n[i] ); + d_non_monotonic_sorts[sid] = true; + break; + } + } + } + }else{ + for( unsigned i=0; i<n.getNumChildren(); i++ ){ + bool npol = pol; + bool nhasPol = hasPol; + if( n.getKind()==kind::NOT || ( n.getKind()==kind::IMPLIES && i==0 ) ){ + npol = !npol; + } + if( ( n.getKind()==kind::ITE && i==0 ) || n.getKind()==kind::XOR || n.getKind()==kind::IFF ){ + nhasPol = false; + } + processMonotonic( n[i], npol, nhasPol, var_bound ); + } + } +} + TypeNode SortInference::getOrCreateTypeForId( int t, TypeNode pref ){ - int rt = getRepresentative( t ); + int rt = d_type_union_find.getRepresentative( t ); if( d_type_types.find( rt )!=d_type_types.end() ){ return d_type_types[rt]; }else{ @@ -278,7 +400,7 @@ TypeNode SortInference::getOrCreateTypeForId( int t, TypeNode pref ){ } TypeNode SortInference::getTypeForId( int t ){ - int rt = getRepresentative( t ); + int rt = d_type_union_find.getRepresentative( t ); if( d_type_types.find( rt )!=d_type_types.end() ){ return d_type_types[rt]; }else{ @@ -414,15 +536,71 @@ Node SortInference::simplify( Node n, std::map< Node, Node >& var_bound ){ } int SortInference::getSortId( Node n ) { Node op = n.getKind()==kind::APPLY_UF ? n.getOperator() : n; - return getRepresentative( d_op_return_types[op] ); + if( d_op_return_types.find( op )!=d_op_return_types.end() ){ + return d_type_union_find.getRepresentative( d_op_return_types[op] ); + }else{ + return 0; + } } int SortInference::getSortId( Node f, Node v ) { - return getRepresentative( d_var_types[f][v] ); + if( d_var_types.find( f )!=d_var_types.end() ){ + return d_type_union_find.getRepresentative( d_var_types[f][v] ); + }else{ + return 0; + } } void SortInference::setSkolemVar( Node f, Node v, Node sk ){ + Trace("sort-inference-temp") << "Set skolem var for " << f << ", variable " << v << std::endl; + if( isWellSortedFormula( f ) && d_var_types.find( f )==d_var_types.end() ){ + std::map< Node, Node > var_bound; + process( f, var_bound ); + } d_op_return_types[sk] = getSortId( f, v ); + Trace("sort-inference-temp") << "Set skolem sort id for " << sk << " to " << d_op_return_types[sk] << std::endl; +} + +bool SortInference::isWellSortedFormula( Node n ) { + if( n.getType().isBoolean() && n.getKind()!=kind::APPLY_UF ){ + for( unsigned i=0; i<n.getNumChildren(); i++ ){ + if( !isWellSortedFormula( n[i] ) ){ + return false; + } + } + return true; + }else{ + return isWellSorted( n ); + } +} + +bool SortInference::isWellSorted( Node n ) { + if( getSortId( n )==0 ){ + return false; + }else{ + if( n.getKind()==kind::APPLY_UF ){ + for( unsigned i=0; i<n.getNumChildren(); i++ ){ + int s1 = getSortId( n[i] ); + int s2 = d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] ); + if( s1!=s2 ){ + return false; + } + if( !isWellSorted( n[i] ) ){ + return false; + } + } + } + return true; + } +} + +void SortInference::getSortConstraints( Node n, UnionFind& uf ) { + if( n.getKind()==kind::APPLY_UF ){ + for( unsigned i=0; i<n.getNumChildren(); i++ ){ + getSortConstraints( n[i], uf ); + uf.setEqual( getSortId( n[i] ), d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] ) ); + } + } } }/* CVC4 namespace */ |