summaryrefslogtreecommitdiff
path: root/src/theory/evaluator.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/theory/evaluator.cpp')
-rw-r--r--src/theory/evaluator.cpp500
1 files changed, 497 insertions, 3 deletions
diff --git a/src/theory/evaluator.cpp b/src/theory/evaluator.cpp
index ca2140ed5..925e6b052 100644
--- a/src/theory/evaluator.cpp
+++ b/src/theory/evaluator.cpp
@@ -17,6 +17,7 @@
#include "theory/evaluator.h"
#include "theory/bv/theory_bv_utils.h"
+#include "theory/rewriter.h"
#include "theory/theory.h"
#include "util/integer.h"
@@ -33,6 +34,11 @@ EvalResult::EvalResult(const EvalResult& other)
new (&d_bv) BitVector;
d_bv = other.d_bv;
break;
+ case ROUNDINGMODE:
+ new (&d_rm) RoundingMode;
+ d_rm = other.d_rm;
+ break;
+ case FLOATINGPOINT: new (&d_fp) FloatingPoint(other.d_fp); break;
case RATIONAL:
new (&d_rat) Rational;
d_rat = other.d_rat;
@@ -57,6 +63,11 @@ EvalResult& EvalResult::operator=(const EvalResult& other)
new (&d_bv) BitVector;
d_bv = other.d_bv;
break;
+ case ROUNDINGMODE:
+ new (&d_rm) RoundingMode;
+ d_rm = other.d_rm;
+ break;
+ case FLOATINGPOINT: new (&d_fp) FloatingPoint(other.d_fp); break;
case RATIONAL:
new (&d_rat) Rational;
d_rat = other.d_rat;
@@ -80,6 +91,16 @@ EvalResult::~EvalResult()
d_bv.~BitVector();
break;
}
+ case ROUNDINGMODE:
+ {
+ d_rm.~RoundingMode();
+ break;
+ }
+ case FLOATINGPOINT:
+ {
+ d_fp.~FloatingPoint();
+ break;
+ }
case RATIONAL:
{
d_rat.~Rational();
@@ -89,9 +110,9 @@ EvalResult::~EvalResult()
{
d_str.~String();
break;
-
- default: break;
}
+
+ default: break;
}
}
@@ -102,6 +123,8 @@ Node EvalResult::toNode() const
{
case EvalResult::BOOL: return nm->mkConst(d_bool);
case EvalResult::BITVECTOR: return nm->mkConst(d_bv);
+ case EvalResult::ROUNDINGMODE: return nm->mkConst(d_rm);
+ case EvalResult::FLOATINGPOINT: return nm->mkConst(d_fp);
case EvalResult::RATIONAL: return nm->mkConst(d_rat);
case EvalResult::STRING: return nm->mkConst(d_str);
default:
@@ -121,7 +144,12 @@ Node Evaluator::eval(TNode n,
{
Trace("evaluator") << "Evaluating " << n << " under substitution " << args
<< " " << vals << std::endl;
- return evalInternal(n, args, vals).toNode();
+ Node res = n.isConst() ? Node(n) : evalInternal(n, args, vals).toNode();
+ Assert(res.isNull()
+ || res
+ == Rewriter::rewrite(n.substitute(
+ args.begin(), args.end(), vals.begin(), vals.end())));
+ return res;
}
EvalResult Evaluator::evalInternal(TNode n,
@@ -524,6 +552,460 @@ EvalResult Evaluator::evalInternal(TNode n,
break;
}
+ case kind::CONST_FLOATINGPOINT:
+ {
+ results[currNode] = EvalResult(currNodeVal.getConst<FloatingPoint>());
+ break;
+ }
+
+ case kind::CONST_ROUNDINGMODE:
+ {
+ results[currNode] = EvalResult(currNodeVal.getConst<RoundingMode>());
+ break;
+ }
+
+ case kind::FLOATINGPOINT_FP:
+ {
+ const BitVector& sign = results[currNode[0]].d_bv;
+ const BitVector& exp = results[currNode[1]].d_bv;
+ const BitVector& sig = results[currNode[2]].d_bv;
+ Assert(sign.getSize() == 1);
+ unsigned e = exp.getSize();
+ unsigned s = sig.getSize() + 1;
+ results[currNode] =
+ EvalResult(FloatingPoint(e, s, sign.concat(exp.concat(sig))));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_EQ:
+ case kind::FLOATINGPOINT_GEQ:
+ case kind::FLOATINGPOINT_GT:
+ {
+ // These kinds should have been removed by earlier rewrites. In
+ // debug, we would like to know about this. In production, we can
+ // just return an invalid result and fall back to the full rewriter.
+ Assert(false);
+ return EvalResult();
+ }
+
+ case kind::FLOATINGPOINT_ABS:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.absolute());
+ break;
+ }
+
+ case kind::FLOATINGPOINT_NEG:
+ {
+ results[currNode] = EvalResult(results[currNode[0]].d_fp.negate());
+ break;
+ }
+
+ case kind::FLOATINGPOINT_PLUS:
+ {
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& lhs = results[currNode[1]].d_fp;
+ const FloatingPoint& rhs = results[currNode[2]].d_fp;
+ results[currNode] = EvalResult(lhs.plus(rm, rhs));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_SUB:
+ {
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& lhs = results[currNode[1]].d_fp;
+ const FloatingPoint& rhs = results[currNode[2]].d_fp;
+ results[currNode] = EvalResult(lhs.sub(rm, rhs));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_MULT:
+ {
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& lhs = results[currNode[1]].d_fp;
+ const FloatingPoint& rhs = results[currNode[2]].d_fp;
+ results[currNode] = EvalResult(lhs.mult(rm, rhs));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_DIV:
+ {
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& lhs = results[currNode[1]].d_fp;
+ const FloatingPoint& rhs = results[currNode[2]].d_fp;
+ results[currNode] = EvalResult(lhs.div(rm, rhs));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_FMA:
+ {
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& val = results[currNode[1]].d_fp;
+ const FloatingPoint& fac = results[currNode[2]].d_fp;
+ const FloatingPoint& add = results[currNode[3]].d_fp;
+ results[currNode] = EvalResult(val.fma(rm, fac, add));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_SQRT:
+ {
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& val = results[currNode[1]].d_fp;
+ results[currNode] = EvalResult(val.sqrt(rm));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_REM:
+ {
+ const FloatingPoint& lhs = results.at(currNode[0]).d_fp;
+ const FloatingPoint& rhs = results.at(currNode[1]).d_fp;
+ FloatingPoint res = lhs.rem(rhs);
+ results[currNode] = EvalResult(lhs.rem(rhs));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_RTI:
+ {
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& val = results[currNode[1]].d_fp;
+ results[currNode] = EvalResult(val.rti(rm));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_MIN:
+ {
+ const FloatingPoint& lhs = results[currNode[0]].d_fp;
+ const FloatingPoint& rhs = results[currNode[1]].d_fp;
+
+ FloatingPoint::PartialFloatingPoint res(lhs.min(rhs));
+
+ if (res.second)
+ {
+ results[currNode] = EvalResult(res.first);
+ }
+ else
+ {
+ // Can't constant fold the underspecified case
+ return EvalResult();
+ }
+ break;
+ }
+
+ case kind::FLOATINGPOINT_MAX:
+ {
+ const FloatingPoint& lhs = results[currNode[0]].d_fp;
+ const FloatingPoint& rhs = results[currNode[1]].d_fp;
+
+ FloatingPoint::PartialFloatingPoint res(lhs.max(rhs));
+
+ if (res.second)
+ {
+ results[currNode] = EvalResult(res.first);
+ }
+ else
+ {
+ // Can't constant fold the underspecified case
+ return EvalResult();
+ }
+ break;
+ }
+
+ case kind::FLOATINGPOINT_MIN_TOTAL:
+ {
+ const FloatingPoint& lhs = results[currNode[0]].d_fp;
+ const FloatingPoint& rhs = results[currNode[1]].d_fp;
+ const BitVector& zeroCaseBv = results[currNode[2]].d_bv;
+ Assert(zeroCaseBv.getSize() == 1);
+ bool zeroCase = zeroCaseBv.isBitSet(0);
+ results[currNode] = EvalResult(lhs.minTotal(rhs, zeroCase));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_MAX_TOTAL:
+ {
+ const FloatingPoint& lhs = results[currNode[0]].d_fp;
+ const FloatingPoint& rhs = results[currNode[1]].d_fp;
+ const BitVector& zeroCaseBv = results[currNode[2]].d_bv;
+ Assert(zeroCaseBv.getSize() == 1);
+ bool zeroCase = zeroCaseBv.isBitSet(0);
+ results[currNode] = EvalResult(lhs.maxTotal(rhs, zeroCase));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_LT:
+ {
+ const FloatingPoint& lhs = results[currNode[0]].d_fp;
+ const FloatingPoint& rhs = results[currNode[1]].d_fp;
+ results[currNode] = EvalResult(lhs < rhs);
+ break;
+ }
+
+ case kind::FLOATINGPOINT_ISN:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.isNormal());
+ break;
+ }
+
+ case kind::FLOATINGPOINT_ISSN:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.isNormal());
+ break;
+ }
+
+ case kind::FLOATINGPOINT_ISZ:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.isZero());
+ break;
+ }
+
+ case kind::FLOATINGPOINT_ISINF:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.isInfinite());
+ break;
+ }
+
+ case kind::FLOATINGPOINT_ISNAN:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.isNaN());
+ break;
+ }
+
+ case kind::FLOATINGPOINT_ISNEG:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.isNegative());
+ break;
+ }
+
+ case kind::FLOATINGPOINT_ISPOS:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.isPositive());
+ break;
+ }
+
+ case kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR:
+ {
+ TNode op = currNode.getOperator();
+ const FloatingPointToFPIEEEBitVector& param =
+ op.getConst<FloatingPointToFPIEEEBitVector>();
+ const BitVector& val = results[currNode[0]].d_bv;
+ results[currNode] = EvalResult(
+ FloatingPoint(param.t.exponent(), param.t.significand(), val));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT:
+ {
+ TNode op = currNode.getOperator();
+ const FloatingPointToFPFloatingPoint& param =
+ op.getConst<FloatingPointToFPFloatingPoint>();
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& val = results[currNode[1]].d_fp;
+
+ results[currNode] = EvalResult(val.convert(param.t, rm));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_TO_FP_REAL:
+ {
+ TNode op = currNode.getOperator();
+ const FloatingPointToFPFloatingPoint& param =
+ op.getConst<FloatingPointToFPFloatingPoint>();
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& val = results[currNode[1]].d_fp;
+
+ results[currNode] = EvalResult(val.convert(param.t, rm));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR:
+ {
+ TNode op = currNode.getOperator();
+ const FloatingPointToFPFloatingPoint& param =
+ op.getConst<FloatingPointToFPSignedBitVector>();
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const BitVector& val = results[currNode[1]].d_bv;
+
+ results[currNode] = EvalResult(FloatingPoint(param.t, rm, val, true));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR:
+ {
+ TNode op = currNode.getOperator();
+ const FloatingPointToFPFloatingPoint& param =
+ op.getConst<FloatingPointToFPUnsignedBitVector>();
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const BitVector& val = results[currNode[1]].d_bv;
+
+ results[currNode] =
+ EvalResult(FloatingPoint(param.t, rm, val, false));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_TO_UBV:
+ {
+ TNode op = currNode.getOperator();
+ const FloatingPointToUBV& param =
+ op.getOperator().getConst<FloatingPointToUBV>();
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& val = results[currNode[1]].d_fp;
+
+ FloatingPoint::PartialBitVector res(
+ val.convertToBV(param.bvs, rm, false));
+
+ if (res.second)
+ {
+ results[currNode] = EvalResult(res.first);
+ }
+ else
+ {
+ // Can't constant fold the underspecified case
+ return EvalResult();
+ }
+
+ break;
+ }
+
+ case kind::FLOATINGPOINT_TO_UBV_TOTAL:
+ {
+ TNode op = currNode.getOperator();
+ const FloatingPointToUBVTotal& param =
+ op.getOperator().getConst<FloatingPointToUBVTotal>();
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& val = results[currNode[1]].d_fp;
+ const BitVector& partial = results[currNode[2]].d_bv;
+
+ results[currNode] =
+ EvalResult(val.convertToBVTotal(param.bvs, rm, false, partial));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_TO_SBV:
+ {
+ TNode op = currNode.getOperator();
+ const FloatingPointToSBV& param =
+ op.getOperator().getConst<FloatingPointToSBV>();
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& val = results[currNode[1]].d_fp;
+
+ FloatingPoint::PartialBitVector res(
+ val.convertToBV(param.bvs, rm, true));
+
+ if (res.second)
+ {
+ results[currNode] = EvalResult(res.first);
+ }
+ else
+ {
+ // Can't constant fold the underspecified case
+ return EvalResult();
+ }
+ break;
+ }
+
+ case kind::FLOATINGPOINT_TO_SBV_TOTAL:
+ {
+ TNode op = currNode.getOperator();
+ const FloatingPointToSBVTotal& param =
+ op.getConst<FloatingPointToSBVTotal>();
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& val = results[currNode[1]].d_fp;
+ const BitVector& partial = results[currNode[2]].d_bv;
+
+ results[currNode] =
+ EvalResult(val.convertToBVTotal(param.bvs, rm, true, partial));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_TO_REAL:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+
+ FloatingPoint::PartialRational res(val.convertToRational());
+
+ if (res.second)
+ {
+ results[currNode] = EvalResult(res.first);
+ }
+ else
+ {
+ // Can't constant fold the underspecified case
+ return EvalResult();
+ }
+ break;
+ }
+
+ case kind::FLOATINGPOINT_TO_REAL_TOTAL:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ const Rational& partial = results[currNode[1]].d_rat;
+
+ results[currNode] = EvalResult(val.convertToRationalTotal(partial));
+ break;
+ }
+
+#ifdef CVC4_USE_SYMFPU
+ case kind::FLOATINGPOINT_COMPONENT_NAN:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.getLiteral().nan);
+ break;
+ }
+
+ case kind::FLOATINGPOINT_COMPONENT_INF:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.getLiteral().inf);
+ break;
+ }
+
+ case kind::FLOATINGPOINT_COMPONENT_ZERO:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.getLiteral().zero);
+ break;
+ }
+
+ case kind::FLOATINGPOINT_COMPONENT_SIGN:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.getLiteral().sign);
+ break;
+ }
+
+ case kind::FLOATINGPOINT_COMPONENT_EXPONENT:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.getLiteral().exponent);
+ break;
+ }
+
+ case kind::FLOATINGPOINT_COMPONENT_SIGNIFICAND:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.getLiteral().significand);
+ break;
+ }
+#else
+ case kind::FLOATINGPOINT_COMPONENT_NAN:
+ case kind::FLOATINGPOINT_COMPONENT_INF:
+ case kind::FLOATINGPOINT_COMPONENT_ZERO:
+ case kind::FLOATINGPOINT_COMPONENT_SIGN:
+ case kind::FLOATINGPOINT_COMPONENT_EXPONENT:
+ case kind::FLOATINGPOINT_COMPONENT_SIGNIFICAND:
+ {
+ // symfpu is required for these operators
+ return EvalResult();
+ }
+#endif
+
case kind::EQUAL:
{
EvalResult lhs = results[currNode[0]];
@@ -543,6 +1025,18 @@ EvalResult Evaluator::evalInternal(TNode n,
break;
}
+ case EvalResult::ROUNDINGMODE:
+ {
+ results[currNode] = EvalResult(lhs.d_rm == rhs.d_rm);
+ break;
+ }
+
+ case EvalResult::FLOATINGPOINT:
+ {
+ results[currNode] = EvalResult(lhs.d_fp == rhs.d_fp);
+ break;
+ }
+
case EvalResult::RATIONAL:
{
results[currNode] = EvalResult(lhs.d_rat == rhs.d_rat);
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback