diff options
Diffstat (limited to 'src/theory/evaluator.cpp')
-rw-r--r-- | src/theory/evaluator.cpp | 500 |
1 files changed, 497 insertions, 3 deletions
diff --git a/src/theory/evaluator.cpp b/src/theory/evaluator.cpp index ca2140ed5..925e6b052 100644 --- a/src/theory/evaluator.cpp +++ b/src/theory/evaluator.cpp @@ -17,6 +17,7 @@ #include "theory/evaluator.h" #include "theory/bv/theory_bv_utils.h" +#include "theory/rewriter.h" #include "theory/theory.h" #include "util/integer.h" @@ -33,6 +34,11 @@ EvalResult::EvalResult(const EvalResult& other) new (&d_bv) BitVector; d_bv = other.d_bv; break; + case ROUNDINGMODE: + new (&d_rm) RoundingMode; + d_rm = other.d_rm; + break; + case FLOATINGPOINT: new (&d_fp) FloatingPoint(other.d_fp); break; case RATIONAL: new (&d_rat) Rational; d_rat = other.d_rat; @@ -57,6 +63,11 @@ EvalResult& EvalResult::operator=(const EvalResult& other) new (&d_bv) BitVector; d_bv = other.d_bv; break; + case ROUNDINGMODE: + new (&d_rm) RoundingMode; + d_rm = other.d_rm; + break; + case FLOATINGPOINT: new (&d_fp) FloatingPoint(other.d_fp); break; case RATIONAL: new (&d_rat) Rational; d_rat = other.d_rat; @@ -80,6 +91,16 @@ EvalResult::~EvalResult() d_bv.~BitVector(); break; } + case ROUNDINGMODE: + { + d_rm.~RoundingMode(); + break; + } + case FLOATINGPOINT: + { + d_fp.~FloatingPoint(); + break; + } case RATIONAL: { d_rat.~Rational(); @@ -89,9 +110,9 @@ EvalResult::~EvalResult() { d_str.~String(); break; - - default: break; } + + default: break; } } @@ -102,6 +123,8 @@ Node EvalResult::toNode() const { case EvalResult::BOOL: return nm->mkConst(d_bool); case EvalResult::BITVECTOR: return nm->mkConst(d_bv); + case EvalResult::ROUNDINGMODE: return nm->mkConst(d_rm); + case EvalResult::FLOATINGPOINT: return nm->mkConst(d_fp); case EvalResult::RATIONAL: return nm->mkConst(d_rat); case EvalResult::STRING: return nm->mkConst(d_str); default: @@ -121,7 +144,12 @@ Node Evaluator::eval(TNode n, { Trace("evaluator") << "Evaluating " << n << " under substitution " << args << " " << vals << std::endl; - return evalInternal(n, args, vals).toNode(); + Node res = n.isConst() ? Node(n) : evalInternal(n, args, vals).toNode(); + Assert(res.isNull() + || res + == Rewriter::rewrite(n.substitute( + args.begin(), args.end(), vals.begin(), vals.end()))); + return res; } EvalResult Evaluator::evalInternal(TNode n, @@ -524,6 +552,460 @@ EvalResult Evaluator::evalInternal(TNode n, break; } + case kind::CONST_FLOATINGPOINT: + { + results[currNode] = EvalResult(currNodeVal.getConst<FloatingPoint>()); + break; + } + + case kind::CONST_ROUNDINGMODE: + { + results[currNode] = EvalResult(currNodeVal.getConst<RoundingMode>()); + break; + } + + case kind::FLOATINGPOINT_FP: + { + const BitVector& sign = results[currNode[0]].d_bv; + const BitVector& exp = results[currNode[1]].d_bv; + const BitVector& sig = results[currNode[2]].d_bv; + Assert(sign.getSize() == 1); + unsigned e = exp.getSize(); + unsigned s = sig.getSize() + 1; + results[currNode] = + EvalResult(FloatingPoint(e, s, sign.concat(exp.concat(sig)))); + break; + } + + case kind::FLOATINGPOINT_EQ: + case kind::FLOATINGPOINT_GEQ: + case kind::FLOATINGPOINT_GT: + { + // These kinds should have been removed by earlier rewrites. In + // debug, we would like to know about this. In production, we can + // just return an invalid result and fall back to the full rewriter. + Assert(false); + return EvalResult(); + } + + case kind::FLOATINGPOINT_ABS: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.absolute()); + break; + } + + case kind::FLOATINGPOINT_NEG: + { + results[currNode] = EvalResult(results[currNode[0]].d_fp.negate()); + break; + } + + case kind::FLOATINGPOINT_PLUS: + { + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& lhs = results[currNode[1]].d_fp; + const FloatingPoint& rhs = results[currNode[2]].d_fp; + results[currNode] = EvalResult(lhs.plus(rm, rhs)); + break; + } + + case kind::FLOATINGPOINT_SUB: + { + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& lhs = results[currNode[1]].d_fp; + const FloatingPoint& rhs = results[currNode[2]].d_fp; + results[currNode] = EvalResult(lhs.sub(rm, rhs)); + break; + } + + case kind::FLOATINGPOINT_MULT: + { + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& lhs = results[currNode[1]].d_fp; + const FloatingPoint& rhs = results[currNode[2]].d_fp; + results[currNode] = EvalResult(lhs.mult(rm, rhs)); + break; + } + + case kind::FLOATINGPOINT_DIV: + { + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& lhs = results[currNode[1]].d_fp; + const FloatingPoint& rhs = results[currNode[2]].d_fp; + results[currNode] = EvalResult(lhs.div(rm, rhs)); + break; + } + + case kind::FLOATINGPOINT_FMA: + { + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& val = results[currNode[1]].d_fp; + const FloatingPoint& fac = results[currNode[2]].d_fp; + const FloatingPoint& add = results[currNode[3]].d_fp; + results[currNode] = EvalResult(val.fma(rm, fac, add)); + break; + } + + case kind::FLOATINGPOINT_SQRT: + { + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& val = results[currNode[1]].d_fp; + results[currNode] = EvalResult(val.sqrt(rm)); + break; + } + + case kind::FLOATINGPOINT_REM: + { + const FloatingPoint& lhs = results.at(currNode[0]).d_fp; + const FloatingPoint& rhs = results.at(currNode[1]).d_fp; + FloatingPoint res = lhs.rem(rhs); + results[currNode] = EvalResult(lhs.rem(rhs)); + break; + } + + case kind::FLOATINGPOINT_RTI: + { + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& val = results[currNode[1]].d_fp; + results[currNode] = EvalResult(val.rti(rm)); + break; + } + + case kind::FLOATINGPOINT_MIN: + { + const FloatingPoint& lhs = results[currNode[0]].d_fp; + const FloatingPoint& rhs = results[currNode[1]].d_fp; + + FloatingPoint::PartialFloatingPoint res(lhs.min(rhs)); + + if (res.second) + { + results[currNode] = EvalResult(res.first); + } + else + { + // Can't constant fold the underspecified case + return EvalResult(); + } + break; + } + + case kind::FLOATINGPOINT_MAX: + { + const FloatingPoint& lhs = results[currNode[0]].d_fp; + const FloatingPoint& rhs = results[currNode[1]].d_fp; + + FloatingPoint::PartialFloatingPoint res(lhs.max(rhs)); + + if (res.second) + { + results[currNode] = EvalResult(res.first); + } + else + { + // Can't constant fold the underspecified case + return EvalResult(); + } + break; + } + + case kind::FLOATINGPOINT_MIN_TOTAL: + { + const FloatingPoint& lhs = results[currNode[0]].d_fp; + const FloatingPoint& rhs = results[currNode[1]].d_fp; + const BitVector& zeroCaseBv = results[currNode[2]].d_bv; + Assert(zeroCaseBv.getSize() == 1); + bool zeroCase = zeroCaseBv.isBitSet(0); + results[currNode] = EvalResult(lhs.minTotal(rhs, zeroCase)); + break; + } + + case kind::FLOATINGPOINT_MAX_TOTAL: + { + const FloatingPoint& lhs = results[currNode[0]].d_fp; + const FloatingPoint& rhs = results[currNode[1]].d_fp; + const BitVector& zeroCaseBv = results[currNode[2]].d_bv; + Assert(zeroCaseBv.getSize() == 1); + bool zeroCase = zeroCaseBv.isBitSet(0); + results[currNode] = EvalResult(lhs.maxTotal(rhs, zeroCase)); + break; + } + + case kind::FLOATINGPOINT_LT: + { + const FloatingPoint& lhs = results[currNode[0]].d_fp; + const FloatingPoint& rhs = results[currNode[1]].d_fp; + results[currNode] = EvalResult(lhs < rhs); + break; + } + + case kind::FLOATINGPOINT_ISN: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.isNormal()); + break; + } + + case kind::FLOATINGPOINT_ISSN: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.isNormal()); + break; + } + + case kind::FLOATINGPOINT_ISZ: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.isZero()); + break; + } + + case kind::FLOATINGPOINT_ISINF: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.isInfinite()); + break; + } + + case kind::FLOATINGPOINT_ISNAN: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.isNaN()); + break; + } + + case kind::FLOATINGPOINT_ISNEG: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.isNegative()); + break; + } + + case kind::FLOATINGPOINT_ISPOS: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.isPositive()); + break; + } + + case kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR: + { + TNode op = currNode.getOperator(); + const FloatingPointToFPIEEEBitVector& param = + op.getConst<FloatingPointToFPIEEEBitVector>(); + const BitVector& val = results[currNode[0]].d_bv; + results[currNode] = EvalResult( + FloatingPoint(param.t.exponent(), param.t.significand(), val)); + break; + } + + case kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT: + { + TNode op = currNode.getOperator(); + const FloatingPointToFPFloatingPoint& param = + op.getConst<FloatingPointToFPFloatingPoint>(); + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& val = results[currNode[1]].d_fp; + + results[currNode] = EvalResult(val.convert(param.t, rm)); + break; + } + + case kind::FLOATINGPOINT_TO_FP_REAL: + { + TNode op = currNode.getOperator(); + const FloatingPointToFPFloatingPoint& param = + op.getConst<FloatingPointToFPFloatingPoint>(); + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& val = results[currNode[1]].d_fp; + + results[currNode] = EvalResult(val.convert(param.t, rm)); + break; + } + + case kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR: + { + TNode op = currNode.getOperator(); + const FloatingPointToFPFloatingPoint& param = + op.getConst<FloatingPointToFPSignedBitVector>(); + const RoundingMode& rm = results[currNode[0]].d_rm; + const BitVector& val = results[currNode[1]].d_bv; + + results[currNode] = EvalResult(FloatingPoint(param.t, rm, val, true)); + break; + } + + case kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR: + { + TNode op = currNode.getOperator(); + const FloatingPointToFPFloatingPoint& param = + op.getConst<FloatingPointToFPUnsignedBitVector>(); + const RoundingMode& rm = results[currNode[0]].d_rm; + const BitVector& val = results[currNode[1]].d_bv; + + results[currNode] = + EvalResult(FloatingPoint(param.t, rm, val, false)); + break; + } + + case kind::FLOATINGPOINT_TO_UBV: + { + TNode op = currNode.getOperator(); + const FloatingPointToUBV& param = + op.getOperator().getConst<FloatingPointToUBV>(); + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& val = results[currNode[1]].d_fp; + + FloatingPoint::PartialBitVector res( + val.convertToBV(param.bvs, rm, false)); + + if (res.second) + { + results[currNode] = EvalResult(res.first); + } + else + { + // Can't constant fold the underspecified case + return EvalResult(); + } + + break; + } + + case kind::FLOATINGPOINT_TO_UBV_TOTAL: + { + TNode op = currNode.getOperator(); + const FloatingPointToUBVTotal& param = + op.getOperator().getConst<FloatingPointToUBVTotal>(); + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& val = results[currNode[1]].d_fp; + const BitVector& partial = results[currNode[2]].d_bv; + + results[currNode] = + EvalResult(val.convertToBVTotal(param.bvs, rm, false, partial)); + break; + } + + case kind::FLOATINGPOINT_TO_SBV: + { + TNode op = currNode.getOperator(); + const FloatingPointToSBV& param = + op.getOperator().getConst<FloatingPointToSBV>(); + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& val = results[currNode[1]].d_fp; + + FloatingPoint::PartialBitVector res( + val.convertToBV(param.bvs, rm, true)); + + if (res.second) + { + results[currNode] = EvalResult(res.first); + } + else + { + // Can't constant fold the underspecified case + return EvalResult(); + } + break; + } + + case kind::FLOATINGPOINT_TO_SBV_TOTAL: + { + TNode op = currNode.getOperator(); + const FloatingPointToSBVTotal& param = + op.getConst<FloatingPointToSBVTotal>(); + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& val = results[currNode[1]].d_fp; + const BitVector& partial = results[currNode[2]].d_bv; + + results[currNode] = + EvalResult(val.convertToBVTotal(param.bvs, rm, true, partial)); + break; + } + + case kind::FLOATINGPOINT_TO_REAL: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + + FloatingPoint::PartialRational res(val.convertToRational()); + + if (res.second) + { + results[currNode] = EvalResult(res.first); + } + else + { + // Can't constant fold the underspecified case + return EvalResult(); + } + break; + } + + case kind::FLOATINGPOINT_TO_REAL_TOTAL: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + const Rational& partial = results[currNode[1]].d_rat; + + results[currNode] = EvalResult(val.convertToRationalTotal(partial)); + break; + } + +#ifdef CVC4_USE_SYMFPU + case kind::FLOATINGPOINT_COMPONENT_NAN: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.getLiteral().nan); + break; + } + + case kind::FLOATINGPOINT_COMPONENT_INF: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.getLiteral().inf); + break; + } + + case kind::FLOATINGPOINT_COMPONENT_ZERO: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.getLiteral().zero); + break; + } + + case kind::FLOATINGPOINT_COMPONENT_SIGN: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.getLiteral().sign); + break; + } + + case kind::FLOATINGPOINT_COMPONENT_EXPONENT: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.getLiteral().exponent); + break; + } + + case kind::FLOATINGPOINT_COMPONENT_SIGNIFICAND: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.getLiteral().significand); + break; + } +#else + case kind::FLOATINGPOINT_COMPONENT_NAN: + case kind::FLOATINGPOINT_COMPONENT_INF: + case kind::FLOATINGPOINT_COMPONENT_ZERO: + case kind::FLOATINGPOINT_COMPONENT_SIGN: + case kind::FLOATINGPOINT_COMPONENT_EXPONENT: + case kind::FLOATINGPOINT_COMPONENT_SIGNIFICAND: + { + // symfpu is required for these operators + return EvalResult(); + } +#endif + case kind::EQUAL: { EvalResult lhs = results[currNode[0]]; @@ -543,6 +1025,18 @@ EvalResult Evaluator::evalInternal(TNode n, break; } + case EvalResult::ROUNDINGMODE: + { + results[currNode] = EvalResult(lhs.d_rm == rhs.d_rm); + break; + } + + case EvalResult::FLOATINGPOINT: + { + results[currNode] = EvalResult(lhs.d_fp == rhs.d_fp); + break; + } + case EvalResult::RATIONAL: { results[currNode] = EvalResult(lhs.d_rat == rhs.d_rat); |