summaryrefslogtreecommitdiff
path: root/src/theory/strings/array_solver.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/theory/strings/array_solver.cpp')
-rw-r--r--src/theory/strings/array_solver.cpp96
1 files changed, 77 insertions, 19 deletions
diff --git a/src/theory/strings/array_solver.cpp b/src/theory/strings/array_solver.cpp
index 09e3aefdd..65ff4cde4 100644
--- a/src/theory/strings/array_solver.cpp
+++ b/src/theory/strings/array_solver.cpp
@@ -15,8 +15,10 @@
#include "theory/strings/array_solver.h"
+#include "expr/sequence.h"
#include "theory/strings/arith_entail.h"
#include "theory/strings/theory_strings_utils.h"
+#include "theory/strings/word.h"
#include "util/rational.h"
using namespace cvc5::context;
@@ -42,7 +44,7 @@ ArraySolver::ArraySolver(Env& env,
d_eqProc(context())
{
NodeManager* nm = NodeManager::currentNM();
- d_zero = nm->mkConst(Rational(0));
+ d_zero = nm->mkConst(CONST_RATIONAL, Rational(0));
}
ArraySolver::~ArraySolver() {}
@@ -81,6 +83,7 @@ void ArraySolver::checkTerms(Kind k)
Node r = d_state.getRepresentative(t[0]);
NormalForm& nf = d_csolver.getNormalForm(r);
Trace("seq-array-debug") << "...normal form " << nf.d_nf << std::endl;
+ std::vector<Node> nfChildren;
if (nf.d_nf.empty())
{
// updates should have been reduced (UPD_EMPTYSTR)
@@ -92,8 +95,16 @@ void ArraySolver::checkTerms(Kind k)
{
Trace("seq-array-debug") << "...norm form size 1" << std::endl;
// NOTE: could split on n=0 if needed, do not introduce ITE
- if (nf.d_nf[0].getKind() == SEQ_UNIT)
+ Kind ck = nf.d_nf[0].getKind();
+ // Note that (seq.unit c) is rewritten to CONST_SEQUENCE{c}, hence we
+ // check two cases here. It is important for completeness of this schema
+ // to handle this differently from STRINGS_ARRAY_UPDATE_CONCAT /
+ // STRINGS_ARRAY_NTH_CONCAT. Otherwise we would conclude a trivial
+ // equality when update/nth is applied to a constant of length one.
+ if (ck == SEQ_UNIT
+ || (ck == CONST_SEQUENCE && Word::getLength(nf.d_nf[0]) == 1))
{
+ Trace("seq-array-debug") << "...unit case" << std::endl;
// do we know whether n = 0 ?
// x = (seq.unit m) => (seq.update x n z) = ite(n=0, z, (seq.unit m))
// x = (seq.unit m) => (seq.nth x n) = ite(n=0, m, Uf(x, n))
@@ -109,7 +120,15 @@ void ArraySolver::checkTerms(Kind k)
else
{
Assert(k == SEQ_NTH);
- thenBranch = nf.d_nf[0][0];
+ if (ck == CONST_SEQUENCE)
+ {
+ const Sequence& seq = nf.d_nf[0].getConst<Sequence>();
+ thenBranch = seq.getVec()[0];
+ }
+ else
+ {
+ thenBranch = nf.d_nf[0][0];
+ }
Node uf = SkolemCache::mkSkolemSeqNth(t[0].getType(), "Uf");
elseBranch = nm->mkNode(APPLY_UF, uf, t[0], t[1]);
iid = InferenceId::STRINGS_ARRAY_NTH_UNIT;
@@ -126,17 +145,33 @@ void ArraySolver::checkTerms(Kind k)
d_eqProc.insert(eq);
d_im.sendInference(exp, eq, iid);
}
+ continue;
}
- // otherwise, the equivalence class is pure wrt concatenation
- d_currTerms[k].push_back(t);
- continue;
+ else if (ck != CONST_SEQUENCE)
+ {
+ // otherwise, if the normal form is not a constant sequence, the
+ // equivalence class is pure wrt concatenation.
+ d_currTerms[k].push_back(t);
+ continue;
+ }
+ // if the normal form is a constant sequence, it is treated as a
+ // concatenation. We split per character and case split on whether the
+ // nth/update falls on each character below, which must have a size
+ // greater than one.
+ std::vector<Node> chars = Word::getChars(nf.d_nf[0]);
+ Assert (chars.size()>1);
+ nfChildren.insert(nfChildren.end(), chars.begin(), chars.end());
+ }
+ else
+ {
+ nfChildren.insert(nfChildren.end(), nf.d_nf.begin(), nf.d_nf.end());
}
// otherwise, we are the concatenation of the components
// NOTE: for nth, split on index vs component lengths, do not introduce ITE
std::vector<Node> cond;
std::vector<Node> cchildren;
std::vector<Node> lacc;
- for (const Node& c : nf.d_nf)
+ for (const Node& c : nfChildren)
{
Trace("seq-array-debug") << "...process " << c << std::endl;
Node clen = nm->mkNode(STRING_LENGTH, c);
@@ -146,26 +181,49 @@ void ArraySolver::checkTerms(Kind k)
Node currSum = lacc.size() == 1 ? lacc[0] : nm->mkNode(PLUS, lacc);
currIndex = nm->mkNode(MINUS, currIndex, currSum);
}
- if (k == STRING_UPDATE)
+ Node cc;
+ // If it is a constant of length one, then the update/nth is determined
+ // in this interval. Notice this is done here as
+ // an optimization to short cut introducing terms like
+ // (seq.nth (seq.unit c) i), which by construction is only relevant in
+ // the context where i = 0, hence we replace by c here.
+ if (c.getKind() == CONST_SEQUENCE)
{
- Node cc = nm->mkNode(STRING_UPDATE, c, currIndex, t[2]);
- Trace("seq-array-debug") << "......component " << cc << std::endl;
- cchildren.push_back(cc);
+ const Sequence& seq = c.getConst<Sequence>();
+ if (seq.size() == 1)
+ {
+ if (k == STRING_UPDATE)
+ {
+ cc = nm->mkNode(ITE, t[1].eqNode(d_zero), t[2], c);
+ }
+ else
+ {
+ cc = seq.getVec()[0];
+ }
+ }
}
- else
+ // if we did not process as a constant of length one
+ if (cc.isNull())
{
- Assert(k == SEQ_NTH);
- Node cc = nm->mkNode(SEQ_NTH, c, currIndex);
- Trace("seq-array-debug") << "......component " << cc << std::endl;
- cchildren.push_back(cc);
+ if (k == STRING_UPDATE)
+ {
+ cc = nm->mkNode(STRING_UPDATE, c, currIndex, t[2]);
+ }
+ else
+ {
+ Assert(k == SEQ_NTH);
+ cc = nm->mkNode(SEQ_NTH, c, currIndex);
+ }
}
+ Trace("seq-array-debug") << "......component " << cc << std::endl;
+ cchildren.push_back(cc);
lacc.push_back(clen);
if (k == SEQ_NTH)
{
Node currSumPost = lacc.size() == 1 ? lacc[0] : nm->mkNode(PLUS, lacc);
- Node cc = nm->mkNode(LT, t[1], currSumPost);
- Trace("seq-array-debug") << "......condition " << cc << std::endl;
- cond.push_back(cc);
+ Node cf = nm->mkNode(LT, t[1], currSumPost);
+ Trace("seq-array-debug") << "......condition " << cf << std::endl;
+ cond.push_back(cf);
}
}
// z = (seq.++ x y) =>
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback