diff options
Diffstat (limited to 'src/rewriter/mkrewrites.py')
-rw-r--r-- | src/rewriter/mkrewrites.py | 48 |
1 files changed, 45 insertions, 3 deletions
diff --git a/src/rewriter/mkrewrites.py b/src/rewriter/mkrewrites.py index a0019308f..ae905dad9 100644 --- a/src/rewriter/mkrewrites.py +++ b/src/rewriter/mkrewrites.py @@ -16,6 +16,8 @@ def gen_kind(op): Op.OR: 'OR', Op.IMPLIES: 'IMPLIES', Op.EQ: 'EQUAL', + Op.LAMBDA: 'LAMBDA', + Op.BOUND_VARS: 'BOUND_VAR_LIST', Op.UMINUS: 'UMINUS', Op.PLUS: 'PLUS', Op.MINUS: 'MINUS', @@ -122,7 +124,7 @@ def gen_mk_node(defns, expr): def gen_rewrite_db_rule(defns, rule): fvs_list = ', '.join(bvar.name for bvar in rule.bvars) - fixed_point_arg = 'true' if rule.is_fixed_point else 'false' + fixed_point_arg = gen_mk_node(defns, rule.rhs_context) if rule.rhs_context else 'Node::null()' return f'db.addRule(DslPfRule::{rule.get_enum()}, {{ {fvs_list} }}, {gen_mk_node(defns, rule.lhs)}, {gen_mk_node(defns, rule.rhs)}, {gen_mk_node(defns, rule.cond)}, {fixed_point_arg});' @@ -147,6 +149,8 @@ def type_check(expr): sort = None if expr.op == Op.UMINUS: sort = Sort(BaseSort.Int) + elif expr.op == Op.STRING_LENGTH: + sort = Sort(BaseSort.Int) if sort: sort.is_const = all(child.sort and child.sort.is_const @@ -188,6 +192,43 @@ def validate_rule(rule): type_check(rule.cond) +def preprocess_rule(rule, decls): + if not rule.rhs_context: + return + + # Resolve placeholders + bvar = Var(fresh_name('t'), rule.rhs.sort) + decls.append(bvar) + result = dict() + to_visit = [rule.rhs_context] + while to_visit: + curr = to_visit.pop() + + if isinstance(curr, App) and curr in result: + if result[curr]: + continue + + new_args = [] + for child in curr.children: + new_args.append(result[child]) + + result[curr] = App(curr.op, new_args) + continue + + if isinstance(curr, Placeholder): + result[curr] = bvar + elif isinstance(curr, App): + to_visit.append(curr) + result[curr] = None + else: + result[curr] = curr + + to_visit.extend(curr.children) + + rule.rhs_context = App(Op.LAMBDA, [App(Op.BOUND_VARS, [bvar]), result[rule.rhs_context]]) + type_check(rule.rhs_context) + + def gen_rewrite_db(args): block_tpl = ''' {{ @@ -207,6 +248,7 @@ def gen_rewrite_db(args): for rule in rules: file_decls.extend(rule.bvars) validate_rule(rule) + preprocess_rule(rule, file_decls) rewrites.append(Rewrites(rewrites_file.name, file_decls, rules)) decls.extend(file_decls) @@ -215,12 +257,12 @@ def gen_rewrite_db(args): expr_counts = defaultdict(lambda: 0) to_visit = [ expr for rewrite in rewrites for rule in rewrite.rules - for expr in [rule.cond, rule.lhs, rule.rhs] + for expr in [rule.cond, rule.lhs, rule.rhs, rule.rhs_context] ] while to_visit: curr = to_visit.pop() - if isinstance(curr, Var): + if not curr or isinstance(curr, Var): # Don't generate definitions for variables continue |