summaryrefslogtreecommitdiff
path: root/src/theory/quantifiers/term_database.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/theory/quantifiers/term_database.cpp')
-rw-r--r--src/theory/quantifiers/term_database.cpp165
1 files changed, 153 insertions, 12 deletions
diff --git a/src/theory/quantifiers/term_database.cpp b/src/theory/quantifiers/term_database.cpp
index 60573a7fc..646a1565e 100644
--- a/src/theory/quantifiers/term_database.cpp
+++ b/src/theory/quantifiers/term_database.cpp
@@ -1359,7 +1359,10 @@ int TermDb::getQAttrRewriteRulePriority( Node q ) {
-
+TermDbSygus::TermDbSygus(){
+ d_true = NodeManager::currentNM()->mkConst( true );
+ d_false = NodeManager::currentNM()->mkConst( false );
+}
TNode TermDbSygus::getVar( TypeNode tn, int i ) {
while( i>=(int)d_fv[tn].size() ){
@@ -1534,7 +1537,7 @@ Node TermDbSygus::mkGeneric( const Datatype& dt, int c, std::map< TypeNode, int
}
Trace("sygus-db") << "mkGeneric " << dt.getName() << " " << op << " " << op.getKind() << "..." << std::endl;
for( int i=0; i<(int)dt[c].getNumArgs(); i++ ){
- TypeNode tna = TypeNode::fromType( ((SelectorType)dt[c][i].getType()).getRangeType() );
+ TypeNode tna = getArgType( dt[c], i );
Node a;
std::map< int, Node >::iterator it = pre.find( i );
if( it!=pre.end() ){
@@ -1589,14 +1592,18 @@ Node TermDbSygus::sygusToBuiltin( Node n, TypeNode tn ) {
}
}
-Node TermDbSygus::builtinToSygusConst( Node c, TypeNode tn ) {
+//rcons_depth limits the number of recursive calls when doing accelerated constant reconstruction (currently limited to 1000)
+//this is hacky : depending upon order of calls, constant rcons may succeed, e.g. 1001, 999 vs. 999, 1001
+Node TermDbSygus::builtinToSygusConst( Node c, TypeNode tn, int rcons_depth ) {
std::map< Node, Node >::iterator it = d_builtin_const_to_sygus[tn].find( c );
if( it==d_builtin_const_to_sygus[tn].end() ){
+ Node sc;
+ d_builtin_const_to_sygus[tn][c] = sc;
Assert( c.isConst() );
Assert( datatypes::DatatypesRewriter::isTypeDatatype( tn ) );
const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype();
+ Trace("csi-rcons-debug") << "Try to reconstruct " << c << " in " << dt.getName() << std::endl;
Assert( dt.isSygus() );
- Node sc;
// if we are not interested in reconstructing constants, or the grammar allows them, return a proxy
if( !options::cegqiSingleInvReconstructConst() || dt.getSygusAllowConst() ){
Node k = NodeManager::currentNM()->mkSkolem( "sy", tn, "sygus proxy" );
@@ -1606,9 +1613,60 @@ Node TermDbSygus::builtinToSygusConst( Node c, TypeNode tn ) {
}else{
int carg = getOpArg( tn, c );
if( carg!=-1 ){
- sc = Node::fromExpr( dt[carg].getSygusOp() );
+ //sc = Node::fromExpr( dt[carg].getSygusOp() );
+ sc = NodeManager::currentNM()->mkNode( APPLY_CONSTRUCTOR, Node::fromExpr( dt[carg].getConstructor() ) );
}else{
- //TODO
+ //identity functions
+ for( unsigned i=0; i<getNumIdFuncs( tn ); i++ ){
+ unsigned ii = getIdFuncIndex( tn, i );
+ Assert( dt[ii].getNumArgs()==1 );
+ //try to directly reconstruct from single argument
+ TypeNode tnc = getArgType( dt[ii], 0 );
+ Trace("csi-rcons-debug") << "Based on id function " << dt[ii].getSygusOp() << ", try reconstructing " << c << " instead in " << tnc << std::endl;
+ Node n = builtinToSygusConst( c, tnc, rcons_depth );
+ if( !n.isNull() ){
+ sc = NodeManager::currentNM()->mkNode( APPLY_CONSTRUCTOR, Node::fromExpr( dt[ii].getConstructor() ), n );
+ break;
+ }
+ }
+ if( sc.isNull() ){
+ if( rcons_depth<1000 ){
+ //accelerated, recursive reconstruction of constants
+ Kind pk = getPlusKind( TypeNode::fromType( dt.getSygusType() ) );
+ if( pk!=UNDEFINED_KIND ){
+ int arg = getKindArg( tn, pk );
+ if( arg!=-1 ){
+ Kind ck = getComparisonKind( TypeNode::fromType( dt.getSygusType() ) );
+ Kind pkm = getPlusKind( TypeNode::fromType( dt.getSygusType() ), true );
+ //get types
+ Assert( dt[arg].getNumArgs()==2 );
+ TypeNode tn1 = getArgType( dt[arg], 0 );
+ TypeNode tn2 = getArgType( dt[arg], 1 );
+ //iterate over all positive constants, largest to smallest
+ int start = d_const_list[tn1].size()-1;
+ int end = d_const_list[tn1].size()-d_const_list_pos[tn1];
+ for( int i=start; i>=end; --i ){
+ Node c1 = d_const_list[tn1][i];
+ //only consider if smaller than c, and
+ if( doCompare( c1, c, ck ) ){
+ Node c2 = NodeManager::currentNM()->mkNode( pkm, c, c1 );
+ c2 = Rewriter::rewrite( c2 );
+ if( c2.isConst() ){
+ //reconstruct constant on the other side
+ Node sc2 = builtinToSygusConst( c2, tn2, rcons_depth+1 );
+ if( !sc2.isNull() ){
+ Node sc1 = builtinToSygusConst( c1, tn1, rcons_depth );
+ Assert( !sc1.isNull() );
+ sc = NodeManager::currentNM()->mkNode( APPLY_CONSTRUCTOR, Node::fromExpr( dt[arg].getConstructor() ), sc1, sc2 );
+ break;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
}
}
d_builtin_const_to_sygus[tn][c] = sc;
@@ -1798,7 +1856,7 @@ Node TermDbSygus::getTypeValue( TypeNode tn, int val ) {
n = NodeManager::currentNM()->mkConst<BitVector>(bval);
}else if( tn.isBoolean() ){
if( val==0 ){
- n = NodeManager::currentNM()->mkConst( false );
+ n = d_false;
}
}
d_type_value[tn][val] = n;
@@ -1815,7 +1873,7 @@ Node TermDbSygus::getTypeMaxValue( TypeNode tn ) {
if( tn.isBitVector() ){
n = bv::utils::mkOnes(tn.getConst<BitVectorSize>());
}else if( tn.isBoolean() ){
- n = NodeManager::currentNM()->mkConst( true );
+ n = d_true;
}
d_type_max_value[tn] = n;
return n;
@@ -1847,6 +1905,18 @@ Node TermDbSygus::getTypeValueOffset( TypeNode tn, Node val, int offset, int& st
}
}
+struct sortConstants {
+ TermDbSygus * d_tds;
+ Kind d_comp_kind;
+ bool operator() (Node i, Node j) {
+ if( i!=j ){
+ return d_tds->doCompare( i, j, d_comp_kind );
+ }else{
+ return false;
+ }
+ }
+};
+
void TermDbSygus::registerSygusType( TypeNode tn ){
if( d_register.find( tn )==d_register.end() ){
if( !datatypes::DatatypesRewriter::isTypeDatatype( tn ) ){
@@ -1858,6 +1928,11 @@ void TermDbSygus::registerSygusType( TypeNode tn ){
if( d_register[tn].isNull() ){
Trace("sygus-util") << "...not sygus." << std::endl;
}else{
+ //for constant reconstruction
+ Kind ck = getComparisonKind( TypeNode::fromType( dt.getSygusType() ) );
+ Node z = getTypeValue( TypeNode::fromType( dt.getSygusType() ), 0 );
+ d_const_list_pos[tn] = 0;
+ //iterate over constructors
for( unsigned i=0; i<dt.getNumConstructors(); i++ ){
Expr sop = dt[i].getSygusOp();
Assert( !sop.isNull() );
@@ -1872,11 +1947,39 @@ void TermDbSygus::registerSygusType( TypeNode tn ){
Trace("sygus-util") << ", constant";
d_consts[tn][n] = i;
d_arg_const[tn][i] = n;
+ d_const_list[tn].push_back( n );
+ if( ck!=UNDEFINED_KIND && doCompare( z, n, ck ) ){
+ d_const_list_pos[tn]++;
+ }
+ }
+ if( dt[i].isSygusIdFunc() ){
+ d_id_funcs[tn].push_back( i );
}
d_ops[tn][n] = i;
d_arg_ops[tn][i] = n;
Trace("sygus-util") << std::endl;
}
+ //sort the constant list
+ if( !d_const_list[tn].empty() ){
+ if( ck!=UNDEFINED_KIND ){
+ sortConstants sc;
+ sc.d_comp_kind = ck;
+ sc.d_tds = this;
+ std::sort( d_const_list[tn].begin(), d_const_list[tn].end(), sc );
+ }
+ Trace("sygus-util") << "Type has " << d_const_list[tn].size() << " constants..." << std::endl << " ";
+ for( unsigned i=0; i<d_const_list[tn].size(); i++ ){
+ Trace("sygus-util") << d_const_list[tn][i] << " ";
+ }
+ Trace("sygus-util") << std::endl;
+ Trace("sygus-util") << "Of these, " << d_const_list_pos[tn] << " are marked as positive." << std::endl;
+ }
+ //register connected types
+ for( unsigned i=0; i<dt.getNumConstructors(); i++ ){
+ for( unsigned j=0; j<dt[i].getNumArgs(); j++ ){
+ registerSygusType( getArgType( dt[i], j ) );
+ }
+ }
}
}
}
@@ -1979,6 +2082,14 @@ bool TermDbSygus::isConstArg( TypeNode tn, int i ) {
}
}
+unsigned TermDbSygus::getNumIdFuncs( TypeNode tn ) {
+ return d_id_funcs[tn].size();
+}
+
+unsigned TermDbSygus::getIdFuncIndex( TypeNode tn, unsigned i ) {
+ return d_id_funcs[tn][i];
+}
+
TypeNode TermDbSygus::getArgType( const DatatypeConstructor& c, int i ) {
Assert( i>=0 && i<(int)c.getNumArgs() );
return TypeNode::fromType( ((SelectorType)c[i].getType()).getRangeType() );
@@ -2041,7 +2152,33 @@ Node TermDbSygus::expandBuiltinTerm( Node t ){
}
-void doReplace(std::string& str, const std::string& oldStr, const std::string& newStr){
+Kind TermDbSygus::getComparisonKind( TypeNode tn ) {
+ if( tn.isInteger() ){
+ return LT;
+ }else if( tn.isBitVector() ){
+ return BITVECTOR_ULT;
+ }else{
+ return UNDEFINED_KIND;
+ }
+}
+
+Kind TermDbSygus::getPlusKind( TypeNode tn, bool is_neg ) {
+ if( tn.isInteger() ){
+ return is_neg ? MINUS : PLUS;
+ }else if( tn.isBitVector() ){
+ return is_neg ? BITVECTOR_SUB : BITVECTOR_PLUS;
+ }else{
+ return UNDEFINED_KIND;
+ }
+}
+
+bool TermDbSygus::doCompare( Node a, Node b, Kind k ) {
+ Node com = NodeManager::currentNM()->mkNode( k, a, b );
+ com = Rewriter::rewrite( com );
+ return com==d_true;
+}
+
+void doStrReplace(std::string& str, const std::string& oldStr, const std::string& newStr){
size_t pos = 0;
while((pos = str.find(oldStr, pos)) != std::string::npos){
str.replace(pos, oldStr.length(), newStr);
@@ -2101,12 +2238,16 @@ void TermDbSygus::printSygusTerm( std::ostream& out, Node n, std::vector< Node >
std::stringstream body_out;
printSygusTerm( body_out, let_body, new_lvs );
std::string body = body_out.str();
- for( unsigned i=dt[cIndex].getNumSygusLetInputArgs(); i<dt[cIndex].getNumSygusLetArgs(); i++ ){
+ for( unsigned i=0; i<dt[cIndex].getNumSygusLetArgs(); i++ ){
std::stringstream old_str;
old_str << new_lvs[i];
std::stringstream new_str;
- printSygusTerm( new_str, n[i], lvs );
- doReplace( body, old_str.str().c_str(), new_str.str().c_str() );
+ if( i>=dt[cIndex].getNumSygusLetInputArgs() ){
+ printSygusTerm( new_str, n[i], lvs );
+ }else{
+ new_str << Node::fromExpr( dt[cIndex].getSygusLetArg( i ) );
+ }
+ doStrReplace( body, old_str.str().c_str(), new_str.str().c_str() );
}
out << body;
if( dt[cIndex].getNumSygusLetInputArgs()>0 ){
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback