/********************* */ /*! \file arith_rewriter.cpp ** \verbatim ** Original author: taking ** Major contributors: dejan ** Minor contributors (to current version): mdeters ** This file is part of the CVC4 prototype. ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys) ** Courant Institute of Mathematical Sciences ** New York University ** See the file COPYING in the top-level source directory for licensing ** information.\endverbatim ** ** \brief [[ Add one-line brief description here ]] ** ** [[ Add lengthier description here ]] ** \todo document this file **/ #include "theory/rewriter.h" #include "theory/arith/arith_utilities.h" #include "theory/arith/arith_static_learner.h" #include "util/propositional_query.h" #include using namespace std; using namespace CVC4; using namespace CVC4::kind; using namespace CVC4::theory; using namespace CVC4::theory::arith; ArithStaticLearner::ArithStaticLearner(): d_miplibTrick(), d_statistics() {} ArithStaticLearner::Statistics::Statistics(): d_iteMinMaxApplications("theory::arith::iteMinMaxApplications", 0), d_iteConstantApplications("theory::arith::iteConstantApplications", 0), d_miplibtrickApplications("theory::arith::miplibtrickApplications", 0), d_avgNumMiplibtrickValues("theory::arith::avgNumMiplibtrickValues") { StatisticsRegistry::registerStat(&d_iteMinMaxApplications); StatisticsRegistry::registerStat(&d_iteConstantApplications); StatisticsRegistry::registerStat(&d_miplibtrickApplications); StatisticsRegistry::registerStat(&d_avgNumMiplibtrickValues); } ArithStaticLearner::Statistics::~Statistics(){ StatisticsRegistry::unregisterStat(&d_iteMinMaxApplications); StatisticsRegistry::unregisterStat(&d_iteConstantApplications); StatisticsRegistry::unregisterStat(&d_miplibtrickApplications); StatisticsRegistry::unregisterStat(&d_avgNumMiplibtrickValues); } void ArithStaticLearner::staticLearning(TNode n, NodeBuilder<>& learned){ vector workList; workList.push_back(n); TNodeSet processed; //Contains an underapproximation of nodes that must hold. TNodeSet defTrue; defTrue.insert(n); while(!workList.empty()) { n = workList.back(); bool unprocessedChildren = false; for(TNode::iterator i = n.begin(), iend = n.end(); i != iend; ++i) { if(processed.find(*i) == processed.end()) { // unprocessed child workList.push_back(*i); unprocessedChildren = true; } } if(n.getKind() == AND && defTrue.find(n) != defTrue.end() ){ for(TNode::iterator i = n.begin(), iend = n.end(); i != iend; ++i) { defTrue.insert(*i); } } if(unprocessedChildren) { continue; } workList.pop_back(); // has node n been processed in the meantime ? if(processed.find(n) != processed.end()) { continue; } processed.insert(n); process(n,learned, defTrue); } postProcess(learned); } void ArithStaticLearner::clear(){ d_miplibTrick.clear(); } void ArithStaticLearner::process(TNode n, NodeBuilder<>& learned, const TNodeSet& defTrue){ Debug("arith::static") << "===================== looking at" << n << endl; switch(n.getKind()){ case ITE: if(n[0].getKind() != EQUAL && isRelationOperator(n[0].getKind()) ){ iteMinMax(n, learned); } if((n[1].getKind() == CONST_RATIONAL || n[1].getKind() == CONST_INTEGER) && (n[2].getKind() == CONST_RATIONAL || n[2].getKind() == CONST_INTEGER)) { iteConstant(n, learned); } break; case IMPLIES: // == 3-FINITE VALUE SET : Collect information == if(n[1].getKind() == EQUAL && n[1][0].getMetaKind() == metakind::VARIABLE && defTrue.find(n) != defTrue.end()){ Node eqTo = n[1][1]; Node rewriteEqTo = Rewriter::rewrite(eqTo); if(rewriteEqTo.getKind() == CONST_RATIONAL){ TNode var = n[1][0]; if(d_miplibTrick.find(var) == d_miplibTrick.end()){ d_miplibTrick.insert(make_pair(var, set())); } d_miplibTrick[var].insert(n); Debug("arith::miplib") << "insert " << var << " const " << n << endl; } } break; default: // Do nothing break; } } void ArithStaticLearner::iteMinMax(TNode n, NodeBuilder<>& learned){ Assert(n.getKind() == kind::ITE); Assert(n[0].getKind() != EQUAL); Assert(isRelationOperator(n[0].getKind())); TNode c = n[0]; Kind k = simplifiedKind(c); TNode t = n[1]; TNode e = n[2]; TNode cleft = (c.getKind() == NOT) ? c[0][0] : c[0]; TNode cright = (c.getKind() == NOT) ? c[0][1] : c[1]; if((t == cright) && (e == cleft)){ TNode tmp = t; t = e; e = tmp; k = reverseRelationKind(k); } if(t == cleft && e == cright){ // t == cleft && e == cright Assert( t == cleft ); Assert( e == cright ); switch(k){ case LT: // (ite (< x y) x y) case LEQ: { // (ite (<= x y) x y) Node nLeqX = NodeBuilder<2>(LEQ) << n << t; Node nLeqY = NodeBuilder<2>(LEQ) << n << e; Debug("arith::static") << n << "is a min =>" << nLeqX << nLeqY << endl; learned << nLeqX << nLeqY; ++(d_statistics.d_iteMinMaxApplications); break; } case GT: // (ite (> x y) x y) case GEQ: { // (ite (>= x y) x y) Node nGeqX = NodeBuilder<2>(GEQ) << n << t; Node nGeqY = NodeBuilder<2>(GEQ) << n << e; Debug("arith::static") << n << "is a max =>" << nGeqX << nGeqY << endl; learned << nGeqX << nGeqY; ++(d_statistics.d_iteMinMaxApplications); break; } default: Unreachable(); } } } void ArithStaticLearner::iteConstant(TNode n, NodeBuilder<>& learned){ Assert(n.getKind() == ITE); Assert(n[1].getKind() == CONST_RATIONAL || n[1].getKind() == CONST_INTEGER ); Assert(n[2].getKind() == CONST_RATIONAL || n[2].getKind() == CONST_INTEGER ); Rational t = coerceToRational(n[1]); Rational e = coerceToRational(n[2]); TNode min = (t <= e) ? n[1] : n[2]; TNode max = (t >= e) ? n[1] : n[2]; Node nGeqMin = NodeBuilder<2>(GEQ) << n << min; Node nLeqMax = NodeBuilder<2>(LEQ) << n << max; Debug("arith::static") << n << " iteConstant" << nGeqMin << nLeqMax << endl; learned << nGeqMin << nLeqMax; ++(d_statistics.d_iteConstantApplications); } void ArithStaticLearner::postProcess(NodeBuilder<>& learned){ vector keys; VarToNodeSetMap::iterator mipIter = d_miplibTrick.begin(); VarToNodeSetMap::iterator endMipLibTrick = d_miplibTrick.end(); for(; mipIter != endMipLibTrick; ++mipIter){ keys.push_back(mipIter->first); } // == 3-FINITE VALUE SET == vector::iterator keyIter = keys.begin(); vector::iterator endKeys = keys.end(); for(; keyIter != endKeys; ++keyIter){ TNode var = *keyIter; const set& imps = d_miplibTrick[var]; Assert(!imps.empty()); vector conditions; set values; set::const_iterator j=imps.begin(), impsEnd=imps.end(); for(; j != impsEnd; ++j){ TNode imp = *j; Assert(imp.getKind() == IMPLIES); Assert(imp[1].getKind() == EQUAL); Node eqTo = imp[1][1]; Node rewriteEqTo = Rewriter::rewrite(eqTo); Assert(rewriteEqTo.getKind() == CONST_RATIONAL); conditions.push_back(imp[0]); values.insert(rewriteEqTo.getConst()); } Node possibleTaut = Node::null(); if(conditions.size() == 1){ possibleTaut = conditions.front(); }else{ NodeBuilder<> orBuilder(OR); orBuilder.append(conditions); possibleTaut = orBuilder; } Debug("arith::miplib") << "var: " << var << endl; Debug("arith::miplib") << "possibleTaut: " << possibleTaut << endl; Result isTaut = PropositionalQuery::isTautology(possibleTaut); if(isTaut == Result(Result::VALID)){ miplibTrick(var, values, learned); d_miplibTrick.erase(var); } } } void ArithStaticLearner::miplibTrick(TNode var, set& values, NodeBuilder<>& learned){ Debug("arith::miplib") << var << " found a tautology!"<< endl; const Rational& min = *(values.begin()); const Rational& max = *(values.rbegin()); Debug("arith::miplib") << "min: " << min << endl; Debug("arith::miplib") << "max: " << max << endl; Assert(min <= max); ++(d_statistics.d_miplibtrickApplications); (d_statistics.d_avgNumMiplibtrickValues).addEntry(values.size()); Node nGeqMin = NodeBuilder<2>(GEQ) << var << mkRationalNode(min); Node nLeqMax = NodeBuilder<2>(LEQ) << var << mkRationalNode(max); Debug("arith::miplib") << nGeqMin << nLeqMax << endl; learned << nGeqMin << nLeqMax; set::iterator valuesIter = values.begin(); set::iterator valuesEnd = values.end(); set::iterator valuesPrev = valuesIter; ++valuesIter; for(; valuesIter != valuesEnd; valuesPrev = valuesIter, ++valuesIter){ const Rational& prev = *valuesPrev; const Rational& curr = *valuesIter; Assert(prev < curr); //The interval (last,curr) can be excluded: //(not (and (> var prev) (< var curr)) //<=> (or (not (> var prev)) (not (< var curr))) //<=> (or (<= var prev) (>= var curr)) Node leqPrev = NodeBuilder<2>(LEQ) << var << mkRationalNode(prev); Node geqCurr = NodeBuilder<2>(GEQ) << var << mkRationalNode(curr); Node excludedMiddle = NodeBuilder<2>(OR) << leqPrev << geqCurr; Debug("arith::miplib") << excludedMiddle << endl; learned << excludedMiddle; } }