diff options
author | ajreynol <andrew.j.reynolds@gmail.com> | 2018-06-05 11:10:18 -0500 |
---|---|---|
committer | ajreynol <andrew.j.reynolds@gmail.com> | 2018-06-05 11:10:18 -0500 |
commit | 25236c1f36fc1615b8a0ac55a694242d0e6bf607 (patch) | |
tree | f64b06bca352c682ffd98c92b5e2ebb3a3a1391c | |
parent | 67249887644dcba14fa07be341a3dcd853774933 (diff) |
Enable factoring, unit resolution.
-rw-r--r-- | src/theory/quantifiers/extended_rewrite.cpp | 101 |
1 files changed, 83 insertions, 18 deletions
diff --git a/src/theory/quantifiers/extended_rewrite.cpp b/src/theory/quantifiers/extended_rewrite.cpp index dae0b04f6..643aae71f 100644 --- a/src/theory/quantifiers/extended_rewrite.cpp +++ b/src/theory/quantifiers/extended_rewrite.cpp @@ -459,14 +459,12 @@ Node ExtendedRewriter::extendedRewriteAndOr(Node n) return new_ret; } // factoring - /* new_ret = extendedRewriteFactoring(AND,OR,NOT,n); if(!new_ret.isNull()) { debugExtendedRewrite(n, new_ret, "Bool factoring"); return new_ret; } - */ // equality resolution new_ret = @@ -800,11 +798,15 @@ Node ExtendedRewriter::extendedRewriteBcp( Node ExtendedRewriter::extendedRewriteFactoring( Kind andk, Kind ork, Kind notk, Node n) { + Trace("ext-rew-factoring") << "Factoring: *** INPUT: " << n << std::endl; + NodeManager* nm = NodeManager::currentNM(); + Kind nk = n.getKind(); Assert( nk==andk || nk==ork ); Kind onk = nk==andk ? ork : andk; // count the number of times atoms occur std::map< Node, std::vector< Node > > lit_to_cl; + std::map< Node, std::vector< Node > > cl_to_lits; for( const Node& nc : n ) { Kind nck = nc.getKind(); @@ -813,27 +815,65 @@ Node ExtendedRewriter::extendedRewriteFactoring( for( const Node& ncl : nc ) { lit_to_cl[ncl].push_back(nc); + cl_to_lits[nc].push_back(ncl); } } else { lit_to_cl[nc].push_back(nc); + cl_to_lits[nc].push_back(nc); } } - // get the maximum shared literal + // get the maximum shared literal to factor unsigned max_size = 0; - Node max_lit; + Node flit; for( const std::pair< const Node, std::vector< Node > >& ltc : lit_to_cl ) { if( ltc.second.size()>max_size ) { max_size = ltc.second.size(); - max_lit = ltc.first; + flit = ltc.first; } } if( max_size>1 ) { - + // do the factoring + std::vector< Node > children; + std::vector< Node > fchildren; + std::map< Node, std::vector< Node > >::iterator itl = lit_to_cl.find(flit); + std::vector< Node >& cls = itl->second; + for( const Node& nc : n ) + { + if( std::find( cls.begin(), cls.end(), nc )==cls.end() ) + { + children.push_back( nc ); + } + else + { + // rebuild + std::vector< Node >& lits = cl_to_lits[nc]; + std::vector< Node >::iterator itlfl = std::find( lits.begin(), lits.end(), flit ); + Assert( itlfl!=lits.end() ); + lits.erase( itlfl ); + // rebuild + if( !lits.empty() ) + { + Node new_cl = lits.size()==1 ? lits[0] : nm->mkNode( onk, lits ); + fchildren.push_back(new_cl); + } + } + } + // rebuild the factored children + Assert( !fchildren.empty() ); + Node fcn = fchildren.size()==1 ? fchildren[0] : nm->mkNode(nk,fchildren); + children.push_back(nm->mkNode(onk,flit,fcn)); + Node ret = children.size()==1 ? children[0] : nm->mkNode(nk,children); + Trace("ext-rew-factoring") << "Factoring: *** OUTPUT: " << ret << std::endl; + return ret; + } + else + { + Trace("ext-rew-factoring") << "Factoring: no change" << std::endl; } return Node::null(); } @@ -936,33 +976,32 @@ class SimpSubsumeTrie public: std::map< Node, SimpSubsumeTrie > d_children; Node d_data; - Node addTerm( Node c, std::map< Node, bool >& atoms, std::vector< Node >& alist, unsigned index=0, bool doAdd = true ) + void addTerm( Node c, std::vector< Node >& alist, std::vector< Node >& subsumes, unsigned index=0, bool doAdd = true ) { if( !d_data.isNull() ) { - return d_data; + subsumes.push_back(d_data); } if( doAdd ) { if( index==alist.size() ) { d_data = c; - return c; + return; } } // try all children where we have this atom for( std::pair<const Node, SimpSubsumeTrie >& cp : d_children ) { - if( atoms.find(cp.first)!=atoms.end() ) + if(std::find(alist.begin(),alist.end(),cp.first)!=alist.end() ) { - Node cc = cp.second.addTerm(c,atoms,alist,0,false); + cp.second.addTerm(c,alist,subsumes,0,false); } } if( doAdd ) { - return d_children[alist[index]].addTerm(c,atoms,alist,index+1,doAdd); + d_children[alist[index]].addTerm(c,alist,subsumes,index+1,doAdd); } - return Node::null(); } }; @@ -1072,10 +1111,22 @@ Node ExtendedRewriter::extendedRewriteEqChain( SimpSubsumeTrie sst; for (std::pair<const Node, bool>& cp : cstatus) { + if( !cp.second ) + { + // already eliminated + continue; + } Node c = cp.first; - Node cc = sst.addTerm(c,atoms[c],alist[c]); - if( cc!=c ) + Trace("ext-rew-eqchain") << " - add term " << c << " with atom list " << alist[c] << "...\n"; + std::vector< Node > subsumes; + sst.addTerm(c,alist[c], subsumes); + for( const Node& cc : subsumes ) { + if( !cstatus[cc] ) + { + // subsumes a child that was already eliminated + continue; + } Trace("ext-rew-eqchain") << " eqchain-simplify: " << c << " subsumes " << cc << std::endl; // for each of the atoms in cc std::map< Node, std::map< Node, bool > >::iterator itc = atoms.find(c); @@ -1112,6 +1163,7 @@ Node ExtendedRewriter::extendedRewriteEqChain( rem_children.push_back(polc ? a : TermUtil::mkNegate( notk,a )); } } + Trace("ext-rew-eqchain") << " #common/diff/rem: " << common_children.size() << "/" << diff_children.size() << "/" << rem_children.size() << "\n"; bool do_rewrite = false; if( common_children.empty() && itc->second.size()==itcc->second.size() && itcc->second.size()==2 ) { @@ -1122,6 +1174,16 @@ Node ExtendedRewriter::extendedRewriteEqChain( gpol = !gpol; Trace("ext-rew-eqchain") << " apply 2-child all-diff\n"; } + else if( common_children.empty() && diff_children.size()==1 ) + { + do_rewrite = true; + // x = ( ~x | y ) ---> ~( ~x | ~y ) + Node remn = rem_children.size()==1 ? rem_children[0] : nm->mkNode( ork, rem_children ); + remn = TermUtil::mkNegate( notk, remn ); + children.push_back(nm->mkNode(ork,diff_children[0],remn)); + gpol = !gpol; + Trace("ext-rew-eqchain") << " apply unit resolution\n"; + } else if( diff_children.size()==1 && itc->second.size()==itcc->second.size() ) { // ( x | y | z ) = ( x | ~y | z ) ---> ( x | z ) @@ -1137,8 +1199,9 @@ Node ExtendedRewriter::extendedRewriteEqChain( if( rem_children.empty() ) { // x | y = x | y ---> true + // this can happen if we have ( ~x & ~y ) = ( x | y ) children.push_back(TermUtil::mkTypeMaxValue(tn)); - Trace("ext-rew-eqchain") << " apply deMorgan\n"; + Trace("ext-rew-eqchain") << " apply cancel\n"; } else { @@ -1161,14 +1224,16 @@ Node ExtendedRewriter::extendedRewriteEqChain( gpol = !gpol; } } + // + break; } } } Trace("ext-rew-eqchain") << "eqchain-simplify: finish" << std::endl; - - // cancel AND/OR children if possible + // this catches cases like ( x & y ) = ( ( x & y ) | z ), where ( x & y ) is + // considered a unit, whereas above it is expanded to ~( ~x | ~y ). for (std::pair<const Node, bool>& cp : cstatus) { if (cp.second) |