summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authormudathirmahgoub <mudathirmahgoub@gmail.com>2021-12-01 20:12:30 -0600
committerGitHub <noreply@github.com>2021-12-02 02:12:30 +0000
commit70997d0e3ebf2027279373d9594c66119f3fa656 (patch)
treea7ce0933bf70efa43507cb3cb3db99220113f426 /src
parent6afc21a16e740d4fb4a16cdbd9a6ff745c7ce00c (diff)
add bag.fold operator (#7718)
Diffstat (limited to 'src')
-rw-r--r--src/CMakeLists.txt2
-rw-r--r--src/api/cpp/cvc5.cpp2
-rw-r--r--src/api/cpp/cvc5_kind.h16
-rw-r--r--src/expr/skolem_manager.cpp4
-rw-r--r--src/expr/skolem_manager.h4
-rw-r--r--src/parser/smt2/smt2.cpp1
-rw-r--r--src/printer/smt2/smt2_printer.cpp1
-rw-r--r--src/theory/bags/bag_reduction.cpp119
-rw-r--r--src/theory/bags/bag_reduction.h77
-rw-r--r--src/theory/bags/bags_rewriter.cpp40
-rw-r--r--src/theory/bags/bags_rewriter.h10
-rw-r--r--src/theory/bags/kinds9
-rw-r--r--src/theory/bags/normal_form.cpp38
-rw-r--r--src/theory/bags/normal_form.h6
-rw-r--r--src/theory/bags/rewrites.cpp3
-rw-r--r--src/theory/bags/rewrites.h3
-rw-r--r--src/theory/bags/theory_bags.cpp20
-rw-r--r--src/theory/bags/theory_bags.h4
-rw-r--r--src/theory/bags/theory_bags_type_rules.cpp51
-rw-r--r--src/theory/bags/theory_bags_type_rules.h9
-rw-r--r--src/theory/inference_id.cpp1
-rw-r--r--src/theory/inference_id.h1
22 files changed, 416 insertions, 5 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 025f499e6..96de9afeb 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -535,6 +535,8 @@ libcvc5_add_sources(
theory/bags/bags_rewriter.h
theory/bags/bag_solver.cpp
theory/bags/bag_solver.h
+ theory/bags/bag_reduction.cpp
+ theory/bags/bag_reduction.h
theory/bags/bags_statistics.cpp
theory/bags/bags_statistics.h
theory/bags/infer_info.cpp
diff --git a/src/api/cpp/cvc5.cpp b/src/api/cpp/cvc5.cpp
index 6129ff891..c62dde511 100644
--- a/src/api/cpp/cvc5.cpp
+++ b/src/api/cpp/cvc5.cpp
@@ -313,6 +313,7 @@ const static std::unordered_map<Kind, cvc5::Kind> s_kinds{
{BAG_FROM_SET, cvc5::Kind::BAG_FROM_SET},
{BAG_TO_SET, cvc5::Kind::BAG_TO_SET},
{BAG_MAP, cvc5::Kind::BAG_MAP},
+ {BAG_FOLD, cvc5::Kind::BAG_FOLD},
/* Strings ------------------------------------------------------------- */
{STRING_CONCAT, cvc5::Kind::STRING_CONCAT},
{STRING_IN_REGEXP, cvc5::Kind::STRING_IN_REGEXP},
@@ -624,6 +625,7 @@ const static std::unordered_map<cvc5::Kind, Kind, cvc5::kind::KindHashFunction>
{cvc5::Kind::BAG_FROM_SET, BAG_FROM_SET},
{cvc5::Kind::BAG_TO_SET, BAG_TO_SET},
{cvc5::Kind::BAG_MAP, BAG_MAP},
+ {cvc5::Kind::BAG_FOLD, BAG_FOLD},
/* Strings --------------------------------------------------------- */
{cvc5::Kind::STRING_CONCAT, STRING_CONCAT},
{cvc5::Kind::STRING_IN_REGEXP, STRING_IN_REGEXP},
diff --git a/src/api/cpp/cvc5_kind.h b/src/api/cpp/cvc5_kind.h
index e6a03cbe4..73843f9b5 100644
--- a/src/api/cpp/cvc5_kind.h
+++ b/src/api/cpp/cvc5_kind.h
@@ -2539,6 +2539,22 @@ enum Kind : int32_t
* - `Solver::mkTerm(Kind kind, const std::vector<Term>& children) const`
*/
BAG_MAP,
+ /**
+ * bag.fold operator combines elements of a bag into a single value.
+ * (bag.fold f t B) folds the elements of bag B starting with term t and using
+ * the combining function f.
+ *
+ * Parameters:
+ * - 1: a binary operation of type (-> T1 T2 T2)
+ * - 2: an initial value of type T2
+ * - 2: a bag of type (Bag T1)
+ *
+ * Create with:
+ * - `Solver::mkTerm(Kind kind, const Term& child1, const Term& child2,
+ * const Term& child3) const`
+ * - `Solver::mkTerm(Kind kind, const std::vector<Term>& children) const`
+ */
+ BAG_FOLD,
/* Strings --------------------------------------------------------------- */
diff --git a/src/expr/skolem_manager.cpp b/src/expr/skolem_manager.cpp
index db976559f..476517820 100644
--- a/src/expr/skolem_manager.cpp
+++ b/src/expr/skolem_manager.cpp
@@ -68,6 +68,10 @@ const char* toString(SkolemFunId id)
case SkolemFunId::SK_FIRST_MATCH_POST: return "SK_FIRST_MATCH_POST";
case SkolemFunId::RE_UNFOLD_POS_COMPONENT: return "RE_UNFOLD_POS_COMPONENT";
case SkolemFunId::BAGS_CHOOSE: return "BAGS_CHOOSE";
+ case SkolemFunId::BAGS_FOLD_CARD: return "BAGS_FOLD_CARD";
+ case SkolemFunId::BAGS_FOLD_COMBINE: return "BAGS_FOLD_COMBINE";
+ case SkolemFunId::BAGS_FOLD_ELEMENTS: return "BAGS_FOLD_ELEMENTS";
+ case SkolemFunId::BAGS_FOLD_UNION_DISJOINT: return "BAGS_FOLD_UNION_DISJOINT";
case SkolemFunId::BAGS_MAP_PREIMAGE: return "BAGS_MAP_PREIMAGE";
case SkolemFunId::BAGS_MAP_SUM: return "BAGS_MAP_SUM";
case SkolemFunId::HO_TYPE_MATCH_PRED: return "HO_TYPE_MATCH_PRED";
diff --git a/src/expr/skolem_manager.h b/src/expr/skolem_manager.h
index a18de8a2e..780413d17 100644
--- a/src/expr/skolem_manager.h
+++ b/src/expr/skolem_manager.h
@@ -112,6 +112,10 @@ enum class SkolemFunId
* i = 0, ..., n.
*/
RE_UNFOLD_POS_COMPONENT,
+ BAGS_FOLD_CARD,
+ BAGS_FOLD_COMBINE,
+ BAGS_FOLD_ELEMENTS,
+ BAGS_FOLD_UNION_DISJOINT,
/** An interpreted function for bag.choose operator:
* (bag.choose A) is expanded as
* (witness ((x elementType))
diff --git a/src/parser/smt2/smt2.cpp b/src/parser/smt2/smt2.cpp
index ad380a31c..4e1a8aae8 100644
--- a/src/parser/smt2/smt2.cpp
+++ b/src/parser/smt2/smt2.cpp
@@ -629,6 +629,7 @@ Command* Smt2::setLogic(std::string name, bool fromCommand)
addOperator(api::BAG_FROM_SET, "bag.from_set");
addOperator(api::BAG_TO_SET, "bag.to_set");
addOperator(api::BAG_MAP, "bag.map");
+ addOperator(api::BAG_FOLD, "bag.fold");
}
if(d_logic.isTheoryEnabled(theory::THEORY_STRINGS)) {
defineType("String", d_solver->getStringSort(), true, true);
diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp
index 13477b792..875ca7dc2 100644
--- a/src/printer/smt2/smt2_printer.cpp
+++ b/src/printer/smt2/smt2_printer.cpp
@@ -1098,6 +1098,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v)
case kind::BAG_FROM_SET: return "bag.from_set";
case kind::BAG_TO_SET: return "bag.to_set";
case kind::BAG_MAP: return "bag.map";
+ case kind::BAG_FOLD: return "bag.fold";
// fp theory
case kind::FLOATINGPOINT_FP: return "fp";
diff --git a/src/theory/bags/bag_reduction.cpp b/src/theory/bags/bag_reduction.cpp
new file mode 100644
index 000000000..9203a1c45
--- /dev/null
+++ b/src/theory/bags/bag_reduction.cpp
@@ -0,0 +1,119 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ * Mudathir Mohamed
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved. See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * bag reduction.
+ */
+
+#include "theory/bags/bag_reduction.h"
+
+#include "expr/bound_var_manager.h"
+#include "expr/emptybag.h"
+#include "expr/skolem_manager.h"
+#include "theory/quantifiers/fmf/bounded_integers.h"
+#include "util/rational.h"
+
+using namespace cvc5;
+using namespace cvc5::kind;
+
+namespace cvc5 {
+namespace theory {
+namespace bags {
+
+BagReduction::BagReduction(Env& env) : EnvObj(env) {}
+
+BagReduction::~BagReduction() {}
+
+/**
+ * A bound variable corresponding to the universally quantified integer
+ * variable used to range over the distinct elements in a bag, used
+ * for axiomatizing the behavior of some term.
+ */
+struct IndexVarAttributeId
+{
+};
+typedef expr::Attribute<IndexVarAttributeId, Node> IndexVarAttribute;
+
+Node BagReduction::reduceFoldOperator(Node node, std::vector<Node>& asserts)
+{
+ Assert(node.getKind() == BAG_FOLD);
+ if (d_env.getLogicInfo().isHigherOrder())
+ {
+ NodeManager* nm = NodeManager::currentNM();
+ SkolemManager* sm = nm->getSkolemManager();
+ Node f = node[0];
+ Node t = node[1];
+ Node A = node[2];
+ Node zero = nm->mkConst(CONST_RATIONAL, Rational(0));
+ Node one = nm->mkConst(CONST_RATIONAL, Rational(1));
+ // types
+ TypeNode bagType = A.getType();
+ TypeNode elementType = A.getType().getBagElementType();
+ TypeNode integerType = nm->integerType();
+ TypeNode ufType = nm->mkFunctionType(integerType, elementType);
+ TypeNode resultType = t.getType();
+ TypeNode combineType = nm->mkFunctionType(integerType, resultType);
+ TypeNode unionDisjointType = nm->mkFunctionType(integerType, bagType);
+ // skolem functions
+ Node n = sm->mkSkolemFunction(SkolemFunId::BAGS_FOLD_CARD, integerType, A);
+ Node uf = sm->mkSkolemFunction(SkolemFunId::BAGS_FOLD_ELEMENTS, ufType, A);
+ Node unionDisjoint = sm->mkSkolemFunction(
+ SkolemFunId::BAGS_FOLD_UNION_DISJOINT, unionDisjointType, A);
+ Node combine = sm->mkSkolemFunction(
+ SkolemFunId::BAGS_FOLD_COMBINE, combineType, {f, t, A});
+
+ BoundVarManager* bvm = nm->getBoundVarManager();
+ Node i = bvm->mkBoundVar<IndexVarAttribute>(node, "i", nm->integerType());
+ Node iList = nm->mkNode(BOUND_VAR_LIST, i);
+ Node iMinusOne = nm->mkNode(MINUS, i, one);
+ Node uf_i = nm->mkNode(APPLY_UF, uf, i);
+ Node combine_0 = nm->mkNode(APPLY_UF, combine, zero);
+ Node combine_iMinusOne = nm->mkNode(APPLY_UF, combine, iMinusOne);
+ Node combine_i = nm->mkNode(APPLY_UF, combine, i);
+ Node combine_n = nm->mkNode(APPLY_UF, combine, n);
+ Node unionDisjoint_0 = nm->mkNode(APPLY_UF, unionDisjoint, zero);
+ Node unionDisjoint_iMinusOne =
+ nm->mkNode(APPLY_UF, unionDisjoint, iMinusOne);
+ Node unionDisjoint_i = nm->mkNode(APPLY_UF, unionDisjoint, i);
+ Node unionDisjoint_n = nm->mkNode(APPLY_UF, unionDisjoint, n);
+ Node combine_0_equal = combine_0.eqNode(t);
+ Node combine_i_equal =
+ combine_i.eqNode(nm->mkNode(APPLY_UF, f, uf_i, combine_iMinusOne));
+ Node unionDisjoint_0_equal =
+ unionDisjoint_0.eqNode(nm->mkConst(EmptyBag(bagType)));
+ Node singleton = nm->mkBag(elementType, uf_i, one);
+
+ Node unionDisjoint_i_equal = unionDisjoint_i.eqNode(
+ nm->mkNode(BAG_UNION_DISJOINT, singleton, unionDisjoint_iMinusOne));
+ Node interval_i =
+ nm->mkNode(AND, nm->mkNode(GEQ, i, one), nm->mkNode(LEQ, i, n));
+
+ Node body_i =
+ nm->mkNode(IMPLIES,
+ interval_i,
+ nm->mkNode(AND, combine_i_equal, unionDisjoint_i_equal));
+ Node forAll_i =
+ quantifiers::BoundedIntegers::mkBoundedForall(iList, body_i);
+ Node nonNegative = nm->mkNode(GEQ, n, zero);
+ Node unionDisjoint_n_equal = A.eqNode(unionDisjoint_n);
+ asserts.push_back(forAll_i);
+ asserts.push_back(combine_0_equal);
+ asserts.push_back(unionDisjoint_0_equal);
+ asserts.push_back(unionDisjoint_n_equal);
+ asserts.push_back(nonNegative);
+ return combine_n;
+ }
+ return Node::null();
+}
+
+} // namespace bags
+} // namespace theory
+} // namespace cvc5
diff --git a/src/theory/bags/bag_reduction.h b/src/theory/bags/bag_reduction.h
new file mode 100644
index 000000000..11f091f94
--- /dev/null
+++ b/src/theory/bags/bag_reduction.h
@@ -0,0 +1,77 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ * Mudathir Mohamed
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved. See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * bag reduction.
+ */
+
+#ifndef CVC5__BAG_REDUCTION_H
+#define CVC5__BAG_REDUCTION_H
+
+#include <vector>
+
+#include "cvc5_private.h"
+#include "smt/env_obj.h"
+#include "theory/bags/inference_manager.h"
+
+namespace cvc5 {
+namespace theory {
+namespace bags {
+
+/**
+ * class for bag reductions
+ */
+class BagReduction : EnvObj
+{
+ public:
+ BagReduction(Env& env);
+ ~BagReduction();
+
+ /**
+ * @param node a term of the form (bag.fold f t A) where
+ * f: (-> T1 T2 T2) is a binary operation
+ * t: T2 is the initial value
+ * A: (Bag T1) is a bag
+ * @param asserts a list of assertions generated by this reduction
+ * @return the reduction term (combine n) such that
+ * (and
+ * (forall ((i Int))
+ * (let ((iMinusOne (- i 1)))
+ * (let ((uf_i (uf i)))
+ * (=>
+ * (and (>= i 1) (<= i n))
+ * (and
+ * (= (combine i) (f uf_i (combine iMinusOne)))
+ * (=
+ * (unionDisjoint i)
+ * (bag.union_disjoint
+ * (bag uf_i 1)
+ * (unionDisjoint iMinusOne))))))))
+ * (= (combine 0) t)
+ * (= (unionDisjoint 0) (as bag.empty (Bag T1)))
+ * (= A (unionDisjoint n))
+ * (>= n 0))
+ * where
+ * n: Int is the cardinality of bag A
+ * uf:Int -> T1 is an uninterpreted function that represents elements of A
+ * combine: Int -> T2 is an uninterpreted function
+ * unionDisjoint: Int -> (Bag T1) is an uninterpreted function
+ */
+ Node reduceFoldOperator(Node node, std::vector<Node>& asserts);
+
+ private:
+};
+
+} // namespace bags
+} // namespace theory
+} // namespace cvc5
+
+#endif /* CVC5__BAG_REDUCTION_H */
diff --git a/src/theory/bags/bags_rewriter.cpp b/src/theory/bags/bags_rewriter.cpp
index b8f3b80c9..766731806 100644
--- a/src/theory/bags/bags_rewriter.cpp
+++ b/src/theory/bags/bags_rewriter.cpp
@@ -90,6 +90,7 @@ RewriteResponse BagsRewriter::postRewrite(TNode n)
case BAG_FROM_SET: response = rewriteFromSet(n); break;
case BAG_TO_SET: response = rewriteToSet(n); break;
case BAG_MAP: response = postRewriteMap(n); break;
+ case BAG_FOLD: response = postRewriteFold(n); break;
default: response = BagsRewriteResponse(n, Rewrite::NONE); break;
}
}
@@ -560,6 +561,45 @@ BagsRewriteResponse BagsRewriter::postRewriteMap(const TNode& n) const
default: return BagsRewriteResponse(n, Rewrite::NONE);
}
}
+
+BagsRewriteResponse BagsRewriter::postRewriteFold(const TNode& n) const
+{
+ Assert(n.getKind() == kind::BAG_FOLD);
+ Node f = n[0];
+ Node t = n[1];
+ Node bag = n[2];
+ if (bag.isConst())
+ {
+ Node value = NormalForm::evaluateBagFold(n);
+ return BagsRewriteResponse(value, Rewrite::FOLD_CONST);
+ }
+ Kind k = bag.getKind();
+ switch (k)
+ {
+ case BAG_MAKE:
+ {
+ if (bag[1].isConst() && bag[1].getConst<Rational>() > Rational(0))
+ {
+ // (bag.fold f t (bag x n)) = (f t ... (f t (f t x))) n times, n > 0
+ Node value = NormalForm::evaluateBagFold(n);
+ return BagsRewriteResponse(value, Rewrite::FOLD_BAG);
+ }
+ break;
+ }
+ case BAG_UNION_DISJOINT:
+ {
+ // (bag.fold f t (bag.union_disjoint A B)) =
+ // (bag.fold f (bag.fold f t A) B) where A < B to break symmetry
+ Node A = bag[0] < bag[1] ? bag[0] : bag[1];
+ Node B = bag[0] < bag[1] ? bag[1] : bag[0];
+ Node foldA = d_nm->mkNode(BAG_FOLD, f, t, A);
+ Node fold = d_nm->mkNode(BAG_FOLD, f, foldA, B);
+ return BagsRewriteResponse(fold, Rewrite::FOLD_UNION_DISJOINT);
+ }
+ default: return BagsRewriteResponse(n, Rewrite::NONE);
+ }
+ return BagsRewriteResponse(n, Rewrite::NONE);
+}
} // namespace bags
} // namespace theory
} // namespace cvc5
diff --git a/src/theory/bags/bags_rewriter.h b/src/theory/bags/bags_rewriter.h
index a938b3bd4..d666982a7 100644
--- a/src/theory/bags/bags_rewriter.h
+++ b/src/theory/bags/bags_rewriter.h
@@ -222,6 +222,16 @@ class BagsRewriter : public TheoryRewriter
*/
BagsRewriteResponse postRewriteMap(const TNode& n) const;
+ /**
+ * rewrites for n include:
+ * - (bag.fold f t (as bag.empty (Bag T1))) = t
+ * - (bag.fold f t (bag x n)) = (f t ... (f t (f t x))) n times, where n > 0
+ * - (bag.fold f t (bag.union_disjoint A B)) =
+ * (bag.fold f (bag.fold f t A) B) where A < B to break symmetry
+ * where f: T1 -> T2 -> T2
+ */
+ BagsRewriteResponse postRewriteFold(const TNode& n) const;
+
private:
/** Reference to the rewriter statistics. */
NodeManager* d_nm;
diff --git a/src/theory/bags/kinds b/src/theory/bags/kinds
index a5c6e75bf..5e4119fa1 100644
--- a/src/theory/bags/kinds
+++ b/src/theory/bags/kinds
@@ -76,6 +76,14 @@ operator BAG_CHOOSE 1 "return an element in the bag given as a parameter
# of the second argument, a bag of type (Bag T1), and returns a bag of type (Bag T2).
operator BAG_MAP 2 "bag map function"
+# bag.fold operator combines elements of a bag into a single value.
+# (bag.fold f t B) folds the elements of bag B starting with term t and using
+# the combining function f.
+# f: a binary operation of type (-> T1 T2 T2)
+# t: an initial value of type T2
+# B: a bag of type (Bag T1)
+operator BAG_FOLD 3 "bag fold operator"
+
typerule BAG_UNION_MAX ::cvc5::theory::bags::BinaryOperatorTypeRule
typerule BAG_UNION_DISJOINT ::cvc5::theory::bags::BinaryOperatorTypeRule
typerule BAG_INTER_MIN ::cvc5::theory::bags::BinaryOperatorTypeRule
@@ -93,6 +101,7 @@ typerule BAG_IS_SINGLETON ::cvc5::theory::bags::IsSingletonTypeRule
typerule BAG_FROM_SET ::cvc5::theory::bags::FromSetTypeRule
typerule BAG_TO_SET ::cvc5::theory::bags::ToSetTypeRule
typerule BAG_MAP ::cvc5::theory::bags::BagMapTypeRule
+typerule BAG_FOLD ::cvc5::theory::bags::BagFoldTypeRule
construle BAG_UNION_DISJOINT ::cvc5::theory::bags::BinaryOperatorTypeRule
construle BAG_MAKE ::cvc5::theory::bags::BagMakeTypeRule
diff --git a/src/theory/bags/normal_form.cpp b/src/theory/bags/normal_form.cpp
index 12bf513b5..9a510c6f5 100644
--- a/src/theory/bags/normal_form.cpp
+++ b/src/theory/bags/normal_form.cpp
@@ -110,6 +110,7 @@ Node NormalForm::evaluate(TNode n)
case BAG_FROM_SET: return evaluateFromSet(n);
case BAG_TO_SET: return evaluateToSet(n);
case BAG_MAP: return evaluateBagMap(n);
+ case BAG_FOLD: return evaluateBagFold(n);
default: break;
}
Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n
@@ -169,8 +170,6 @@ Node NormalForm::evaluateBinaryOperation(const TNode& n,
std::map<Node, Rational> NormalForm::getBagElements(TNode n)
{
- Assert(n.isConst()) << "node " << n << " is not in a normal form"
- << std::endl;
std::map<Node, Rational> elements;
if (n.getKind() == BAG_EMPTY)
{
@@ -692,6 +691,41 @@ Node NormalForm::evaluateBagMap(TNode n)
return ret;
}
+Node NormalForm::evaluateBagFold(TNode n)
+{
+ Assert(n.getKind() == BAG_FOLD);
+
+ // Examples
+ // --------
+ // minimum string
+ // - (bag.fold
+ // ((lambda ((x String) (y String)) (ite (str.< x y) x y))
+ // ""
+ // (bag.union_disjoint (bag "a" 2) (bag "b" 3))
+ // = "a"
+
+ Node f = n[0]; // combining function
+ Node ret = n[1]; // initial value
+ Node A = n[2]; // bag
+ std::map<Node, Rational> elements = NormalForm::getBagElements(A);
+
+ std::map<Node, Rational>::iterator it = elements.begin();
+ NodeManager* nm = NodeManager::currentNM();
+ while (it != elements.end())
+ {
+ // apply the combination function n times, where n is the multiplicity
+ Rational count = it->second;
+ Assert(count.sgn() >= 0) << "negative multiplicity" << std::endl;
+ while (!count.isZero())
+ {
+ ret = nm->mkNode(APPLY_UF, f, it->first, ret);
+ count = count - 1;
+ }
+ ++it;
+ }
+ return ret;
+}
+
} // namespace bags
} // namespace theory
} // namespace cvc5
diff --git a/src/theory/bags/normal_form.h b/src/theory/bags/normal_form.h
index 8ceee2881..5275678ff 100644
--- a/src/theory/bags/normal_form.h
+++ b/src/theory/bags/normal_form.h
@@ -75,6 +75,12 @@ class NormalForm
static Node constructBagFromElements(TypeNode t,
const std::map<Node, Node>& elements);
+ /**
+ * @param n has the form (bag.fold f t A) where A is a constant bag
+ * @return a single value which is the result of the fold
+ */
+ static Node evaluateBagFold(TNode n);
+
private:
/**
* a high order helper function that return a constant bag that is the result
diff --git a/src/theory/bags/rewrites.cpp b/src/theory/bags/rewrites.cpp
index 896c4f251..1a8f8f849 100644
--- a/src/theory/bags/rewrites.cpp
+++ b/src/theory/bags/rewrites.cpp
@@ -38,6 +38,9 @@ const char* toString(Rewrite r)
case Rewrite::EQ_REFL: return "EQ_REFL";
case Rewrite::EQ_SYM: return "EQ_SYM";
case Rewrite::FROM_SINGLETON: return "FROM_SINGLETON";
+ case Rewrite::FOLD_BAG: return "FOLD_BAG";
+ case Rewrite::FOLD_CONST: return "FOLD_CONST";
+ case Rewrite::FOLD_UNION_DISJOINT: return "FOLD_UNION_DISJOINT";
case Rewrite::IDENTICAL_NODES: return "IDENTICAL_NODES";
case Rewrite::INTERSECTION_EMPTY_LEFT: return "INTERSECTION_EMPTY_LEFT";
case Rewrite::INTERSECTION_EMPTY_RIGHT: return "INTERSECTION_EMPTY_RIGHT";
diff --git a/src/theory/bags/rewrites.h b/src/theory/bags/rewrites.h
index c5050ea72..0b7188599 100644
--- a/src/theory/bags/rewrites.h
+++ b/src/theory/bags/rewrites.h
@@ -42,6 +42,9 @@ enum class Rewrite : uint32_t
EQ_REFL,
EQ_SYM,
FROM_SINGLETON,
+ FOLD_BAG,
+ FOLD_CONST,
+ FOLD_UNION_DISJOINT,
IDENTICAL_NODES,
INTERSECTION_EMPTY_LEFT,
INTERSECTION_EMPTY_RIGHT,
diff --git a/src/theory/bags/theory_bags.cpp b/src/theory/bags/theory_bags.cpp
index 4dffbdb00..68bdb7b1b 100644
--- a/src/theory/bags/theory_bags.cpp
+++ b/src/theory/bags/theory_bags.cpp
@@ -20,6 +20,7 @@
#include "proof/proof_checker.h"
#include "smt/logic_exception.h"
#include "theory/bags/normal_form.h"
+#include "theory/quantifiers/fmf/bounded_integers.h"
#include "theory/rewriter.h"
#include "theory/theory_model.h"
#include "util/rational.h"
@@ -39,7 +40,8 @@ TheoryBags::TheoryBags(Env& env, OutputChannel& out, Valuation valuation)
d_statistics(),
d_rewriter(&d_statistics.d_rewrites),
d_termReg(env, d_state, d_im),
- d_solver(env, d_state, d_im, d_termReg)
+ d_solver(env, d_state, d_im, d_termReg),
+ d_bagReduction(env)
{
// use the official theory state and inference manager objects
d_theoryState = &d_state;
@@ -87,6 +89,18 @@ TrustNode TheoryBags::ppRewrite(TNode atom, std::vector<SkolemLemma>& lems)
{
case kind::BAG_CHOOSE: return expandChooseOperator(atom, lems);
case kind::BAG_CARD: return expandCardOperator(atom, lems);
+ case kind::BAG_FOLD:
+ {
+ std::vector<Node> asserts;
+ Node ret = d_bagReduction.reduceFoldOperator(atom, asserts);
+ NodeManager* nm = NodeManager::currentNM();
+ Node andNode = nm->mkNode(AND, asserts);
+ d_im.lemma(andNode, InferenceId::BAGS_FOLD);
+ Trace("bags::ppr") << "reduce(" << atom << ") = " << ret
+ << " such that:" << std::endl
+ << asserts << std::endl;
+ return TrustNode::mkTrustRewrite(atom, ret, nullptr);
+ }
default: return TrustNode::null();
}
}
@@ -131,9 +145,9 @@ TrustNode TheoryBags::expandChooseOperator(const Node& node,
return TrustNode::mkTrustRewrite(node, ret, nullptr);
}
-TrustNode TheoryBags::expandCardOperator(TNode n,
- std::vector<SkolemLemma>& vector)
+TrustNode TheoryBags::expandCardOperator(TNode n, std::vector<SkolemLemma>&)
{
+ Assert(n.getKind() == BAG_CARD);
if (d_env.getLogicInfo().isHigherOrder())
{
// (bag.card A) = (bag.count 1 (bag.map (lambda ((x E)) 1) A)),
diff --git a/src/theory/bags/theory_bags.h b/src/theory/bags/theory_bags.h
index fd28482d4..1a8af780e 100644
--- a/src/theory/bags/theory_bags.h
+++ b/src/theory/bags/theory_bags.h
@@ -18,6 +18,7 @@
#ifndef CVC5__THEORY__BAGS__THEORY_BAGS_H
#define CVC5__THEORY__BAGS__THEORY_BAGS_H
+#include "theory/bags/bag_reduction.h"
#include "theory/bags/bag_solver.h"
#include "theory/bags/bags_rewriter.h"
#include "theory/bags/bags_statistics.h"
@@ -112,6 +113,9 @@ class TheoryBags : public Theory
/** the main solver for bags */
BagSolver d_solver;
+ /** bag reduction */
+ BagReduction d_bagReduction;
+
void eqNotifyNewClass(TNode n);
void eqNotifyMerge(TNode n1, TNode n2);
void eqNotifyDisequal(TNode t1, TNode t2, TNode reason);
diff --git a/src/theory/bags/theory_bags_type_rules.cpp b/src/theory/bags/theory_bags_type_rules.cpp
index 2623f3ed7..fe81fadf5 100644
--- a/src/theory/bags/theory_bags_type_rules.cpp
+++ b/src/theory/bags/theory_bags_type_rules.cpp
@@ -327,6 +327,57 @@ TypeNode BagMapTypeRule::computeType(NodeManager* nodeManager,
return retType;
}
+TypeNode BagFoldTypeRule::computeType(NodeManager* nodeManager,
+ TNode n,
+ bool check)
+{
+ Assert(n.getKind() == kind::BAG_FOLD);
+ TypeNode functionType = n[0].getType(check);
+ TypeNode initialValueType = n[1].getType(check);
+ TypeNode bagType = n[2].getType(check);
+ if (check)
+ {
+ if (!bagType.isBag())
+ {
+ throw TypeCheckingExceptionPrivate(
+ n,
+ "bag.fold operator expects a bag in the third argument, "
+ "a non-bag is found");
+ }
+
+ TypeNode elementType = bagType.getBagElementType();
+
+ if (!(functionType.isFunction()))
+ {
+ std::stringstream ss;
+ ss << "Operator " << n.getKind() << " expects a function of type (-> "
+ << elementType << " T2 T2) as a first argument. "
+ << "Found a term of type '" << functionType << "'.";
+ throw TypeCheckingExceptionPrivate(n, ss.str());
+ }
+ std::vector<TypeNode> argTypes = functionType.getArgTypes();
+ TypeNode rangeType = functionType.getRangeType();
+ if (!(argTypes.size() == 2 && argTypes[0] == elementType
+ && argTypes[1] == rangeType))
+ {
+ std::stringstream ss;
+ ss << "Operator " << n.getKind() << " expects a function of type (-> "
+ << elementType << " T2 T2). "
+ << "Found a function of type '" << functionType << "'.";
+ throw TypeCheckingExceptionPrivate(n, ss.str());
+ }
+ if (rangeType != initialValueType)
+ {
+ std::stringstream ss;
+ ss << "Operator " << n.getKind() << " expects an initial value of type "
+ << rangeType << ". Found a term of type '" << initialValueType << "'.";
+ throw TypeCheckingExceptionPrivate(n, ss.str());
+ }
+ }
+ TypeNode retType = n[0].getType().getRangeType();
+ return retType;
+}
+
Cardinality BagsProperties::computeCardinality(TypeNode type)
{
return Cardinality::INTEGERS;
diff --git a/src/theory/bags/theory_bags_type_rules.h b/src/theory/bags/theory_bags_type_rules.h
index d7b8b2737..fa2f78313 100644
--- a/src/theory/bags/theory_bags_type_rules.h
+++ b/src/theory/bags/theory_bags_type_rules.h
@@ -132,6 +132,15 @@ struct BagMapTypeRule
static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
}; /* struct BagMapTypeRule */
+/**
+ * Type rule for (bag.fold f t A) to make sure f is a binary operation of type
+ * (-> T1 T2 T2), t of type T2, and B is a bag of type (Bag T1)
+ */
+struct BagFoldTypeRule
+{
+ static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
+}; /* struct BagFoldTypeRule */
+
struct BagsProperties
{
static Cardinality computeCardinality(TypeNode type);
diff --git a/src/theory/inference_id.cpp b/src/theory/inference_id.cpp
index 82ae674e2..56d2f0500 100644
--- a/src/theory/inference_id.cpp
+++ b/src/theory/inference_id.cpp
@@ -118,6 +118,7 @@ const char* toString(InferenceId i)
case InferenceId::BAGS_DIFFERENCE_REMOVE: return "BAGS_DIFFERENCE_REMOVE";
case InferenceId::BAGS_DUPLICATE_REMOVAL: return "BAGS_DUPLICATE_REMOVAL";
case InferenceId::BAGS_MAP: return "BAGS_MAP";
+ case InferenceId::BAGS_FOLD: return "BAGS_FOLD";
case InferenceId::BV_BITBLAST_CONFLICT: return "BV_BITBLAST_CONFLICT";
case InferenceId::BV_BITBLAST_INTERNAL_EAGER_LEMMA:
diff --git a/src/theory/inference_id.h b/src/theory/inference_id.h
index ad879d7ab..d98d3ff25 100644
--- a/src/theory/inference_id.h
+++ b/src/theory/inference_id.h
@@ -180,6 +180,7 @@ enum class InferenceId
BAGS_DIFFERENCE_REMOVE,
BAGS_DUPLICATE_REMOVAL,
BAGS_MAP,
+ BAGS_FOLD,
// ---------------------------------- end bags theory
// ---------------------------------- bitvector theory
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback