summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndres Noetzli <andres.noetzli@gmail.com>2020-03-16 00:01:03 -0700
committerAndres Noetzli <andres.noetzli@gmail.com>2020-03-16 00:01:03 -0700
commite963eed14cefab140cf31fdc8c16f1d059d5a716 (patch)
tree188fd01d8bf9197b01dd5c6f6b982f741f6a586d
parent227cd8c26c508b7b444fbed6f2868f90c8281eed (diff)
Basic rules compiling
-rw-r--r--src/CMakeLists.txt3
-rw-r--r--src/theory/CMakeLists.txt2
-rw-r--r--src/theory/bv/theory_bv_rewriter.cpp1
-rw-r--r--src/theory/rewriter/CMakeLists.txt18
-rw-r--r--src/theory/rewriter/__init__.py0
-rwxr-xr-xsrc/theory/rewriter/compiler.py216
-rw-r--r--src/theory/rewriter/ir.py33
-rw-r--r--src/theory/rewriter/node.py122
-rw-r--r--src/theory/rewriter/parser.py62
-rw-r--r--src/theory/rewriter/rule.py7
-rw-r--r--src/theory/rewriter/rules/basic.rules47
11 files changed, 510 insertions, 1 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index bb2b95960..bcc3753c2 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -796,7 +796,8 @@ target_compile_definitions(cvc4
-D__STDC_FORMAT_MACROS
)
# Add libcvc4 dependencies for generated sources.
-add_dependencies(cvc4 gen-expr gen-gitinfo gen-options gen-tags gen-theory)
+add_dependencies(cvc4
+ gen-expr gen-gitinfo gen-options gen-tags gen-theory gen-rewriter)
# Add library/include dependencies
if(ENABLE_VALGRIND)
diff --git a/src/theory/CMakeLists.txt b/src/theory/CMakeLists.txt
index 4c2f66a0e..3f33356a1 100644
--- a/src/theory/CMakeLists.txt
+++ b/src/theory/CMakeLists.txt
@@ -43,3 +43,5 @@ add_custom_target(gen-theory
theory_traits.h
rewriter_tables.h
)
+
+add_subdirectory(rewriter)
diff --git a/src/theory/bv/theory_bv_rewriter.cpp b/src/theory/bv/theory_bv_rewriter.cpp
index 6a04d6e4e..7bda28923 100644
--- a/src/theory/bv/theory_bv_rewriter.cpp
+++ b/src/theory/bv/theory_bv_rewriter.cpp
@@ -24,6 +24,7 @@
#include "theory/bv/theory_bv_rewrite_rules_simplification.h"
#include "theory/bv/theory_bv_rewriter.h"
#include "theory/theory.h"
+#include "theory/rewriter/rules.h"
using namespace CVC4;
using namespace CVC4::theory;
diff --git a/src/theory/rewriter/CMakeLists.txt b/src/theory/rewriter/CMakeLists.txt
new file mode 100644
index 000000000..fe7b7b0aa
--- /dev/null
+++ b/src/theory/rewriter/CMakeLists.txt
@@ -0,0 +1,18 @@
+libcvc4_add_sources(GENERATED
+ rules.h
+)
+
+add_custom_command(
+ OUTPUT rules.h
+ COMMAND
+ ${CMAKE_CURRENT_LIST_DIR}/compiler.py
+ ${CMAKE_CURRENT_LIST_DIR}/rules/basic.rules
+ > ${CMAKE_CURRENT_BINARY_DIR}/rules.h
+ DEPENDS
+ ${CMAKE_CURRENT_LIST_DIR}/compiler.py
+)
+
+add_custom_target(gen-rewriter
+ DEPENDS
+ rules.h
+)
diff --git a/src/theory/rewriter/__init__.py b/src/theory/rewriter/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/src/theory/rewriter/__init__.py
diff --git a/src/theory/rewriter/compiler.py b/src/theory/rewriter/compiler.py
new file mode 100755
index 000000000..5fd54e6a8
--- /dev/null
+++ b/src/theory/rewriter/compiler.py
@@ -0,0 +1,216 @@
+#!/usr/bin/env python3
+
+import argparse
+import sys
+
+from subprocess import Popen, PIPE, STDOUT
+from ir import Assign, Assert, optimize_ir
+from node import *
+from rule import Rule
+from parser import parse_rules
+
+op_to_kind = {
+ Op.BVSGT: 'BITVECTOR_SGT',
+ Op.BVSLT: 'BITVECTOR_SLT',
+ Op.BVULT: 'BITVECTOR_ULT',
+ Op.BVULE: 'BITVECTOR_ULE',
+ Op.BVNEG: 'BITVECTOR_NEG',
+ Op.ZERO_EXTEND: 'BITVECTOR_ZERO_EXTEND',
+ Op.NOT: 'NOT',
+ Op.EQ: 'EQUAL',
+}
+
+
+def rule_to_in_ir(rvars, lhs):
+ def expr_to_ir(expr, path, vars_seen, out_ir):
+ if isinstance(expr, Fn):
+ out_ir.append(
+ Assert(
+ Fn(Op.EQ,
+ [Fn(Op.GET_KIND, [GetChild(path)]),
+ KindConst(expr.op)])))
+ for i, child in enumerate(expr.children):
+ expr_to_ir(child, path + [i], vars_seen, out_ir)
+
+ if isinstance(expr.op, Fn):
+ pass
+
+ elif isinstance(expr, Var):
+ if expr.name in vars_seen:
+ out_ir.append(
+ Assert(Fn(Op.EQ,
+ [Var(expr.name), GetChild(path)])))
+ else:
+ out_ir.append(Assign(expr.name, GetChild(path)))
+
+ if expr.sort is not None and expr.sort.base == BaseSort.BitVec:
+ width = expr.sort.args[0]
+ if isinstance(width, Var) and not width.name in vars_seen:
+ bv_size_expr = Fn(Op.BV_SIZE, [GetChild(path)])
+ bv_size_expr.sort = Sort(BaseSort.Int, [])
+ out_ir.append(
+ Assign(
+ width.name,
+ bv_size_expr))
+ vars_seen.add(width.name)
+
+ vars_seen.add(expr.name)
+ elif isinstance(expr, BVConst):
+ if isinstance(expr.bw, Var) and not expr.bw.name in vars_seen:
+ bv_size_expr = Fn(Op.BV_SIZE, [GetChild(path)])
+ bv_size_expr.sort = Sort(BaseSort.Int, [])
+ out_ir.append(
+ Assign(
+ expr.bw.name,
+ bv_size_expr))
+ vars_seen.add(expr.bw.name)
+
+ out_ir.append(
+ Assert(Fn(
+ Op.EQ,
+ [GetChild(path), Fn(Op.MK_CONST, [expr])])))
+ elif isinstance(expr, IntConst):
+ out_ir.append(
+ Assert(
+ Fn(Op.EQ,
+ [Fn(Op.GET_KIND, [GetChild(path)]),
+ KindConst(expr.op)])))
+
+ out_ir = []
+ vars_seen = set()
+
+ expr_to_ir(lhs, [], vars_seen, out_ir)
+ return out_ir
+
+
+def rule_to_out_expr(expr):
+ if isinstance(expr, Fn):
+ new_children = [rule_to_out_expr(child) for child in expr.children]
+ return Fn(Op.MK_NODE, [KindConst(expr.op)] + new_children)
+ elif isinstance(expr, BoolConst) or isinstance(expr, BVConst):
+ return Fn(Op.MK_CONST, [expr])
+ else:
+ return expr
+
+
+def expr_to_code(expr):
+ if isinstance(expr, Fn):
+ args = [expr_to_code(child) for child in expr.children]
+ if expr.op == Op.EQ:
+ return '({} == {})'.format(args[0], args[1])
+ elif expr.op == Op.GET_KIND:
+ return '{}.getKind()'.format(args[0])
+ elif expr.op == Op.BV_SIZE:
+ return 'utils::getSize({})'.format(args[0])
+ elif expr.op == Op.MK_CONST:
+ return 'nm->mkConst({})'.format(', '.join(args))
+ elif expr.op == Op.MK_NODE:
+ return 'nm->mkNode({})'.format(', '.join(args))
+ elif isinstance(expr, GetChild):
+ path_str = ''.join(['[{}]'.format(i) for i in expr.path])
+ return '__node{}'.format(path_str)
+ elif isinstance(expr, BoolConst):
+ return ('true' if expr.val else 'false')
+ elif isinstance(expr, BVConst):
+ bw_code = expr_to_code(expr.bw)
+ return 'BitVector({}, Integer({}))'.format(bw_code, expr.val)
+ elif isinstance(expr, KindConst):
+ return 'kind::{}'.format(op_to_kind[expr.val])
+ elif isinstance(expr, Var):
+ return expr.name
+
+
+def sort_to_code(sort):
+ return 'uint32_t' if sort and sort.base == BaseSort.Int else 'Node'
+
+
+def ir_to_code(match_instrs):
+ code = []
+ for instr in match_instrs:
+ if isinstance(instr, Assign):
+ code.append('{} {} = {};'.format(sort_to_code(instr.expr.sort), instr.name,
+ expr_to_code(instr.expr)))
+ elif isinstance(instr, Assert):
+ code.append('if (!({})) return __node;'.format(
+ expr_to_code(instr.expr)))
+
+ return '\n'.join(code)
+
+
+def gen_rule(rule):
+ out_var = '__ret'
+ rule_pattern = """
+ Node {}(TNode __node) {{
+ NodeManager* nm = NodeManager::currentNM();
+ {}
+ return {};
+ }}"""
+
+ infer_types(rule.rvars, rule.lhs)
+ in_ir = rule_to_in_ir(rule.rvars, rule.lhs)
+ out_ir = [Assign(out_var, rule_to_out_expr(rule.rhs))]
+ ir = in_ir + [rule.cond] + out_ir
+ opt_ir = optimize_ir(out_var, ir)
+ body = ir_to_code(opt_ir)
+ return format_cpp(rule_pattern.format(rule.name, body, out_var))
+
+
+def format_cpp(s):
+ p = Popen(['clang-format'], stdout=PIPE, stdin=PIPE, stderr=STDOUT)
+ out = p.communicate(input=s.encode())[0]
+ return out.decode()
+
+
+def main():
+ # (define-rule SgtEliminate ((x (_ BitVec n)) (y (_ BitVec n))) (bvsgt x y) (bvsgt y x))
+
+ # sgt_eliminate = Rule('SgtEliminate',
+ # {'x': Sort(BaseSort.BitVec, [Var('n', int_sort)]),
+ # 'y': Sort(BaseSort.BitVec, [Var('n', int_sort)])},
+ # BoolConst(True),
+ # Fn(Op.BVSGT, [Var('x'), Var('y')]),
+ # Fn(Op.BVSLT, [Var('y'), Var('x')]))
+
+ file_pattern = """
+ #include "expr/node.h"
+ #include "theory/bv/theory_bv_utils.h"
+ #include "util/bitvector.h"
+
+ namespace CVC4 {{
+ namespace theory {{
+ namespace bv {{
+ namespace rules {{
+
+ {}
+
+ }}
+ }}
+ }}
+ }}"""
+
+ parser = argparse.ArgumentParser(description='Compile rewrite rules.')
+ parser.add_argument('infile',
+ nargs='?',
+ type=argparse.FileType('r'),
+ default=sys.stdin,
+ help='Rule file')
+ args = parser.parse_args()
+
+ rules = parse_rules(args.infile.read())
+ rules_code = []
+ for rule in rules:
+ rules_code.append(gen_rule(rule))
+
+
+ print(format_cpp(file_pattern.format('\n'.join(rules_code))))
+
+ # zero_extend_eliminate = Rule('ZeroExtendEliminate',
+ # [Var('x', Sort(BaseSort.BitVec, [Var('n', int_sort)]))],
+ # BoolConst(True),
+ # Fn(Fn(Op.ZERO_EXTEND, [IntConst(0)]), [Var('x')]),
+ # Var('x'))
+ # print(format_cpp(gen_rule(zero_extend_eliminate)))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/theory/rewriter/ir.py b/src/theory/rewriter/ir.py
new file mode 100644
index 000000000..c9c9fa9c3
--- /dev/null
+++ b/src/theory/rewriter/ir.py
@@ -0,0 +1,33 @@
+from node import collect_vars
+
+class IRNode:
+ def __init__(self):
+ pass
+
+
+class Assign(IRNode):
+ def __init__(self, name, expr):
+ super(IRNode, self).__init__()
+ self.name = name
+ self.expr = expr
+
+
+class Assert(IRNode):
+ def __init__(self, expr):
+ super(IRNode, self).__init__()
+ self.expr = expr
+
+
+def optimize_ir(out_var, instrs):
+ used_vars = set([out_var])
+ for instr in instrs:
+ if isinstance(instr, Assert):
+ used_vars |= collect_vars(instr.expr)
+ elif isinstance(instr, Assign):
+ used_vars |= collect_vars(instr.expr)
+
+ opt_instrs = []
+ for instr in instrs:
+ if not(isinstance(instr, Assign)) or instr.name in used_vars:
+ opt_instrs.append(instr)
+ return opt_instrs
diff --git a/src/theory/rewriter/node.py b/src/theory/rewriter/node.py
new file mode 100644
index 000000000..5c3876488
--- /dev/null
+++ b/src/theory/rewriter/node.py
@@ -0,0 +1,122 @@
+from enum import Enum, auto
+
+
+class Op(Enum):
+ BVSGT = auto()
+ BVSLT = auto()
+ BVULT = auto()
+ BVULE = auto()
+ BVNEG = auto()
+ ZERO_EXTEND = auto()
+ NOT = auto()
+
+ EQ = auto()
+
+ GET_KIND = auto()
+ GET_CHILD = auto()
+ MK_NODE = auto()
+ MK_CONST = auto()
+ BV_SIZE = auto()
+
+
+class BaseSort(Enum):
+ Bool = auto()
+ BitVec = auto()
+ Int = auto()
+ Kind = auto()
+
+
+class Node:
+ def __init__(self, children, sort=None):
+ self.children = children
+ self.sort = sort
+
+
+class Var(Node):
+ def __init__(self, name, sort=None):
+ super().__init__([], sort)
+ self.name = name
+
+
+class BoolConst(Node):
+ def __init__(self, val):
+ super().__init__([])
+ self.val = val
+
+
+class BVConst(Node):
+ def __init__(self, val, bw):
+ super().__init__([], Sort(BaseSort.BitVec, [bw]))
+ self.val = val
+ self.bw = bw
+
+
+class KindConst(Node):
+ def __init__(self, val):
+ super().__init__([])
+ self.val = val
+
+
+class IntConst(Node):
+ def __init__(self, val):
+ super().__init__([])
+ self.val = val
+
+
+class GetChild(Node):
+ def __init__(self, path):
+ super().__init__([])
+ self.path = path
+
+
+class Fn(Node):
+ def __init__(self, op, args):
+ super().__init__(args)
+ self.op = op
+
+
+class Sort:
+ def __init__(self, base, args):
+ self.base = base
+ self.args = args
+
+
+def collect_vars(node):
+ if isinstance(node, Var):
+ return set(node.name)
+
+ result = set()
+ for child in node.children:
+ result |= collect_vars(child)
+
+ if isinstance(node, BVConst):
+ result |= collect_vars(node.bw)
+
+ return result
+
+
+def unify_types(t1, t2):
+ assert t1.base == t2.base
+ if t1.base == BaseSort.BitVec:
+ if isinstance(t1.args[0], Var) and isinstance(t2.args[0], Var):
+ if t1.args[0].name == t2.args[0].name:
+ return t1
+
+
+def infer_types(rvars, node):
+ if node.sort:
+ return
+
+ if isinstance(node, Var):
+ node.sort = rvars[node.name]
+ return
+
+ for child in node.children:
+ infer_types(rvars, child)
+
+ sort = None
+ if isinstance(node, Fn):
+ if node.op in [Op.EQ, Op.BVSGT, Op.BVSLT, Op.BVULT]:
+ sort = unify_types(node.children[0].sort, node.children[1].sort)
+
+ node.sort = sort
diff --git a/src/theory/rewriter/parser.py b/src/theory/rewriter/parser.py
new file mode 100644
index 000000000..0f56edb79
--- /dev/null
+++ b/src/theory/rewriter/parser.py
@@ -0,0 +1,62 @@
+import pyparsing as pp
+
+from node import *
+from rule import Rule
+
+symbol_to_op = {
+ 'bvsgt': Op.BVSGT,
+ 'bvslt': Op.BVSLT,
+ 'bvult': Op.BVULT,
+ 'bvule': Op.BVULE,
+ 'bvneg': Op.BVNEG,
+ 'not': Op.NOT,
+ '=': Op.EQ
+}
+
+
+def bv_to_int(s):
+ assert s.startswith('bv')
+ return int(s[2:])
+
+
+def parse_expr():
+ expr = pp.Forward()
+ bconst = pp.Keyword('true').setParseAction(
+ lambda s, l, t: BoolConst(True)) | pp.Keyword('false').setParseAction(
+ lambda s, l, t: BoolConst(False))
+ bvconst = (
+ pp.Suppress('(') + pp.Suppress('_') + pp.Word(pp.alphanums) + expr +
+ ')').setParseAction(lambda s, l, t: BVConst(bv_to_int(t[0]), t[1]))
+ app = (pp.Suppress('(') + pp.Word(pp.alphas + '=') + pp.OneOrMore(expr) +
+ pp.Suppress(')')
+ ).setParseAction(lambda s, l, t: Fn(symbol_to_op[t[0]], t[1:]))
+ expr <<= bconst | bvconst | app | pp.Word(
+ pp.alphas).setParseAction(lambda s, l, t: Var(t[0]))
+ return expr
+
+
+def parse_sort():
+ return (pp.Suppress('(') + (pp.Suppress('_') + pp.Keyword('BitVec')) +
+ parse_expr() + pp.Suppress(')')
+ ).setParseAction(lambda s, l, t: Sort(BaseSort.BitVec, [t[1]]))
+
+
+def parse_var():
+ return (pp.Suppress('(') + pp.Word(pp.alphas) + parse_sort() +
+ pp.Suppress(')')).setParseAction(lambda s, l, t: (t[0], t[1]))
+
+
+def parse_var_list():
+ return (pp.Suppress('(') + pp.OneOrMore(parse_var()) +
+ pp.Suppress(')')).setParseAction(lambda s, l, t: dict(t[:]))
+
+
+def parse_rules(s):
+ comments = pp.ZeroOrMore(pp.Suppress(pp.cStyleComment))
+
+ rule = (pp.Suppress('(') + pp.Keyword('define-rule') + pp.Word(pp.alphas) +
+ parse_var_list() + parse_expr() + parse_expr() +
+ pp.Suppress(')')).setParseAction(
+ lambda s, l, t: Rule(t[1], t[2], BoolConst(True), t[3], t[4]))
+ rules = pp.OneOrMore(rule)
+ return rules.parseString(s)
diff --git a/src/theory/rewriter/rule.py b/src/theory/rewriter/rule.py
new file mode 100644
index 000000000..f1a0f4b71
--- /dev/null
+++ b/src/theory/rewriter/rule.py
@@ -0,0 +1,7 @@
+class Rule:
+ def __init__(self, name, rvars, cond, lhs, rhs):
+ self.name = name
+ self.rvars = rvars
+ self.cond = cond
+ self.lhs = lhs
+ self.rhs = rhs
diff --git a/src/theory/rewriter/rules/basic.rules b/src/theory/rewriter/rules/basic.rules
new file mode 100644
index 000000000..20b7d3769
--- /dev/null
+++ b/src/theory/rewriter/rules/basic.rules
@@ -0,0 +1,47 @@
+(define-rule SgtEliminate ((x (_ BitVec n)) (y (_ BitVec n)))
+ (bvsgt x y)
+ (bvslt y x))
+
+(define-rule LtSelfUlt ((x (_ BitVec n)))
+ (bvult x x)
+ false)
+
+(define-rule LtSelfSlt ((x (_ BitVec n)))
+ (bvslt x x)
+ false)
+
+(define-rule ZeroUlt ((x (_ BitVec n)))
+ (bvult (_ bv0 n) x)
+ (not (= (_ bv0 n) x)))
+
+(define-rule UltZero ((x (_ BitVec n)))
+ (bvult x (_ bv0 n))
+ false)
+
+(define-rule UltOne ((x (_ BitVec n)))
+ (bvult x (_ bv1 n))
+ (= x (_ bv0 n)))
+
+(define-rule UleZero ((x (_ BitVec n)))
+ (bvule x (_ bv0 n))
+ (= x (_ bv0 n)))
+
+(define-rule UleSelf ((x (_ BitVec n)))
+ (bvule x x)
+ true)
+
+(define-rule ZeroUle ((x (_ BitVec n)))
+ (bvule (_ bv0 n) x)
+ true)
+
+(define-rule NotUlt ((x (_ BitVec n)) (y (_ BitVec n)))
+ (not (bvult x y))
+ (bvule y x))
+
+(define-rule NotUle ((x (_ BitVec n)) (y (_ BitVec n)))
+ (not (bvule x y))
+ (bvult y x))
+
+(define-rule NegIdemp ((x (_ BitVec n)))
+ (bvneg (bvneg x))
+ x)
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback