From eb5d13fbb251b6b87579086566c18e9050b858ac Mon Sep 17 00:00:00 2001 From: Andres Noetzli Date: Mon, 13 Apr 2020 01:12:03 -0700 Subject: LHS eval + optimization --- src/proof/proof_manager.cpp | 6 +-- src/proof/rewrite_proof.cpp | 4 +- src/theory/bv/theory_bv_rewriter.cpp | 8 ++++ src/theory/rewriter/compiler.py | 79 ++++++++++++++++++++++------------- src/theory/rewriter/ir.py | 62 +++++++++++++++++++++++---- src/theory/rewriter/node.py | 30 +++++++++++++ src/theory/rewriter/parser.py | 2 +- src/theory/rewriter/rules/basic.rules | 12 +++++- 8 files changed, 159 insertions(+), 44 deletions(-) diff --git a/src/proof/proof_manager.cpp b/src/proof/proof_manager.cpp index 419095cf3..f59c6cd0d 100644 --- a/src/proof/proof_manager.cpp +++ b/src/proof/proof_manager.cpp @@ -878,20 +878,20 @@ void LFSCProof::printPreprocessedAssertions(const NodeSet& assertions, << std::endl; os << "(th_let_pf _ "; - RewriteProof rp; + // RewriteProof rp; if ((*it).getKind() == kind::NOT && (*it)[0] == NodeManager::currentNM()->mkConst(false)) { os << "t_eq_n_f "; } - else if (theory::Rewriter::rewriteWithProof(inputAssertion, &rp) == *it) + /*else if (theory::Rewriter::rewriteWithProof(inputAssertion, &rp) == *it) { theory::rules::RewriteProofPrinter::printProof( ProofManager::currentPM()->getTheoryProofEngine(), rp, os, globalLetMap); - } + }*/ else { os << "(trust_f (iff "; diff --git a/src/proof/rewrite_proof.cpp b/src/proof/rewrite_proof.cpp index e0b7c171c..dfb0e6f4d 100644 --- a/src/proof/rewrite_proof.cpp +++ b/src/proof/rewrite_proof.cpp @@ -204,8 +204,8 @@ void RewriteProof::printCachedProofs(TheoryProofEngine* tp, { os << std::endl; os << "(@ let" << iter->second->d_id << " "; - theory::rules::RewriteProofPrinter::printRewriteProof( - false, tp, iter->second, os, globalLetMap); + //theory::rules::RewriteProofPrinter::printRewriteProof( + // false, tp, iter->second, os, globalLetMap); paren << ")"; } } diff --git a/src/theory/bv/theory_bv_rewriter.cpp b/src/theory/bv/theory_bv_rewriter.cpp index c4975d710..25db555fa 100644 --- a/src/theory/bv/theory_bv_rewriter.cpp +++ b/src/theory/bv/theory_bv_rewriter.cpp @@ -56,11 +56,13 @@ RewriteResponse TheoryBVRewriter::postRewrite(TNode node) { } RewriteResponse TheoryBVRewriter::RewriteUlt(TNode node, bool prerewrite) { + /* RewriteResponse response = rules::UltZero(node); if (response.d_node != node) { return response; } + */ // reduce common subexpressions on both sides Node resultNode = LinearRewriteStrategy @@ -418,11 +420,13 @@ RewriteResponse TheoryBVRewriter::RewriteSub(TNode node, bool prerewrite){ RewriteResponse TheoryBVRewriter::RewriteNeg(TNode node, bool prerewrite) { Node resultNode = node; + /* RewriteResponse response = rules::NegIdemp(node); if (response.d_node != node) { return response; } + */ resultNode = LinearRewriteStrategy < RewriteRule, @@ -574,11 +578,13 @@ RewriteResponse TheoryBVRewriter::RewriteRepeat(TNode node, bool prerewrite) { } RewriteResponse TheoryBVRewriter::RewriteZeroExtend(TNode node, bool prerewrite){ + /* RewriteResponse response = rules::ZeroExtendNZ(node); if (response.d_node != node) { return response; } + */ Node resultNode = LinearRewriteStrategy < RewriteRule @@ -659,11 +665,13 @@ RewriteResponse TheoryBVRewriter::RewriteIntToBV(TNode node, bool prerewrite) { } RewriteResponse TheoryBVRewriter::RewriteEqual(TNode node, bool prerewrite) { + /* RewriteResponse response = rules::ReflexivityEq(node); if (response.d_node != node) { return response; } + */ if (prerewrite) { Node resultNode = LinearRewriteStrategy diff --git a/src/theory/rewriter/compiler.py b/src/theory/rewriter/compiler.py index c12b4cb70..cda5e2261 100755 --- a/src/theory/rewriter/compiler.py +++ b/src/theory/rewriter/compiler.py @@ -33,7 +33,8 @@ op_to_kind = { } op_to_const_eval = { - Op.BVSHL: '{}.leftShift({})', + Op.BVSHL: '({}.leftShift({}))', + Op.BVNOT: '(~{})', Op.PLUS: '({} + {})', Op.MINUS: '({} - {})', Op.EQ: '({} == {})', @@ -82,17 +83,21 @@ op_to_nindex = { def rule_to_in_ir(rvars, lhs): def expr_to_ir(expr, path, vars_seen, out_ir, in_index = False): if isinstance(expr, Fn): - out_ir.append( - Assert( - Fn(Op.EQ, - [Fn(Op.GET_KIND, [GetChild(path)]), - KindConst(expr.op)]))) - for i, child in enumerate(expr.children): - index = i if i < op_to_nindex[expr.op] else i - op_to_nindex[expr.op] - expr_to_ir(child, path + [index], vars_seen, out_ir, i < op_to_nindex[expr.op]) + if expr.sort.const: + out_ir.append( + Assert(Fn( + Op.EQ, + [GetChild(path), Fn(Op.MK_CONST, [expr])]))) + else: + out_ir.append( + Assert( + Fn(Op.EQ, + [Fn(Op.GET_KIND, [GetChild(path)]), + KindConst(expr.op)]))) + for i, child in enumerate(expr.children): + index = i if i < op_to_nindex[expr.op] else i - op_to_nindex[expr.op] + expr_to_ir(child, path + [index], vars_seen, out_ir, i < op_to_nindex[expr.op]) - if isinstance(expr.op, Fn): - pass elif isinstance(expr, Var): if expr.name in vars_seen: @@ -103,16 +108,18 @@ def rule_to_in_ir(rvars, lhs): if in_index: index_expr = GetIndex(path) index_expr.sort = Sort(BaseSort.Int, []) - out_ir.append(Assign(expr.name, index_expr)) + out_ir.append(Assign(expr, index_expr)) else: - out_ir.append(Assign(expr.name, GetChild(path))) + out_ir.append(Assign(expr, GetChild(path))) if expr.sort is not None and expr.sort.base == BaseSort.BitVec: width = expr.sort.children[0] if isinstance(width, Var) and not width.name in vars_seen: bv_size_expr = Fn(Op.BV_SIZE, [GetChild(path)]) - bv_size_expr.sort = Sort(BaseSort.Int, []) - out_ir.append(Assign(width.name, bv_size_expr)) + bv_size_expr.sort = Sort(BaseSort.Int, [], True) + # TODO: should resolve earlier? + width.sort = Sort(BaseSort.Int, [], True) + out_ir.append(Assign(width, bv_size_expr)) vars_seen.add(width.name) vars_seen.add(expr.name) @@ -206,12 +213,19 @@ def rule_to_out_expr(cfg, next_block, res, expr): for var, child in zip(new_vars, expr.children): next_block = rule_to_out_expr(cfg, next_block, var, child) return next_block - elif isinstance(expr, BoolConst) or isinstance(expr, BVConst): - return Fn(Op.MK_CONST, [expr]) + elif isinstance(expr, BoolConst): + assign_block = fresh_name('block') + res.sort = Sort(BaseSort.Bool, []) + assign = Assign(res, expr) + if next_block: + cfg[assign_block] = CFGNode([assign], [CFGEdge(BoolConst(True), next_block)]) + else: + assign.expr = Fn(Op.MK_CONST, [assign.expr]) + cfg[assign_block] = CFGNode([assign], []) + return assign_block elif isinstance(expr, IntConst): assign_block = fresh_name('block') res.sort = Sort(BaseSort.Int, []) - print('{} should be int {}'.format(res, expr.val)) assign = Assign(res, expr) if next_block: cfg[assign_block] = CFGNode([assign], [CFGEdge(BoolConst(True), next_block)]) @@ -272,20 +286,20 @@ def expr_to_code(expr): def sort_to_code(sort): - if not sort: + if not sort or not sort.const: # TODO: should not happen return 'Node' elif sort.base == BaseSort.Int: return 'uint32_t' elif sort.base == BaseSort.BitVec: - return 'BitVector' if sort.const else 'Node' + return 'BitVector' def ir_to_code(match_instrs): code = [] for instr in match_instrs: if isinstance(instr, Assign): - code.append('{} {} = {};'.format(sort_to_code(instr.expr.sort), + code.append('{} = {};'.format( instr.name, expr_to_code(instr.expr))) elif isinstance(instr, Assert): @@ -327,6 +341,7 @@ def gen_rule(rule): RewriteResponse {}(TNode __node) {{ NodeManager* nm = NodeManager::currentNM(); {} + {} return RewriteResponse(REWRITE_AGAIN, {}, RewriteRule::{}); }}""" @@ -335,14 +350,20 @@ def gen_rule(rule): entry = rule_to_out_expr(cfg, None, out_var, rule.rhs) match_block_name = fresh_name('block') cfg[match_block_name] = CFGNode(match_block, [CFGEdge(BoolConst(True), entry)]) - - optimize_cfg(match_block_name, cfg) - print(cfg_to_str(cfg)) out_ir = rule_to_out_expr(cfg, None, out_var, rule.rhs) + + optimize_cfg(out_var, match_block_name, cfg) # ir = in_ir + [rule.cond] + out_ir # opt_ir = optimize_ir(out_var, ir) + + cfg_vars = cfg_collect_vars(cfg) + var_decls = '' + for var in cfg_vars: + var_decls += '{} {};\n'.format(sort_to_code(var.sort), var.name) + body = cfg_to_code(match_block_name, cfg) - result = rule_pattern.format(rule.name, body, out_var, + + result = rule_pattern.format(rule.name, var_decls, body, out_var, name_to_enum(rule.name)) print(result) return result @@ -662,7 +683,7 @@ def type_check(rules): infer_types(rule.rvars, rule.rhs) # Ensure that we were able to compute the types for both sides - assert rule.lhs.sort is not None and rule.rhs.sort is not None + assert isinstance(rule.lhs.sort, Sort) and isinstance(rule.rhs.sort, Sort) def main(): @@ -700,10 +721,10 @@ def main(): args.rulesfile.write(gen_enum(rules)) args.implementationfile.write(gen_rules_implementation(rules)) - args.printerfile.write(gen_proof_printer(rules)) + #args.printerfile.write(gen_proof_printer(rules)) - for rule in rules: - rule_to_lfsc(rule) + #for rule in rules: + # rule_to_lfsc(rule) if __name__ == "__main__": diff --git a/src/theory/rewriter/ir.py b/src/theory/rewriter/ir.py index 76c27af2b..5aabfcc3c 100644 --- a/src/theory/rewriter/ir.py +++ b/src/theory/rewriter/ir.py @@ -1,4 +1,6 @@ -from node import collect_vars, BoolConst +from collections import defaultdict + +from node import collect_vars, count_vars, subst, BoolConst, Var class CFGEdge: def __init__(self, cond, target): @@ -30,6 +32,7 @@ class IRNode: class Assign(IRNode): def __init__(self, name, expr): super(IRNode, self).__init__() + assert isinstance(name, Var) self.name = name self.expr = expr @@ -60,7 +63,8 @@ def optimize_ir(out_var, instrs): return opt_instrs -def optimize_cfg(entry, cfg): +def optimize_cfg(out_var, entry, cfg): + # Merge basic blocks change = True while change: change = False @@ -71,16 +75,60 @@ def optimize_cfg(entry, cfg): cfg[label].instrs += next_block.instrs cfg[label].edges = next_block.edges + # Remove unused blocks not_called = set(cfg.keys()) - not_called.remove(entry) - for label, node in cfg.items(): - for edge in node.edges: - if edge.target in not_called: - not_called.remove(edge.target) + to_visit = [entry] + while to_visit: + curr = to_visit[-1] + to_visit.pop() + + not_called.remove(curr) + for edge in cfg[curr].edges: + to_visit.append(edge.target) for target in not_called: del cfg[target] + # Inline assignments that are used only once + used_count = defaultdict(lambda: 0) + substs = dict() + for label, node in cfg.items(): + for instr in node.instrs: + if isinstance(instr, Assert): + count_vars(used_count, instr.expr) + elif isinstance(instr, Assign): + count_vars(used_count, instr.expr) + substs[instr.name] = instr.expr + + del substs[out_var] + + # Only keep the substitutions for variables that are unused or appear once + substs = dict(filter(lambda kv: used_count[kv[0]] <= 1, substs.items())) + + # Apply substitutions to themselves + # Note: since each substituted variable can appear at most once, each + # substitution cannot be applied more than once, so it should be enough to + # do this n times where n is the number of substitutions + for i in range(len(substs)): + substs = {name:subst(expr, substs) for name, expr in substs.items()} + + # Remove unused instructions and apply substitutions + for label, node in cfg.items(): + node.instrs = list(filter(lambda instr: (not isinstance(instr, Assign)) or (instr.name not in substs), node.instrs)) + for instr in node.instrs: + instr.expr = subst(instr.expr, substs) + + for edge in node.edges: + edge.cond = subst(edge.cond, substs) + +def cfg_collect_vars(cfg): + cfg_vars = set() + for label, node in cfg.items(): + for instr in node.instrs: + if isinstance(instr, Assign): + cfg_vars.add(instr.name) + return cfg_vars + def cfg_to_str(cfg): result = '' for label, node in cfg.items(): diff --git a/src/theory/rewriter/node.py b/src/theory/rewriter/node.py index 0a4d9ba8e..b2faa648d 100644 --- a/src/theory/rewriter/node.py +++ b/src/theory/rewriter/node.py @@ -157,6 +157,12 @@ class GetChild(Node): def __repr__(self): return '__node' + ''.join('[{}]'.format(p) for p in self.path) + def __eq__(self, other): + return isinstance(other, GetChild) and self.path == other.path + + def __hash__(self): + return hash(tuple(self.path)) + class GetIndex(Node): def __init__(self, path): super().__init__([]) @@ -165,6 +171,12 @@ class GetIndex(Node): def __repr__(self): return 'index(__node{}, {})'.format(''.join('[{}]'.format(p) for p in self.path[:-1]), self.path[-1]) + def __eq__(self, other): + return isinstance(other, GetIndex) and self.path == other.path + + def __hash__(self): + return hash(tuple(self.path)) + class Fn(Node): def __init__(self, op, args): super().__init__(args) @@ -208,6 +220,24 @@ def collect_vars(node): return result +def count_vars(counts, node): + if isinstance(node, Var): + counts[node] += 1 + else: + for child in node.children: + count_vars(counts, child) + + +def subst(node, substs): + # TODO: non-destructive substitution? + if node in substs: + return substs[node] + else: + new_children = [] + for child in node.children: + new_children.append(subst(child, substs)) + node.children = new_children + return node def unify_types(t1, t2): assert t1.base == t2.base diff --git a/src/theory/rewriter/parser.py b/src/theory/rewriter/parser.py index 09d4cf141..7a977c774 100644 --- a/src/theory/rewriter/parser.py +++ b/src/theory/rewriter/parser.py @@ -93,7 +93,7 @@ def sort(): bv_sort = (pp.Suppress('(') + (pp.Suppress('_') + pp.Keyword('BitVec')) + expr() + pp.Suppress(')') ).setParseAction(lambda s, l, t: Sort(BaseSort.BitVec, [t[1]])) - int_sort = pp.Keyword('Int').setParseAction( + int_sort = pp.Keyword('Num').setParseAction( lambda s, l, t: Sort(BaseSort.Int, [], True)) return bv_sort | int_sort diff --git a/src/theory/rewriter/rules/basic.rules b/src/theory/rewriter/rules/basic.rules index 6bee44747..cb058656b 100644 --- a/src/theory/rewriter/rules/basic.rules +++ b/src/theory/rewriter/rules/basic.rules @@ -1,10 +1,18 @@ -(define-rule ZeroExtendEliminate ((i Int) (x (_ BitVec n))) +(define-rule NegIdemp ((n Num) (x (_ BitVec n))) + (bvneg (bvneg x)) + x) + +(define-rule ZeroExtendEliminate ((i Num) (x (_ BitVec n))) ((_ zero_extend i) x) (cond ((= i 0) x) (concat (_ bv 0 i) x))) -(define-rule SltEliminate ((n Int) (x (_ BitVec n)) (y (_ BitVec n))) +(define-rule SltEliminate ((n Num) (x (_ BitVec n)) (y (_ BitVec n))) (bvslt x y) (let ((pow_two (bvshl (_ bv 1 n) (_ bv (- n 1) n)))) (bvult (bvadd x pow_two) (bvadd y pow_two)))) + +(define-rule UleMax ((n Num) (x (_ BitVec n))) + (bvult x (bvnot (_ bv 0 n))) + true) -- cgit v1.2.3