summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAina Niemetz <aina.niemetz@gmail.com>2021-04-30 15:06:30 -0700
committerGitHub <noreply@github.com>2021-04-30 22:06:30 +0000
commitdf3ffb33b8173f252b7720d27aa0204e8ff3632e (patch)
treed8dc390f10bbd6a81fd0522c7e05e5908162f6e9
parent67a1510b8e6306993d7efb7671b8f0aa53a45deb (diff)
Add parameter name for argument `isPreRewrite` for FP rewrites. (#6469)
-rw-r--r--src/theory/fp/fp_expand_defs.cpp17
-rw-r--r--src/theory/fp/theory_fp_rewriter.cpp462
2 files changed, 270 insertions, 209 deletions
diff --git a/src/theory/fp/fp_expand_defs.cpp b/src/theory/fp/fp_expand_defs.cpp
index 4e9803bf7..34cc1ed5d 100644
--- a/src/theory/fp/fp_expand_defs.cpp
+++ b/src/theory/fp/fp_expand_defs.cpp
@@ -260,22 +260,23 @@ TrustNode FpExpandDefs::expandDefinition(Node node)
<< "FpExpandDefs::expandDefinition(): " << node << std::endl;
Node res = node;
+ Kind kind = node.getKind();
- if (node.getKind() == kind::FLOATINGPOINT_TO_FP_GENERIC)
+ if (kind == kind::FLOATINGPOINT_TO_FP_GENERIC)
{
res = removeToFPGeneric::removeToFPGeneric(node);
}
- else if (node.getKind() == kind::FLOATINGPOINT_MIN)
+ else if (kind == kind::FLOATINGPOINT_MIN)
{
res = NodeManager::currentNM()->mkNode(
kind::FLOATINGPOINT_MIN_TOTAL, node[0], node[1], minUF(node));
}
- else if (node.getKind() == kind::FLOATINGPOINT_MAX)
+ else if (kind == kind::FLOATINGPOINT_MAX)
{
res = NodeManager::currentNM()->mkNode(
kind::FLOATINGPOINT_MAX_TOTAL, node[0], node[1], maxUF(node));
}
- else if (node.getKind() == kind::FLOATINGPOINT_TO_UBV)
+ else if (kind == kind::FLOATINGPOINT_TO_UBV)
{
FloatingPointToUBV info = node.getOperator().getConst<FloatingPointToUBV>();
FloatingPointToUBVTotal newInfo(info);
@@ -287,7 +288,7 @@ TrustNode FpExpandDefs::expandDefinition(Node node)
node[1],
toUBVUF(node));
}
- else if (node.getKind() == kind::FLOATINGPOINT_TO_SBV)
+ else if (kind == kind::FLOATINGPOINT_TO_SBV)
{
FloatingPointToSBV info = node.getOperator().getConst<FloatingPointToSBV>();
FloatingPointToSBVTotal newInfo(info);
@@ -299,15 +300,11 @@ TrustNode FpExpandDefs::expandDefinition(Node node)
node[1],
toSBVUF(node));
}
- else if (node.getKind() == kind::FLOATINGPOINT_TO_REAL)
+ else if (kind == kind::FLOATINGPOINT_TO_REAL)
{
res = NodeManager::currentNM()->mkNode(
kind::FLOATINGPOINT_TO_REAL_TOTAL, node[0], toRealUF(node));
}
- else
- {
- // Do nothing
- }
if (res != node)
{
diff --git a/src/theory/fp/theory_fp_rewriter.cpp b/src/theory/fp/theory_fp_rewriter.cpp
index 07fde6a88..e431ffa09 100644
--- a/src/theory/fp/theory_fp_rewriter.cpp
+++ b/src/theory/fp/theory_fp_rewriter.cpp
@@ -58,21 +58,25 @@ namespace rewrite {
}
}
- RewriteResponse notFP (TNode node, bool) {
+ RewriteResponse notFP(TNode node, bool isPreRewrite)
+ {
Unreachable() << "non floating-point kind (" << node.getKind()
<< ") in floating point rewrite?";
}
- RewriteResponse identity (TNode node, bool) {
+ RewriteResponse identity(TNode node, bool isPreRewrite)
+ {
return RewriteResponse(REWRITE_DONE, node);
}
- RewriteResponse type (TNode node, bool) {
+ RewriteResponse type(TNode node, bool isPreRewrite)
+ {
Unreachable() << "sort kind (" << node.getKind()
<< ") found in expression?";
}
- RewriteResponse removeDoubleNegation (TNode node, bool) {
+ RewriteResponse removeDoubleNegation(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_NEG);
if (node[0].getKind() == kind::FLOATINGPOINT_NEG) {
return RewriteResponse(REWRITE_AGAIN, node[0][0]);
@@ -81,7 +85,8 @@ namespace rewrite {
return RewriteResponse(REWRITE_DONE, node);
}
- RewriteResponse compactAbs (TNode node, bool) {
+ RewriteResponse compactAbs(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_ABS);
if (node[0].getKind() == kind::FLOATINGPOINT_NEG
|| node[0].getKind() == kind::FLOATINGPOINT_ABS)
@@ -94,7 +99,8 @@ namespace rewrite {
return RewriteResponse(REWRITE_DONE, node);
}
- RewriteResponse convertSubtractionToAddition (TNode node, bool) {
+ RewriteResponse convertSubtractionToAddition(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_SUB);
Node negation = NodeManager::currentNM()->mkNode(kind::FLOATINGPOINT_NEG,node[2]);
Node addition = NodeManager::currentNM()->mkNode(kind::FLOATINGPOINT_PLUS,node[0],node[1],negation);
@@ -129,7 +135,8 @@ namespace rewrite {
/* Implies (fp.eq x x) --> (not (isNaN x))
*/
- RewriteResponse ieeeEqToEq (TNode node, bool) {
+ RewriteResponse ieeeEqToEq(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_EQ);
NodeManager *nm = NodeManager::currentNM();
@@ -145,24 +152,26 @@ namespace rewrite {
nm->mkNode(kind::FLOATINGPOINT_ISZ, node[1])))));
}
-
- RewriteResponse geqToleq (TNode node, bool) {
+ RewriteResponse geqToleq(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_GEQ);
return RewriteResponse(REWRITE_DONE,NodeManager::currentNM()->mkNode(kind::FLOATINGPOINT_LEQ,node[1],node[0]));
}
- RewriteResponse gtTolt (TNode node, bool) {
+ RewriteResponse gtTolt(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_GT);
return RewriteResponse(REWRITE_DONE,NodeManager::currentNM()->mkNode(kind::FLOATINGPOINT_LT,node[1],node[0]));
}
- RewriteResponse removed(TNode node, bool)
+ RewriteResponse removed(TNode node, bool isPreRewrite)
{
Unreachable() << "kind (" << node.getKind()
<< ") should have been removed?";
}
- RewriteResponse variable (TNode node, bool) {
+ RewriteResponse variable(TNode node, bool isPreRewrite)
+ {
// We should only get floating point and rounding mode variables to rewrite.
TypeNode tn = node.getType(true);
Assert(tn.isFloatingPoint() || tn.isRoundingMode());
@@ -328,226 +337,264 @@ namespace rewrite {
namespace constantFold {
+RewriteResponse fpLiteral(TNode node, bool isPreRewrite)
+{
+ Assert(node.getKind() == kind::FLOATINGPOINT_FP);
- RewriteResponse fpLiteral (TNode node, bool) {
- Assert(node.getKind() == kind::FLOATINGPOINT_FP);
+ BitVector bv(node[0].getConst<BitVector>());
+ bv = bv.concat(node[1].getConst<BitVector>());
+ bv = bv.concat(node[2].getConst<BitVector>());
- BitVector bv(node[0].getConst<BitVector>());
- bv = bv.concat(node[1].getConst<BitVector>());
- bv = bv.concat(node[2].getConst<BitVector>());
-
- // +1 to support the hidden bit
- Node lit =
- NodeManager::currentNM()->mkConst(FloatingPoint(node[1].getConst<BitVector>().getSize(),
- node[2].getConst<BitVector>().getSize() + 1,
- bv));
-
- return RewriteResponse(REWRITE_DONE, lit);
- }
+ // +1 to support the hidden bit
+ Node lit = NodeManager::currentNM()->mkConst(
+ FloatingPoint(node[1].getConst<BitVector>().getSize(),
+ node[2].getConst<BitVector>().getSize() + 1,
+ bv));
- RewriteResponse abs (TNode node, bool) {
- Assert(node.getKind() == kind::FLOATINGPOINT_ABS);
- Assert(node.getNumChildren() == 1);
+ return RewriteResponse(REWRITE_DONE, lit);
+}
- return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().absolute()));
- }
+RewriteResponse abs(TNode node, bool isPreRewrite)
+{
+ Assert(node.getKind() == kind::FLOATINGPOINT_ABS);
+ Assert(node.getNumChildren() == 1);
+ return RewriteResponse(REWRITE_DONE,
+ NodeManager::currentNM()->mkConst(
+ node[0].getConst<FloatingPoint>().absolute()));
+}
- RewriteResponse neg (TNode node, bool) {
- Assert(node.getKind() == kind::FLOATINGPOINT_NEG);
- Assert(node.getNumChildren() == 1);
+RewriteResponse neg(TNode node, bool isPreRewrite)
+{
+ Assert(node.getKind() == kind::FLOATINGPOINT_NEG);
+ Assert(node.getNumChildren() == 1);
- return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().negate()));
- }
+ return RewriteResponse(REWRITE_DONE,
+ NodeManager::currentNM()->mkConst(
+ node[0].getConst<FloatingPoint>().negate()));
+}
+RewriteResponse plus(TNode node, bool isPreRewrite)
+{
+ Assert(node.getKind() == kind::FLOATINGPOINT_PLUS);
+ Assert(node.getNumChildren() == 3);
- RewriteResponse plus (TNode node, bool) {
- Assert(node.getKind() == kind::FLOATINGPOINT_PLUS);
- Assert(node.getNumChildren() == 3);
+ RoundingMode rm(node[0].getConst<RoundingMode>());
+ FloatingPoint arg1(node[1].getConst<FloatingPoint>());
+ FloatingPoint arg2(node[2].getConst<FloatingPoint>());
- RoundingMode rm(node[0].getConst<RoundingMode>());
- FloatingPoint arg1(node[1].getConst<FloatingPoint>());
- FloatingPoint arg2(node[2].getConst<FloatingPoint>());
+ Assert(arg1.getSize() == arg2.getSize());
- Assert(arg1.getSize() == arg2.getSize());
+ return RewriteResponse(
+ REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.plus(rm, arg2)));
+}
- return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.plus(rm, arg2)));
- }
+RewriteResponse mult(TNode node, bool isPreRewrite)
+{
+ Assert(node.getKind() == kind::FLOATINGPOINT_MULT);
+ Assert(node.getNumChildren() == 3);
- RewriteResponse mult (TNode node, bool) {
- Assert(node.getKind() == kind::FLOATINGPOINT_MULT);
- Assert(node.getNumChildren() == 3);
+ RoundingMode rm(node[0].getConst<RoundingMode>());
+ FloatingPoint arg1(node[1].getConst<FloatingPoint>());
+ FloatingPoint arg2(node[2].getConst<FloatingPoint>());
- RoundingMode rm(node[0].getConst<RoundingMode>());
- FloatingPoint arg1(node[1].getConst<FloatingPoint>());
- FloatingPoint arg2(node[2].getConst<FloatingPoint>());
+ Assert(arg1.getSize() == arg2.getSize());
- Assert(arg1.getSize() == arg2.getSize());
+ return RewriteResponse(
+ REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.mult(rm, arg2)));
+}
- return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.mult(rm, arg2)));
- }
+RewriteResponse fma(TNode node, bool isPreRewrite)
+{
+ Assert(node.getKind() == kind::FLOATINGPOINT_FMA);
+ Assert(node.getNumChildren() == 4);
- RewriteResponse fma (TNode node, bool) {
- Assert(node.getKind() == kind::FLOATINGPOINT_FMA);
- Assert(node.getNumChildren() == 4);
+ RoundingMode rm(node[0].getConst<RoundingMode>());
+ FloatingPoint arg1(node[1].getConst<FloatingPoint>());
+ FloatingPoint arg2(node[2].getConst<FloatingPoint>());
+ FloatingPoint arg3(node[3].getConst<FloatingPoint>());
- RoundingMode rm(node[0].getConst<RoundingMode>());
- FloatingPoint arg1(node[1].getConst<FloatingPoint>());
- FloatingPoint arg2(node[2].getConst<FloatingPoint>());
- FloatingPoint arg3(node[3].getConst<FloatingPoint>());
+ Assert(arg1.getSize() == arg2.getSize());
+ Assert(arg1.getSize() == arg3.getSize());
- Assert(arg1.getSize() == arg2.getSize());
- Assert(arg1.getSize() == arg3.getSize());
+ return RewriteResponse(
+ REWRITE_DONE,
+ NodeManager::currentNM()->mkConst(arg1.fma(rm, arg2, arg3)));
+}
- return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.fma(rm, arg2, arg3)));
- }
+RewriteResponse div(TNode node, bool isPreRewrite)
+{
+ Assert(node.getKind() == kind::FLOATINGPOINT_DIV);
+ Assert(node.getNumChildren() == 3);
- RewriteResponse div (TNode node, bool) {
- Assert(node.getKind() == kind::FLOATINGPOINT_DIV);
- Assert(node.getNumChildren() == 3);
+ RoundingMode rm(node[0].getConst<RoundingMode>());
+ FloatingPoint arg1(node[1].getConst<FloatingPoint>());
+ FloatingPoint arg2(node[2].getConst<FloatingPoint>());
- RoundingMode rm(node[0].getConst<RoundingMode>());
- FloatingPoint arg1(node[1].getConst<FloatingPoint>());
- FloatingPoint arg2(node[2].getConst<FloatingPoint>());
+ Assert(arg1.getSize() == arg2.getSize());
- Assert(arg1.getSize() == arg2.getSize());
+ return RewriteResponse(REWRITE_DONE,
+ NodeManager::currentNM()->mkConst(arg1.div(rm, arg2)));
+}
- return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.div(rm, arg2)));
- }
-
- RewriteResponse sqrt (TNode node, bool) {
- Assert(node.getKind() == kind::FLOATINGPOINT_SQRT);
- Assert(node.getNumChildren() == 2);
+RewriteResponse sqrt(TNode node, bool isPreRewrite)
+{
+ Assert(node.getKind() == kind::FLOATINGPOINT_SQRT);
+ Assert(node.getNumChildren() == 2);
- RoundingMode rm(node[0].getConst<RoundingMode>());
- FloatingPoint arg(node[1].getConst<FloatingPoint>());
-
- return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg.sqrt(rm)));
- }
+ RoundingMode rm(node[0].getConst<RoundingMode>());
+ FloatingPoint arg(node[1].getConst<FloatingPoint>());
- RewriteResponse rti (TNode node, bool) {
- Assert(node.getKind() == kind::FLOATINGPOINT_RTI);
- Assert(node.getNumChildren() == 2);
+ return RewriteResponse(REWRITE_DONE,
+ NodeManager::currentNM()->mkConst(arg.sqrt(rm)));
+}
- RoundingMode rm(node[0].getConst<RoundingMode>());
- FloatingPoint arg(node[1].getConst<FloatingPoint>());
-
- return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg.rti(rm)));
- }
+RewriteResponse rti(TNode node, bool isPreRewrite)
+{
+ Assert(node.getKind() == kind::FLOATINGPOINT_RTI);
+ Assert(node.getNumChildren() == 2);
- RewriteResponse rem (TNode node, bool) {
- Assert(node.getKind() == kind::FLOATINGPOINT_REM);
- Assert(node.getNumChildren() == 2);
+ RoundingMode rm(node[0].getConst<RoundingMode>());
+ FloatingPoint arg(node[1].getConst<FloatingPoint>());
- FloatingPoint arg1(node[0].getConst<FloatingPoint>());
- FloatingPoint arg2(node[1].getConst<FloatingPoint>());
+ return RewriteResponse(REWRITE_DONE,
+ NodeManager::currentNM()->mkConst(arg.rti(rm)));
+}
- Assert(arg1.getSize() == arg2.getSize());
+RewriteResponse rem(TNode node, bool isPreRewrite)
+{
+ Assert(node.getKind() == kind::FLOATINGPOINT_REM);
+ Assert(node.getNumChildren() == 2);
- return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.rem(arg2)));
- }
+ FloatingPoint arg1(node[0].getConst<FloatingPoint>());
+ FloatingPoint arg2(node[1].getConst<FloatingPoint>());
- RewriteResponse min (TNode node, bool) {
- Assert(node.getKind() == kind::FLOATINGPOINT_MIN);
- Assert(node.getNumChildren() == 2);
+ Assert(arg1.getSize() == arg2.getSize());
- FloatingPoint arg1(node[0].getConst<FloatingPoint>());
- FloatingPoint arg2(node[1].getConst<FloatingPoint>());
+ return RewriteResponse(REWRITE_DONE,
+ NodeManager::currentNM()->mkConst(arg1.rem(arg2)));
+}
- Assert(arg1.getSize() == arg2.getSize());
+RewriteResponse min(TNode node, bool isPreRewrite)
+{
+ Assert(node.getKind() == kind::FLOATINGPOINT_MIN);
+ Assert(node.getNumChildren() == 2);
- FloatingPoint::PartialFloatingPoint res(arg1.min(arg2));
+ FloatingPoint arg1(node[0].getConst<FloatingPoint>());
+ FloatingPoint arg2(node[1].getConst<FloatingPoint>());
- if (res.second) {
- Node lit = NodeManager::currentNM()->mkConst(res.first);
- return RewriteResponse(REWRITE_DONE, lit);
- } else {
- // Can't constant fold the underspecified case
- return RewriteResponse(REWRITE_DONE, node);
- }
+ Assert(arg1.getSize() == arg2.getSize());
+
+ FloatingPoint::PartialFloatingPoint res(arg1.min(arg2));
+
+ if (res.second)
+ {
+ Node lit = NodeManager::currentNM()->mkConst(res.first);
+ return RewriteResponse(REWRITE_DONE, lit);
+ }
+ else
+ {
+ // Can't constant fold the underspecified case
+ return RewriteResponse(REWRITE_DONE, node);
}
+}
- RewriteResponse max (TNode node, bool) {
- Assert(node.getKind() == kind::FLOATINGPOINT_MAX);
- Assert(node.getNumChildren() == 2);
+RewriteResponse max(TNode node, bool isPreRewrite)
+{
+ Assert(node.getKind() == kind::FLOATINGPOINT_MAX);
+ Assert(node.getNumChildren() == 2);
- FloatingPoint arg1(node[0].getConst<FloatingPoint>());
- FloatingPoint arg2(node[1].getConst<FloatingPoint>());
+ FloatingPoint arg1(node[0].getConst<FloatingPoint>());
+ FloatingPoint arg2(node[1].getConst<FloatingPoint>());
- Assert(arg1.getSize() == arg2.getSize());
+ Assert(arg1.getSize() == arg2.getSize());
- FloatingPoint::PartialFloatingPoint res(arg1.max(arg2));
+ FloatingPoint::PartialFloatingPoint res(arg1.max(arg2));
- if (res.second) {
- Node lit = NodeManager::currentNM()->mkConst(res.first);
- return RewriteResponse(REWRITE_DONE, lit);
- } else {
- // Can't constant fold the underspecified case
- return RewriteResponse(REWRITE_DONE, node);
- }
+ if (res.second)
+ {
+ Node lit = NodeManager::currentNM()->mkConst(res.first);
+ return RewriteResponse(REWRITE_DONE, lit);
}
+ else
+ {
+ // Can't constant fold the underspecified case
+ return RewriteResponse(REWRITE_DONE, node);
+ }
+}
- RewriteResponse minTotal (TNode node, bool) {
- Assert(node.getKind() == kind::FLOATINGPOINT_MIN_TOTAL);
- Assert(node.getNumChildren() == 3);
-
- FloatingPoint arg1(node[0].getConst<FloatingPoint>());
- FloatingPoint arg2(node[1].getConst<FloatingPoint>());
+RewriteResponse minTotal(TNode node, bool isPreRewrite)
+{
+ Assert(node.getKind() == kind::FLOATINGPOINT_MIN_TOTAL);
+ Assert(node.getNumChildren() == 3);
- Assert(arg1.getSize() == arg2.getSize());
+ FloatingPoint arg1(node[0].getConst<FloatingPoint>());
+ FloatingPoint arg2(node[1].getConst<FloatingPoint>());
- // Can be called with the third argument non-constant
- if (node[2].getMetaKind() == kind::metakind::CONSTANT) {
- BitVector arg3(node[2].getConst<BitVector>());
+ Assert(arg1.getSize() == arg2.getSize());
- FloatingPoint folded(arg1.minTotal(arg2, arg3.isBitSet(0)));
- Node lit = NodeManager::currentNM()->mkConst(folded);
- return RewriteResponse(REWRITE_DONE, lit);
+ // Can be called with the third argument non-constant
+ if (node[2].getMetaKind() == kind::metakind::CONSTANT)
+ {
+ BitVector arg3(node[2].getConst<BitVector>());
- } else {
- FloatingPoint::PartialFloatingPoint res(arg1.min(arg2));
+ FloatingPoint folded(arg1.minTotal(arg2, arg3.isBitSet(0)));
+ Node lit = NodeManager::currentNM()->mkConst(folded);
+ return RewriteResponse(REWRITE_DONE, lit);
+ }
+ else
+ {
+ FloatingPoint::PartialFloatingPoint res(arg1.min(arg2));
- if (res.second) {
- Node lit = NodeManager::currentNM()->mkConst(res.first);
- return RewriteResponse(REWRITE_DONE, lit);
- } else {
- // Can't constant fold the underspecified case
- return RewriteResponse(REWRITE_DONE, node);
- }
+ if (res.second)
+ {
+ Node lit = NodeManager::currentNM()->mkConst(res.first);
+ return RewriteResponse(REWRITE_DONE, lit);
+ }
+ else
+ {
+ // Can't constant fold the underspecified case
+ return RewriteResponse(REWRITE_DONE, node);
}
}
+}
- RewriteResponse maxTotal (TNode node, bool) {
- Assert(node.getKind() == kind::FLOATINGPOINT_MAX_TOTAL);
- Assert(node.getNumChildren() == 3);
+RewriteResponse maxTotal(TNode node, bool isPreRewrite)
+{
+ Assert(node.getKind() == kind::FLOATINGPOINT_MAX_TOTAL);
+ Assert(node.getNumChildren() == 3);
- FloatingPoint arg1(node[0].getConst<FloatingPoint>());
- FloatingPoint arg2(node[1].getConst<FloatingPoint>());
+ FloatingPoint arg1(node[0].getConst<FloatingPoint>());
+ FloatingPoint arg2(node[1].getConst<FloatingPoint>());
- Assert(arg1.getSize() == arg2.getSize());
+ Assert(arg1.getSize() == arg2.getSize());
- // Can be called with the third argument non-constant
- if (node[2].getMetaKind() == kind::metakind::CONSTANT) {
- BitVector arg3(node[2].getConst<BitVector>());
-
- FloatingPoint folded(arg1.maxTotal(arg2, arg3.isBitSet(0)));
- Node lit = NodeManager::currentNM()->mkConst(folded);
- return RewriteResponse(REWRITE_DONE, lit);
+ // Can be called with the third argument non-constant
+ if (node[2].getMetaKind() == kind::metakind::CONSTANT)
+ {
+ BitVector arg3(node[2].getConst<BitVector>());
- } else {
- FloatingPoint::PartialFloatingPoint res(arg1.max(arg2));
+ FloatingPoint folded(arg1.maxTotal(arg2, arg3.isBitSet(0)));
+ Node lit = NodeManager::currentNM()->mkConst(folded);
+ return RewriteResponse(REWRITE_DONE, lit);
+ }
+ else
+ {
+ FloatingPoint::PartialFloatingPoint res(arg1.max(arg2));
- if (res.second) {
- Node lit = NodeManager::currentNM()->mkConst(res.first);
- return RewriteResponse(REWRITE_DONE, lit);
- } else {
- // Can't constant fold the underspecified case
- return RewriteResponse(REWRITE_DONE, node);
- }
+ if (res.second)
+ {
+ Node lit = NodeManager::currentNM()->mkConst(res.first);
+ return RewriteResponse(REWRITE_DONE, lit);
+ }
+ else
+ {
+ // Can't constant fold the underspecified case
+ return RewriteResponse(REWRITE_DONE, node);
}
}
+}
-
RewriteResponse equal (TNode node, bool isPreRewrite) {
Assert(node.getKind() == kind::EQUAL);
@@ -572,8 +619,8 @@ namespace constantFold {
Unreachable() << "Equality of unknown type";
}
-
- RewriteResponse leq (TNode node, bool) {
+ RewriteResponse leq(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_LEQ);
Assert(node.getNumChildren() == 2);
@@ -585,8 +632,8 @@ namespace constantFold {
return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1 <= arg2));
}
-
- RewriteResponse lt (TNode node, bool) {
+ RewriteResponse lt(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_LT);
Assert(node.getNumChildren() == 2);
@@ -598,57 +645,64 @@ namespace constantFold {
return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1 < arg2));
}
-
- RewriteResponse isNormal (TNode node, bool) {
+ RewriteResponse isNormal(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_ISN);
Assert(node.getNumChildren() == 1);
return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isNormal()));
}
- RewriteResponse isSubnormal (TNode node, bool) {
+ RewriteResponse isSubnormal(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_ISSN);
Assert(node.getNumChildren() == 1);
return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isSubnormal()));
}
- RewriteResponse isZero (TNode node, bool) {
+ RewriteResponse isZero(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_ISZ);
Assert(node.getNumChildren() == 1);
return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isZero()));
}
- RewriteResponse isInfinite (TNode node, bool) {
+ RewriteResponse isInfinite(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_ISINF);
Assert(node.getNumChildren() == 1);
return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isInfinite()));
}
- RewriteResponse isNaN (TNode node, bool) {
+ RewriteResponse isNaN(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_ISNAN);
Assert(node.getNumChildren() == 1);
return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isNaN()));
}
- RewriteResponse isNegative (TNode node, bool) {
+ RewriteResponse isNegative(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_ISNEG);
Assert(node.getNumChildren() == 1);
return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isNegative()));
}
- RewriteResponse isPositive (TNode node, bool) {
+ RewriteResponse isPositive(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_ISPOS);
Assert(node.getNumChildren() == 1);
return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isPositive()));
}
- RewriteResponse convertFromIEEEBitVectorLiteral (TNode node, bool) {
+ RewriteResponse convertFromIEEEBitVectorLiteral(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR);
TNode op = node.getOperator();
@@ -663,7 +717,8 @@ namespace constantFold {
return RewriteResponse(REWRITE_DONE, lit);
}
- RewriteResponse constantConvert (TNode node, bool) {
+ RewriteResponse constantConvert(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT);
Assert(node.getNumChildren() == 2);
@@ -676,7 +731,8 @@ namespace constantFold {
NodeManager::currentNM()->mkConst(arg1.convert(info.getSize(), rm)));
}
- RewriteResponse convertFromRealLiteral (TNode node, bool) {
+ RewriteResponse convertFromRealLiteral(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_REAL);
TNode op = node.getOperator();
@@ -692,7 +748,8 @@ namespace constantFold {
return RewriteResponse(REWRITE_DONE, lit);
}
- RewriteResponse convertFromSBV (TNode node, bool) {
+ RewriteResponse convertFromSBV(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR);
TNode op = node.getOperator();
@@ -708,7 +765,8 @@ namespace constantFold {
return RewriteResponse(REWRITE_DONE, lit);
}
- RewriteResponse convertFromUBV (TNode node, bool) {
+ RewriteResponse convertFromUBV(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR);
TNode op = node.getOperator();
@@ -724,7 +782,8 @@ namespace constantFold {
return RewriteResponse(REWRITE_DONE, lit);
}
- RewriteResponse convertToUBV (TNode node, bool) {
+ RewriteResponse convertToUBV(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_TO_UBV);
TNode op = node.getOperator();
@@ -745,7 +804,8 @@ namespace constantFold {
}
}
- RewriteResponse convertToSBV (TNode node, bool) {
+ RewriteResponse convertToSBV(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_TO_SBV);
TNode op = node.getOperator();
@@ -766,7 +826,8 @@ namespace constantFold {
}
}
- RewriteResponse convertToReal (TNode node, bool) {
+ RewriteResponse convertToReal(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_TO_REAL);
FloatingPoint arg(node[0].getConst<FloatingPoint>());
@@ -782,7 +843,8 @@ namespace constantFold {
}
}
- RewriteResponse convertToUBVTotal (TNode node, bool) {
+ RewriteResponse convertToUBVTotal(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_TO_UBV_TOTAL);
TNode op = node.getOperator();
@@ -814,7 +876,8 @@ namespace constantFold {
}
}
- RewriteResponse convertToSBVTotal (TNode node, bool) {
+ RewriteResponse convertToSBVTotal(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_TO_SBV_TOTAL);
TNode op = node.getOperator();
@@ -846,7 +909,8 @@ namespace constantFold {
}
}
- RewriteResponse convertToRealTotal (TNode node, bool) {
+ RewriteResponse convertToRealTotal(TNode node, bool isPreRewrite)
+ {
Assert(node.getKind() == kind::FLOATINGPOINT_TO_REAL_TOTAL);
FloatingPoint arg(node[0].getConst<FloatingPoint>());
@@ -872,7 +936,7 @@ namespace constantFold {
}
}
- RewriteResponse componentFlag(TNode node, bool)
+ RewriteResponse componentFlag(TNode node, bool isPreRewrite)
{
Kind k = node.getKind();
@@ -901,7 +965,7 @@ namespace constantFold {
NodeManager::currentNM()->mkConst(res));
}
- RewriteResponse componentExponent(TNode node, bool)
+ RewriteResponse componentExponent(TNode node, bool isPreRewrite)
{
Assert(node.getKind() == kind::FLOATINGPOINT_COMPONENT_EXPONENT);
@@ -918,7 +982,7 @@ namespace constantFold {
);
}
- RewriteResponse componentSignificand(TNode node, bool)
+ RewriteResponse componentSignificand(TNode node, bool isPreRewrite)
{
Assert(node.getKind() == kind::FLOATINGPOINT_COMPONENT_SIGNIFICAND);
@@ -934,7 +998,7 @@ namespace constantFold {
);
}
- RewriteResponse roundingModeBitBlast(TNode node, bool)
+ RewriteResponse roundingModeBitBlast(TNode node, bool isPreRewrite)
{
Assert(node.getKind() == kind::ROUNDINGMODE_BITBLAST);
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback