From 7032064336dc62f5b3fc660f444f96b604b4a6bd Mon Sep 17 00:00:00 2001 From: ajreynol Date: Thu, 7 Jun 2018 15:39:37 -0500 Subject: ITE bitvector rewrites. --- src/theory/quantifiers/extended_rewrite.cpp | 181 +++++++++++++++++++++++++++- src/theory/quantifiers/extended_rewrite.h | 17 +++ 2 files changed, 197 insertions(+), 1 deletion(-) diff --git a/src/theory/quantifiers/extended_rewrite.cpp b/src/theory/quantifiers/extended_rewrite.cpp index 1ea1aeba3..abaf608cc 100644 --- a/src/theory/quantifiers/extended_rewrite.cpp +++ b/src/theory/quantifiers/extended_rewrite.cpp @@ -1673,7 +1673,56 @@ Node ExtendedRewriter::extendedRewriteBv(Node ret) } else if (k == ITE ) { - + if( ret[0].getKind()==EQUAL && ret[0][0].getType().isBitVector() ) + { + for( unsigned i=0; i<2; i++ ) + { + Node ct = ret[0][i]; + Node cto = ret[0][1-i]; + if( ct.isConst() && bv::utils::getSize(ct)==1 ) + { + // do they differ by exactly one bit? + std::vector< Node > rcc; + int diff_index = spliceBvConstBit(ret[1],ret[2],rcc); + if( diff_index>=0 ) + { + Node rpl = rcc[diff_index]==ct ? cto : TermUtil::mkNegate(BITVECTOR_NOT,cto); + rcc[diff_index] = rpl; + new_ret = rcc.size()==1 ? rcc[0] : nm->mkNode( BITVECTOR_CONCAT, rcc ); + debugExtendedRewrite(ret, new_ret, "BV 1bit ITE"); + } + } + else if( ct.getKind()==BITVECTOR_EXTRACT ) + { + Node cte = ct[0]; + if( cte==ret[1] || cte==ret[2] ) + { + // get the other branch + Node ob = ret[cte==ret[1] ? 2 : 1]; + // get the extension of the extract + std::vector< Node > exs; + exs.push_back(ct); + Node ext = extendBv(cte,exs); + Assert(ext.getType()==ob.getType()); + // now, splice the other branch + std::vector< Node > extc; + std::vector< Node > obc; + spliceBv(ext,ob,extc,obc); + if( obc.size()==2 && ( cto==obc[0] || cto==obc[1] ) ) + { + unsigned cflip_index = cto==obc[0] ? 1 : 0; + if( obc[cflip_index].isConst() && bv::utils::getSize(obc[cflip_index])==1 ) + { + obc[cflip_index] = TermUtil::mkNegate(BITVECTOR_NOT,obc[cflip_index]); + Node new_eq = cte.eqNode( nm->mkNode( BITVECTOR_CONCAT, obc ) ); + new_ret = nm->mkNode( ITE, new_eq, ret[1], ret[2] ); + debugExtendedRewrite(ret, new_ret, "BV 1bit exrem ITE"); + } + } + } + } + } + } } else if (k == BITVECTOR_AND || k == BITVECTOR_OR) { @@ -3100,6 +3149,136 @@ void ExtendedRewriter::spliceBv(Node a, } } + +int ExtendedRewriter::spliceBvConstBit(Node n1, + Node n2, + std::vector& nv) +{ + if( n1==n2 ) + { + return -1; + } + Trace("q-ext-rewrite-debug") << "Splice constant bv bit " << n1 << " " << n2 << std::endl; + // splice the children + std::vector< Node > rc1; + std::vector< Node > rc2; + spliceBv(n1,n2,rc1,rc2); + Assert( rc1.size()==rc2.size() ); + int diff_index = -1; + for( unsigned r=0; r=0 ) + { + // differ at more than one index + Trace("q-ext-rewrite-debug") << "...more than one diff component." << std::endl; + return -1; + } + diff_index = r; + } + } + Assert( diff_index>=0 ); + if( !rc1[diff_index].isConst() || !rc2[diff_index].isConst() ) + { + Trace("q-ext-rewrite-debug") << "...non-constant diff components." << std::endl; + return -1; + } + // insert prefix + if( diff_index>0 ) + { + nv.insert(nv.end(),rc1.begin(),rc1.begin()+diff_index); + } + Assert( rc1[diff_index]!=rc2[diff_index] ); + Node c1 = rc1[diff_index]; + Node c2 = rc2[diff_index]; + // do they differ by exactly one bit? + int bit_diff_index = -1; + unsigned csize = bv::utils::getSize(c1); + for( unsigned i=0; i=0 ) + { + // differ by more than one bit + nv.clear(); + Trace("q-ext-rewrite-debug") << "...more than one bit diff." << std::endl; + return -1; + } + bit_diff_index = i; + } + } + if( bit_diff_index>=0 ) + { + std::vector< Node > split; + if( bit_diff_index+1(csize) ) + { + Node extract = bv::utils::mkExtract(c1, csize-1, bit_diff_index+1); + nv.push_back(Rewriter::rewrite(extract)); + } + Node bit = bv::utils::getBit(c1,bit_diff_index) ? bv::utils::mkOnes(1) : bv::utils::mkZero(1); + diff_index = nv.size(); + nv.push_back(bit); + // remainder + if( bit_diff_index>0 ) + { + Node extract = bv::utils::mkExtract(c1, bit_diff_index-1,0); + nv.push_back(Rewriter::rewrite(extract)); + } + // insert suffix + if( diff_index(rc1.size()) ) + { + nv.insert(nv.end(),rc1.begin()+diff_index+1,rc1.end()); + } + return diff_index; + } + return -1; +} + +Node ExtendedRewriter::extendBv(Node n, std::vector< Node >& exs) +{ + std::map< unsigned, Node > ex_map; + for( const Node& e : exs ) + { + ex_map[bv::utils::getExtractHigh(e)] = e; + } + return extendBv(n,ex_map); +} + +Node ExtendedRewriter::extendBv(Node n, std::map< unsigned, Node >& ex_map) +{ + Trace("q-ext-rewrite-debug") << "extendBv " << n << std::endl; + std::vector< Node > children; + int counter = bv::utils::getSize(n) - 1; + for( const std::pair< const unsigned, Node >& ep : ex_map ) + { + Trace("q-ext-rewrite-debug") << " process " << ep.first << " : " << ep.second << ", counter=" << counter << std::endl; + unsigned start = ep.first; + Assert( static_cast(start)<=counter ); + if( static_cast(start)=0 ); + } + if( counter>=0 ) + { + children.push_back( bv::utils::mkExtract(n, counter, 0) ); + } + Trace("q-ext-rewrite-debug") << "extendBv finish, children = " << children << std::endl; + if( children.empty() ) + { + return n; + } + return children.size()==1 ? children[0] : NodeManager::currentNM()->mkNode( BITVECTOR_CONCAT, children ); +} + void ExtendedRewriter::debugExtendedRewrite(Node n, Node ret, const char* c) const diff --git a/src/theory/quantifiers/extended_rewrite.h b/src/theory/quantifiers/extended_rewrite.h index 450c3d3bb..27d030eaf 100644 --- a/src/theory/quantifiers/extended_rewrite.h +++ b/src/theory/quantifiers/extended_rewrite.h @@ -301,6 +301,23 @@ class ExtendedRewriter Node n2, std::vector& n1v, std::vector& n2v); + /** splice bv to constant bit + * + * If the return value of this method is a non-negative value i, it adds k + * terms to nv such that: + * n1 is equivalent to nv[0] ++ ... ++ nv[i] ++ ... ++ nv[k-1], + * n2 is equivalent to nv[0] ++ ... ++ (~)nv[i] ++ ... ++ nv[k-1], and + * nv[i] is a constant of bit-width one. + */ + int spliceBvConstBit(Node n1, + Node n2, + std::vector& nv); + /** extend + * + * This returns the concatentation node of the form + */ + Node extendBv(Node n, std::map< unsigned, Node >& ex_map); + Node extendBv(Node n, std::vector< Node >& exs); //--------------------------------------end bit-vectors }; -- cgit v1.2.3