diff options
-rw-r--r-- | src/theory/arith/idl/idl_extension.cpp | 45 |
1 files changed, 44 insertions, 1 deletions
diff --git a/src/theory/arith/idl/idl_extension.cpp b/src/theory/arith/idl/idl_extension.cpp index ee92c41ff..36e9dc78c 100644 --- a/src/theory/arith/idl/idl_extension.cpp +++ b/src/theory/arith/idl/idl_extension.cpp @@ -167,10 +167,49 @@ Node IdlExtension::ppRewrite(TNode atom, std::vector<SkolemLemma>& lems) } return ppRewrite(ret, lems); } - else if (atom[0].getKind() == kind::VARIABLE) { + else if (atom[0].getKind() == kind::VARIABLE + && atom[1].getKind() == kind::CONST_RATIONAL) { Node new_lhs = nm->mkNode(kind::MINUS, atom[0], zero_node); return ppRewrite(nm->mkNode(atom.getKind(), new_lhs, atom[1]), lems); } + // Handle ([op] ([+-] x 10) ([+-] y 20)) + // Note that not all of these are valid, e.g., x + 10 <= 10 - y. + // This is a 'just barely enough' implementation to get the few smtlib + // benchmarks using this syntax to work. It should probably be cleaned + // up/simplified. + else if ((atom[1].getKind() == kind::PLUS || atom[1].getKind() == kind::MINUS) + || (atom[0].getKind() == kind::PLUS + && atom[1].getKind() == kind::CONST_RATIONAL)) { + Node lhs = atom[0], rhs = atom[1]; + if (lhs.getKind() == kind::VARIABLE) { + lhs = nm->mkNode(kind::PLUS, lhs, nm->mkConstInt(Rational())); + } + if (rhs.getKind() == kind::CONST_RATIONAL) { + rhs = nm->mkNode(kind::PLUS, rhs, zero_node); + } + size_t lhs_var_idx = (lhs[0].getKind() == kind::VARIABLE) ? 0 : 1, + rhs_var_idx = (rhs[0].getKind() == kind::VARIABLE) ? 0 : 1; + Node lhs_var = lhs[lhs_var_idx], + rhs_var = rhs[rhs_var_idx]; + Assert(lhs[1 - lhs_var_idx].getKind() == kind::CONST_RATIONAL); + Assert(rhs[1 - rhs_var_idx].getKind() == kind::CONST_RATIONAL); + const Rational& lhs_const = lhs[1 - lhs_var_idx].getConst<Rational>(); + const Rational& rhs_const = rhs[1 - rhs_var_idx].getConst<Rational>(); + bool lhs_is_plus = lhs.getKind() == kind::PLUS, + rhs_is_plus = rhs.getKind() == kind::PLUS; + bool lhs_var_subtracted = (!lhs_is_plus) && lhs_var_idx == 1, + rhs_var_subtracted = (!rhs_is_plus) && rhs_var_idx == 1; + // Move the rhs to the lhs. + Assert(!lhs_var_subtracted || rhs_var_subtracted); + Node new_lhs_first_slot = lhs_var_subtracted ? rhs_var : lhs_var, + new_lhs_second_slot = lhs_var_subtracted ? lhs_var : rhs_var; + Node new_lhs + = nm->mkNode(kind::MINUS, new_lhs_first_slot, new_lhs_second_slot); + bool lhs_const_subtracted = (!lhs_is_plus) && lhs_var_idx == 0; + Node new_rhs + = nm->mkConstInt(rhs_const + (lhs_const_subtracted ? lhs_const : -lhs_const)); + return ppRewrite(nm->mkNode(atom.getKind(), new_lhs, new_rhs), lems); + } // NOTE: Does *NOT* handle cases like x = 5. switch (atom.getKind()) @@ -180,6 +219,7 @@ Node IdlExtension::ppRewrite(TNode atom, std::vector<SkolemLemma>& lems) Node l_le_r = nm->mkNode(kind::LEQ, atom[0], atom[1]); Assert(atom[0].getKind() == kind::MINUS); Node negated_left = nm->mkNode(kind::MINUS, atom[0][1], atom[0][0]); + Assert(atom[1].getKind() == kind::CONST_RATIONAL); const Rational& right = atom[1].getConst<Rational>(); Node negated_right = nm->mkConstInt(-right); Node r_le_l = nm->mkNode(kind::LEQ, negated_left, negated_right); @@ -192,6 +232,7 @@ Node IdlExtension::ppRewrite(TNode atom, std::vector<SkolemLemma>& lems) case kind::LT: // a - b < c == a - b <= c - 1 (integers) { Assert(atom[0].getKind() == kind::MINUS); + Assert(atom[1].getKind() == kind::CONST_RATIONAL); const Rational& right = atom[1].getConst<Rational>(); Node tight_right = nm->mkConstInt(right - 1); return nm->mkNode(kind::LEQ, atom[0], tight_right); @@ -200,6 +241,7 @@ Node IdlExtension::ppRewrite(TNode atom, std::vector<SkolemLemma>& lems) case kind::GT: // a - b > c == -(a - b) < -c == b - a <= -c - 1 { Node negated_left = nm->mkNode(kind::MINUS, atom[0][1], atom[0][0]); + Assert(atom[1].getKind() == kind::CONST_RATIONAL); const Rational& right = atom[1].getConst<Rational>(); Node updated_right = nm->mkConstInt(-right - 1); return nm->mkNode(kind::LEQ, negated_left, updated_right); @@ -207,6 +249,7 @@ Node IdlExtension::ppRewrite(TNode atom, std::vector<SkolemLemma>& lems) case kind::GEQ: // a - b >= c == b - a <= -c { Node negated_left = nm->mkNode(kind::MINUS, atom[0][1], atom[0][0]); + Assert(atom[1].getKind() == kind::CONST_RATIONAL); const Rational& right = atom[1].getConst<Rational>(); Node negated_right = nm->mkConstInt(-right); return nm->mkNode(kind::LEQ, negated_left, negated_right); |