diff options
Diffstat (limited to 'src/theory/arith/arith_static_learner.cpp')
-rw-r--r-- | src/theory/arith/arith_static_learner.cpp | 298 |
1 files changed, 298 insertions, 0 deletions
diff --git a/src/theory/arith/arith_static_learner.cpp b/src/theory/arith/arith_static_learner.cpp new file mode 100644 index 000000000..6fb538fac --- /dev/null +++ b/src/theory/arith/arith_static_learner.cpp @@ -0,0 +1,298 @@ +/********************* */ +/*! \file arith_rewriter.cpp + ** \verbatim + ** Original author: taking + ** Major contributors: dejan + ** Minor contributors (to current version): mdeters + ** This file is part of the CVC4 prototype. + ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys) + ** Courant Institute of Mathematical Sciences + ** New York University + ** See the file COPYING in the top-level source directory for licensing + ** information.\endverbatim + ** + ** \brief [[ Add one-line brief description here ]] + ** + ** [[ Add lengthier description here ]] + ** \todo document this file + **/ + +#include "theory/rewriter.h" + +#include "theory/arith/arith_utilities.h" +#include "theory/arith/arith_static_learner.h" + +#include "util/propositional_query.h" + +#include <vector> + +using namespace std; + +using namespace CVC4; +using namespace CVC4::kind; +using namespace CVC4::theory; +using namespace CVC4::theory::arith; + + +ArithStaticLearner::ArithStaticLearner(): + d_miplibTrick(), + d_statistics() +{} + +ArithStaticLearner::Statistics::Statistics(): + d_iteMinMaxApplications("theory::arith::iteMinMaxApplications", 0), + d_iteConstantApplications("theory::arith::iteConstantApplications", 0), + d_miplibtrickApplications("theory::arith::miplibtrickApplications", 0), + d_avgNumMiplibtrickValues("theory::arith::avgNumMiplibtrickValues") +{ + StatisticsRegistry::registerStat(&d_iteMinMaxApplications); + StatisticsRegistry::registerStat(&d_iteConstantApplications); + StatisticsRegistry::registerStat(&d_miplibtrickApplications); + StatisticsRegistry::registerStat(&d_avgNumMiplibtrickValues); +} + +ArithStaticLearner::Statistics::~Statistics(){ + StatisticsRegistry::unregisterStat(&d_iteMinMaxApplications); + StatisticsRegistry::unregisterStat(&d_iteConstantApplications); + StatisticsRegistry::unregisterStat(&d_miplibtrickApplications); + StatisticsRegistry::unregisterStat(&d_avgNumMiplibtrickValues); +} + +void ArithStaticLearner::staticLearning(TNode n, NodeBuilder<>& learned){ + + vector<TNode> workList; + workList.push_back(n); + TNodeSet processed; + + //Contains an underapproximation of nodes that must hold. + TNodeSet defTrue; + + defTrue.insert(n); + + while(!workList.empty()) { + n = workList.back(); + + bool unprocessedChildren = false; + for(TNode::iterator i = n.begin(), iend = n.end(); i != iend; ++i) { + if(processed.find(*i) == processed.end()) { + // unprocessed child + workList.push_back(*i); + unprocessedChildren = true; + } + } + if(n.getKind() == AND && defTrue.find(n) != defTrue.end() ){ + for(TNode::iterator i = n.begin(), iend = n.end(); i != iend; ++i) { + defTrue.insert(*i); + } + } + + if(unprocessedChildren) { + continue; + } + + workList.pop_back(); + // has node n been processed in the meantime ? + if(processed.find(n) != processed.end()) { + continue; + } + processed.insert(n); + + process(n,learned, defTrue); + + } + + postProcess(learned); +} + +void ArithStaticLearner::clear(){ + d_miplibTrick.clear(); +} + + +void ArithStaticLearner::process(TNode n, NodeBuilder<>& learned, const TNodeSet& defTrue){ + Debug("arith::static") << "===================== looking at" << n << endl; + + switch(n.getKind()){ + case ITE: + if(n[0].getKind() != EQUAL && + isRelationOperator(n[0].getKind()) ){ + iteMinMax(n, learned); + } + + if((n[1].getKind() == CONST_RATIONAL || n[1].getKind() == CONST_INTEGER) && + (n[2].getKind() == CONST_RATIONAL || n[2].getKind() == CONST_INTEGER)) { + iteConstant(n, learned); + } + break; + case IMPLIES: + // == 3-FINITE VALUE SET : Collect information == + if(n[1].getKind() == EQUAL && + n[1][0].getMetaKind() == metakind::VARIABLE && + defTrue.find(n) != defTrue.end()){ + Node eqTo = n[1][1]; + Node rewriteEqTo = Rewriter::rewrite(eqTo); + if(rewriteEqTo.getKind() == CONST_RATIONAL){ + + TNode var = n[1][0]; + if(d_miplibTrick.find(var) == d_miplibTrick.end()){ + d_miplibTrick.insert(make_pair(var, set<Node>())); + } + d_miplibTrick[var].insert(n); + Debug("arith::miplib") << "insert " << var << " const " << n << endl; + } + } + break; + default: // Do nothing + break; + } +} + +void ArithStaticLearner::iteMinMax(TNode n, NodeBuilder<>& learned){ + Assert(n.getKind() == kind::ITE); + Assert(n[0].getKind() != EQUAL); + Assert(isRelationOperator(n[0].getKind())); + + TNode c = n[0]; + Kind k = simplifiedKind(c); + TNode t = n[1]; + TNode e = n[2]; + TNode cleft = (c.getKind() == NOT) ? c[0][0] : c[0]; + TNode cright = (c.getKind() == NOT) ? c[0][1] : c[1]; + + if((t == cright) && (e == cleft)){ + TNode tmp = t; + t = e; + e = tmp; + k = reverseRelationKind(k); + } + if(t == cleft && e == cright){ + // t == cleft && e == cright + Assert( t == cleft ); + Assert( e == cright ); + switch(k){ + case LT: // (ite (< x y) x y) + case LEQ: { // (ite (<= x y) x y) + Node nLeqX = NodeBuilder<2>(LEQ) << n << t; + Node nLeqY = NodeBuilder<2>(LEQ) << n << e; + Debug("arith::static") << n << "is a min =>" << nLeqX << nLeqY << endl; + learned << nLeqX << nLeqY; + ++(d_statistics.d_iteMinMaxApplications); + break; + } + case GT: // (ite (> x y) x y) + case GEQ: { // (ite (>= x y) x y) + Node nGeqX = NodeBuilder<2>(GEQ) << n << t; + Node nGeqY = NodeBuilder<2>(GEQ) << n << e; + Debug("arith::static") << n << "is a max =>" << nGeqX << nGeqY << endl; + learned << nGeqX << nGeqY; + ++(d_statistics.d_iteMinMaxApplications); + break; + } + default: Unreachable(); + } + } +} + +void ArithStaticLearner::iteConstant(TNode n, NodeBuilder<>& learned){ + Assert(n.getKind() == ITE); + Assert(n[1].getKind() == CONST_RATIONAL || n[1].getKind() == CONST_INTEGER ); + Assert(n[2].getKind() == CONST_RATIONAL || n[2].getKind() == CONST_INTEGER ); + + Rational t = coerceToRational(n[1]); + Rational e = coerceToRational(n[2]); + TNode min = (t <= e) ? n[1] : n[2]; + TNode max = (t >= e) ? n[1] : n[2]; + + Node nGeqMin = NodeBuilder<2>(GEQ) << n << min; + Node nLeqMax = NodeBuilder<2>(LEQ) << n << max; + Debug("arith::static") << n << " iteConstant" << nGeqMin << nLeqMax << endl; + learned << nGeqMin << nLeqMax; + ++(d_statistics.d_iteConstantApplications); +} + + +void ArithStaticLearner::postProcess(NodeBuilder<>& learned){ + + // == 3-FINITE VALUE SET == + VarToNodeSetMap::iterator i = d_miplibTrick.begin(); + VarToNodeSetMap::iterator endMipLibTrick = d_miplibTrick.end(); + for(; i != endMipLibTrick; ++i){ + TNode var = i->first; + const set<Node>& imps = i->second; + + Assert(!imps.empty()); + vector<Node> conditions; + set<Rational> values; + set<Node>::const_iterator j=imps.begin(), impsEnd=imps.end(); + for(; j != impsEnd; ++j){ + TNode imp = *j; + Assert(imp.getKind() == IMPLIES); + Assert(imp[1].getKind() == EQUAL); + + Node eqTo = imp[1][1]; + Node rewriteEqTo = Rewriter::rewrite(eqTo); + Assert(rewriteEqTo.getKind() == CONST_RATIONAL); + + conditions.push_back(imp[0]); + values.insert(rewriteEqTo.getConst<Rational>()); + } + + Node possibleTaut = Node::null(); + if(conditions.size() == 1){ + possibleTaut = conditions.front(); + }else{ + NodeBuilder<> orBuilder(OR); + orBuilder.append(conditions); + possibleTaut = orBuilder; + } + + + Debug("arith::miplib") << "var: " << var << endl; + Debug("arith::miplib") << "possibleTaut: " << possibleTaut << endl; + + Result isTaut = PropositionalQuery::isTautology(possibleTaut); + if(isTaut == Result(Result::VALID)){ + miplibTrick(var, values, learned); + } + } +} + + +void ArithStaticLearner::miplibTrick(TNode var, set<Rational>& values, NodeBuilder<>& learned){ + + Debug("arith::miplib") << var << " found a tautology!"<< endl; + + const Rational& min = *(values.begin()); + const Rational& max = *(values.rbegin()); + + Debug("arith::miplib") << "min: " << min << endl; + Debug("arith::miplib") << "max: " << max << endl; + + Assert(min <= max); + ++(d_statistics.d_miplibtrickApplications); + (d_statistics.d_avgNumMiplibtrickValues).addEntry(values.size()); + + Node nGeqMin = NodeBuilder<2>(GEQ) << var << mkRationalNode(min); + Node nLeqMax = NodeBuilder<2>(LEQ) << var << mkRationalNode(max); + Debug("arith::miplib") << nGeqMin << nLeqMax << endl; + learned << nGeqMin << nLeqMax; + set<Rational>::iterator valuesIter = values.begin(); + set<Rational>::iterator valuesEnd = values.end(); + set<Rational>::iterator valuesPrev = valuesIter; + ++valuesIter; + for(; valuesIter != valuesEnd; valuesPrev = valuesIter, ++valuesIter){ + const Rational& prev = *valuesPrev; + const Rational& curr = *valuesIter; + Assert(prev < curr); + + //The interval (last,curr) can be excluded: + //(not (and (> var prev) (< var curr)) + //<=> (or (not (> var prev)) (not (< var curr))) + //<=> (or (<= var prev) (>= var curr)) + Node leqPrev = NodeBuilder<2>(LEQ) << var << mkRationalNode(prev); + Node geqCurr = NodeBuilder<2>(GEQ) << var << mkRationalNode(curr); + Node excludedMiddle = NodeBuilder<2>(OR) << leqPrev << geqCurr; + Debug("arith::miplib") << excludedMiddle << endl; + learned << excludedMiddle; + } +} |