diff options
author | Dejan Jovanović <dejan.jovanovic@gmail.com> | 2011-09-15 06:53:33 +0000 |
---|---|---|
committer | Dejan Jovanović <dejan.jovanovic@gmail.com> | 2011-09-15 06:53:33 +0000 |
commit | 72f552ead344b13d90832222157b970ae3dec8ff (patch) | |
tree | b02854356d5c5f98b3873f858f38b6762135bdc1 /src/theory/theory_engine.h | |
parent | 62a50760346e130345b24e8a14ad0dac0dca5d38 (diff) |
additional stuff for sharing,
Diffstat (limited to 'src/theory/theory_engine.h')
-rw-r--r-- | src/theory/theory_engine.h | 523 |
1 files changed, 240 insertions, 283 deletions
diff --git a/src/theory/theory_engine.h b/src/theory/theory_engine.h index 815a79a5a..04f15e89f 100644 --- a/src/theory/theory_engine.h +++ b/src/theory/theory_engine.h @@ -29,11 +29,12 @@ #include "expr/command.h" #include "prop/prop_engine.h" #include "context/cdset.h" -#include "theory/shared_term_manager.h" #include "theory/theory.h" #include "theory/substitutions.h" #include "theory/rewriter.h" #include "theory/substitutions.h" +#include "theory/shared_terms_database.h" +#include "theory/term_registration_visitor.h" #include "theory/valuation.h" #include "util/options.h" #include "util/stats.h" @@ -46,6 +47,25 @@ namespace CVC4 { // In terms of abstraction, this is below (and provides services to) // PropEngine. +struct NodeTheoryPair { + Node node; + theory::TheoryId theory; + NodeTheoryPair(Node node, theory::TheoryId theory) + : node(node), theory(theory) {} + NodeTheoryPair() + : theory(theory::THEORY_LAST) {} + bool operator == (const NodeTheoryPair& pair) const { + return node == pair.node && theory == pair.theory; + } +}; + +struct NodeTheoryPairHashFunction { + NodeHashFunction hashFunction; + size_t operator()(const NodeTheoryPair& pair) const { + return hashFunction(pair.node)*0x9e3779b9 + pair.theory; + } +}; + /** * This is essentially an abstraction for a collection of theories. A * TheoryEngine provides services to a PropEngine, making various @@ -60,7 +80,10 @@ class TheoryEngine { /** Our context */ context::Context* d_context; - /** A table of from theory IDs to theory pointers */ + /** + * A table of from theory IDs to theory pointers. Never use this table + * directly, use theoryOf() instead. + */ theory::Theory* d_theoryTable[theory::THEORY_LAST]; /** @@ -73,9 +96,15 @@ class TheoryEngine { theory::Theory::Set d_activeTheories; /** - * Cache from proprocessing of atoms. + * The database of shared terms. */ + SharedTermsDatabase d_sharedTerms; + typedef std::hash_map<Node, Node, NodeHashFunction> NodeMap; + + /** + * Cache from proprocessing of atoms. + */ NodeMap d_atomPreprocessingCache; /** @@ -92,6 +121,39 @@ class TheoryEngine { context::CDSet<TNode, TNodeHashFunction> d_hasPropagated; /** + * Statistics for a particular theory. + */ + class Statistics { + + static std::string mkName(std::string prefix, theory::TheoryId theory, std::string suffix) { + std::stringstream ss; + ss << prefix << theory << suffix; + return ss.str(); + } + + public: + + IntStat conflicts, propagations, lemmas; + + Statistics(theory::TheoryId theory): + conflicts(mkName("theory<", theory, ">::conflicts"), 0), + propagations(mkName("theory<", theory, ">::propagations"), 0), + lemmas(mkName("theory<", theory, ">::lemmas"), 0) + { + StatisticsRegistry::registerStat(&conflicts); + StatisticsRegistry::registerStat(&propagations); + StatisticsRegistry::registerStat(&lemmas); + } + + ~Statistics() { + StatisticsRegistry::unregisterStat(&conflicts); + StatisticsRegistry::unregisterStat(&propagations); + StatisticsRegistry::unregisterStat(&lemmas); + } + };/* class TheoryEngine::Statistics */ + + + /** * An output channel for Theory that passes messages * back to a TheoryEngine. */ @@ -99,113 +161,78 @@ class TheoryEngine { friend class TheoryEngine; + /** + * The theory engine we're communicating with. + */ TheoryEngine* d_engine; - context::Context* d_context; - context::CDO<bool> d_inConflict; - context::CDO<Node> d_explanationNode; /** - * Literals that are propagated by the theory. Note that these are TNodes. - * The theory can only propagate nodes that have an assigned literal in the - * sat solver and are hence referenced in the SAT solver. + * The statistics of the theory interractions. */ - std::vector<TNode> d_propagatedLiterals; + Statistics d_statistics; /** - * Check if the node is in conflict for debug purposes + * The theory owning this chanell. */ - bool isProperConflict(TNode conflictNode) { - bool value; - if (conflictNode.getKind() == kind::AND) { - for (unsigned i = 0; i < conflictNode.getNumChildren(); ++ i) { - if (!d_engine->getPropEngine()->hasValue(conflictNode[i], value)) return false; - if (!value) return false; - } - } else { - if (!d_engine->getPropEngine()->hasValue(conflictNode, value)) return false; - return value; - } - return true; - } + theory::TheoryId d_theory; public: - EngineOutputChannel(TheoryEngine* engine, context::Context* context) : + EngineOutputChannel(TheoryEngine* engine, theory::TheoryId theory) : d_engine(engine), - d_context(context), - d_inConflict(context, false), - d_explanationNode(context) { + d_statistics(theory), + d_theory(theory) + { } - void conflict(TNode conflictNode, bool safe) - throw(theory::Interrupted, AssertionException) { - Trace("theory") << "EngineOutputChannel::conflict(" << conflictNode << ")" << std::endl; - d_inConflict = true; - - if(Dump.isOn("t-conflicts")) { - Dump("t-conflicts") << CommentCommand("theory conflict: expect unsat") << std::endl - << CheckSatCommand(conflictNode.toExpr()) << std::endl; - } - Assert(d_engine->properConflict(conflictNode)); - ++(d_engine->d_statistics.d_statConflicts); - - // Construct the lemma (note that no CNF caching should happen as all the literals already exists) - Assert(isProperConflict(conflictNode)); - d_engine->newLemma(conflictNode, true, false); - - if(safe) { - throw theory::Interrupted(); - } + void conflict(TNode conflictNode) throw(AssertionException) { + Trace("theory") << "EngineOutputChannel<" << d_theory << ">::conflict(" << conflictNode << ")" << std::endl; + ++ d_statistics.conflicts; + d_engine->conflict(conflictNode); } - void propagate(TNode lit, bool) - throw(theory::Interrupted, AssertionException) { - Trace("theory") << "EngineOutputChannel::propagate(" - << lit << ")" << std::endl; - if(Dump.isOn("t-propagations")) { - Dump("t-propagations") - << CommentCommand("negation of theory propagation: expect valid") << std::endl - << QueryCommand(lit.toExpr()) << std::endl; - } - if(Dump.isOn("missed-t-propagations")) { - d_engine->d_hasPropagated.insert(lit); - } - Assert(d_engine->properPropagation(lit)); - d_propagatedLiterals.push_back(lit); - ++(d_engine->d_statistics.d_statPropagate); + void propagate(TNode literal) throw(AssertionException) { + Trace("theory") << "EngineOutputChannel<" << d_theory << ">::propagate(" << literal << ")" << std::endl; + ++ d_statistics.propagations; + d_engine->propagate(literal, d_theory); } - void lemma(TNode node, bool removable = false) - throw(theory::Interrupted, TypeCheckingExceptionPrivate, AssertionException) { - Trace("theory") << "EngineOutputChannel::lemma(" - << node << ")" << std::endl; - if(Dump.isOn("t-lemmas")) { - Dump("t-lemmas") << CommentCommand("theory lemma: expect valid") << std::endl - << QueryCommand(node.toExpr()) << std::endl; - } - ++(d_engine->d_statistics.d_statLemma); - - d_engine->newLemma(node, false, removable); + void lemma(TNode lemma, bool removable = false) throw(TypeCheckingExceptionPrivate, AssertionException) { + Trace("theory") << "EngineOutputChannel<" << d_theory << ">::lemma(" << lemma << ")" << std::endl; + ++ d_statistics.lemmas; + d_engine->lemma(lemma, false, removable); } - void explanation(TNode explanationNode, bool) - throw(theory::Interrupted, AssertionException) { - Trace("theory") << "EngineOutputChannel::explanation(" - << explanationNode << ")" << std::endl; - // handle dumping of explanations elsewhere.. - d_explanationNode = explanationNode; - ++(d_engine->d_statistics.d_statExplanation); - } - - void setIncomplete() throw(theory::Interrupted, AssertionException) { - d_engine->d_incomplete = true; + void setIncomplete() throw(AssertionException) { + d_engine->setIncomplete(d_theory); } };/* class EngineOutputChannel */ - EngineOutputChannel d_theoryOut; + /** + * Output channels for individual theories. + */ + EngineOutputChannel* d_theoryOut[theory::THEORY_LAST]; + + /** + * Are we in conflict. + */ + context::CDO<bool> d_inConflict; + + void conflict(TNode conflict) { + + Assert(properConflict(conflict)); - /** Pointer to Shared Term Manager */ - SharedTermManager* d_sharedTermManager; + // Mark that we are in conflict + d_inConflict = true; + + if(Dump.isOn("t-conflicts")) { + Dump("t-conflicts") << CommentCommand("theory conflict: expect unsat") << std::endl + << CheckSatCommand(conflict.toExpr()) << std::endl; + } + + // Construct the lemma (note that no CNF caching should happen as all the literals already exists) + lemma(conflict, true, false); + } /** * Debugging flag to ensure that shutdown() is called before the @@ -220,10 +247,10 @@ class TheoryEngine { context::CDO<bool> d_incomplete; /** - * Mark a theory active if it's not already. + * Called by the theories to notify that the current branch is incomplete. */ - void markActive(theory::Theory::Set theories) { - d_activeTheories = theory::Theory::setUnion(d_activeTheories, theories); + void setIncomplete(theory::TheoryId theory) { + d_incomplete = true; } /** @@ -233,9 +260,86 @@ class TheoryEngine { return theory::Theory::setContains(theory, d_activeTheories); } + struct SharedEquality { + /** The node/theory pair for the assertion */ + NodeTheoryPair toAssert; + /** This is the node/theory pair that we will use to explain it */ + NodeTheoryPair toExplain; + + SharedEquality(TNode assertion, TNode original, theory::TheoryId sendingTheory, theory::TheoryId receivingTheory) + : toAssert(assertion, sendingTheory), + toExplain(original, receivingTheory) + { } + }; + + /** + * A map from asserted facts to where they came from (for explanations). + */ + context::CDMap<NodeTheoryPair, NodeTheoryPair, NodeTheoryPairHashFunction> d_sharedAssertions; + + /** + * Assert a shared equalities propagated by theories. + */ + void assertSharedEqualities(); + /** The logic of the problem */ std::string d_logic; + /** + * Literals that are propagated by the theory. Note that these are TNodes. + * The theory can only propagate nodes that have an assigned literal in the + * sat solver and are hence referenced in the SAT solver. + */ + context::CDList<TNode> d_propagatedLiterals; + + /** + * The index of the next literal to be propagated by a theory. + */ + context::CDO<unsigned> d_propagatedLiteralsIndex; + + /** + * Shared term equalities that should be asserted to the individual theories. + */ + std::vector<SharedEquality> d_propagatedEqualities; + + /** + * Called by the output channel to propagate literals and facts + */ + void propagate(TNode literal, theory::TheoryId theory); + + /** + * Internal method to call the propagation routines and collect the + * propagated literals. + */ + void propagate(theory::Theory::Effort effort); + + /** + * A variable to mark if we added any lemmas. + */ + bool d_lemmasAdded; + + /** + * Adds a new lemma + */ + void lemma(TNode node, bool negated, bool removable) { + + if(Dump.isOn("t-lemmas")) { + Dump("t-lemmas") << CommentCommand("theory lemma: expect valid") << std::endl + << QueryCommand(node.toExpr()) << std::endl; + } + // Remove the ITEs and assert to prop engine + std::vector<Node> additionalLemmas; + additionalLemmas.push_back(node); + RemoveITE::run(additionalLemmas); + d_propEngine->assertLemma(theory::Rewriter::rewrite(additionalLemmas[0]), negated, removable); + for (unsigned i = 1; i < additionalLemmas.size(); ++ i) { + d_propEngine->assertLemma(theory::Rewriter::rewrite(additionalLemmas[i]), false, removable); + } + + // Mark that we added some lemmas + d_lemmasAdded = true; + } + public: /** Constructs a theory engine */ @@ -249,10 +353,10 @@ public: * there is another theory it will be deleted. */ template <class TheoryClass> - inline void addTheory() { - TheoryClass* theory = new TheoryClass(d_context, d_theoryOut, theory::Valuation(this)); - d_theoryTable[theory->getId()] = theory; - d_sharedTermManager->registerTheory(static_cast<TheoryClass*>(theory)); + inline void addTheory(theory::TheoryId theoryId) { + Assert(d_theoryTable[theoryId] == NULL && d_theoryOut[theoryId] == NULL); + d_theoryOut[theoryId] = new EngineOutputChannel(this, theoryId); + d_theoryTable[theoryId] = new TheoryClass(d_context, *d_theoryOut[theoryId], theory::Valuation(this)); } /** @@ -261,8 +365,11 @@ public: */ void setLogic(std::string logic); - SharedTermManager* getSharedTermManager() { - return d_sharedTermManager; + /** + * Mark a theory active if it's not already. + */ + void markActive(theory::Theory::Set theories) { + d_activeTheories = theory::Theory::setUnion(d_activeTheories, theories); } inline void setPropEngine(prop::PropEngine* propEngine) { @@ -298,26 +405,6 @@ public: void shutdown(); /** - * Get the theory associated to a given Node. - * - * @returns the theory, or NULL if the TNode is - * of built-in type. - */ - inline theory::Theory* theoryOf(TNode node) { - return d_theoryTable[theory::Theory::theoryOf(node)]; - } - - /** - * Get the theory associated to a the given theory id. - * - * @returns the theory, or NULL if the TNode is - * of built-in type. - */ - inline theory::Theory* theoryOf(theory::TheoryId theoryId) { - return d_theoryTable[theoryId]; - } - - /** * Solve the given literal with a theory that owns it. */ theory::Theory::SolveStatus solve(TNode literal, theory::SubstitutionMap& substitionOut); @@ -341,16 +428,7 @@ public: * Assert the formula to the appropriate theory. * @param node the assertion */ - inline void assertFact(TNode node) { - Trace("theory") << "TheoryEngine::assertFact(" << node << ")" << std::endl; - - // Get the atom - TNode atom = node.getKind() == kind::NOT ? node[0] : node; - - theory::Theory* theory = theoryOf(atom); - Trace("theory") << "asserting " << node << " to " << theory->getId() << std::endl; - theory->assertFact(node); - } + void assertFact(TNode node); /** * Check all (currently-active) theories for conflicts. @@ -359,6 +437,11 @@ public: void check(theory::Theory::Effort effort); /** + * Run the combination framework. + */ + void combineTheories(); + + /** * Calls staticLearning() on all theories, accumulating their * combined contributions in the "learned" builder. */ @@ -375,27 +458,12 @@ public: */ void notifyRestart(); - inline const std::vector<TNode>& getPropagatedLiterals() const { - return d_theoryOut.d_propagatedLiterals; - } - - inline void clearPropagatedLiterals() { - d_theoryOut.d_propagatedLiterals.clear(); - } - - inline void newLemma(TNode node, bool negated, bool removable) { - // Remove the ITEs and assert to prop engine - std::vector<Node> additionalLemmas; - additionalLemmas.push_back(node); - RemoveITE::run(additionalLemmas); - d_propEngine->assertLemma(theory::Rewriter::rewrite(additionalLemmas[0]), negated, removable); - for (unsigned i = 1; i < additionalLemmas.size(); ++ i) { - d_propEngine->assertLemma(theory::Rewriter::rewrite(additionalLemmas[i]), false, removable); + void getPropagatedLiterals(std::vector<TNode>& literals) { + for (; d_propagatedLiteralsIndex < d_propagatedLiterals.size(); d_propagatedLiteralsIndex = d_propagatedLiteralsIndex + 1) { + literals.push_back(d_propagatedLiterals[d_propagatedLiteralsIndex]); } } - void propagate(); - Node getExplanation(TNode node, theory::Theory* theory); bool properConflict(TNode conflict) const; @@ -403,160 +471,49 @@ public: bool properExplanation(TNode node, TNode expl) const; inline Node getExplanation(TNode node) { - d_theoryOut.d_explanationNode = Node::null(); TNode atom = node.getKind() == kind::NOT ? node[0] : node; - theoryOf(atom)->explain(node); - Assert(!d_theoryOut.d_explanationNode.get().isNull()); + Node explanation = theoryOf(atom)->explain(node); + Assert(!explanation.isNull()); if(Dump.isOn("t-explanations")) { - Dump("t-explanations") - << CommentCommand(std::string("theory explanation from ") + - theoryOf(atom)->identify() + ": expect valid") << std::endl - << QueryCommand(d_theoryOut.d_explanationNode.get().impNode(node).toExpr()) - << std::endl; + Dump("t-explanations") << CommentCommand(std::string("theory explanation from ") + theoryOf(atom)->identify() + ": expect valid") << std::endl + << QueryCommand(explanation.impNode(node).toExpr()) << std::endl; } - Assert(properExplanation(node, d_theoryOut.d_explanationNode.get())); - return d_theoryOut.d_explanationNode; + Assert(properExplanation(node, explanation)); + return explanation; } Node getValue(TNode node); -private: - class Statistics { - public: - IntStat d_statConflicts, d_statPropagate, d_statLemma, d_statAugLemma, d_statExplanation; - Statistics(): - d_statConflicts("theory::conflicts", 0), - d_statPropagate("theory::propagate", 0), - d_statLemma("theory::lemma", 0), - d_statAugLemma("theory::aug_lemma", 0), - d_statExplanation("theory::explanation", 0) { - StatisticsRegistry::registerStat(&d_statConflicts); - StatisticsRegistry::registerStat(&d_statPropagate); - StatisticsRegistry::registerStat(&d_statLemma); - StatisticsRegistry::registerStat(&d_statAugLemma); - StatisticsRegistry::registerStat(&d_statExplanation); - } - - ~Statistics() { - StatisticsRegistry::unregisterStat(&d_statConflicts); - StatisticsRegistry::unregisterStat(&d_statPropagate); - StatisticsRegistry::unregisterStat(&d_statLemma); - StatisticsRegistry::unregisterStat(&d_statAugLemma); - StatisticsRegistry::unregisterStat(&d_statExplanation); - } - };/* class TheoryEngine::Statistics */ - Statistics d_statistics; - - /////////////////////////// - // Visitors - /////////////////////////// - /** - * Visitor that calls the apropriate theory to preregister the term. + * Get the theory associated to a given Node. + * + * @returns the theory, or NULL if the TNode is + * of built-in type. */ - class PreRegisterVisitor { - - /** The engine */ - TheoryEngine& d_engine; - - /** - * Cache from proprocessing of atoms. - */ - typedef context::CDMap<TNode, theory::Theory::Set, TNodeHashFunction> TNodeVisitedMap; - TNodeVisitedMap d_visited; - - /** - * All the theories of the visitation. - */ - theory::Theory::Set d_theories; - - std::string toString() const { - std::stringstream ss; - TNodeVisitedMap::const_iterator it = d_visited.begin(); - for (; it != d_visited.end(); ++ it) { - ss << (*it).first << ": " << theory::Theory::setToString((*it).second) << std::endl; - } - return ss.str(); - } - - public: - - PreRegisterVisitor(TheoryEngine& engine, context::Context* context): d_engine(engine), d_visited(context), d_theories(0){} - - bool alreadyVisited(TNode current, TNode parent) { - - Debug("register::internal") << "PreRegisterVisitor::alreadyVisited(" << current << "," << parent << ") => "; - - using namespace theory; - - TNodeVisitedMap::iterator find = d_visited.find(current); - - // If node is not visited at all, just return false - if (find == d_visited.end()) { - Debug("register::internal") << "1:false" << std::endl; - return false; - } - - Theory::Set theories = (*find).second; - - TheoryId currentTheoryId = Theory::theoryOf(current); - TheoryId parentTheoryId = Theory::theoryOf(parent); - - if (Theory::setContains(currentTheoryId, theories)) { - // The current theory has already visited it, so now it depends on the parent - Debug("register::internal") << (Theory::setContains(parentTheoryId, theories) ? "2:true" : "2:false") << std::endl; - return Theory::setContains(parentTheoryId, theories); - } else { - // If the current theory is not registered, it still needs to be visited - Debug("register::internal") << "2:false" << std::endl; - return false; - } - } - - void visit(TNode current, TNode parent) { - - Debug("register") << "PreRegisterVisitor::visit(" << current << "," << parent << ")" << std::endl; - Debug("register::internal") << toString() << std::endl; - - using namespace theory; - - // Get the theories of the terms - TheoryId currentTheoryId = Theory::theoryOf(current); - TheoryId parentTheoryId = Theory::theoryOf(parent); - - Theory::Set theories = d_visited[current]; - Debug("register::internal") << "PreRegisterVisitor::visit(" << current << "," << parent << "): previously registered with " << Theory::setToString(theories) << std::endl; - if (!Theory::setContains(currentTheoryId, theories)) { - d_visited[current] = theories = Theory::setInsert(currentTheoryId, theories); - d_engine.theoryOf(currentTheoryId)->preRegisterTerm(current); - d_theories = Theory::setInsert(currentTheoryId, d_theories); - Debug("register::internal") << "PreRegisterVisitor::visit(" << current << "," << parent << "): adding " << currentTheoryId << std::endl; - } - if (!Theory::setContains(parentTheoryId, theories)) { - d_visited[current] = theories = Theory::setInsert(parentTheoryId, theories); - d_engine.theoryOf(parentTheoryId)->preRegisterTerm(current); - d_theories = Theory::setInsert(parentTheoryId, d_theories); - Debug("register::internal") << "PreRegisterVisitor::visit(" << current << "," << parent << "): adding " << parentTheoryId << std::endl; - } - Debug("register::internal") << "PreRegisterVisitor::visit(" << current << "," << parent << "): now registered with " << Theory::setToString(theories) << std::endl; - - Assert(d_visited.find(current) != d_visited.end()); - Assert(alreadyVisited(current, parent)); - } - - void start(TNode node) { - d_theories = 0; - } + inline theory::Theory* theoryOf(TNode node) { + return d_theoryTable[theory::Theory::theoryOf(node)]; + } - void done(TNode node) { - d_engine.markActive(d_theories); - } + /** + * Get the theory associated to a the given theory id. It will also mark the + * theory as currently active, we assume that theories are called only through + * theoryOf. + * + * @returns the theory, or NULL if the TNode is + * of built-in type. + */ + inline theory::Theory* theoryOf(theory::TheoryId theoryId) { + return d_theoryTable[theoryId]; + } - }; +private: /** Default visitor for pre-registration */ PreRegisterVisitor d_preRegistrationVisitor; + /** Visitor for collecting shared terms */ + SharedTermsVisitor d_sharedTermsVisitor; + };/* class TheoryEngine */ }/* CVC4 namespace */ |