diff options
author | Haniel Barbosa <hanielbbarbosa@gmail.com> | 2021-01-22 13:15:43 -0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-01-22 13:15:43 -0300 |
commit | 109e7e43efdeb557ff17880da83da438db35eb3e (patch) | |
tree | 313d66098519bd3ac782a3c21a830fa41d79214a /src/smt/proof_post_processor.cpp | |
parent | 98d2ca3ee48cb87e8baa7537c97016cc85ab048d (diff) |
[proof-new] Expanding MACRO_RESOLUTION in post-processing (#5755)
Breaks down resolution, factoring and reordering. The hardest part of this process is making getting rid of the so-called "crowding literals", i.e., duplicate literals introduced during the series of resolutions and removed implicitly by the SAT solver. A naive removal via addition of premises to the chain resolution can lead to exponential behavior, so instead the removal is done by breaking the resolution and applying a factoring step midway through. This guarantees non-exponential behavior.
Diffstat (limited to 'src/smt/proof_post_processor.cpp')
-rw-r--r-- | src/smt/proof_post_processor.cpp | 345 |
1 files changed, 345 insertions, 0 deletions
diff --git a/src/smt/proof_post_processor.cpp b/src/smt/proof_post_processor.cpp index 049eb02c0..a620a4d22 100644 --- a/src/smt/proof_post_processor.cpp +++ b/src/smt/proof_post_processor.cpp @@ -130,6 +130,246 @@ bool ProofPostprocessCallback::updateInternal(Node res, return update(res, id, children, args, cdp, continueUpdate); } +Node ProofPostprocessCallback::eliminateCrowdingLits( + const std::vector<Node>& clauseLits, + const std::vector<Node>& targetClauseLits, + const std::vector<Node>& children, + const std::vector<Node>& args, + CDProof* cdp) +{ + NodeManager* nm = NodeManager::currentNM(); + Node trueNode = nm->mkConst(true); + // get crowding lits and the position of the last clause that includes + // them. The factoring step must be added after the last inclusion and before + // its elimination. + std::unordered_set<TNode, TNodeHashFunction> crowding; + std::vector<std::pair<Node, size_t>> lastInclusion; + // positions of eliminators of crowding literals, which are the positions of + // the clauses that eliminate crowding literals *after* their last inclusion + std::vector<size_t> eliminators; + for (size_t i = 0, size = clauseLits.size(); i < size; ++i) + { + if (!crowding.count(clauseLits[i]) + && std::find( + targetClauseLits.begin(), targetClauseLits.end(), clauseLits[i]) + == targetClauseLits.end()) + { + Node crowdLit = clauseLits[i]; + crowding.insert(crowdLit); + // found crowding lit, now get its last inclusion position, which is the + // position of the last resolution link that introduces the crowding + // literal. Note that this position has to be *before* the last link, as a + // link *after* the last inclusion must eliminate the crowding literal. + size_t j; + for (j = children.size() - 1; j > 0; --j) + { + // notice that only non-unit clauses may be introducing the crowding + // literal, so we don't need to differentiate unit from non-unit + if (children[j - 1].getKind() != kind::OR) + { + continue; + } + if (std::find(children[j - 1].begin(), children[j - 1].end(), crowdLit) + != children[j - 1].end()) + { + break; + } + } + Assert(j > 0); + lastInclusion.emplace_back(crowdLit, j - 1); + Trace("smt-proof-pp-debug2") << "crowding lit " << crowdLit << "\n"; + Trace("smt-proof-pp-debug2") << "last inc " << j - 1 << "\n"; + // get elimination position, starting from the following link as the last + // inclusion one. The result is the last (in the chain, but first from + // this point on) resolution link that eliminates the crowding literal. A + // literal l is eliminated by a link if it contains a literal l' with + // opposite polarity to l. + for (; j < children.size(); ++j) + { + bool posFirst = args[(2 * j) - 1] == trueNode; + Node pivot = args[(2 * j)]; + Trace("smt-proof-pp-debug2") + << "\tcheck w/ args " << posFirst << " / " << pivot << "\n"; + // To eliminate the crowding literal (crowdLit), the clause must contain + // it with opposite polarity. There are three successful cases, + // according to the pivot and its sign + // + // - crowdLit is the same as the pivot and posFirst is true, which means + // that the clause contains its negation and eliminates it + // + // - crowdLit is the negation of the pivot and posFirst is false, so the + // clause contains the node whose negation is crowdLit. Note that this + // case may either be crowdLit.notNode() == pivot or crowdLit == + // pivot.notNode(). + if ((crowdLit == pivot && posFirst) + || (crowdLit.notNode() == pivot && !posFirst) + || (pivot.notNode() == crowdLit && !posFirst)) + { + Trace("smt-proof-pp-debug2") << "\t\tfound it!\n"; + eliminators.push_back(j); + break; + } + } + Assert(j < children.size()); + } + } + Assert(!lastInclusion.empty()); + // order map so that we process crowding literals in the order of the clauses + // that last introduce them + auto cmp = [](std::pair<Node, size_t>& a, std::pair<Node, size_t>& b) { + return a.second < b.second; + }; + std::sort(lastInclusion.begin(), lastInclusion.end(), cmp); + // order eliminators + std::sort(eliminators.begin(), eliminators.end()); + if (Trace.isOn("smt-proof-pp-debug")) + { + Trace("smt-proof-pp-debug") << "crowding lits last inclusion:\n"; + for (const auto& pair : lastInclusion) + { + Trace("smt-proof-pp-debug") + << "\t- [" << pair.second << "] : " << pair.first << "\n"; + } + Trace("smt-proof-pp-debug") << "eliminators:"; + for (size_t elim : eliminators) + { + Trace("smt-proof-pp-debug") << " " << elim; + } + Trace("smt-proof-pp-debug") << "\n"; + } + // TODO (cvc4-wishues/issues/77): implement also simpler version and compare + // + // We now start to break the chain, one step at a time. Naively this breaking + // down would be one resolution/factoring to each crowding literal, but we can + // merge some of the cases. Effectively we do the following: + // + // + // lastClause children[start] ... children[end] + // ---------------------------------------------- CHAIN_RES + // C + // ----------- FACTORING + // lastClause' children[start'] ... children[end'] + // -------------------------------------------------------------- CHAIN_RES + // ... + // + // where + // lastClause_0 = children[0] + // start_0 = 1 + // end_0 = eliminators[0] - 1 + // start_i+1 = nextGuardedElimPos - 1 + // + // The important point is how end_i+1 is computed. It is based on what we call + // the "nextGuardedElimPos", i.e., the next elimination position that requires + // removal of duplicates. The intuition is that a factoring step may eliminate + // the duplicates of crowding literals l1 and l2. If the last inclusion of l2 + // is before the elimination of l1, then we can go ahead and also perform the + // elimination of l2 without another factoring. However if another literal l3 + // has its last inclusion after the elimination of l2, then the elimination of + // l3 is the next guarded elimination. + // + // To do the above computation then we determine, after a resolution/factoring + // step, the first crowded literal to have its last inclusion after "end". The + // first elimination position to be bigger than the position of that crowded + // literal is the next guarded elimination position. + size_t lastElim = 0; + Node lastClause = children[0]; + std::vector<Node> childrenRes; + std::vector<Node> childrenResArgs; + Node resPlaceHolder; + size_t nextGuardedElimPos = eliminators[0]; + do + { + size_t start = lastElim + 1; + size_t end = nextGuardedElimPos - 1; + Trace("smt-proof-pp-debug2") + << "res with:\n\tlastClause: " << lastClause << "\n\tstart: " << start + << "\n\tend: " << end << "\n"; + childrenRes.push_back(lastClause); + // note that the interval of insert is exclusive in the end, so we add 1 + childrenRes.insert(childrenRes.end(), + children.begin() + start, + children.begin() + end + 1); + childrenResArgs.insert(childrenResArgs.end(), + args.begin() + (2 * start) - 1, + args.begin() + (2 * end) + 1); + Trace("smt-proof-pp-debug2") << "res children: " << childrenRes << "\n"; + Trace("smt-proof-pp-debug2") << "res args: " << childrenResArgs << "\n"; + resPlaceHolder = d_pnm->getChecker()->checkDebug(PfRule::CHAIN_RESOLUTION, + childrenRes, + childrenResArgs, + Node::null(), + ""); + Trace("smt-proof-pp-debug2") + << "resPlaceHorder: " << resPlaceHolder << "\n"; + cdp->addStep( + resPlaceHolder, PfRule::CHAIN_RESOLUTION, childrenRes, childrenResArgs); + // I need to add factoring if end < children.size(). Otherwise, this is + // to be handled by the caller + if (end < children.size() - 1) + { + lastClause = d_pnm->getChecker()->checkDebug( + PfRule::FACTORING, {resPlaceHolder}, {}, Node::null(), ""); + if (!lastClause.isNull()) + { + cdp->addStep(lastClause, PfRule::FACTORING, {resPlaceHolder}, {}); + } + else + { + lastClause = resPlaceHolder; + } + Trace("smt-proof-pp-debug2") << "lastClause: " << lastClause << "\n"; + } + else + { + lastClause = resPlaceHolder; + break; + } + // update for next round + childrenRes.clear(); + childrenResArgs.clear(); + lastElim = end; + + // find the position of the last inclusion of the next crowded literal + size_t nextCrowdedInclusionPos = lastInclusion.size(); + for (size_t i = 0, size = lastInclusion.size(); i < size; ++i) + { + if (lastInclusion[i].second > lastElim) + { + nextCrowdedInclusionPos = i; + break; + } + } + Trace("smt-proof-pp-debug2") + << "nextCrowdedInclusion/Pos: " + << lastInclusion[nextCrowdedInclusionPos].second << "/" + << nextCrowdedInclusionPos << "\n"; + // if there is none, then the remaining literals will be used in the next + // round + if (nextCrowdedInclusionPos == lastInclusion.size()) + { + nextGuardedElimPos = children.size(); + } + else + { + nextGuardedElimPos = children.size(); + for (size_t i = 0, size = eliminators.size(); i < size; ++i) + { + // nextGuardedElimPos is the largest element of + // eliminators bigger the next crowded literal's last inclusion + if (eliminators[i] > lastInclusion[nextCrowdedInclusionPos].second) + { + nextGuardedElimPos = eliminators[i]; + break; + } + } + Assert(nextGuardedElimPos < children.size()); + } + Trace("smt-proof-pp-debug2") + << "nextGuardedElimPos: " << nextGuardedElimPos << "\n"; + } while (true); + return lastClause; +} + Node ProofPostprocessCallback::expandMacros(PfRule id, const std::vector<Node>& children, const std::vector<Node>& args, @@ -375,6 +615,111 @@ Node ProofPostprocessCallback::expandMacros(PfRule id, cdp->addStep(args[0], PfRule::EQ_RESOLVE, {children[0], eq}, {}); return args[0]; } + else if (id == PfRule::MACRO_RESOLUTION) + { + // first generate the naive chain_resolution + std::vector<Node> chainResArgs{args.begin() + 1, args.end()}; + Node chainConclusion = d_pnm->getChecker()->checkDebug( + PfRule::CHAIN_RESOLUTION, children, chainResArgs, Node::null(), ""); + Trace("smt-proof-pp-debug") << "Original conclusion: " << args[0] << "\n"; + Trace("smt-proof-pp-debug") + << "chainRes conclusion: " << chainConclusion << "\n"; + // There are n cases: + // - if the conclusion is the same, just replace + // - if they have the same literals but in different quantity, add a + // FACTORING step + // - if the order is not the same, add a REORDERING step + // - if there are literals in chainConclusion that are not in the original + // conclusion, we need to transform the MACRO_RESOLUTION into a series of + // CHAIN_RESOLUTION + FACTORING steps, so that we explicitly eliminate all + // these "crowding" literals. We do this via FACTORING so we avoid adding + // an exponential number of premises, which would happen if we just + // repeated in the premises the clauses needed for eliminating crowding + // literals, which could themselves add crowding literals. + if (chainConclusion == args[0]) + { + cdp->addStep( + chainConclusion, PfRule::CHAIN_RESOLUTION, children, chainResArgs); + return chainConclusion; + } + NodeManager* nm = NodeManager::currentNM(); + // If we got here, then chainConclusion is NECESSARILY an OR node + Assert(chainConclusion.getKind() == kind::OR); + // get the literals in the chain conclusion + std::vector<Node> chainConclusionLits{chainConclusion.begin(), + chainConclusion.end()}; + std::set<Node> chainConclusionLitsSet{chainConclusion.begin(), + chainConclusion.end()}; + // is args[0] a unit clause? If it's not an OR node, then yes. Otherwise, + // it's only a unit if it occurs in chainConclusionLitsSet + std::vector<Node> conclusionLits; + // whether conclusion is unit + if (chainConclusionLitsSet.count(args[0])) + { + conclusionLits.push_back(args[0]); + } + else + { + Assert(args[0].getKind() == kind::OR); + conclusionLits.insert( + conclusionLits.end(), args[0].begin(), args[0].end()); + } + std::set<Node> conclusionLitsSet{conclusionLits.begin(), + conclusionLits.end()}; + // If the sets are different, there are "crowding" literals, i.e. literals + // that were removed by implicit multi-usage of premises in the resolution + // chain. + if (chainConclusionLitsSet != conclusionLitsSet) + { + chainConclusion = eliminateCrowdingLits( + chainConclusionLits, conclusionLits, children, args, cdp); + // update vector of lits. Note that the set is no longer used, so we don't + // need to update it + chainConclusionLits.clear(); + chainConclusionLits.insert(chainConclusionLits.end(), + chainConclusion.begin(), + chainConclusion.end()); + } + else + { + cdp->addStep( + chainConclusion, PfRule::CHAIN_RESOLUTION, children, chainResArgs); + } + // Placeholder for running conclusion + Node n = chainConclusion; + // factoring + if (chainConclusionLits.size() != conclusionLits.size()) + { + // We build it rather than taking conclusionLits because the order may be + // different + std::vector<Node> factoredLits; + std::unordered_set<TNode, TNodeHashFunction> clauseSet; + for (size_t i = 0, size = chainConclusionLits.size(); i < size; ++i) + { + if (clauseSet.count(chainConclusionLits[i])) + { + continue; + } + factoredLits.push_back(n[i]); + clauseSet.insert(n[i]); + } + Node factored = factoredLits.empty() + ? nm->mkConst(false) + : factoredLits.size() == 1 + ? factoredLits[0] + : nm->mkNode(kind::OR, factoredLits); + cdp->addStep(factored, PfRule::FACTORING, {n}, {}); + n = factored; + } + // either same node or n as a clause + Assert(n == args[0] || n.getKind() == kind::OR); + // reordering + if (n != args[0]) + { + cdp->addStep(args[0], PfRule::REORDERING, {n}, {args[0]}); + } + return args[0]; + } else if (id == PfRule::SUBS) { NodeManager* nm = NodeManager::currentNM(); |