summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndres Noetzli <andres.noetzli@gmail.com>2018-07-09 16:26:01 -0700
committerAndres Noetzli <andres.noetzli@gmail.com>2018-07-13 17:39:52 -0700
commit0631bbb1716dc7343cfb0c3a4b447c7e667dc5d2 (patch)
tree660f2fa6e70991d004c80c6eace2cd043a98bae3
parent86d9ba4431108e1fd89639e23857631a7380a005 (diff)
Add floating-point support in evaluatoreval_fp
Currently, the operations implemented by the FloatingPoint class are fairly slow, so it is not always beneficial to use the Evaluator as opposed to the rewriter because the rewriter does more aggressive caching. Thus, the evaluator is disabled by default for floating-point logics. Note that this commit also adds a check for the --sygus-eval-opt flag in SygusSampler::evaluate() (previously, it was always used in that method).
-rw-r--r--src/smt/smt_engine.cpp8
-rw-r--r--src/theory/evaluator.cpp500
-rw-r--r--src/theory/evaluator.h13
-rw-r--r--src/theory/quantifiers/sygus/term_database_sygus.cpp5
-rw-r--r--src/theory/quantifiers/sygus_sampler.cpp21
5 files changed, 531 insertions, 16 deletions
diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp
index 22916e354..457e9a486 100644
--- a/src/smt/smt_engine.cpp
+++ b/src/smt/smt_engine.cpp
@@ -1312,6 +1312,14 @@ void SmtEngine::setDefaults() {
options::cbqi.set(true);
}
}
+ if (d_logic.isTheoryEnabled(THEORY_FP)
+ && !options::sygusEvalOpt.wasSetByUser())
+ {
+ Notice() << "SmtEngine: turning off evaluator for floating-point logics "
+ "for better performance"
+ << endl;
+ options::sygusEvalOpt.set(false);
+ }
}
if (options::bitblastMode() == theory::bv::BITBLAST_MODE_EAGER)
diff --git a/src/theory/evaluator.cpp b/src/theory/evaluator.cpp
index ca2140ed5..925e6b052 100644
--- a/src/theory/evaluator.cpp
+++ b/src/theory/evaluator.cpp
@@ -17,6 +17,7 @@
#include "theory/evaluator.h"
#include "theory/bv/theory_bv_utils.h"
+#include "theory/rewriter.h"
#include "theory/theory.h"
#include "util/integer.h"
@@ -33,6 +34,11 @@ EvalResult::EvalResult(const EvalResult& other)
new (&d_bv) BitVector;
d_bv = other.d_bv;
break;
+ case ROUNDINGMODE:
+ new (&d_rm) RoundingMode;
+ d_rm = other.d_rm;
+ break;
+ case FLOATINGPOINT: new (&d_fp) FloatingPoint(other.d_fp); break;
case RATIONAL:
new (&d_rat) Rational;
d_rat = other.d_rat;
@@ -57,6 +63,11 @@ EvalResult& EvalResult::operator=(const EvalResult& other)
new (&d_bv) BitVector;
d_bv = other.d_bv;
break;
+ case ROUNDINGMODE:
+ new (&d_rm) RoundingMode;
+ d_rm = other.d_rm;
+ break;
+ case FLOATINGPOINT: new (&d_fp) FloatingPoint(other.d_fp); break;
case RATIONAL:
new (&d_rat) Rational;
d_rat = other.d_rat;
@@ -80,6 +91,16 @@ EvalResult::~EvalResult()
d_bv.~BitVector();
break;
}
+ case ROUNDINGMODE:
+ {
+ d_rm.~RoundingMode();
+ break;
+ }
+ case FLOATINGPOINT:
+ {
+ d_fp.~FloatingPoint();
+ break;
+ }
case RATIONAL:
{
d_rat.~Rational();
@@ -89,9 +110,9 @@ EvalResult::~EvalResult()
{
d_str.~String();
break;
-
- default: break;
}
+
+ default: break;
}
}
@@ -102,6 +123,8 @@ Node EvalResult::toNode() const
{
case EvalResult::BOOL: return nm->mkConst(d_bool);
case EvalResult::BITVECTOR: return nm->mkConst(d_bv);
+ case EvalResult::ROUNDINGMODE: return nm->mkConst(d_rm);
+ case EvalResult::FLOATINGPOINT: return nm->mkConst(d_fp);
case EvalResult::RATIONAL: return nm->mkConst(d_rat);
case EvalResult::STRING: return nm->mkConst(d_str);
default:
@@ -121,7 +144,12 @@ Node Evaluator::eval(TNode n,
{
Trace("evaluator") << "Evaluating " << n << " under substitution " << args
<< " " << vals << std::endl;
- return evalInternal(n, args, vals).toNode();
+ Node res = n.isConst() ? Node(n) : evalInternal(n, args, vals).toNode();
+ Assert(res.isNull()
+ || res
+ == Rewriter::rewrite(n.substitute(
+ args.begin(), args.end(), vals.begin(), vals.end())));
+ return res;
}
EvalResult Evaluator::evalInternal(TNode n,
@@ -524,6 +552,460 @@ EvalResult Evaluator::evalInternal(TNode n,
break;
}
+ case kind::CONST_FLOATINGPOINT:
+ {
+ results[currNode] = EvalResult(currNodeVal.getConst<FloatingPoint>());
+ break;
+ }
+
+ case kind::CONST_ROUNDINGMODE:
+ {
+ results[currNode] = EvalResult(currNodeVal.getConst<RoundingMode>());
+ break;
+ }
+
+ case kind::FLOATINGPOINT_FP:
+ {
+ const BitVector& sign = results[currNode[0]].d_bv;
+ const BitVector& exp = results[currNode[1]].d_bv;
+ const BitVector& sig = results[currNode[2]].d_bv;
+ Assert(sign.getSize() == 1);
+ unsigned e = exp.getSize();
+ unsigned s = sig.getSize() + 1;
+ results[currNode] =
+ EvalResult(FloatingPoint(e, s, sign.concat(exp.concat(sig))));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_EQ:
+ case kind::FLOATINGPOINT_GEQ:
+ case kind::FLOATINGPOINT_GT:
+ {
+ // These kinds should have been removed by earlier rewrites. In
+ // debug, we would like to know about this. In production, we can
+ // just return an invalid result and fall back to the full rewriter.
+ Assert(false);
+ return EvalResult();
+ }
+
+ case kind::FLOATINGPOINT_ABS:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.absolute());
+ break;
+ }
+
+ case kind::FLOATINGPOINT_NEG:
+ {
+ results[currNode] = EvalResult(results[currNode[0]].d_fp.negate());
+ break;
+ }
+
+ case kind::FLOATINGPOINT_PLUS:
+ {
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& lhs = results[currNode[1]].d_fp;
+ const FloatingPoint& rhs = results[currNode[2]].d_fp;
+ results[currNode] = EvalResult(lhs.plus(rm, rhs));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_SUB:
+ {
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& lhs = results[currNode[1]].d_fp;
+ const FloatingPoint& rhs = results[currNode[2]].d_fp;
+ results[currNode] = EvalResult(lhs.sub(rm, rhs));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_MULT:
+ {
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& lhs = results[currNode[1]].d_fp;
+ const FloatingPoint& rhs = results[currNode[2]].d_fp;
+ results[currNode] = EvalResult(lhs.mult(rm, rhs));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_DIV:
+ {
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& lhs = results[currNode[1]].d_fp;
+ const FloatingPoint& rhs = results[currNode[2]].d_fp;
+ results[currNode] = EvalResult(lhs.div(rm, rhs));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_FMA:
+ {
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& val = results[currNode[1]].d_fp;
+ const FloatingPoint& fac = results[currNode[2]].d_fp;
+ const FloatingPoint& add = results[currNode[3]].d_fp;
+ results[currNode] = EvalResult(val.fma(rm, fac, add));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_SQRT:
+ {
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& val = results[currNode[1]].d_fp;
+ results[currNode] = EvalResult(val.sqrt(rm));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_REM:
+ {
+ const FloatingPoint& lhs = results.at(currNode[0]).d_fp;
+ const FloatingPoint& rhs = results.at(currNode[1]).d_fp;
+ FloatingPoint res = lhs.rem(rhs);
+ results[currNode] = EvalResult(lhs.rem(rhs));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_RTI:
+ {
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& val = results[currNode[1]].d_fp;
+ results[currNode] = EvalResult(val.rti(rm));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_MIN:
+ {
+ const FloatingPoint& lhs = results[currNode[0]].d_fp;
+ const FloatingPoint& rhs = results[currNode[1]].d_fp;
+
+ FloatingPoint::PartialFloatingPoint res(lhs.min(rhs));
+
+ if (res.second)
+ {
+ results[currNode] = EvalResult(res.first);
+ }
+ else
+ {
+ // Can't constant fold the underspecified case
+ return EvalResult();
+ }
+ break;
+ }
+
+ case kind::FLOATINGPOINT_MAX:
+ {
+ const FloatingPoint& lhs = results[currNode[0]].d_fp;
+ const FloatingPoint& rhs = results[currNode[1]].d_fp;
+
+ FloatingPoint::PartialFloatingPoint res(lhs.max(rhs));
+
+ if (res.second)
+ {
+ results[currNode] = EvalResult(res.first);
+ }
+ else
+ {
+ // Can't constant fold the underspecified case
+ return EvalResult();
+ }
+ break;
+ }
+
+ case kind::FLOATINGPOINT_MIN_TOTAL:
+ {
+ const FloatingPoint& lhs = results[currNode[0]].d_fp;
+ const FloatingPoint& rhs = results[currNode[1]].d_fp;
+ const BitVector& zeroCaseBv = results[currNode[2]].d_bv;
+ Assert(zeroCaseBv.getSize() == 1);
+ bool zeroCase = zeroCaseBv.isBitSet(0);
+ results[currNode] = EvalResult(lhs.minTotal(rhs, zeroCase));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_MAX_TOTAL:
+ {
+ const FloatingPoint& lhs = results[currNode[0]].d_fp;
+ const FloatingPoint& rhs = results[currNode[1]].d_fp;
+ const BitVector& zeroCaseBv = results[currNode[2]].d_bv;
+ Assert(zeroCaseBv.getSize() == 1);
+ bool zeroCase = zeroCaseBv.isBitSet(0);
+ results[currNode] = EvalResult(lhs.maxTotal(rhs, zeroCase));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_LT:
+ {
+ const FloatingPoint& lhs = results[currNode[0]].d_fp;
+ const FloatingPoint& rhs = results[currNode[1]].d_fp;
+ results[currNode] = EvalResult(lhs < rhs);
+ break;
+ }
+
+ case kind::FLOATINGPOINT_ISN:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.isNormal());
+ break;
+ }
+
+ case kind::FLOATINGPOINT_ISSN:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.isNormal());
+ break;
+ }
+
+ case kind::FLOATINGPOINT_ISZ:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.isZero());
+ break;
+ }
+
+ case kind::FLOATINGPOINT_ISINF:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.isInfinite());
+ break;
+ }
+
+ case kind::FLOATINGPOINT_ISNAN:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.isNaN());
+ break;
+ }
+
+ case kind::FLOATINGPOINT_ISNEG:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.isNegative());
+ break;
+ }
+
+ case kind::FLOATINGPOINT_ISPOS:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.isPositive());
+ break;
+ }
+
+ case kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR:
+ {
+ TNode op = currNode.getOperator();
+ const FloatingPointToFPIEEEBitVector& param =
+ op.getConst<FloatingPointToFPIEEEBitVector>();
+ const BitVector& val = results[currNode[0]].d_bv;
+ results[currNode] = EvalResult(
+ FloatingPoint(param.t.exponent(), param.t.significand(), val));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT:
+ {
+ TNode op = currNode.getOperator();
+ const FloatingPointToFPFloatingPoint& param =
+ op.getConst<FloatingPointToFPFloatingPoint>();
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& val = results[currNode[1]].d_fp;
+
+ results[currNode] = EvalResult(val.convert(param.t, rm));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_TO_FP_REAL:
+ {
+ TNode op = currNode.getOperator();
+ const FloatingPointToFPFloatingPoint& param =
+ op.getConst<FloatingPointToFPFloatingPoint>();
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& val = results[currNode[1]].d_fp;
+
+ results[currNode] = EvalResult(val.convert(param.t, rm));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR:
+ {
+ TNode op = currNode.getOperator();
+ const FloatingPointToFPFloatingPoint& param =
+ op.getConst<FloatingPointToFPSignedBitVector>();
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const BitVector& val = results[currNode[1]].d_bv;
+
+ results[currNode] = EvalResult(FloatingPoint(param.t, rm, val, true));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR:
+ {
+ TNode op = currNode.getOperator();
+ const FloatingPointToFPFloatingPoint& param =
+ op.getConst<FloatingPointToFPUnsignedBitVector>();
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const BitVector& val = results[currNode[1]].d_bv;
+
+ results[currNode] =
+ EvalResult(FloatingPoint(param.t, rm, val, false));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_TO_UBV:
+ {
+ TNode op = currNode.getOperator();
+ const FloatingPointToUBV& param =
+ op.getOperator().getConst<FloatingPointToUBV>();
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& val = results[currNode[1]].d_fp;
+
+ FloatingPoint::PartialBitVector res(
+ val.convertToBV(param.bvs, rm, false));
+
+ if (res.second)
+ {
+ results[currNode] = EvalResult(res.first);
+ }
+ else
+ {
+ // Can't constant fold the underspecified case
+ return EvalResult();
+ }
+
+ break;
+ }
+
+ case kind::FLOATINGPOINT_TO_UBV_TOTAL:
+ {
+ TNode op = currNode.getOperator();
+ const FloatingPointToUBVTotal& param =
+ op.getOperator().getConst<FloatingPointToUBVTotal>();
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& val = results[currNode[1]].d_fp;
+ const BitVector& partial = results[currNode[2]].d_bv;
+
+ results[currNode] =
+ EvalResult(val.convertToBVTotal(param.bvs, rm, false, partial));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_TO_SBV:
+ {
+ TNode op = currNode.getOperator();
+ const FloatingPointToSBV& param =
+ op.getOperator().getConst<FloatingPointToSBV>();
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& val = results[currNode[1]].d_fp;
+
+ FloatingPoint::PartialBitVector res(
+ val.convertToBV(param.bvs, rm, true));
+
+ if (res.second)
+ {
+ results[currNode] = EvalResult(res.first);
+ }
+ else
+ {
+ // Can't constant fold the underspecified case
+ return EvalResult();
+ }
+ break;
+ }
+
+ case kind::FLOATINGPOINT_TO_SBV_TOTAL:
+ {
+ TNode op = currNode.getOperator();
+ const FloatingPointToSBVTotal& param =
+ op.getConst<FloatingPointToSBVTotal>();
+ const RoundingMode& rm = results[currNode[0]].d_rm;
+ const FloatingPoint& val = results[currNode[1]].d_fp;
+ const BitVector& partial = results[currNode[2]].d_bv;
+
+ results[currNode] =
+ EvalResult(val.convertToBVTotal(param.bvs, rm, true, partial));
+ break;
+ }
+
+ case kind::FLOATINGPOINT_TO_REAL:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+
+ FloatingPoint::PartialRational res(val.convertToRational());
+
+ if (res.second)
+ {
+ results[currNode] = EvalResult(res.first);
+ }
+ else
+ {
+ // Can't constant fold the underspecified case
+ return EvalResult();
+ }
+ break;
+ }
+
+ case kind::FLOATINGPOINT_TO_REAL_TOTAL:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ const Rational& partial = results[currNode[1]].d_rat;
+
+ results[currNode] = EvalResult(val.convertToRationalTotal(partial));
+ break;
+ }
+
+#ifdef CVC4_USE_SYMFPU
+ case kind::FLOATINGPOINT_COMPONENT_NAN:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.getLiteral().nan);
+ break;
+ }
+
+ case kind::FLOATINGPOINT_COMPONENT_INF:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.getLiteral().inf);
+ break;
+ }
+
+ case kind::FLOATINGPOINT_COMPONENT_ZERO:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.getLiteral().zero);
+ break;
+ }
+
+ case kind::FLOATINGPOINT_COMPONENT_SIGN:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.getLiteral().sign);
+ break;
+ }
+
+ case kind::FLOATINGPOINT_COMPONENT_EXPONENT:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.getLiteral().exponent);
+ break;
+ }
+
+ case kind::FLOATINGPOINT_COMPONENT_SIGNIFICAND:
+ {
+ const FloatingPoint& val = results[currNode[0]].d_fp;
+ results[currNode] = EvalResult(val.getLiteral().significand);
+ break;
+ }
+#else
+ case kind::FLOATINGPOINT_COMPONENT_NAN:
+ case kind::FLOATINGPOINT_COMPONENT_INF:
+ case kind::FLOATINGPOINT_COMPONENT_ZERO:
+ case kind::FLOATINGPOINT_COMPONENT_SIGN:
+ case kind::FLOATINGPOINT_COMPONENT_EXPONENT:
+ case kind::FLOATINGPOINT_COMPONENT_SIGNIFICAND:
+ {
+ // symfpu is required for these operators
+ return EvalResult();
+ }
+#endif
+
case kind::EQUAL:
{
EvalResult lhs = results[currNode[0]];
@@ -543,6 +1025,18 @@ EvalResult Evaluator::evalInternal(TNode n,
break;
}
+ case EvalResult::ROUNDINGMODE:
+ {
+ results[currNode] = EvalResult(lhs.d_rm == rhs.d_rm);
+ break;
+ }
+
+ case EvalResult::FLOATINGPOINT:
+ {
+ results[currNode] = EvalResult(lhs.d_fp == rhs.d_fp);
+ break;
+ }
+
case EvalResult::RATIONAL:
{
results[currNode] = EvalResult(lhs.d_rat == rhs.d_rat);
diff --git a/src/theory/evaluator.h b/src/theory/evaluator.h
index 0d7ddbec8..38030a52c 100644
--- a/src/theory/evaluator.h
+++ b/src/theory/evaluator.h
@@ -26,6 +26,8 @@
#include "base/output.h"
#include "expr/node.h"
#include "util/bitvector.h"
+#include "util/floatingpoint.h"
+#include "util/hash.h"
#include "util/rational.h"
#include "util/regexp.h"
@@ -43,6 +45,8 @@ struct EvalResult
{
BOOL,
BITVECTOR,
+ ROUNDINGMODE,
+ FLOATINGPOINT,
RATIONAL,
STRING,
INVALID
@@ -53,6 +57,8 @@ struct EvalResult
{
bool d_bool;
BitVector d_bv;
+ RoundingMode d_rm;
+ FloatingPoint d_fp;
Rational d_rat;
String d_str;
};
@@ -61,7 +67,9 @@ struct EvalResult
EvalResult() : d_tag(INVALID) {}
EvalResult(bool b) : d_tag(BOOL), d_bool(b) {}
EvalResult(const BitVector& bv) : d_tag(BITVECTOR), d_bv(bv) {}
- EvalResult(const Rational& i) : d_tag(RATIONAL), d_rat(i) {}
+ EvalResult(const RoundingMode& rm) : d_tag(ROUNDINGMODE), d_rm(rm) {}
+ EvalResult(const FloatingPoint& fp) : d_tag(FLOATINGPOINT), d_fp(fp) {}
+ EvalResult(const Rational& r) : d_tag(RATIONAL), d_rat(r) {}
EvalResult(const String& str) : d_tag(STRING), d_str(str) {}
EvalResult& operator=(const EvalResult& other);
@@ -88,6 +96,9 @@ class Evaluator
* `args` and the corresponding values `vals`. The function returns a null
* node if there is a subterm that is not constant under the substitution or
* if an operator is not supported by the evaluator.
+ *
+ * Note: The evaluator expects that `n` has been rewritten to the point where
+ * it does not contain non-essential operators (e.g. FLOATINGPOINT_EQ).
*/
Node eval(TNode n,
const std::vector<Node>& args,
diff --git a/src/theory/quantifiers/sygus/term_database_sygus.cpp b/src/theory/quantifiers/sygus/term_database_sygus.cpp
index c6976ac62..3ab021315 100644
--- a/src/theory/quantifiers/sygus/term_database_sygus.cpp
+++ b/src/theory/quantifiers/sygus/term_database_sygus.cpp
@@ -1572,11 +1572,6 @@ Node TermDbSygus::evaluateBuiltin(TypeNode tn,
}
if (!res.isNull())
{
- Assert(res
- == Rewriter::rewrite(bn.substitute(it->second.begin(),
- it->second.end(),
- args.begin(),
- args.end())));
return res;
}
else
diff --git a/src/theory/quantifiers/sygus_sampler.cpp b/src/theory/quantifiers/sygus_sampler.cpp
index b1b21a53e..0ed7cea12 100644
--- a/src/theory/quantifiers/sygus_sampler.cpp
+++ b/src/theory/quantifiers/sygus_sampler.cpp
@@ -451,16 +451,23 @@ void SygusSampler::addSamplePoint(std::vector<Node>& pt)
Node SygusSampler::evaluate(Node n, unsigned index)
{
Assert(index < d_samples.size());
- // use efficient rewrite for substitution + rewrite
- Node ev = d_eval.eval(n, d_vars, d_samples[index]);
- Trace("sygus-sample-ev") << "Evaluate ( " << n << ", " << index << " ) -> ";
- if (!ev.isNull())
+
+ std::vector<Node>& pt = d_samples[index];
+
+ Node ev;
+ if (options::sygusEvalOpt())
{
- Trace("sygus-sample-ev") << ev << std::endl;
- return ev;
+ // use efficient rewrite for substitution + rewrite
+ ev = d_eval.eval(n, d_vars, pt);
+ Trace("sygus-sample-ev") << "Evaluate ( " << n << ", " << index << " ) -> ";
+ if (!ev.isNull())
+ {
+ Trace("sygus-sample-ev") << ev << std::endl;
+ return ev;
+ }
}
+
// substitution + rewrite
- std::vector<Node>& pt = d_samples[index];
ev = n.substitute(d_vars.begin(), d_vars.end(), pt.begin(), pt.end());
ev = Rewriter::rewrite(ev);
Trace("sygus-sample-ev") << ev << std::endl;
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback