diff options
author | Andres Noetzli <andres.noetzli@gmail.com> | 2018-07-09 16:26:01 -0700 |
---|---|---|
committer | Andres Noetzli <andres.noetzli@gmail.com> | 2018-07-13 17:39:52 -0700 |
commit | 0631bbb1716dc7343cfb0c3a4b447c7e667dc5d2 (patch) | |
tree | 660f2fa6e70991d004c80c6eace2cd043a98bae3 | |
parent | 86d9ba4431108e1fd89639e23857631a7380a005 (diff) |
Add floating-point support in evaluatoreval_fp
Currently, the operations implemented by the FloatingPoint class are
fairly slow, so it is not always beneficial to use the Evaluator as
opposed to the rewriter because the rewriter does more aggressive
caching. Thus, the evaluator is disabled by default for floating-point
logics. Note that this commit also adds a check for the --sygus-eval-opt
flag in SygusSampler::evaluate() (previously, it was always used in that
method).
-rw-r--r-- | src/smt/smt_engine.cpp | 8 | ||||
-rw-r--r-- | src/theory/evaluator.cpp | 500 | ||||
-rw-r--r-- | src/theory/evaluator.h | 13 | ||||
-rw-r--r-- | src/theory/quantifiers/sygus/term_database_sygus.cpp | 5 | ||||
-rw-r--r-- | src/theory/quantifiers/sygus_sampler.cpp | 21 |
5 files changed, 531 insertions, 16 deletions
diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index 22916e354..457e9a486 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -1312,6 +1312,14 @@ void SmtEngine::setDefaults() { options::cbqi.set(true); } } + if (d_logic.isTheoryEnabled(THEORY_FP) + && !options::sygusEvalOpt.wasSetByUser()) + { + Notice() << "SmtEngine: turning off evaluator for floating-point logics " + "for better performance" + << endl; + options::sygusEvalOpt.set(false); + } } if (options::bitblastMode() == theory::bv::BITBLAST_MODE_EAGER) diff --git a/src/theory/evaluator.cpp b/src/theory/evaluator.cpp index ca2140ed5..925e6b052 100644 --- a/src/theory/evaluator.cpp +++ b/src/theory/evaluator.cpp @@ -17,6 +17,7 @@ #include "theory/evaluator.h" #include "theory/bv/theory_bv_utils.h" +#include "theory/rewriter.h" #include "theory/theory.h" #include "util/integer.h" @@ -33,6 +34,11 @@ EvalResult::EvalResult(const EvalResult& other) new (&d_bv) BitVector; d_bv = other.d_bv; break; + case ROUNDINGMODE: + new (&d_rm) RoundingMode; + d_rm = other.d_rm; + break; + case FLOATINGPOINT: new (&d_fp) FloatingPoint(other.d_fp); break; case RATIONAL: new (&d_rat) Rational; d_rat = other.d_rat; @@ -57,6 +63,11 @@ EvalResult& EvalResult::operator=(const EvalResult& other) new (&d_bv) BitVector; d_bv = other.d_bv; break; + case ROUNDINGMODE: + new (&d_rm) RoundingMode; + d_rm = other.d_rm; + break; + case FLOATINGPOINT: new (&d_fp) FloatingPoint(other.d_fp); break; case RATIONAL: new (&d_rat) Rational; d_rat = other.d_rat; @@ -80,6 +91,16 @@ EvalResult::~EvalResult() d_bv.~BitVector(); break; } + case ROUNDINGMODE: + { + d_rm.~RoundingMode(); + break; + } + case FLOATINGPOINT: + { + d_fp.~FloatingPoint(); + break; + } case RATIONAL: { d_rat.~Rational(); @@ -89,9 +110,9 @@ EvalResult::~EvalResult() { d_str.~String(); break; - - default: break; } + + default: break; } } @@ -102,6 +123,8 @@ Node EvalResult::toNode() const { case EvalResult::BOOL: return nm->mkConst(d_bool); case EvalResult::BITVECTOR: return nm->mkConst(d_bv); + case EvalResult::ROUNDINGMODE: return nm->mkConst(d_rm); + case EvalResult::FLOATINGPOINT: return nm->mkConst(d_fp); case EvalResult::RATIONAL: return nm->mkConst(d_rat); case EvalResult::STRING: return nm->mkConst(d_str); default: @@ -121,7 +144,12 @@ Node Evaluator::eval(TNode n, { Trace("evaluator") << "Evaluating " << n << " under substitution " << args << " " << vals << std::endl; - return evalInternal(n, args, vals).toNode(); + Node res = n.isConst() ? Node(n) : evalInternal(n, args, vals).toNode(); + Assert(res.isNull() + || res + == Rewriter::rewrite(n.substitute( + args.begin(), args.end(), vals.begin(), vals.end()))); + return res; } EvalResult Evaluator::evalInternal(TNode n, @@ -524,6 +552,460 @@ EvalResult Evaluator::evalInternal(TNode n, break; } + case kind::CONST_FLOATINGPOINT: + { + results[currNode] = EvalResult(currNodeVal.getConst<FloatingPoint>()); + break; + } + + case kind::CONST_ROUNDINGMODE: + { + results[currNode] = EvalResult(currNodeVal.getConst<RoundingMode>()); + break; + } + + case kind::FLOATINGPOINT_FP: + { + const BitVector& sign = results[currNode[0]].d_bv; + const BitVector& exp = results[currNode[1]].d_bv; + const BitVector& sig = results[currNode[2]].d_bv; + Assert(sign.getSize() == 1); + unsigned e = exp.getSize(); + unsigned s = sig.getSize() + 1; + results[currNode] = + EvalResult(FloatingPoint(e, s, sign.concat(exp.concat(sig)))); + break; + } + + case kind::FLOATINGPOINT_EQ: + case kind::FLOATINGPOINT_GEQ: + case kind::FLOATINGPOINT_GT: + { + // These kinds should have been removed by earlier rewrites. In + // debug, we would like to know about this. In production, we can + // just return an invalid result and fall back to the full rewriter. + Assert(false); + return EvalResult(); + } + + case kind::FLOATINGPOINT_ABS: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.absolute()); + break; + } + + case kind::FLOATINGPOINT_NEG: + { + results[currNode] = EvalResult(results[currNode[0]].d_fp.negate()); + break; + } + + case kind::FLOATINGPOINT_PLUS: + { + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& lhs = results[currNode[1]].d_fp; + const FloatingPoint& rhs = results[currNode[2]].d_fp; + results[currNode] = EvalResult(lhs.plus(rm, rhs)); + break; + } + + case kind::FLOATINGPOINT_SUB: + { + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& lhs = results[currNode[1]].d_fp; + const FloatingPoint& rhs = results[currNode[2]].d_fp; + results[currNode] = EvalResult(lhs.sub(rm, rhs)); + break; + } + + case kind::FLOATINGPOINT_MULT: + { + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& lhs = results[currNode[1]].d_fp; + const FloatingPoint& rhs = results[currNode[2]].d_fp; + results[currNode] = EvalResult(lhs.mult(rm, rhs)); + break; + } + + case kind::FLOATINGPOINT_DIV: + { + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& lhs = results[currNode[1]].d_fp; + const FloatingPoint& rhs = results[currNode[2]].d_fp; + results[currNode] = EvalResult(lhs.div(rm, rhs)); + break; + } + + case kind::FLOATINGPOINT_FMA: + { + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& val = results[currNode[1]].d_fp; + const FloatingPoint& fac = results[currNode[2]].d_fp; + const FloatingPoint& add = results[currNode[3]].d_fp; + results[currNode] = EvalResult(val.fma(rm, fac, add)); + break; + } + + case kind::FLOATINGPOINT_SQRT: + { + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& val = results[currNode[1]].d_fp; + results[currNode] = EvalResult(val.sqrt(rm)); + break; + } + + case kind::FLOATINGPOINT_REM: + { + const FloatingPoint& lhs = results.at(currNode[0]).d_fp; + const FloatingPoint& rhs = results.at(currNode[1]).d_fp; + FloatingPoint res = lhs.rem(rhs); + results[currNode] = EvalResult(lhs.rem(rhs)); + break; + } + + case kind::FLOATINGPOINT_RTI: + { + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& val = results[currNode[1]].d_fp; + results[currNode] = EvalResult(val.rti(rm)); + break; + } + + case kind::FLOATINGPOINT_MIN: + { + const FloatingPoint& lhs = results[currNode[0]].d_fp; + const FloatingPoint& rhs = results[currNode[1]].d_fp; + + FloatingPoint::PartialFloatingPoint res(lhs.min(rhs)); + + if (res.second) + { + results[currNode] = EvalResult(res.first); + } + else + { + // Can't constant fold the underspecified case + return EvalResult(); + } + break; + } + + case kind::FLOATINGPOINT_MAX: + { + const FloatingPoint& lhs = results[currNode[0]].d_fp; + const FloatingPoint& rhs = results[currNode[1]].d_fp; + + FloatingPoint::PartialFloatingPoint res(lhs.max(rhs)); + + if (res.second) + { + results[currNode] = EvalResult(res.first); + } + else + { + // Can't constant fold the underspecified case + return EvalResult(); + } + break; + } + + case kind::FLOATINGPOINT_MIN_TOTAL: + { + const FloatingPoint& lhs = results[currNode[0]].d_fp; + const FloatingPoint& rhs = results[currNode[1]].d_fp; + const BitVector& zeroCaseBv = results[currNode[2]].d_bv; + Assert(zeroCaseBv.getSize() == 1); + bool zeroCase = zeroCaseBv.isBitSet(0); + results[currNode] = EvalResult(lhs.minTotal(rhs, zeroCase)); + break; + } + + case kind::FLOATINGPOINT_MAX_TOTAL: + { + const FloatingPoint& lhs = results[currNode[0]].d_fp; + const FloatingPoint& rhs = results[currNode[1]].d_fp; + const BitVector& zeroCaseBv = results[currNode[2]].d_bv; + Assert(zeroCaseBv.getSize() == 1); + bool zeroCase = zeroCaseBv.isBitSet(0); + results[currNode] = EvalResult(lhs.maxTotal(rhs, zeroCase)); + break; + } + + case kind::FLOATINGPOINT_LT: + { + const FloatingPoint& lhs = results[currNode[0]].d_fp; + const FloatingPoint& rhs = results[currNode[1]].d_fp; + results[currNode] = EvalResult(lhs < rhs); + break; + } + + case kind::FLOATINGPOINT_ISN: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.isNormal()); + break; + } + + case kind::FLOATINGPOINT_ISSN: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.isNormal()); + break; + } + + case kind::FLOATINGPOINT_ISZ: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.isZero()); + break; + } + + case kind::FLOATINGPOINT_ISINF: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.isInfinite()); + break; + } + + case kind::FLOATINGPOINT_ISNAN: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.isNaN()); + break; + } + + case kind::FLOATINGPOINT_ISNEG: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.isNegative()); + break; + } + + case kind::FLOATINGPOINT_ISPOS: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.isPositive()); + break; + } + + case kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR: + { + TNode op = currNode.getOperator(); + const FloatingPointToFPIEEEBitVector& param = + op.getConst<FloatingPointToFPIEEEBitVector>(); + const BitVector& val = results[currNode[0]].d_bv; + results[currNode] = EvalResult( + FloatingPoint(param.t.exponent(), param.t.significand(), val)); + break; + } + + case kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT: + { + TNode op = currNode.getOperator(); + const FloatingPointToFPFloatingPoint& param = + op.getConst<FloatingPointToFPFloatingPoint>(); + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& val = results[currNode[1]].d_fp; + + results[currNode] = EvalResult(val.convert(param.t, rm)); + break; + } + + case kind::FLOATINGPOINT_TO_FP_REAL: + { + TNode op = currNode.getOperator(); + const FloatingPointToFPFloatingPoint& param = + op.getConst<FloatingPointToFPFloatingPoint>(); + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& val = results[currNode[1]].d_fp; + + results[currNode] = EvalResult(val.convert(param.t, rm)); + break; + } + + case kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR: + { + TNode op = currNode.getOperator(); + const FloatingPointToFPFloatingPoint& param = + op.getConst<FloatingPointToFPSignedBitVector>(); + const RoundingMode& rm = results[currNode[0]].d_rm; + const BitVector& val = results[currNode[1]].d_bv; + + results[currNode] = EvalResult(FloatingPoint(param.t, rm, val, true)); + break; + } + + case kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR: + { + TNode op = currNode.getOperator(); + const FloatingPointToFPFloatingPoint& param = + op.getConst<FloatingPointToFPUnsignedBitVector>(); + const RoundingMode& rm = results[currNode[0]].d_rm; + const BitVector& val = results[currNode[1]].d_bv; + + results[currNode] = + EvalResult(FloatingPoint(param.t, rm, val, false)); + break; + } + + case kind::FLOATINGPOINT_TO_UBV: + { + TNode op = currNode.getOperator(); + const FloatingPointToUBV& param = + op.getOperator().getConst<FloatingPointToUBV>(); + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& val = results[currNode[1]].d_fp; + + FloatingPoint::PartialBitVector res( + val.convertToBV(param.bvs, rm, false)); + + if (res.second) + { + results[currNode] = EvalResult(res.first); + } + else + { + // Can't constant fold the underspecified case + return EvalResult(); + } + + break; + } + + case kind::FLOATINGPOINT_TO_UBV_TOTAL: + { + TNode op = currNode.getOperator(); + const FloatingPointToUBVTotal& param = + op.getOperator().getConst<FloatingPointToUBVTotal>(); + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& val = results[currNode[1]].d_fp; + const BitVector& partial = results[currNode[2]].d_bv; + + results[currNode] = + EvalResult(val.convertToBVTotal(param.bvs, rm, false, partial)); + break; + } + + case kind::FLOATINGPOINT_TO_SBV: + { + TNode op = currNode.getOperator(); + const FloatingPointToSBV& param = + op.getOperator().getConst<FloatingPointToSBV>(); + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& val = results[currNode[1]].d_fp; + + FloatingPoint::PartialBitVector res( + val.convertToBV(param.bvs, rm, true)); + + if (res.second) + { + results[currNode] = EvalResult(res.first); + } + else + { + // Can't constant fold the underspecified case + return EvalResult(); + } + break; + } + + case kind::FLOATINGPOINT_TO_SBV_TOTAL: + { + TNode op = currNode.getOperator(); + const FloatingPointToSBVTotal& param = + op.getConst<FloatingPointToSBVTotal>(); + const RoundingMode& rm = results[currNode[0]].d_rm; + const FloatingPoint& val = results[currNode[1]].d_fp; + const BitVector& partial = results[currNode[2]].d_bv; + + results[currNode] = + EvalResult(val.convertToBVTotal(param.bvs, rm, true, partial)); + break; + } + + case kind::FLOATINGPOINT_TO_REAL: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + + FloatingPoint::PartialRational res(val.convertToRational()); + + if (res.second) + { + results[currNode] = EvalResult(res.first); + } + else + { + // Can't constant fold the underspecified case + return EvalResult(); + } + break; + } + + case kind::FLOATINGPOINT_TO_REAL_TOTAL: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + const Rational& partial = results[currNode[1]].d_rat; + + results[currNode] = EvalResult(val.convertToRationalTotal(partial)); + break; + } + +#ifdef CVC4_USE_SYMFPU + case kind::FLOATINGPOINT_COMPONENT_NAN: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.getLiteral().nan); + break; + } + + case kind::FLOATINGPOINT_COMPONENT_INF: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.getLiteral().inf); + break; + } + + case kind::FLOATINGPOINT_COMPONENT_ZERO: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.getLiteral().zero); + break; + } + + case kind::FLOATINGPOINT_COMPONENT_SIGN: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.getLiteral().sign); + break; + } + + case kind::FLOATINGPOINT_COMPONENT_EXPONENT: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.getLiteral().exponent); + break; + } + + case kind::FLOATINGPOINT_COMPONENT_SIGNIFICAND: + { + const FloatingPoint& val = results[currNode[0]].d_fp; + results[currNode] = EvalResult(val.getLiteral().significand); + break; + } +#else + case kind::FLOATINGPOINT_COMPONENT_NAN: + case kind::FLOATINGPOINT_COMPONENT_INF: + case kind::FLOATINGPOINT_COMPONENT_ZERO: + case kind::FLOATINGPOINT_COMPONENT_SIGN: + case kind::FLOATINGPOINT_COMPONENT_EXPONENT: + case kind::FLOATINGPOINT_COMPONENT_SIGNIFICAND: + { + // symfpu is required for these operators + return EvalResult(); + } +#endif + case kind::EQUAL: { EvalResult lhs = results[currNode[0]]; @@ -543,6 +1025,18 @@ EvalResult Evaluator::evalInternal(TNode n, break; } + case EvalResult::ROUNDINGMODE: + { + results[currNode] = EvalResult(lhs.d_rm == rhs.d_rm); + break; + } + + case EvalResult::FLOATINGPOINT: + { + results[currNode] = EvalResult(lhs.d_fp == rhs.d_fp); + break; + } + case EvalResult::RATIONAL: { results[currNode] = EvalResult(lhs.d_rat == rhs.d_rat); diff --git a/src/theory/evaluator.h b/src/theory/evaluator.h index 0d7ddbec8..38030a52c 100644 --- a/src/theory/evaluator.h +++ b/src/theory/evaluator.h @@ -26,6 +26,8 @@ #include "base/output.h" #include "expr/node.h" #include "util/bitvector.h" +#include "util/floatingpoint.h" +#include "util/hash.h" #include "util/rational.h" #include "util/regexp.h" @@ -43,6 +45,8 @@ struct EvalResult { BOOL, BITVECTOR, + ROUNDINGMODE, + FLOATINGPOINT, RATIONAL, STRING, INVALID @@ -53,6 +57,8 @@ struct EvalResult { bool d_bool; BitVector d_bv; + RoundingMode d_rm; + FloatingPoint d_fp; Rational d_rat; String d_str; }; @@ -61,7 +67,9 @@ struct EvalResult EvalResult() : d_tag(INVALID) {} EvalResult(bool b) : d_tag(BOOL), d_bool(b) {} EvalResult(const BitVector& bv) : d_tag(BITVECTOR), d_bv(bv) {} - EvalResult(const Rational& i) : d_tag(RATIONAL), d_rat(i) {} + EvalResult(const RoundingMode& rm) : d_tag(ROUNDINGMODE), d_rm(rm) {} + EvalResult(const FloatingPoint& fp) : d_tag(FLOATINGPOINT), d_fp(fp) {} + EvalResult(const Rational& r) : d_tag(RATIONAL), d_rat(r) {} EvalResult(const String& str) : d_tag(STRING), d_str(str) {} EvalResult& operator=(const EvalResult& other); @@ -88,6 +96,9 @@ class Evaluator * `args` and the corresponding values `vals`. The function returns a null * node if there is a subterm that is not constant under the substitution or * if an operator is not supported by the evaluator. + * + * Note: The evaluator expects that `n` has been rewritten to the point where + * it does not contain non-essential operators (e.g. FLOATINGPOINT_EQ). */ Node eval(TNode n, const std::vector<Node>& args, diff --git a/src/theory/quantifiers/sygus/term_database_sygus.cpp b/src/theory/quantifiers/sygus/term_database_sygus.cpp index c6976ac62..3ab021315 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.cpp +++ b/src/theory/quantifiers/sygus/term_database_sygus.cpp @@ -1572,11 +1572,6 @@ Node TermDbSygus::evaluateBuiltin(TypeNode tn, } if (!res.isNull()) { - Assert(res - == Rewriter::rewrite(bn.substitute(it->second.begin(), - it->second.end(), - args.begin(), - args.end()))); return res; } else diff --git a/src/theory/quantifiers/sygus_sampler.cpp b/src/theory/quantifiers/sygus_sampler.cpp index b1b21a53e..0ed7cea12 100644 --- a/src/theory/quantifiers/sygus_sampler.cpp +++ b/src/theory/quantifiers/sygus_sampler.cpp @@ -451,16 +451,23 @@ void SygusSampler::addSamplePoint(std::vector<Node>& pt) Node SygusSampler::evaluate(Node n, unsigned index) { Assert(index < d_samples.size()); - // use efficient rewrite for substitution + rewrite - Node ev = d_eval.eval(n, d_vars, d_samples[index]); - Trace("sygus-sample-ev") << "Evaluate ( " << n << ", " << index << " ) -> "; - if (!ev.isNull()) + + std::vector<Node>& pt = d_samples[index]; + + Node ev; + if (options::sygusEvalOpt()) { - Trace("sygus-sample-ev") << ev << std::endl; - return ev; + // use efficient rewrite for substitution + rewrite + ev = d_eval.eval(n, d_vars, pt); + Trace("sygus-sample-ev") << "Evaluate ( " << n << ", " << index << " ) -> "; + if (!ev.isNull()) + { + Trace("sygus-sample-ev") << ev << std::endl; + return ev; + } } + // substitution + rewrite - std::vector<Node>& pt = d_samples[index]; ev = n.substitute(d_vars.begin(), d_vars.end(), pt.begin(), pt.end()); ev = Rewriter::rewrite(ev); Trace("sygus-sample-ev") << ev << std::endl; |