diff options
author | Tim King <taking@cs.nyu.edu> | 2011-03-17 20:38:32 +0000 |
---|---|---|
committer | Tim King <taking@cs.nyu.edu> | 2011-03-17 20:38:32 +0000 |
commit | 232042b3e2e265dbfe9c693d018d48388be91018 (patch) | |
tree | 55dc6d39bbd5eff9b0b1c220ed33dac3d4bdd316 /src/theory/arith/arith_static_learner.cpp | |
parent | 68f8b6b2589320dac3a022a1e5058e5a65cc570b (diff) |
- Removes arith_constants.h
- Adds ArithStaticLearner. Consolidates and cleans up the code for static learning in arithmetic. Static learning is now associated with a small amount of state between calls. This is used to track the data for the miplib trick. The goal is to make this inference work without relying on the fact that all of the miplib problem is asserted under the same AND node.
- This commit contains miscellaneous other arithmetic cleanup.
Diffstat (limited to 'src/theory/arith/arith_static_learner.cpp')
-rw-r--r-- | src/theory/arith/arith_static_learner.cpp | 298 |
1 files changed, 298 insertions, 0 deletions
diff --git a/src/theory/arith/arith_static_learner.cpp b/src/theory/arith/arith_static_learner.cpp new file mode 100644 index 000000000..6fb538fac --- /dev/null +++ b/src/theory/arith/arith_static_learner.cpp @@ -0,0 +1,298 @@ +/********************* */ +/*! \file arith_rewriter.cpp + ** \verbatim + ** Original author: taking + ** Major contributors: dejan + ** Minor contributors (to current version): mdeters + ** This file is part of the CVC4 prototype. + ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys) + ** Courant Institute of Mathematical Sciences + ** New York University + ** See the file COPYING in the top-level source directory for licensing + ** information.\endverbatim + ** + ** \brief [[ Add one-line brief description here ]] + ** + ** [[ Add lengthier description here ]] + ** \todo document this file + **/ + +#include "theory/rewriter.h" + +#include "theory/arith/arith_utilities.h" +#include "theory/arith/arith_static_learner.h" + +#include "util/propositional_query.h" + +#include <vector> + +using namespace std; + +using namespace CVC4; +using namespace CVC4::kind; +using namespace CVC4::theory; +using namespace CVC4::theory::arith; + + +ArithStaticLearner::ArithStaticLearner(): + d_miplibTrick(), + d_statistics() +{} + +ArithStaticLearner::Statistics::Statistics(): + d_iteMinMaxApplications("theory::arith::iteMinMaxApplications", 0), + d_iteConstantApplications("theory::arith::iteConstantApplications", 0), + d_miplibtrickApplications("theory::arith::miplibtrickApplications", 0), + d_avgNumMiplibtrickValues("theory::arith::avgNumMiplibtrickValues") +{ + StatisticsRegistry::registerStat(&d_iteMinMaxApplications); + StatisticsRegistry::registerStat(&d_iteConstantApplications); + StatisticsRegistry::registerStat(&d_miplibtrickApplications); + StatisticsRegistry::registerStat(&d_avgNumMiplibtrickValues); +} + +ArithStaticLearner::Statistics::~Statistics(){ + StatisticsRegistry::unregisterStat(&d_iteMinMaxApplications); + StatisticsRegistry::unregisterStat(&d_iteConstantApplications); + StatisticsRegistry::unregisterStat(&d_miplibtrickApplications); + StatisticsRegistry::unregisterStat(&d_avgNumMiplibtrickValues); +} + +void ArithStaticLearner::staticLearning(TNode n, NodeBuilder<>& learned){ + + vector<TNode> workList; + workList.push_back(n); + TNodeSet processed; + + //Contains an underapproximation of nodes that must hold. + TNodeSet defTrue; + + defTrue.insert(n); + + while(!workList.empty()) { + n = workList.back(); + + bool unprocessedChildren = false; + for(TNode::iterator i = n.begin(), iend = n.end(); i != iend; ++i) { + if(processed.find(*i) == processed.end()) { + // unprocessed child + workList.push_back(*i); + unprocessedChildren = true; + } + } + if(n.getKind() == AND && defTrue.find(n) != defTrue.end() ){ + for(TNode::iterator i = n.begin(), iend = n.end(); i != iend; ++i) { + defTrue.insert(*i); + } + } + + if(unprocessedChildren) { + continue; + } + + workList.pop_back(); + // has node n been processed in the meantime ? + if(processed.find(n) != processed.end()) { + continue; + } + processed.insert(n); + + process(n,learned, defTrue); + + } + + postProcess(learned); +} + +void ArithStaticLearner::clear(){ + d_miplibTrick.clear(); +} + + +void ArithStaticLearner::process(TNode n, NodeBuilder<>& learned, const TNodeSet& defTrue){ + Debug("arith::static") << "===================== looking at" << n << endl; + + switch(n.getKind()){ + case ITE: + if(n[0].getKind() != EQUAL && + isRelationOperator(n[0].getKind()) ){ + iteMinMax(n, learned); + } + + if((n[1].getKind() == CONST_RATIONAL || n[1].getKind() == CONST_INTEGER) && + (n[2].getKind() == CONST_RATIONAL || n[2].getKind() == CONST_INTEGER)) { + iteConstant(n, learned); + } + break; + case IMPLIES: + // == 3-FINITE VALUE SET : Collect information == + if(n[1].getKind() == EQUAL && + n[1][0].getMetaKind() == metakind::VARIABLE && + defTrue.find(n) != defTrue.end()){ + Node eqTo = n[1][1]; + Node rewriteEqTo = Rewriter::rewrite(eqTo); + if(rewriteEqTo.getKind() == CONST_RATIONAL){ + + TNode var = n[1][0]; + if(d_miplibTrick.find(var) == d_miplibTrick.end()){ + d_miplibTrick.insert(make_pair(var, set<Node>())); + } + d_miplibTrick[var].insert(n); + Debug("arith::miplib") << "insert " << var << " const " << n << endl; + } + } + break; + default: // Do nothing + break; + } +} + +void ArithStaticLearner::iteMinMax(TNode n, NodeBuilder<>& learned){ + Assert(n.getKind() == kind::ITE); + Assert(n[0].getKind() != EQUAL); + Assert(isRelationOperator(n[0].getKind())); + + TNode c = n[0]; + Kind k = simplifiedKind(c); + TNode t = n[1]; + TNode e = n[2]; + TNode cleft = (c.getKind() == NOT) ? c[0][0] : c[0]; + TNode cright = (c.getKind() == NOT) ? c[0][1] : c[1]; + + if((t == cright) && (e == cleft)){ + TNode tmp = t; + t = e; + e = tmp; + k = reverseRelationKind(k); + } + if(t == cleft && e == cright){ + // t == cleft && e == cright + Assert( t == cleft ); + Assert( e == cright ); + switch(k){ + case LT: // (ite (< x y) x y) + case LEQ: { // (ite (<= x y) x y) + Node nLeqX = NodeBuilder<2>(LEQ) << n << t; + Node nLeqY = NodeBuilder<2>(LEQ) << n << e; + Debug("arith::static") << n << "is a min =>" << nLeqX << nLeqY << endl; + learned << nLeqX << nLeqY; + ++(d_statistics.d_iteMinMaxApplications); + break; + } + case GT: // (ite (> x y) x y) + case GEQ: { // (ite (>= x y) x y) + Node nGeqX = NodeBuilder<2>(GEQ) << n << t; + Node nGeqY = NodeBuilder<2>(GEQ) << n << e; + Debug("arith::static") << n << "is a max =>" << nGeqX << nGeqY << endl; + learned << nGeqX << nGeqY; + ++(d_statistics.d_iteMinMaxApplications); + break; + } + default: Unreachable(); + } + } +} + +void ArithStaticLearner::iteConstant(TNode n, NodeBuilder<>& learned){ + Assert(n.getKind() == ITE); + Assert(n[1].getKind() == CONST_RATIONAL || n[1].getKind() == CONST_INTEGER ); + Assert(n[2].getKind() == CONST_RATIONAL || n[2].getKind() == CONST_INTEGER ); + + Rational t = coerceToRational(n[1]); + Rational e = coerceToRational(n[2]); + TNode min = (t <= e) ? n[1] : n[2]; + TNode max = (t >= e) ? n[1] : n[2]; + + Node nGeqMin = NodeBuilder<2>(GEQ) << n << min; + Node nLeqMax = NodeBuilder<2>(LEQ) << n << max; + Debug("arith::static") << n << " iteConstant" << nGeqMin << nLeqMax << endl; + learned << nGeqMin << nLeqMax; + ++(d_statistics.d_iteConstantApplications); +} + + +void ArithStaticLearner::postProcess(NodeBuilder<>& learned){ + + // == 3-FINITE VALUE SET == + VarToNodeSetMap::iterator i = d_miplibTrick.begin(); + VarToNodeSetMap::iterator endMipLibTrick = d_miplibTrick.end(); + for(; i != endMipLibTrick; ++i){ + TNode var = i->first; + const set<Node>& imps = i->second; + + Assert(!imps.empty()); + vector<Node> conditions; + set<Rational> values; + set<Node>::const_iterator j=imps.begin(), impsEnd=imps.end(); + for(; j != impsEnd; ++j){ + TNode imp = *j; + Assert(imp.getKind() == IMPLIES); + Assert(imp[1].getKind() == EQUAL); + + Node eqTo = imp[1][1]; + Node rewriteEqTo = Rewriter::rewrite(eqTo); + Assert(rewriteEqTo.getKind() == CONST_RATIONAL); + + conditions.push_back(imp[0]); + values.insert(rewriteEqTo.getConst<Rational>()); + } + + Node possibleTaut = Node::null(); + if(conditions.size() == 1){ + possibleTaut = conditions.front(); + }else{ + NodeBuilder<> orBuilder(OR); + orBuilder.append(conditions); + possibleTaut = orBuilder; + } + + + Debug("arith::miplib") << "var: " << var << endl; + Debug("arith::miplib") << "possibleTaut: " << possibleTaut << endl; + + Result isTaut = PropositionalQuery::isTautology(possibleTaut); + if(isTaut == Result(Result::VALID)){ + miplibTrick(var, values, learned); + } + } +} + + +void ArithStaticLearner::miplibTrick(TNode var, set<Rational>& values, NodeBuilder<>& learned){ + + Debug("arith::miplib") << var << " found a tautology!"<< endl; + + const Rational& min = *(values.begin()); + const Rational& max = *(values.rbegin()); + + Debug("arith::miplib") << "min: " << min << endl; + Debug("arith::miplib") << "max: " << max << endl; + + Assert(min <= max); + ++(d_statistics.d_miplibtrickApplications); + (d_statistics.d_avgNumMiplibtrickValues).addEntry(values.size()); + + Node nGeqMin = NodeBuilder<2>(GEQ) << var << mkRationalNode(min); + Node nLeqMax = NodeBuilder<2>(LEQ) << var << mkRationalNode(max); + Debug("arith::miplib") << nGeqMin << nLeqMax << endl; + learned << nGeqMin << nLeqMax; + set<Rational>::iterator valuesIter = values.begin(); + set<Rational>::iterator valuesEnd = values.end(); + set<Rational>::iterator valuesPrev = valuesIter; + ++valuesIter; + for(; valuesIter != valuesEnd; valuesPrev = valuesIter, ++valuesIter){ + const Rational& prev = *valuesPrev; + const Rational& curr = *valuesIter; + Assert(prev < curr); + + //The interval (last,curr) can be excluded: + //(not (and (> var prev) (< var curr)) + //<=> (or (not (> var prev)) (not (< var curr))) + //<=> (or (<= var prev) (>= var curr)) + Node leqPrev = NodeBuilder<2>(LEQ) << var << mkRationalNode(prev); + Node geqCurr = NodeBuilder<2>(GEQ) << var << mkRationalNode(curr); + Node excludedMiddle = NodeBuilder<2>(OR) << leqPrev << geqCurr; + Debug("arith::miplib") << excludedMiddle << endl; + learned << excludedMiddle; + } +} |