diff options
Diffstat (limited to 'src/theory/arith/congruence_manager.cpp')
-rw-r--r-- | src/theory/arith/congruence_manager.cpp | 374 |
1 files changed, 304 insertions, 70 deletions
diff --git a/src/theory/arith/congruence_manager.cpp b/src/theory/arith/congruence_manager.cpp index 858098b70..57214e0f8 100644 --- a/src/theory/arith/congruence_manager.cpp +++ b/src/theory/arith/congruence_manager.cpp @@ -2,10 +2,10 @@ /*! \file congruence_manager.cpp ** \verbatim ** Top contributors (to current version): - ** Tim King, Dejan Jovanovic, Paul Meng + ** Tim King, Andrew Reynolds, Dejan Jovanovic ** This file is part of the CVC4 project. ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS - ** in the top-level source directory) and their institutional affiliations. + ** in the top-level source directory and their institutional affiliations. ** All rights reserved. See the file COPYING in the top-level source ** directory for licensing information.\endverbatim ** @@ -29,10 +29,12 @@ namespace arith { ArithCongruenceManager::ArithCongruenceManager( context::Context* c, + context::UserContext* u, ConstraintDatabase& cd, SetupLiteralCallBack setup, const ArithVariables& avars, - RaiseEqualityEngineConflict raiseConflict) + RaiseEqualityEngineConflict raiseConflict, + ProofNodeManager* pnm) : d_inConflict(c), d_raiseConflict(raiseConflict), d_notify(*this), @@ -42,16 +44,44 @@ ArithCongruenceManager::ArithCongruenceManager( d_constraintDatabase(cd), d_setupLiteral(setup), d_avariables(avars), - d_ee(d_notify, c, "theory::arith::ArithCongruenceManager", true) + d_ee(nullptr), + d_satContext(c), + d_userContext(u), + d_pnm(pnm), + // Construct d_pfGenEe with the SAT context, since its proof include + // unclosed assumptions of theory literals. + d_pfGenEe( + new EagerProofGenerator(pnm, c, "ArithCongruenceManager::pfGenEe")), + // Construct d_pfGenEe with the USER context, since its proofs are closed. + d_pfGenExplain(new EagerProofGenerator( + pnm, u, "ArithCongruenceManager::pfGenExplain")), + d_pfee(nullptr) { - d_ee.addFunctionKind(kind::NONLINEAR_MULT); - d_ee.addFunctionKind(kind::EXPONENTIAL); - d_ee.addFunctionKind(kind::SINE); - d_ee.addFunctionKind(kind::IAND); } ArithCongruenceManager::~ArithCongruenceManager() {} +bool ArithCongruenceManager::needsEqualityEngine(EeSetupInfo& esi) +{ + esi.d_notify = &d_notify; + esi.d_name = "theory::arith::ArithCongruenceManager"; + return true; +} + +void ArithCongruenceManager::finishInit(eq::EqualityEngine* ee, + eq::ProofEqEngine* pfee) +{ + Assert(ee != nullptr); + d_ee = ee; + d_ee->addFunctionKind(kind::NONLINEAR_MULT); + d_ee->addFunctionKind(kind::EXPONENTIAL); + d_ee->addFunctionKind(kind::SINE); + d_ee->addFunctionKind(kind::IAND); + // have proof equality engine only if proofs are enabled + Assert(isProofEnabled() == (pfee != nullptr)); + d_pfee = pfee; +} + ArithCongruenceManager::Statistics::Statistics(): d_watchedVariables("theory::arith::congruence::watchedVariables", 0), d_watchedVariableIsZero("theory::arith::congruence::watchedVariableIsZero", 0), @@ -84,16 +114,17 @@ ArithCongruenceManager::ArithCongruenceNotify::ArithCongruenceNotify(ArithCongru : d_acm(acm) {} -bool ArithCongruenceManager::ArithCongruenceNotify::eqNotifyTriggerEquality(TNode equality, bool value) { - Debug("arith::congruences") << "ArithCongruenceNotify::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false") << ")" << std::endl; +bool ArithCongruenceManager::ArithCongruenceNotify::eqNotifyTriggerPredicate( + TNode predicate, bool value) +{ + Assert(predicate.getKind() == kind::EQUAL); + Debug("arith::congruences") + << "ArithCongruenceNotify::eqNotifyTriggerPredicate(" << predicate << ", " + << (value ? "true" : "false") << ")" << std::endl; if (value) { - return d_acm.propagate(equality); - } else { - return d_acm.propagate(equality.notNode()); + return d_acm.propagate(predicate); } -} -bool ArithCongruenceManager::ArithCongruenceNotify::eqNotifyTriggerPredicate(TNode predicate, bool value) { - Unreachable(); + return d_acm.propagate(predicate.notNode()); } bool ArithCongruenceManager::ArithCongruenceNotify::eqNotifyTriggerTermEquality(TheoryId tag, TNode t1, TNode t2, bool value) { @@ -110,18 +141,20 @@ void ArithCongruenceManager::ArithCongruenceNotify::eqNotifyConstantTermMerge(TN } void ArithCongruenceManager::ArithCongruenceNotify::eqNotifyNewClass(TNode t) { } -void ArithCongruenceManager::ArithCongruenceNotify::eqNotifyPreMerge(TNode t1, TNode t2) { -} -void ArithCongruenceManager::ArithCongruenceNotify::eqNotifyPostMerge(TNode t1, TNode t2) { +void ArithCongruenceManager::ArithCongruenceNotify::eqNotifyMerge(TNode t1, + TNode t2) +{ } void ArithCongruenceManager::ArithCongruenceNotify::eqNotifyDisequal(TNode t1, TNode t2, TNode reason) { } -void ArithCongruenceManager::raiseConflict(Node conflict){ +void ArithCongruenceManager::raiseConflict(Node conflict, + std::shared_ptr<ProofNode> pf) +{ Assert(!inConflict()); Debug("arith::conflict") << "difference manager conflict " << conflict << std::endl; d_inConflict.raise(); - d_raiseConflict.raiseEEConflict(conflict); + d_raiseConflict.raiseEEConflict(conflict, pf); } bool ArithCongruenceManager::inConflict() const{ return d_inConflict.isRaised(); @@ -141,10 +174,6 @@ bool ArithCongruenceManager::canExplain(TNode n) const { return d_explanationMap.find(n) != d_explanationMap.end(); } -void ArithCongruenceManager::setMasterEqualityEngine(eq::EqualityEngine* eq) { - d_ee.setMasterEqualityEngine(eq); -} - Node ArithCongruenceManager::externalToInternal(TNode n) const{ Assert(canExplain(n)); ExplainMap::const_iterator iter = d_explanationMap.find(n); @@ -184,13 +213,32 @@ void ArithCongruenceManager::watchedVariableIsZero(ConstraintCP lb, ConstraintCP ++(d_statistics.d_watchedVariableIsZero); ArithVar s = lb->getVariable(); - Node reason = Constraint::externalExplainByAssertions(lb,ub); + TNode eq = d_watchedEqualities[s]; + ConstraintCP eqC = d_constraintDatabase.getConstraint( + s, ConstraintType::Equality, lb->getValue()); + NodeBuilder<> reasonBuilder(Kind::AND); + auto pfLb = lb->externalExplainByAssertions(reasonBuilder); + auto pfUb = ub->externalExplainByAssertions(reasonBuilder); + Node reason = safeConstructNary(reasonBuilder); + std::shared_ptr<ProofNode> pf{}; + if (isProofEnabled()) + { + pf = d_pnm->mkNode( + PfRule::ARITH_TRICHOTOMY, {pfLb, pfUb}, {eqC->getProofLiteral()}); + pf = d_pnm->mkNode(PfRule::MACRO_SR_PRED_TRANSFORM, {pf}, {eq}); + } d_keepAlive.push_back(reason); - assertionToEqualityEngine(true, s, reason); + Trace("arith-ee") << "Asserting an equality on " << s << ", on trichotomy" + << std::endl; + Trace("arith-ee") << " based on " << lb << std::endl; + Trace("arith-ee") << " based on " << ub << std::endl; + assertionToEqualityEngine(true, s, reason, pf); } void ArithCongruenceManager::watchedVariableIsZero(ConstraintCP eq){ + Debug("arith::cong") << "Cong::watchedVariableIsZero: " << *eq << std::endl; + Assert(eq->isEquality()); Assert(eq->getValue().sgn() == 0); @@ -201,23 +249,86 @@ void ArithCongruenceManager::watchedVariableIsZero(ConstraintCP eq){ //Explain for conflict is correct as these proofs are generated //and stored eagerly //These will be safe for propagation later as well - Node reason = eq->externalExplainByAssertions(); + NodeBuilder<> nb(Kind::AND); + // An open proof of eq from literals now in reason. + if (Debug.isOn("arith::cong")) + { + eq->printProofTree(Debug("arith::cong")); + } + auto pf = eq->externalExplainByAssertions(nb); + if (isProofEnabled()) + { + pf = d_pnm->mkNode( + PfRule::MACRO_SR_PRED_TRANSFORM, {pf}, {d_watchedEqualities[s]}); + } + Node reason = safeConstructNary(nb); d_keepAlive.push_back(reason); - assertionToEqualityEngine(true, s, reason); + assertionToEqualityEngine(true, s, reason, pf); } void ArithCongruenceManager::watchedVariableCannotBeZero(ConstraintCP c){ + Debug("arith::cong::notzero") + << "Cong::watchedVariableCannotBeZero " << *c << std::endl; ++(d_statistics.d_watchedVariableIsNotZero); ArithVar s = c->getVariable(); + Node disEq = d_watchedEqualities[s].negate(); //Explain for conflict is correct as these proofs are generated and stored eagerly //These will be safe for propagation later as well - Node reason = c->externalExplainByAssertions(); - + NodeBuilder<> nb(Kind::AND); + // An open proof of eq from literals now in reason. + auto pf = c->externalExplainByAssertions(nb); + if (Debug.isOn("arith::cong::notzero")) + { + Debug("arith::cong::notzero") << " original proof "; + pf->printDebug(Debug("arith::cong::notzero")); + Debug("arith::cong::notzero") << std::endl; + } + Node reason = safeConstructNary(nb); + if (isProofEnabled()) + { + if (c->getType() == ConstraintType::Disequality) + { + Assert(c->getLiteral() == d_watchedEqualities[s].negate()); + // We have to prove equivalence to the watched disequality. + pf = d_pnm->mkNode(PfRule::MACRO_SR_PRED_TRANSFORM, {pf}, {disEq}); + } + else + { + Debug("arith::cong::notzero") + << " proof modification needed" << std::endl; + + // Four cases: + // c has form x_i = d, d > 0 => multiply c by -1 in Farkas proof + // c has form x_i = d, d > 0 => multiply c by 1 in Farkas proof + // c has form x_i <= d, d < 0 => multiply c by 1 in Farkas proof + // c has form x_i >= d, d > 0 => multiply c by -1 in Farkas proof + const bool scaleCNegatively = c->getType() == ConstraintType::LowerBound + || (c->getType() == ConstraintType::Equality + && c->getValue().sgn() > 0); + const int cSign = scaleCNegatively ? -1 : 1; + TNode isZero = d_watchedEqualities[s]; + const auto isZeroPf = d_pnm->mkAssume(isZero); + const auto nm = NodeManager::currentNM(); + const auto sumPf = d_pnm->mkNode( + PfRule::ARITH_SCALE_SUM_UPPER_BOUNDS, + {isZeroPf, pf}, + // Trick for getting correct, opposing signs. + {nm->mkConst(Rational(-1 * cSign)), nm->mkConst(Rational(cSign))}); + const auto botPf = d_pnm->mkNode( + PfRule::MACRO_SR_PRED_TRANSFORM, {sumPf}, {nm->mkConst(false)}); + std::vector<Node> assumption = {isZero}; + pf = d_pnm->mkScope(botPf, assumption, false); + Debug("arith::cong::notzero") << " new proof "; + pf->printDebug(Debug("arith::cong::notzero")); + Debug("arith::cong::notzero") << std::endl; + } + Assert(pf->getResult() == disEq); + } d_keepAlive.push_back(reason); - assertionToEqualityEngine(false, s, reason); + assertionToEqualityEngine(false, s, reason, pf); } @@ -236,11 +347,22 @@ bool ArithCongruenceManager::propagate(TNode x){ if(rewritten.getConst<bool>()){ return true; }else{ + // x rewrites to false. ++(d_statistics.d_conflicts); - - Node conf = flattenAnd(explainInternal(x)); - raiseConflict(conf); + TrustNode trn = explainInternal(x); + Node conf = flattenAnd(trn.getNode()); Debug("arith::congruenceManager") << "rewritten to false "<<x<<" with explanation "<< conf << std::endl; + if (isProofEnabled()) + { + auto pf = trn.getGenerator()->getProofFor(trn.getProven()); + auto confPf = d_pnm->mkNode( + PfRule::MACRO_SR_PRED_TRANSFORM, {pf}, {conf.negate()}); + raiseConflict(conf, confPf); + } + else + { + raiseConflict(conf); + } return false; } } @@ -262,9 +384,10 @@ bool ArithCongruenceManager::propagate(TNode x){ << c->negationHasProof() << std::endl; if(c->negationHasProof()){ - Node expC = explainInternal(x); + TrustNode texpC = explainInternal(x); + Node expC = texpC.getNode(); ConstraintCP negC = c->getNegation(); - Node neg = negC->externalExplainByAssertions(); + Node neg = Constraint::externalExplainByAssertions({negC}); Node conf = expC.andNode(neg); Node final = flattenAnd(conf); @@ -320,9 +443,9 @@ bool ArithCongruenceManager::propagate(TNode x){ void ArithCongruenceManager::explain(TNode literal, std::vector<TNode>& assumptions) { if (literal.getKind() != kind::NOT) { - d_ee.explainEquality(literal[0], literal[1], true, assumptions); + d_ee->explainEquality(literal[0], literal[1], true, assumptions); } else { - d_ee.explainEquality(literal[0][0], literal[0][1], false, assumptions); + d_ee->explainEquality(literal[0][0], literal[0][1], false, assumptions); } } @@ -334,28 +457,44 @@ void ArithCongruenceManager::enqueueIntoNB(const std::set<TNode> s, NodeBuilder< } } -Node ArithCongruenceManager::explainInternal(TNode internal){ - std::vector<TNode> assumptions; - explain(internal, assumptions); - - std::set<TNode> assumptionSet; - assumptionSet.insert(assumptions.begin(), assumptions.end()); - - if (assumptionSet.size() == 1) { - // All the same, or just one - return assumptions[0]; - }else{ - NodeBuilder<> conjunction(kind::AND); - enqueueIntoNB(assumptionSet, conjunction); - return conjunction; +TrustNode ArithCongruenceManager::explainInternal(TNode internal) +{ + if (isProofEnabled()) + { + return d_pfee->explain(internal); } + // otherwise, explain without proof generator + Node exp = d_ee->mkExplainLit(internal); + return TrustNode::mkTrustPropExp(internal, exp, nullptr); } -Node ArithCongruenceManager::explain(TNode external){ +TrustNode ArithCongruenceManager::explain(TNode external) +{ Trace("arith-ee") << "Ask for explanation of " << external << std::endl; Node internal = externalToInternal(external); Trace("arith-ee") << "...internal = " << internal << std::endl; - return explainInternal(internal); + TrustNode trn = explainInternal(internal); + if (isProofEnabled() && trn.getProven()[1] != external) + { + Assert(trn.getKind() == TrustNodeKind::PROP_EXP); + Assert(trn.getProven().getKind() == Kind::IMPLIES); + Assert(trn.getGenerator() != nullptr); + Trace("arith-ee") << "tweaking proof to prove " << external << " not " + << trn.getProven()[1] << std::endl; + std::vector<std::shared_ptr<ProofNode>> assumptionPfs; + std::vector<Node> assumptions = andComponents(trn.getNode()); + assumptionPfs.push_back(trn.toProofNode()); + for (const auto& a : assumptions) + { + assumptionPfs.push_back( + d_pnm->mkNode(PfRule::TRUE_INTRO, {d_pnm->mkAssume(a)}, {})); + } + auto litPf = d_pnm->mkNode( + PfRule::MACRO_SR_PRED_TRANSFORM, {assumptionPfs}, {external}); + auto extPf = d_pnm->mkScope(litPf, assumptions); + return d_pfGenExplain->mkTrustedPropagation(external, trn.getNode(), extPf); + } + return trn; } void ArithCongruenceManager::explain(TNode external, NodeBuilder<>& out){ @@ -384,18 +523,86 @@ void ArithCongruenceManager::addWatchedPair(ArithVar s, TNode x, TNode y){ d_watchedEqualities.set(s, eq); } -void ArithCongruenceManager::assertionToEqualityEngine(bool isEquality, ArithVar s, TNode reason){ +void ArithCongruenceManager::assertLitToEqualityEngine( + Node lit, TNode reason, std::shared_ptr<ProofNode> pf) +{ + bool isEquality = lit.getKind() != Kind::NOT; + Node eq = isEquality ? lit : lit[0]; + Assert(eq.getKind() == Kind::EQUAL); + + Trace("arith-ee") << "Assert to Eq " << lit << ", reason " << reason + << std::endl; + if (isProofEnabled()) + { + if (CDProof::isSame(lit, reason)) + { + Trace("arith-pfee") << "Asserting only, b/c implied by symm" << std::endl; + // The equality engine doesn't ref-count for us... + d_keepAlive.push_back(eq); + d_keepAlive.push_back(reason); + d_ee->assertEquality(eq, isEquality, reason); + } + else if (hasProofFor(lit)) + { + Trace("arith-pfee") << "Skipping b/c already done" << std::endl; + } + else + { + setProofFor(lit, pf); + Trace("arith-pfee") << "Actually asserting" << std::endl; + if (Debug.isOn("arith-pfee")) + { + Trace("arith-pfee") << "Proof: "; + pf->printDebug(Trace("arith-pfee")); + Trace("arith-pfee") << std::endl; + } + // The proof equality engine *does* ref-count for us... + d_pfee->assertFact(lit, reason, d_pfGenEe.get()); + } + } + else + { + // The equality engine doesn't ref-count for us... + d_keepAlive.push_back(eq); + d_keepAlive.push_back(reason); + d_ee->assertEquality(eq, isEquality, reason); + } +} + +void ArithCongruenceManager::assertionToEqualityEngine( + bool isEquality, ArithVar s, TNode reason, std::shared_ptr<ProofNode> pf) +{ Assert(isWatchedVariable(s)); TNode eq = d_watchedEqualities[s]; Assert(eq.getKind() == kind::EQUAL); - Trace("arith-ee") << "Assert " << eq << ", pol " << isEquality << ", reason " << reason << std::endl; - if(isEquality){ - d_ee.assertEquality(eq, true, reason); - }else{ - d_ee.assertEquality(eq, false, reason); + Node lit = isEquality ? Node(eq) : eq.notNode(); + Trace("arith-ee") << "Assert to Eq " << eq << ", pol " << isEquality + << ", reason " << reason << std::endl; + assertLitToEqualityEngine(lit, reason, pf); +} + +bool ArithCongruenceManager::hasProofFor(TNode f) const +{ + Assert(isProofEnabled()); + if (d_pfGenEe->hasProofFor(f)) + { + return true; } + Node sym = CDProof::getSymmFact(f); + Assert(!sym.isNull()); + return d_pfGenEe->hasProofFor(sym); +} + +void ArithCongruenceManager::setProofFor(TNode f, + std::shared_ptr<ProofNode> pf) const +{ + Assert(!hasProofFor(f)); + d_pfGenEe->mkTrustNode(f, pf); + Node symF = CDProof::getSymmFact(f); + auto symPf = d_pnm->mkNode(PfRule::SYMM, {pf}, {}); + d_pfGenEe->mkTrustNode(symF, symPf); } void ArithCongruenceManager::equalsConstant(ConstraintCP c){ @@ -408,16 +615,18 @@ void ArithCongruenceManager::equalsConstant(ConstraintCP c){ Node xAsNode = d_avariables.asNode(x); Node asRational = mkRationalNode(c->getValue().getNoninfinitesimalPart()); - - //No guarentee this is in normal form! + // No guarentee this is in normal form! + // Note though, that it happens to be in proof normal form! Node eq = xAsNode.eqNode(asRational); d_keepAlive.push_back(eq); - Node reason = c->externalExplainByAssertions(); + NodeBuilder<> nb(Kind::AND); + auto pf = c->externalExplainByAssertions(nb); + Node reason = safeConstructNary(nb); d_keepAlive.push_back(reason); Trace("arith-ee") << "Assert equalsConstant " << eq << ", reason " << reason << std::endl; - d_ee.assertEquality(eq, true, reason); + assertLitToEqualityEngine(eq, reason, pf); } void ArithCongruenceManager::equalsConstant(ConstraintCP lb, ConstraintCP ub){ @@ -430,22 +639,47 @@ void ArithCongruenceManager::equalsConstant(ConstraintCP lb, ConstraintCP ub){ << ub << std::endl; ArithVar x = lb->getVariable(); - Node reason = Constraint::externalExplainByAssertions(lb,ub); + NodeBuilder<> nb(Kind::AND); + auto pfLb = lb->externalExplainByAssertions(nb); + auto pfUb = ub->externalExplainByAssertions(nb); + Node reason = safeConstructNary(nb); Node xAsNode = d_avariables.asNode(x); Node asRational = mkRationalNode(lb->getValue().getNoninfinitesimalPart()); - //No guarentee this is in normal form! + // No guarentee this is in normal form! + // Note though, that it happens to be in proof normal form! Node eq = xAsNode.eqNode(asRational); + std::shared_ptr<ProofNode> pf; + if (isProofEnabled()) + { + pf = d_pnm->mkNode(PfRule::ARITH_TRICHOTOMY, {pfLb, pfUb}, {eq}); + } d_keepAlive.push_back(eq); d_keepAlive.push_back(reason); Trace("arith-ee") << "Assert equalsConstant2 " << eq << ", reason " << reason << std::endl; - d_ee.assertEquality(eq, true, reason); + + assertLitToEqualityEngine(eq, reason, pf); } -void ArithCongruenceManager::addSharedTerm(Node x){ - d_ee.addTriggerTerm(x, THEORY_ARITH); +bool ArithCongruenceManager::isProofEnabled() const { return d_pnm != nullptr; } + +std::vector<Node> andComponents(TNode an) +{ + auto nm = NodeManager::currentNM(); + if (an == nm->mkConst(true)) + { + return {}; + } + else if (an.getKind() != Kind::AND) + { + return {an}; + } + std::vector<Node> a{}; + a.reserve(an.getNumChildren()); + a.insert(a.end(), an.begin(), an.end()); + return a; } }/* CVC4::theory::arith namespace */ |