diff options
Diffstat (limited to 'src/theory/uf/ho_extension.cpp')
-rw-r--r-- | src/theory/uf/ho_extension.cpp | 274 |
1 files changed, 260 insertions, 14 deletions
diff --git a/src/theory/uf/ho_extension.cpp b/src/theory/uf/ho_extension.cpp index 96029eab8..aec223d66 100644 --- a/src/theory/uf/ho_extension.cpp +++ b/src/theory/uf/ho_extension.cpp @@ -19,6 +19,7 @@ #include "expr/skolem_manager.h" #include "options/uf_options.h" #include "theory/theory_model.h" +#include "theory/uf/lambda_lift.h" #include "theory/uf/theory_uf_rewriter.h" using namespace std; @@ -30,31 +31,83 @@ namespace uf { HoExtension::HoExtension(Env& env, TheoryState& state, - TheoryInferenceManager& im) + TheoryInferenceManager& im, + LambdaLift& ll) : EnvObj(env), d_state(state), d_im(im), + d_ll(ll), d_extensionality(userContext()), + d_cachedLemmas(userContext()), d_uf_std_skolem(userContext()) { d_true = NodeManager::currentNM()->mkConst(true); } -Node HoExtension::ppRewrite(Node node) +TrustNode HoExtension::ppRewrite(Node node, std::vector<SkolemLemma>& lems) { - // convert HO_APPLY to APPLY_UF if fully applied - if (node.getKind() == HO_APPLY) + Kind k = node.getKind(); + if (k == HO_APPLY) { + // convert HO_APPLY to APPLY_UF if fully applied if (node[0].getType().getNumChildren() == 2) { Trace("uf-ho") << "uf-ho : expanding definition : " << node << std::endl; Node ret = getApplyUfForHoApply(node); Trace("uf-ho") << "uf-ho : ppRewrite : " << node << " to " << ret << std::endl; - return ret; + return TrustNode::mkTrustRewrite(node, ret); } + // partial beta reduction + // f ---> (lambda ((x Int) (y Int)) s[x, y]) then (@ f t) is preprocessed + // to (lambda ((y Int)) s[t, y]). + if (options().uf.ufHoLazyLambdaLift) + { + Node op = node[0]; + Node opl = d_ll.getLambdaFor(op); + if (!opl.isNull()) + { + NodeManager* nm = NodeManager::currentNM(); + Node app = nm->mkNode(HO_APPLY, opl, node[1]); + app = rewrite(app); + Trace("uf-lazy-ll") + << "Partial beta reduce: " << node << " -> " << app << std::endl; + return TrustNode::mkTrustRewrite(node, app, nullptr); + } + } + } + else if (k == APPLY_UF) + { + // Say (lambda ((x Int)) t[x]) occurs in the input. We replace this + // by k during ppRewrite. In the following, if we see (k s), we replace + // it by t[s]. This maintains the invariant that the *only* occurences + // of k are as arguments to other functions; k is not applied + // in any preprocessed constraints. + if (options().uf.ufHoLazyLambdaLift) + { + // if an application of the lambda lifted function, do beta reduction + // immediately + Node op = node.getOperator(); + Node opl = d_ll.getLambdaFor(op); + if (!opl.isNull()) + { + Assert(opl.getKind() == LAMBDA); + std::vector<Node> args(node.begin(), node.end()); + Node app = d_ll.betaReduce(opl, args); + Trace("uf-lazy-ll") + << "Beta reduce: " << node << " -> " << app << std::endl; + return TrustNode::mkTrustRewrite(node, app, nullptr); + } + } + } + else if (k == kind::LAMBDA) + { + Trace("uf-lazy-ll") << "Preprocess lambda: " << node << std::endl; + TrustNode skTrn = d_ll.ppRewrite(node, lems); + Trace("uf-lazy-ll") << "...return " << skTrn.getNode() << std::endl; + return skTrn; } - return node; + return TrustNode::null(); } Node HoExtension::getExtensionalityDeq(TNode deq, bool isCached) @@ -217,7 +270,7 @@ unsigned HoExtension::checkExtensionality(TheoryModel* m) { Node eqc = (*eqcs_i); TypeNode tn = eqc.getType(); - if (tn.isFunction()) + if (tn.isFunction() && d_lambdaEqc.find(eqc) == d_lambdaEqc.end()) { hasFunctions = true; // if during collect model, must have an infinite type @@ -413,6 +466,168 @@ unsigned HoExtension::checkAppCompletion() return 0; } +unsigned HoExtension::checkLazyLambda() +{ + if (!options().uf.ufHoLazyLambdaLift) + { + // no lambdas are lazily lifted + return 0; + } + Trace("uf-ho") << "HoExtension::checkLazyLambda..." << std::endl; + NodeManager* nm = NodeManager::currentNM(); + unsigned numLemmas = 0; + d_lambdaEqc.clear(); + eq::EqualityEngine* ee = d_state.getEqualityEngine(); + eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(ee); + // normal functions equated to lambda functions + std::unordered_set<Node> normalEqFuns; + // mapping from functions to terms + while (!eqcs_i.isFinished()) + { + Node eqc = (*eqcs_i); + ++eqcs_i; + if (!eqc.getType().isFunction()) + { + continue; + } + eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, ee); + Node lamRep; // the first lambda function we encounter in the equivalence + // class + Node lamRepLam; + std::unordered_set<Node> normalEqFunWait; + while (!eqc_i.isFinished()) + { + Node n = *eqc_i; + ++eqc_i; + Node lam = d_ll.getLambdaFor(n); + if (lam.isNull()) + { + if (!lamRep.isNull()) + { + // if we are equal to a lambda function, we must beta-reduce + // applications of this + normalEqFuns.insert(n); + } + else + { + // waiting to see if there is a lambda function in this equivalence + // class + normalEqFunWait.insert(n); + } + } + else if (lamRep.isNull()) + { + // there is a lambda function in this equivalence class + lamRep = n; + lamRepLam = lam; + // must consider all normal functions we've seen so far + normalEqFuns.insert(normalEqFunWait.begin(), normalEqFunWait.end()); + normalEqFunWait.clear(); + } + else + { + // two lambda functions are in same equivalence class + Node f = lamRep < n ? lamRep : n; + Node g = lamRep < n ? n : lamRep; + Trace("uf-ho-debug") << " found equivalent lambda functions " << f + << " and " << g << std::endl; + Node flam = lamRep < n ? lamRepLam : lam; + Assert(!flam.isNull() && flam.getKind() == LAMBDA); + Node lhs = flam[1]; + Node glam = lamRep < n ? lam : lamRepLam; + Trace("uf-ho-debug") + << " lambda are " << flam << " and " << glam << std::endl; + std::vector<Node> args(flam[0].begin(), flam[0].end()); + Node rhs = d_ll.betaReduce(glam, args); + Node univ = nm->mkNode(FORALL, flam[0], lhs.eqNode(rhs)); + // f = g => forall x. reduce(lambda(f)(x)) = reduce(lambda(g)(x)) + // + // For example, if f -> lambda z. z+1, g -> lambda y. y+3, this + // will infer: f = g => forall x. x+1 = x+3, which simplifies to + // f != g. + Node lem = nm->mkNode(IMPLIES, f.eqNode(g), univ); + if (cacheLemma(lem)) + { + d_im.lemma(lem, InferenceId::UF_HO_LAMBDA_UNIV_EQ); + numLemmas++; + } + } + } + if (!lamRep.isNull()) + { + d_lambdaEqc[eqc] = lamRep; + } + } + Trace("uf-ho-debug") + << " found " << normalEqFuns.size() + << " ordinary functions that are equal to lambda functions" << std::endl; + if (normalEqFuns.empty()) + { + return numLemmas; + } + // if we have normal functions that are equal to lambda functions, go back + // and ensure they are mapped properly + // mapping from functions to terms + eq::EqClassesIterator eqcs_i2 = eq::EqClassesIterator(ee); + while (!eqcs_i2.isFinished()) + { + Node eqc = (*eqcs_i2); + ++eqcs_i2; + Trace("uf-ho-debug") << "Check equivalence class " << eqc << std::endl; + eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, ee); + while (!eqc_i.isFinished()) + { + Node n = *eqc_i; + ++eqc_i; + Trace("uf-ho-debug") << "Check term " << n << std::endl; + Node op; + Kind k = n.getKind(); + std::vector<Node> args; + if (k == APPLY_UF) + { + op = n.getOperator(); + args.insert(args.end(), n.begin(), n.end()); + } + else if (k == HO_APPLY) + { + op = n[0]; + args.push_back(n[1]); + } + else + { + continue; + } + if (normalEqFuns.find(op) == normalEqFuns.end()) + { + continue; + } + Trace("uf-ho-debug") << " found relevant ordinary application " << n + << std::endl; + Assert(ee->hasTerm(op)); + Node r = ee->getRepresentative(op); + Assert(d_lambdaEqc.find(r) != d_lambdaEqc.end()); + Node lf = d_lambdaEqc[r]; + Node lam = d_ll.getLambdaFor(lf); + Assert(!lam.isNull() && lam.getKind() == LAMBDA); + // a normal function g equal to a lambda, say f --> lambda(f) + // need to infer f = g => g(t) = f(t) for all terms g(t) + // that occur in the equality engine. + Node premise = op.eqNode(lf); + args.insert(args.begin(), lam); + Node rhs = nm->mkNode(n.getKind(), args); + rhs = rewrite(rhs); + Node conc = n.eqNode(rhs); + Node lem = nm->mkNode(IMPLIES, premise, conc); + if (cacheLemma(lem)) + { + d_im.lemma(lem, InferenceId::UF_HO_LAMBDA_APP_REDUCE); + numLemmas++; + } + } + } + return numLemmas; +} + unsigned HoExtension::check() { Trace("uf-ho") << "HoExtension::checkHigherOrder..." << std::endl; @@ -429,14 +644,24 @@ unsigned HoExtension::check() } } while (num_facts > 0); - unsigned num_lemmas = 0; - - num_lemmas = checkExtensionality(); - if (num_lemmas > 0) + // Apply extensionality, lazy lambda schemas in order. We make lazy lambda + // handling come last as it may introduce quantifiers. + for (size_t i = 0; i < 2; i++) { - Trace("uf-ho") << "...extensionality returned " << num_lemmas << " lemmas." - << std::endl; - return num_lemmas; + unsigned num_lemmas = 0; + // apply the schema + switch (i) + { + case 0: num_lemmas = checkExtensionality(); break; + case 1: num_lemmas = checkLazyLambda(); break; + default: break; + } + // finish if we added lemmas + if (num_lemmas > 0) + { + Trace("uf-ho") << "...returned " << num_lemmas << " lemmas." << std::endl; + return num_lemmas; + } } Trace("uf-ho") << "...finished check higher order." << std::endl; @@ -464,6 +689,16 @@ bool HoExtension::collectModelInfoHo(TheoryModel* m, // non-standard alternative to using a type enumerator over function // values to assign unique values. int addedLemmas = checkExtensionality(m); + // for equivalence classes that we know to assign a lambda directly + for (const std::pair<const Node, Node>& p : d_lambdaEqc) + { + Node lam = d_ll.getLambdaFor(p.second); + Assert(!lam.isNull()); + m->assertEquality(p.second, lam, true); + m->assertSkeleton(lam); + // assign it as the function definition for all variables in this class + m->assignFunctionDefinition(p.second, lam); + } return addedLemmas == 0; } @@ -484,6 +719,17 @@ bool HoExtension::collectModelInfoHoTerm(Node n, TheoryModel* m) return true; } +bool HoExtension::cacheLemma(TNode lem) +{ + Node rewritten = rewrite(lem); + if (d_cachedLemmas.find(rewritten) != d_cachedLemmas.end()) + { + return false; + } + d_cachedLemmas.insert(rewritten); + return true; +} + } // namespace uf } // namespace theory } // namespace cvc5 |