From 232042b3e2e265dbfe9c693d018d48388be91018 Mon Sep 17 00:00:00 2001 From: Tim King Date: Thu, 17 Mar 2011 20:38:32 +0000 Subject: - Removes arith_constants.h - Adds ArithStaticLearner. Consolidates and cleans up the code for static learning in arithmetic. Static learning is now associated with a small amount of state between calls. This is used to track the data for the miplib trick. The goal is to make this inference work without relying on the fact that all of the miplib problem is asserted under the same AND node. - This commit contains miscellaneous other arithmetic cleanup. --- src/theory/arith/arith_static_learner.cpp | 298 ++++++++++++++++++++++++++++++ 1 file changed, 298 insertions(+) create mode 100644 src/theory/arith/arith_static_learner.cpp (limited to 'src/theory/arith/arith_static_learner.cpp') diff --git a/src/theory/arith/arith_static_learner.cpp b/src/theory/arith/arith_static_learner.cpp new file mode 100644 index 000000000..6fb538fac --- /dev/null +++ b/src/theory/arith/arith_static_learner.cpp @@ -0,0 +1,298 @@ +/********************* */ +/*! \file arith_rewriter.cpp + ** \verbatim + ** Original author: taking + ** Major contributors: dejan + ** Minor contributors (to current version): mdeters + ** This file is part of the CVC4 prototype. + ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys) + ** Courant Institute of Mathematical Sciences + ** New York University + ** See the file COPYING in the top-level source directory for licensing + ** information.\endverbatim + ** + ** \brief [[ Add one-line brief description here ]] + ** + ** [[ Add lengthier description here ]] + ** \todo document this file + **/ + +#include "theory/rewriter.h" + +#include "theory/arith/arith_utilities.h" +#include "theory/arith/arith_static_learner.h" + +#include "util/propositional_query.h" + +#include + +using namespace std; + +using namespace CVC4; +using namespace CVC4::kind; +using namespace CVC4::theory; +using namespace CVC4::theory::arith; + + +ArithStaticLearner::ArithStaticLearner(): + d_miplibTrick(), + d_statistics() +{} + +ArithStaticLearner::Statistics::Statistics(): + d_iteMinMaxApplications("theory::arith::iteMinMaxApplications", 0), + d_iteConstantApplications("theory::arith::iteConstantApplications", 0), + d_miplibtrickApplications("theory::arith::miplibtrickApplications", 0), + d_avgNumMiplibtrickValues("theory::arith::avgNumMiplibtrickValues") +{ + StatisticsRegistry::registerStat(&d_iteMinMaxApplications); + StatisticsRegistry::registerStat(&d_iteConstantApplications); + StatisticsRegistry::registerStat(&d_miplibtrickApplications); + StatisticsRegistry::registerStat(&d_avgNumMiplibtrickValues); +} + +ArithStaticLearner::Statistics::~Statistics(){ + StatisticsRegistry::unregisterStat(&d_iteMinMaxApplications); + StatisticsRegistry::unregisterStat(&d_iteConstantApplications); + StatisticsRegistry::unregisterStat(&d_miplibtrickApplications); + StatisticsRegistry::unregisterStat(&d_avgNumMiplibtrickValues); +} + +void ArithStaticLearner::staticLearning(TNode n, NodeBuilder<>& learned){ + + vector workList; + workList.push_back(n); + TNodeSet processed; + + //Contains an underapproximation of nodes that must hold. + TNodeSet defTrue; + + defTrue.insert(n); + + while(!workList.empty()) { + n = workList.back(); + + bool unprocessedChildren = false; + for(TNode::iterator i = n.begin(), iend = n.end(); i != iend; ++i) { + if(processed.find(*i) == processed.end()) { + // unprocessed child + workList.push_back(*i); + unprocessedChildren = true; + } + } + if(n.getKind() == AND && defTrue.find(n) != defTrue.end() ){ + for(TNode::iterator i = n.begin(), iend = n.end(); i != iend; ++i) { + defTrue.insert(*i); + } + } + + if(unprocessedChildren) { + continue; + } + + workList.pop_back(); + // has node n been processed in the meantime ? + if(processed.find(n) != processed.end()) { + continue; + } + processed.insert(n); + + process(n,learned, defTrue); + + } + + postProcess(learned); +} + +void ArithStaticLearner::clear(){ + d_miplibTrick.clear(); +} + + +void ArithStaticLearner::process(TNode n, NodeBuilder<>& learned, const TNodeSet& defTrue){ + Debug("arith::static") << "===================== looking at" << n << endl; + + switch(n.getKind()){ + case ITE: + if(n[0].getKind() != EQUAL && + isRelationOperator(n[0].getKind()) ){ + iteMinMax(n, learned); + } + + if((n[1].getKind() == CONST_RATIONAL || n[1].getKind() == CONST_INTEGER) && + (n[2].getKind() == CONST_RATIONAL || n[2].getKind() == CONST_INTEGER)) { + iteConstant(n, learned); + } + break; + case IMPLIES: + // == 3-FINITE VALUE SET : Collect information == + if(n[1].getKind() == EQUAL && + n[1][0].getMetaKind() == metakind::VARIABLE && + defTrue.find(n) != defTrue.end()){ + Node eqTo = n[1][1]; + Node rewriteEqTo = Rewriter::rewrite(eqTo); + if(rewriteEqTo.getKind() == CONST_RATIONAL){ + + TNode var = n[1][0]; + if(d_miplibTrick.find(var) == d_miplibTrick.end()){ + d_miplibTrick.insert(make_pair(var, set())); + } + d_miplibTrick[var].insert(n); + Debug("arith::miplib") << "insert " << var << " const " << n << endl; + } + } + break; + default: // Do nothing + break; + } +} + +void ArithStaticLearner::iteMinMax(TNode n, NodeBuilder<>& learned){ + Assert(n.getKind() == kind::ITE); + Assert(n[0].getKind() != EQUAL); + Assert(isRelationOperator(n[0].getKind())); + + TNode c = n[0]; + Kind k = simplifiedKind(c); + TNode t = n[1]; + TNode e = n[2]; + TNode cleft = (c.getKind() == NOT) ? c[0][0] : c[0]; + TNode cright = (c.getKind() == NOT) ? c[0][1] : c[1]; + + if((t == cright) && (e == cleft)){ + TNode tmp = t; + t = e; + e = tmp; + k = reverseRelationKind(k); + } + if(t == cleft && e == cright){ + // t == cleft && e == cright + Assert( t == cleft ); + Assert( e == cright ); + switch(k){ + case LT: // (ite (< x y) x y) + case LEQ: { // (ite (<= x y) x y) + Node nLeqX = NodeBuilder<2>(LEQ) << n << t; + Node nLeqY = NodeBuilder<2>(LEQ) << n << e; + Debug("arith::static") << n << "is a min =>" << nLeqX << nLeqY << endl; + learned << nLeqX << nLeqY; + ++(d_statistics.d_iteMinMaxApplications); + break; + } + case GT: // (ite (> x y) x y) + case GEQ: { // (ite (>= x y) x y) + Node nGeqX = NodeBuilder<2>(GEQ) << n << t; + Node nGeqY = NodeBuilder<2>(GEQ) << n << e; + Debug("arith::static") << n << "is a max =>" << nGeqX << nGeqY << endl; + learned << nGeqX << nGeqY; + ++(d_statistics.d_iteMinMaxApplications); + break; + } + default: Unreachable(); + } + } +} + +void ArithStaticLearner::iteConstant(TNode n, NodeBuilder<>& learned){ + Assert(n.getKind() == ITE); + Assert(n[1].getKind() == CONST_RATIONAL || n[1].getKind() == CONST_INTEGER ); + Assert(n[2].getKind() == CONST_RATIONAL || n[2].getKind() == CONST_INTEGER ); + + Rational t = coerceToRational(n[1]); + Rational e = coerceToRational(n[2]); + TNode min = (t <= e) ? n[1] : n[2]; + TNode max = (t >= e) ? n[1] : n[2]; + + Node nGeqMin = NodeBuilder<2>(GEQ) << n << min; + Node nLeqMax = NodeBuilder<2>(LEQ) << n << max; + Debug("arith::static") << n << " iteConstant" << nGeqMin << nLeqMax << endl; + learned << nGeqMin << nLeqMax; + ++(d_statistics.d_iteConstantApplications); +} + + +void ArithStaticLearner::postProcess(NodeBuilder<>& learned){ + + // == 3-FINITE VALUE SET == + VarToNodeSetMap::iterator i = d_miplibTrick.begin(); + VarToNodeSetMap::iterator endMipLibTrick = d_miplibTrick.end(); + for(; i != endMipLibTrick; ++i){ + TNode var = i->first; + const set& imps = i->second; + + Assert(!imps.empty()); + vector conditions; + set values; + set::const_iterator j=imps.begin(), impsEnd=imps.end(); + for(; j != impsEnd; ++j){ + TNode imp = *j; + Assert(imp.getKind() == IMPLIES); + Assert(imp[1].getKind() == EQUAL); + + Node eqTo = imp[1][1]; + Node rewriteEqTo = Rewriter::rewrite(eqTo); + Assert(rewriteEqTo.getKind() == CONST_RATIONAL); + + conditions.push_back(imp[0]); + values.insert(rewriteEqTo.getConst()); + } + + Node possibleTaut = Node::null(); + if(conditions.size() == 1){ + possibleTaut = conditions.front(); + }else{ + NodeBuilder<> orBuilder(OR); + orBuilder.append(conditions); + possibleTaut = orBuilder; + } + + + Debug("arith::miplib") << "var: " << var << endl; + Debug("arith::miplib") << "possibleTaut: " << possibleTaut << endl; + + Result isTaut = PropositionalQuery::isTautology(possibleTaut); + if(isTaut == Result(Result::VALID)){ + miplibTrick(var, values, learned); + } + } +} + + +void ArithStaticLearner::miplibTrick(TNode var, set& values, NodeBuilder<>& learned){ + + Debug("arith::miplib") << var << " found a tautology!"<< endl; + + const Rational& min = *(values.begin()); + const Rational& max = *(values.rbegin()); + + Debug("arith::miplib") << "min: " << min << endl; + Debug("arith::miplib") << "max: " << max << endl; + + Assert(min <= max); + ++(d_statistics.d_miplibtrickApplications); + (d_statistics.d_avgNumMiplibtrickValues).addEntry(values.size()); + + Node nGeqMin = NodeBuilder<2>(GEQ) << var << mkRationalNode(min); + Node nLeqMax = NodeBuilder<2>(LEQ) << var << mkRationalNode(max); + Debug("arith::miplib") << nGeqMin << nLeqMax << endl; + learned << nGeqMin << nLeqMax; + set::iterator valuesIter = values.begin(); + set::iterator valuesEnd = values.end(); + set::iterator valuesPrev = valuesIter; + ++valuesIter; + for(; valuesIter != valuesEnd; valuesPrev = valuesIter, ++valuesIter){ + const Rational& prev = *valuesPrev; + const Rational& curr = *valuesIter; + Assert(prev < curr); + + //The interval (last,curr) can be excluded: + //(not (and (> var prev) (< var curr)) + //<=> (or (not (> var prev)) (not (< var curr))) + //<=> (or (<= var prev) (>= var curr)) + Node leqPrev = NodeBuilder<2>(LEQ) << var << mkRationalNode(prev); + Node geqCurr = NodeBuilder<2>(GEQ) << var << mkRationalNode(curr); + Node excludedMiddle = NodeBuilder<2>(OR) << leqPrev << geqCurr; + Debug("arith::miplib") << excludedMiddle << endl; + learned << excludedMiddle; + } +} -- cgit v1.2.3