summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMathias Preiner <mathias.preiner@gmail.com>2017-10-20 21:03:04 -0700
committerGitHub <noreply@github.com>2017-10-20 21:03:04 -0700
commit6b5c27d7f634eb5985ce455989fcda36e1261929 (patch)
treeb80be5a7c5099f4517b912850f1e91a72117d18e
parent7908fd9c901c056628f5f3846049d078d48bc396 (diff)
Add rewriting rules for Eq/Ult with sign_extend and constants. (#1258)
-rw-r--r--src/theory/bv/theory_bv.cpp4
-rw-r--r--src/theory/bv/theory_bv_rewrite_rules.h12
-rw-r--r--src/theory/bv/theory_bv_rewrite_rules_simplification.h196
-rw-r--r--src/theory/bv/theory_bv_rewriter.cpp8
4 files changed, 216 insertions, 4 deletions
diff --git a/src/theory/bv/theory_bv.cpp b/src/theory/bv/theory_bv.cpp
index 116903ea6..e03cecdd9 100644
--- a/src/theory/bv/theory_bv.cpp
+++ b/src/theory/bv/theory_bv.cpp
@@ -798,6 +798,10 @@ Node TheoryBV::ppRewrite(TNode t)
} else {
res = t;
}
+ } else if (RewriteRule<SignExtendEqConst>::applies(t)) {
+ res = RewriteRule<SignExtendEqConst>::run<false>(t);
+ } else if (RewriteRule<ZeroExtendEqConst>::applies(t)) {
+ res = RewriteRule<ZeroExtendEqConst>::run<false>(t);
}
diff --git a/src/theory/bv/theory_bv_rewrite_rules.h b/src/theory/bv/theory_bv_rewrite_rules.h
index 9f148d823..4d0f8033e 100644
--- a/src/theory/bv/theory_bv_rewrite_rules.h
+++ b/src/theory/bv/theory_bv_rewrite_rules.h
@@ -144,6 +144,10 @@ enum RewriteRuleId {
SltZero,
ZeroUlt,
MergeSignExtend,
+ SignExtendEqConst,
+ ZeroExtendEqConst,
+ SignExtendUltConst,
+ ZeroExtendUltConst,
/// normalization rules
ExtractBitwise,
@@ -303,6 +307,10 @@ inline std::ostream& operator << (std::ostream& out, RewriteRuleId ruleId) {
case SltZero : out << "SltZero"; return out;
case ZeroUlt : out << "ZeroUlt"; return out;
case MergeSignExtend : out << "MergeSignExtend"; return out;
+ case SignExtendEqConst: out << "SignExtendEqConst"; return out;
+ case ZeroExtendEqConst: out << "ZeroExtendEqConst"; return out;
+ case SignExtendUltConst: out << "SignExtendUltConst"; return out;
+ case ZeroExtendUltConst: out << "ZeroExtendUltConst"; return out;
case UleEliminate : out << "UleEliminate"; return out;
case BitwiseSlicing : out << "BitwiseSlicing"; return out;
@@ -533,6 +541,10 @@ struct AllRewriteRules {
RewriteRule<IsPowerOfTwo> rule121;
RewriteRule<RedorEliminate> rule122;
RewriteRule<RedandEliminate> rule123;
+ RewriteRule<SignExtendEqConst> rule124;
+ RewriteRule<ZeroExtendEqConst> rule125;
+ RewriteRule<SignExtendUltConst> rule126;
+ RewriteRule<ZeroExtendUltConst> rule127;
};
template<> inline
diff --git a/src/theory/bv/theory_bv_rewrite_rules_simplification.h b/src/theory/bv/theory_bv_rewrite_rules_simplification.h
index c7247f260..24e5fb5e8 100644
--- a/src/theory/bv/theory_bv_rewrite_rules_simplification.h
+++ b/src/theory/bv/theory_bv_rewrite_rules_simplification.h
@@ -21,6 +21,7 @@
#include "theory/bv/theory_bv_rewrite_rules.h"
#include "theory/bv/theory_bv_utils.h"
+#include "theory/rewriter.h"
namespace CVC4 {
namespace theory {
@@ -1105,6 +1106,201 @@ Node RewriteRule<MergeSignExtend>::apply(TNode node) {
return res;
}
+/**
+ * ZeroExtendEqConst
+ *
+ * Rewrite zero_extend(x^n, m) = c^n+m to
+ *
+ * false if c[n+m-1:n] != 0
+ * x = c[n-1:0] otherwise.
+ */
+template <>
+inline bool RewriteRule<ZeroExtendEqConst>::applies(TNode node) {
+ return node.getKind() == kind::EQUAL &&
+ ((node[0].getKind() == kind::BITVECTOR_ZERO_EXTEND &&
+ node[1].isConst()) ||
+ (node[1].getKind() == kind::BITVECTOR_ZERO_EXTEND &&
+ node[0].isConst()));
+}
+
+template <>
+inline Node RewriteRule<ZeroExtendEqConst>::apply(TNode node) {
+ TNode t, c;
+ if (node[0].getKind() == kind::BITVECTOR_ZERO_EXTEND) {
+ t = node[0][0];
+ c = node[1];
+ } else {
+ t = node[1][0];
+ c = node[0];
+ }
+ BitVector c_hi =
+ c.getConst<BitVector>().extract(utils::getSize(c) - 1, utils::getSize(t));
+ BitVector c_lo = c.getConst<BitVector>().extract(utils::getSize(t) - 1, 0);
+ BitVector zero = BitVector(c_hi.getSize(), Integer(0));
+
+ if (c_hi == zero) {
+ return NodeManager::currentNM()->mkNode(kind::EQUAL, t,
+ utils::mkConst(c_lo));
+ }
+ return utils::mkFalse();
+}
+
+/**
+ * SignExtendEqConst
+ *
+ * Rewrite sign_extend(x^n, m) = c^n+m to
+ *
+ * x = c[n-1:0] if (c[n-1:n-1] == 0 && c[n+m-1:n] == 0) ||
+ * (c[n-1:n-1] == 1 && c[n+m-1:n] == ~0)
+ * false otherwise.
+ */
+template <>
+inline bool RewriteRule<SignExtendEqConst>::applies(TNode node) {
+ return node.getKind() == kind::EQUAL &&
+ ((node[0].getKind() == kind::BITVECTOR_SIGN_EXTEND &&
+ node[1].isConst()) ||
+ (node[1].getKind() == kind::BITVECTOR_SIGN_EXTEND &&
+ node[0].isConst()));
+}
+
+template <>
+inline Node RewriteRule<SignExtendEqConst>::apply(TNode node) {
+ TNode t, c;
+ if (node[0].getKind() == kind::BITVECTOR_SIGN_EXTEND) {
+ t = node[0][0];
+ c = node[1];
+ } else {
+ t = node[1][0];
+ c = node[0];
+ }
+ unsigned pos_msb_t = utils::getSize(t) - 1;
+ BitVector c_hi =
+ c.getConst<BitVector>().extract(utils::getSize(c) - 1, pos_msb_t);
+ BitVector c_lo = c.getConst<BitVector>().extract(pos_msb_t, 0);
+ BitVector zero = BitVector(c_hi.getSize(), Integer(0));
+
+ if (c_hi == zero || c_hi == ~zero) {
+ return NodeManager::currentNM()->mkNode(kind::EQUAL, t,
+ utils::mkConst(c_lo));
+ }
+ return utils::mkFalse();
+}
+
+/**
+ * ZeroExtendUltConst
+ *
+ * Rewrite zero_extend(x^n,m) < c^n+m to
+ *
+ * x < c[n-1:0] if c[n+m-1:n] == 0.
+ *
+ * Rewrite c^n+m < Rewrite zero_extend(x^n,m) to
+ *
+ * c[n-1:0] < x if c[n+m-1:n] == 0.
+ */
+template <>
+inline bool RewriteRule<ZeroExtendUltConst>::applies(TNode node) {
+ if (node.getKind() == kind::BITVECTOR_ULT &&
+ ((node[0].getKind() == kind::BITVECTOR_ZERO_EXTEND &&
+ node[1].isConst()) ||
+ (node[1].getKind() == kind::BITVECTOR_ZERO_EXTEND &&
+ node[0].isConst()))) {
+ TNode t, c;
+ bool is_lhs = node[0].getKind() == kind::BITVECTOR_ZERO_EXTEND;
+ if (is_lhs) {
+ t = node[0][0];
+ c = node[1];
+ } else {
+ t = node[1][0];
+ c = node[0];
+ }
+ BitVector bv_c = c.getConst<BitVector>();
+ BitVector bv_max =
+ BitVector(utils::getSize(c)).setBit(utils::getSize(t) - 1);
+
+ BitVector c_hi = c.getConst<BitVector>().extract(utils::getSize(c) - 1,
+ utils::getSize(t));
+ BitVector zero = BitVector(c_hi.getSize(), Integer(0));
+
+ return c_hi == zero;
+ }
+ return false;
+}
+
+template <>
+inline Node RewriteRule<ZeroExtendUltConst>::apply(TNode node) {
+ TNode t, c;
+ bool is_lhs = node[0].getKind() == kind::BITVECTOR_ZERO_EXTEND;
+ if (is_lhs) {
+ t = node[0][0];
+ c = node[1];
+ } else {
+ t = node[1][0];
+ c = node[0];
+ }
+ Node c_lo =
+ utils::mkConst(c.getConst<BitVector>().extract(utils::getSize(t) - 1, 0));
+
+ if (is_lhs) {
+ return NodeManager::currentNM()->mkNode(kind::BITVECTOR_ULT, t, c_lo);
+ }
+ return NodeManager::currentNM()->mkNode(kind::BITVECTOR_ULT, c_lo, t);
+}
+
+/**
+ * SignExtendUltConst
+ *
+ * Rewrite sign_extend(x^n,m) < c^n+m to
+ *
+ * x < c[n-1:0] if c <= (1 << (n - 1)).
+ *
+ * Rewrite c^n+m < sign_extend(x^n,m) to
+ *
+ * c[n-1:0] < x if c < (1 << (n - 1)).
+ */
+template <>
+inline bool RewriteRule<SignExtendUltConst>::applies(TNode node) {
+ if (node.getKind() == kind::BITVECTOR_ULT &&
+ ((node[0].getKind() == kind::BITVECTOR_SIGN_EXTEND &&
+ node[1].isConst()) ||
+ (node[1].getKind() == kind::BITVECTOR_SIGN_EXTEND &&
+ node[0].isConst()))) {
+ TNode t, c;
+ bool is_lhs = node[0].getKind() == kind::BITVECTOR_SIGN_EXTEND;
+ if (is_lhs) {
+ t = node[0][0];
+ c = node[1];
+ } else {
+ t = node[1][0];
+ c = node[0];
+ }
+ BitVector bv_c = c.getConst<BitVector>();
+ BitVector bv_max =
+ BitVector(utils::getSize(c)).setBit(utils::getSize(t) - 1);
+
+ return (is_lhs && bv_c <= bv_max) || (!is_lhs && bv_c < bv_max);
+ }
+ return false;
+}
+
+template <>
+inline Node RewriteRule<SignExtendUltConst>::apply(TNode node) {
+ TNode t, c;
+ bool is_lhs = node[0].getKind() == kind::BITVECTOR_SIGN_EXTEND;
+ if (is_lhs) {
+ t = node[0][0];
+ c = node[1];
+ } else {
+ t = node[1][0];
+ c = node[0];
+ }
+ Node c_lo =
+ utils::mkConst(c.getConst<BitVector>().extract(utils::getSize(t) - 1, 0));
+
+ if (is_lhs) {
+ return NodeManager::currentNM()->mkNode(kind::BITVECTOR_ULT, t, c_lo);
+ }
+ return NodeManager::currentNM()->mkNode(kind::BITVECTOR_ULT, c_lo, t);
+}
template<> inline
bool RewriteRule<MultSlice>::applies(TNode node) {
diff --git a/src/theory/bv/theory_bv_rewriter.cpp b/src/theory/bv/theory_bv_rewriter.cpp
index f3185bc13..bfaf517cc 100644
--- a/src/theory/bv/theory_bv_rewriter.cpp
+++ b/src/theory/bv/theory_bv_rewriter.cpp
@@ -68,10 +68,10 @@ RewriteResponse TheoryBVRewriter::postRewrite(TNode node) {
RewriteResponse TheoryBVRewriter::RewriteUlt(TNode node, bool prerewrite) {
// reduce common subexpressions on both sides
Node resultNode = LinearRewriteStrategy
- < RewriteRule<EvalUlt>,
- // if both arguments are constants evaluates
- RewriteRule<UltZero>
- // a < 0 rewrites to false
+ < RewriteRule<EvalUlt>, // if both arguments are constants evaluates
+ RewriteRule<UltZero>, // a < 0 rewrites to false,
+ RewriteRule<SignExtendUltConst>,
+ RewriteRule<ZeroExtendUltConst>
>::apply(node);
return RewriteResponse(REWRITE_DONE, resultNode);
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback