diff options
author | Aina Niemetz <aina.niemetz@gmail.com> | 2021-05-03 13:27:02 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-05-03 20:27:02 +0000 |
commit | c8c7a075428e6193dee86e57a9ecb8af11af270c (patch) | |
tree | ac12ea111a8228c594495573bc5407a05d3b3131 /src/theory | |
parent | 439ab123cccdbf4f046b4e084ce996a1dc2aa758 (diff) |
FP: Rewrite to_fp conversion from signed bit-vector. (#6472)
SymFPU does not allow to_fp conversion from signed bv of size 1. This
adds rewrites for this case.
Rewrites for the constant and the non-constant cases were tested in
isolation.
Diffstat (limited to 'src/theory')
-rw-r--r-- | src/theory/fp/theory_fp_rewriter.cpp | 99 |
1 files changed, 66 insertions, 33 deletions
diff --git a/src/theory/fp/theory_fp_rewriter.cpp b/src/theory/fp/theory_fp_rewriter.cpp index e431ffa09..74e1ff526 100644 --- a/src/theory/fp/theory_fp_rewriter.cpp +++ b/src/theory/fp/theory_fp_rewriter.cpp @@ -32,11 +32,13 @@ * - Samuel Figuer results */ +#include "theory/fp/theory_fp_rewriter.h" + #include <algorithm> #include "base/check.h" +#include "theory/bv/theory_bv_utils.h" #include "theory/fp/fp_converter.h" -#include "theory/fp/theory_fp_rewriter.h" namespace cvc5 { namespace theory { @@ -333,6 +335,28 @@ namespace rewrite { return RewriteResponse(REWRITE_DONE, node); } + RewriteResponse toFPSignedBV(TNode node, bool isPreRewrite) + { + Assert(!isPreRewrite); + Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR); + + /* symFPU does not allow conversions from signed bit-vector of size 1 */ + if (node[1].getType().getBitVectorSize() == 1) + { + NodeManager* nm = NodeManager::currentNM(); + Node op = nm->mkConst(FloatingPointToFPUnsignedBitVector( + node.getOperator().getConst<FloatingPointToFPSignedBitVector>())); + Node fromubv = nm->mkNode(op, node[0], node[1]); + return RewriteResponse( + REWRITE_AGAIN_FULL, + nm->mkNode(kind::ITE, + node[1].eqNode(bv::utils::mkOne(1)), + nm->mkNode(kind::FLOATINGPOINT_NEG, fromubv), + fromubv)); + } + return RewriteResponse(REWRITE_DONE, node); + } + }; // namespace rewrite namespace constantFold { @@ -736,15 +760,16 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite) Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_REAL); TNode op = node.getOperator(); - const FloatingPointToFPReal ¶m = op.getConst<FloatingPointToFPReal>(); + const FloatingPointSize& size = + op.getConst<FloatingPointToFPReal>().getSize(); RoundingMode rm(node[0].getConst<RoundingMode>()); Rational arg(node[1].getConst<Rational>()); - FloatingPoint res(param.getSize(), rm, arg); + FloatingPoint res(size, rm, arg); Node lit = NodeManager::currentNM()->mkConst(res); - + return RewriteResponse(REWRITE_DONE, lit); } @@ -753,16 +778,27 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite) Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR); TNode op = node.getOperator(); - const FloatingPointToFPSignedBitVector ¶m = op.getConst<FloatingPointToFPSignedBitVector>(); + const FloatingPointSize& size = + op.getConst<FloatingPointToFPSignedBitVector>().getSize(); RoundingMode rm(node[0].getConst<RoundingMode>()); - BitVector arg(node[1].getConst<BitVector>()); + BitVector sbv(node[1].getConst<BitVector>()); - FloatingPoint res(param.getSize(), rm, arg, true); + NodeManager* nm = NodeManager::currentNM(); - Node lit = NodeManager::currentNM()->mkConst(res); - - return RewriteResponse(REWRITE_DONE, lit); + /* symFPU does not allow conversions from signed bit-vector of size 1 */ + if (sbv.getSize() == 1) + { + FloatingPoint fromubv(size, rm, sbv, false); + if (sbv.isBitSet(0)) + { + return RewriteResponse(REWRITE_DONE, nm->mkConst(fromubv.negate())); + } + return RewriteResponse(REWRITE_DONE, nm->mkConst(fromubv)); + } + + return RewriteResponse(REWRITE_DONE, + nm->mkConst(FloatingPoint(size, rm, sbv, true))); } RewriteResponse convertFromUBV(TNode node, bool isPreRewrite) @@ -770,15 +806,16 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite) Assert(node.getKind() == kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR); TNode op = node.getOperator(); - const FloatingPointToFPUnsignedBitVector ¶m = op.getConst<FloatingPointToFPUnsignedBitVector>(); + const FloatingPointSize& size = + op.getConst<FloatingPointToFPUnsignedBitVector>().getSize(); RoundingMode rm(node[0].getConst<RoundingMode>()); BitVector arg(node[1].getConst<BitVector>()); - FloatingPoint res(param.getSize(), rm, arg, false); + FloatingPoint res(size, rm, arg, false); Node lit = NodeManager::currentNM()->mkConst(res); - + return RewriteResponse(REWRITE_DONE, lit); } @@ -787,13 +824,12 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite) Assert(node.getKind() == kind::FLOATINGPOINT_TO_UBV); TNode op = node.getOperator(); - const FloatingPointToUBV ¶m = op.getConst<FloatingPointToUBV>(); + const BitVectorSize& size = op.getConst<FloatingPointToUBV>().d_bv_size; RoundingMode rm(node[0].getConst<RoundingMode>()); FloatingPoint arg(node[1].getConst<FloatingPoint>()); - FloatingPoint::PartialBitVector res( - arg.convertToBV(param.d_bv_size, rm, false)); + FloatingPoint::PartialBitVector res(arg.convertToBV(size, rm, false)); if (res.second) { Node lit = NodeManager::currentNM()->mkConst(res.first); @@ -809,13 +845,12 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite) Assert(node.getKind() == kind::FLOATINGPOINT_TO_SBV); TNode op = node.getOperator(); - const FloatingPointToSBV ¶m = op.getConst<FloatingPointToSBV>(); + const BitVectorSize& size = op.getConst<FloatingPointToSBV>().d_bv_size; RoundingMode rm(node[0].getConst<RoundingMode>()); FloatingPoint arg(node[1].getConst<FloatingPoint>()); - FloatingPoint::PartialBitVector res( - arg.convertToBV(param.d_bv_size, rm, true)); + FloatingPoint::PartialBitVector res(arg.convertToBV(size, rm, true)); if (res.second) { Node lit = NodeManager::currentNM()->mkConst(res.first); @@ -848,7 +883,8 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite) Assert(node.getKind() == kind::FLOATINGPOINT_TO_UBV_TOTAL); TNode op = node.getOperator(); - const FloatingPointToUBVTotal ¶m = op.getConst<FloatingPointToUBVTotal>(); + const BitVectorSize& size = + op.getConst<FloatingPointToUBVTotal>().d_bv_size; RoundingMode rm(node[0].getConst<RoundingMode>()); FloatingPoint arg(node[1].getConst<FloatingPoint>()); @@ -857,14 +893,12 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite) if (node[2].getMetaKind() == kind::metakind::CONSTANT) { BitVector partialValue(node[2].getConst<BitVector>()); - BitVector folded( - arg.convertToBVTotal(param.d_bv_size, rm, false, partialValue)); + BitVector folded(arg.convertToBVTotal(size, rm, false, partialValue)); Node lit = NodeManager::currentNM()->mkConst(folded); return RewriteResponse(REWRITE_DONE, lit); } else { - FloatingPoint::PartialBitVector res( - arg.convertToBV(param.d_bv_size, rm, false)); + FloatingPoint::PartialBitVector res(arg.convertToBV(size, rm, false)); if (res.second) { Node lit = NodeManager::currentNM()->mkConst(res.first); @@ -881,7 +915,8 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite) Assert(node.getKind() == kind::FLOATINGPOINT_TO_SBV_TOTAL); TNode op = node.getOperator(); - const FloatingPointToSBVTotal ¶m = op.getConst<FloatingPointToSBVTotal>(); + const BitVectorSize& size = + op.getConst<FloatingPointToSBVTotal>().d_bv_size; RoundingMode rm(node[0].getConst<RoundingMode>()); FloatingPoint arg(node[1].getConst<FloatingPoint>()); @@ -890,14 +925,12 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite) if (node[2].getMetaKind() == kind::metakind::CONSTANT) { BitVector partialValue(node[2].getConst<BitVector>()); - BitVector folded( - arg.convertToBVTotal(param.d_bv_size, rm, true, partialValue)); + BitVector folded(arg.convertToBVTotal(size, rm, true, partialValue)); Node lit = NodeManager::currentNM()->mkConst(folded); return RewriteResponse(REWRITE_DONE, lit); } else { - FloatingPoint::PartialBitVector res( - arg.convertToBV(param.d_bv_size, rm, true)); + FloatingPoint::PartialBitVector res(arg.convertToBV(size, rm, true)); if (res.second) { Node lit = NodeManager::currentNM()->mkConst(res.first); @@ -1049,7 +1082,7 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite) TheoryFpRewriter::TheoryFpRewriter(context::UserContext* u) : d_fpExpDef(u) { /* Set up the pre-rewrite dispatch table */ - for (unsigned i = 0; i < kind::LAST_KIND; ++i) + for (uint32_t i = 0; i < kind::LAST_KIND; ++i) { d_preRewriteTable[i] = rewrite::notFP; } @@ -1140,7 +1173,7 @@ TheoryFpRewriter::TheoryFpRewriter(context::UserContext* u) : d_fpExpDef(u) d_preRewriteTable[kind::ROUNDINGMODE_BITBLAST] = rewrite::identity; /* Set up the post-rewrite dispatch table */ - for (unsigned i = 0; i < kind::LAST_KIND; ++i) + for (uint32_t i = 0; i < kind::LAST_KIND; ++i) { d_postRewriteTable[i] = rewrite::notFP; } @@ -1197,7 +1230,7 @@ TheoryFpRewriter::TheoryFpRewriter(context::UserContext* u) : d_fpExpDef(u) rewrite::identity; d_postRewriteTable[kind::FLOATINGPOINT_TO_FP_REAL] = rewrite::identity; d_postRewriteTable[kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR] = - rewrite::identity; + rewrite::toFPSignedBV; d_postRewriteTable[kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR] = rewrite::identity; d_postRewriteTable[kind::FLOATINGPOINT_TO_FP_GENERIC] = rewrite::identity; @@ -1228,7 +1261,7 @@ TheoryFpRewriter::TheoryFpRewriter(context::UserContext* u) : d_fpExpDef(u) d_postRewriteTable[kind::ROUNDINGMODE_BITBLAST] = rewrite::identity; /* Set up the post-rewrite constant fold table */ - for (unsigned i = 0; i < kind::LAST_KIND; ++i) + for (uint32_t i = 0; i < kind::LAST_KIND; ++i) { // Note that this is identity, not notFP // Constant folding is called after post-rewrite |