diff options
Diffstat (limited to 'src/theory/quantifiers/term_database.cpp')
-rw-r--r-- | src/theory/quantifiers/term_database.cpp | 455 |
1 files changed, 454 insertions, 1 deletions
diff --git a/src/theory/quantifiers/term_database.cpp b/src/theory/quantifiers/term_database.cpp index 95214cfc6..24d7cbb5c 100644 --- a/src/theory/quantifiers/term_database.cpp +++ b/src/theory/quantifiers/term_database.cpp @@ -24,6 +24,11 @@ #include "theory/quantifiers/ce_guided_instantiation.h" #include "theory/quantifiers/rewrite_engine.h" +//for sygus +#include "theory/bv/theory_bv_utils.h" +#include "util/bitvector.h" +#include "smt/smt_engine_scope.h" + using namespace std; using namespace CVC4; using namespace CVC4::kind; @@ -75,6 +80,11 @@ void TermArgTrie::debugPrint( const char * c, Node n, unsigned depth ) { TermDb::TermDb( context::Context* c, context::UserContext* u, QuantifiersEngine* qe ) : d_quantEngine( qe ), d_op_ccount( u ) { d_true = NodeManager::currentNM()->mkConst( true ); d_false = NodeManager::currentNM()->mkConst( false ); + if( options::ceGuidedInst() ){ + d_sygus_tdb = new TermDbSygus; + }else{ + d_sygus_tdb = NULL; + } } /** ground terms */ @@ -1152,7 +1162,6 @@ bool TermDb::isInductionTerm( Node n ) { return false; } - bool TermDb::isRewriteRule( Node q ) { return !getRewriteRule( q ).isNull(); } @@ -1309,3 +1318,447 @@ int TermDb::getQAttrRewriteRulePriority( Node q ) { return it->second; } } + + + + + +TNode TermDbSygus::getVar( TypeNode tn, int i ) { + while( i>=(int)d_fv[tn].size() ){ + std::stringstream ss; + TypeNode vtn = tn; + if( datatypes::DatatypesRewriter::isTypeDatatype( tn ) ){ + const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype(); + ss << "fv_" << dt.getName() << "_" << i; + if( !dt.getSygusType().isNull() ){ + vtn = TypeNode::fromType( dt.getSygusType() ); + } + }else{ + ss << "fv_" << tn << "_" << i; + } + Assert( !vtn.isNull() ); + Node v = NodeManager::currentNM()->mkSkolem( ss.str(), vtn, "for sygus normal form testing" ); + d_fv_stype[v] = tn; + d_fv[tn].push_back( v ); + } + return d_fv[tn][i]; +} + +TNode TermDbSygus::getVarInc( TypeNode tn, std::map< TypeNode, int >& var_count ) { + std::map< TypeNode, int >::iterator it = var_count.find( tn ); + if( it==var_count.end() ){ + var_count[tn] = 1; + return getVar( tn, 0 ); + }else{ + int index = it->second; + var_count[tn]++; + return getVar( tn, index ); + } +} + +TypeNode TermDbSygus::getSygusType( Node v ) { + Assert( d_fv_stype.find( v )!=d_fv_stype.end() ); + return d_fv_stype[v]; +} + +Node TermDbSygus::mkGeneric( const Datatype& dt, int c, std::map< TypeNode, int >& var_count, std::map< int, Node >& pre ) { + Assert( c>=0 && c<(int)dt.getNumConstructors() ); + Assert( dt.isSygus() ); + Assert( !dt[c].getSygusOp().isNull() ); + std::vector< Node > children; + Node op = Node::fromExpr( dt[c].getSygusOp() ); + if( op.getKind()!=BUILTIN ){ + children.push_back( op ); + } + for( int i=0; i<(int)dt[c].getNumArgs(); i++ ){ + TypeNode tna = TypeNode::fromType( ((SelectorType)dt[c][i].getType()).getRangeType() ); + Node a; + std::map< int, Node >::iterator it = pre.find( i ); + if( it!=pre.end() ){ + a = it->second; + }else{ + a = getVarInc( tna, var_count ); + } + Assert( !a.isNull() ); + children.push_back( a ); + } + if( op.getKind()==BUILTIN ){ + return NodeManager::currentNM()->mkNode( op, children ); + }else{ + if( children.size()==1 ){ + return children[0]; + }else{ + return NodeManager::currentNM()->mkNode( APPLY, children ); + /* + Node n = NodeManager::currentNM()->mkNode( APPLY, children ); + //must expand definitions + Node ne = Node::fromExpr( smt::currentSmtEngine()->expandDefinitions( n.toExpr() ) ); + Trace("sygus-util-debug") << "Expanded definitions in " << n << " to " << ne << std::endl; + return ne; + */ + } + } +} + +Node TermDbSygus::getSygusNormalized( Node n, std::map< TypeNode, int >& var_count, std::map< Node, Node >& subs ) { + return n; + if( n.getKind()==SKOLEM ){ + std::map< Node, Node >::iterator its = subs.find( n ); + if( its!=subs.end() ){ + return its->second; + }else{ + std::map< Node, TypeNode >::iterator it = d_fv_stype.find( n ); + if( it!=d_fv_stype.end() ){ + Node v = getVarInc( it->second, var_count ); + subs[n] = v; + return v; + }else{ + return n; + } + } + }else{ + if( n.getNumChildren()>0 ){ + std::vector< Node > children; + if( n.getMetaKind() == kind::metakind::PARAMETERIZED ){ + children.push_back( n.getOperator() ); + } + bool childChanged = false; + for( unsigned i=0; i<n.getNumChildren(); i++ ){ + Node nc = getSygusNormalized( n[i], var_count, subs ); + childChanged = childChanged || nc!=n[i]; + children.push_back( nc ); + } + if( childChanged ){ + return NodeManager::currentNM()->mkNode( n.getKind(), children ); + } + } + return n; + } +} + +Node TermDbSygus::getNormalized( TypeNode t, Node prog, bool do_pre_norm ) { + if( do_pre_norm ){ + std::map< TypeNode, int > var_count; + std::map< Node, Node > subs; + prog = getSygusNormalized( prog, var_count, subs ); + } + std::map< Node, Node >::iterator itn = d_normalized[t].find( prog ); + if( itn==d_normalized[t].end() ){ + Node progr = Node::fromExpr( smt::currentSmtEngine()->expandDefinitions( prog.toExpr() ) ); + progr = Rewriter::rewrite( progr ); + std::map< TypeNode, int > var_count; + std::map< Node, Node > subs; + progr = getSygusNormalized( progr, var_count, subs ); + Trace("sygus-sym-break2") << "...rewrites to " << progr << std::endl; + d_normalized[t][prog] = progr; + return progr; + }else{ + return itn->second; + } +} + +int TermDbSygus::getTermSize( Node n ){ + if( isVar( n ) ){ + return 0; + }else{ + int sum = 0; + for( unsigned i=0; i<n.getNumChildren(); i++ ){ + sum += getTermSize( n[i] ); + } + return 1+sum; + } + +} + +bool TermDbSygus::isAssoc( Kind k ) { + return k==PLUS || k==MULT || k==AND || k==OR || k==XOR || k==IFF || + k==BITVECTOR_PLUS || k==BITVECTOR_MULT || k==BITVECTOR_AND || k==BITVECTOR_OR || k==BITVECTOR_XOR || k==BITVECTOR_XNOR || k==BITVECTOR_CONCAT; +} + +bool TermDbSygus::isComm( Kind k ) { + return k==PLUS || k==MULT || k==AND || k==OR || k==XOR || k==IFF || + k==BITVECTOR_PLUS || k==BITVECTOR_MULT || k==BITVECTOR_AND || k==BITVECTOR_OR || k==BITVECTOR_XOR || k==BITVECTOR_XNOR; +} + +bool TermDbSygus::isAntisymmetric( Kind k, Kind& dk ) { + if( k==GT ){ + dk = LT; + return true; + }else if( k==GEQ ){ + dk = LEQ; + return true; + }else if( k==BITVECTOR_UGT ){ + dk = BITVECTOR_ULT; + return true; + }else if( k==BITVECTOR_UGE ){ + dk = BITVECTOR_ULE; + return true; + }else if( k==BITVECTOR_SGT ){ + dk = BITVECTOR_SLT; + return true; + }else if( k==BITVECTOR_SGE ){ + dk = BITVECTOR_SLE; + return true; + }else{ + return false; + } +} + +bool TermDbSygus::isIdempotentArg( Node n, Kind ik, int arg ) { + TypeNode tn = n.getType(); + if( n==getTypeValue( tn, 0 ) ){ + if( ik==PLUS || ik==OR || ik==XOR || ik==BITVECTOR_PLUS || ik==BITVECTOR_OR || ik==BITVECTOR_XOR ){ + return true; + }else if( ik==MINUS || ik==BITVECTOR_SHL || ik==BITVECTOR_LSHR || ik==BITVECTOR_SUB ){ + return arg==1; + } + }else if( n==getTypeValue( tn, 1 ) ){ + if( ik==MULT || ik==BITVECTOR_MULT ){ + return true; + }else if( ik==DIVISION || ik==BITVECTOR_UDIV || ik==BITVECTOR_SDIV ){ + return arg==1; + } + }else if( n==getTypeMaxValue( tn ) ){ + if( ik==IFF || ik==BITVECTOR_AND || ik==BITVECTOR_XNOR ){ + return true; + } + } + return false; +} + + +bool TermDbSygus::isSingularArg( Node n, Kind ik, int arg ) { + TypeNode tn = n.getType(); + if( n==getTypeValue( tn, 0 ) ){ + if( ik==AND || ik==MULT || ik==BITVECTOR_AND || ik==BITVECTOR_MULT ){ + return true; + }else if( ik==DIVISION || ik==BITVECTOR_UDIV || ik==BITVECTOR_SDIV ){ + return arg==0; + } + }else if( n==getTypeMaxValue( tn ) ){ + if( ik==OR || ik==BITVECTOR_OR ){ + return true; + } + } + return false; +} + +bool TermDbSygus::hasOffsetArg( Kind ik, int arg, int& offset, Kind& ok ) { + if( ik==LT ){ + Assert( arg==0 || arg==1 ); + offset = arg==0 ? 1 : -1; + ok = LEQ; + return true; + }else if( ik==BITVECTOR_ULT ){ + Assert( arg==0 || arg==1 ); + offset = arg==0 ? 1 : -1; + ok = BITVECTOR_ULE; + return true; + }else if( ik==BITVECTOR_SLT ){ + Assert( arg==0 || arg==1 ); + offset = arg==0 ? 1 : -1; + ok = BITVECTOR_SLE; + return true; + } + return false; +} + + +Node TermDbSygus::getTypeValue( TypeNode tn, int val ) { + std::map< int, Node >::iterator it = d_type_value[tn].find( val ); + if( it==d_type_value[tn].end() ){ + Node n; + if( tn.isInteger() || tn.isReal() ){ + Rational c(val); + n = NodeManager::currentNM()->mkConst( c ); + }else if( tn.isBitVector() ){ + unsigned int uv = val; + BitVector bval(tn.getConst<BitVectorSize>(), uv); + n = NodeManager::currentNM()->mkConst<BitVector>(bval); + }else if( tn.isBoolean() ){ + if( val==0 ){ + n = NodeManager::currentNM()->mkConst( false ); + } + } + d_type_value[tn][val] = n; + return n; + }else{ + return it->second; + } +} + +Node TermDbSygus::getTypeMaxValue( TypeNode tn ) { + std::map< TypeNode, Node >::iterator it = d_type_max_value.find( tn ); + if( it==d_type_max_value.end() ){ + Node n; + if( tn.isBitVector() ){ + n = bv::utils::mkOnes(tn.getConst<BitVectorSize>()); + }else if( tn.isBoolean() ){ + n = NodeManager::currentNM()->mkConst( true ); + } + d_type_max_value[tn] = n; + return n; + }else{ + return it->second; + } +} + +Node TermDbSygus::getTypeValueOffset( TypeNode tn, Node val, int offset, int& status ) { + std::map< int, Node >::iterator it = d_type_value_offset[tn][val].find( offset ); + if( it==d_type_value_offset[tn][val].end() ){ + Node val_o; + Node offset_val = getTypeValue( tn, offset ); + status = -1; + if( !offset_val.isNull() ){ + if( tn.isInteger() || tn.isReal() ){ + val_o = Rewriter::rewrite( NodeManager::currentNM()->mkNode( PLUS, val, offset_val ) ); + status = 0; + }else if( tn.isBitVector() ){ + val_o = Rewriter::rewrite( NodeManager::currentNM()->mkNode( BITVECTOR_PLUS, val, offset_val ) ); + } + } + d_type_value_offset[tn][val][offset] = val_o; + d_type_value_offset_status[tn][val][offset] = status; + return val_o; + }else{ + status = d_type_value_offset_status[tn][val][offset]; + return it->second; + } +} + +void TermDbSygus::registerSygusType( TypeNode tn ){ + if( d_register.find( tn )==d_register.end() ){ + if( !datatypes::DatatypesRewriter::isTypeDatatype( tn ) ){ + d_register[tn] = TypeNode::null(); + }else{ + const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype(); + Trace("sygus-util") << "Register type " << dt.getName() << "..." << std::endl; + d_register[tn] = TypeNode::fromType( dt.getSygusType() ); + if( d_register[tn].isNull() ){ + Trace("sygus-util") << "...not sygus." << std::endl; + }else{ + for( unsigned i=0; i<dt.getNumConstructors(); i++ ){ + Expr sop = dt[i].getSygusOp(); + Assert( !sop.isNull() ); + Node n = Node::fromExpr( sop ); + Trace("sygus-util") << " Operator #" << i << " : " << sop; + if( sop.getKind() == kind::BUILTIN ){ + Kind sk = NodeManager::operatorToKind( n ); + Trace("sygus-util") << ", kind = " << sk; + d_kinds[tn][sk] = i; + d_arg_kind[tn][i] = sk; + }else if( sop.isConst() ){ + Trace("sygus-util") << ", constant"; + d_consts[tn][n] = i; + d_arg_const[tn][i] = n; + } + d_ops[tn][n] = i; + d_arg_ops[tn][i] = n; + Trace("sygus-util") << std::endl; + } + } + } + } +} + +bool TermDbSygus::isRegistered( TypeNode tn ) { + return d_register.find( tn )!=d_register.end(); +} + +int TermDbSygus::getKindArg( TypeNode tn, Kind k ) { + Assert( isRegistered( tn ) ); + std::map< TypeNode, std::map< Kind, int > >::iterator itt = d_kinds.find( tn ); + if( itt!=d_kinds.end() ){ + std::map< Kind, int >::iterator it = itt->second.find( k ); + if( it!=itt->second.end() ){ + return it->second; + } + } + return -1; +} + +int TermDbSygus::getConstArg( TypeNode tn, Node n ){ + Assert( isRegistered( tn ) ); + std::map< TypeNode, std::map< Node, int > >::iterator itt = d_consts.find( tn ); + if( itt!=d_consts.end() ){ + std::map< Node, int >::iterator it = itt->second.find( n ); + if( it!=itt->second.end() ){ + return it->second; + } + } + return -1; +} + +int TermDbSygus::getOpArg( TypeNode tn, Node n ) { + std::map< Node, int >::iterator it = d_ops[tn].find( n ); + if( it!=d_ops[tn].end() ){ + return it->second; + }else{ + return -1; + } +} + +bool TermDbSygus::hasKind( TypeNode tn, Kind k ) { + return getKindArg( tn, k )!=-1; +} +bool TermDbSygus::hasConst( TypeNode tn, Node n ) { + return getConstArg( tn, n )!=-1; +} +bool TermDbSygus::hasOp( TypeNode tn, Node n ) { + return getOpArg( tn, n )!=-1; +} + +Node TermDbSygus::getArgOp( TypeNode tn, int i ) { + Assert( isRegistered( tn ) ); + std::map< TypeNode, std::map< int, Node > >::iterator itt = d_arg_ops.find( tn ); + if( itt!=d_arg_ops.end() ){ + std::map< int, Node >::iterator itn = itt->second.find( i ); + if( itn!=itt->second.end() ){ + return itn->second; + } + } + return Node::null(); +} + +Node TermDbSygus::getArgConst( TypeNode tn, int i ) { + Assert( isRegistered( tn ) ); + std::map< TypeNode, std::map< int, Node > >::iterator itt = d_arg_const.find( tn ); + if( itt!=d_arg_const.end() ){ + std::map< int, Node >::iterator itn = itt->second.find( i ); + if( itn!=itt->second.end() ){ + return itn->second; + } + } + return Node::null(); +} + +Kind TermDbSygus::getArgKind( TypeNode tn, int i ) { + Assert( isRegistered( tn ) ); + std::map< TypeNode, std::map< int, Kind > >::iterator itt = d_arg_kind.find( tn ); + if( itt!=d_arg_kind.end() ){ + std::map< int, Kind >::iterator itk = itt->second.find( i ); + if( itk!=itt->second.end() ){ + return itk->second; + } + } + return UNDEFINED_KIND; +} + +bool TermDbSygus::isKindArg( TypeNode tn, int i ) { + return getArgKind( tn, i )!=UNDEFINED_KIND; +} + +bool TermDbSygus::isConstArg( TypeNode tn, int i ) { + Assert( isRegistered( tn ) ); + std::map< TypeNode, std::map< int, Node > >::iterator itt = d_arg_const.find( tn ); + if( itt!=d_arg_const.end() ){ + return itt->second.find( i )!=itt->second.end(); + }else{ + return false; + } +} + +TypeNode TermDbSygus::getArgType( const DatatypeConstructor& c, int i ) { + Assert( i>=0 && i<(int)c.getNumArgs() ); + return TypeNode::fromType( ((SelectorType)c[i].getType()).getRangeType() ); +} |