diff options
Diffstat (limited to 'src/theory/arrays/theory_arrays_rewriter.cpp')
-rw-r--r-- | src/theory/arrays/theory_arrays_rewriter.cpp | 103 |
1 files changed, 59 insertions, 44 deletions
diff --git a/src/theory/arrays/theory_arrays_rewriter.cpp b/src/theory/arrays/theory_arrays_rewriter.cpp index 1072ffaf4..dd7a56d33 100644 --- a/src/theory/arrays/theory_arrays_rewriter.cpp +++ b/src/theory/arrays/theory_arrays_rewriter.cpp @@ -20,6 +20,9 @@ #include "expr/array_store_all.h" #include "expr/attribute.h" +#include "proof/conv_proof_generator.h" +#include "proof/eager_proof_generator.h" +#include "theory/arrays/skolem_cache.h" #include "util/cardinality.h" namespace cvc5 { @@ -48,6 +51,11 @@ void setMostFrequentValueCount(TNode store, uint64_t count) { return store.setAttribute(ArrayConstantMostFrequentValueCountAttr(), count); } +TheoryArraysRewriter::TheoryArraysRewriter(ProofNodeManager* pnm) + : d_epg(pnm ? new EagerProofGenerator(pnm) : nullptr) +{ +} + Node TheoryArraysRewriter::normalizeConstant(TNode node) { return normalizeConstant(node, node[1].getType().getCardinality()); @@ -271,6 +279,48 @@ Node TheoryArraysRewriter::normalizeConstant(TNode node, Cardinality indexCard) return n; } +Node TheoryArraysRewriter::expandEqRange(TNode node) +{ + Assert(node.getKind() == kind::EQ_RANGE); + + NodeManager* nm = NodeManager::currentNM(); + TNode a = node[0]; + TNode b = node[1]; + TNode i = node[2]; + TNode j = node[3]; + Node k = SkolemCache::getEqRangeVar(node); + Node bvl = nm->mkNode(kind::BOUND_VAR_LIST, k); + TypeNode type = k.getType(); + + Kind kle; + Node range; + if (type.isBitVector()) + { + kle = kind::BITVECTOR_ULE; + } + else if (type.isFloatingPoint()) + { + kle = kind::FLOATINGPOINT_LEQ; + } + else if (type.isInteger() || type.isReal()) + { + kle = kind::LEQ; + } + else + { + Unimplemented() << "Type " << type << " is not supported for predicate " + << node.getKind(); + } + + range = nm->mkNode(kind::AND, nm->mkNode(kle, i, k), nm->mkNode(kle, k, j)); + + Node eq = nm->mkNode(kind::EQUAL, + nm->mkNode(kind::SELECT, a, k), + nm->mkNode(kind::SELECT, b, k)); + Node implies = nm->mkNode(kind::IMPLIES, range, eq); + return nm->mkNode(kind::FORALL, bvl, implies); +} + RewriteResponse TheoryArraysRewriter::postRewrite(TNode node) { Trace("arrays-postrewrite") @@ -610,57 +660,22 @@ RewriteResponse TheoryArraysRewriter::preRewrite(TNode node) TrustNode TheoryArraysRewriter::expandDefinition(Node node) { - NodeManager* nm = NodeManager::currentNM(); Kind kind = node.getKind(); - /* Expand - * - * (eqrange a b i j) - * - * to - * - * forall k . i <= k <= j => a[k] = b[k] - * - */ if (kind == kind::EQ_RANGE) { - TNode a = node[0]; - TNode b = node[1]; - TNode i = node[2]; - TNode j = node[3]; - Node k = nm->mkBoundVar(i.getType()); - Node bvl = nm->mkNode(kind::BOUND_VAR_LIST, k); - TypeNode type = k.getType(); - - Kind kle; - Node range; - if (type.isBitVector()) - { - kle = kind::BITVECTOR_ULE; - } - else if (type.isFloatingPoint()) + Node expandedEqRange = expandEqRange(node); + if (d_epg) { - kle = kind::FLOATINGPOINT_LEQ; + TrustNode tn = d_epg->mkTrustNode(node.eqNode(expandedEqRange), + PfRule::ARRAYS_EQ_RANGE_EXPAND, + {}, + {node}); + return TrustNode::mkTrustRewrite(node, expandedEqRange, d_epg.get()); } - else if (type.isInteger() || type.isReal()) - { - kle = kind::LEQ; - } - else - { - Unimplemented() << "Type " << type << " is not supported for predicate " - << kind; - } - - range = nm->mkNode(kind::AND, nm->mkNode(kle, i, k), nm->mkNode(kle, k, j)); - - Node eq = nm->mkNode(kind::EQUAL, - nm->mkNode(kind::SELECT, a, k), - nm->mkNode(kind::SELECT, b, k)); - Node implies = nm->mkNode(kind::IMPLIES, range, eq); - Node ret = nm->mkNode(kind::FORALL, bvl, implies); - return TrustNode::mkTrustRewrite(node, ret, nullptr); + return TrustNode::mkTrustRewrite(node, expandedEqRange, nullptr); } + return TrustNode::null(); } |