summaryrefslogtreecommitdiff
path: root/src/theory/fp
diff options
context:
space:
mode:
authorMartin Brain <martin.brain@cs.ox.ac.uk>2017-09-26 17:21:51 -0700
committerAndres Noetzli <noetzli@stanford.edu>2017-09-26 17:24:55 -0700
commite23377411d993e126403eb186c80f664419d512c (patch)
treed80a837f2e9a465ddc7f1bdcb2b821af1453e522 /src/theory/fp
parent0f9f1fee128c86f3a1210134f1f22a0343793d4a (diff)
Improve FP rewriter: const folding, other (#1126)
Diffstat (limited to 'src/theory/fp')
-rw-r--r--src/theory/fp/theory_fp_rewriter.cpp643
-rw-r--r--src/theory/fp/theory_fp_rewriter.h2
2 files changed, 566 insertions, 79 deletions
diff --git a/src/theory/fp/theory_fp_rewriter.cpp b/src/theory/fp/theory_fp_rewriter.cpp
index 747aaeac6..ec42099c2 100644
--- a/src/theory/fp/theory_fp_rewriter.cpp
+++ b/src/theory/fp/theory_fp_rewriter.cpp
@@ -12,17 +12,22 @@
**
** \brief [[ Rewrite rules for floating point theories. ]]
**
- ** \todo [[ Constant folding
- ** Push negations up through arithmetic operators (include max and min? maybe not due to +0/-0)
+ ** \todo [[ Single argument constant propagate / simplify
+ Push negations through arithmetic operators (include max and min? maybe not due to +0/-0)
** classifications to normal tests (maybe)
** (= x (fp.neg x)) --> (isNaN x)
** (fp.eq x (fp.neg x)) --> (isZero x) (previous and reorganise should be sufficient)
- ** (fp.eq x const) --> various = depending on const
+ ** (fp.eq x const) --> various = depending on const
** (fp.abs (fp.neg x)) --> (fp.abs x)
** (fp.isPositive (fp.neg x)) --> (fp.isNegative x)
** (fp.isNegative (fp.neg x)) --> (fp.isPositive x)
** (fp.isPositive (fp.abs x)) --> (not (isNaN x))
** (fp.isNegative (fp.abs x)) --> false
+ ** A -> castA --> A
+ ** A -> castB -> castC --> A -> castC if A <= B <= C
+ ** A -> castB -> castA --> A if A <= B
+ ** promotion converts can ignore rounding mode
+ ** Samuel Figuer results
** ]]
**/
@@ -137,7 +142,7 @@ namespace rewrite {
}
RewriteResponse removed (TNode node, bool) {
- Unreachable("kind (%d) should have been removed?",node.getKind());
+ Unreachable("kind (%s) should have been removed?",kindToString(node.getKind()).c_str());
return RewriteResponse(REWRITE_DONE, node);
}
@@ -150,10 +155,10 @@ namespace rewrite {
return RewriteResponse(REWRITE_DONE, node);
}
- RewriteResponse equal (TNode node, bool isPreRewrite) {
- // We should only get equalities of floating point or rounding mode types.
+ RewriteResponse equal (TNode node, bool isPreRewrite) {
Assert(node.getKind() == kind::EQUAL);
-
+
+ // We should only get equalities of floating point or rounding mode types.
TypeNode tn = node[0].getType(true);
Assert(tn.isFloatingPoint() || tn.isRoundingMode());
@@ -169,74 +174,6 @@ namespace rewrite {
}
}
- RewriteResponse convertFromRealLiteral (TNode node, bool) {
- Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_REAL);
-
- // \todo Honour the rounding mode and work for something other than doubles!
-
- if (node[1].getKind() == kind::CONST_RATIONAL) {
- TNode op = node.getOperator();
- const FloatingPointToFPReal &param = op.getConst<FloatingPointToFPReal>();
- Node lit =
- NodeManager::currentNM()->mkConst(FloatingPoint(param.t.exponent(),
- param.t.significand(),
- node[1].getConst<Rational>().getDouble()));
-
- return RewriteResponse(REWRITE_DONE, lit);
- } else {
- return RewriteResponse(REWRITE_DONE, node);
- }
- }
-
- RewriteResponse convertFromIEEEBitVectorLiteral (TNode node, bool) {
- Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR);
-
- // \todo Handle arbitrary length bit vectors without using strings!
-
- if (node[0].getKind() == kind::CONST_BITVECTOR) {
- TNode op = node.getOperator();
- const FloatingPointToFPIEEEBitVector &param = op.getConst<FloatingPointToFPIEEEBitVector>();
- const BitVector &bv = node[0].getConst<BitVector>();
- std::string bitString(bv.toString());
-
- Node lit =
- NodeManager::currentNM()->mkConst(FloatingPoint(param.t.exponent(),
- param.t.significand(),
- bitString));
-
- return RewriteResponse(REWRITE_DONE, lit);
- } else {
- return RewriteResponse(REWRITE_DONE, node);
- }
- }
-
- RewriteResponse convertFromLiteral (TNode node, bool) {
- Assert(node.getKind() == kind::FLOATINGPOINT_FP);
-
- // \todo Handle arbitrary length bit vectors without using strings!
-
- if ((node[0].getKind() == kind::CONST_BITVECTOR) &&
- (node[1].getKind() == kind::CONST_BITVECTOR) &&
- (node[2].getKind() == kind::CONST_BITVECTOR)) {
-
- BitVector bv(node[0].getConst<BitVector>());
- bv = bv.concat(node[1].getConst<BitVector>());
- bv = bv.concat(node[2].getConst<BitVector>());
-
- std::string bitString(bv.toString());
- std::reverse(bitString.begin(), bitString.end());
-
- // +1 to support the hidden bit
- Node lit =
- NodeManager::currentNM()->mkConst(FloatingPoint(node[1].getConst<BitVector>().getSize(),
- node[2].getConst<BitVector>().getSize() + 1,
- bitString));
-
- return RewriteResponse(REWRITE_DONE, lit);
- } else {
- return RewriteResponse(REWRITE_DONE, node);
- }
- }
// Note these cannot be assumed to be symmetric for +0/-0, thus no symmetry reorder
RewriteResponse compactMinMax (TNode node, bool isPreRewrite) {
@@ -309,11 +246,410 @@ namespace rewrite {
}
}
-
}; /* CVC4::theory::fp::rewrite */
+
+namespace constantFold {
+
+
+ RewriteResponse fpLiteral (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_FP);
+
+ BitVector bv(node[0].getConst<BitVector>());
+ bv = bv.concat(node[1].getConst<BitVector>());
+ bv = bv.concat(node[2].getConst<BitVector>());
+
+ // +1 to support the hidden bit
+ Node lit =
+ NodeManager::currentNM()->mkConst(FloatingPoint(node[1].getConst<BitVector>().getSize(),
+ node[2].getConst<BitVector>().getSize() + 1,
+ bv));
+
+ return RewriteResponse(REWRITE_DONE, lit);
+ }
+
+ RewriteResponse abs (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_ABS);
+ Assert(node.getNumChildren() == 1);
+
+ return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().absolute()));
+ }
+
+
+ RewriteResponse neg (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_NEG);
+ Assert(node.getNumChildren() == 1);
+
+ return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().negate()));
+ }
+
+
+ RewriteResponse plus (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_PLUS);
+ Assert(node.getNumChildren() == 3);
+
+ RoundingMode rm(node[0].getConst<RoundingMode>());
+ FloatingPoint arg1(node[1].getConst<FloatingPoint>());
+ FloatingPoint arg2(node[2].getConst<FloatingPoint>());
+
+ Assert(arg1.t == arg2.t);
+
+ return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.plus(rm, arg2)));
+ }
+
+ RewriteResponse mult (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_MULT);
+ Assert(node.getNumChildren() == 3);
+
+ RoundingMode rm(node[0].getConst<RoundingMode>());
+ FloatingPoint arg1(node[1].getConst<FloatingPoint>());
+ FloatingPoint arg2(node[2].getConst<FloatingPoint>());
+
+ Assert(arg1.t == arg2.t);
+
+ return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.mult(rm, arg2)));
+ }
+
+ RewriteResponse fma (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_FMA);
+ Assert(node.getNumChildren() == 4);
+
+ RoundingMode rm(node[0].getConst<RoundingMode>());
+ FloatingPoint arg1(node[1].getConst<FloatingPoint>());
+ FloatingPoint arg2(node[2].getConst<FloatingPoint>());
+ FloatingPoint arg3(node[3].getConst<FloatingPoint>());
+
+ Assert(arg1.t == arg2.t);
+ Assert(arg1.t == arg3.t);
+
+ return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.fma(rm, arg2, arg3)));
+ }
+
+ RewriteResponse div (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_DIV);
+ Assert(node.getNumChildren() == 3);
+
+ RoundingMode rm(node[0].getConst<RoundingMode>());
+ FloatingPoint arg1(node[1].getConst<FloatingPoint>());
+ FloatingPoint arg2(node[2].getConst<FloatingPoint>());
+
+ Assert(arg1.t == arg2.t);
+
+ return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.div(rm, arg2)));
+ }
+
+ RewriteResponse sqrt (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_SQRT);
+ Assert(node.getNumChildren() == 2);
+
+ RoundingMode rm(node[0].getConst<RoundingMode>());
+ FloatingPoint arg(node[1].getConst<FloatingPoint>());
+
+ return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg.sqrt(rm)));
+ }
+
+ RewriteResponse rti (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_RTI);
+ Assert(node.getNumChildren() == 2);
+
+ RoundingMode rm(node[0].getConst<RoundingMode>());
+ FloatingPoint arg(node[1].getConst<FloatingPoint>());
+
+ return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg.rti(rm)));
+ }
+
+ RewriteResponse rem (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_REM);
+ Assert(node.getNumChildren() == 2);
+
+ FloatingPoint arg1(node[0].getConst<FloatingPoint>());
+ FloatingPoint arg2(node[1].getConst<FloatingPoint>());
+
+ Assert(arg1.t == arg2.t);
+
+ return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.rem(arg2)));
+ }
+
+ RewriteResponse min (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_MIN);
+ Assert(node.getNumChildren() == 2);
+
+ FloatingPoint arg1(node[0].getConst<FloatingPoint>());
+ FloatingPoint arg2(node[1].getConst<FloatingPoint>());
+
+ Assert(arg1.t == arg2.t);
+
+ FloatingPoint::PartialFloatingPoint res(arg1.min(arg2));
+
+ if (res.second) {
+ Node lit = NodeManager::currentNM()->mkConst(res.first);
+ return RewriteResponse(REWRITE_DONE, lit);
+ } else {
+ // Can't constant fold the underspecified case
+ return RewriteResponse(REWRITE_DONE, node);
+ }
+ }
+
+ RewriteResponse max (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_MAX);
+ Assert(node.getNumChildren() == 2);
+
+ FloatingPoint arg1(node[0].getConst<FloatingPoint>());
+ FloatingPoint arg2(node[1].getConst<FloatingPoint>());
+
+ Assert(arg1.t == arg2.t);
+
+ FloatingPoint::PartialFloatingPoint res(arg1.max(arg2));
+
+ if (res.second) {
+ Node lit = NodeManager::currentNM()->mkConst(res.first);
+ return RewriteResponse(REWRITE_DONE, lit);
+ } else {
+ // Can't constant fold the underspecified case
+ return RewriteResponse(REWRITE_DONE, node);
+ }
+ }
+
+
+ RewriteResponse equal (TNode node, bool isPreRewrite) {
+ Assert(node.getKind() == kind::EQUAL);
+
+ // We should only get equalities of floating point or rounding mode types.
+ TypeNode tn = node[0].getType(true);
+
+ if (tn.isFloatingPoint()) {
+ FloatingPoint arg1(node[0].getConst<FloatingPoint>());
+ FloatingPoint arg2(node[1].getConst<FloatingPoint>());
+
+ Assert(arg1.t == arg2.t);
+
+ return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1 == arg2));
+
+ } else if (tn.isRoundingMode()) {
+ RoundingMode arg1(node[0].getConst<RoundingMode>());
+ RoundingMode arg2(node[1].getConst<RoundingMode>());
+
+ return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1 == arg2));
+
+ } else {
+ Unreachable("Equality of unknown type");
+ }
+
+ return RewriteResponse(REWRITE_DONE, node);
+ }
+
+
+ RewriteResponse leq (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_LEQ);
+ Assert(node.getNumChildren() == 2);
+
+ FloatingPoint arg1(node[0].getConst<FloatingPoint>());
+ FloatingPoint arg2(node[1].getConst<FloatingPoint>());
+
+ Assert(arg1.t == arg2.t);
+
+ return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1 <= arg2));
+ }
+
+
+ RewriteResponse lt (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_LT);
+ Assert(node.getNumChildren() == 2);
+
+ FloatingPoint arg1(node[0].getConst<FloatingPoint>());
+ FloatingPoint arg2(node[1].getConst<FloatingPoint>());
+
+ Assert(arg1.t == arg2.t);
+
+ return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1 < arg2));
+ }
+
+
+ RewriteResponse isNormal (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_ISN);
+ Assert(node.getNumChildren() == 1);
+
+ return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isNormal()));
+ }
+
+ RewriteResponse isSubnormal (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_ISSN);
+ Assert(node.getNumChildren() == 1);
+
+ return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isSubnormal()));
+ }
+
+ RewriteResponse isZero (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_ISZ);
+ Assert(node.getNumChildren() == 1);
+
+ return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isZero()));
+ }
+
+ RewriteResponse isInfinite (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_ISINF);
+ Assert(node.getNumChildren() == 1);
+
+ return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isInfinite()));
+ }
+
+ RewriteResponse isNaN (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_ISNAN);
+ Assert(node.getNumChildren() == 1);
+
+ return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isNaN()));
+ }
+
+ RewriteResponse isNegative (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_ISNEG);
+ Assert(node.getNumChildren() == 1);
+
+ return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isNegative()));
+ }
+
+ RewriteResponse isPositive (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_ISPOS);
+ Assert(node.getNumChildren() == 1);
+
+ return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(node[0].getConst<FloatingPoint>().isPositive()));
+ }
+
+ RewriteResponse convertFromIEEEBitVectorLiteral (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR);
+
+ TNode op = node.getOperator();
+ const FloatingPointToFPIEEEBitVector &param = op.getConst<FloatingPointToFPIEEEBitVector>();
+ const BitVector &bv = node[0].getConst<BitVector>();
+
+ Node lit =
+ NodeManager::currentNM()->mkConst(FloatingPoint(param.t.exponent(),
+ param.t.significand(),
+ bv));
+
+ return RewriteResponse(REWRITE_DONE, lit);
+ }
+
+ RewriteResponse constantConvert (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT);
+ Assert(node.getNumChildren() == 2);
+
+ RoundingMode rm(node[0].getConst<RoundingMode>());
+ FloatingPoint arg1(node[1].getConst<FloatingPoint>());
+ FloatingPointToFPFloatingPoint info = node.getOperator().getConst<FloatingPointToFPFloatingPoint>();
+
+ return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1.convert(info.t,rm)));
+ }
+
+ RewriteResponse convertFromRealLiteral (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_REAL);
+
+ TNode op = node.getOperator();
+ const FloatingPointToFPReal &param = op.getConst<FloatingPointToFPReal>();
+
+ RoundingMode rm(node[0].getConst<RoundingMode>());
+ Rational arg(node[1].getConst<Rational>());
+
+ FloatingPoint res(param.t, rm, arg);
+
+ Node lit = NodeManager::currentNM()->mkConst(res);
+
+ return RewriteResponse(REWRITE_DONE, lit);
+ }
+
+ RewriteResponse convertFromSBV (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR);
+
+ TNode op = node.getOperator();
+ const FloatingPointToFPSignedBitVector &param = op.getConst<FloatingPointToFPSignedBitVector>();
+
+ RoundingMode rm(node[0].getConst<RoundingMode>());
+ BitVector arg(node[1].getConst<BitVector>());
+
+ FloatingPoint res(param.t, rm, arg, true);
+
+ Node lit = NodeManager::currentNM()->mkConst(res);
+
+ return RewriteResponse(REWRITE_DONE, lit);
+ }
+
+ RewriteResponse convertFromUBV (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR);
+
+ TNode op = node.getOperator();
+ const FloatingPointToFPUnsignedBitVector &param = op.getConst<FloatingPointToFPUnsignedBitVector>();
+
+ RoundingMode rm(node[0].getConst<RoundingMode>());
+ BitVector arg(node[1].getConst<BitVector>());
+
+ FloatingPoint res(param.t, rm, arg, false);
+
+ Node lit = NodeManager::currentNM()->mkConst(res);
+
+ return RewriteResponse(REWRITE_DONE, lit);
+ }
+
+ RewriteResponse convertToUBV (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_TO_UBV);
+
+ TNode op = node.getOperator();
+ const FloatingPointToUBV &param = op.getConst<FloatingPointToUBV>();
+
+ RoundingMode rm(node[0].getConst<RoundingMode>());
+ FloatingPoint arg(node[1].getConst<FloatingPoint>());
+
+ FloatingPoint::PartialBitVector res(arg.convertToBV(param.bvs, rm, false));
+
+ if (res.second) {
+ Node lit = NodeManager::currentNM()->mkConst(res.first);
+ return RewriteResponse(REWRITE_DONE, lit);
+ } else {
+ // Can't constant fold the underspecified case
+ return RewriteResponse(REWRITE_DONE, node);
+ }
+ }
+
+ RewriteResponse convertToSBV (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_TO_SBV);
+
+ TNode op = node.getOperator();
+ const FloatingPointToSBV &param = op.getConst<FloatingPointToSBV>();
+
+ RoundingMode rm(node[0].getConst<RoundingMode>());
+ FloatingPoint arg(node[1].getConst<FloatingPoint>());
+
+ FloatingPoint::PartialBitVector res(arg.convertToBV(param.bvs, rm, true));
+
+ if (res.second) {
+ Node lit = NodeManager::currentNM()->mkConst(res.first);
+ return RewriteResponse(REWRITE_DONE, lit);
+ } else {
+ // Can't constant fold the underspecified case
+ return RewriteResponse(REWRITE_DONE, node);
+ }
+ }
+
+ RewriteResponse convertToReal (TNode node, bool) {
+ Assert(node.getKind() == kind::FLOATINGPOINT_TO_REAL);
+
+ FloatingPoint arg(node[0].getConst<FloatingPoint>());
+
+ FloatingPoint::PartialRational res(arg.convertToRational());
+
+ if (res.second) {
+ Node lit = NodeManager::currentNM()->mkConst(res.first);
+ return RewriteResponse(REWRITE_DONE, lit);
+ } else {
+ // Can't constant fold the underspecified case
+ return RewriteResponse(REWRITE_DONE, node);
+ }
+ }
+
+}; /* CVC4::theory::fp::constantFold */
+
+
RewriteFunction TheoryFpRewriter::preRewriteTable[kind::LAST_KIND];
RewriteFunction TheoryFpRewriter::postRewriteTable[kind::LAST_KIND];
+RewriteFunction TheoryFpRewriter::constantFoldTable[kind::LAST_KIND];
/**
@@ -381,6 +717,7 @@ RewriteFunction TheoryFpRewriter::postRewriteTable[kind::LAST_KIND];
/******** Variables ********/
preRewriteTable[kind::VARIABLE] = rewrite::variable;
preRewriteTable[kind::BOUND_VARIABLE] = rewrite::variable;
+ preRewriteTable[kind::SKOLEM] = rewrite::variable;
preRewriteTable[kind::EQUAL] = rewrite::equal;
@@ -403,7 +740,7 @@ RewriteFunction TheoryFpRewriter::postRewriteTable[kind::LAST_KIND];
postRewriteTable[kind::FLOATINGPOINT_TYPE] = rewrite::type;
/******** Operations ********/
- postRewriteTable[kind::FLOATINGPOINT_FP] = rewrite::convertFromLiteral;
+ postRewriteTable[kind::FLOATINGPOINT_FP] = rewrite::identity;
postRewriteTable[kind::FLOATINGPOINT_ABS] = rewrite::identity;
postRewriteTable[kind::FLOATINGPOINT_NEG] = rewrite::removeDoubleNegation;
postRewriteTable[kind::FLOATINGPOINT_PLUS] = rewrite::reorderBinaryOperation;
@@ -434,9 +771,9 @@ RewriteFunction TheoryFpRewriter::postRewriteTable[kind::LAST_KIND];
postRewriteTable[kind::FLOATINGPOINT_ISPOS] = rewrite::identity;
/******** Conversions ********/
- postRewriteTable[kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR] = rewrite::convertFromIEEEBitVectorLiteral;
+ postRewriteTable[kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR] = rewrite::identity;
postRewriteTable[kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT] = rewrite::identity;
- postRewriteTable[kind::FLOATINGPOINT_TO_FP_REAL] = rewrite::convertFromRealLiteral;
+ postRewriteTable[kind::FLOATINGPOINT_TO_FP_REAL] = rewrite::identity;
postRewriteTable[kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR] = rewrite::identity;
postRewriteTable[kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR] = rewrite::identity;
postRewriteTable[kind::FLOATINGPOINT_TO_FP_GENERIC] = rewrite::removed;
@@ -447,10 +784,81 @@ RewriteFunction TheoryFpRewriter::postRewriteTable[kind::LAST_KIND];
/******** Variables ********/
postRewriteTable[kind::VARIABLE] = rewrite::variable;
postRewriteTable[kind::BOUND_VARIABLE] = rewrite::variable;
+ postRewriteTable[kind::SKOLEM] = rewrite::variable;
postRewriteTable[kind::EQUAL] = rewrite::equal;
+
+
+ /* Set up the post-rewrite constant fold table */
+ for (unsigned i = 0; i < kind::LAST_KIND; ++i) {
+ // Note that this is identity, not notFP
+ // Constant folding is called after post-rewrite
+ // So may have to deal with cases of things being
+ // re-written to non-floating-point sorts (i.e. true).
+ constantFoldTable[i] = rewrite::identity;
+ }
+
+ /******** Constants ********/
+ /* Already folded! */
+ constantFoldTable[kind::CONST_FLOATINGPOINT] = rewrite::identity;
+ constantFoldTable[kind::CONST_ROUNDINGMODE] = rewrite::identity;
+
+ /******** Sorts(?) ********/
+ /* These kinds should only appear in types */
+ constantFoldTable[kind::FLOATINGPOINT_TYPE] = rewrite::type;
+
+ /******** Operations ********/
+ constantFoldTable[kind::FLOATINGPOINT_FP] = constantFold::fpLiteral;
+ constantFoldTable[kind::FLOATINGPOINT_ABS] = constantFold::abs;
+ constantFoldTable[kind::FLOATINGPOINT_NEG] = constantFold::neg;
+ constantFoldTable[kind::FLOATINGPOINT_PLUS] = constantFold::plus;
+ constantFoldTable[kind::FLOATINGPOINT_SUB] = rewrite::removed;
+ constantFoldTable[kind::FLOATINGPOINT_MULT] = constantFold::mult;
+ constantFoldTable[kind::FLOATINGPOINT_DIV] = constantFold::div;
+ constantFoldTable[kind::FLOATINGPOINT_FMA] = constantFold::fma;
+ constantFoldTable[kind::FLOATINGPOINT_SQRT] = constantFold::sqrt;
+ constantFoldTable[kind::FLOATINGPOINT_REM] = constantFold::rem;
+ constantFoldTable[kind::FLOATINGPOINT_RTI] = constantFold::rti;
+ constantFoldTable[kind::FLOATINGPOINT_MIN] = constantFold::min;
+ constantFoldTable[kind::FLOATINGPOINT_MAX] = constantFold::max;
+
+ /******** Comparisons ********/
+ constantFoldTable[kind::FLOATINGPOINT_EQ] = rewrite::removed;
+ constantFoldTable[kind::FLOATINGPOINT_LEQ] = constantFold::leq;
+ constantFoldTable[kind::FLOATINGPOINT_LT] = constantFold::lt;
+ constantFoldTable[kind::FLOATINGPOINT_GEQ] = rewrite::removed;
+ constantFoldTable[kind::FLOATINGPOINT_GT] = rewrite::removed;
+
+ /******** Classifications ********/
+ constantFoldTable[kind::FLOATINGPOINT_ISN] = constantFold::isNormal;
+ constantFoldTable[kind::FLOATINGPOINT_ISSN] = constantFold::isSubnormal;
+ constantFoldTable[kind::FLOATINGPOINT_ISZ] = constantFold::isZero;
+ constantFoldTable[kind::FLOATINGPOINT_ISINF] = constantFold::isInfinite;
+ constantFoldTable[kind::FLOATINGPOINT_ISNAN] = constantFold::isNaN;
+ constantFoldTable[kind::FLOATINGPOINT_ISNEG] = constantFold::isNegative;
+ constantFoldTable[kind::FLOATINGPOINT_ISPOS] = constantFold::isPositive;
+
+ /******** Conversions ********/
+ constantFoldTable[kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR] = constantFold::convertFromIEEEBitVectorLiteral;
+ constantFoldTable[kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT] = constantFold::constantConvert;
+ constantFoldTable[kind::FLOATINGPOINT_TO_FP_REAL] = constantFold::convertFromRealLiteral;
+ constantFoldTable[kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR] = constantFold::convertFromSBV;
+ constantFoldTable[kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR] = constantFold::convertFromUBV;
+ constantFoldTable[kind::FLOATINGPOINT_TO_FP_GENERIC] = rewrite::removed;
+ constantFoldTable[kind::FLOATINGPOINT_TO_UBV] = constantFold::convertToUBV;
+ constantFoldTable[kind::FLOATINGPOINT_TO_SBV] = constantFold::convertToSBV;
+ constantFoldTable[kind::FLOATINGPOINT_TO_REAL] = constantFold::convertToReal;
+
+ /******** Variables ********/
+ constantFoldTable[kind::VARIABLE] = rewrite::variable;
+ constantFoldTable[kind::BOUND_VARIABLE] = rewrite::variable;
+
+ constantFoldTable[kind::EQUAL] = constantFold::equal;
+
+
+
}
@@ -506,6 +914,83 @@ RewriteFunction TheoryFpRewriter::postRewriteTable[kind::LAST_KIND];
Debug("fp-rewrite") << "TheoryFpRewriter::postRewrite(): before " << node << std::endl;
Debug("fp-rewrite") << "TheoryFpRewriter::postRewrite(): after " << res.node << std::endl;
}
+
+ if (res.status == REWRITE_DONE) {
+ bool allChildrenConst = true;
+ bool apartFromRoundingMode = false;
+ for (Node::const_iterator i = res.node.begin();
+ i != res.node.end();
+ ++i) {
+
+ if ((*i).getMetaKind() != kind::metakind::CONSTANT) {
+ if ((*i).getType().isRoundingMode() && !apartFromRoundingMode) {
+ apartFromRoundingMode = true;
+ } else {
+ allChildrenConst = false;
+ break;
+ }
+ }
+ }
+
+ if (allChildrenConst) {
+ RewriteStatus rs = REWRITE_DONE; // This is a bit messy because
+ Node rn = res.node; // RewriteResponse is too functional..
+
+ if (apartFromRoundingMode) {
+ if (!(res.node.getKind() == kind::EQUAL)) { // Avoid infinite recursion...
+ // We are close to being able to constant fold this
+ // and in many cases the rounding mode really doesn't matter.
+ // So we can try brute forcing our way through them.
+
+ NodeManager *nm = NodeManager::currentNM();
+
+ Node RNE(nm->mkConst(roundNearestTiesToEven));
+ Node RNA(nm->mkConst(roundNearestTiesToAway));
+ Node RTZ(nm->mkConst(roundTowardPositive));
+ Node RTN(nm->mkConst(roundTowardNegative));
+ Node RTP(nm->mkConst(roundTowardZero));
+
+ TNode RM(res.node[0]);
+
+ Node wRNE(res.node.substitute(RM, TNode(RNE)));
+ Node wRNA(res.node.substitute(RM, TNode(RNA)));
+ Node wRTZ(res.node.substitute(RM, TNode(RTZ)));
+ Node wRTN(res.node.substitute(RM, TNode(RTN)));
+ Node wRTP(res.node.substitute(RM, TNode(RTP)));
+
+
+ rs = REWRITE_AGAIN_FULL;
+ rn = nm->mkNode(kind::ITE,
+ nm->mkNode(kind::EQUAL, RM, RNE),
+ wRNE,
+ nm->mkNode(kind::ITE,
+ nm->mkNode(kind::EQUAL, RM, RNA),
+ wRNA,
+ nm->mkNode(kind::ITE,
+ nm->mkNode(kind::EQUAL, RM, RTZ),
+ wRTZ,
+ nm->mkNode(kind::ITE,
+ nm->mkNode(kind::EQUAL, RM, RTN),
+ wRTN,
+ wRTP))));
+ }
+ } else {
+ RewriteResponse tmp = constantFoldTable [res.node.getKind()] (res.node, false);
+ rs = tmp.status;
+ rn = tmp.node;
+ }
+
+ RewriteResponse constRes(rs,rn);
+
+ if (constRes.node != res.node) {
+ Debug("fp-rewrite") << "TheoryFpRewriter::postRewrite(): before constant fold " << res.node << std::endl;
+ Debug("fp-rewrite") << "TheoryFpRewriter::postRewrite(): after constant fold " << constRes.node << std::endl;
+ }
+
+ return constRes;
+ }
+ }
+
return res;
}
diff --git a/src/theory/fp/theory_fp_rewriter.h b/src/theory/fp/theory_fp_rewriter.h
index d2a9a0466..56492f921 100644
--- a/src/theory/fp/theory_fp_rewriter.h
+++ b/src/theory/fp/theory_fp_rewriter.h
@@ -32,6 +32,8 @@ class TheoryFpRewriter {
protected :
static RewriteFunction preRewriteTable[kind::LAST_KIND];
static RewriteFunction postRewriteTable[kind::LAST_KIND];
+ static RewriteFunction constantFoldTable[kind::LAST_KIND];
+
public:
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback