summaryrefslogtreecommitdiff
path: root/src/theory/arith/arith_static_learner.cpp
diff options
context:
space:
mode:
authorTim King <taking@cs.nyu.edu>2012-10-24 21:46:34 +0000
committerTim King <taking@cs.nyu.edu>2012-10-24 21:46:34 +0000
commita6ac7fefed613c4d83e577361f98c28a8e18f3a9 (patch)
treecd9cf557f8bac184c9ffc5e85280f073eafb1b53 /src/theory/arith/arith_static_learner.cpp
parent203435906c670095b3b753077f09ad334f278bf7 (diff)
Updated the ArithStaticLearner to be user context dependent.
Diffstat (limited to 'src/theory/arith/arith_static_learner.cpp')
-rw-r--r--src/theory/arith/arith_static_learner.cpp84
1 files changed, 43 insertions, 41 deletions
diff --git a/src/theory/arith/arith_static_learner.cpp b/src/theory/arith/arith_static_learner.cpp
index a5d2b0a53..af2f0c9bc 100644
--- a/src/theory/arith/arith_static_learner.cpp
+++ b/src/theory/arith/arith_static_learner.cpp
@@ -35,10 +35,10 @@ namespace theory {
namespace arith {
-ArithStaticLearner::ArithStaticLearner(SubstitutionMap& pbSubstitutions) :
- d_miplibTrick(),
- d_miplibTrickKeys(),
- d_pbSubstitutions(pbSubstitutions),
+ArithStaticLearner::ArithStaticLearner(context::Context* userContext) :
+ d_miplibTrick(userContext),
+ d_minMap(userContext),
+ d_maxMap(userContext),
d_statistics()
{}
@@ -108,11 +108,7 @@ void ArithStaticLearner::staticLearning(TNode n, NodeBuilder<>& learned){
}
-void ArithStaticLearner::clear(){
- d_miplibTrick.clear();
- d_miplibTrickKeys.clear();
- // do not clear d_pbSubstitutions, as it is shared
-}
+
void ArithStaticLearner::process(TNode n, NodeBuilder<>& learned, const TNodeSet& defTrue){
@@ -140,11 +136,9 @@ void ArithStaticLearner::process(TNode n, NodeBuilder<>& learned, const TNodeSet
if(rewriteEqTo.getKind() == CONST_RATIONAL){
TNode var = n[1][0];
- if(d_miplibTrick.find(var) == d_miplibTrick.end()){
- d_miplibTrick.insert(make_pair(var, set<Node>()));
- d_miplibTrickKeys.push_back(var);
- }
- d_miplibTrick[var].insert(n);
+ Node current = (d_miplibTrick.find(var) == d_miplibTrick.end()) ?
+ mkBoolNode(false) : d_miplibTrick[var];
+ d_miplibTrick.insert(var, n.orNode(current));
Debug("arith::miplib") << "insert " << var << " const " << n << endl;
}
}
@@ -249,9 +243,11 @@ void ArithStaticLearner::iteConstant(TNode n, NodeBuilder<>& learned){
Debug("arith::static") << "iteConstant(" << n << ")" << endl;
if (d_minMap.find(n[1]) != d_minMap.end() && d_minMap.find(n[2]) != d_minMap.end()) {
- DeltaRational min = std::min(d_minMap[n[1]], d_minMap[n[2]]);
- NodeToMinMaxMap::iterator minFind = d_minMap.find(n);
- if (minFind == d_minMap.end() || minFind->second < min) {
+ const DeltaRational& first = d_minMap[n[1]];
+ const DeltaRational& second = d_minMap[n[2]];
+ DeltaRational min = std::min(first, second);
+ CDNodeToMinMaxMap::const_iterator minFind = d_minMap.find(n);
+ if (minFind == d_minMap.end() || (*minFind).second < min) {
d_minMap[n] = min;
Node nGeqMin;
if (min.getInfinitesimalPart() == 0) {
@@ -266,9 +262,11 @@ void ArithStaticLearner::iteConstant(TNode n, NodeBuilder<>& learned){
}
if (d_maxMap.find(n[1]) != d_maxMap.end() && d_maxMap.find(n[2]) != d_maxMap.end()) {
- DeltaRational max = std::max(d_maxMap[n[1]], d_maxMap[n[2]]);
- NodeToMinMaxMap::iterator maxFind = d_maxMap.find(n);
- if (maxFind == d_maxMap.end() || maxFind->second > max) {
+ const DeltaRational& first = d_minMap[n[1]];
+ const DeltaRational& second = d_minMap[n[2]];
+ DeltaRational max = std::max(first, second);
+ CDNodeToMinMaxMap::const_iterator maxFind = d_maxMap.find(n);
+ if (maxFind == d_maxMap.end() || (*maxFind).second > max) {
d_maxMap[n] = max;
Node nLeqMax;
if (max.getInfinitesimalPart() == 0) {
@@ -283,14 +281,29 @@ void ArithStaticLearner::iteConstant(TNode n, NodeBuilder<>& learned){
}
}
+std::set<Node> listToSet(TNode l){
+ std::set<Node> ret;
+ while(l.getKind() == OR){
+ Assert(l.getNumChildren() == 2);
+ ret.insert(l[0]);
+ l = l[1];
+ }
+ return ret;
+}
void ArithStaticLearner::postProcess(NodeBuilder<>& learned){
// == 3-FINITE VALUE SET ==
- list<TNode>::iterator keyIter = d_miplibTrickKeys.begin();
- list<TNode>::iterator endKeys = d_miplibTrickKeys.end();
+ CDNodeToNodeListMap::const_iterator keyIter = d_miplibTrick.begin();
+ CDNodeToNodeListMap::const_iterator endKeys = d_miplibTrick.end();
while(keyIter != endKeys) {
- TNode var = *keyIter;
- const set<Node>& imps = d_miplibTrick[var];
+ TNode var = (*keyIter).first;
+ Node list = (*keyIter).second;
+ const set<Node> imps = listToSet(list);
+
+ if(imps.empty()){
+ ++keyIter;
+ continue;
+ }
Assert(!imps.empty());
vector<Node> conditions;
@@ -325,20 +338,9 @@ void ArithStaticLearner::postProcess(NodeBuilder<>& learned){
Result isTaut = PropositionalQuery::isTautology(possibleTaut);
if(isTaut == Result(Result::VALID)){
miplibTrick(var, values, learned);
- d_miplibTrick.erase(var);
- // also have to erase from keys list
- if(keyIter == endKeys) {
- // last element is special: exit loop
- d_miplibTrickKeys.erase(keyIter);
- break;
- } else {
- // non-last element: make sure iterator is incremented before erase
- list<TNode>::iterator eraseIter = keyIter++;
- d_miplibTrickKeys.erase(eraseIter);
- }
- } else {
- ++keyIter;
+ d_miplibTrick.insert(var, mkBoolNode(false));
}
+ ++keyIter;
}
}
@@ -384,8 +386,8 @@ void ArithStaticLearner::miplibTrick(TNode var, set<Rational>& values, NodeBuild
void ArithStaticLearner::addBound(TNode n) {
- NodeToMinMaxMap::iterator minFind = d_minMap.find(n[0]);
- NodeToMinMaxMap::iterator maxFind = d_maxMap.find(n[0]);
+ CDNodeToMinMaxMap::const_iterator minFind = d_minMap.find(n[0]);
+ CDNodeToMinMaxMap::const_iterator maxFind = d_maxMap.find(n[0]);
Rational constant = n[1].getConst<Rational>();
DeltaRational bound = constant;
@@ -395,7 +397,7 @@ void ArithStaticLearner::addBound(TNode n) {
bound = DeltaRational(constant, -1);
/* fall through */
case kind::LEQ:
- if (maxFind == d_maxMap.end() || maxFind->second > bound) {
+ if (maxFind == d_maxMap.end() || (*maxFind).second > bound) {
d_maxMap[n[0]] = bound;
Debug("arith::static") << "adding bound " << n << endl;
}
@@ -404,7 +406,7 @@ void ArithStaticLearner::addBound(TNode n) {
bound = DeltaRational(constant, 1);
/* fall through */
case kind::GEQ:
- if (minFind == d_minMap.end() || minFind->second < bound) {
+ if (minFind == d_minMap.end() || (*minFind).second < bound) {
d_minMap[n[0]] = bound;
Debug("arith::static") << "adding bound " << n << endl;
}
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback