summaryrefslogtreecommitdiff
path: root/src/theory/rewriter/ir.py
blob: 5aabfcc3c6589f042f60eaaf214c3a190ec02ca6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from collections import defaultdict

from node import collect_vars, count_vars, subst, BoolConst, Var

class CFGEdge:
    def __init__(self, cond, target):
        self.cond = cond
        self.target = target

    def __repr__(self):
        return '{} -> {}'.format(self.cond, self.target)

class CFGNode:
    def __init__(self, instrs, edges):
        self.instrs = instrs
        self.edges = edges

    def __repr__(self):
        result = ''
        for instr in self.instrs:
            result += '{}\n'.format(instr)

        for edge in self.edges:
            result += str(edge) + '\n'

        return result

class IRNode:
    def __init__(self):
        pass

class Assign(IRNode):
    def __init__(self, name, expr):
        super(IRNode, self).__init__()
        assert isinstance(name, Var)
        self.name = name
        self.expr = expr

    def __repr__(self):
        return '{} := {}'.format(self.name, self.expr)

class Assert(IRNode):
    def __init__(self, expr):
        super(IRNode, self).__init__()
        self.expr = expr

    def __repr__(self):
        return 'assert {}'.format(self.expr)


def optimize_ir(out_var, instrs):
    used_vars = set([out_var])
    for instr in instrs:
        if isinstance(instr, Assert):
            used_vars |= collect_vars(instr.expr)
        elif isinstance(instr, Assign):
            used_vars |= collect_vars(instr.expr)

    opt_instrs = []
    for instr in instrs:
        if not(isinstance(instr, Assign)) or instr.name in used_vars:
            opt_instrs.append(instr)
    return opt_instrs


def optimize_cfg(out_var, entry, cfg):
    # Merge basic blocks
    change = True
    while change:
        change = False
        for label, node in cfg.items():
            if len(node.edges) == 1:
                assert node.edges[0].cond == BoolConst(True)
                next_block = cfg[node.edges[0].target]
                cfg[label].instrs += next_block.instrs
                cfg[label].edges = next_block.edges

    # Remove unused blocks
    not_called = set(cfg.keys())
    to_visit = [entry]
    while to_visit:
        curr = to_visit[-1]
        to_visit.pop()

        not_called.remove(curr)
        for edge in cfg[curr].edges:
            to_visit.append(edge.target)

    for target in not_called:
        del cfg[target]

    # Inline assignments that are used only once
    used_count = defaultdict(lambda: 0)
    substs = dict()
    for label, node in cfg.items():
        for instr in node.instrs:
            if isinstance(instr, Assert):
                count_vars(used_count, instr.expr)
            elif isinstance(instr, Assign):
                count_vars(used_count, instr.expr)
                substs[instr.name] = instr.expr

    del substs[out_var]

    # Only keep the substitutions for variables that are unused or appear once
    substs = dict(filter(lambda kv: used_count[kv[0]] <= 1, substs.items()))
    
    # Apply substitutions to themselves
    # Note: since each substituted variable can appear at most once, each
    # substitution cannot be applied more than once, so it should be enough to
    # do this n times where n is the number of substitutions
    for i in range(len(substs)):
        substs = {name:subst(expr, substs) for name, expr in substs.items()}

    # Remove unused instructions and apply substitutions
    for label, node in cfg.items():
        node.instrs = list(filter(lambda instr: (not isinstance(instr, Assign)) or (instr.name not in substs), node.instrs))
        for instr in node.instrs:
            instr.expr = subst(instr.expr, substs)

        for edge in node.edges:
            edge.cond = subst(edge.cond, substs)

def cfg_collect_vars(cfg):
    cfg_vars = set()
    for label, node in cfg.items():
        for instr in node.instrs:
            if isinstance(instr, Assign):
                cfg_vars.add(instr.name)
    return cfg_vars

def cfg_to_str(cfg):
    result = ''
    for label, node in cfg.items():
        result += '{}:\n'.format(label)
        result += str(node)
        result += '\n'
    return result
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback