summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/theory/arith/nonlinear_extension.cpp79
1 files changed, 64 insertions, 15 deletions
diff --git a/src/theory/arith/nonlinear_extension.cpp b/src/theory/arith/nonlinear_extension.cpp
index c9a4c5075..be8f22222 100644
--- a/src/theory/arith/nonlinear_extension.cpp
+++ b/src/theory/arith/nonlinear_extension.cpp
@@ -1056,6 +1056,9 @@ void NonlinearExtension::addCheckModelSubstitution(TNode v, TNode s)
void NonlinearExtension::addCheckModelBound(TNode v, TNode l, TNode u)
{
Assert(!hasCheckModelAssignment(v));
+ Assert(l.isConst());
+ Assert(u.isConst());
+ Assert(l.getConst<Rational>() <= u.getConst<Rational>());
d_check_model_bounds[v] = std::pair<Node, Node>(l, u);
}
@@ -1294,6 +1297,14 @@ bool NonlinearExtension::solveEqualitySimple(Node eq)
MULT, coeffa, nm->mkNode(r == 0 ? MINUS : PLUS, negb, val));
approx = Rewriter::rewrite(approx);
bounds[r][b] = approx;
+ Assert(approx.isConst());
+ }
+ if (bounds[r][0].getConst<Rational>() > bounds[r][1].getConst<Rational>())
+ {
+ // ensure bound is (lower, upper)
+ Node tmp = bounds[r][0];
+ bounds[r][0] = bounds[r][1];
+ bounds[r][1] = tmp;
}
Node diff =
nm->mkNode(MINUS,
@@ -1448,26 +1459,31 @@ bool NonlinearExtension::simpleCheckModelLit(Node lit)
t = Rewriter::rewrite(t);
Trace("nl-ext-cms-debug") << "Trying to find min/max for quadratic "
<< t << "..." << std::endl;
+ Trace("nl-ext-cms-debug") << " a = " << a << std::endl;
+ Trace("nl-ext-cms-debug") << " b = " << b << std::endl;
// find maximal/minimal value on the interval
Node apex = nm->mkNode(
DIVISION, nm->mkNode(UMINUS, b), nm->mkNode(MULT, d_two, a));
apex = Rewriter::rewrite(apex);
Assert(apex.isConst());
+ // for lower, upper, whether we are greater than the apex
bool cmp[2];
Node boundn[2];
for (unsigned r = 0; r < 2; r++)
{
boundn[r] = r == 0 ? bit->second.first : bit->second.second;
- Node cmpn = nm->mkNode(LT, boundn[r], apex);
+ Node cmpn = nm->mkNode(GT, boundn[r], apex);
cmpn = Rewriter::rewrite(cmpn);
Assert(cmpn.isConst());
cmp[r] = cmpn.getConst<bool>();
}
Trace("nl-ext-cms-debug") << " apex " << apex << std::endl;
Trace("nl-ext-cms-debug")
- << " min " << boundn[0] << ", cmp: " << cmp[0] << std::endl;
+ << " lower " << boundn[0] << ", cmp: " << cmp[0] << std::endl;
Trace("nl-ext-cms-debug")
- << " max " << boundn[1] << ", cmp: " << cmp[1] << std::endl;
+ << " upper " << boundn[1] << ", cmp: " << cmp[1] << std::endl;
+ Assert(boundn[0].getConst<Rational>()
+ <= boundn[1].getConst<Rational>());
Node s;
qvars.push_back(v);
if (cmp[0] != cmp[1])
@@ -1497,19 +1513,25 @@ bool NonlinearExtension::simpleCheckModelLit(Node lit)
<< " ...both sides of apex, compare " << tcmp << std::endl;
tcmp = Rewriter::rewrite(tcmp);
Assert(tcmp.isConst());
- unsigned bindex_use = tcmp.getConst<bool>() == pol ? 1 : 0;
+ unsigned bindex_use = (tcmp.getConst<bool>() == pol) ? 1 : 0;
Trace("nl-ext-cms-debug")
- << " ...set to " << (bindex_use == 1 ? "max" : "min")
+ << " ...set to " << (bindex_use == 1 ? "upper" : "lower")
<< std::endl;
s = boundn[bindex_use];
}
}
else
{
- // both to one side
- unsigned bindex_use = ((asgn == 1) == cmp[0]) == pol ? 0 : 1;
+ // both to one side of the apex
+ // we figure out which bound to use (lower or upper) based on
+ // three factors:
+ // (1) whether a's sign is positive,
+ // (2) whether we are greater than the apex of the parabola,
+ // (3) the polarity of the constraint, i.e. >= or <=.
+ // there are 8 cases of these factors, which we test here.
+ unsigned bindex_use = (((asgn == 1) == cmp[0]) == pol) ? 0 : 1;
Trace("nl-ext-cms-debug")
- << " ...set to " << (bindex_use == 1 ? "max" : "min")
+ << " ...set to " << (bindex_use == 1 ? "upper" : "lower")
<< std::endl;
s = boundn[bindex_use];
}
@@ -1589,6 +1611,9 @@ bool NonlinearExtension::simpleCheckModelMsum(const std::map<Node, Node>& msum,
int choose_index = -1;
std::vector<Node> ls;
std::vector<Node> us;
+ // the relevant sign information for variables with odd exponents:
+ // 1: both signs of the interval of this variable are positive,
+ // -1: both signs of the interval of this variable are negative.
std::vector<int> signs;
Trace("nl-ext-cms-debug") << "get sign information..." << std::endl;
for (unsigned i = 0, size = vars.size(); i < size; i++)
@@ -1613,9 +1638,10 @@ bool NonlinearExtension::simpleCheckModelMsum(const std::map<Node, Node>& msum,
Node u = bit->second.second;
ls.push_back(l);
us.push_back(u);
- int vsign = 1;
+ int vsign = 0;
if (vcfact % 2 == 1)
{
+ vsign = 1;
int lsgn = l.getConst<Rational>().sgn();
int usgn = u.getConst<Rational>().sgn();
Trace("nl-ext-cms-debug")
@@ -1658,7 +1684,10 @@ bool NonlinearExtension::simpleCheckModelMsum(const std::map<Node, Node>& msum,
}
}
// whether we will try to minimize/maximize (-1/1) the absolute value
- int minimizeAbs = set_lower == has_neg_factor ? -1 : 1;
+ int setAbs = (set_lower == has_neg_factor) ? 1 : -1;
+ Trace("nl-ext-cms-debug")
+ << "set absolute value to " << (setAbs == 1 ? "maximal" : "minimal")
+ << std::endl;
std::vector<Node> vbs;
Trace("nl-ext-cms-debug") << "set bounds..." << std::endl;
@@ -1669,6 +1698,10 @@ bool NonlinearExtension::simpleCheckModelMsum(const std::map<Node, Node>& msum,
Node l = ls[i];
Node u = us[i];
bool vc_set_lower;
+ int vcsign = signs[i];
+ Trace("nl-ext-cms-debug")
+ << "Bounds for " << vc << " : " << l << ", " << u
+ << ", sign : " << vcsign << ", factor : " << vcfact << std::endl;
if (l == u)
{
// by convention, always say it is lower if they are the same
@@ -1678,15 +1711,31 @@ bool NonlinearExtension::simpleCheckModelMsum(const std::map<Node, Node>& msum,
}
else
{
- if (signs[i] == 0)
+ if (vcfact % 2 == 0)
+ {
+ // minimize or maximize its absolute value
+ Rational la = l.getConst<Rational>().abs();
+ Rational ua = u.getConst<Rational>().abs();
+ if (la == ua)
+ {
+ // by convention, always say it is lower if abs are the same
+ vc_set_lower = true;
+ Trace("nl-ext-cms-debug")
+ << "..." << vc << " equal abs, set to lower" << std::endl;
+ }
+ else
+ {
+ vc_set_lower = (la > ua) == (setAbs == 1);
+ }
+ }
+ else if (signs[i] == 0)
{
// we choose this index to match the overall set_lower
vc_set_lower = set_lower;
}
else
{
- // minimize or maximize its absolute value
- vc_set_lower = (signs[i] == minimizeAbs);
+ vc_set_lower = (signs[i] == setAbs);
}
Trace("nl-ext-cms-debug")
<< "..." << vc << " set to " << (vc_set_lower ? "lower" : "upper")
@@ -1704,8 +1753,8 @@ bool NonlinearExtension::simpleCheckModelMsum(const std::map<Node, Node>& msum,
<< " failed due to conflicting bound for " << vc << std::endl;
return false;
}
- // must over/under approximate
- Node vb = set_lower ? l : u;
+ // must over/under approximate based on vc_set_lower, computed above
+ Node vb = vc_set_lower ? l : u;
for (unsigned i = 0; i < vcfact; i++)
{
vbs.push_back(vb);
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback