summaryrefslogtreecommitdiff
path: root/src/theory/arith/nl/cad/cdcac.cpp
diff options
context:
space:
mode:
authorGereon Kremer <nafur42@gmail.com>2021-10-22 14:49:50 -0700
committerGitHub <noreply@github.com>2021-10-22 21:49:50 +0000
commit0dfbf4b80f25bc9edd1c843ba9a9bb37bace79a9 (patch)
tree432a07f4119f031f14c0247982f86f6d76ca9ab0 /src/theory/arith/nl/cad/cdcac.cpp
parentf1db161860d0283cb5537ad8847e0b52d1485e28 (diff)
Fix out-of-sync pruning in CDCAC proofs (#7470)
This PR resolves a subtle issue with CDCAC proofs. The CDCAC proof is maintained as a tree where (mostly) every node corresponds to an (infeasible) interval generated within the CDCAC method. We prune these intervals regularly to get rid of redundant intervals, which also sorts intervals. The pruning however relied on a stable ordering of both intervals and child nodes within the proof tree, as there was no easy way to map nodes back to intervals. This PR adds an objectId field to the proof tree nodes and assigns ids to the CDCAC intervals. This allows for a robust mapping between the two, even if the interval list is reordered. Fixes cvc5/cvc5-projects#313.
Diffstat (limited to 'src/theory/arith/nl/cad/cdcac.cpp')
-rw-r--r--src/theory/arith/nl/cad/cdcac.cpp104
1 files changed, 65 insertions, 39 deletions
diff --git a/src/theory/arith/nl/cad/cdcac.cpp b/src/theory/arith/nl/cad/cdcac.cpp
index d259bc096..4a2709cf8 100644
--- a/src/theory/arith/nl/cad/cdcac.cpp
+++ b/src/theory/arith/nl/cad/cdcac.cpp
@@ -55,6 +55,7 @@ void CDCAC::reset()
{
d_constraints.reset();
d_assignment.clear();
+ d_nextIntervalId = 1;
}
void CDCAC::computeVariableOrdering()
@@ -150,7 +151,7 @@ std::vector<CACInterval> CDCAC::getUnsatIntervals(std::size_t cur_variable)
m.pushDownPolys(d, d_variableOrdering[cur_variable]);
if (!is_minus_infinity(get_lower(i))) l = m;
if (!is_plus_infinity(get_upper(i))) u = m;
- res.emplace_back(CACInterval{i, l, u, m, d, {n}});
+ res.emplace_back(CACInterval{d_nextIntervalId++, i, l, u, m, d, {n}});
if (isProofEnabled())
{
d_proof->addDirect(
@@ -160,7 +161,8 @@ std::vector<CACInterval> CDCAC::getUnsatIntervals(std::size_t cur_variable)
d_assignment,
sc,
i,
- n);
+ n,
+ res.back().d_id);
}
}
}
@@ -293,18 +295,21 @@ PolyVector requiredCoefficientsLazardModified(
PolyVector CDCAC::requiredCoefficients(const poly::Polynomial& p)
{
- if (Trace.isOn("cdcac"))
+ if (Trace.isOn("cdcac::projection"))
{
- Trace("cdcac") << "Poly: " << p << " over " << d_assignment << std::endl;
- Trace("cdcac") << "Lazard: "
- << requiredCoefficientsLazard(p, d_assignment) << std::endl;
- Trace("cdcac") << "LMod: "
- << requiredCoefficientsLazardModified(
- p, d_assignment, d_constraints.varMapper())
- << std::endl;
- Trace("cdcac") << "Original: "
- << requiredCoefficientsOriginal(p, d_assignment)
- << std::endl;
+ Trace("cdcac::projection")
+ << "Poly: " << p << " over " << d_assignment << std::endl;
+ Trace("cdcac::projection")
+ << "Lazard: " << requiredCoefficientsLazard(p, d_assignment)
+ << std::endl;
+ Trace("cdcac::projection")
+ << "LMod: "
+ << requiredCoefficientsLazardModified(
+ p, d_assignment, d_constraints.varMapper())
+ << std::endl;
+ Trace("cdcac::projection")
+ << "Original: " << requiredCoefficientsOriginal(p, d_assignment)
+ << std::endl;
}
switch (options().arith.nlCadProjection)
{
@@ -346,15 +351,16 @@ PolyVector CDCAC::constructCharacterization(std::vector<CACInterval>& intervals)
}
for (const auto& p : i.d_mainPolys)
{
- Trace("cdcac") << "Discriminant of " << p << " -> " << discriminant(p)
- << std::endl;
+ Trace("cdcac::projection")
+ << "Discriminant of " << p << " -> " << discriminant(p) << std::endl;
// Add all discriminants
res.add(discriminant(p));
for (const auto& q : requiredCoefficients(p))
{
// Add all required coefficients
- Trace("cdcac") << "Coeff of " << p << " -> " << q << std::endl;
+ Trace("cdcac::projection")
+ << "Coeff of " << p << " -> " << q << std::endl;
res.add(q);
}
for (const auto& q : i.d_lowerPolys)
@@ -362,8 +368,8 @@ PolyVector CDCAC::constructCharacterization(std::vector<CACInterval>& intervals)
if (p == q) continue;
// Check whether p(s \times a) = 0 for some a <= l
if (!hasRootBelow(q, get_lower(i.d_interval))) continue;
- Trace("cdcac") << "Resultant of " << p << " and " << q << " -> "
- << resultant(p, q) << std::endl;
+ Trace("cdcac::projection") << "Resultant of " << p << " and " << q
+ << " -> " << resultant(p, q) << std::endl;
res.add(resultant(p, q));
}
for (const auto& q : i.d_upperPolys)
@@ -371,8 +377,8 @@ PolyVector CDCAC::constructCharacterization(std::vector<CACInterval>& intervals)
if (p == q) continue;
// Check whether p(s \times a) = 0 for some a >= u
if (!hasRootAbove(q, get_upper(i.d_interval))) continue;
- Trace("cdcac") << "Resultant of " << p << " and " << q << " -> "
- << resultant(p, q) << std::endl;
+ Trace("cdcac::projection") << "Resultant of " << p << " and " << q
+ << " -> " << resultant(p, q) << std::endl;
res.add(resultant(p, q));
}
}
@@ -385,8 +391,8 @@ PolyVector CDCAC::constructCharacterization(std::vector<CACInterval>& intervals)
{
for (const auto& q : intervals[i + 1].d_lowerPolys)
{
- Trace("cdcac") << "Resultant of " << p << " and " << q << " -> "
- << resultant(p, q) << std::endl;
+ Trace("cdcac::projection") << "Resultant of " << p << " and " << q
+ << " -> " << resultant(p, q) << std::endl;
res.add(resultant(p, q));
}
}
@@ -477,25 +483,31 @@ CACInterval CDCAC::intervalFromCharacterization(
if (lower == upper)
{
// construct a point interval
- return CACInterval{
- poly::Interval(lower, false, upper, false), l, u, m, d, {}};
+ return CACInterval{d_nextIntervalId++,
+ poly::Interval(lower, false, upper, false),
+ l,
+ u,
+ m,
+ d,
+ {}};
}
else
{
// construct an open interval
Assert(lower < upper);
- return CACInterval{
- poly::Interval(lower, true, upper, true), l, u, m, d, {}};
+ return CACInterval{d_nextIntervalId++,
+ poly::Interval(lower, true, upper, true),
+ l,
+ u,
+ m,
+ d,
+ {}};
}
}
-std::vector<CACInterval> CDCAC::getUnsatCover(std::size_t curVariable,
- bool returnFirstInterval)
+std::vector<CACInterval> CDCAC::getUnsatCoverImpl(std::size_t curVariable,
+ bool returnFirstInterval)
{
- if (isProofEnabled())
- {
- d_proof->startRecursive();
- }
Trace("cdcac") << "Looking for unsat cover for "
<< d_variableOrdering[curVariable] << std::endl;
std::vector<CACInterval> intervals = getUnsatIntervals(curVariable);
@@ -537,9 +549,10 @@ std::vector<CACInterval> CDCAC::getUnsatCover(std::size_t curVariable,
if (isProofEnabled())
{
d_proof->startScope();
+ d_proof->startRecursive();
}
// Recurse to next variable
- auto cov = getUnsatCover(curVariable + 1);
+ auto cov = getUnsatCoverImpl(curVariable + 1);
if (cov.empty())
{
// Found SAT!
@@ -558,6 +571,7 @@ std::vector<CACInterval> CDCAC::getUnsatCover(std::size_t curVariable,
intervals.emplace_back(newInterval);
if (isProofEnabled())
{
+ d_proof->endRecursive(newInterval.d_id);
auto cell = d_proof->constructCell(
d_constraints.varMapper()(d_variableOrdering[curVariable]),
newInterval,
@@ -596,11 +610,21 @@ std::vector<CACInterval> CDCAC::getUnsatCover(std::size_t curVariable,
Trace("cdcac") << "-> " << i.d_interval << std::endl;
}
}
+ return intervals;
+}
+
+std::vector<CACInterval> CDCAC::getUnsatCover(bool returnFirstInterval)
+{
+ if (isProofEnabled())
+ {
+ d_proof->startRecursive();
+ }
+ auto res = getUnsatCoverImpl(0, returnFirstInterval);
if (isProofEnabled())
{
- d_proof->endRecursive();
+ d_proof->endRecursive(0);
}
- return intervals;
+ return res;
}
void CDCAC::startNewProof()
@@ -639,7 +663,8 @@ CACInterval CDCAC::buildIntegralityInterval(std::size_t cur_variable,
poly::Integer below = poly::floor(value);
poly::Integer above = poly::ceil(value);
// construct var \in (below, above)
- return CACInterval{poly::Interval(below, above),
+ return CACInterval{d_nextIntervalId++,
+ poly::Interval(below, above),
{var - below},
{var - above},
{var - below, var - above},
@@ -669,10 +694,11 @@ void CDCAC::pruneRedundantIntervals(std::vector<CACInterval>& intervals)
{
if (isProofEnabled())
{
- std::vector<CACInterval> allIntervals = intervals;
cleanIntervals(intervals);
- d_proof->pruneChildren([&allIntervals, &intervals](std::size_t i) {
- return std::find(intervals.begin(), intervals.end(), allIntervals[i])
+ d_proof->pruneChildren([&intervals](std::size_t id) {
+ return std::find_if(intervals.begin(),
+ intervals.end(),
+ [id](const CACInterval& i) { return i.d_id == id; })
!= intervals.end();
});
}
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback