summaryrefslogtreecommitdiff
path: root/src/rewriter/mkrewrites.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/rewriter/mkrewrites.py')
-rw-r--r--src/rewriter/mkrewrites.py48
1 files changed, 45 insertions, 3 deletions
diff --git a/src/rewriter/mkrewrites.py b/src/rewriter/mkrewrites.py
index a0019308f..ae905dad9 100644
--- a/src/rewriter/mkrewrites.py
+++ b/src/rewriter/mkrewrites.py
@@ -16,6 +16,8 @@ def gen_kind(op):
Op.OR: 'OR',
Op.IMPLIES: 'IMPLIES',
Op.EQ: 'EQUAL',
+ Op.LAMBDA: 'LAMBDA',
+ Op.BOUND_VARS: 'BOUND_VAR_LIST',
Op.UMINUS: 'UMINUS',
Op.PLUS: 'PLUS',
Op.MINUS: 'MINUS',
@@ -122,7 +124,7 @@ def gen_mk_node(defns, expr):
def gen_rewrite_db_rule(defns, rule):
fvs_list = ', '.join(bvar.name for bvar in rule.bvars)
- fixed_point_arg = 'true' if rule.is_fixed_point else 'false'
+ fixed_point_arg = gen_mk_node(defns, rule.rhs_context) if rule.rhs_context else 'Node::null()'
return f'db.addRule(DslPfRule::{rule.get_enum()}, {{ {fvs_list} }}, {gen_mk_node(defns, rule.lhs)}, {gen_mk_node(defns, rule.rhs)}, {gen_mk_node(defns, rule.cond)}, {fixed_point_arg});'
@@ -147,6 +149,8 @@ def type_check(expr):
sort = None
if expr.op == Op.UMINUS:
sort = Sort(BaseSort.Int)
+ elif expr.op == Op.STRING_LENGTH:
+ sort = Sort(BaseSort.Int)
if sort:
sort.is_const = all(child.sort and child.sort.is_const
@@ -188,6 +192,43 @@ def validate_rule(rule):
type_check(rule.cond)
+def preprocess_rule(rule, decls):
+ if not rule.rhs_context:
+ return
+
+ # Resolve placeholders
+ bvar = Var(fresh_name('t'), rule.rhs.sort)
+ decls.append(bvar)
+ result = dict()
+ to_visit = [rule.rhs_context]
+ while to_visit:
+ curr = to_visit.pop()
+
+ if isinstance(curr, App) and curr in result:
+ if result[curr]:
+ continue
+
+ new_args = []
+ for child in curr.children:
+ new_args.append(result[child])
+
+ result[curr] = App(curr.op, new_args)
+ continue
+
+ if isinstance(curr, Placeholder):
+ result[curr] = bvar
+ elif isinstance(curr, App):
+ to_visit.append(curr)
+ result[curr] = None
+ else:
+ result[curr] = curr
+
+ to_visit.extend(curr.children)
+
+ rule.rhs_context = App(Op.LAMBDA, [App(Op.BOUND_VARS, [bvar]), result[rule.rhs_context]])
+ type_check(rule.rhs_context)
+
+
def gen_rewrite_db(args):
block_tpl = '''
{{
@@ -207,6 +248,7 @@ def gen_rewrite_db(args):
for rule in rules:
file_decls.extend(rule.bvars)
validate_rule(rule)
+ preprocess_rule(rule, file_decls)
rewrites.append(Rewrites(rewrites_file.name, file_decls, rules))
decls.extend(file_decls)
@@ -215,12 +257,12 @@ def gen_rewrite_db(args):
expr_counts = defaultdict(lambda: 0)
to_visit = [
expr for rewrite in rewrites for rule in rewrite.rules
- for expr in [rule.cond, rule.lhs, rule.rhs]
+ for expr in [rule.cond, rule.lhs, rule.rhs, rule.rhs_context]
]
while to_visit:
curr = to_visit.pop()
- if isinstance(curr, Var):
+ if not curr or isinstance(curr, Var):
# Don't generate definitions for variables
continue
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback