summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndres Noetzli <andres.noetzli@gmail.com>2020-04-13 01:12:03 -0700
committerAndres Noetzli <andres.noetzli@gmail.com>2020-04-13 01:12:03 -0700
commiteb5d13fbb251b6b87579086566c18e9050b858ac (patch)
tree72cdf74b7554692c4b8e58cf0c1b6c0b04ddef73
parent218c08fd77344dfc11b7879143bf94eaa5a3ba7e (diff)
LHS eval + optimization
-rw-r--r--src/proof/proof_manager.cpp6
-rw-r--r--src/proof/rewrite_proof.cpp4
-rw-r--r--src/theory/bv/theory_bv_rewriter.cpp8
-rwxr-xr-xsrc/theory/rewriter/compiler.py79
-rw-r--r--src/theory/rewriter/ir.py62
-rw-r--r--src/theory/rewriter/node.py30
-rw-r--r--src/theory/rewriter/parser.py2
-rw-r--r--src/theory/rewriter/rules/basic.rules12
8 files changed, 159 insertions, 44 deletions
diff --git a/src/proof/proof_manager.cpp b/src/proof/proof_manager.cpp
index 419095cf3..f59c6cd0d 100644
--- a/src/proof/proof_manager.cpp
+++ b/src/proof/proof_manager.cpp
@@ -878,20 +878,20 @@ void LFSCProof::printPreprocessedAssertions(const NodeSet& assertions,
<< std::endl;
os << "(th_let_pf _ ";
- RewriteProof rp;
+ // RewriteProof rp;
if ((*it).getKind() == kind::NOT
&& (*it)[0] == NodeManager::currentNM()->mkConst<bool>(false))
{
os << "t_eq_n_f ";
}
- else if (theory::Rewriter::rewriteWithProof(inputAssertion, &rp) == *it)
+ /*else if (theory::Rewriter::rewriteWithProof(inputAssertion, &rp) == *it)
{
theory::rules::RewriteProofPrinter::printProof(
ProofManager::currentPM()->getTheoryProofEngine(),
rp,
os,
globalLetMap);
- }
+ }*/
else
{
os << "(trust_f (iff ";
diff --git a/src/proof/rewrite_proof.cpp b/src/proof/rewrite_proof.cpp
index e0b7c171c..dfb0e6f4d 100644
--- a/src/proof/rewrite_proof.cpp
+++ b/src/proof/rewrite_proof.cpp
@@ -204,8 +204,8 @@ void RewriteProof::printCachedProofs(TheoryProofEngine* tp,
{
os << std::endl;
os << "(@ let" << iter->second->d_id << " ";
- theory::rules::RewriteProofPrinter::printRewriteProof(
- false, tp, iter->second, os, globalLetMap);
+ //theory::rules::RewriteProofPrinter::printRewriteProof(
+ // false, tp, iter->second, os, globalLetMap);
paren << ")";
}
}
diff --git a/src/theory/bv/theory_bv_rewriter.cpp b/src/theory/bv/theory_bv_rewriter.cpp
index c4975d710..25db555fa 100644
--- a/src/theory/bv/theory_bv_rewriter.cpp
+++ b/src/theory/bv/theory_bv_rewriter.cpp
@@ -56,11 +56,13 @@ RewriteResponse TheoryBVRewriter::postRewrite(TNode node) {
}
RewriteResponse TheoryBVRewriter::RewriteUlt(TNode node, bool prerewrite) {
+ /*
RewriteResponse response = rules::UltZero(node);
if (response.d_node != node)
{
return response;
}
+ */
// reduce common subexpressions on both sides
Node resultNode = LinearRewriteStrategy
@@ -418,11 +420,13 @@ RewriteResponse TheoryBVRewriter::RewriteSub(TNode node, bool prerewrite){
RewriteResponse TheoryBVRewriter::RewriteNeg(TNode node, bool prerewrite) {
Node resultNode = node;
+ /*
RewriteResponse response = rules::NegIdemp(node);
if (response.d_node != node)
{
return response;
}
+ */
resultNode = LinearRewriteStrategy
< RewriteRule<EvalNeg>,
@@ -574,11 +578,13 @@ RewriteResponse TheoryBVRewriter::RewriteRepeat(TNode node, bool prerewrite) {
}
RewriteResponse TheoryBVRewriter::RewriteZeroExtend(TNode node, bool prerewrite){
+ /*
RewriteResponse response = rules::ZeroExtendNZ(node);
if (response.d_node != node)
{
return response;
}
+ */
Node resultNode = LinearRewriteStrategy
< RewriteRule<ZeroExtendEliminate >
@@ -659,11 +665,13 @@ RewriteResponse TheoryBVRewriter::RewriteIntToBV(TNode node, bool prerewrite) {
}
RewriteResponse TheoryBVRewriter::RewriteEqual(TNode node, bool prerewrite) {
+ /*
RewriteResponse response = rules::ReflexivityEq(node);
if (response.d_node != node)
{
return response;
}
+ */
if (prerewrite) {
Node resultNode = LinearRewriteStrategy
diff --git a/src/theory/rewriter/compiler.py b/src/theory/rewriter/compiler.py
index c12b4cb70..cda5e2261 100755
--- a/src/theory/rewriter/compiler.py
+++ b/src/theory/rewriter/compiler.py
@@ -33,7 +33,8 @@ op_to_kind = {
}
op_to_const_eval = {
- Op.BVSHL: '{}.leftShift({})',
+ Op.BVSHL: '({}.leftShift({}))',
+ Op.BVNOT: '(~{})',
Op.PLUS: '({} + {})',
Op.MINUS: '({} - {})',
Op.EQ: '({} == {})',
@@ -82,17 +83,21 @@ op_to_nindex = {
def rule_to_in_ir(rvars, lhs):
def expr_to_ir(expr, path, vars_seen, out_ir, in_index = False):
if isinstance(expr, Fn):
- out_ir.append(
- Assert(
- Fn(Op.EQ,
- [Fn(Op.GET_KIND, [GetChild(path)]),
- KindConst(expr.op)])))
- for i, child in enumerate(expr.children):
- index = i if i < op_to_nindex[expr.op] else i - op_to_nindex[expr.op]
- expr_to_ir(child, path + [index], vars_seen, out_ir, i < op_to_nindex[expr.op])
+ if expr.sort.const:
+ out_ir.append(
+ Assert(Fn(
+ Op.EQ,
+ [GetChild(path), Fn(Op.MK_CONST, [expr])])))
+ else:
+ out_ir.append(
+ Assert(
+ Fn(Op.EQ,
+ [Fn(Op.GET_KIND, [GetChild(path)]),
+ KindConst(expr.op)])))
+ for i, child in enumerate(expr.children):
+ index = i if i < op_to_nindex[expr.op] else i - op_to_nindex[expr.op]
+ expr_to_ir(child, path + [index], vars_seen, out_ir, i < op_to_nindex[expr.op])
- if isinstance(expr.op, Fn):
- pass
elif isinstance(expr, Var):
if expr.name in vars_seen:
@@ -103,16 +108,18 @@ def rule_to_in_ir(rvars, lhs):
if in_index:
index_expr = GetIndex(path)
index_expr.sort = Sort(BaseSort.Int, [])
- out_ir.append(Assign(expr.name, index_expr))
+ out_ir.append(Assign(expr, index_expr))
else:
- out_ir.append(Assign(expr.name, GetChild(path)))
+ out_ir.append(Assign(expr, GetChild(path)))
if expr.sort is not None and expr.sort.base == BaseSort.BitVec:
width = expr.sort.children[0]
if isinstance(width, Var) and not width.name in vars_seen:
bv_size_expr = Fn(Op.BV_SIZE, [GetChild(path)])
- bv_size_expr.sort = Sort(BaseSort.Int, [])
- out_ir.append(Assign(width.name, bv_size_expr))
+ bv_size_expr.sort = Sort(BaseSort.Int, [], True)
+ # TODO: should resolve earlier?
+ width.sort = Sort(BaseSort.Int, [], True)
+ out_ir.append(Assign(width, bv_size_expr))
vars_seen.add(width.name)
vars_seen.add(expr.name)
@@ -206,12 +213,19 @@ def rule_to_out_expr(cfg, next_block, res, expr):
for var, child in zip(new_vars, expr.children):
next_block = rule_to_out_expr(cfg, next_block, var, child)
return next_block
- elif isinstance(expr, BoolConst) or isinstance(expr, BVConst):
- return Fn(Op.MK_CONST, [expr])
+ elif isinstance(expr, BoolConst):
+ assign_block = fresh_name('block')
+ res.sort = Sort(BaseSort.Bool, [])
+ assign = Assign(res, expr)
+ if next_block:
+ cfg[assign_block] = CFGNode([assign], [CFGEdge(BoolConst(True), next_block)])
+ else:
+ assign.expr = Fn(Op.MK_CONST, [assign.expr])
+ cfg[assign_block] = CFGNode([assign], [])
+ return assign_block
elif isinstance(expr, IntConst):
assign_block = fresh_name('block')
res.sort = Sort(BaseSort.Int, [])
- print('{} should be int {}'.format(res, expr.val))
assign = Assign(res, expr)
if next_block:
cfg[assign_block] = CFGNode([assign], [CFGEdge(BoolConst(True), next_block)])
@@ -272,20 +286,20 @@ def expr_to_code(expr):
def sort_to_code(sort):
- if not sort:
+ if not sort or not sort.const:
# TODO: should not happen
return 'Node'
elif sort.base == BaseSort.Int:
return 'uint32_t'
elif sort.base == BaseSort.BitVec:
- return 'BitVector' if sort.const else 'Node'
+ return 'BitVector'
def ir_to_code(match_instrs):
code = []
for instr in match_instrs:
if isinstance(instr, Assign):
- code.append('{} {} = {};'.format(sort_to_code(instr.expr.sort),
+ code.append('{} = {};'.format(
instr.name,
expr_to_code(instr.expr)))
elif isinstance(instr, Assert):
@@ -327,6 +341,7 @@ def gen_rule(rule):
RewriteResponse {}(TNode __node) {{
NodeManager* nm = NodeManager::currentNM();
{}
+ {}
return RewriteResponse(REWRITE_AGAIN, {}, RewriteRule::{});
}}"""
@@ -335,14 +350,20 @@ def gen_rule(rule):
entry = rule_to_out_expr(cfg, None, out_var, rule.rhs)
match_block_name = fresh_name('block')
cfg[match_block_name] = CFGNode(match_block, [CFGEdge(BoolConst(True), entry)])
-
- optimize_cfg(match_block_name, cfg)
- print(cfg_to_str(cfg))
out_ir = rule_to_out_expr(cfg, None, out_var, rule.rhs)
+
+ optimize_cfg(out_var, match_block_name, cfg)
# ir = in_ir + [rule.cond] + out_ir
# opt_ir = optimize_ir(out_var, ir)
+
+ cfg_vars = cfg_collect_vars(cfg)
+ var_decls = ''
+ for var in cfg_vars:
+ var_decls += '{} {};\n'.format(sort_to_code(var.sort), var.name)
+
body = cfg_to_code(match_block_name, cfg)
- result = rule_pattern.format(rule.name, body, out_var,
+
+ result = rule_pattern.format(rule.name, var_decls, body, out_var,
name_to_enum(rule.name))
print(result)
return result
@@ -662,7 +683,7 @@ def type_check(rules):
infer_types(rule.rvars, rule.rhs)
# Ensure that we were able to compute the types for both sides
- assert rule.lhs.sort is not None and rule.rhs.sort is not None
+ assert isinstance(rule.lhs.sort, Sort) and isinstance(rule.rhs.sort, Sort)
def main():
@@ -700,10 +721,10 @@ def main():
args.rulesfile.write(gen_enum(rules))
args.implementationfile.write(gen_rules_implementation(rules))
- args.printerfile.write(gen_proof_printer(rules))
+ #args.printerfile.write(gen_proof_printer(rules))
- for rule in rules:
- rule_to_lfsc(rule)
+ #for rule in rules:
+ # rule_to_lfsc(rule)
if __name__ == "__main__":
diff --git a/src/theory/rewriter/ir.py b/src/theory/rewriter/ir.py
index 76c27af2b..5aabfcc3c 100644
--- a/src/theory/rewriter/ir.py
+++ b/src/theory/rewriter/ir.py
@@ -1,4 +1,6 @@
-from node import collect_vars, BoolConst
+from collections import defaultdict
+
+from node import collect_vars, count_vars, subst, BoolConst, Var
class CFGEdge:
def __init__(self, cond, target):
@@ -30,6 +32,7 @@ class IRNode:
class Assign(IRNode):
def __init__(self, name, expr):
super(IRNode, self).__init__()
+ assert isinstance(name, Var)
self.name = name
self.expr = expr
@@ -60,7 +63,8 @@ def optimize_ir(out_var, instrs):
return opt_instrs
-def optimize_cfg(entry, cfg):
+def optimize_cfg(out_var, entry, cfg):
+ # Merge basic blocks
change = True
while change:
change = False
@@ -71,16 +75,60 @@ def optimize_cfg(entry, cfg):
cfg[label].instrs += next_block.instrs
cfg[label].edges = next_block.edges
+ # Remove unused blocks
not_called = set(cfg.keys())
- not_called.remove(entry)
- for label, node in cfg.items():
- for edge in node.edges:
- if edge.target in not_called:
- not_called.remove(edge.target)
+ to_visit = [entry]
+ while to_visit:
+ curr = to_visit[-1]
+ to_visit.pop()
+
+ not_called.remove(curr)
+ for edge in cfg[curr].edges:
+ to_visit.append(edge.target)
for target in not_called:
del cfg[target]
+ # Inline assignments that are used only once
+ used_count = defaultdict(lambda: 0)
+ substs = dict()
+ for label, node in cfg.items():
+ for instr in node.instrs:
+ if isinstance(instr, Assert):
+ count_vars(used_count, instr.expr)
+ elif isinstance(instr, Assign):
+ count_vars(used_count, instr.expr)
+ substs[instr.name] = instr.expr
+
+ del substs[out_var]
+
+ # Only keep the substitutions for variables that are unused or appear once
+ substs = dict(filter(lambda kv: used_count[kv[0]] <= 1, substs.items()))
+
+ # Apply substitutions to themselves
+ # Note: since each substituted variable can appear at most once, each
+ # substitution cannot be applied more than once, so it should be enough to
+ # do this n times where n is the number of substitutions
+ for i in range(len(substs)):
+ substs = {name:subst(expr, substs) for name, expr in substs.items()}
+
+ # Remove unused instructions and apply substitutions
+ for label, node in cfg.items():
+ node.instrs = list(filter(lambda instr: (not isinstance(instr, Assign)) or (instr.name not in substs), node.instrs))
+ for instr in node.instrs:
+ instr.expr = subst(instr.expr, substs)
+
+ for edge in node.edges:
+ edge.cond = subst(edge.cond, substs)
+
+def cfg_collect_vars(cfg):
+ cfg_vars = set()
+ for label, node in cfg.items():
+ for instr in node.instrs:
+ if isinstance(instr, Assign):
+ cfg_vars.add(instr.name)
+ return cfg_vars
+
def cfg_to_str(cfg):
result = ''
for label, node in cfg.items():
diff --git a/src/theory/rewriter/node.py b/src/theory/rewriter/node.py
index 0a4d9ba8e..b2faa648d 100644
--- a/src/theory/rewriter/node.py
+++ b/src/theory/rewriter/node.py
@@ -157,6 +157,12 @@ class GetChild(Node):
def __repr__(self):
return '__node' + ''.join('[{}]'.format(p) for p in self.path)
+ def __eq__(self, other):
+ return isinstance(other, GetChild) and self.path == other.path
+
+ def __hash__(self):
+ return hash(tuple(self.path))
+
class GetIndex(Node):
def __init__(self, path):
super().__init__([])
@@ -165,6 +171,12 @@ class GetIndex(Node):
def __repr__(self):
return 'index(__node{}, {})'.format(''.join('[{}]'.format(p) for p in self.path[:-1]), self.path[-1])
+ def __eq__(self, other):
+ return isinstance(other, GetIndex) and self.path == other.path
+
+ def __hash__(self):
+ return hash(tuple(self.path))
+
class Fn(Node):
def __init__(self, op, args):
super().__init__(args)
@@ -208,6 +220,24 @@ def collect_vars(node):
return result
+def count_vars(counts, node):
+ if isinstance(node, Var):
+ counts[node] += 1
+ else:
+ for child in node.children:
+ count_vars(counts, child)
+
+
+def subst(node, substs):
+ # TODO: non-destructive substitution?
+ if node in substs:
+ return substs[node]
+ else:
+ new_children = []
+ for child in node.children:
+ new_children.append(subst(child, substs))
+ node.children = new_children
+ return node
def unify_types(t1, t2):
assert t1.base == t2.base
diff --git a/src/theory/rewriter/parser.py b/src/theory/rewriter/parser.py
index 09d4cf141..7a977c774 100644
--- a/src/theory/rewriter/parser.py
+++ b/src/theory/rewriter/parser.py
@@ -93,7 +93,7 @@ def sort():
bv_sort = (pp.Suppress('(') + (pp.Suppress('_') + pp.Keyword('BitVec')) +
expr() + pp.Suppress(')')
).setParseAction(lambda s, l, t: Sort(BaseSort.BitVec, [t[1]]))
- int_sort = pp.Keyword('Int').setParseAction(
+ int_sort = pp.Keyword('Num').setParseAction(
lambda s, l, t: Sort(BaseSort.Int, [], True))
return bv_sort | int_sort
diff --git a/src/theory/rewriter/rules/basic.rules b/src/theory/rewriter/rules/basic.rules
index 6bee44747..cb058656b 100644
--- a/src/theory/rewriter/rules/basic.rules
+++ b/src/theory/rewriter/rules/basic.rules
@@ -1,10 +1,18 @@
-(define-rule ZeroExtendEliminate ((i Int) (x (_ BitVec n)))
+(define-rule NegIdemp ((n Num) (x (_ BitVec n)))
+ (bvneg (bvneg x))
+ x)
+
+(define-rule ZeroExtendEliminate ((i Num) (x (_ BitVec n)))
((_ zero_extend i) x)
(cond
((= i 0) x)
(concat (_ bv 0 i) x)))
-(define-rule SltEliminate ((n Int) (x (_ BitVec n)) (y (_ BitVec n)))
+(define-rule SltEliminate ((n Num) (x (_ BitVec n)) (y (_ BitVec n)))
(bvslt x y)
(let ((pow_two (bvshl (_ bv 1 n) (_ bv (- n 1) n))))
(bvult (bvadd x pow_two) (bvadd y pow_two))))
+
+(define-rule UleMax ((n Num) (x (_ BitVec n)))
+ (bvult x (bvnot (_ bv 0 n)))
+ true)
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback