diff options
-rw-r--r-- | proofs/signatures/rewrites.plf | 61 | ||||
-rw-r--r-- | src/theory/booleans/theory_bool_rewriter.cpp | 10 | ||||
-rw-r--r-- | src/theory/bv/theory_bv_rewriter.cpp | 18 | ||||
-rw-r--r-- | src/theory/bv/theory_bv_utils.cpp | 27 | ||||
-rw-r--r-- | src/theory/bv/theory_bv_utils.h | 2 | ||||
-rwxr-xr-x | src/theory/rewriter/compiler.py | 246 | ||||
-rw-r--r-- | src/theory/rewriter/node.py | 116 | ||||
-rw-r--r-- | src/theory/rewriter/parser.py | 33 | ||||
-rw-r--r-- | src/theory/rewriter/rules/basic.rules | 452 |
9 files changed, 905 insertions, 60 deletions
diff --git a/proofs/signatures/rewrites.plf b/proofs/signatures/rewrites.plf index 2b5a7e7cc..e8d2382d8 100644 --- a/proofs/signatures/rewrites.plf +++ b/proofs/signatures/rewrites.plf @@ -1,4 +1,20 @@ -(declare trusted_rewrite_f +(declare trusted_formula_rewrite + (! x formula + (! y formula + (! u (th_holds (iff x y)) + (! z formula + (th_holds (iff x z))))))) + +(declare trusted_term_rewrite + (! s sort + (! x (term s) + (! y (term s) + (! u (th_holds (= s x y)) + (! z (term s) + (th_holds (= s x z)))))))) + +; TODO: Side condition that checks the evaluation +(declare const_eval_f (! x formula (! y formula (! u (th_holds (iff x y)) @@ -34,11 +50,48 @@ (! u2 (th_holds (iff y w)) (th_holds (iff (op x y) (op z w))))))))))) -(declare neg_idemp +(declare symm_bvpred + (! op bvpred (! n mpz (! x (term (BitVec n)) (! y (term (BitVec n)) - (! u (th_holds (= (BitVec n) x (bvneg _ (bvneg _ y)))) - (th_holds (= (BitVec n) x y))))))) + (! z (term (BitVec n)) + (! w (term (BitVec n)) + (! u1 (th_holds (= (BitVec n) x z)) + (! u2 (th_holds (= (BitVec n) y w)) + (th_holds (iff (op n x y) (op n z w)))))))))))) + +(declare neg_idemp + (! n mpz + (! lhs (term (BitVec n)) + (! y (term (BitVec n)) + (! u (th_holds (= (BitVec n) lhs (bvneg _ (bvneg _ y)))) + (th_holds (= _ lhs y))))))) + +(declare reflexivity_eq + (! n mpz + (! x formula + (! y (term (BitVec n)) + (! u (th_holds (iff x (= (BitVec n) y y))) + (th_holds (iff x true))))))) + + (declare ult_zero + (! n mpz +(! bv0_n bv +(! original formula +(! x (term (BitVec n)) + (! u (th_holds (iff original (bvult _ x (a_bv _ bv0_n)))) + (th_holds (iff original false)))))))) + + (declare zero_extend_n_z + (! zeamount mpz +(! zebv mpz +(! n mpz +(! m mpz +(! original (term (BitVec n)) +(! m (term (BitVec n)) +(! x (term (BitVec n)) + (! u (th_holds (= _ original (zero_extend zebv m _ x))) + (th_holds (= _ original (concat _ (a_bv _ bv0_m) x)))))))))))) (declare t_eq_n_f (th_holds (iff true (not false)))) diff --git a/src/theory/booleans/theory_bool_rewriter.cpp b/src/theory/booleans/theory_bool_rewriter.cpp index ca22467d4..0cdb124e5 100644 --- a/src/theory/booleans/theory_bool_rewriter.cpp +++ b/src/theory/booleans/theory_bool_rewriter.cpp @@ -15,11 +15,13 @@ ** \todo document this file **/ +#include "theory/booleans/theory_bool_rewriter.h" + #include <algorithm> #include <unordered_set> #include "expr/node_value.h" -#include "theory/booleans/theory_bool_rewriter.h" +#include "theory/rewriter/rules.h" namespace CVC4 { namespace theory { @@ -141,8 +143,10 @@ RewriteResponse TheoryBoolRewriter::preRewrite(TNode n) { switch(n.getKind()) { case kind::NOT: { - if (n[0] == tt) return RewriteResponse(REWRITE_DONE, ff); - if (n[0] == ff) return RewriteResponse(REWRITE_DONE, tt); + if (n[0] == tt) + return RewriteResponse(REWRITE_DONE, ff, rules::RewriteRule::CONST_EVAL); + if (n[0] == ff) + return RewriteResponse(REWRITE_DONE, tt, rules::RewriteRule::CONST_EVAL); if (n[0].getKind() == kind::NOT) return RewriteResponse(REWRITE_AGAIN, n[0][0]); break; } diff --git a/src/theory/bv/theory_bv_rewriter.cpp b/src/theory/bv/theory_bv_rewriter.cpp index 97189aab2..c4975d710 100644 --- a/src/theory/bv/theory_bv_rewriter.cpp +++ b/src/theory/bv/theory_bv_rewriter.cpp @@ -56,6 +56,12 @@ RewriteResponse TheoryBVRewriter::postRewrite(TNode node) { } RewriteResponse TheoryBVRewriter::RewriteUlt(TNode node, bool prerewrite) { + RewriteResponse response = rules::UltZero(node); + if (response.d_node != node) + { + return response; + } + // reduce common subexpressions on both sides Node resultNode = LinearRewriteStrategy < RewriteRule<EvalUlt>, // if both arguments are constants evaluates @@ -568,6 +574,12 @@ RewriteResponse TheoryBVRewriter::RewriteRepeat(TNode node, bool prerewrite) { } RewriteResponse TheoryBVRewriter::RewriteZeroExtend(TNode node, bool prerewrite){ + RewriteResponse response = rules::ZeroExtendNZ(node); + if (response.d_node != node) + { + return response; + } + Node resultNode = LinearRewriteStrategy < RewriteRule<ZeroExtendEliminate > >::apply(node); @@ -647,6 +659,12 @@ RewriteResponse TheoryBVRewriter::RewriteIntToBV(TNode node, bool prerewrite) { } RewriteResponse TheoryBVRewriter::RewriteEqual(TNode node, bool prerewrite) { + RewriteResponse response = rules::ReflexivityEq(node); + if (response.d_node != node) + { + return response; + } + if (prerewrite) { Node resultNode = LinearRewriteStrategy < RewriteRule<FailEq>, diff --git a/src/theory/bv/theory_bv_utils.cpp b/src/theory/bv/theory_bv_utils.cpp index b98aacb2f..07fe4a49e 100644 --- a/src/theory/bv/theory_bv_utils.cpp +++ b/src/theory/bv/theory_bv_utils.cpp @@ -56,6 +56,33 @@ unsigned getSignExtendAmount(TNode node) return node.getOperator().getConst<BitVectorSignExtend>().d_signExtendAmount; } +uint32_t getIndex(TNode node, size_t index) +{ + if (node.getKind() == kind::BITVECTOR_ZERO_EXTEND) + { + if (index == 0) + { + return node.getOperator() + .getConst<BitVectorZeroExtend>() + .d_zeroExtendAmount; + } + Unreachable(); + } + if (node.getKind() == kind::BITVECTOR_EXTRACT) + { + if (index == 0) + { + return node.getOperator().getConst<BitVectorExtract>().d_high; + } + else if (index == 1) + { + return node.getOperator().getConst<BitVectorExtract>().d_low; + } + Unreachable(); + } + Unreachable(); +} + /* ------------------------------------------------------------------------- */ bool isOnes(TNode node) diff --git a/src/theory/bv/theory_bv_utils.h b/src/theory/bv/theory_bv_utils.h index 4303926f1..0cf50dbdc 100644 --- a/src/theory/bv/theory_bv_utils.h +++ b/src/theory/bv/theory_bv_utils.h @@ -52,6 +52,8 @@ unsigned getExtractLow(TNode node); /* Get the number of bits by which a given node is extended. */ unsigned getSignExtendAmount(TNode node); +uint32_t getIndex(TNode node, size_t index); + /* Returns true if given node represents a bit-vector comprised of ones. */ bool isOnes(TNode node); diff --git a/src/theory/rewriter/compiler.py b/src/theory/rewriter/compiler.py index 75d73d2d4..3e9d644cc 100755 --- a/src/theory/rewriter/compiler.py +++ b/src/theory/rewriter/compiler.py @@ -10,20 +10,66 @@ from node import * from rule import Rule from parser import parse_rules +from backend_lfsc import collect_params + op_to_kind = { + Op.BVUGT: 'BITVECTOR_UGT', + Op.BVUGE: 'BITVECTOR_UGE', Op.BVSGT: 'BITVECTOR_SGT', + Op.BVSGE: 'BITVECTOR_SGE', Op.BVSLT: 'BITVECTOR_SLT', + Op.BVSLE: 'BITVECTOR_SLE', Op.BVULT: 'BITVECTOR_ULT', Op.BVULE: 'BITVECTOR_ULE', Op.BVNEG: 'BITVECTOR_NEG', + Op.BVADD: 'BITVECTOR_PLUS', + Op.BVSUB: 'BITVECTOR_SUB', + Op.CONCAT: 'BITVECTOR_CONCAT', Op.ZERO_EXTEND: 'BITVECTOR_ZERO_EXTEND', Op.NOT: 'NOT', Op.EQ: 'EQUAL', } +op_to_lfsc = { + Op.BVUGT: 'bvugt', + Op.BVUGE: 'bvuge', + Op.BVSGT: 'bvsgt', + Op.BVSGE: 'bvsge', + Op.BVSLT: 'bvslt', + Op.BVSLE: 'bvsle', + Op.BVULT: 'bvult', + Op.BVULE: 'bvule', + Op.BVNEG: 'bvneg', + Op.BVADD: 'bvadd', + Op.BVSUB: 'bvsub', + Op.CONCAT: 'concat', + Op.ZERO_EXTEND: 'zero_extend', + Op.NOT: 'not', + Op.EQ: '=', +} + + +op_to_nindex = { + Op.BVUGT: 0, + Op.BVUGE: 0, + Op.BVSGT: 0, + Op.BVSGE: 0, + Op.BVSLT: 0, + Op.BVSLE: 0, + Op.BVULT: 0, + Op.BVULE: 0, + Op.BVNEG: 0, + Op.BVADD: 0, + Op.BVSUB: 0, + Op.CONCAT: 0, + Op.ZERO_EXTEND: 1, + Op.NOT: 0, + Op.EQ: 0, +} + def rule_to_in_ir(rvars, lhs): - def expr_to_ir(expr, path, vars_seen, out_ir): + def expr_to_ir(expr, path, vars_seen, out_ir, in_index = False): if isinstance(expr, Fn): out_ir.append( Assert( @@ -31,7 +77,8 @@ def rule_to_in_ir(rvars, lhs): [Fn(Op.GET_KIND, [GetChild(path)]), KindConst(expr.op)]))) for i, child in enumerate(expr.children): - expr_to_ir(child, path + [i], vars_seen, out_ir) + index = i if i < op_to_nindex[expr.op] else i - op_to_nindex[expr.op] + expr_to_ir(child, path + [index], vars_seen, out_ir, i < op_to_nindex[expr.op]) if isinstance(expr.op, Fn): pass @@ -42,10 +89,15 @@ def rule_to_in_ir(rvars, lhs): Assert(Fn(Op.EQ, [Var(expr.name), GetChild(path)]))) else: - out_ir.append(Assign(expr.name, GetChild(path))) + if in_index: + index_expr = GetIndex(path) + index_expr.sort = Sort(BaseSort.Int, []) + out_ir.append(Assign(expr.name, index_expr)) + else: + out_ir.append(Assign(expr.name, GetChild(path))) if expr.sort is not None and expr.sort.base == BaseSort.BitVec: - width = expr.sort.args[0] + width = expr.sort.children[0] if isinstance(width, Var) and not width.name in vars_seen: bv_size_expr = Fn(Op.BV_SIZE, [GetChild(path)]) bv_size_expr.sort = Sort(BaseSort.Int, []) @@ -104,6 +156,9 @@ def expr_to_code(expr): elif isinstance(expr, GetChild): path_str = ''.join(['[{}]'.format(i) for i in expr.path]) return '__node{}'.format(path_str) + elif isinstance(expr, GetIndex): + path_str = ''.join(['[{}]'.format(i) for i in expr.path[:-1]]) + return 'bv::utils::getIndex(__node{}, {})'.format(path_str, expr.path[-1]) elif isinstance(expr, BoolConst): return ('true' if expr.val else 'false') elif isinstance(expr, BVConst): @@ -148,7 +203,6 @@ def gen_rule(rule): return RewriteResponse(REWRITE_AGAIN, {}, RewriteRule::{}); }}""" - infer_types(rule.rvars, rule.lhs) in_ir = rule_to_in_ir(rule.rvars, rule.lhs) out_ir = [Assign(out_var, rule_to_out_expr(rule.rhs))] ir = in_ir + [rule.cond] + out_ir @@ -162,13 +216,18 @@ def gen_rule_printer(rule): rule_printer_pattern = """ if (step->d_tag == RewriteRule::{}) {{ - os << "({} _ _ _ "; + os << "({} {} _ _ "; printRewriteProof(useCache, tp, step->d_children[0], os, globalLetMap); os << ")"; return; }} """ - return rule_printer_pattern.format(name_to_enum(rule.name), name_to_enum(rule.name).lower()) + + # TODO: put in ProofRule instead of recomputing + params = collect_params(rule) + params_str = ' '.join(['_'] * len(params)) + + return rule_printer_pattern.format(name_to_enum(rule.name), name_to_enum(rule.name).lower(), params_str) def gen_proof_printer(rules): @@ -192,23 +251,20 @@ def gen_proof_printer(rules): {{ if (step->d_tag == RewriteRule::NONE && step->d_children.size() == 0) {{ - switch (step->d_original.getKind()) + TypeNode tn = step->d_original.getType(); + if (tn.isBoolean()) {{ - case kind::EQUAL: - {{ - os << "(iff_symm "; - tp->printTheoryTerm(step->d_original.toExpr(), os, globalLetMap); - os << ")"; - return; - }} - - default: - {{ - os << "(refl _ "; - tp->printTheoryTerm(step->d_original.toExpr(), os, globalLetMap); - os << ")"; - return; - }} + os << "(iff_symm "; + tp->printTheoryTerm(step->d_original.toExpr(), os, globalLetMap); + os << ")"; + return; + }} + else + {{ + os << "(refl _ "; + tp->printTheoryTerm(step->d_original.toExpr(), os, globalLetMap); + os << ")"; + return; }} }} else if (step->d_tag == RewriteRule::NONE) @@ -231,6 +287,16 @@ def gen_proof_printer(rules): return; }} + case kind::BITVECTOR_ULT: + {{ + os << "(symm_bvpred bvult _ _ _ _ _ "; + printRewriteProof(useCache, tp, step->d_children[0], os, globalLetMap); + os << " "; + printRewriteProof(useCache, tp, step->d_children[1], os, globalLetMap); + os << ")"; + return; + }} + case kind::IMPLIES: {{ os << "(symm_formula_op2 impl _ _ _ _ "; @@ -251,6 +317,16 @@ def gen_proof_printer(rules): return; }} + case kind::OR: + {{ + os << "(symm_formula_op2 or _ _ _ _ "; + printRewriteProof(useCache, tp, step->d_children[0], os, globalLetMap); + os << " "; + printRewriteProof(useCache, tp, step->d_children[1], os, globalLetMap); + os << ")"; + return; + }} + case kind::EQUAL: {{ os << "(symm_equal _ _ _ _ _ "; @@ -261,17 +337,43 @@ def gen_proof_printer(rules): return; }} - default: Unimplemented(); + default: Unimplemented() << "Not supported: " << step->d_original.getKind(); }} }} else if (step->d_tag == RewriteRule::UNKNOWN) {{ - os << "(trusted_rewrite_f _ _ "; - printRewriteProof(useCache, tp, step->d_children[0], os, globalLetMap); - os << " "; - tp->printTheoryTerm(step->d_rewritten.toExpr(), os, globalLetMap); - os << ")"; - return; + TypeNode tn = step->d_original.getType(); + if (tn.isBoolean()) + {{ + os << "(trusted_formula_rewrite _ _ "; + printRewriteProof(useCache, tp, step->d_children[0], os, globalLetMap); + os << " "; + tp->printTheoryTerm(step->d_rewritten.toExpr(), os, globalLetMap); + os << ")"; + return; + }} + else + {{ + os << "(trusted_term_rewrite _ _ _ "; + printRewriteProof(useCache, tp, step->d_children[0], os, globalLetMap); + os << " "; + tp->printTheoryTerm(step->d_rewritten.toExpr(), os, globalLetMap); + os << ")"; + return; + }} + }} + else if (step->d_tag == RewriteRule::CONST_EVAL) + {{ + if (step->d_rewritten.getType().isBoolean()) + {{ + os << "(const_eval_f _ _ "; + printRewriteProof(useCache, tp, step->d_children[0], os, globalLetMap); + os << " "; + tp->printTheoryTerm(step->d_rewritten.toExpr(), os, globalLetMap); + os << ")"; + return; + }} + Unreachable(); }} {} }} @@ -308,6 +410,7 @@ def gen_enum(rules): enum class RewriteRule {{ {}, UNKNOWN, + CONST_EVAL, NONE }}; @@ -352,6 +455,79 @@ def format_cpp(s): return out.decode() +def sort_to_lfsc(sort): + if sort and sort.base == BaseSort.Bool: + return 'formula' + else: # if sort.base == BaseSort.BitVec: + return '(term (BitVec n))' + +def expr_to_lfsc(expr): + if isinstance(expr, Fn): + if expr.op in [Op.ZERO_EXTEND]: + args = [expr_to_lfsc(arg) for arg in expr.children] + return '({} zebv {} _ {})'.format(op_to_lfsc[expr.op], ' '.join(args[:op_to_nindex[expr.op]]), ' '.join(args[op_to_nindex[expr.op]:])) + else: + args = [expr_to_lfsc(arg) for arg in expr.children] + return '({} _ {})'.format(op_to_lfsc[expr.op], ' '.join(args)) + + elif isinstance(expr, Var): + return expr.name + elif isinstance(expr, BVConst): + return '(a_bv _ {})'.format('bv{}_{}'.format(expr.val, expr.bw)) + elif isinstance(expr, BoolConst): + return ('true' if expr.val else 'false') + +def rule_to_lfsc(rule): + rule_pattern = """ + (declare {} + {} + (! u (th_holds {}) + (th_holds {}))){}""" + closing_parens = '' + + rule_name = name_to_enum(rule.name).lower() + + params = collect_params(rule) + + varargs = [] + + for param in params: + sort_str = '' + if param.sort.base == BaseSort.Int: + sort_str = 'mpz' + elif param.sort.base == BaseSort.BitVec: + sort_str = 'bv' + else: + print('Unsupported sort: {}'.format(param.sort_base)) + assert False + varargs.append('(! {} {}'.format(param.name, sort_str)) + closing_parens += ')' + + varargs.append('(! original {}'.format(sort_to_lfsc(rule.lhs.sort))) + closing_parens += ')' + + for name, sort in rule.rvars.items(): + varargs.append('(! {} {}'.format(name, sort_to_lfsc(sort))) + closing_parens += ')' + + if rule.lhs.sort.base == BaseSort.Bool: + lhs = '(iff original {})'.format(expr_to_lfsc(rule.lhs)) + rhs = '(iff original {})'.format(expr_to_lfsc(rule.rhs)) + else: + lhs = '(= _ original {})'.format(expr_to_lfsc(rule.lhs)) + rhs = '(= _ original {})'.format(expr_to_lfsc(rule.rhs)) + + print(rule_pattern.format(rule_name, '\n'.join(varargs), lhs, rhs, closing_parens)) + + +def type_check(rules): + for rule in rules: + infer_types(rule.rvars, rule.lhs) + + # Ensure that we were able to compute the types for the whole left-hand side + assert rule.lhs.sort is not None + + def main(): # (define-rule SgtEliminate ((x (_ BitVec n)) (y (_ BitVec n))) (bvsgt x y) (bvsgt y x)) @@ -383,16 +559,14 @@ def main(): rules = parse_rules(args.infile.read()) + type_check(rules) + args.rulesfile.write(gen_enum(rules)) args.implementationfile.write(gen_rules_implementation(rules)) args.printerfile.write(gen_proof_printer(rules)) - # zero_extend_eliminate = Rule('ZeroExtendEliminate', - # [Var('x', Sort(BaseSort.BitVec, [Var('n', int_sort)]))], - # BoolConst(True), - # Fn(Fn(Op.ZERO_EXTEND, [IntConst(0)]), [Var('x')]), - # Var('x')) - # print(format_cpp(gen_rule(zero_extend_eliminate))) + for rule in rules: + rule_to_lfsc(rule) if __name__ == "__main__": diff --git a/src/theory/rewriter/node.py b/src/theory/rewriter/node.py index 5c3876488..d8c9c0ec9 100644 --- a/src/theory/rewriter/node.py +++ b/src/theory/rewriter/node.py @@ -2,18 +2,32 @@ from enum import Enum, auto class Op(Enum): + BVUGT = auto() + BVUGE = auto() BVSGT = auto() + BVSGE = auto() BVSLT = auto() + BVSLE = auto() BVULT = auto() BVULE = auto() + BVNEG = auto() + BVADD = auto() + BVSUB = auto() + + CONCAT = auto() + ZERO_EXTEND = auto() + + PLUS = auto() + NOT = auto() EQ = auto() GET_KIND = auto() GET_CHILD = auto() + GET_INDEX = auto() MK_NODE = auto() MK_CONST = auto() BV_SIZE = auto() @@ -31,6 +45,15 @@ class Node: self.children = children self.sort = sort + def __eq__(self, other): + if len(self.children) != len(other.children): + return False + + for c1, c2 in zip(self.children, other.children): + if c1 != c2: + return False + + return True class Var(Node): def __init__(self, name, sort=None): @@ -38,11 +61,27 @@ class Var(Node): self.name = name + def __eq__(self, other): + return self.name == other.name + + + def __hash__(self): + return hash(self.name) + + def __repr__(self): + return self.name + + class BoolConst(Node): def __init__(self, val): super().__init__([]) self.val = val + def __eq__(self, other): + return self.val == other.val + + def __hash__(self): + return hash(self.val) class BVConst(Node): def __init__(self, val, bw): @@ -50,36 +89,67 @@ class BVConst(Node): self.val = val self.bw = bw + def __eq__(self, other): + return self.val == other.val and self.bw == other.bw + + def __hash__(self): + return hash((self.bw, self.val)) class KindConst(Node): def __init__(self, val): super().__init__([]) self.val = val + def __eq__(self, other): + return self.val == other.val + + def __hash__(self): + return hash(self.val) class IntConst(Node): def __init__(self, val): super().__init__([]) self.val = val + def __eq__(self, other): + return self.val == other.val + + def __hash__(self): + return hash(self.val) class GetChild(Node): def __init__(self, path): super().__init__([]) self.path = path +class GetIndex(Node): + def __init__(self, path): + super().__init__([]) + self.path = path class Fn(Node): def __init__(self, op, args): super().__init__(args) self.op = op + def __eq__(self, other): + return self.op == other.op and super().__eq__(other) + + def __hash__(self): + return hash((self.op, tuple(self.children))) + -class Sort: +class Sort(Node): def __init__(self, base, args): + super().__init__(args) self.base = base - self.args = args + print(base, args) + + def __eq__(self, other): + return self.base == other.base and super().__eq__(other) + def __hash__(self): + return hash((self.base, tupe(self.children))) def collect_vars(node): if isinstance(node, Var): @@ -116,7 +186,45 @@ def infer_types(rvars, node): sort = None if isinstance(node, Fn): - if node.op in [Op.EQ, Op.BVSGT, Op.BVSLT, Op.BVULT]: - sort = unify_types(node.children[0].sort, node.children[1].sort) + if node.op in [ + Op.BVUGT, + Op.BVUGE, + Op.BVSGT, + Op.BVSGE, + Op.BVSLT, + Op.BVSLE, + Op.BVULT, + Op.BVULE]: + assert node.children[0].sort.base == BaseSort.BitVec + assert node.children[0].sort == node.children[1].sort + sort = Sort(BaseSort.Bool, []) + elif node.op in [ + Op.BVADD, + Op.BVSUB]: + assert node.children[0].sort.base == BaseSort.BitVec + assert node.children[0].sort == node.children[1].sort + sort = node.children[0].sort + elif node.op in [Op.CONCAT]: + assert node.children[0].sort.base == BaseSort.BitVec + assert node.children[1].sort.base == BaseSort.BitVec + sort = Sort(BaseSort.BitVec, [Fn(Op.PLUS, [node.children[0].sort.children[0], node.children[1].sort.children[1]])]) + elif node.op in [Op.ZERO_EXTEND]: + assert len(node.children) == 2 + assert node.children[0].sort.base == BaseSort.Int + assert node.children[1].sort.base == BaseSort.BitVec + sort = Sort(BaseSort.BitVec, [Fn(Op.PLUS, [node.children[0], node.children[1].sort.children[0]])]) + elif node.op in [Op.BVNEG]: + assert node.children[0].sort.base == BaseSort.BitVec + sort = node.children[0].sort + elif node.op in [Op.NOT]: + assert node.children[0].sort.base == BaseSort.Bool + sort = Sort(BaseSort.Bool, []) + elif node.op in [Op.EQ]: + assert node.children[0].sort == node.children[1].sort + sort = Sort(BaseSort.Bool, []) + else: + print('Unsupported operator: {}'.format(node.op)) + assert False node.sort = sort + print(node.op, sort) diff --git a/src/theory/rewriter/parser.py b/src/theory/rewriter/parser.py index 0f56edb79..af32a5321 100644 --- a/src/theory/rewriter/parser.py +++ b/src/theory/rewriter/parser.py @@ -4,11 +4,19 @@ from node import * from rule import Rule symbol_to_op = { + 'bvugt': Op.BVUGT, + 'bvuge': Op.BVUGE, 'bvsgt': Op.BVSGT, + 'bvsge': Op.BVSGE, 'bvslt': Op.BVSLT, + 'bvsle': Op.BVSLE, 'bvult': Op.BVULT, 'bvule': Op.BVULE, 'bvneg': Op.BVNEG, + 'bvadd': Op.BVADD, + 'bvsub': Op.BVSUB, + 'concat': Op.CONCAT, + 'zeroextend': Op.ZERO_EXTEND, 'not': Op.NOT, '=': Op.EQ } @@ -27,18 +35,27 @@ def parse_expr(): bvconst = ( pp.Suppress('(') + pp.Suppress('_') + pp.Word(pp.alphanums) + expr + ')').setParseAction(lambda s, l, t: BVConst(bv_to_int(t[0]), t[1])) + indexed_app = (pp.Suppress('(') + pp.Suppress('(') + pp.Suppress('_') + + pp.Word(pp.alphas + '=' + '_') + pp.OneOrMore(expr) + + pp.Suppress(')') + pp.OneOrMore(expr) + + pp.Suppress(')')).setParseAction( + lambda s, l, t: Fn(symbol_to_op[t[0]], t[1:])) app = (pp.Suppress('(') + pp.Word(pp.alphas + '=') + pp.OneOrMore(expr) + pp.Suppress(')') ).setParseAction(lambda s, l, t: Fn(symbol_to_op[t[0]], t[1:])) - expr <<= bconst | bvconst | app | pp.Word( + expr <<= bconst | bvconst | indexed_app | app | pp.Word( pp.alphas).setParseAction(lambda s, l, t: Var(t[0])) return expr def parse_sort(): - return (pp.Suppress('(') + (pp.Suppress('_') + pp.Keyword('BitVec')) + - parse_expr() + pp.Suppress(')') - ).setParseAction(lambda s, l, t: Sort(BaseSort.BitVec, [t[1]])) + + bv_sort = (pp.Suppress('(') + (pp.Suppress('_') + pp.Keyword('BitVec')) + + parse_expr() + pp.Suppress(')') + ).setParseAction(lambda s, l, t: Sort(BaseSort.BitVec, [t[1]])) + int_sort = pp.Keyword('Int').setParseAction( + lambda s, l, t: Sort(BaseSort.Int, [])) + return bv_sort | int_sort def parse_var(): @@ -54,9 +71,9 @@ def parse_var_list(): def parse_rules(s): comments = pp.ZeroOrMore(pp.Suppress(pp.cStyleComment)) - rule = (pp.Suppress('(') + pp.Keyword('define-rule') + pp.Word(pp.alphas) + - parse_var_list() + parse_expr() + parse_expr() + - pp.Suppress(')')).setParseAction( - lambda s, l, t: Rule(t[1], t[2], BoolConst(True), t[3], t[4])) + rule = comments + (pp.Suppress('(') + pp.Keyword('define-rule') + pp.Word( + pp.alphas) + parse_var_list() + parse_expr() + parse_expr() + + pp.Suppress(')')).setParseAction(lambda s, l, t: Rule( + t[1], t[2], BoolConst(True), t[3], t[4])) rules = pp.OneOrMore(rule) return rules.parseString(s) diff --git a/src/theory/rewriter/rules/basic.rules b/src/theory/rewriter/rules/basic.rules index 20b7d3769..55f52b3c7 100644 --- a/src/theory/rewriter/rules/basic.rules +++ b/src/theory/rewriter/rules/basic.rules @@ -1,7 +1,256 @@ +/****************************************************************************** + * Operator Elimination + ******************************************************************************/ + +(define-rule UgtEliminate ((x (_ BitVec n)) (y (_ BitVec n))) + (bvugt x y) + (bvult y x)) + +(define-rule UgeEliminate ((x (_ BitVec n)) (y (_ BitVec n))) + (bvuge x y) + (bvule y x)) + (define-rule SgtEliminate ((x (_ BitVec n)) (y (_ BitVec n))) (bvsgt x y) (bvslt y x)) +(define-rule SgeEliminate ((x (_ BitVec n)) (y (_ BitVec n))) + (bvsge x y) + (bvsle y x)) + +(define-rule SltEliminate ((x (_ BitVec n)) (y (_ BitVec n))) + (bvslt x y) + (let (pow_two (bvshl (_ bv1 n) ((_ nat2bv n) (- size 1))) + (bvult (bvadd x pow_two) (bvadd y pow_two))))) + +(define-rule SleEliminate ((x (_ BitVec n)) (y (_ BitVec n))) + (bvsle x y) + (not (bvslt y x))) + +(define-rule UleEliminate ((x (_ BitVec n)) (y (_ BitVec n))) + (bvule x y) + (not (bvult y x))) + +(define-rule CompEliminate ((x (_ BitVec n)) (y (_ BitVec n))) + (bvcomp x y) + (ite (= x y) (_ bv 1 1) (_ bv 0 1))) + +(define-rule SubEliminate ((x (_ BitVec n)) (y (_ BitVec n))) + (bvsub x y) + (bvadd x (bvneg y))) + +/* RepeatEliminate: Worth it? */ + +(define-rule RotateLeftEliminate ((i Int) (x (_ BitVec n))) + ((_ rotate_left i) x) + (ite (= i 0) x (concat ((_ extract (- n (+ 1 i)) 0) ) ((_ extract (- n 1) (- n i)) x)))) + +(define-rule RotateRightEliminate ((i Int) (x (_ BitVec n))) + ((_ rotate_right i) x) + (ite (= i 0) x (concat ((_ extract (- i 1) 0) ) ((_ extract (- n 1) i) x)))) + +/* BVToNatEliminate: COMPLEX */ + +/* IntToBVEliminate: COMPLEX */ + +(define-rule NandEliminate ((x (_ BitVec n)) (y (_ BitVec n))) + (bvnand x y) + (bvnot (bvand x y))) + +(define-rule NorEliminate ((x (_ BitVec n)) (y (_ BitVec n))) + (bvnor x y) + (bvnot (bvor x y))) + +(define-rule XnorEliminate ((x (_ BitVec n)) (y (_ BitVec n))) + (bvxnor x y) + (bvnot (bvxor x y))) + +(define-rule SdivEliminate ((x (_ BitVec n)) (y (_ BitVec n))) + (bvsdiv x y) + (let ((x_lt_0 (= ((_ extract (- n 1) (- n 1)) x) (_ bv 1 1))) + (y_lt_0 (= ((_ extract (- n 1) (- n 1)) y) (_ bv 1 1))) + (abs_x (ite x_lt_0 (bvneg x) x)) + (abs_y (ite y_lt_0 (bvneg y) y)) + (x_udiv_y (bvudiv abs_x abs_y))) + (ite (xor x_lt_0 y_lt_0) + (bvneg x_udiv_y) + x_udiv_y))) + +(define-rule SremEliminate ((x (_ BitVec n)) (y (_ BitVec n))) + (bvsdiv x y) + (let ((x_lt_0 (= ((_ extract (- n 1) (- n 1)) x) (_ bv 1 1))) + (y_lt_0 (= ((_ extract (- n 1) (- n 1)) y) (_ bv 1 1))) + (abs_x (ite x_lt_0 (bvneg x) x)) + (abs_y (ite y_lt_0 (bvneg y) y)) + (x_urem_y (bvurem abs_x abs_y))) + (ite x_lt_0 + (bvneg x_urem_y) + x_urem_y))) + +(define-rule SmodEliminate ((x (_ BitVec n)) (y (_ BitVec n))) + (bvsdiv x y) + (let ((msb_x ((_ extract (- n 1) (- n 1)) x)) + (msb_y ((_ extract (- n 1) (- n 1)) y)) + (x_lt_0 (= msb_x (_ bv 1 1))) + (y_lt_0 (= msb_y (_ bv 1 1))) + (abs_x (ite x_lt_0 (bvneg x) x)) + (abs_y (ite y_lt_0 (bvneg y) y)) + (x_urem_y (bvurem abs_x abs_y))) + (ite (= x_urem_y (_ bv 0 n)) + x_urem_y + (ite (and (= msb_x (_ bv 0 1)) (= msb_y (_ bv 0 1))) + x_urem_y + (ite (and (= msb_x (_ bv 1 1)) (= msb_y (_ bv 0 1))) + (bvadd (bvneg x_urem_y) y) + (ite (and (= msb_x (_ bv 0 1)) (= msb_y (_ bv 1 1))) + (bvadd x_urem_y y) + (bvneg x_urem_y))))))) + +(define-rule ZeroExtendEliminate ((i Int) (x (_ BitVec n))) + ((_ zero_extend i) x) + (ite (= i 0) x (concat (_ bv 0 i) x))) + +(define-rule SignExtendEliminate ((i Int) (x (_ BitVec n))) + ((_ zero_extend i) x) + (ite (= i 0) x (concat ((_ repeat i) ((_ extract (- n 1) (- n 1)) x)) x))) + +(define-rule RedorEliminate ((x (_ BitVec n))) + (redor x) + (not (= x (_ bv 0 n)))) + +(define-rule RedandEliminate ((x (_ BitVec n))) + (redand x) + (= x (bvnot (_ bv 0 n)))) + +/****************************************************************************** + * Simplification + ******************************************************************************/ + +(define-rule BvIteConstCond ((x (_ BitVec n)) (y (_ BitVec n))) + (bvite (_ bv m 1) x y) + (ite (= m 1) x y)) + +(define-rule BvIteEqualChildren ((c (_ BitVec 1)) (x (_ BitVec m))) + (bvite c x x) + x) + +(define-rule BvIteConstChildren ((n Int) (m Int) (x (_ BitVec o))) + (bvite x (_ bv n 1) (_ bv m 1)) + (ite (and (= n 1) (= m 0)) + x + (bvnot x))) + +(define-rule BvIteEqualCond1 ((c (_ BitVec 1)) (x (_ BitVec n)) (y (_ BitVec n)) (z (_ BitVec n))) + (bvite c (bvite c x y) z) + (bvite c x z)) + +(define-rule BvIteEqualCond2 ((c (_ BitVec 1)) (x (_ BitVec n)) (y (_ BitVec n)) (z (_ BitVec n))) + (bvite c x (bvite c y z)) + (bvite c x z)) + +(define-rule BvIteMergeThenIf ((c0 (_ BitVec 1)) (c1 (_ BitVec 1)) (x (_ BitVec n)) (y (_ BitVec n))) + (bvite c0 (bvite c1 x y) x) + (bvite (bvand c0 (bvnot c1)) y x)) + +(define-rule BvIteMergeElseIf ((c0 (_ BitVec 1)) (c1 (_ BitVec 1)) (x (_ BitVec n)) (y (_ BitVec n))) + (bvite c0 (bvite c1 x y) y) + (bvite (bvand c0 c1) x y)) + +(define-rule BvIteMergeThenElse ((c0 (_ BitVec 1)) (c1 (_ BitVec 1)) (x (_ BitVec n)) (y (_ BitVec n))) + (bvite c0 x (bvite c1 x y)) + (bvite (bvand (bvnot c0) (bvnot c1)) y x)) + +(define-rule BvIteMergeElseElse ((c0 (_ BitVec 1)) (c1 (_ BitVec 1)) (x (_ BitVec n)) (y (_ BitVec n))) + (bvite c0 x (bvite c1 y x)) + (bvite (bvand (bvnot c0) c1) y x)) + +(define-rule BvComp ((n Int) (x (_ BitVec 1))) + (bvcomp (_ bv n 1) x) + (ite (= n 0) (bvnot x) x)) + +(define-rule ShlByConst (()) + (bvshl x (_ bv n m)) + (ite (= n 0) + x + (ite (>= n m) (_ bv 0 m) (concat ((_ extract (- m (+ n 1)) 0) ) (_ bv 0 n))))) + +(define-rule LshrByConst (()) + (bvshl x (_ bv n m)) + (ite (= n 0) + x + (ite (>= n m) (_ bv 0 m) (concat (_ bv 0 n) ((_ extract (- m 1) n)))))) + +(define-rule AshrByConst (()) + (bvshl x (_ bv n m)) + (ite (= n 0) + x + (let (sign_bit ((_ extract (- m 1) (- m 1)) x)) + (ite (>= n m) ((_ repeat m) sign_bit) (concat ((_ repeat n) sign_bit) ((_ extract (- m (+ n 1)) 0))))))) + +/* Note: rewrite is limited to 2 children */ +(define-rule BitwiseIdempAnd ((x (_ BitVec n))) + (bvand x x) + x) + +/* Note: rewrite is limited to 2 children */ +(define-rule BitwiseIdempOr ((x (_ BitVec n))) + (bvor x x) + x) + +/* Note: rewrite is limited to 2 children */ +(define-rule AndZero ((x (_ BitVec n))) + (bvand x (_ bv 0 n)) + (_ bv 0 n)) + +/* Note: rewrite is limited to 2 children */ +(define-rule AndOne ((x (_ BitVec n))) + (bvand x (bvnot (_ bv 0 n))) + x) + +/* AndOrXorConcatPullUp */ + +/* Note: rewrite is limited to 2 children */ +(define-rule OrZero ((x (_ BitVec n))) + (bvor x (_ bv 0 n)) + x) + +/* Note: rewrite is limited to 2 children */ +(define-rule OrOne ((x (_ BitVec n))) + (bvor x (bvnot (_ bv 0 n))) + (bvnot (_ bv 0 n))) + +/* Note: rewrite is limited to 2 children */ +(define-rule XorDuplicate ((x (_ BitVec n))) + (bvor x x) + (_ bv 0 n)) + +/* FIXME */ +(define-rule XorOne ((x (_ BitVec n))) + (bvxor x (bvnot (_ bv 0 n))) + (bvnot (_ bv 0 n))) + +/* XorZero */ + +/* Note: rewrite is limited to 2 children */ +(define-rule BitwiseNotAnd ((x (_ BitVec n))) + (bvand x (bvnot x)) + (_ bv 0 n)) + +/* Note: rewrite is limited to 2 children */ +(define-rule BitwiseNotOr ((x (_ BitVec n))) + (bvor x (bvnot x)) + (bvnot (_ bv 0 n))) + +/* XorNot: DISABLED */ + +(define-rule NotXor ((x (_ BitVec n)) (xs (_ Set (_ BitVec n)))) + (bvnot (bvxor x xs)) + (bvxor (bvnot x) xs)) + +(define-rule NotIdemp ((x (_ BitVec n))) + (bvnot (bvnot x)) + x) + (define-rule LtSelfUlt ((x (_ BitVec n))) (bvult x x) false) @@ -10,22 +259,40 @@ (bvslt x x) false) +(define-rule LteSelfUle ((x (_ BitVec n))) + (bvule x x) + true) + +(define-rule LteSelfSle ((x (_ BitVec n))) + (bvsle x x) + true) + (define-rule ZeroUlt ((x (_ BitVec n))) - (bvult (_ bv0 n) x) - (not (= (_ bv0 n) x))) + (bvult (_ bv 0 n) x) + (not (= (_ bv 0 n) x))) (define-rule UltZero ((x (_ BitVec n))) - (bvult x (_ bv0 n)) + (bvult x (_ bv 0 n)) false) (define-rule UltOne ((x (_ BitVec n))) - (bvult x (_ bv1 n)) - (= x (_ bv0 n))) + (bvult x (_ bv 1 n)) + (= x (_ bv 0 n))) + +(define-rule SltZero ((x (_ BitVec n))) + (bvslt x (_ bv 0 n)) + (= ((_ extract (- n 1) (- n 1)) x) (_ bv 1 1))) + +/* Note: Duplicate of LtSelfUlt */ +(define-rule UltSelf ((x (_ BitVec n))) + (bvult x x) + false) (define-rule UleZero ((x (_ BitVec n))) (bvule x (_ bv0 n)) (= x (_ bv0 n))) +/* Note: Duplicate of LteSelfUle */ (define-rule UleSelf ((x (_ BitVec n))) (bvule x x) true) @@ -38,10 +305,185 @@ (not (bvult x y)) (bvule y x)) +(define-rule UleMax ((x (_ BitVec n))) + (bvult x (bvnot (_ bv 0 n))) + true) + +(define-rule NotUlt ((x (_ BitVec n)) (y (_ BitVec n))) + (not (bvult x y)) + (bvule y x)) + (define-rule NotUle ((x (_ BitVec n)) (y (_ BitVec n))) (not (bvule x y)) (bvult y x)) +/* MultPow2 */ + +/* ExtractMultLeadingBit: BROKEN? */ + (define-rule NegIdemp ((x (_ BitVec n))) (bvneg (bvneg x)) x) + +/* UdivPow2 */ + +(define-rule UdivZero ((x (_ BitVec n))) + (bvudiv x (_ bv 0 n)) + (bvnot (_ bv 0 n))) + +(define-rule UdivOne ((x (_ BitVec n))) + (bvudiv x (_ bv 1 n)) + x) + +/* UremPow2 */ + +(define-rule UremOne ((x (_ BitVec n))) + (bvudiv x (_ bv 1 n)) + (_ bv 0 n)) + +(define-rule UremSelf ((x (_ BitVec n))) + (bvudiv x x) + (_ bv 0 n)) + +(define-rule ShiftZeroShl ((x (_ BitVec n))) + (bvshl (_ bv 0 n) x) + (_ bv 0 n)) + +(define-rule ShiftZeroLshr ((x (_ BitVec n))) + (bvlshr (_ bv 0 n) x) + (_ bv 0 n)) + +(define-rule ShiftZeroAshr ((x (_ BitVec n))) + (bvashr (_ bv 0 n) x) + (_ bv 0 n)) + +/* BBPlusNeg */ + +/* MergeSignExtend */ + +/* ZeroExtendEqConst */ + +/* SignExtendEqConst */ + +/* ZeroExtendUltConst */ + +/* SignExtendUltConst */ + +/* MultSlice */ + +/* UltPlusOne */ + +/* IsPowerOfTwo */ + +/* MultSltMult */ + +/****************************************************************************** + * Core + ******************************************************************************/ + +/* ConcatFlatten: FLATTEN IMPLICIT */ + +/* ConcatExtractMerge */ + +/* ConcatConstantMerge */ + +(define-rule ExtractWhole ((x (_ BitVec n))) + ((_ extract (- n 1) 0) x) + x) + +/* ExtractConstant: CONST EVAL */ + +/* ExtractConcat */ + +(define-rule ExtractExtract ((i Int) (j Int) (k Int) (l Int) (x (_ BitVec n))) + ((_ extract i j) ((_ extract k l) x)) + ((_ extract (+ k j) (+ l j)) x)) + +/* FailEq: CONST EVAL */ + +(define-rule SimplifyEq ((x (_ BitVec n))) + (= x x) + true) + +/* Do we want that as a rule? */ +(define-rule SimplifyEq ((x (_ BitVec n)) (y (_ BitVec n))) + (= x y) + (< (id x) (id y)) + (= y x)) + +/****************************************************************************** + * Normalization + ******************************************************************************/ + +(define-rule ExtractBitwiseAnd ((i Int) (j Int) (xs (_ Set (_ BitVec n)))) + ((_ extract i j) (bvand xs)) + (bvand (map xs (lambda ((_ x (_ BitVec n))) ((_ extract i j) x))))) + +(define-rule ExtractBitwiseOr ((i Int) (j Int) (xs (_ Set (_ BitVec n)))) + ((_ extract i j) (bvor xs)) + (bvor (map xs (lambda ((_ x (_ BitVec n))) ((_ extract i j) x))))) + +(define-rule ExtractBitwiseXor ((i Int) (j Int) (xs (_ Set (_ BitVec n)))) + ((_ extract i j) (bvxor xs)) + (bvxor (map xs (lambda ((_ x (_ BitVec n))) ((_ extract i j) x))))) + +(define-rule ExtractNot ((i Int) (j Int) (x (_ BitVec n))) + ((_ extract i j) (bvnot x)) + (bvnot ((_ extract i j) x))) + +/* ExtractSignExtend */ + +(define-rule ExtractArithPlus ((i Int) (xs (_ Set (_ BitVec n)))) + ((_ extract i 0) (bvadd xs)) + (bvadd (map xs (lambda ((_ x (_ BitVec n))) ((_ extract i 0) x))))) + +(define-rule ExtractArithMult ((i Int) (xs (_ Set (_ BitVec n)))) + ((_ extract i 0) (bvmul xs)) + (bvmul (map xs (lambda ((_ x (_ BitVec n))) ((_ extract i 0) x))))) + +(define-rule ExtractArith2Plus ((i Int) (j Int) (xs (_ Set (_ BitVec n)))) + ((_ extract i j) (bvadd xs)) + ((_ extract i j) (bvadd (map xs (lambda ((_ x (_ BitVec n))) ((_ extract i 0) x)))))) + +(define-rule ExtractArith2Mult ((i Int) (j Int) (xs (_ Set (_ BitVec n)))) + ((_ extract i j) (bvmul xs)) + ((_ extract i j) (bvmul (map xs (lambda ((_ x (_ BitVec n))) ((_ extract i 0) x)))))) + +/* FlattenAssocCommut: FLATTEN IMPLICIT */ + +/* PlusCombineLikeTerms: COMPLEX */ + +/* MultSimplify */ + +/* MultDistribConst */ + +/* MultDistrib */ + +/* ConcatToMult */ + +/* SolveEq: COMPLEX */ + +/* BitwiseEq */ + +/* Note: original rewrite directly constant folds */ +(define-rule NegMult ((xs (_ Set (_ BitVec n)))) + (bvneg (bvmul xs)) + (bvmul (bvneg (_ bv 1)) xs)) + +(define-rule NegSub ((x (_ BitVec n)) (y (_ BitVec n))) + (bvneg (bvsub x y)) + (bvsub y x)) + +/* NegPlus */ + +/* AndSimplify */ + +/* FlattenAssocCommutNoDuplicates: FLATTEN IMPLICIT */ + +/* OrSimplify */ + +/* XorSimplify */ + +/* BitwiseSlicing */ + +/* NormalizeEqPlusNeg */ |