diff options
author | Andrew Reynolds <andrew.j.reynolds@gmail.com> | 2019-09-16 15:36:12 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-09-16 15:36:12 -0500 |
commit | 558baff63737f1441776ea69b893754ea02f680b (patch) | |
tree | ee61f63bf76a8ea2c83e4b6b6ed3fd74bb5910eb /src/theory/builtin | |
parent | f6cfc98f37bf92ccbf11aad20c1419071d8704f8 (diff) |
Fix HO model construction for functions having Boolean arguments (#3158)
Diffstat (limited to 'src/theory/builtin')
-rw-r--r-- | src/theory/builtin/theory_builtin_rewriter.cpp | 118 |
1 files changed, 88 insertions, 30 deletions
diff --git a/src/theory/builtin/theory_builtin_rewriter.cpp b/src/theory/builtin/theory_builtin_rewriter.cpp index 5b893ffc6..63d08b0ef 100644 --- a/src/theory/builtin/theory_builtin_rewriter.cpp +++ b/src/theory/builtin/theory_builtin_rewriter.cpp @@ -77,7 +77,6 @@ RewriteResponse TheoryBuiltinRewriter::postRewrite(TNode node) { Trace("builtin-rewrite") << "Rewriting lambda " << node << "..." << std::endl; Node anode = getArrayRepresentationForLambda( node ); if( !anode.isNull() ){ - anode = Rewriter::rewrite( anode ); Assert( anode.getType().isArray() ); //must get the standard bound variable list Node varList = NodeManager::currentNM()->getBoundVarListForFunctionType( node.getType() ); @@ -198,6 +197,7 @@ Node TheoryBuiltinRewriter::getArrayRepresentationForLambdaRec(TNode n, TypeNode retType) { Assert( n.getKind()==kind::LAMBDA ); + NodeManager* nm = NodeManager::currentNM(); Trace("builtin-rewrite-debug") << "Get array representation for : " << n << std::endl; Node first_arg = n[0][0]; @@ -207,33 +207,71 @@ Node TheoryBuiltinRewriter::getArrayRepresentationForLambdaRec(TNode n, for( unsigned i=1; i<n[0].getNumChildren(); i++ ){ args.push_back( n[0][i] ); } - rec_bvl = NodeManager::currentNM()->mkNode( kind::BOUND_VAR_LIST, args ); + rec_bvl = nm->mkNode(kind::BOUND_VAR_LIST, args); } Trace("builtin-rewrite-debug2") << " process body..." << std::endl; std::vector< Node > conds; std::vector< Node > vals; Node curr = n[1]; - while( curr.getKind()==kind::ITE || curr.getKind()==kind::EQUAL || curr.getKind()==kind::NOT ){ - Trace("builtin-rewrite-debug2") << " process condition : " << curr[0] << std::endl; + Kind ck = curr.getKind(); + while (ck == kind::ITE || ck == kind::EQUAL || ck == kind::NOT + || ck == kind::BOUND_VARIABLE) + { Node index_eq; Node curr_val; Node next; - if( curr.getKind()==kind::ITE ){ + // Each iteration of this loop infers an entry in the function, e.g. it + // has a value under some condition. + + // [1] We infer that the entry has value "curr_val" under condition + // "index_eq". We set "next" to the node that is the remainder of the + // function to process. + if (ck == kind::ITE) + { + Trace("builtin-rewrite-debug2") + << " process condition : " << curr[0] << std::endl; index_eq = curr[0]; curr_val = curr[1]; next = curr[2]; - }else{ - bool pol = curr.getKind()!=kind::NOT; - //Boolean case, e.g. lambda x. (= x v) is lambda x. (ite (= x v) true false) - index_eq = curr.getKind()==kind::NOT ? curr[0] : curr; - curr_val = NodeManager::currentNM()->mkConst( pol ); - next = NodeManager::currentNM()->mkConst( !pol ); } - if( index_eq.getKind()!=kind::EQUAL ){ - // non-equality condition - Trace("builtin-rewrite-debug2") << " ...non-equality condition." << std::endl; - return Node::null(); + else + { + Trace("builtin-rewrite-debug2") + << " process base : " << curr << std::endl; + // Boolean return case, e.g. lambda x. (= x v) becomes + // lambda x. (ite (= x v) true false) + index_eq = curr; + curr_val = nm->mkConst(true); + next = nm->mkConst(false); + } + + // [2] We ensure that "index_eq" is an equality, if possible. + if (index_eq.getKind() != kind::EQUAL) + { + bool pol = index_eq.getKind() != kind::NOT; + Node indexEqAtom = pol ? index_eq : index_eq[0]; + if (indexEqAtom.getKind() == kind::BOUND_VARIABLE) + { + if (!indexEqAtom.getType().isBoolean()) + { + // Catches default case of non-Boolean variable, e.g. + // lambda x : Int. x. In this case, it is not canonical and we fail. + Trace("builtin-rewrite-debug2") + << " ...non-Boolean variable." << std::endl; + return Node::null(); + } + // Boolean argument case, e.g. lambda x. ite( x, t, s ) is processed as + // lambda x. (ite (= x true) t s) + index_eq = indexEqAtom.eqNode(nm->mkConst(pol)); + } + else + { + // non-equality condition + Trace("builtin-rewrite-debug2") + << " ...non-equality condition." << std::endl; + return Node::null(); + } } else if (Rewriter::rewrite(index_eq) != index_eq) { @@ -242,6 +280,9 @@ Node TheoryBuiltinRewriter::getArrayRepresentationForLambdaRec(TNode n, return Node::null(); } + // [3] We ensure that "index_eq" is an equality that is equivalent to + // "first_arg" = "curr_index", where curr_index is a constant, and + // "first_arg" is the current argument we are processing, if possible. Node curr_index; for( unsigned r=0; r<2; r++ ){ Node arg = index_eq[r]; @@ -259,25 +300,36 @@ Node TheoryBuiltinRewriter::getArrayRepresentationForLambdaRec(TNode n, } } } - if( !curr_index.isNull() ){ - if( !rec_bvl.isNull() ){ - curr_val = NodeManager::currentNM()->mkNode( kind::LAMBDA, rec_bvl, curr_val ); - curr_val = getArrayRepresentationForLambdaRec(curr_val, retType); - if( curr_val.isNull() ){ - Trace("builtin-rewrite-debug2") << " ...non-constant value." << std::endl; - return Node::null(); - } - } - Trace("builtin-rewrite-debug2") << " ...condition is index " << curr_val << std::endl; - }else{ - Trace("builtin-rewrite-debug2") << " ...non-constant value." << std::endl; + if (curr_index.isNull()) + { + Trace("builtin-rewrite-debug2") + << " ...could not infer index value." << std::endl; return Node::null(); } + + // [4] Recurse to ensure that "curr_val" has been normalized w.r.t. the + // remaining arguments (rec_bvl). + if (!rec_bvl.isNull()) + { + curr_val = nm->mkNode(kind::LAMBDA, rec_bvl, curr_val); + curr_val = getArrayRepresentationForLambdaRec(curr_val, retType); + if (curr_val.isNull()) + { + Trace("builtin-rewrite-debug2") + << " ...failed to recursively find value." << std::endl; + return Node::null(); + } + } + Trace("builtin-rewrite-debug2") + << " ...condition is index " << curr_val << std::endl; + + // [5] Add the entry conds.push_back( curr_index ); vals.push_back( curr_val ); - TypeNode vtype = curr_val.getType(); - //recurse + + // we will now process the remainder curr = next; + ck = curr.getKind(); } if( !rec_bvl.isNull() ){ curr = NodeManager::currentNM()->mkNode( kind::LAMBDA, rec_bvl, curr ); @@ -314,7 +366,13 @@ Node TheoryBuiltinRewriter::getArrayRepresentationForLambda(TNode n) Assert( n.getKind()==kind::LAMBDA ); // must carry the overall return type to deal with cases like (lambda ((x Int)(y Int)) (ite (= x _) 0.5 0.0)), // where the inner construction for the else case about should be (arraystoreall (Array Int Real) 0.0) - return getArrayRepresentationForLambdaRec(n, n[1].getType()); + Node anode = getArrayRepresentationForLambdaRec(n, n[1].getType()); + if (anode.isNull()) + { + return anode; + } + // must rewrite it to make canonical + return Rewriter::rewrite(anode); } }/* CVC4::theory::builtin namespace */ |