summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndres Noetzli <andres.noetzli@gmail.com>2021-10-05 14:42:41 -0700
committerAndres Noetzli <andres.noetzli@gmail.com>2021-10-05 14:42:41 -0700
commit7a07fefb1b2bbdae0aef64a2dfff5a0e17e0998f (patch)
treee37ed26890ea6a2e8e5a71c727b69e19e6e3f4e4
parent8eabbfd08f54061ceb3e679f0726b89c3a27cb69 (diff)
-rw-r--r--src/rewriter/mkrewrites.py48
-rw-r--r--src/rewriter/node.py16
-rw-r--r--src/rewriter/parser.py13
-rw-r--r--src/rewriter/rewrite_db.cpp5
-rw-r--r--src/rewriter/rewrite_db.h2
-rw-r--r--src/rewriter/rewrite_db_proof_cons.cpp19
-rw-r--r--src/rewriter/rewrite_proof_rule.cpp10
-rw-r--r--src/rewriter/rewrite_proof_rule.h8
-rw-r--r--src/rewriter/rule.py8
-rw-r--r--src/theory/strings/rewrites8
10 files changed, 107 insertions, 30 deletions
diff --git a/src/rewriter/mkrewrites.py b/src/rewriter/mkrewrites.py
index a0019308f..ae905dad9 100644
--- a/src/rewriter/mkrewrites.py
+++ b/src/rewriter/mkrewrites.py
@@ -16,6 +16,8 @@ def gen_kind(op):
Op.OR: 'OR',
Op.IMPLIES: 'IMPLIES',
Op.EQ: 'EQUAL',
+ Op.LAMBDA: 'LAMBDA',
+ Op.BOUND_VARS: 'BOUND_VAR_LIST',
Op.UMINUS: 'UMINUS',
Op.PLUS: 'PLUS',
Op.MINUS: 'MINUS',
@@ -122,7 +124,7 @@ def gen_mk_node(defns, expr):
def gen_rewrite_db_rule(defns, rule):
fvs_list = ', '.join(bvar.name for bvar in rule.bvars)
- fixed_point_arg = 'true' if rule.is_fixed_point else 'false'
+ fixed_point_arg = gen_mk_node(defns, rule.rhs_context) if rule.rhs_context else 'Node::null()'
return f'db.addRule(DslPfRule::{rule.get_enum()}, {{ {fvs_list} }}, {gen_mk_node(defns, rule.lhs)}, {gen_mk_node(defns, rule.rhs)}, {gen_mk_node(defns, rule.cond)}, {fixed_point_arg});'
@@ -147,6 +149,8 @@ def type_check(expr):
sort = None
if expr.op == Op.UMINUS:
sort = Sort(BaseSort.Int)
+ elif expr.op == Op.STRING_LENGTH:
+ sort = Sort(BaseSort.Int)
if sort:
sort.is_const = all(child.sort and child.sort.is_const
@@ -188,6 +192,43 @@ def validate_rule(rule):
type_check(rule.cond)
+def preprocess_rule(rule, decls):
+ if not rule.rhs_context:
+ return
+
+ # Resolve placeholders
+ bvar = Var(fresh_name('t'), rule.rhs.sort)
+ decls.append(bvar)
+ result = dict()
+ to_visit = [rule.rhs_context]
+ while to_visit:
+ curr = to_visit.pop()
+
+ if isinstance(curr, App) and curr in result:
+ if result[curr]:
+ continue
+
+ new_args = []
+ for child in curr.children:
+ new_args.append(result[child])
+
+ result[curr] = App(curr.op, new_args)
+ continue
+
+ if isinstance(curr, Placeholder):
+ result[curr] = bvar
+ elif isinstance(curr, App):
+ to_visit.append(curr)
+ result[curr] = None
+ else:
+ result[curr] = curr
+
+ to_visit.extend(curr.children)
+
+ rule.rhs_context = App(Op.LAMBDA, [App(Op.BOUND_VARS, [bvar]), result[rule.rhs_context]])
+ type_check(rule.rhs_context)
+
+
def gen_rewrite_db(args):
block_tpl = '''
{{
@@ -207,6 +248,7 @@ def gen_rewrite_db(args):
for rule in rules:
file_decls.extend(rule.bvars)
validate_rule(rule)
+ preprocess_rule(rule, file_decls)
rewrites.append(Rewrites(rewrites_file.name, file_decls, rules))
decls.extend(file_decls)
@@ -215,12 +257,12 @@ def gen_rewrite_db(args):
expr_counts = defaultdict(lambda: 0)
to_visit = [
expr for rewrite in rewrites for rule in rewrite.rules
- for expr in [rule.cond, rule.lhs, rule.rhs]
+ for expr in [rule.cond, rule.lhs, rule.rhs, rule.rhs_context]
]
while to_visit:
curr = to_visit.pop()
- if isinstance(curr, Var):
+ if not curr or isinstance(curr, Var):
# Don't generate definitions for variables
continue
diff --git a/src/rewriter/node.py b/src/rewriter/node.py
index 458ba7c03..8676cac00 100644
--- a/src/rewriter/node.py
+++ b/src/rewriter/node.py
@@ -90,6 +90,8 @@ class Op(Enum):
EQ = auto()
ITE = auto()
+ LAMBDA = auto()
+ BOUND_VARS = auto()
###########################################################################
# Strings
@@ -197,6 +199,20 @@ class Sort(Node):
return self.base == BaseSort.Int
+class Placeholder(Node):
+ def __init__(self):
+ super().__init__([], None)
+
+ def __eq__(self, other):
+ return isinstance(other, Placeholder)
+
+ def __hash__(self):
+ return hash('_')
+
+ def __repr__(self):
+ return '_'
+
+
class Var(Node):
def __init__(self, name, sort=None):
super().__init__([], sort)
diff --git a/src/rewriter/parser.py b/src/rewriter/parser.py
index e69350111..e6c76a285 100644
--- a/src/rewriter/parser.py
+++ b/src/rewriter/parser.py
@@ -104,6 +104,7 @@ class SymbolTable:
self.consts = {
're.none': App(Op.REGEXP_EMPTY, []),
're.allchar': App(Op.REGEXP_SIGMA, []),
+ '_': Placeholder(),
}
self.symbols = {}
@@ -250,25 +251,25 @@ class Parser:
lambda s, l, t: self.var_decl_action(t[0], t[1], t[2:])))
return (pp.Suppress('(') + pp.ZeroOrMore(decl) + pp.Suppress(')'))
- def rule_action(self, name, cond, lhs, rhs, is_fixed_point):
+ def rule_action(self, name, cond, lhs, rhs, is_fixed_point, rhs_context):
bvars = self.symbols.symbols.values()
self.symbols.pop()
- return Rule(name, bvars, cond, lhs, rhs, is_fixed_point)
+ return Rule(name, bvars, cond, lhs, rhs, is_fixed_point, rhs_context)
def parse_rules(self, s):
rule = (
pp.Suppress('(') +
(pp.Keyword('define-rule*') | pp.Keyword('define-rule')) +
- self.symbol() + self.var_list() + self.expr() + self.expr() +
+ self.symbol() + self.var_list() + self.expr() + self.expr() + pp.Optional(self.expr()) +
pp.Suppress(')')).setParseAction(lambda s, l, t: self.rule_action(
- t[1], CBool(True), t[2], t[3], t[0] == 'define-rule*'))
+ t[1], CBool(True), t[2], t[3], t[0] == 'define-rule*', t[4] if len(t) == 5 else None))
cond_rule = (
pp.Suppress('(') +
(pp.Keyword('define-cond-rule*') | pp.Keyword('define-cond-rule'))
+ self.symbol() + self.var_list() + self.expr() + self.expr() +
- self.expr() +
+ self.expr() + pp.Optional(self.expr()) +
pp.Suppress(')')).setParseAction(lambda s, l, t: self.rule_action(
- t[1], t[2], t[3], t[4], t[0] == 'define-cond-rule*'))
+ t[1], t[2], t[3], t[4], t[0] == 'define-cond-rule*', t[5] if len(t) == 6 else None))
rules = pp.OneOrMore(rule | cond_rule) + pp.StringEnd()
rules.ignore(';' + pp.restOfLine)
return rules.parseString(s)
diff --git a/src/rewriter/rewrite_db.cpp b/src/rewriter/rewrite_db.cpp
index 15252ae9e..cd6abde37 100644
--- a/src/rewriter/rewrite_db.cpp
+++ b/src/rewriter/rewrite_db.cpp
@@ -118,7 +118,7 @@ void RewriteDb::addRule(DslPfRule id,
Node a,
Node b,
Node cond,
- bool isFixedPoint,
+ Node context,
bool isFlatForm)
{
NodeManager* nm = NodeManager::currentNM();
@@ -154,6 +154,7 @@ void RewriteDb::addRule(DslPfRule id,
<< " == " << b << std::endl;
Assert(a.getType().isComparableTo(b.getType()));
Node cr = d_canon.getCanonicalTerm(tmpi, false, false);
+ context = d_canon.getCanonicalTerm(context, false, false);
Node condC = cr[1];
std::vector<Node> conds;
@@ -229,7 +230,7 @@ void RewriteDb::addRule(DslPfRule id,
}
// initialize rule
- d_rewDbRule[id].init(id, ofvs, cfvs, conds, eqC, isFixedPoint, isFlatForm);
+ d_rewDbRule[id].init(id, ofvs, cfvs, conds, eqC, context, isFlatForm);
d_concToRules[eqC].push_back(id);
d_headToRules[eqC[0]].push_back(id);
}
diff --git a/src/rewriter/rewrite_db.h b/src/rewriter/rewrite_db.h
index 6725b864f..819ae0fb8 100644
--- a/src/rewriter/rewrite_db.h
+++ b/src/rewriter/rewrite_db.h
@@ -56,7 +56,7 @@ class RewriteDb
Node a,
Node b,
Node cond,
- bool isFixedPoint,
+ Node context,
bool isFlatForm = false);
/** get matches */
void getMatches(Node eq, expr::NotifyMatch* ntm);
diff --git a/src/rewriter/rewrite_db_proof_cons.cpp b/src/rewriter/rewrite_db_proof_cons.cpp
index 11b1624f1..e1282e0ba 100644
--- a/src/rewriter/rewrite_db_proof_cons.cpp
+++ b/src/rewriter/rewrite_db_proof_cons.cpp
@@ -715,7 +715,7 @@ Node RewriteDbProofCons::getRuleConclusion(const RewriteProofRule& rpr,
d_currFixedPointId = rpr.getId();
// check if stgt also rewrites with the same rule?
bool continueFixedPoint;
- std::vector<Node> transEq;
+ std::vector<Node> steps;
// start from the source, match again to start the chain. Notice this is
// required for uniformity since we want to successfully cache the first
// step, independent of the target.
@@ -728,13 +728,26 @@ Node RewriteDbProofCons::getRuleConclusion(const RewriteProofRule& rpr,
if (!d_currFixedPointConc.isNull())
{
// currently avoid accidental loops: arbitrarily bound to 1000
- continueFixedPoint = transEq.size() <= 1000;
+ continueFixedPoint = steps.size() <= 1000;
Assert(d_currFixedPointConc.getKind() == EQUAL);
- transEq.push_back(stgt.eqNode(d_currFixedPointConc[1]));
+ steps.push_back(d_currFixedPointConc[1]);
stgt = d_currFixedPointConc[1];
}
d_currFixedPointConc = Node::null();
} while (continueFixedPoint);
+
+ std::vector<Node> transEq;
+ Node prev = ssrc;
+ Node context = rpr.getContext();
+ Node placeholder = context[0][0];
+ for (Node& step : steps) {
+ Node stepConc = context[1].substitute(TNode(placeholder), TNode(step));
+ stepConc = expr::narySubstitute(stepConc, vars, subs);
+ transEq.push_back(prev.eqNode(stepConc));
+ prev = stepConc;
+ }
+ std::cout << transEq << std::endl;
+
d_currFixedPointId = DslPfRule::FAIL;
// add the transistivity rule here if needed
if (transEq.size() >= 2)
diff --git a/src/rewriter/rewrite_proof_rule.cpp b/src/rewriter/rewrite_proof_rule.cpp
index ace56f9fe..2f5ddf400 100644
--- a/src/rewriter/rewrite_proof_rule.cpp
+++ b/src/rewriter/rewrite_proof_rule.cpp
@@ -40,7 +40,7 @@ bool getDslPfRule(TNode n, DslPfRule& id)
}
RewriteProofRule::RewriteProofRule()
- : d_id(DslPfRule::FAIL), d_isFixedPoint(false), d_isFlatForm(false)
+ : d_id(DslPfRule::FAIL), d_isFlatForm(false)
{
}
@@ -49,7 +49,7 @@ void RewriteProofRule::init(DslPfRule id,
const std::vector<Node>& fvs,
const std::vector<Node>& cond,
Node conc,
- bool isFixedPoint,
+ Node context,
bool isFlatForm)
{
// not initialized yet
@@ -73,7 +73,7 @@ void RewriteProofRule::init(DslPfRule id,
d_obGen.push_back(cc);
}
d_conc = conc;
- d_isFixedPoint = isFixedPoint;
+ d_context = context;
d_isFlatForm = isFlatForm;
if (!expr::getListVarContext(conc, d_listVarCtx))
{
@@ -97,7 +97,7 @@ void RewriteProofRule::init(DslPfRule id,
}
}
// if fixed point, initialize match utility
- if (d_isFixedPoint)
+ if (d_context != Node::null())
{
d_mt.addTerm(conc[0]);
}
@@ -266,7 +266,7 @@ Node RewriteProofRule::getConclusionFor(const std::vector<Node>& ss) const
Assert(d_fvs.size() == ss.size());
return expr::narySubstitute(d_conc, d_fvs, ss);
}
-bool RewriteProofRule::isFixedPoint() const { return d_isFixedPoint; }
+bool RewriteProofRule::isFixedPoint() const { return d_context != Node::null(); }
bool RewriteProofRule::isFlatForm() const { return d_isFlatForm; }
} // namespace rewriter
} // namespace cvc5
diff --git a/src/rewriter/rewrite_proof_rule.h b/src/rewriter/rewrite_proof_rule.h
index 91279b2e7..c5cc475d5 100644
--- a/src/rewriter/rewrite_proof_rule.h
+++ b/src/rewriter/rewrite_proof_rule.h
@@ -61,7 +61,7 @@ class RewriteProofRule
const std::vector<Node>& fvs,
const std::vector<Node>& cond,
Node conc,
- bool isFixedPoint,
+ Node context,
bool isFlatForm);
/** get id */
DslPfRule getId() const;
@@ -71,6 +71,8 @@ class RewriteProofRule
const std::vector<Node>& getUserVarList() const;
/** Get variable list */
const std::vector<Node>& getVarList() const;
+ /** The context that the rule is applied in */
+ Node getContext() const { return d_context; }
/** Does this rule have conditions? */
bool hasConditions() const;
/** Get (declared) conditions */
@@ -134,8 +136,8 @@ class RewriteProofRule
std::vector<Node> d_obGen;
/** The conclusion of the rule (an equality) */
Node d_conc;
- /** Is the rule applied until a fixed point is reached? */
- bool d_isFixedPoint;
+ /** Is the rule applied in some fixed point context? */
+ Node d_context;
/** Whether the rule is in flat form */
bool d_isFlatForm;
/** the ordered list of free variables, provided by the user */
diff --git a/src/rewriter/rule.py b/src/rewriter/rule.py
index c3ef1f91e..164bf8dc1 100644
--- a/src/rewriter/rule.py
+++ b/src/rewriter/rule.py
@@ -1,15 +1,17 @@
class Rule:
- def __init__(self, name, bvars, cond, lhs, rhs, is_fixed_point):
+ def __init__(self, name, bvars, cond, lhs, rhs, is_fixed_point, rhs_context):
self.name = name
self.bvars = bvars
self.cond = cond
self.lhs = lhs
self.rhs = rhs
self.is_fixed_point = is_fixed_point
+ self.rhs_context = rhs_context
def get_enum(self):
return self.name.replace('-', '_').upper()
def __repr__(self):
- bvars_str = ' '.join(str(bvar) for bvar in bvars)
- return f"(define-rule {self.name} ({bvars_str}) {self.lhs} {self.rhs})"
+ bvars_str = ' '.join(str(bvar) for bvar in self.bvars)
+ rhs_context_str = f' {self.rhs_context}' if self.rhs_context else ''
+ return f"(define-rule {self.name} ({bvars_str}) {self.lhs} {self.rhs}{rhs_context_str})"
diff --git a/src/theory/strings/rewrites b/src/theory/strings/rewrites
index 1e0494ca6..8d14e1f22 100644
--- a/src/theory/strings/rewrites
+++ b/src/theory/strings/rewrites
@@ -23,10 +23,6 @@
(= (str.substr s n m) "")
(= s ""))
-(define-rule* str-len-concat ((s1 String) (s2 String) (s3 String :list))
- (str.len (str.++ s1 s2 s3))
- (+ (str.len s1) (str.len (str.++ s2 s3))))
-
(define-cond-rule str-len-replace-inv ((t String) (s String) (r String))
(= (str.len s) (str.len r))
(str.len (str.replace t s r))
@@ -97,3 +93,7 @@
(define-rule re-concat-star-swap ((xs RegLan :list) (r RegLan) (ys RegLan :list)) (re.++ xs (re.* r) r ys) (re.++ xs r (re.* r) ys))
+(define-rule* str-len-concat-rec ((s1 String) (s2 String) (s3 String :list))
+ (str.len (str.++ s1 s2 s3))
+ (str.len (str.++ s2 s3))
+ (+ (str.len s1) _))
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback