diff options
Diffstat (limited to 'src/theory')
-rw-r--r-- | src/theory/bv/theory_bv.cpp | 4 | ||||
-rw-r--r-- | src/theory/bv/theory_bv_rewrite_rules.h | 12 | ||||
-rw-r--r-- | src/theory/bv/theory_bv_rewrite_rules_simplification.h | 196 | ||||
-rw-r--r-- | src/theory/bv/theory_bv_rewriter.cpp | 8 |
4 files changed, 216 insertions, 4 deletions
diff --git a/src/theory/bv/theory_bv.cpp b/src/theory/bv/theory_bv.cpp index 116903ea6..e03cecdd9 100644 --- a/src/theory/bv/theory_bv.cpp +++ b/src/theory/bv/theory_bv.cpp @@ -798,6 +798,10 @@ Node TheoryBV::ppRewrite(TNode t) } else { res = t; } + } else if (RewriteRule<SignExtendEqConst>::applies(t)) { + res = RewriteRule<SignExtendEqConst>::run<false>(t); + } else if (RewriteRule<ZeroExtendEqConst>::applies(t)) { + res = RewriteRule<ZeroExtendEqConst>::run<false>(t); } diff --git a/src/theory/bv/theory_bv_rewrite_rules.h b/src/theory/bv/theory_bv_rewrite_rules.h index 9f148d823..4d0f8033e 100644 --- a/src/theory/bv/theory_bv_rewrite_rules.h +++ b/src/theory/bv/theory_bv_rewrite_rules.h @@ -144,6 +144,10 @@ enum RewriteRuleId { SltZero, ZeroUlt, MergeSignExtend, + SignExtendEqConst, + ZeroExtendEqConst, + SignExtendUltConst, + ZeroExtendUltConst, /// normalization rules ExtractBitwise, @@ -303,6 +307,10 @@ inline std::ostream& operator << (std::ostream& out, RewriteRuleId ruleId) { case SltZero : out << "SltZero"; return out; case ZeroUlt : out << "ZeroUlt"; return out; case MergeSignExtend : out << "MergeSignExtend"; return out; + case SignExtendEqConst: out << "SignExtendEqConst"; return out; + case ZeroExtendEqConst: out << "ZeroExtendEqConst"; return out; + case SignExtendUltConst: out << "SignExtendUltConst"; return out; + case ZeroExtendUltConst: out << "ZeroExtendUltConst"; return out; case UleEliminate : out << "UleEliminate"; return out; case BitwiseSlicing : out << "BitwiseSlicing"; return out; @@ -533,6 +541,10 @@ struct AllRewriteRules { RewriteRule<IsPowerOfTwo> rule121; RewriteRule<RedorEliminate> rule122; RewriteRule<RedandEliminate> rule123; + RewriteRule<SignExtendEqConst> rule124; + RewriteRule<ZeroExtendEqConst> rule125; + RewriteRule<SignExtendUltConst> rule126; + RewriteRule<ZeroExtendUltConst> rule127; }; template<> inline diff --git a/src/theory/bv/theory_bv_rewrite_rules_simplification.h b/src/theory/bv/theory_bv_rewrite_rules_simplification.h index c7247f260..24e5fb5e8 100644 --- a/src/theory/bv/theory_bv_rewrite_rules_simplification.h +++ b/src/theory/bv/theory_bv_rewrite_rules_simplification.h @@ -21,6 +21,7 @@ #include "theory/bv/theory_bv_rewrite_rules.h" #include "theory/bv/theory_bv_utils.h" +#include "theory/rewriter.h" namespace CVC4 { namespace theory { @@ -1105,6 +1106,201 @@ Node RewriteRule<MergeSignExtend>::apply(TNode node) { return res; } +/** + * ZeroExtendEqConst + * + * Rewrite zero_extend(x^n, m) = c^n+m to + * + * false if c[n+m-1:n] != 0 + * x = c[n-1:0] otherwise. + */ +template <> +inline bool RewriteRule<ZeroExtendEqConst>::applies(TNode node) { + return node.getKind() == kind::EQUAL && + ((node[0].getKind() == kind::BITVECTOR_ZERO_EXTEND && + node[1].isConst()) || + (node[1].getKind() == kind::BITVECTOR_ZERO_EXTEND && + node[0].isConst())); +} + +template <> +inline Node RewriteRule<ZeroExtendEqConst>::apply(TNode node) { + TNode t, c; + if (node[0].getKind() == kind::BITVECTOR_ZERO_EXTEND) { + t = node[0][0]; + c = node[1]; + } else { + t = node[1][0]; + c = node[0]; + } + BitVector c_hi = + c.getConst<BitVector>().extract(utils::getSize(c) - 1, utils::getSize(t)); + BitVector c_lo = c.getConst<BitVector>().extract(utils::getSize(t) - 1, 0); + BitVector zero = BitVector(c_hi.getSize(), Integer(0)); + + if (c_hi == zero) { + return NodeManager::currentNM()->mkNode(kind::EQUAL, t, + utils::mkConst(c_lo)); + } + return utils::mkFalse(); +} + +/** + * SignExtendEqConst + * + * Rewrite sign_extend(x^n, m) = c^n+m to + * + * x = c[n-1:0] if (c[n-1:n-1] == 0 && c[n+m-1:n] == 0) || + * (c[n-1:n-1] == 1 && c[n+m-1:n] == ~0) + * false otherwise. + */ +template <> +inline bool RewriteRule<SignExtendEqConst>::applies(TNode node) { + return node.getKind() == kind::EQUAL && + ((node[0].getKind() == kind::BITVECTOR_SIGN_EXTEND && + node[1].isConst()) || + (node[1].getKind() == kind::BITVECTOR_SIGN_EXTEND && + node[0].isConst())); +} + +template <> +inline Node RewriteRule<SignExtendEqConst>::apply(TNode node) { + TNode t, c; + if (node[0].getKind() == kind::BITVECTOR_SIGN_EXTEND) { + t = node[0][0]; + c = node[1]; + } else { + t = node[1][0]; + c = node[0]; + } + unsigned pos_msb_t = utils::getSize(t) - 1; + BitVector c_hi = + c.getConst<BitVector>().extract(utils::getSize(c) - 1, pos_msb_t); + BitVector c_lo = c.getConst<BitVector>().extract(pos_msb_t, 0); + BitVector zero = BitVector(c_hi.getSize(), Integer(0)); + + if (c_hi == zero || c_hi == ~zero) { + return NodeManager::currentNM()->mkNode(kind::EQUAL, t, + utils::mkConst(c_lo)); + } + return utils::mkFalse(); +} + +/** + * ZeroExtendUltConst + * + * Rewrite zero_extend(x^n,m) < c^n+m to + * + * x < c[n-1:0] if c[n+m-1:n] == 0. + * + * Rewrite c^n+m < Rewrite zero_extend(x^n,m) to + * + * c[n-1:0] < x if c[n+m-1:n] == 0. + */ +template <> +inline bool RewriteRule<ZeroExtendUltConst>::applies(TNode node) { + if (node.getKind() == kind::BITVECTOR_ULT && + ((node[0].getKind() == kind::BITVECTOR_ZERO_EXTEND && + node[1].isConst()) || + (node[1].getKind() == kind::BITVECTOR_ZERO_EXTEND && + node[0].isConst()))) { + TNode t, c; + bool is_lhs = node[0].getKind() == kind::BITVECTOR_ZERO_EXTEND; + if (is_lhs) { + t = node[0][0]; + c = node[1]; + } else { + t = node[1][0]; + c = node[0]; + } + BitVector bv_c = c.getConst<BitVector>(); + BitVector bv_max = + BitVector(utils::getSize(c)).setBit(utils::getSize(t) - 1); + + BitVector c_hi = c.getConst<BitVector>().extract(utils::getSize(c) - 1, + utils::getSize(t)); + BitVector zero = BitVector(c_hi.getSize(), Integer(0)); + + return c_hi == zero; + } + return false; +} + +template <> +inline Node RewriteRule<ZeroExtendUltConst>::apply(TNode node) { + TNode t, c; + bool is_lhs = node[0].getKind() == kind::BITVECTOR_ZERO_EXTEND; + if (is_lhs) { + t = node[0][0]; + c = node[1]; + } else { + t = node[1][0]; + c = node[0]; + } + Node c_lo = + utils::mkConst(c.getConst<BitVector>().extract(utils::getSize(t) - 1, 0)); + + if (is_lhs) { + return NodeManager::currentNM()->mkNode(kind::BITVECTOR_ULT, t, c_lo); + } + return NodeManager::currentNM()->mkNode(kind::BITVECTOR_ULT, c_lo, t); +} + +/** + * SignExtendUltConst + * + * Rewrite sign_extend(x^n,m) < c^n+m to + * + * x < c[n-1:0] if c <= (1 << (n - 1)). + * + * Rewrite c^n+m < sign_extend(x^n,m) to + * + * c[n-1:0] < x if c < (1 << (n - 1)). + */ +template <> +inline bool RewriteRule<SignExtendUltConst>::applies(TNode node) { + if (node.getKind() == kind::BITVECTOR_ULT && + ((node[0].getKind() == kind::BITVECTOR_SIGN_EXTEND && + node[1].isConst()) || + (node[1].getKind() == kind::BITVECTOR_SIGN_EXTEND && + node[0].isConst()))) { + TNode t, c; + bool is_lhs = node[0].getKind() == kind::BITVECTOR_SIGN_EXTEND; + if (is_lhs) { + t = node[0][0]; + c = node[1]; + } else { + t = node[1][0]; + c = node[0]; + } + BitVector bv_c = c.getConst<BitVector>(); + BitVector bv_max = + BitVector(utils::getSize(c)).setBit(utils::getSize(t) - 1); + + return (is_lhs && bv_c <= bv_max) || (!is_lhs && bv_c < bv_max); + } + return false; +} + +template <> +inline Node RewriteRule<SignExtendUltConst>::apply(TNode node) { + TNode t, c; + bool is_lhs = node[0].getKind() == kind::BITVECTOR_SIGN_EXTEND; + if (is_lhs) { + t = node[0][0]; + c = node[1]; + } else { + t = node[1][0]; + c = node[0]; + } + Node c_lo = + utils::mkConst(c.getConst<BitVector>().extract(utils::getSize(t) - 1, 0)); + + if (is_lhs) { + return NodeManager::currentNM()->mkNode(kind::BITVECTOR_ULT, t, c_lo); + } + return NodeManager::currentNM()->mkNode(kind::BITVECTOR_ULT, c_lo, t); +} template<> inline bool RewriteRule<MultSlice>::applies(TNode node) { diff --git a/src/theory/bv/theory_bv_rewriter.cpp b/src/theory/bv/theory_bv_rewriter.cpp index f3185bc13..bfaf517cc 100644 --- a/src/theory/bv/theory_bv_rewriter.cpp +++ b/src/theory/bv/theory_bv_rewriter.cpp @@ -68,10 +68,10 @@ RewriteResponse TheoryBVRewriter::postRewrite(TNode node) { RewriteResponse TheoryBVRewriter::RewriteUlt(TNode node, bool prerewrite) { // reduce common subexpressions on both sides Node resultNode = LinearRewriteStrategy - < RewriteRule<EvalUlt>, - // if both arguments are constants evaluates - RewriteRule<UltZero> - // a < 0 rewrites to false + < RewriteRule<EvalUlt>, // if both arguments are constants evaluates + RewriteRule<UltZero>, // a < 0 rewrites to false, + RewriteRule<SignExtendUltConst>, + RewriteRule<ZeroExtendUltConst> >::apply(node); return RewriteResponse(REWRITE_DONE, resultNode); |