From 7a07fefb1b2bbdae0aef64a2dfff5a0e17e0998f Mon Sep 17 00:00:00 2001 From: Andres Noetzli Date: Tue, 5 Oct 2021 14:42:41 -0700 Subject: tmp --- src/rewriter/mkrewrites.py | 48 +++++++++++++++++++++++++++++++--- src/rewriter/node.py | 16 ++++++++++++ src/rewriter/parser.py | 13 ++++----- src/rewriter/rewrite_db.cpp | 5 ++-- src/rewriter/rewrite_db.h | 2 +- src/rewriter/rewrite_db_proof_cons.cpp | 19 +++++++++++--- src/rewriter/rewrite_proof_rule.cpp | 10 +++---- src/rewriter/rewrite_proof_rule.h | 8 +++--- src/rewriter/rule.py | 8 +++--- src/theory/strings/rewrites | 8 +++--- 10 files changed, 107 insertions(+), 30 deletions(-) diff --git a/src/rewriter/mkrewrites.py b/src/rewriter/mkrewrites.py index a0019308f..ae905dad9 100644 --- a/src/rewriter/mkrewrites.py +++ b/src/rewriter/mkrewrites.py @@ -16,6 +16,8 @@ def gen_kind(op): Op.OR: 'OR', Op.IMPLIES: 'IMPLIES', Op.EQ: 'EQUAL', + Op.LAMBDA: 'LAMBDA', + Op.BOUND_VARS: 'BOUND_VAR_LIST', Op.UMINUS: 'UMINUS', Op.PLUS: 'PLUS', Op.MINUS: 'MINUS', @@ -122,7 +124,7 @@ def gen_mk_node(defns, expr): def gen_rewrite_db_rule(defns, rule): fvs_list = ', '.join(bvar.name for bvar in rule.bvars) - fixed_point_arg = 'true' if rule.is_fixed_point else 'false' + fixed_point_arg = gen_mk_node(defns, rule.rhs_context) if rule.rhs_context else 'Node::null()' return f'db.addRule(DslPfRule::{rule.get_enum()}, {{ {fvs_list} }}, {gen_mk_node(defns, rule.lhs)}, {gen_mk_node(defns, rule.rhs)}, {gen_mk_node(defns, rule.cond)}, {fixed_point_arg});' @@ -147,6 +149,8 @@ def type_check(expr): sort = None if expr.op == Op.UMINUS: sort = Sort(BaseSort.Int) + elif expr.op == Op.STRING_LENGTH: + sort = Sort(BaseSort.Int) if sort: sort.is_const = all(child.sort and child.sort.is_const @@ -188,6 +192,43 @@ def validate_rule(rule): type_check(rule.cond) +def preprocess_rule(rule, decls): + if not rule.rhs_context: + return + + # Resolve placeholders + bvar = Var(fresh_name('t'), rule.rhs.sort) + decls.append(bvar) + result = dict() + to_visit = [rule.rhs_context] + while to_visit: + curr = to_visit.pop() + + if isinstance(curr, App) and curr in result: + if result[curr]: + continue + + new_args = [] + for child in curr.children: + new_args.append(result[child]) + + result[curr] = App(curr.op, new_args) + continue + + if isinstance(curr, Placeholder): + result[curr] = bvar + elif isinstance(curr, App): + to_visit.append(curr) + result[curr] = None + else: + result[curr] = curr + + to_visit.extend(curr.children) + + rule.rhs_context = App(Op.LAMBDA, [App(Op.BOUND_VARS, [bvar]), result[rule.rhs_context]]) + type_check(rule.rhs_context) + + def gen_rewrite_db(args): block_tpl = ''' {{ @@ -207,6 +248,7 @@ def gen_rewrite_db(args): for rule in rules: file_decls.extend(rule.bvars) validate_rule(rule) + preprocess_rule(rule, file_decls) rewrites.append(Rewrites(rewrites_file.name, file_decls, rules)) decls.extend(file_decls) @@ -215,12 +257,12 @@ def gen_rewrite_db(args): expr_counts = defaultdict(lambda: 0) to_visit = [ expr for rewrite in rewrites for rule in rewrite.rules - for expr in [rule.cond, rule.lhs, rule.rhs] + for expr in [rule.cond, rule.lhs, rule.rhs, rule.rhs_context] ] while to_visit: curr = to_visit.pop() - if isinstance(curr, Var): + if not curr or isinstance(curr, Var): # Don't generate definitions for variables continue diff --git a/src/rewriter/node.py b/src/rewriter/node.py index 458ba7c03..8676cac00 100644 --- a/src/rewriter/node.py +++ b/src/rewriter/node.py @@ -90,6 +90,8 @@ class Op(Enum): EQ = auto() ITE = auto() + LAMBDA = auto() + BOUND_VARS = auto() ########################################################################### # Strings @@ -197,6 +199,20 @@ class Sort(Node): return self.base == BaseSort.Int +class Placeholder(Node): + def __init__(self): + super().__init__([], None) + + def __eq__(self, other): + return isinstance(other, Placeholder) + + def __hash__(self): + return hash('_') + + def __repr__(self): + return '_' + + class Var(Node): def __init__(self, name, sort=None): super().__init__([], sort) diff --git a/src/rewriter/parser.py b/src/rewriter/parser.py index e69350111..e6c76a285 100644 --- a/src/rewriter/parser.py +++ b/src/rewriter/parser.py @@ -104,6 +104,7 @@ class SymbolTable: self.consts = { 're.none': App(Op.REGEXP_EMPTY, []), 're.allchar': App(Op.REGEXP_SIGMA, []), + '_': Placeholder(), } self.symbols = {} @@ -250,25 +251,25 @@ class Parser: lambda s, l, t: self.var_decl_action(t[0], t[1], t[2:]))) return (pp.Suppress('(') + pp.ZeroOrMore(decl) + pp.Suppress(')')) - def rule_action(self, name, cond, lhs, rhs, is_fixed_point): + def rule_action(self, name, cond, lhs, rhs, is_fixed_point, rhs_context): bvars = self.symbols.symbols.values() self.symbols.pop() - return Rule(name, bvars, cond, lhs, rhs, is_fixed_point) + return Rule(name, bvars, cond, lhs, rhs, is_fixed_point, rhs_context) def parse_rules(self, s): rule = ( pp.Suppress('(') + (pp.Keyword('define-rule*') | pp.Keyword('define-rule')) + - self.symbol() + self.var_list() + self.expr() + self.expr() + + self.symbol() + self.var_list() + self.expr() + self.expr() + pp.Optional(self.expr()) + pp.Suppress(')')).setParseAction(lambda s, l, t: self.rule_action( - t[1], CBool(True), t[2], t[3], t[0] == 'define-rule*')) + t[1], CBool(True), t[2], t[3], t[0] == 'define-rule*', t[4] if len(t) == 5 else None)) cond_rule = ( pp.Suppress('(') + (pp.Keyword('define-cond-rule*') | pp.Keyword('define-cond-rule')) + self.symbol() + self.var_list() + self.expr() + self.expr() + - self.expr() + + self.expr() + pp.Optional(self.expr()) + pp.Suppress(')')).setParseAction(lambda s, l, t: self.rule_action( - t[1], t[2], t[3], t[4], t[0] == 'define-cond-rule*')) + t[1], t[2], t[3], t[4], t[0] == 'define-cond-rule*', t[5] if len(t) == 6 else None)) rules = pp.OneOrMore(rule | cond_rule) + pp.StringEnd() rules.ignore(';' + pp.restOfLine) return rules.parseString(s) diff --git a/src/rewriter/rewrite_db.cpp b/src/rewriter/rewrite_db.cpp index 15252ae9e..cd6abde37 100644 --- a/src/rewriter/rewrite_db.cpp +++ b/src/rewriter/rewrite_db.cpp @@ -118,7 +118,7 @@ void RewriteDb::addRule(DslPfRule id, Node a, Node b, Node cond, - bool isFixedPoint, + Node context, bool isFlatForm) { NodeManager* nm = NodeManager::currentNM(); @@ -154,6 +154,7 @@ void RewriteDb::addRule(DslPfRule id, << " == " << b << std::endl; Assert(a.getType().isComparableTo(b.getType())); Node cr = d_canon.getCanonicalTerm(tmpi, false, false); + context = d_canon.getCanonicalTerm(context, false, false); Node condC = cr[1]; std::vector conds; @@ -229,7 +230,7 @@ void RewriteDb::addRule(DslPfRule id, } // initialize rule - d_rewDbRule[id].init(id, ofvs, cfvs, conds, eqC, isFixedPoint, isFlatForm); + d_rewDbRule[id].init(id, ofvs, cfvs, conds, eqC, context, isFlatForm); d_concToRules[eqC].push_back(id); d_headToRules[eqC[0]].push_back(id); } diff --git a/src/rewriter/rewrite_db.h b/src/rewriter/rewrite_db.h index 6725b864f..819ae0fb8 100644 --- a/src/rewriter/rewrite_db.h +++ b/src/rewriter/rewrite_db.h @@ -56,7 +56,7 @@ class RewriteDb Node a, Node b, Node cond, - bool isFixedPoint, + Node context, bool isFlatForm = false); /** get matches */ void getMatches(Node eq, expr::NotifyMatch* ntm); diff --git a/src/rewriter/rewrite_db_proof_cons.cpp b/src/rewriter/rewrite_db_proof_cons.cpp index 11b1624f1..e1282e0ba 100644 --- a/src/rewriter/rewrite_db_proof_cons.cpp +++ b/src/rewriter/rewrite_db_proof_cons.cpp @@ -715,7 +715,7 @@ Node RewriteDbProofCons::getRuleConclusion(const RewriteProofRule& rpr, d_currFixedPointId = rpr.getId(); // check if stgt also rewrites with the same rule? bool continueFixedPoint; - std::vector transEq; + std::vector steps; // start from the source, match again to start the chain. Notice this is // required for uniformity since we want to successfully cache the first // step, independent of the target. @@ -728,13 +728,26 @@ Node RewriteDbProofCons::getRuleConclusion(const RewriteProofRule& rpr, if (!d_currFixedPointConc.isNull()) { // currently avoid accidental loops: arbitrarily bound to 1000 - continueFixedPoint = transEq.size() <= 1000; + continueFixedPoint = steps.size() <= 1000; Assert(d_currFixedPointConc.getKind() == EQUAL); - transEq.push_back(stgt.eqNode(d_currFixedPointConc[1])); + steps.push_back(d_currFixedPointConc[1]); stgt = d_currFixedPointConc[1]; } d_currFixedPointConc = Node::null(); } while (continueFixedPoint); + + std::vector transEq; + Node prev = ssrc; + Node context = rpr.getContext(); + Node placeholder = context[0][0]; + for (Node& step : steps) { + Node stepConc = context[1].substitute(TNode(placeholder), TNode(step)); + stepConc = expr::narySubstitute(stepConc, vars, subs); + transEq.push_back(prev.eqNode(stepConc)); + prev = stepConc; + } + std::cout << transEq << std::endl; + d_currFixedPointId = DslPfRule::FAIL; // add the transistivity rule here if needed if (transEq.size() >= 2) diff --git a/src/rewriter/rewrite_proof_rule.cpp b/src/rewriter/rewrite_proof_rule.cpp index ace56f9fe..2f5ddf400 100644 --- a/src/rewriter/rewrite_proof_rule.cpp +++ b/src/rewriter/rewrite_proof_rule.cpp @@ -40,7 +40,7 @@ bool getDslPfRule(TNode n, DslPfRule& id) } RewriteProofRule::RewriteProofRule() - : d_id(DslPfRule::FAIL), d_isFixedPoint(false), d_isFlatForm(false) + : d_id(DslPfRule::FAIL), d_isFlatForm(false) { } @@ -49,7 +49,7 @@ void RewriteProofRule::init(DslPfRule id, const std::vector& fvs, const std::vector& cond, Node conc, - bool isFixedPoint, + Node context, bool isFlatForm) { // not initialized yet @@ -73,7 +73,7 @@ void RewriteProofRule::init(DslPfRule id, d_obGen.push_back(cc); } d_conc = conc; - d_isFixedPoint = isFixedPoint; + d_context = context; d_isFlatForm = isFlatForm; if (!expr::getListVarContext(conc, d_listVarCtx)) { @@ -97,7 +97,7 @@ void RewriteProofRule::init(DslPfRule id, } } // if fixed point, initialize match utility - if (d_isFixedPoint) + if (d_context != Node::null()) { d_mt.addTerm(conc[0]); } @@ -266,7 +266,7 @@ Node RewriteProofRule::getConclusionFor(const std::vector& ss) const Assert(d_fvs.size() == ss.size()); return expr::narySubstitute(d_conc, d_fvs, ss); } -bool RewriteProofRule::isFixedPoint() const { return d_isFixedPoint; } +bool RewriteProofRule::isFixedPoint() const { return d_context != Node::null(); } bool RewriteProofRule::isFlatForm() const { return d_isFlatForm; } } // namespace rewriter } // namespace cvc5 diff --git a/src/rewriter/rewrite_proof_rule.h b/src/rewriter/rewrite_proof_rule.h index 91279b2e7..c5cc475d5 100644 --- a/src/rewriter/rewrite_proof_rule.h +++ b/src/rewriter/rewrite_proof_rule.h @@ -61,7 +61,7 @@ class RewriteProofRule const std::vector& fvs, const std::vector& cond, Node conc, - bool isFixedPoint, + Node context, bool isFlatForm); /** get id */ DslPfRule getId() const; @@ -71,6 +71,8 @@ class RewriteProofRule const std::vector& getUserVarList() const; /** Get variable list */ const std::vector& getVarList() const; + /** The context that the rule is applied in */ + Node getContext() const { return d_context; } /** Does this rule have conditions? */ bool hasConditions() const; /** Get (declared) conditions */ @@ -134,8 +136,8 @@ class RewriteProofRule std::vector d_obGen; /** The conclusion of the rule (an equality) */ Node d_conc; - /** Is the rule applied until a fixed point is reached? */ - bool d_isFixedPoint; + /** Is the rule applied in some fixed point context? */ + Node d_context; /** Whether the rule is in flat form */ bool d_isFlatForm; /** the ordered list of free variables, provided by the user */ diff --git a/src/rewriter/rule.py b/src/rewriter/rule.py index c3ef1f91e..164bf8dc1 100644 --- a/src/rewriter/rule.py +++ b/src/rewriter/rule.py @@ -1,15 +1,17 @@ class Rule: - def __init__(self, name, bvars, cond, lhs, rhs, is_fixed_point): + def __init__(self, name, bvars, cond, lhs, rhs, is_fixed_point, rhs_context): self.name = name self.bvars = bvars self.cond = cond self.lhs = lhs self.rhs = rhs self.is_fixed_point = is_fixed_point + self.rhs_context = rhs_context def get_enum(self): return self.name.replace('-', '_').upper() def __repr__(self): - bvars_str = ' '.join(str(bvar) for bvar in bvars) - return f"(define-rule {self.name} ({bvars_str}) {self.lhs} {self.rhs})" + bvars_str = ' '.join(str(bvar) for bvar in self.bvars) + rhs_context_str = f' {self.rhs_context}' if self.rhs_context else '' + return f"(define-rule {self.name} ({bvars_str}) {self.lhs} {self.rhs}{rhs_context_str})" diff --git a/src/theory/strings/rewrites b/src/theory/strings/rewrites index 1e0494ca6..8d14e1f22 100644 --- a/src/theory/strings/rewrites +++ b/src/theory/strings/rewrites @@ -23,10 +23,6 @@ (= (str.substr s n m) "") (= s "")) -(define-rule* str-len-concat ((s1 String) (s2 String) (s3 String :list)) - (str.len (str.++ s1 s2 s3)) - (+ (str.len s1) (str.len (str.++ s2 s3)))) - (define-cond-rule str-len-replace-inv ((t String) (s String) (r String)) (= (str.len s) (str.len r)) (str.len (str.replace t s r)) @@ -97,3 +93,7 @@ (define-rule re-concat-star-swap ((xs RegLan :list) (r RegLan) (ys RegLan :list)) (re.++ xs (re.* r) r ys) (re.++ xs r (re.* r) ys)) +(define-rule* str-len-concat-rec ((s1 String) (s2 String) (s3 String :list)) + (str.len (str.++ s1 s2 s3)) + (str.len (str.++ s2 s3)) + (+ (str.len s1) _)) -- cgit v1.2.3