diff options
Diffstat (limited to 'src/theory/arith/arith_static_learner.cpp')
-rw-r--r-- | src/theory/arith/arith_static_learner.cpp | 84 |
1 files changed, 43 insertions, 41 deletions
diff --git a/src/theory/arith/arith_static_learner.cpp b/src/theory/arith/arith_static_learner.cpp index a5d2b0a53..af2f0c9bc 100644 --- a/src/theory/arith/arith_static_learner.cpp +++ b/src/theory/arith/arith_static_learner.cpp @@ -35,10 +35,10 @@ namespace theory { namespace arith { -ArithStaticLearner::ArithStaticLearner(SubstitutionMap& pbSubstitutions) : - d_miplibTrick(), - d_miplibTrickKeys(), - d_pbSubstitutions(pbSubstitutions), +ArithStaticLearner::ArithStaticLearner(context::Context* userContext) : + d_miplibTrick(userContext), + d_minMap(userContext), + d_maxMap(userContext), d_statistics() {} @@ -108,11 +108,7 @@ void ArithStaticLearner::staticLearning(TNode n, NodeBuilder<>& learned){ } -void ArithStaticLearner::clear(){ - d_miplibTrick.clear(); - d_miplibTrickKeys.clear(); - // do not clear d_pbSubstitutions, as it is shared -} + void ArithStaticLearner::process(TNode n, NodeBuilder<>& learned, const TNodeSet& defTrue){ @@ -140,11 +136,9 @@ void ArithStaticLearner::process(TNode n, NodeBuilder<>& learned, const TNodeSet if(rewriteEqTo.getKind() == CONST_RATIONAL){ TNode var = n[1][0]; - if(d_miplibTrick.find(var) == d_miplibTrick.end()){ - d_miplibTrick.insert(make_pair(var, set<Node>())); - d_miplibTrickKeys.push_back(var); - } - d_miplibTrick[var].insert(n); + Node current = (d_miplibTrick.find(var) == d_miplibTrick.end()) ? + mkBoolNode(false) : d_miplibTrick[var]; + d_miplibTrick.insert(var, n.orNode(current)); Debug("arith::miplib") << "insert " << var << " const " << n << endl; } } @@ -249,9 +243,11 @@ void ArithStaticLearner::iteConstant(TNode n, NodeBuilder<>& learned){ Debug("arith::static") << "iteConstant(" << n << ")" << endl; if (d_minMap.find(n[1]) != d_minMap.end() && d_minMap.find(n[2]) != d_minMap.end()) { - DeltaRational min = std::min(d_minMap[n[1]], d_minMap[n[2]]); - NodeToMinMaxMap::iterator minFind = d_minMap.find(n); - if (minFind == d_minMap.end() || minFind->second < min) { + const DeltaRational& first = d_minMap[n[1]]; + const DeltaRational& second = d_minMap[n[2]]; + DeltaRational min = std::min(first, second); + CDNodeToMinMaxMap::const_iterator minFind = d_minMap.find(n); + if (minFind == d_minMap.end() || (*minFind).second < min) { d_minMap[n] = min; Node nGeqMin; if (min.getInfinitesimalPart() == 0) { @@ -266,9 +262,11 @@ void ArithStaticLearner::iteConstant(TNode n, NodeBuilder<>& learned){ } if (d_maxMap.find(n[1]) != d_maxMap.end() && d_maxMap.find(n[2]) != d_maxMap.end()) { - DeltaRational max = std::max(d_maxMap[n[1]], d_maxMap[n[2]]); - NodeToMinMaxMap::iterator maxFind = d_maxMap.find(n); - if (maxFind == d_maxMap.end() || maxFind->second > max) { + const DeltaRational& first = d_minMap[n[1]]; + const DeltaRational& second = d_minMap[n[2]]; + DeltaRational max = std::max(first, second); + CDNodeToMinMaxMap::const_iterator maxFind = d_maxMap.find(n); + if (maxFind == d_maxMap.end() || (*maxFind).second > max) { d_maxMap[n] = max; Node nLeqMax; if (max.getInfinitesimalPart() == 0) { @@ -283,14 +281,29 @@ void ArithStaticLearner::iteConstant(TNode n, NodeBuilder<>& learned){ } } +std::set<Node> listToSet(TNode l){ + std::set<Node> ret; + while(l.getKind() == OR){ + Assert(l.getNumChildren() == 2); + ret.insert(l[0]); + l = l[1]; + } + return ret; +} void ArithStaticLearner::postProcess(NodeBuilder<>& learned){ // == 3-FINITE VALUE SET == - list<TNode>::iterator keyIter = d_miplibTrickKeys.begin(); - list<TNode>::iterator endKeys = d_miplibTrickKeys.end(); + CDNodeToNodeListMap::const_iterator keyIter = d_miplibTrick.begin(); + CDNodeToNodeListMap::const_iterator endKeys = d_miplibTrick.end(); while(keyIter != endKeys) { - TNode var = *keyIter; - const set<Node>& imps = d_miplibTrick[var]; + TNode var = (*keyIter).first; + Node list = (*keyIter).second; + const set<Node> imps = listToSet(list); + + if(imps.empty()){ + ++keyIter; + continue; + } Assert(!imps.empty()); vector<Node> conditions; @@ -325,20 +338,9 @@ void ArithStaticLearner::postProcess(NodeBuilder<>& learned){ Result isTaut = PropositionalQuery::isTautology(possibleTaut); if(isTaut == Result(Result::VALID)){ miplibTrick(var, values, learned); - d_miplibTrick.erase(var); - // also have to erase from keys list - if(keyIter == endKeys) { - // last element is special: exit loop - d_miplibTrickKeys.erase(keyIter); - break; - } else { - // non-last element: make sure iterator is incremented before erase - list<TNode>::iterator eraseIter = keyIter++; - d_miplibTrickKeys.erase(eraseIter); - } - } else { - ++keyIter; + d_miplibTrick.insert(var, mkBoolNode(false)); } + ++keyIter; } } @@ -384,8 +386,8 @@ void ArithStaticLearner::miplibTrick(TNode var, set<Rational>& values, NodeBuild void ArithStaticLearner::addBound(TNode n) { - NodeToMinMaxMap::iterator minFind = d_minMap.find(n[0]); - NodeToMinMaxMap::iterator maxFind = d_maxMap.find(n[0]); + CDNodeToMinMaxMap::const_iterator minFind = d_minMap.find(n[0]); + CDNodeToMinMaxMap::const_iterator maxFind = d_maxMap.find(n[0]); Rational constant = n[1].getConst<Rational>(); DeltaRational bound = constant; @@ -395,7 +397,7 @@ void ArithStaticLearner::addBound(TNode n) { bound = DeltaRational(constant, -1); /* fall through */ case kind::LEQ: - if (maxFind == d_maxMap.end() || maxFind->second > bound) { + if (maxFind == d_maxMap.end() || (*maxFind).second > bound) { d_maxMap[n[0]] = bound; Debug("arith::static") << "adding bound " << n << endl; } @@ -404,7 +406,7 @@ void ArithStaticLearner::addBound(TNode n) { bound = DeltaRational(constant, 1); /* fall through */ case kind::GEQ: - if (minFind == d_minMap.end() || minFind->second < bound) { + if (minFind == d_minMap.end() || (*minFind).second < bound) { d_minMap[n[0]] = bound; Debug("arith::static") << "adding bound " << n << endl; } |