summaryrefslogtreecommitdiff
path: root/src/theory/arith/arith_static_learner.cpp
diff options
context:
space:
mode:
authorMorgan Deters <mdeters@gmail.com>2011-09-02 17:56:43 +0000
committerMorgan Deters <mdeters@gmail.com>2011-09-02 17:56:43 +0000
commit487e610b88f2a634e3285886ff96717c103338de (patch)
tree7f034b5c9f537195df72ac9ecd7666226dc2ed9f /src/theory/arith/arith_static_learner.cpp
parent90267f8729799f44c6fb33ace11b971a16e78dff (diff)
Partial merge of integers work; this is simple B&B and some pseudoboolean
infrastructure, and takes care not to affect CVC4's performance on LRA benchmarks.
Diffstat (limited to 'src/theory/arith/arith_static_learner.cpp')
-rw-r--r--src/theory/arith/arith_static_learner.cpp127
1 files changed, 124 insertions, 3 deletions
diff --git a/src/theory/arith/arith_static_learner.cpp b/src/theory/arith/arith_static_learner.cpp
index 77a89b54b..a5fa428c6 100644
--- a/src/theory/arith/arith_static_learner.cpp
+++ b/src/theory/arith/arith_static_learner.cpp
@@ -24,6 +24,9 @@
#include "util/propositional_query.h"
+#include "expr/expr.h"
+#include "expr/convenience_node_builders.h"
+
#include <vector>
using namespace std;
@@ -34,9 +37,10 @@ using namespace CVC4::theory;
using namespace CVC4::theory::arith;
-ArithStaticLearner::ArithStaticLearner():
+ArithStaticLearner::ArithStaticLearner(SubstitutionMap& pbSubstitutions) :
d_miplibTrick(),
d_miplibTrickKeys(),
+ d_pbSubstitutions(pbSubstitutions),
d_statistics()
{}
@@ -105,9 +109,11 @@ void ArithStaticLearner::staticLearning(TNode n, NodeBuilder<>& learned){
postProcess(learned);
}
+
void ArithStaticLearner::clear(){
d_miplibTrick.clear();
d_miplibTrickKeys.clear();
+ // do not clear d_pbSubstitutions, as it is shared
}
@@ -151,11 +157,101 @@ void ArithStaticLearner::process(TNode n, NodeBuilder<>& learned, const TNodeSet
d_minMap[n] = coerceToRational(n);
d_maxMap[n] = coerceToRational(n);
break;
+ case OR: {
+ // Look for things like "x = 0 OR x = 1" (that are defTrue) and
+ // turn them into a pseudoboolean. We catch "x >= 0
+ if(defTrue.find(n) == defTrue.end() ||
+ n.getNumChildren() != 2 ||
+ n[0].getKind() != EQUAL ||
+ n[1].getKind() != EQUAL) {
+ break;
+ }
+ Node var, c1, c2;
+ if(n[0][0].getMetaKind() == metakind::VARIABLE &&
+ n[0][1].getMetaKind() == metakind::CONSTANT) {
+ var = n[0][0];
+ c1 = n[0][1];
+ } else if(n[0][1].getMetaKind() == metakind::VARIABLE &&
+ n[0][0].getMetaKind() == metakind::CONSTANT) {
+ var = n[0][1];
+ c1 = n[0][0];
+ } else {
+ break;
+ }
+ if(!var.getType().isInteger() ||
+ !c1.getType().isReal()) {
+ break;
+ }
+ if(var == n[1][0]) {
+ c2 = n[1][1];
+ } else if(var == n[1][1]) {
+ c2 = n[1][0];
+ } else {
+ break;
+ }
+ if(!c2.getType().isReal()) {
+ break;
+ }
+
+ Integer k1, k2;
+ if(c1.getType().getConst<TypeConstant>() == INTEGER_TYPE) {
+ k1 = c1.getConst<Integer>();
+ } else {
+ Rational r = c1.getConst<Rational>();
+ if(r.getDenominator() == 1) {
+ k1 = r.getNumerator();
+ } else {
+ break;
+ }
+ }
+ if(c2.getType().getConst<TypeConstant>() == INTEGER_TYPE) {
+ k2 = c2.getConst<Integer>();
+ } else {
+ Rational r = c2.getConst<Rational>();
+ if(r.getDenominator() == 1) {
+ k2 = r.getNumerator();
+ } else {
+ break;
+ }
+ }
+ if(k1 > k2) {
+ swap(k1, k2);
+ }
+ if(k1 + 1 == k2) {
+ Debug("arith::static") << "==> found " << n << endl
+ << " which indicates " << var << " \\in { "
+ << k1 << " , " << k2 << " }" << endl;
+ c1 = NodeManager::currentNM()->mkConst(k1);
+ c2 = NodeManager::currentNM()->mkConst(k2);
+ Node lhs = NodeBuilder<2>(kind::GEQ) << var << c1;
+ Node rhs = NodeBuilder<2>(kind::LEQ) << var << c2;
+ Node l = lhs && rhs;
+ Debug("arith::static") << " learned: " << l << endl;
+ learned << l;
+ if(k1 == 0) {
+ Assert(k2 == 1);
+ replaceWithPseudoboolean(var);
+ }
+ }
+ break;
+ }
default: // Do nothing
break;
}
}
+void ArithStaticLearner::replaceWithPseudoboolean(TNode var) {
+ AssertArgument(var.getMetaKind() == kind::metakind::VARIABLE, var);
+ TypeNode pbType = NodeManager::currentNM()->pseudobooleanType();
+ Node pbVar = NodeManager::currentNM()->mkVar(string("PB[") + var.toString() + ']', pbType);
+ d_pbSubstitutions.addSubstitution(var, pbVar);
+
+ if(Debug.isOn("pb")) {
+ Expr::printtypes::Scope pts(Debug("pb"), true);
+ Debug("pb") << "will replace " << var << " with " << pbVar << endl;
+ }
+}
+
void ArithStaticLearner::iteMinMax(TNode n, NodeBuilder<>& learned){
Assert(n.getKind() == kind::ITE);
Assert(n[0].getKind() != EQUAL);
@@ -341,6 +437,27 @@ void ArithStaticLearner::miplibTrick(TNode var, set<Rational>& values, NodeBuild
}
}
+void ArithStaticLearner::checkBoundsForPseudobooleanReplacement(TNode n) {
+ NodeToMinMaxMap::iterator minFind = d_minMap.find(n);
+ NodeToMinMaxMap::iterator maxFind = d_maxMap.find(n);
+
+ if( n.getType().isInteger() &&
+ minFind != d_minMap.end() &&
+ maxFind != d_maxMap.end() &&
+ ( ( (*minFind).second.getNoninfinitesimalPart() == 1 &&
+ (*minFind).second.getInfinitesimalPart() == 0 ) ||
+ ( (*minFind).second.getNoninfinitesimalPart() == 0 &&
+ (*minFind).second.getInfinitesimalPart() > 0 ) ) &&
+ ( ( (*maxFind).second.getNoninfinitesimalPart() == 1 &&
+ (*maxFind).second.getInfinitesimalPart() == 0 ) ||
+ ( (*maxFind).second.getNoninfinitesimalPart() == 2 &&
+ (*maxFind).second.getInfinitesimalPart() < 0 ) ) ) {
+ // eligible for pseudoboolean replacement
+ Debug("pb") << "eligible for pseudoboolean replacement: " << n << endl;
+ replaceWithPseudoboolean(n);
+ }
+}
+
void ArithStaticLearner::addBound(TNode n) {
NodeToMinMaxMap::iterator minFind = d_minMap.find(n[0]);
@@ -349,25 +466,29 @@ void ArithStaticLearner::addBound(TNode n) {
Rational constant = coerceToRational(n[1]);
DeltaRational bound = constant;
- switch(n.getKind()) {
+ switch(Kind k = n.getKind()) {
case kind::LT:
bound = DeltaRational(constant, -1);
+ /* fall through */
case kind::LEQ:
if (maxFind == d_maxMap.end() || maxFind->second > bound) {
d_maxMap[n[0]] = bound;
Debug("arith::static") << "adding bound " << n << endl;
+ checkBoundsForPseudobooleanReplacement(n[0]);
}
break;
case kind::GT:
bound = DeltaRational(constant, 1);
+ /* fall through */
case kind::GEQ:
if (minFind == d_minMap.end() || minFind->second < bound) {
d_minMap[n[0]] = bound;
Debug("arith::static") << "adding bound " << n << endl;
+ checkBoundsForPseudobooleanReplacement(n[0]);
}
break;
default:
- // nothing else
+ Unhandled(k);
break;
}
}
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback