summaryrefslogtreecommitdiff
path: root/src/theory/arith/arith_static_learner.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/theory/arith/arith_static_learner.cpp')
-rw-r--r--src/theory/arith/arith_static_learner.cpp298
1 files changed, 298 insertions, 0 deletions
diff --git a/src/theory/arith/arith_static_learner.cpp b/src/theory/arith/arith_static_learner.cpp
new file mode 100644
index 000000000..6fb538fac
--- /dev/null
+++ b/src/theory/arith/arith_static_learner.cpp
@@ -0,0 +1,298 @@
+/********************* */
+/*! \file arith_rewriter.cpp
+ ** \verbatim
+ ** Original author: taking
+ ** Major contributors: dejan
+ ** Minor contributors (to current version): mdeters
+ ** This file is part of the CVC4 prototype.
+ ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys)
+ ** Courant Institute of Mathematical Sciences
+ ** New York University
+ ** See the file COPYING in the top-level source directory for licensing
+ ** information.\endverbatim
+ **
+ ** \brief [[ Add one-line brief description here ]]
+ **
+ ** [[ Add lengthier description here ]]
+ ** \todo document this file
+ **/
+
+#include "theory/rewriter.h"
+
+#include "theory/arith/arith_utilities.h"
+#include "theory/arith/arith_static_learner.h"
+
+#include "util/propositional_query.h"
+
+#include <vector>
+
+using namespace std;
+
+using namespace CVC4;
+using namespace CVC4::kind;
+using namespace CVC4::theory;
+using namespace CVC4::theory::arith;
+
+
+ArithStaticLearner::ArithStaticLearner():
+ d_miplibTrick(),
+ d_statistics()
+{}
+
+ArithStaticLearner::Statistics::Statistics():
+ d_iteMinMaxApplications("theory::arith::iteMinMaxApplications", 0),
+ d_iteConstantApplications("theory::arith::iteConstantApplications", 0),
+ d_miplibtrickApplications("theory::arith::miplibtrickApplications", 0),
+ d_avgNumMiplibtrickValues("theory::arith::avgNumMiplibtrickValues")
+{
+ StatisticsRegistry::registerStat(&d_iteMinMaxApplications);
+ StatisticsRegistry::registerStat(&d_iteConstantApplications);
+ StatisticsRegistry::registerStat(&d_miplibtrickApplications);
+ StatisticsRegistry::registerStat(&d_avgNumMiplibtrickValues);
+}
+
+ArithStaticLearner::Statistics::~Statistics(){
+ StatisticsRegistry::unregisterStat(&d_iteMinMaxApplications);
+ StatisticsRegistry::unregisterStat(&d_iteConstantApplications);
+ StatisticsRegistry::unregisterStat(&d_miplibtrickApplications);
+ StatisticsRegistry::unregisterStat(&d_avgNumMiplibtrickValues);
+}
+
+void ArithStaticLearner::staticLearning(TNode n, NodeBuilder<>& learned){
+
+ vector<TNode> workList;
+ workList.push_back(n);
+ TNodeSet processed;
+
+ //Contains an underapproximation of nodes that must hold.
+ TNodeSet defTrue;
+
+ defTrue.insert(n);
+
+ while(!workList.empty()) {
+ n = workList.back();
+
+ bool unprocessedChildren = false;
+ for(TNode::iterator i = n.begin(), iend = n.end(); i != iend; ++i) {
+ if(processed.find(*i) == processed.end()) {
+ // unprocessed child
+ workList.push_back(*i);
+ unprocessedChildren = true;
+ }
+ }
+ if(n.getKind() == AND && defTrue.find(n) != defTrue.end() ){
+ for(TNode::iterator i = n.begin(), iend = n.end(); i != iend; ++i) {
+ defTrue.insert(*i);
+ }
+ }
+
+ if(unprocessedChildren) {
+ continue;
+ }
+
+ workList.pop_back();
+ // has node n been processed in the meantime ?
+ if(processed.find(n) != processed.end()) {
+ continue;
+ }
+ processed.insert(n);
+
+ process(n,learned, defTrue);
+
+ }
+
+ postProcess(learned);
+}
+
+void ArithStaticLearner::clear(){
+ d_miplibTrick.clear();
+}
+
+
+void ArithStaticLearner::process(TNode n, NodeBuilder<>& learned, const TNodeSet& defTrue){
+ Debug("arith::static") << "===================== looking at" << n << endl;
+
+ switch(n.getKind()){
+ case ITE:
+ if(n[0].getKind() != EQUAL &&
+ isRelationOperator(n[0].getKind()) ){
+ iteMinMax(n, learned);
+ }
+
+ if((n[1].getKind() == CONST_RATIONAL || n[1].getKind() == CONST_INTEGER) &&
+ (n[2].getKind() == CONST_RATIONAL || n[2].getKind() == CONST_INTEGER)) {
+ iteConstant(n, learned);
+ }
+ break;
+ case IMPLIES:
+ // == 3-FINITE VALUE SET : Collect information ==
+ if(n[1].getKind() == EQUAL &&
+ n[1][0].getMetaKind() == metakind::VARIABLE &&
+ defTrue.find(n) != defTrue.end()){
+ Node eqTo = n[1][1];
+ Node rewriteEqTo = Rewriter::rewrite(eqTo);
+ if(rewriteEqTo.getKind() == CONST_RATIONAL){
+
+ TNode var = n[1][0];
+ if(d_miplibTrick.find(var) == d_miplibTrick.end()){
+ d_miplibTrick.insert(make_pair(var, set<Node>()));
+ }
+ d_miplibTrick[var].insert(n);
+ Debug("arith::miplib") << "insert " << var << " const " << n << endl;
+ }
+ }
+ break;
+ default: // Do nothing
+ break;
+ }
+}
+
+void ArithStaticLearner::iteMinMax(TNode n, NodeBuilder<>& learned){
+ Assert(n.getKind() == kind::ITE);
+ Assert(n[0].getKind() != EQUAL);
+ Assert(isRelationOperator(n[0].getKind()));
+
+ TNode c = n[0];
+ Kind k = simplifiedKind(c);
+ TNode t = n[1];
+ TNode e = n[2];
+ TNode cleft = (c.getKind() == NOT) ? c[0][0] : c[0];
+ TNode cright = (c.getKind() == NOT) ? c[0][1] : c[1];
+
+ if((t == cright) && (e == cleft)){
+ TNode tmp = t;
+ t = e;
+ e = tmp;
+ k = reverseRelationKind(k);
+ }
+ if(t == cleft && e == cright){
+ // t == cleft && e == cright
+ Assert( t == cleft );
+ Assert( e == cright );
+ switch(k){
+ case LT: // (ite (< x y) x y)
+ case LEQ: { // (ite (<= x y) x y)
+ Node nLeqX = NodeBuilder<2>(LEQ) << n << t;
+ Node nLeqY = NodeBuilder<2>(LEQ) << n << e;
+ Debug("arith::static") << n << "is a min =>" << nLeqX << nLeqY << endl;
+ learned << nLeqX << nLeqY;
+ ++(d_statistics.d_iteMinMaxApplications);
+ break;
+ }
+ case GT: // (ite (> x y) x y)
+ case GEQ: { // (ite (>= x y) x y)
+ Node nGeqX = NodeBuilder<2>(GEQ) << n << t;
+ Node nGeqY = NodeBuilder<2>(GEQ) << n << e;
+ Debug("arith::static") << n << "is a max =>" << nGeqX << nGeqY << endl;
+ learned << nGeqX << nGeqY;
+ ++(d_statistics.d_iteMinMaxApplications);
+ break;
+ }
+ default: Unreachable();
+ }
+ }
+}
+
+void ArithStaticLearner::iteConstant(TNode n, NodeBuilder<>& learned){
+ Assert(n.getKind() == ITE);
+ Assert(n[1].getKind() == CONST_RATIONAL || n[1].getKind() == CONST_INTEGER );
+ Assert(n[2].getKind() == CONST_RATIONAL || n[2].getKind() == CONST_INTEGER );
+
+ Rational t = coerceToRational(n[1]);
+ Rational e = coerceToRational(n[2]);
+ TNode min = (t <= e) ? n[1] : n[2];
+ TNode max = (t >= e) ? n[1] : n[2];
+
+ Node nGeqMin = NodeBuilder<2>(GEQ) << n << min;
+ Node nLeqMax = NodeBuilder<2>(LEQ) << n << max;
+ Debug("arith::static") << n << " iteConstant" << nGeqMin << nLeqMax << endl;
+ learned << nGeqMin << nLeqMax;
+ ++(d_statistics.d_iteConstantApplications);
+}
+
+
+void ArithStaticLearner::postProcess(NodeBuilder<>& learned){
+
+ // == 3-FINITE VALUE SET ==
+ VarToNodeSetMap::iterator i = d_miplibTrick.begin();
+ VarToNodeSetMap::iterator endMipLibTrick = d_miplibTrick.end();
+ for(; i != endMipLibTrick; ++i){
+ TNode var = i->first;
+ const set<Node>& imps = i->second;
+
+ Assert(!imps.empty());
+ vector<Node> conditions;
+ set<Rational> values;
+ set<Node>::const_iterator j=imps.begin(), impsEnd=imps.end();
+ for(; j != impsEnd; ++j){
+ TNode imp = *j;
+ Assert(imp.getKind() == IMPLIES);
+ Assert(imp[1].getKind() == EQUAL);
+
+ Node eqTo = imp[1][1];
+ Node rewriteEqTo = Rewriter::rewrite(eqTo);
+ Assert(rewriteEqTo.getKind() == CONST_RATIONAL);
+
+ conditions.push_back(imp[0]);
+ values.insert(rewriteEqTo.getConst<Rational>());
+ }
+
+ Node possibleTaut = Node::null();
+ if(conditions.size() == 1){
+ possibleTaut = conditions.front();
+ }else{
+ NodeBuilder<> orBuilder(OR);
+ orBuilder.append(conditions);
+ possibleTaut = orBuilder;
+ }
+
+
+ Debug("arith::miplib") << "var: " << var << endl;
+ Debug("arith::miplib") << "possibleTaut: " << possibleTaut << endl;
+
+ Result isTaut = PropositionalQuery::isTautology(possibleTaut);
+ if(isTaut == Result(Result::VALID)){
+ miplibTrick(var, values, learned);
+ }
+ }
+}
+
+
+void ArithStaticLearner::miplibTrick(TNode var, set<Rational>& values, NodeBuilder<>& learned){
+
+ Debug("arith::miplib") << var << " found a tautology!"<< endl;
+
+ const Rational& min = *(values.begin());
+ const Rational& max = *(values.rbegin());
+
+ Debug("arith::miplib") << "min: " << min << endl;
+ Debug("arith::miplib") << "max: " << max << endl;
+
+ Assert(min <= max);
+ ++(d_statistics.d_miplibtrickApplications);
+ (d_statistics.d_avgNumMiplibtrickValues).addEntry(values.size());
+
+ Node nGeqMin = NodeBuilder<2>(GEQ) << var << mkRationalNode(min);
+ Node nLeqMax = NodeBuilder<2>(LEQ) << var << mkRationalNode(max);
+ Debug("arith::miplib") << nGeqMin << nLeqMax << endl;
+ learned << nGeqMin << nLeqMax;
+ set<Rational>::iterator valuesIter = values.begin();
+ set<Rational>::iterator valuesEnd = values.end();
+ set<Rational>::iterator valuesPrev = valuesIter;
+ ++valuesIter;
+ for(; valuesIter != valuesEnd; valuesPrev = valuesIter, ++valuesIter){
+ const Rational& prev = *valuesPrev;
+ const Rational& curr = *valuesIter;
+ Assert(prev < curr);
+
+ //The interval (last,curr) can be excluded:
+ //(not (and (> var prev) (< var curr))
+ //<=> (or (not (> var prev)) (not (< var curr)))
+ //<=> (or (<= var prev) (>= var curr))
+ Node leqPrev = NodeBuilder<2>(LEQ) << var << mkRationalNode(prev);
+ Node geqCurr = NodeBuilder<2>(GEQ) << var << mkRationalNode(curr);
+ Node excludedMiddle = NodeBuilder<2>(OR) << leqPrev << geqCurr;
+ Debug("arith::miplib") << excludedMiddle << endl;
+ learned << excludedMiddle;
+ }
+}
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback