From 0073464e433e80311269ce07e0fa5de417b5eefc Mon Sep 17 00:00:00 2001 From: Alex Ozdemir Date: Sat, 11 Apr 2020 10:16:05 -0700 Subject: Add skip predicate to node traversal. (#4222) Sometime you want to skip specific sub-DAGs when traversing a node. For example, you might be doing a transformation with a cache, and want to skip sub-DAGs that you've already processed. This PR would add a skipIf builder method to NodeDfsIterable, which allows the user to provide a predicate according to which nodes will be omitted from the subsequent traversal. --- src/expr/node_traversal.cpp | 33 +++++++++++++++++++++++++++------ src/expr/node_traversal.h | 10 +++++++++- 2 files changed, 36 insertions(+), 7 deletions(-) (limited to 'src') diff --git a/src/expr/node_traversal.cpp b/src/expr/node_traversal.cpp index 9e7a82c24..ad1a9ec71 100644 --- a/src/expr/node_traversal.cpp +++ b/src/expr/node_traversal.cpp @@ -16,11 +16,14 @@ namespace CVC4 { -NodeDfsIterator::NodeDfsIterator(TNode n, bool postorder) +NodeDfsIterator::NodeDfsIterator(TNode n, + bool postorder, + std::function skipIf) : d_stack{n}, d_visited(), d_postorder(postorder), - d_current(TNode()) + d_current(TNode()), + d_skipIf(skipIf) { } @@ -28,7 +31,8 @@ NodeDfsIterator::NodeDfsIterator(bool postorder) : d_stack(), d_visited(), d_postorder(postorder), - d_current(TNode()) + d_current(TNode()), + d_skipIf([](TNode) { return false; }) { } @@ -64,7 +68,8 @@ bool NodeDfsIterator::operator==(const NodeDfsIterator& other) const // The stack and current node uniquely represent traversal state. We need not // use the scheduled node set. // - // Users should not compare iterators for traversals of different nodes. + // Users should not compare iterators for traversals of different nodes, or + // traversals with different skipIfs. Assert(d_postorder == other.d_postorder); return d_stack == other.d_stack && d_current == other.d_current; } @@ -84,6 +89,12 @@ void NodeDfsIterator::advanceToNextVisit() if (visitEntry == d_visited.end()) { // if we haven't pre-visited this node, pre-visit it + if (d_skipIf(back)) + { + // actually, skip it if the skip predicate says so... + d_stack.pop_back(); + continue; + } d_visited[back] = false; d_current = back; // Use integer underflow to reverse-iterate @@ -123,7 +134,10 @@ void NodeDfsIterator::initializeIfUninitialized() } } -NodeDfsIterable::NodeDfsIterable(TNode n) : d_node(n), d_postorder(true) {} +NodeDfsIterable::NodeDfsIterable(TNode n) + : d_node(n), d_postorder(true), d_skipIf([](TNode) { return false; }) +{ +} NodeDfsIterable& NodeDfsIterable::inPostorder() { @@ -137,9 +151,16 @@ NodeDfsIterable& NodeDfsIterable::inPreorder() return *this; } +NodeDfsIterable& NodeDfsIterable::skipIf( + std::function skipCondition) +{ + d_skipIf = skipCondition; + return *this; +} + NodeDfsIterator NodeDfsIterable::begin() const { - return NodeDfsIterator(d_node, d_postorder); + return NodeDfsIterator(d_node, d_postorder, d_skipIf); } NodeDfsIterator NodeDfsIterable::end() const diff --git a/src/expr/node_traversal.h b/src/expr/node_traversal.h index fffc1d746..1078f08c8 100644 --- a/src/expr/node_traversal.h +++ b/src/expr/node_traversal.h @@ -18,6 +18,7 @@ #define CVC4__EXPR__NODE_TRAVERSAL_H #include +#include #include #include #include @@ -39,7 +40,7 @@ class NodeDfsIterator using difference_type = std::ptrdiff_t; // Construct a traversal iterator beginning at `n` - NodeDfsIterator(TNode n, bool postorder); + NodeDfsIterator(TNode n, bool postorder, std::function skipIf); // Construct an end-of-traversal iterator NodeDfsIterator(bool postorder); @@ -97,6 +98,9 @@ class NodeDfsIterator // Current referent node. A valid node to visit if non-null. // Null after construction (but before first access) and at the end. TNode d_current; + + // When to omit a node and its descendants from the traversal + std::function d_skipIf; }; // Node wrapper that is iterable in DAG post-order @@ -111,6 +115,9 @@ class NodeDfsIterable // Modify this iterable to be in pre-order NodeDfsIterable& inPreorder(); + // Skip a node (and its descendants) if true. + NodeDfsIterable& skipIf(std::function skipCondition); + // Move/copy construction and assignment. Destructor. NodeDfsIterable(NodeDfsIterable&&) = default; NodeDfsIterable& operator=(NodeDfsIterable&&) = default; @@ -124,6 +131,7 @@ class NodeDfsIterable private: TNode d_node; bool d_postorder; + std::function d_skipIf; }; } // namespace CVC4 -- cgit v1.2.3