diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/smt/boolean_terms.cpp | 70 | ||||
-rw-r--r-- | src/smt/boolean_terms.h | 2 |
2 files changed, 66 insertions, 6 deletions
diff --git a/src/smt/boolean_terms.cpp b/src/smt/boolean_terms.cpp index f8f2330e1..0063035ff 100644 --- a/src/smt/boolean_terms.cpp +++ b/src/smt/boolean_terms.cpp @@ -111,6 +111,52 @@ static inline bool isBoolean(TNode top, unsigned i) { } } +// This function rewrites "in" as an "as"---this is actually expected +// to be for model-substitution, so the "in" is a Boolean-term-converted +// node, and "as" is the original. See how it's used in function +// handling, below. +Node BooleanTermConverter::rewriteAs(TNode in, TypeNode as) throw() { + if(in.getType() == as) { + return in; + } + if(in.getType().isBoolean()) { + Assert(d_tt.getType() == as); + return NodeManager::currentNM()->mkNode(kind::ITE, in, d_tt, d_ff); + } + if(as.isBoolean() && in.getType().isBitVector() && in.getType().getBitVectorSize() == 1) { + return NodeManager::currentNM()->mkNode(kind::EQUAL, NodeManager::currentNM()->mkConst(BitVector(1u, 1u)), in); + } + if(in.getType().isDatatype()) { + if(as.isBoolean() && in.getType().hasAttribute(BooleanTermAttr())) { + return NodeManager::currentNM()->mkNode(kind::EQUAL, d_ttDt, in); + } + Assert(as.isDatatype()); + const Datatype* dt2 = &as.getDatatype(); + const Datatype* dt1 = d_datatypeCache[dt2]; + Assert(dt1 != NULL, "expected datatype in cache"); + Assert(*dt1 == in.getType().getDatatype(), "improper rewriteAs() between datatypes"); + Node out; + for(size_t i = 0; i < dt1->getNumConstructors(); ++i) { + DatatypeConstructor ctor = (*dt1)[i]; + NodeBuilder<> appctorb(kind::APPLY_CONSTRUCTOR); + appctorb << (*dt2)[i].getConstructor(); + for(size_t j = 0; j < ctor.getNumArgs(); ++j) { + appctorb << rewriteAs(NodeManager::currentNM()->mkNode(kind::APPLY_SELECTOR, ctor[j].getSelector(), in), TypeNode::fromType(SelectorType((*dt2)[i][j].getSelector().getType()).getRangeType())); + } + Node appctor = appctorb; + if(i == 0) { + out = appctor; + } else { + Node newOut = NodeManager::currentNM()->mkNode(kind::ITE, ctor.getTester(), appctor, out); + out = newOut; + } + } + return out; + } + + Unhandled(in); +} + const Datatype& BooleanTermConverter::convertDatatype(const Datatype& dt) throw() { const Datatype*& out = d_datatypeCache[&dt]; if(out != NULL) { @@ -392,12 +438,20 @@ Node BooleanTermConverter::rewriteBooleanTermsRec(TNode top, theory::TheoryId pa } else if(mk == kind::metakind::VARIABLE) { TypeNode t = top.getType(); if(t.isFunction()) { - for(unsigned i = 0; i < t.getNumChildren() - 1; ++i) { + for(unsigned i = 0; i < t.getNumChildren(); ++i) { TypeNode newType = convertType(t[i], false); - if(newType != t[i]) { + // is this the return type? (allowed to be Bool) + bool returnType = (i == t.getNumChildren() - 1); + if(newType != t[i] && (!t[i].isBoolean() || !returnType)) { vector<TypeNode> argTypes = t.getArgTypes(); - replace(argTypes.begin(), argTypes.end(), t[i], d_tt.getType()); - TypeNode newType = nm->mkFunctionType(argTypes, t.getRangeType()); + for(unsigned j = 0; j < argTypes.size(); ++j) { + argTypes[j] = convertType(argTypes[j], false); + } + TypeNode rangeType = t.getRangeType(); + if(!rangeType.isBoolean()) { + rangeType = convertType(rangeType, false); + } + TypeNode newType = nm->mkFunctionType(argTypes, rangeType); Node n = nm->mkSkolem(top.getAttribute(expr::VarNameAttr()) + "__boolterm__", newType, "a function introduced by Boolean-term conversion", NodeManager::SKOLEM_EXACT_NAME); @@ -409,14 +463,18 @@ Node BooleanTermConverter::rewriteBooleanTermsRec(TNode top, theory::TheoryId pa for(unsigned j = 0; j < t.getNumChildren() - 1; ++j) { Node var = nm->mkBoundVar(t[j]); boundVarsBuilder << var; - if(t[j].isBoolean()) { - bodyBuilder << nm->mkNode(kind::ITE, var, d_tt, d_ff); + if(t[j] != argTypes[j]) { + bodyBuilder << rewriteAs(var, argTypes[j]); } else { bodyBuilder << var; } } Node boundVars = boundVarsBuilder; Node body = bodyBuilder; + if(t.getRangeType() != rangeType) { + Node convbody = rewriteAs(body, t.getRangeType()); + body = convbody; + } Node lam = nm->mkNode(kind::LAMBDA, boundVars, body); Debug("boolean-terms") << "substituting " << top << " ==> " << lam << endl; d_smt.d_theoryEngine->getModel()->addSubstitution(top, lam); diff --git a/src/smt/boolean_terms.h b/src/smt/boolean_terms.h index e5f18b68d..904d47b5f 100644 --- a/src/smt/boolean_terms.h +++ b/src/smt/boolean_terms.h @@ -65,6 +65,8 @@ class BooleanTermConverter { /** The cache used during Boolean term datatype conversion */ BooleanTermDatatypeCache d_datatypeCache; + Node rewriteAs(TNode in, TypeNode as) throw(); + /** * Scan a datatype for and convert as necessary. */ |