summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/theory/arith/arith_msum.cpp78
-rw-r--r--src/theory/arith/arith_msum.h13
-rw-r--r--src/theory/arith/arith_rewriter.cpp2
-rw-r--r--src/theory/quantifiers/cegqi/ceg_arith_instantiator.cpp11
-rw-r--r--src/theory/quantifiers/cegqi/ceg_arith_instantiator.h2
-rw-r--r--src/theory/quantifiers/fmf/bounded_integers.cpp7
-rw-r--r--src/theory/quantifiers/relevant_domain.cpp10
7 files changed, 62 insertions, 61 deletions
diff --git a/src/theory/arith/arith_msum.cpp b/src/theory/arith/arith_msum.cpp
index a8edb0e79..0621c1391 100644
--- a/src/theory/arith/arith_msum.cpp
+++ b/src/theory/arith/arith_msum.cpp
@@ -81,7 +81,8 @@ bool ArithMSum::getMonomialSum(Node n, std::map<Node, Node>& msum)
bool ArithMSum::getMonomialSumLit(Node lit, std::map<Node, Node>& msum)
{
- if (lit.getKind() == GEQ || lit.getKind() == EQUAL)
+ if (lit.getKind() == GEQ
+ || (lit.getKind() == EQUAL && lit[0].getType().isRealOrInt()))
{
if (getMonomialSum(lit[0], msum))
{
@@ -96,6 +97,7 @@ bool ArithMSum::getMonomialSumLit(Node lit, std::map<Node, Node>& msum)
NodeManager* nm = NodeManager::currentNM();
if (getMonomialSum(lit[1], msum2))
{
+ TypeNode tn = lit[0].getType();
for (std::map<Node, Node>::iterator it = msum2.begin();
it != msum2.end();
++it)
@@ -103,20 +105,20 @@ bool ArithMSum::getMonomialSumLit(Node lit, std::map<Node, Node>& msum)
std::map<Node, Node>::iterator it2 = msum.find(it->first);
if (it2 != msum.end())
{
- Node r = nm->mkNode(MINUS,
- it2->second.isNull()
- ? nm->mkConst(CONST_RATIONAL, Rational(1))
- : it2->second,
- it->second.isNull()
- ? nm->mkConst(CONST_RATIONAL, Rational(1))
- : it->second);
- msum[it->first] = Rewriter::rewrite(r);
+ Rational r1 = it2->second.isNull()
+ ? Rational(1)
+ : it2->second.getConst<Rational>();
+ Rational r2 = it->second.isNull()
+ ? Rational(1)
+ : it->second.getConst<Rational>();
+ msum[it->first] = nm->mkConstRealOrInt(tn, r1 - r2);
}
else
{
msum[it->first] = it->second.isNull()
- ? nm->mkConst(CONST_RATIONAL, Rational(-1))
- : negate(it->second);
+ ? nm->mkConstRealOrInt(tn, Rational(-1))
+ : nm->mkConstRealOrInt(
+ tn, -it->second.getConst<Rational>());
}
}
return true;
@@ -127,7 +129,7 @@ bool ArithMSum::getMonomialSumLit(Node lit, std::map<Node, Node>& msum)
return false;
}
-Node ArithMSum::mkNode(const std::map<Node, Node>& msum)
+Node ArithMSum::mkNode(TypeNode tn, const std::map<Node, Node>& msum)
{
NodeManager* nm = NodeManager::currentNM();
std::vector<Node> children;
@@ -146,10 +148,10 @@ Node ArithMSum::mkNode(const std::map<Node, Node>& msum)
}
children.push_back(m);
}
- return children.size() > 1 ? nm->mkNode(PLUS, children)
- : (children.size() == 1
- ? children[0]
- : nm->mkConst(CONST_RATIONAL, Rational(0)));
+ return children.size() > 1
+ ? nm->mkNode(PLUS, children)
+ : (children.size() == 1 ? children[0]
+ : nm->mkConstRealOrInt(tn, Rational(0)));
}
int ArithMSum::isolate(
@@ -159,11 +161,13 @@ int ArithMSum::isolate(
std::map<Node, Node>::const_iterator itv = msum.find(v);
if (itv != msum.end())
{
+ NodeManager* nm = NodeManager::currentNM();
std::vector<Node> children;
Rational r =
itv->second.isNull() ? Rational(1) : itv->second.getConst<Rational>();
if (r.sgn() != 0)
{
+ TypeNode vtn = v.getType();
for (std::map<Node, Node>::const_iterator it = msum.begin();
it != msum.end();
++it)
@@ -182,27 +186,25 @@ int ArithMSum::isolate(
children.push_back(m);
}
}
- val = children.size() > 1
- ? NodeManager::currentNM()->mkNode(PLUS, children)
- : (children.size() == 1 ? children[0]
- : NodeManager::currentNM()->mkConst(
- CONST_RATIONAL, Rational(0)));
+ val =
+ children.size() > 1
+ ? nm->mkNode(PLUS, children)
+ : (children.size() == 1 ? children[0]
+ : nm->mkConstRealOrInt(vtn, Rational(0)));
if (!r.isOne() && !r.isNegativeOne())
{
- if (v.getType().isInteger())
+ if (vtn.isInteger())
{
- veq_c = NodeManager::currentNM()->mkConst(CONST_RATIONAL, r.abs());
+ veq_c = nm->mkConstInt(r.abs());
}
else
{
- val = NodeManager::currentNM()->mkNode(
- MULT,
- val,
- NodeManager::currentNM()->mkConst(CONST_RATIONAL,
- Rational(1) / r.abs()));
+ val = nm->mkNode(MULT, val, nm->mkConstReal(Rational(1) / r.abs()));
}
}
- val = r.sgn() == 1 ? negate(val) : Rewriter::rewrite(val);
+ val = r.sgn() == 1
+ ? nm->mkNode(MULT, nm->mkConstRealOrInt(vtn, Rational(-1)), val)
+ : val;
return (r.sgn() == 1 || k == EQUAL) ? 1 : -1;
}
}
@@ -284,29 +286,13 @@ bool ArithMSum::decompose(Node n, Node v, Node& coeff, Node& rem)
{
coeff = it->second;
msum.erase(v);
- rem = mkNode(msum);
+ rem = mkNode(n.getType(), msum);
return true;
}
}
return false;
}
-Node ArithMSum::negate(Node t)
-{
- Node tt = NodeManager::currentNM()->mkNode(
- MULT, NodeManager::currentNM()->mkConst(CONST_RATIONAL, Rational(-1)), t);
- tt = Rewriter::rewrite(tt);
- return tt;
-}
-
-Node ArithMSum::offset(Node t, int i)
-{
- Node tt = NodeManager::currentNM()->mkNode(
- PLUS, NodeManager::currentNM()->mkConst(CONST_RATIONAL, Rational(i)), t);
- tt = Rewriter::rewrite(tt);
- return tt;
-}
-
void ArithMSum::debugPrintMonomialSum(std::map<Node, Node>& msum, const char* c)
{
for (std::map<Node, Node>::iterator it = msum.begin(); it != msum.end(); ++it)
diff --git a/src/theory/arith/arith_msum.h b/src/theory/arith/arith_msum.h
index 87f56e64f..ae57ee1cb 100644
--- a/src/theory/arith/arith_msum.h
+++ b/src/theory/arith/arith_msum.h
@@ -103,8 +103,13 @@ class ArithMSum
*
* Make the Node corresponding to the interpretation of msum, [msum], where:
* [msum] = sum_{( v, c ) \in msum } [c]*[v]
+ *
+ * @param tn The type of the node to return, which is used only if msum is
+ * empty
+ * @param msum The monomial sum
+ * @return The node corresponding to the monomial sum
*/
- static Node mkNode(const std::map<Node, Node>& msum);
+ static Node mkNode(TypeNode tn, const std::map<Node, Node>& msum);
/** make coefficent term
*
@@ -173,12 +178,6 @@ class ArithMSum
*/
static bool decompose(Node n, Node v, Node& coeff, Node& rem);
- /** return the rewritten form of (UMINUS t) */
- static Node negate(Node t);
-
- /** return the rewritten form of (PLUS t (CONST_RATIONAL i)) */
- static Node offset(Node t, int i);
-
/** debug print for a monmoial sum, prints to Trace(c) */
static void debugPrintMonomialSum(std::map<Node, Node>& msum, const char* c);
};
diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp
index 4c01f25f3..af6f23c1f 100644
--- a/src/theory/arith/arith_rewriter.cpp
+++ b/src/theory/arith/arith_rewriter.cpp
@@ -516,7 +516,7 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
msum.erase(pi);
if (!msum.empty())
{
- rem = ArithMSum::mkNode(msum);
+ rem = ArithMSum::mkNode(t[0].getType(), msum);
}
}
}
diff --git a/src/theory/quantifiers/cegqi/ceg_arith_instantiator.cpp b/src/theory/quantifiers/cegqi/ceg_arith_instantiator.cpp
index 2d483d502..56debbbac 100644
--- a/src/theory/quantifiers/cegqi/ceg_arith_instantiator.cpp
+++ b/src/theory/quantifiers/cegqi/ceg_arith_instantiator.cpp
@@ -818,7 +818,7 @@ CegTermType ArithInstantiator::solve_arith(CegInstantiator* ci,
// multiply by the coefficient we will isolate for
if (itv->second.isNull())
{
- vts_coeff[t] = ArithMSum::negate(vts_coeff[t]);
+ vts_coeff[t] = negate(vts_coeff[t]);
}
else
{
@@ -833,7 +833,7 @@ CegTermType ArithInstantiator::solve_arith(CegInstantiator* ci,
}
else if (itv->second.getConst<Rational>().sgn() == 1)
{
- vts_coeff[t] = ArithMSum::negate(vts_coeff[t]);
+ vts_coeff[t] = negate(vts_coeff[t]);
}
}
}
@@ -1040,6 +1040,13 @@ Node ArithInstantiator::getModelBasedProjectionValue(CegInstantiator* ci,
return val;
}
+Node ArithInstantiator::negate(const Node& t) const
+{
+ NodeManager* nm = NodeManager::currentNM();
+ return rewrite(
+ nm->mkNode(MULT, nm->mkConstRealOrInt(t.getType(), Rational(-1)), t));
+}
+
} // namespace quantifiers
} // namespace theory
} // namespace cvc5
diff --git a/src/theory/quantifiers/cegqi/ceg_arith_instantiator.h b/src/theory/quantifiers/cegqi/ceg_arith_instantiator.h
index e102b834e..d44ab4993 100644
--- a/src/theory/quantifiers/cegqi/ceg_arith_instantiator.h
+++ b/src/theory/quantifiers/cegqi/ceg_arith_instantiator.h
@@ -206,6 +206,8 @@ class ArithInstantiator : public Instantiator
Node theta,
Node inf_coeff,
Node delta_coeff);
+ /** Return the rewritten form of the negation of t */
+ Node negate(const Node& t) const;
};
} // namespace quantifiers
diff --git a/src/theory/quantifiers/fmf/bounded_integers.cpp b/src/theory/quantifiers/fmf/bounded_integers.cpp
index 18a63d245..5c0283863 100644
--- a/src/theory/quantifiers/fmf/bounded_integers.cpp
+++ b/src/theory/quantifiers/fmf/bounded_integers.cpp
@@ -223,6 +223,7 @@ void BoundedIntegers::process( Node q, Node n, bool pol,
std::map< Node, Node > msum;
if (ArithMSum::getMonomialSumLit(n, msum))
{
+ NodeManager* nm = NodeManager::currentNM();
Trace("bound-int-debug") << "literal (polarity = " << pol << ") " << n << " is monomial sum : " << std::endl;
ArithMSum::debugPrintMonomialSum(msum, "bound-int-debug");
for( std::map< Node, Node >::iterator it = msum.begin(); it != msum.end(); ++it ){
@@ -239,11 +240,11 @@ void BoundedIntegers::process( Node q, Node n, bool pol,
n1 = veq[1];
n2 = veq[0];
if( n1.getKind()==BOUND_VARIABLE ){
- n2 = ArithMSum::offset(n2, 1);
+ n2 = nm->mkNode(PLUS, n2, nm->mkConstInt(Rational(1)));
}else{
- n1 = ArithMSum::offset(n1, -1);
+ n1 = nm->mkNode(PLUS, n1, nm->mkConstInt(Rational(-1)));
}
- veq = NodeManager::currentNM()->mkNode( GEQ, n1, n2 );
+ veq = nm->mkNode(GEQ, n1, n2);
}
Trace("bound-int-debug") << "Isolated for " << it->first << " : (" << n1 << " >= " << n2 << ")" << std::endl;
Node t = n1==it->first ? n2 : n1;
diff --git a/src/theory/quantifiers/relevant_domain.cpp b/src/theory/quantifiers/relevant_domain.cpp
index 0f3699990..f0684f04a 100644
--- a/src/theory/quantifiers/relevant_domain.cpp
+++ b/src/theory/quantifiers/relevant_domain.cpp
@@ -24,6 +24,7 @@
#include "theory/quantifiers/term_database.h"
#include "theory/quantifiers/term_registry.h"
#include "theory/quantifiers/term_util.h"
+#include "util/rational.h"
using namespace cvc5::kind;
@@ -301,6 +302,7 @@ void RelevantDomain::computeRelevantDomainOpCh( RDomain * rf, Node n ) {
void RelevantDomain::computeRelevantDomainLit( Node q, bool hasPol, bool pol, Node n ) {
if( d_rel_dom_lit[hasPol][pol].find( n )==d_rel_dom_lit[hasPol][pol].end() ){
+ NodeManager* nm = NodeManager::currentNM();
RDomainLit& rdl = d_rel_dom_lit[hasPol][pol][n];
rdl.d_merge = false;
int varCount = 0;
@@ -405,10 +407,14 @@ void RelevantDomain::computeRelevantDomainLit( Node q, bool hasPol, bool pol, No
if( ( !hasPol || pol ) && n[0].getType().isInteger() ){
if( n.getKind()==EQUAL ){
for( unsigned i=0; i<2; i++ ){
- rdl.d_val.push_back(ArithMSum::offset(r_add, i == 0 ? 1 : -1));
+ Node roff = nm->mkNode(
+ PLUS, r_add, nm->mkConstInt(Rational(i == 0 ? 1 : -1)));
+ rdl.d_val.push_back(roff);
}
}else if( n.getKind()==GEQ ){
- rdl.d_val.push_back(ArithMSum::offset(r_add, varLhs ? 1 : -1));
+ Node roff = nm->mkNode(
+ PLUS, r_add, nm->mkConstInt(Rational(varLhs ? 1 : -1)));
+ rdl.d_val.push_back(roff);
}
}
}
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback