diff options
author | ajreynol <andrew.j.reynolds@gmail.com> | 2015-06-12 14:15:14 +0200 |
---|---|---|
committer | ajreynol <andrew.j.reynolds@gmail.com> | 2015-06-12 14:15:14 +0200 |
commit | ad0863ae8333c4dcd950153e0db8cd4565a250b3 (patch) | |
tree | 96d265ac59ce48f5bb90f0d041b1a4ffe57539a4 /src/theory/quantifiers/term_database.cpp | |
parent | df88bab0da253bb00056a25b4f7603d9ac6f3d66 (diff) |
Accelerate sygus solution reconstruction for constants and id functions. Minor changes to sygus type registration. Print sygus let solutions assuming fixed variable names.
Diffstat (limited to 'src/theory/quantifiers/term_database.cpp')
-rw-r--r-- | src/theory/quantifiers/term_database.cpp | 165 |
1 files changed, 153 insertions, 12 deletions
diff --git a/src/theory/quantifiers/term_database.cpp b/src/theory/quantifiers/term_database.cpp index 60573a7fc..646a1565e 100644 --- a/src/theory/quantifiers/term_database.cpp +++ b/src/theory/quantifiers/term_database.cpp @@ -1359,7 +1359,10 @@ int TermDb::getQAttrRewriteRulePriority( Node q ) { - +TermDbSygus::TermDbSygus(){ + d_true = NodeManager::currentNM()->mkConst( true ); + d_false = NodeManager::currentNM()->mkConst( false ); +} TNode TermDbSygus::getVar( TypeNode tn, int i ) { while( i>=(int)d_fv[tn].size() ){ @@ -1534,7 +1537,7 @@ Node TermDbSygus::mkGeneric( const Datatype& dt, int c, std::map< TypeNode, int } Trace("sygus-db") << "mkGeneric " << dt.getName() << " " << op << " " << op.getKind() << "..." << std::endl; for( int i=0; i<(int)dt[c].getNumArgs(); i++ ){ - TypeNode tna = TypeNode::fromType( ((SelectorType)dt[c][i].getType()).getRangeType() ); + TypeNode tna = getArgType( dt[c], i ); Node a; std::map< int, Node >::iterator it = pre.find( i ); if( it!=pre.end() ){ @@ -1589,14 +1592,18 @@ Node TermDbSygus::sygusToBuiltin( Node n, TypeNode tn ) { } } -Node TermDbSygus::builtinToSygusConst( Node c, TypeNode tn ) { +//rcons_depth limits the number of recursive calls when doing accelerated constant reconstruction (currently limited to 1000) +//this is hacky : depending upon order of calls, constant rcons may succeed, e.g. 1001, 999 vs. 999, 1001 +Node TermDbSygus::builtinToSygusConst( Node c, TypeNode tn, int rcons_depth ) { std::map< Node, Node >::iterator it = d_builtin_const_to_sygus[tn].find( c ); if( it==d_builtin_const_to_sygus[tn].end() ){ + Node sc; + d_builtin_const_to_sygus[tn][c] = sc; Assert( c.isConst() ); Assert( datatypes::DatatypesRewriter::isTypeDatatype( tn ) ); const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype(); + Trace("csi-rcons-debug") << "Try to reconstruct " << c << " in " << dt.getName() << std::endl; Assert( dt.isSygus() ); - Node sc; // if we are not interested in reconstructing constants, or the grammar allows them, return a proxy if( !options::cegqiSingleInvReconstructConst() || dt.getSygusAllowConst() ){ Node k = NodeManager::currentNM()->mkSkolem( "sy", tn, "sygus proxy" ); @@ -1606,9 +1613,60 @@ Node TermDbSygus::builtinToSygusConst( Node c, TypeNode tn ) { }else{ int carg = getOpArg( tn, c ); if( carg!=-1 ){ - sc = Node::fromExpr( dt[carg].getSygusOp() ); + //sc = Node::fromExpr( dt[carg].getSygusOp() ); + sc = NodeManager::currentNM()->mkNode( APPLY_CONSTRUCTOR, Node::fromExpr( dt[carg].getConstructor() ) ); }else{ - //TODO + //identity functions + for( unsigned i=0; i<getNumIdFuncs( tn ); i++ ){ + unsigned ii = getIdFuncIndex( tn, i ); + Assert( dt[ii].getNumArgs()==1 ); + //try to directly reconstruct from single argument + TypeNode tnc = getArgType( dt[ii], 0 ); + Trace("csi-rcons-debug") << "Based on id function " << dt[ii].getSygusOp() << ", try reconstructing " << c << " instead in " << tnc << std::endl; + Node n = builtinToSygusConst( c, tnc, rcons_depth ); + if( !n.isNull() ){ + sc = NodeManager::currentNM()->mkNode( APPLY_CONSTRUCTOR, Node::fromExpr( dt[ii].getConstructor() ), n ); + break; + } + } + if( sc.isNull() ){ + if( rcons_depth<1000 ){ + //accelerated, recursive reconstruction of constants + Kind pk = getPlusKind( TypeNode::fromType( dt.getSygusType() ) ); + if( pk!=UNDEFINED_KIND ){ + int arg = getKindArg( tn, pk ); + if( arg!=-1 ){ + Kind ck = getComparisonKind( TypeNode::fromType( dt.getSygusType() ) ); + Kind pkm = getPlusKind( TypeNode::fromType( dt.getSygusType() ), true ); + //get types + Assert( dt[arg].getNumArgs()==2 ); + TypeNode tn1 = getArgType( dt[arg], 0 ); + TypeNode tn2 = getArgType( dt[arg], 1 ); + //iterate over all positive constants, largest to smallest + int start = d_const_list[tn1].size()-1; + int end = d_const_list[tn1].size()-d_const_list_pos[tn1]; + for( int i=start; i>=end; --i ){ + Node c1 = d_const_list[tn1][i]; + //only consider if smaller than c, and + if( doCompare( c1, c, ck ) ){ + Node c2 = NodeManager::currentNM()->mkNode( pkm, c, c1 ); + c2 = Rewriter::rewrite( c2 ); + if( c2.isConst() ){ + //reconstruct constant on the other side + Node sc2 = builtinToSygusConst( c2, tn2, rcons_depth+1 ); + if( !sc2.isNull() ){ + Node sc1 = builtinToSygusConst( c1, tn1, rcons_depth ); + Assert( !sc1.isNull() ); + sc = NodeManager::currentNM()->mkNode( APPLY_CONSTRUCTOR, Node::fromExpr( dt[arg].getConstructor() ), sc1, sc2 ); + break; + } + } + } + } + } + } + } + } } } d_builtin_const_to_sygus[tn][c] = sc; @@ -1798,7 +1856,7 @@ Node TermDbSygus::getTypeValue( TypeNode tn, int val ) { n = NodeManager::currentNM()->mkConst<BitVector>(bval); }else if( tn.isBoolean() ){ if( val==0 ){ - n = NodeManager::currentNM()->mkConst( false ); + n = d_false; } } d_type_value[tn][val] = n; @@ -1815,7 +1873,7 @@ Node TermDbSygus::getTypeMaxValue( TypeNode tn ) { if( tn.isBitVector() ){ n = bv::utils::mkOnes(tn.getConst<BitVectorSize>()); }else if( tn.isBoolean() ){ - n = NodeManager::currentNM()->mkConst( true ); + n = d_true; } d_type_max_value[tn] = n; return n; @@ -1847,6 +1905,18 @@ Node TermDbSygus::getTypeValueOffset( TypeNode tn, Node val, int offset, int& st } } +struct sortConstants { + TermDbSygus * d_tds; + Kind d_comp_kind; + bool operator() (Node i, Node j) { + if( i!=j ){ + return d_tds->doCompare( i, j, d_comp_kind ); + }else{ + return false; + } + } +}; + void TermDbSygus::registerSygusType( TypeNode tn ){ if( d_register.find( tn )==d_register.end() ){ if( !datatypes::DatatypesRewriter::isTypeDatatype( tn ) ){ @@ -1858,6 +1928,11 @@ void TermDbSygus::registerSygusType( TypeNode tn ){ if( d_register[tn].isNull() ){ Trace("sygus-util") << "...not sygus." << std::endl; }else{ + //for constant reconstruction + Kind ck = getComparisonKind( TypeNode::fromType( dt.getSygusType() ) ); + Node z = getTypeValue( TypeNode::fromType( dt.getSygusType() ), 0 ); + d_const_list_pos[tn] = 0; + //iterate over constructors for( unsigned i=0; i<dt.getNumConstructors(); i++ ){ Expr sop = dt[i].getSygusOp(); Assert( !sop.isNull() ); @@ -1872,11 +1947,39 @@ void TermDbSygus::registerSygusType( TypeNode tn ){ Trace("sygus-util") << ", constant"; d_consts[tn][n] = i; d_arg_const[tn][i] = n; + d_const_list[tn].push_back( n ); + if( ck!=UNDEFINED_KIND && doCompare( z, n, ck ) ){ + d_const_list_pos[tn]++; + } + } + if( dt[i].isSygusIdFunc() ){ + d_id_funcs[tn].push_back( i ); } d_ops[tn][n] = i; d_arg_ops[tn][i] = n; Trace("sygus-util") << std::endl; } + //sort the constant list + if( !d_const_list[tn].empty() ){ + if( ck!=UNDEFINED_KIND ){ + sortConstants sc; + sc.d_comp_kind = ck; + sc.d_tds = this; + std::sort( d_const_list[tn].begin(), d_const_list[tn].end(), sc ); + } + Trace("sygus-util") << "Type has " << d_const_list[tn].size() << " constants..." << std::endl << " "; + for( unsigned i=0; i<d_const_list[tn].size(); i++ ){ + Trace("sygus-util") << d_const_list[tn][i] << " "; + } + Trace("sygus-util") << std::endl; + Trace("sygus-util") << "Of these, " << d_const_list_pos[tn] << " are marked as positive." << std::endl; + } + //register connected types + for( unsigned i=0; i<dt.getNumConstructors(); i++ ){ + for( unsigned j=0; j<dt[i].getNumArgs(); j++ ){ + registerSygusType( getArgType( dt[i], j ) ); + } + } } } } @@ -1979,6 +2082,14 @@ bool TermDbSygus::isConstArg( TypeNode tn, int i ) { } } +unsigned TermDbSygus::getNumIdFuncs( TypeNode tn ) { + return d_id_funcs[tn].size(); +} + +unsigned TermDbSygus::getIdFuncIndex( TypeNode tn, unsigned i ) { + return d_id_funcs[tn][i]; +} + TypeNode TermDbSygus::getArgType( const DatatypeConstructor& c, int i ) { Assert( i>=0 && i<(int)c.getNumArgs() ); return TypeNode::fromType( ((SelectorType)c[i].getType()).getRangeType() ); @@ -2041,7 +2152,33 @@ Node TermDbSygus::expandBuiltinTerm( Node t ){ } -void doReplace(std::string& str, const std::string& oldStr, const std::string& newStr){ +Kind TermDbSygus::getComparisonKind( TypeNode tn ) { + if( tn.isInteger() ){ + return LT; + }else if( tn.isBitVector() ){ + return BITVECTOR_ULT; + }else{ + return UNDEFINED_KIND; + } +} + +Kind TermDbSygus::getPlusKind( TypeNode tn, bool is_neg ) { + if( tn.isInteger() ){ + return is_neg ? MINUS : PLUS; + }else if( tn.isBitVector() ){ + return is_neg ? BITVECTOR_SUB : BITVECTOR_PLUS; + }else{ + return UNDEFINED_KIND; + } +} + +bool TermDbSygus::doCompare( Node a, Node b, Kind k ) { + Node com = NodeManager::currentNM()->mkNode( k, a, b ); + com = Rewriter::rewrite( com ); + return com==d_true; +} + +void doStrReplace(std::string& str, const std::string& oldStr, const std::string& newStr){ size_t pos = 0; while((pos = str.find(oldStr, pos)) != std::string::npos){ str.replace(pos, oldStr.length(), newStr); @@ -2101,12 +2238,16 @@ void TermDbSygus::printSygusTerm( std::ostream& out, Node n, std::vector< Node > std::stringstream body_out; printSygusTerm( body_out, let_body, new_lvs ); std::string body = body_out.str(); - for( unsigned i=dt[cIndex].getNumSygusLetInputArgs(); i<dt[cIndex].getNumSygusLetArgs(); i++ ){ + for( unsigned i=0; i<dt[cIndex].getNumSygusLetArgs(); i++ ){ std::stringstream old_str; old_str << new_lvs[i]; std::stringstream new_str; - printSygusTerm( new_str, n[i], lvs ); - doReplace( body, old_str.str().c_str(), new_str.str().c_str() ); + if( i>=dt[cIndex].getNumSygusLetInputArgs() ){ + printSygusTerm( new_str, n[i], lvs ); + }else{ + new_str << Node::fromExpr( dt[cIndex].getSygusLetArg( i ) ); + } + doStrReplace( body, old_str.str().c_str(), new_str.str().c_str() ); } out << body; if( dt[cIndex].getNumSygusLetInputArgs()>0 ){ |