summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAina Niemetz <aina.niemetz@gmail.com>2021-05-03 13:27:02 -0700
committerGitHub <noreply@github.com>2021-05-03 20:27:02 +0000
commitc8c7a075428e6193dee86e57a9ecb8af11af270c (patch)
treeac12ea111a8228c594495573bc5407a05d3b3131 /src
parent439ab123cccdbf4f046b4e084ce996a1dc2aa758 (diff)
FP: Rewrite to_fp conversion from signed bit-vector. (#6472)
SymFPU does not allow to_fp conversion from signed bv of size 1. This adds rewrites for this case. Rewrites for the constant and the non-constant cases were tested in isolation.
Diffstat (limited to 'src')
-rw-r--r--src/theory/fp/theory_fp_rewriter.cpp99
1 files changed, 66 insertions, 33 deletions
diff --git a/src/theory/fp/theory_fp_rewriter.cpp b/src/theory/fp/theory_fp_rewriter.cpp
index e431ffa09..74e1ff526 100644
--- a/src/theory/fp/theory_fp_rewriter.cpp
+++ b/src/theory/fp/theory_fp_rewriter.cpp
@@ -32,11 +32,13 @@
* - Samuel Figuer results
*/
+#include "theory/fp/theory_fp_rewriter.h"
+
#include <algorithm>
#include "base/check.h"
+#include "theory/bv/theory_bv_utils.h"
#include "theory/fp/fp_converter.h"
-#include "theory/fp/theory_fp_rewriter.h"
namespace cvc5 {
namespace theory {
@@ -333,6 +335,28 @@ namespace rewrite {
return RewriteResponse(REWRITE_DONE, node);
}
+ RewriteResponse toFPSignedBV(TNode node, bool isPreRewrite)
+ {
+ Assert(!isPreRewrite);
+ Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR);
+
+ /* symFPU does not allow conversions from signed bit-vector of size 1 */
+ if (node[1].getType().getBitVectorSize() == 1)
+ {
+ NodeManager* nm = NodeManager::currentNM();
+ Node op = nm->mkConst(FloatingPointToFPUnsignedBitVector(
+ node.getOperator().getConst<FloatingPointToFPSignedBitVector>()));
+ Node fromubv = nm->mkNode(op, node[0], node[1]);
+ return RewriteResponse(
+ REWRITE_AGAIN_FULL,
+ nm->mkNode(kind::ITE,
+ node[1].eqNode(bv::utils::mkOne(1)),
+ nm->mkNode(kind::FLOATINGPOINT_NEG, fromubv),
+ fromubv));
+ }
+ return RewriteResponse(REWRITE_DONE, node);
+ }
+
}; // namespace rewrite
namespace constantFold {
@@ -736,15 +760,16 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite)
Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_REAL);
TNode op = node.getOperator();
- const FloatingPointToFPReal &param = op.getConst<FloatingPointToFPReal>();
+ const FloatingPointSize& size =
+ op.getConst<FloatingPointToFPReal>().getSize();
RoundingMode rm(node[0].getConst<RoundingMode>());
Rational arg(node[1].getConst<Rational>());
- FloatingPoint res(param.getSize(), rm, arg);
+ FloatingPoint res(size, rm, arg);
Node lit = NodeManager::currentNM()->mkConst(res);
-
+
return RewriteResponse(REWRITE_DONE, lit);
}
@@ -753,16 +778,27 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite)
Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR);
TNode op = node.getOperator();
- const FloatingPointToFPSignedBitVector &param = op.getConst<FloatingPointToFPSignedBitVector>();
+ const FloatingPointSize& size =
+ op.getConst<FloatingPointToFPSignedBitVector>().getSize();
RoundingMode rm(node[0].getConst<RoundingMode>());
- BitVector arg(node[1].getConst<BitVector>());
+ BitVector sbv(node[1].getConst<BitVector>());
- FloatingPoint res(param.getSize(), rm, arg, true);
+ NodeManager* nm = NodeManager::currentNM();
- Node lit = NodeManager::currentNM()->mkConst(res);
-
- return RewriteResponse(REWRITE_DONE, lit);
+ /* symFPU does not allow conversions from signed bit-vector of size 1 */
+ if (sbv.getSize() == 1)
+ {
+ FloatingPoint fromubv(size, rm, sbv, false);
+ if (sbv.isBitSet(0))
+ {
+ return RewriteResponse(REWRITE_DONE, nm->mkConst(fromubv.negate()));
+ }
+ return RewriteResponse(REWRITE_DONE, nm->mkConst(fromubv));
+ }
+
+ return RewriteResponse(REWRITE_DONE,
+ nm->mkConst(FloatingPoint(size, rm, sbv, true)));
}
RewriteResponse convertFromUBV(TNode node, bool isPreRewrite)
@@ -770,15 +806,16 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite)
Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR);
TNode op = node.getOperator();
- const FloatingPointToFPUnsignedBitVector &param = op.getConst<FloatingPointToFPUnsignedBitVector>();
+ const FloatingPointSize& size =
+ op.getConst<FloatingPointToFPUnsignedBitVector>().getSize();
RoundingMode rm(node[0].getConst<RoundingMode>());
BitVector arg(node[1].getConst<BitVector>());
- FloatingPoint res(param.getSize(), rm, arg, false);
+ FloatingPoint res(size, rm, arg, false);
Node lit = NodeManager::currentNM()->mkConst(res);
-
+
return RewriteResponse(REWRITE_DONE, lit);
}
@@ -787,13 +824,12 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite)
Assert(node.getKind() == kind::FLOATINGPOINT_TO_UBV);
TNode op = node.getOperator();
- const FloatingPointToUBV &param = op.getConst<FloatingPointToUBV>();
+ const BitVectorSize& size = op.getConst<FloatingPointToUBV>().d_bv_size;
RoundingMode rm(node[0].getConst<RoundingMode>());
FloatingPoint arg(node[1].getConst<FloatingPoint>());
- FloatingPoint::PartialBitVector res(
- arg.convertToBV(param.d_bv_size, rm, false));
+ FloatingPoint::PartialBitVector res(arg.convertToBV(size, rm, false));
if (res.second) {
Node lit = NodeManager::currentNM()->mkConst(res.first);
@@ -809,13 +845,12 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite)
Assert(node.getKind() == kind::FLOATINGPOINT_TO_SBV);
TNode op = node.getOperator();
- const FloatingPointToSBV &param = op.getConst<FloatingPointToSBV>();
+ const BitVectorSize& size = op.getConst<FloatingPointToSBV>().d_bv_size;
RoundingMode rm(node[0].getConst<RoundingMode>());
FloatingPoint arg(node[1].getConst<FloatingPoint>());
- FloatingPoint::PartialBitVector res(
- arg.convertToBV(param.d_bv_size, rm, true));
+ FloatingPoint::PartialBitVector res(arg.convertToBV(size, rm, true));
if (res.second) {
Node lit = NodeManager::currentNM()->mkConst(res.first);
@@ -848,7 +883,8 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite)
Assert(node.getKind() == kind::FLOATINGPOINT_TO_UBV_TOTAL);
TNode op = node.getOperator();
- const FloatingPointToUBVTotal &param = op.getConst<FloatingPointToUBVTotal>();
+ const BitVectorSize& size =
+ op.getConst<FloatingPointToUBVTotal>().d_bv_size;
RoundingMode rm(node[0].getConst<RoundingMode>());
FloatingPoint arg(node[1].getConst<FloatingPoint>());
@@ -857,14 +893,12 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite)
if (node[2].getMetaKind() == kind::metakind::CONSTANT) {
BitVector partialValue(node[2].getConst<BitVector>());
- BitVector folded(
- arg.convertToBVTotal(param.d_bv_size, rm, false, partialValue));
+ BitVector folded(arg.convertToBVTotal(size, rm, false, partialValue));
Node lit = NodeManager::currentNM()->mkConst(folded);
return RewriteResponse(REWRITE_DONE, lit);
} else {
- FloatingPoint::PartialBitVector res(
- arg.convertToBV(param.d_bv_size, rm, false));
+ FloatingPoint::PartialBitVector res(arg.convertToBV(size, rm, false));
if (res.second) {
Node lit = NodeManager::currentNM()->mkConst(res.first);
@@ -881,7 +915,8 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite)
Assert(node.getKind() == kind::FLOATINGPOINT_TO_SBV_TOTAL);
TNode op = node.getOperator();
- const FloatingPointToSBVTotal &param = op.getConst<FloatingPointToSBVTotal>();
+ const BitVectorSize& size =
+ op.getConst<FloatingPointToSBVTotal>().d_bv_size;
RoundingMode rm(node[0].getConst<RoundingMode>());
FloatingPoint arg(node[1].getConst<FloatingPoint>());
@@ -890,14 +925,12 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite)
if (node[2].getMetaKind() == kind::metakind::CONSTANT) {
BitVector partialValue(node[2].getConst<BitVector>());
- BitVector folded(
- arg.convertToBVTotal(param.d_bv_size, rm, true, partialValue));
+ BitVector folded(arg.convertToBVTotal(size, rm, true, partialValue));
Node lit = NodeManager::currentNM()->mkConst(folded);
return RewriteResponse(REWRITE_DONE, lit);
} else {
- FloatingPoint::PartialBitVector res(
- arg.convertToBV(param.d_bv_size, rm, true));
+ FloatingPoint::PartialBitVector res(arg.convertToBV(size, rm, true));
if (res.second) {
Node lit = NodeManager::currentNM()->mkConst(res.first);
@@ -1049,7 +1082,7 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite)
TheoryFpRewriter::TheoryFpRewriter(context::UserContext* u) : d_fpExpDef(u)
{
/* Set up the pre-rewrite dispatch table */
- for (unsigned i = 0; i < kind::LAST_KIND; ++i)
+ for (uint32_t i = 0; i < kind::LAST_KIND; ++i)
{
d_preRewriteTable[i] = rewrite::notFP;
}
@@ -1140,7 +1173,7 @@ TheoryFpRewriter::TheoryFpRewriter(context::UserContext* u) : d_fpExpDef(u)
d_preRewriteTable[kind::ROUNDINGMODE_BITBLAST] = rewrite::identity;
/* Set up the post-rewrite dispatch table */
- for (unsigned i = 0; i < kind::LAST_KIND; ++i)
+ for (uint32_t i = 0; i < kind::LAST_KIND; ++i)
{
d_postRewriteTable[i] = rewrite::notFP;
}
@@ -1197,7 +1230,7 @@ TheoryFpRewriter::TheoryFpRewriter(context::UserContext* u) : d_fpExpDef(u)
rewrite::identity;
d_postRewriteTable[kind::FLOATINGPOINT_TO_FP_REAL] = rewrite::identity;
d_postRewriteTable[kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR] =
- rewrite::identity;
+ rewrite::toFPSignedBV;
d_postRewriteTable[kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR] =
rewrite::identity;
d_postRewriteTable[kind::FLOATINGPOINT_TO_FP_GENERIC] = rewrite::identity;
@@ -1228,7 +1261,7 @@ TheoryFpRewriter::TheoryFpRewriter(context::UserContext* u) : d_fpExpDef(u)
d_postRewriteTable[kind::ROUNDINGMODE_BITBLAST] = rewrite::identity;
/* Set up the post-rewrite constant fold table */
- for (unsigned i = 0; i < kind::LAST_KIND; ++i)
+ for (uint32_t i = 0; i < kind::LAST_KIND; ++i)
{
// Note that this is identity, not notFP
// Constant folding is called after post-rewrite
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback