From ce6d8fde786eb6b4bb658ba83afd384d02853948 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Tue, 2 Jan 2018 16:12:45 -0600 Subject: Rewrites for BitVector multiplication (#1465) --- .../bv/theory_bv_rewrite_rules_normalization.h | 65 ++++++----- .../bv/theory_bv_rewrite_rules_simplification.h | 125 +++++++++++++++------ src/theory/bv/theory_bv_utils.h | 23 +++- 3 files changed, 146 insertions(+), 67 deletions(-) (limited to 'src/theory/bv') diff --git a/src/theory/bv/theory_bv_rewrite_rules_normalization.h b/src/theory/bv/theory_bv_rewrite_rules_normalization.h index 61f072643..3ad733f99 100644 --- a/src/theory/bv/theory_bv_rewrite_rules_normalization.h +++ b/src/theory/bv/theory_bv_rewrite_rules_normalization.h @@ -381,25 +381,7 @@ Node RewriteRule::apply(TNode node) { template<> inline bool RewriteRule::applies(TNode node) { - if (node.getKind() != kind::BITVECTOR_MULT) { - return false; - } - TNode::iterator child_it = node.begin(); - TNode::iterator child_next = child_it + 1; - for(; child_next != node.end(); ++child_it, ++child_next) { - if ((*child_it).isConst() || - !((*child_it) < (*child_next))) { - return true; - } - } - if ((*child_it).isConst()) { - BitVector bv = (*child_it).getConst(); - if (bv == BitVector(utils::getSize(node), (unsigned) 0) || - bv == BitVector(utils::getSize(node), (unsigned) 1)) { - return true; - } - } - return false; + return node.getKind() == kind::BITVECTOR_MULT; } template<> inline @@ -408,31 +390,58 @@ Node RewriteRule::apply(TNode node) { unsigned size = utils::getSize(node); BitVector constant(size, Integer(1)); - std::vector children; - for(unsigned i = 0; i < node.getNumChildren(); ++i) { - TNode current = node[i]; + bool isNeg = false; + std::vector children; + for (const TNode& current : node) + { if (current.getKind() == kind::CONST_BITVECTOR) { BitVector value = current.getConst(); constant = constant * value; if(constant == BitVector(size, (unsigned) 0)) { return utils::mkConst(size, 0); } + } + else if (current.getKind() == kind::BITVECTOR_NEG) + { + isNeg = !isNeg; + children.push_back(current[0]); } else { children.push_back(current); } } + BitVector oValue = BitVector(size, static_cast(1)); + BitVector noValue = utils::mkBitVectorOnes(size); + + if (children.empty()) + { + Assert(!isNeg); + return utils::mkConst(constant); + } std::sort(children.begin(), children.end()); - if(constant != BitVector(size, (unsigned)1)) { - children.push_back(utils::mkConst(constant)); + if (constant == noValue) + { + isNeg = !isNeg; } - - if(children.size() == 0) { - return utils::mkConst(size, (unsigned)1); + else if (constant != oValue) + { + if (isNeg) + { + isNeg = !isNeg; + constant = -constant; + } + children.push_back(utils::mkConst(constant)); } - return utils::mkNode(kind::BITVECTOR_MULT, children); + Node ret = utils::mkNode(kind::BITVECTOR_MULT, children); + + // if negative, negate entire node + if (isNeg && size > 1) + { + ret = utils::mkNode(kind::BITVECTOR_NEG, ret); + } + return ret; } diff --git a/src/theory/bv/theory_bv_rewrite_rules_simplification.h b/src/theory/bv/theory_bv_rewrite_rules_simplification.h index 98a311890..9d44d3be5 100644 --- a/src/theory/bv/theory_bv_rewrite_rules_simplification.h +++ b/src/theory/bv/theory_bv_rewrite_rules_simplification.h @@ -751,38 +751,56 @@ Node RewriteRule::apply(TNode node) { * (a * 2^k) ==> a[n-k-1:0] 0_k */ -template<> inline -bool RewriteRule::applies(TNode node) { +template <> +inline bool RewriteRule::applies(TNode node) +{ if (node.getKind() != kind::BITVECTOR_MULT) return false; - for(unsigned i = 0; i < node.getNumChildren(); ++i) { - if (utils::isPow2Const(node[i])) { + for (const Node& cn : node) + { + bool cIsNeg = false; + if (utils::isPow2Const(cn, cIsNeg)) + { return true; } } return false; } -template<> inline -Node RewriteRule::apply(TNode node) { +template <> +inline Node RewriteRule::apply(TNode node) +{ Debug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + unsigned size = utils::getSize(node); std::vector children; - unsigned exponent = 0; - for(unsigned i = 0; i < node.getNumChildren(); ++i) { - unsigned exp = utils::isPow2Const(node[i]); + unsigned exponent = 0; + bool isNeg = false; + for (const Node& cn : node) + { + bool cIsNeg = false; + unsigned exp = utils::isPow2Const(cn, cIsNeg); if (exp) { exponent += exp - 1; + if (cIsNeg) + { + isNeg = !isNeg; + } } else { - children.push_back(node[i]); + children.push_back(cn); } } - Node a = utils::mkNode(kind::BITVECTOR_MULT, children); + Node a = utils::mkNode(kind::BITVECTOR_MULT, children); - Node extract = utils::mkExtract(a, utils::getSize(node) - exponent - 1, 0); + if (isNeg && size > 1) + { + a = utils::mkNode(kind::BITVECTOR_NEG, a); + } + + Node extract = utils::mkExtract(a, size - exponent - 1, 0); Node zeros = utils::mkConst(exponent, 0); return utils::mkConcat(extract, zeros); } @@ -888,24 +906,43 @@ Node RewriteRule::apply(TNode node) { * (a udiv 2^k) ==> 0_k a[n-1: k] */ -template<> inline -bool RewriteRule::applies(TNode node) { - return (node.getKind() == kind::BITVECTOR_UDIV_TOTAL && - utils::isPow2Const(node[1])); +template <> +inline bool RewriteRule::applies(TNode node) +{ + bool isNeg = false; + if (node.getKind() == kind::BITVECTOR_UDIV_TOTAL + && utils::isPow2Const(node[1], isNeg)) + { + return !isNeg; + } + return false; } -template<> inline -Node RewriteRule::apply(TNode node) { +template <> +inline Node RewriteRule::apply(TNode node) +{ Debug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; + unsigned size = utils::getSize(node); Node a = node[0]; - unsigned power = utils::isPow2Const(node[1]) -1; - if (power == 0) { - return a; + bool isNeg = false; + unsigned power = utils::isPow2Const(node[1], isNeg) - 1; + Node ret; + if (power == 0) + { + ret = a; } - Node extract = utils::mkExtract(a, utils::getSize(node) - 1, power); - Node zeros = utils::mkConst(power, 0); - - return utils::mkNode(kind::BITVECTOR_CONCAT, zeros, extract); + else + { + Node extract = utils::mkExtract(a, size - 1, power); + Node zeros = utils::mkConst(power, 0); + + ret = utils::mkNode(kind::BITVECTOR_CONCAT, zeros, extract); + } + if (isNeg && size > 1) + { + ret = utils::mkNode(kind::BITVECTOR_NEG, ret); + } + return ret; } /** @@ -950,23 +987,37 @@ inline Node RewriteRule::apply(TNode node) { * (a urem 2^k) ==> 0_(n-k) a[k-1:0] */ -template<> inline -bool RewriteRule::applies(TNode node) { - return (node.getKind() == kind::BITVECTOR_UREM_TOTAL && - utils::isPow2Const(node[1])); +template <> +inline bool RewriteRule::applies(TNode node) +{ + bool isNeg; + if (node.getKind() == kind::BITVECTOR_UREM_TOTAL + && utils::isPow2Const(node[1], isNeg)) + { + return !isNeg; + } + return false; } -template<> inline -Node RewriteRule::apply(TNode node) { +template <> +inline Node RewriteRule::apply(TNode node) +{ Debug("bv-rewrite") << "RewriteRule(" << node << ")" << std::endl; TNode a = node[0]; - unsigned power = utils::isPow2Const(node[1]) - 1; - if (power == 0) { - return utils::mkConst(utils::getSize(node), 0); + bool isNeg = false; + unsigned power = utils::isPow2Const(node[1], isNeg) - 1; + Node ret; + if (power == 0) + { + ret = utils::mkZero(utils::getSize(node)); + } + else + { + Node extract = utils::mkExtract(a, power - 1, 0); + Node zeros = utils::mkZero(utils::getSize(node) - power); + ret = utils::mkNode(kind::BITVECTOR_CONCAT, zeros, extract); } - Node extract = utils::mkExtract(a, power - 1, 0); - Node zeros = utils::mkConst(utils::getSize(node) - power, 0); - return utils::mkNode(kind::BITVECTOR_CONCAT, zeros, extract); + return ret; } /** diff --git a/src/theory/bv/theory_bv_utils.h b/src/theory/bv/theory_bv_utils.h index d9d1183af..ed772b7c4 100644 --- a/src/theory/bv/theory_bv_utils.h +++ b/src/theory/bv/theory_bv_utils.h @@ -272,13 +272,32 @@ inline Node mkConjunction(const std::set nodes) { return conjunction; } -inline unsigned isPow2Const(TNode node) { +/** + * If node is a constant of the form 2^c or -2^c, then this function returns + * c+1. Otherwise, this function returns 0. The flag isNeg is updated to + * indicate whether node is negative. + */ +inline unsigned isPow2Const(TNode node, bool& isNeg) +{ if (node.getKind() != kind::CONST_BITVECTOR) { return false; } BitVector bv = node.getConst(); - return bv.isPow2(); + unsigned p = bv.isPow2(); + if (p != 0) + { + isNeg = false; + return p; + } + BitVector nbv = -bv; + p = nbv.isPow2(); + if (p != 0) + { + isNeg = true; + return p; + } + return false; } inline Node mkOr(const std::vector& nodes) { -- cgit v1.2.3