diff options
author | Andres Noetzli <andres.noetzli@gmail.com> | 2020-03-16 00:01:03 -0700 |
---|---|---|
committer | Andres Noetzli <andres.noetzli@gmail.com> | 2020-03-16 00:01:03 -0700 |
commit | e963eed14cefab140cf31fdc8c16f1d059d5a716 (patch) | |
tree | 188fd01d8bf9197b01dd5c6f6b982f741f6a586d | |
parent | 227cd8c26c508b7b444fbed6f2868f90c8281eed (diff) |
Basic rules compiling
-rw-r--r-- | src/CMakeLists.txt | 3 | ||||
-rw-r--r-- | src/theory/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/theory/bv/theory_bv_rewriter.cpp | 1 | ||||
-rw-r--r-- | src/theory/rewriter/CMakeLists.txt | 18 | ||||
-rw-r--r-- | src/theory/rewriter/__init__.py | 0 | ||||
-rwxr-xr-x | src/theory/rewriter/compiler.py | 216 | ||||
-rw-r--r-- | src/theory/rewriter/ir.py | 33 | ||||
-rw-r--r-- | src/theory/rewriter/node.py | 122 | ||||
-rw-r--r-- | src/theory/rewriter/parser.py | 62 | ||||
-rw-r--r-- | src/theory/rewriter/rule.py | 7 | ||||
-rw-r--r-- | src/theory/rewriter/rules/basic.rules | 47 |
11 files changed, 510 insertions, 1 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index bb2b95960..bcc3753c2 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -796,7 +796,8 @@ target_compile_definitions(cvc4 -D__STDC_FORMAT_MACROS ) # Add libcvc4 dependencies for generated sources. -add_dependencies(cvc4 gen-expr gen-gitinfo gen-options gen-tags gen-theory) +add_dependencies(cvc4 + gen-expr gen-gitinfo gen-options gen-tags gen-theory gen-rewriter) # Add library/include dependencies if(ENABLE_VALGRIND) diff --git a/src/theory/CMakeLists.txt b/src/theory/CMakeLists.txt index 4c2f66a0e..3f33356a1 100644 --- a/src/theory/CMakeLists.txt +++ b/src/theory/CMakeLists.txt @@ -43,3 +43,5 @@ add_custom_target(gen-theory theory_traits.h rewriter_tables.h ) + +add_subdirectory(rewriter) diff --git a/src/theory/bv/theory_bv_rewriter.cpp b/src/theory/bv/theory_bv_rewriter.cpp index 6a04d6e4e..7bda28923 100644 --- a/src/theory/bv/theory_bv_rewriter.cpp +++ b/src/theory/bv/theory_bv_rewriter.cpp @@ -24,6 +24,7 @@ #include "theory/bv/theory_bv_rewrite_rules_simplification.h" #include "theory/bv/theory_bv_rewriter.h" #include "theory/theory.h" +#include "theory/rewriter/rules.h" using namespace CVC4; using namespace CVC4::theory; diff --git a/src/theory/rewriter/CMakeLists.txt b/src/theory/rewriter/CMakeLists.txt new file mode 100644 index 000000000..fe7b7b0aa --- /dev/null +++ b/src/theory/rewriter/CMakeLists.txt @@ -0,0 +1,18 @@ +libcvc4_add_sources(GENERATED + rules.h +) + +add_custom_command( + OUTPUT rules.h + COMMAND + ${CMAKE_CURRENT_LIST_DIR}/compiler.py + ${CMAKE_CURRENT_LIST_DIR}/rules/basic.rules + > ${CMAKE_CURRENT_BINARY_DIR}/rules.h + DEPENDS + ${CMAKE_CURRENT_LIST_DIR}/compiler.py +) + +add_custom_target(gen-rewriter + DEPENDS + rules.h +) diff --git a/src/theory/rewriter/__init__.py b/src/theory/rewriter/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/src/theory/rewriter/__init__.py diff --git a/src/theory/rewriter/compiler.py b/src/theory/rewriter/compiler.py new file mode 100755 index 000000000..5fd54e6a8 --- /dev/null +++ b/src/theory/rewriter/compiler.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 + +import argparse +import sys + +from subprocess import Popen, PIPE, STDOUT +from ir import Assign, Assert, optimize_ir +from node import * +from rule import Rule +from parser import parse_rules + +op_to_kind = { + Op.BVSGT: 'BITVECTOR_SGT', + Op.BVSLT: 'BITVECTOR_SLT', + Op.BVULT: 'BITVECTOR_ULT', + Op.BVULE: 'BITVECTOR_ULE', + Op.BVNEG: 'BITVECTOR_NEG', + Op.ZERO_EXTEND: 'BITVECTOR_ZERO_EXTEND', + Op.NOT: 'NOT', + Op.EQ: 'EQUAL', +} + + +def rule_to_in_ir(rvars, lhs): + def expr_to_ir(expr, path, vars_seen, out_ir): + if isinstance(expr, Fn): + out_ir.append( + Assert( + Fn(Op.EQ, + [Fn(Op.GET_KIND, [GetChild(path)]), + KindConst(expr.op)]))) + for i, child in enumerate(expr.children): + expr_to_ir(child, path + [i], vars_seen, out_ir) + + if isinstance(expr.op, Fn): + pass + + elif isinstance(expr, Var): + if expr.name in vars_seen: + out_ir.append( + Assert(Fn(Op.EQ, + [Var(expr.name), GetChild(path)]))) + else: + out_ir.append(Assign(expr.name, GetChild(path))) + + if expr.sort is not None and expr.sort.base == BaseSort.BitVec: + width = expr.sort.args[0] + if isinstance(width, Var) and not width.name in vars_seen: + bv_size_expr = Fn(Op.BV_SIZE, [GetChild(path)]) + bv_size_expr.sort = Sort(BaseSort.Int, []) + out_ir.append( + Assign( + width.name, + bv_size_expr)) + vars_seen.add(width.name) + + vars_seen.add(expr.name) + elif isinstance(expr, BVConst): + if isinstance(expr.bw, Var) and not expr.bw.name in vars_seen: + bv_size_expr = Fn(Op.BV_SIZE, [GetChild(path)]) + bv_size_expr.sort = Sort(BaseSort.Int, []) + out_ir.append( + Assign( + expr.bw.name, + bv_size_expr)) + vars_seen.add(expr.bw.name) + + out_ir.append( + Assert(Fn( + Op.EQ, + [GetChild(path), Fn(Op.MK_CONST, [expr])]))) + elif isinstance(expr, IntConst): + out_ir.append( + Assert( + Fn(Op.EQ, + [Fn(Op.GET_KIND, [GetChild(path)]), + KindConst(expr.op)]))) + + out_ir = [] + vars_seen = set() + + expr_to_ir(lhs, [], vars_seen, out_ir) + return out_ir + + +def rule_to_out_expr(expr): + if isinstance(expr, Fn): + new_children = [rule_to_out_expr(child) for child in expr.children] + return Fn(Op.MK_NODE, [KindConst(expr.op)] + new_children) + elif isinstance(expr, BoolConst) or isinstance(expr, BVConst): + return Fn(Op.MK_CONST, [expr]) + else: + return expr + + +def expr_to_code(expr): + if isinstance(expr, Fn): + args = [expr_to_code(child) for child in expr.children] + if expr.op == Op.EQ: + return '({} == {})'.format(args[0], args[1]) + elif expr.op == Op.GET_KIND: + return '{}.getKind()'.format(args[0]) + elif expr.op == Op.BV_SIZE: + return 'utils::getSize({})'.format(args[0]) + elif expr.op == Op.MK_CONST: + return 'nm->mkConst({})'.format(', '.join(args)) + elif expr.op == Op.MK_NODE: + return 'nm->mkNode({})'.format(', '.join(args)) + elif isinstance(expr, GetChild): + path_str = ''.join(['[{}]'.format(i) for i in expr.path]) + return '__node{}'.format(path_str) + elif isinstance(expr, BoolConst): + return ('true' if expr.val else 'false') + elif isinstance(expr, BVConst): + bw_code = expr_to_code(expr.bw) + return 'BitVector({}, Integer({}))'.format(bw_code, expr.val) + elif isinstance(expr, KindConst): + return 'kind::{}'.format(op_to_kind[expr.val]) + elif isinstance(expr, Var): + return expr.name + + +def sort_to_code(sort): + return 'uint32_t' if sort and sort.base == BaseSort.Int else 'Node' + + +def ir_to_code(match_instrs): + code = [] + for instr in match_instrs: + if isinstance(instr, Assign): + code.append('{} {} = {};'.format(sort_to_code(instr.expr.sort), instr.name, + expr_to_code(instr.expr))) + elif isinstance(instr, Assert): + code.append('if (!({})) return __node;'.format( + expr_to_code(instr.expr))) + + return '\n'.join(code) + + +def gen_rule(rule): + out_var = '__ret' + rule_pattern = """ + Node {}(TNode __node) {{ + NodeManager* nm = NodeManager::currentNM(); + {} + return {}; + }}""" + + infer_types(rule.rvars, rule.lhs) + in_ir = rule_to_in_ir(rule.rvars, rule.lhs) + out_ir = [Assign(out_var, rule_to_out_expr(rule.rhs))] + ir = in_ir + [rule.cond] + out_ir + opt_ir = optimize_ir(out_var, ir) + body = ir_to_code(opt_ir) + return format_cpp(rule_pattern.format(rule.name, body, out_var)) + + +def format_cpp(s): + p = Popen(['clang-format'], stdout=PIPE, stdin=PIPE, stderr=STDOUT) + out = p.communicate(input=s.encode())[0] + return out.decode() + + +def main(): + # (define-rule SgtEliminate ((x (_ BitVec n)) (y (_ BitVec n))) (bvsgt x y) (bvsgt y x)) + + # sgt_eliminate = Rule('SgtEliminate', + # {'x': Sort(BaseSort.BitVec, [Var('n', int_sort)]), + # 'y': Sort(BaseSort.BitVec, [Var('n', int_sort)])}, + # BoolConst(True), + # Fn(Op.BVSGT, [Var('x'), Var('y')]), + # Fn(Op.BVSLT, [Var('y'), Var('x')])) + + file_pattern = """ + #include "expr/node.h" + #include "theory/bv/theory_bv_utils.h" + #include "util/bitvector.h" + + namespace CVC4 {{ + namespace theory {{ + namespace bv {{ + namespace rules {{ + + {} + + }} + }} + }} + }}""" + + parser = argparse.ArgumentParser(description='Compile rewrite rules.') + parser.add_argument('infile', + nargs='?', + type=argparse.FileType('r'), + default=sys.stdin, + help='Rule file') + args = parser.parse_args() + + rules = parse_rules(args.infile.read()) + rules_code = [] + for rule in rules: + rules_code.append(gen_rule(rule)) + + + print(format_cpp(file_pattern.format('\n'.join(rules_code)))) + + # zero_extend_eliminate = Rule('ZeroExtendEliminate', + # [Var('x', Sort(BaseSort.BitVec, [Var('n', int_sort)]))], + # BoolConst(True), + # Fn(Fn(Op.ZERO_EXTEND, [IntConst(0)]), [Var('x')]), + # Var('x')) + # print(format_cpp(gen_rule(zero_extend_eliminate))) + + +if __name__ == "__main__": + main() diff --git a/src/theory/rewriter/ir.py b/src/theory/rewriter/ir.py new file mode 100644 index 000000000..c9c9fa9c3 --- /dev/null +++ b/src/theory/rewriter/ir.py @@ -0,0 +1,33 @@ +from node import collect_vars + +class IRNode: + def __init__(self): + pass + + +class Assign(IRNode): + def __init__(self, name, expr): + super(IRNode, self).__init__() + self.name = name + self.expr = expr + + +class Assert(IRNode): + def __init__(self, expr): + super(IRNode, self).__init__() + self.expr = expr + + +def optimize_ir(out_var, instrs): + used_vars = set([out_var]) + for instr in instrs: + if isinstance(instr, Assert): + used_vars |= collect_vars(instr.expr) + elif isinstance(instr, Assign): + used_vars |= collect_vars(instr.expr) + + opt_instrs = [] + for instr in instrs: + if not(isinstance(instr, Assign)) or instr.name in used_vars: + opt_instrs.append(instr) + return opt_instrs diff --git a/src/theory/rewriter/node.py b/src/theory/rewriter/node.py new file mode 100644 index 000000000..5c3876488 --- /dev/null +++ b/src/theory/rewriter/node.py @@ -0,0 +1,122 @@ +from enum import Enum, auto + + +class Op(Enum): + BVSGT = auto() + BVSLT = auto() + BVULT = auto() + BVULE = auto() + BVNEG = auto() + ZERO_EXTEND = auto() + NOT = auto() + + EQ = auto() + + GET_KIND = auto() + GET_CHILD = auto() + MK_NODE = auto() + MK_CONST = auto() + BV_SIZE = auto() + + +class BaseSort(Enum): + Bool = auto() + BitVec = auto() + Int = auto() + Kind = auto() + + +class Node: + def __init__(self, children, sort=None): + self.children = children + self.sort = sort + + +class Var(Node): + def __init__(self, name, sort=None): + super().__init__([], sort) + self.name = name + + +class BoolConst(Node): + def __init__(self, val): + super().__init__([]) + self.val = val + + +class BVConst(Node): + def __init__(self, val, bw): + super().__init__([], Sort(BaseSort.BitVec, [bw])) + self.val = val + self.bw = bw + + +class KindConst(Node): + def __init__(self, val): + super().__init__([]) + self.val = val + + +class IntConst(Node): + def __init__(self, val): + super().__init__([]) + self.val = val + + +class GetChild(Node): + def __init__(self, path): + super().__init__([]) + self.path = path + + +class Fn(Node): + def __init__(self, op, args): + super().__init__(args) + self.op = op + + +class Sort: + def __init__(self, base, args): + self.base = base + self.args = args + + +def collect_vars(node): + if isinstance(node, Var): + return set(node.name) + + result = set() + for child in node.children: + result |= collect_vars(child) + + if isinstance(node, BVConst): + result |= collect_vars(node.bw) + + return result + + +def unify_types(t1, t2): + assert t1.base == t2.base + if t1.base == BaseSort.BitVec: + if isinstance(t1.args[0], Var) and isinstance(t2.args[0], Var): + if t1.args[0].name == t2.args[0].name: + return t1 + + +def infer_types(rvars, node): + if node.sort: + return + + if isinstance(node, Var): + node.sort = rvars[node.name] + return + + for child in node.children: + infer_types(rvars, child) + + sort = None + if isinstance(node, Fn): + if node.op in [Op.EQ, Op.BVSGT, Op.BVSLT, Op.BVULT]: + sort = unify_types(node.children[0].sort, node.children[1].sort) + + node.sort = sort diff --git a/src/theory/rewriter/parser.py b/src/theory/rewriter/parser.py new file mode 100644 index 000000000..0f56edb79 --- /dev/null +++ b/src/theory/rewriter/parser.py @@ -0,0 +1,62 @@ +import pyparsing as pp + +from node import * +from rule import Rule + +symbol_to_op = { + 'bvsgt': Op.BVSGT, + 'bvslt': Op.BVSLT, + 'bvult': Op.BVULT, + 'bvule': Op.BVULE, + 'bvneg': Op.BVNEG, + 'not': Op.NOT, + '=': Op.EQ +} + + +def bv_to_int(s): + assert s.startswith('bv') + return int(s[2:]) + + +def parse_expr(): + expr = pp.Forward() + bconst = pp.Keyword('true').setParseAction( + lambda s, l, t: BoolConst(True)) | pp.Keyword('false').setParseAction( + lambda s, l, t: BoolConst(False)) + bvconst = ( + pp.Suppress('(') + pp.Suppress('_') + pp.Word(pp.alphanums) + expr + + ')').setParseAction(lambda s, l, t: BVConst(bv_to_int(t[0]), t[1])) + app = (pp.Suppress('(') + pp.Word(pp.alphas + '=') + pp.OneOrMore(expr) + + pp.Suppress(')') + ).setParseAction(lambda s, l, t: Fn(symbol_to_op[t[0]], t[1:])) + expr <<= bconst | bvconst | app | pp.Word( + pp.alphas).setParseAction(lambda s, l, t: Var(t[0])) + return expr + + +def parse_sort(): + return (pp.Suppress('(') + (pp.Suppress('_') + pp.Keyword('BitVec')) + + parse_expr() + pp.Suppress(')') + ).setParseAction(lambda s, l, t: Sort(BaseSort.BitVec, [t[1]])) + + +def parse_var(): + return (pp.Suppress('(') + pp.Word(pp.alphas) + parse_sort() + + pp.Suppress(')')).setParseAction(lambda s, l, t: (t[0], t[1])) + + +def parse_var_list(): + return (pp.Suppress('(') + pp.OneOrMore(parse_var()) + + pp.Suppress(')')).setParseAction(lambda s, l, t: dict(t[:])) + + +def parse_rules(s): + comments = pp.ZeroOrMore(pp.Suppress(pp.cStyleComment)) + + rule = (pp.Suppress('(') + pp.Keyword('define-rule') + pp.Word(pp.alphas) + + parse_var_list() + parse_expr() + parse_expr() + + pp.Suppress(')')).setParseAction( + lambda s, l, t: Rule(t[1], t[2], BoolConst(True), t[3], t[4])) + rules = pp.OneOrMore(rule) + return rules.parseString(s) diff --git a/src/theory/rewriter/rule.py b/src/theory/rewriter/rule.py new file mode 100644 index 000000000..f1a0f4b71 --- /dev/null +++ b/src/theory/rewriter/rule.py @@ -0,0 +1,7 @@ +class Rule: + def __init__(self, name, rvars, cond, lhs, rhs): + self.name = name + self.rvars = rvars + self.cond = cond + self.lhs = lhs + self.rhs = rhs diff --git a/src/theory/rewriter/rules/basic.rules b/src/theory/rewriter/rules/basic.rules new file mode 100644 index 000000000..20b7d3769 --- /dev/null +++ b/src/theory/rewriter/rules/basic.rules @@ -0,0 +1,47 @@ +(define-rule SgtEliminate ((x (_ BitVec n)) (y (_ BitVec n))) + (bvsgt x y) + (bvslt y x)) + +(define-rule LtSelfUlt ((x (_ BitVec n))) + (bvult x x) + false) + +(define-rule LtSelfSlt ((x (_ BitVec n))) + (bvslt x x) + false) + +(define-rule ZeroUlt ((x (_ BitVec n))) + (bvult (_ bv0 n) x) + (not (= (_ bv0 n) x))) + +(define-rule UltZero ((x (_ BitVec n))) + (bvult x (_ bv0 n)) + false) + +(define-rule UltOne ((x (_ BitVec n))) + (bvult x (_ bv1 n)) + (= x (_ bv0 n))) + +(define-rule UleZero ((x (_ BitVec n))) + (bvule x (_ bv0 n)) + (= x (_ bv0 n))) + +(define-rule UleSelf ((x (_ BitVec n))) + (bvule x x) + true) + +(define-rule ZeroUle ((x (_ BitVec n))) + (bvule (_ bv0 n) x) + true) + +(define-rule NotUlt ((x (_ BitVec n)) (y (_ BitVec n))) + (not (bvult x y)) + (bvule y x)) + +(define-rule NotUle ((x (_ BitVec n)) (y (_ BitVec n))) + (not (bvule x y)) + (bvult y x)) + +(define-rule NegIdemp ((x (_ BitVec n))) + (bvneg (bvneg x)) + x) |