summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>2019-09-12 17:31:22 -0500
committerGitHub <noreply@github.com>2019-09-12 17:31:22 -0500
commita117e2b45539a822aa480b90558c2c0da6031dd9 (patch)
tree5006484b1794943a1f049247ebbb2a63cb82dbfb
parentbee3c7f6840e531bc91d990b98f2b331d1f2f82c (diff)
Update to standard implementation of contains term (#3270)
-rw-r--r--src/expr/node_algorithm.cpp77
-rw-r--r--src/expr/node_algorithm.h9
-rw-r--r--src/theory/quantifiers/cegqi/ceg_instantiator.cpp3
-rw-r--r--src/theory/quantifiers/instantiate.cpp4
-rw-r--r--src/theory/quantifiers/quantifiers_rewriter.cpp7
-rw-r--r--src/theory/quantifiers/term_util.cpp44
-rw-r--r--src/theory/quantifiers/term_util.h4
7 files changed, 94 insertions, 54 deletions
diff --git a/src/expr/node_algorithm.cpp b/src/expr/node_algorithm.cpp
index 50ac8297c..59e3d3b03 100644
--- a/src/expr/node_algorithm.cpp
+++ b/src/expr/node_algorithm.cpp
@@ -34,16 +34,26 @@ bool hasSubterm(TNode n, TNode t, bool strict)
toProcess.push_back(n);
+ // incrementally iterate and add to toProcess
for (unsigned i = 0; i < toProcess.size(); ++i)
{
TNode current = toProcess[i];
- if (current.hasOperator() && current.getOperator() == t)
+ for (unsigned j = 0, j_end = current.getNumChildren(); j <= j_end; ++j)
{
- return true;
- }
- for (unsigned j = 0, j_end = current.getNumChildren(); j < j_end; ++j)
- {
- TNode child = current[j];
+ TNode child;
+ // try children then operator
+ if (j < j_end)
+ {
+ child = current[j];
+ }
+ else if (current.hasOperator())
+ {
+ child = current.getOperator();
+ }
+ else
+ {
+ break;
+ }
if (child == t)
{
return true;
@@ -118,6 +128,61 @@ bool hasSubtermMulti(TNode n, TNode t)
return false;
}
+bool hasSubterm(TNode n, const std::vector<Node>& t, bool strict)
+{
+ if (t.empty())
+ {
+ return false;
+ }
+ if (!strict && std::find(t.begin(), t.end(), n) != t.end())
+ {
+ return true;
+ }
+
+ std::unordered_set<TNode, TNodeHashFunction> visited;
+ std::vector<TNode> toProcess;
+
+ toProcess.push_back(n);
+
+ // incrementally iterate and add to toProcess
+ for (unsigned i = 0; i < toProcess.size(); ++i)
+ {
+ TNode current = toProcess[i];
+ for (unsigned j = 0, j_end = current.getNumChildren(); j <= j_end; ++j)
+ {
+ TNode child;
+ // try children then operator
+ if (j < j_end)
+ {
+ child = current[j];
+ }
+ else if (current.hasOperator())
+ {
+ child = current.getOperator();
+ }
+ else
+ {
+ break;
+ }
+ if (std::find(t.begin(), t.end(), child) != t.end())
+ {
+ return true;
+ }
+ if (visited.find(child) != visited.end())
+ {
+ continue;
+ }
+ else
+ {
+ visited.insert(child);
+ toProcess.push_back(child);
+ }
+ }
+ }
+
+ return false;
+}
+
struct HasBoundVarTag
{
};
diff --git a/src/expr/node_algorithm.h b/src/expr/node_algorithm.h
index 17d7d951b..e5a21d565 100644
--- a/src/expr/node_algorithm.h
+++ b/src/expr/node_algorithm.h
@@ -45,6 +45,15 @@ bool hasSubterm(TNode n, TNode t, bool strict = false);
bool hasSubtermMulti(TNode n, TNode t);
/**
+ * Check if the node n has a subterm that occurs in t.
+ * @param n The node to search in
+ * @param t The set of subterms to search for
+ * @param strict If true, a term is not considered to be a subterm of itself
+ * @return true iff there is a term in t that is a subterm in n
+ */
+bool hasSubterm(TNode n, const std::vector<Node>& t, bool strict = false);
+
+/**
* Returns true iff the node n contains a bound variable, that is a node of
* kind BOUND_VARIABLE. This bound variable may or may not be free.
* @param n The node under investigation
diff --git a/src/theory/quantifiers/cegqi/ceg_instantiator.cpp b/src/theory/quantifiers/cegqi/ceg_instantiator.cpp
index 104e40d8b..1713c21e2 100644
--- a/src/theory/quantifiers/cegqi/ceg_instantiator.cpp
+++ b/src/theory/quantifiers/cegqi/ceg_instantiator.cpp
@@ -1245,7 +1245,8 @@ Node CegInstantiator::applySubstitution( TypeNode tn, Node n, std::vector< Node
Node nretc = children.size()==1 ? children[0] : NodeManager::currentNM()->mkNode( PLUS, children );
nretc = Rewriter::rewrite( nretc );
//ensure that nret does not contain vars
- if( !TermUtil::containsTerms( nretc, vars ) ){
+ if (!expr::hasSubterm(nretc, vars))
+ {
//result is ( nret / pv_prop.d_coeff )
nret = nretc;
}else{
diff --git a/src/theory/quantifiers/instantiate.cpp b/src/theory/quantifiers/instantiate.cpp
index ea90ddd66..c6427a4c4 100644
--- a/src/theory/quantifiers/instantiate.cpp
+++ b/src/theory/quantifiers/instantiate.cpp
@@ -14,6 +14,7 @@
#include "theory/quantifiers/instantiate.h"
+#include "expr/node_algorithm.h"
#include "options/quantifiers_options.h"
#include "smt/smt_statistics_registry.h"
#include "theory/quantifiers/cegqi/inst_strategy_cegqi.h"
@@ -170,8 +171,7 @@ bool Instantiate::addInstantiation(
<< terms[i] << std::endl;
bad_inst = true;
}
- else if (quantifiers::TermUtil::containsTerms(
- terms[i], d_term_util->d_inst_constants[q]))
+ else if (expr::hasSubterm(terms[i], d_term_util->d_inst_constants[q]))
{
Trace("inst") << "***& inst contains inst constants : " << terms[i]
<< std::endl;
diff --git a/src/theory/quantifiers/quantifiers_rewriter.cpp b/src/theory/quantifiers/quantifiers_rewriter.cpp
index f5159a630..33da46675 100644
--- a/src/theory/quantifiers/quantifiers_rewriter.cpp
+++ b/src/theory/quantifiers/quantifiers_rewriter.cpp
@@ -1808,9 +1808,10 @@ Node QuantifiersRewriter::computeMiniscoping( std::vector< Node >& args, Node bo
Node newBody = body;
NodeBuilder<> body_split(kind::OR);
NodeBuilder<> tb(kind::OR);
- for( unsigned i=0; i<body.getNumChildren(); i++ ){
- Node trm = body[i];
- if( TermUtil::containsTerms( body[i], args ) ){
+ for (const Node& trm : body)
+ {
+ if (expr::hasSubterm(trm, args))
+ {
tb << trm;
}else{
body_split << trm;
diff --git a/src/theory/quantifiers/term_util.cpp b/src/theory/quantifiers/term_util.cpp
index ffd808ed3..48dc88537 100644
--- a/src/theory/quantifiers/term_util.cpp
+++ b/src/theory/quantifiers/term_util.cpp
@@ -489,15 +489,17 @@ Node TermUtil::rewriteVtsSymbols( Node n ) {
bool TermUtil::containsVtsTerm( Node n, bool isFree ) {
std::vector< Node > t;
getVtsTerms( t, isFree, false );
- return containsTerms( n, t );
+ return expr::hasSubterm(n, t);
}
bool TermUtil::containsVtsTerm( std::vector< Node >& n, bool isFree ) {
std::vector< Node > t;
getVtsTerms( t, isFree, false );
if( !t.empty() ){
- for( unsigned i=0; i<n.size(); i++ ){
- if( containsTerms( n[i], t ) ){
+ for (const Node& nc : n)
+ {
+ if (expr::hasSubterm(nc, t))
+ {
return true;
}
}
@@ -508,7 +510,7 @@ bool TermUtil::containsVtsTerm( std::vector< Node >& n, bool isFree ) {
bool TermUtil::containsVtsInfinity( Node n, bool isFree ) {
std::vector< Node > t;
getVtsTerms( t, isFree, false, false );
- return containsTerms( n, t );
+ return expr::hasSubterm(n, t);
}
Node TermUtil::ensureType( Node n, TypeNode tn ) {
@@ -524,40 +526,6 @@ Node TermUtil::ensureType( Node n, TypeNode tn ) {
}
}
-bool TermUtil::containsTerms2( Node n, std::vector< Node >& t, std::map< Node, bool >& visited ) {
- if (visited.find(n) == visited.end())
- {
- if( std::find( t.begin(), t.end(), n )!=t.end() ){
- return true;
- }
- visited[n] = true;
- if (n.hasOperator())
- {
- if (containsTerms2(n.getOperator(), t, visited))
- {
- return true;
- }
- }
- for (const Node& nc : n)
- {
- if (containsTerms2(nc, t, visited))
- {
- return true;
- }
- }
- }
- return false;
-}
-
-bool TermUtil::containsTerms( Node n, std::vector< Node >& t ) {
- if( t.empty() ){
- return false;
- }else{
- std::map< Node, bool > visited;
- return containsTerms2( n, t, visited );
- }
-}
-
int TermUtil::getTermDepth( Node n ) {
if (!n.hasAttribute(TermDepthAttribute()) ){
int maxDepth = -1;
diff --git a/src/theory/quantifiers/term_util.h b/src/theory/quantifiers/term_util.h
index b39a4e129..99ea483d9 100644
--- a/src/theory/quantifiers/term_util.h
+++ b/src/theory/quantifiers/term_util.h
@@ -219,8 +219,6 @@ public:
//general utilities
// TODO #1216 : promote these?
private:
- //helper for contains term
- static bool containsTerms2( Node n, std::vector< Node >& t, std::map< Node, bool >& visited );
/** cache for getTypeValue */
std::unordered_map<TypeNode,
std::unordered_map<int, Node>,
@@ -244,8 +242,6 @@ public:
d_type_value_offset_status;
public:
- /** simple check for contains term, true if contains at least one term in t */
- static bool containsTerms( Node n, std::vector< Node >& t );
/** contains uninterpreted constant */
static bool containsUninterpretedConstant( Node n );
/** get the term depth of n */
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback