diff options
Diffstat (limited to 'src/theory/quantifiers/term_util.cpp')
-rw-r--r-- | src/theory/quantifiers/term_util.cpp | 91 |
1 files changed, 57 insertions, 34 deletions
diff --git a/src/theory/quantifiers/term_util.cpp b/src/theory/quantifiers/term_util.cpp index 7cebf0e1e..b3915bd5d 100644 --- a/src/theory/quantifiers/term_util.cpp +++ b/src/theory/quantifiers/term_util.cpp @@ -267,51 +267,74 @@ Node TermUtil::substituteInstConstants(Node n, Node q, std::vector<Node>& terms) terms.end()); } -void TermUtil::computeVarContains( Node n, std::vector< Node >& varContains ) { - std::map< Node, bool > visited; - computeVarContains2( n, INST_CONSTANT, varContains, visited ); +void TermUtil::computeInstConstContains(Node n, std::vector<Node>& ics) +{ + computeVarContainsInternal(n, INST_CONSTANT, ics); } -void TermUtil::computeQuantContains( Node n, std::vector< Node >& quantContains ) { - std::map< Node, bool > visited; - computeVarContains2( n, FORALL, quantContains, visited ); +void TermUtil::computeVarContains(Node n, std::vector<Node>& vars) +{ + computeVarContainsInternal(n, BOUND_VARIABLE, vars); } +void TermUtil::computeQuantContains(Node n, std::vector<Node>& quants) +{ + computeVarContainsInternal(n, FORALL, quants); +} -void TermUtil::computeVarContains2( Node n, Kind k, std::vector< Node >& varContains, std::map< Node, bool >& visited ){ - if( visited.find( n )==visited.end() ){ - visited[n] = true; - if( n.getKind()==k ){ - if( std::find( varContains.begin(), varContains.end(), n )==varContains.end() ){ - varContains.push_back( n ); - } - }else{ - if (n.hasOperator()) +void TermUtil::computeVarContainsInternal(Node n, + Kind k, + std::vector<Node>& vars) +{ + std::unordered_set<TNode, TNodeHashFunction> visited; + std::unordered_set<TNode, TNodeHashFunction>::iterator it; + std::vector<TNode> visit; + TNode cur; + visit.push_back(n); + do + { + cur = visit.back(); + visit.pop_back(); + it = visited.find(cur); + + if (it == visited.end()) + { + visited.insert(cur); + if (cur.getKind() == k) { - computeVarContains2(n.getOperator(), k, varContains, visited); + if (std::find(vars.begin(), vars.end(), cur) == vars.end()) + { + vars.push_back(cur); + } } - for( unsigned i=0; i<n.getNumChildren(); i++ ){ - computeVarContains2( n[i], k, varContains, visited ); + else + { + if (cur.hasOperator()) + { + visit.push_back(cur.getOperator()); + } + for (const Node& cn : cur) + { + visit.push_back(cn); + } } } - } + } while (!visit.empty()); } -void TermUtil::getVarContains( Node f, std::vector< Node >& pats, std::map< Node, std::vector< Node > >& varContains ){ - for( unsigned i=0; i<pats.size(); i++ ){ - varContains[ pats[i] ].clear(); - getVarContainsNode( f, pats[i], varContains[ pats[i] ] ); - } -} - -void TermUtil::getVarContainsNode( Node f, Node n, std::vector< Node >& varContains ){ - std::vector< Node > vars; - computeVarContains( n, vars ); - for( unsigned j=0; j<vars.size(); j++ ){ - Node v = vars[j]; - if( v.getAttribute(InstConstantAttribute())==f ){ - if( std::find( varContains.begin(), varContains.end(), v )==varContains.end() ){ - varContains.push_back( v ); +void TermUtil::computeInstConstContainsForQuant(Node q, + Node n, + std::vector<Node>& vars) +{ + std::vector<Node> ics; + computeInstConstContains(n, ics); + for (const Node& v : ics) + { + if (v.getAttribute(InstConstantAttribute()) == q) + { + if (std::find(vars.begin(), vars.end(), v) == vars.end()) + { + vars.push_back(v); } } } |