summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Ozdemir <aozdemir@hmc.edu>2020-04-11 10:16:05 -0700
committerGitHub <noreply@github.com>2020-04-11 12:16:05 -0500
commit0073464e433e80311269ce07e0fa5de417b5eefc (patch)
treee42eafe127a7135057e98ca1c0222c548ac80975
parent4e310461b2e41f9ccf1426797b5d8b58e27bc1c7 (diff)
Add skip predicate to node traversal. (#4222)
Sometime you want to skip specific sub-DAGs when traversing a node. For example, you might be doing a transformation with a cache, and want to skip sub-DAGs that you've already processed. This PR would add a skipIf builder method to NodeDfsIterable, which allows the user to provide a predicate according to which nodes will be omitted from the subsequent traversal.
-rw-r--r--src/expr/node_traversal.cpp33
-rw-r--r--src/expr/node_traversal.h10
-rw-r--r--test/unit/expr/node_traversal_black.h32
3 files changed, 68 insertions, 7 deletions
diff --git a/src/expr/node_traversal.cpp b/src/expr/node_traversal.cpp
index 9e7a82c24..ad1a9ec71 100644
--- a/src/expr/node_traversal.cpp
+++ b/src/expr/node_traversal.cpp
@@ -16,11 +16,14 @@
namespace CVC4 {
-NodeDfsIterator::NodeDfsIterator(TNode n, bool postorder)
+NodeDfsIterator::NodeDfsIterator(TNode n,
+ bool postorder,
+ std::function<bool(TNode)> skipIf)
: d_stack{n},
d_visited(),
d_postorder(postorder),
- d_current(TNode())
+ d_current(TNode()),
+ d_skipIf(skipIf)
{
}
@@ -28,7 +31,8 @@ NodeDfsIterator::NodeDfsIterator(bool postorder)
: d_stack(),
d_visited(),
d_postorder(postorder),
- d_current(TNode())
+ d_current(TNode()),
+ d_skipIf([](TNode) { return false; })
{
}
@@ -64,7 +68,8 @@ bool NodeDfsIterator::operator==(const NodeDfsIterator& other) const
// The stack and current node uniquely represent traversal state. We need not
// use the scheduled node set.
//
- // Users should not compare iterators for traversals of different nodes.
+ // Users should not compare iterators for traversals of different nodes, or
+ // traversals with different skipIfs.
Assert(d_postorder == other.d_postorder);
return d_stack == other.d_stack && d_current == other.d_current;
}
@@ -84,6 +89,12 @@ void NodeDfsIterator::advanceToNextVisit()
if (visitEntry == d_visited.end())
{
// if we haven't pre-visited this node, pre-visit it
+ if (d_skipIf(back))
+ {
+ // actually, skip it if the skip predicate says so...
+ d_stack.pop_back();
+ continue;
+ }
d_visited[back] = false;
d_current = back;
// Use integer underflow to reverse-iterate
@@ -123,7 +134,10 @@ void NodeDfsIterator::initializeIfUninitialized()
}
}
-NodeDfsIterable::NodeDfsIterable(TNode n) : d_node(n), d_postorder(true) {}
+NodeDfsIterable::NodeDfsIterable(TNode n)
+ : d_node(n), d_postorder(true), d_skipIf([](TNode) { return false; })
+{
+}
NodeDfsIterable& NodeDfsIterable::inPostorder()
{
@@ -137,9 +151,16 @@ NodeDfsIterable& NodeDfsIterable::inPreorder()
return *this;
}
+NodeDfsIterable& NodeDfsIterable::skipIf(
+ std::function<bool(TNode)> skipCondition)
+{
+ d_skipIf = skipCondition;
+ return *this;
+}
+
NodeDfsIterator NodeDfsIterable::begin() const
{
- return NodeDfsIterator(d_node, d_postorder);
+ return NodeDfsIterator(d_node, d_postorder, d_skipIf);
}
NodeDfsIterator NodeDfsIterable::end() const
diff --git a/src/expr/node_traversal.h b/src/expr/node_traversal.h
index fffc1d746..1078f08c8 100644
--- a/src/expr/node_traversal.h
+++ b/src/expr/node_traversal.h
@@ -18,6 +18,7 @@
#define CVC4__EXPR__NODE_TRAVERSAL_H
#include <cstddef>
+#include <functional>
#include <iterator>
#include <unordered_map>
#include <vector>
@@ -39,7 +40,7 @@ class NodeDfsIterator
using difference_type = std::ptrdiff_t;
// Construct a traversal iterator beginning at `n`
- NodeDfsIterator(TNode n, bool postorder);
+ NodeDfsIterator(TNode n, bool postorder, std::function<bool(TNode)> skipIf);
// Construct an end-of-traversal iterator
NodeDfsIterator(bool postorder);
@@ -97,6 +98,9 @@ class NodeDfsIterator
// Current referent node. A valid node to visit if non-null.
// Null after construction (but before first access) and at the end.
TNode d_current;
+
+ // When to omit a node and its descendants from the traversal
+ std::function<bool(TNode)> d_skipIf;
};
// Node wrapper that is iterable in DAG post-order
@@ -111,6 +115,9 @@ class NodeDfsIterable
// Modify this iterable to be in pre-order
NodeDfsIterable& inPreorder();
+ // Skip a node (and its descendants) if true.
+ NodeDfsIterable& skipIf(std::function<bool(TNode)> skipCondition);
+
// Move/copy construction and assignment. Destructor.
NodeDfsIterable(NodeDfsIterable&&) = default;
NodeDfsIterable& operator=(NodeDfsIterable&&) = default;
@@ -124,6 +131,7 @@ class NodeDfsIterable
private:
TNode d_node;
bool d_postorder;
+ std::function<bool(TNode)> d_skipIf;
};
} // namespace CVC4
diff --git a/test/unit/expr/node_traversal_black.h b/test/unit/expr/node_traversal_black.h
index b4a7c449c..b751a0999 100644
--- a/test/unit/expr/node_traversal_black.h
+++ b/test/unit/expr/node_traversal_black.h
@@ -161,6 +161,22 @@ class NodePostorderTraversalBlack : public CxxTest::TestSuite
std::copy(traversal.begin(), traversal.end(), std::back_inserter(actual));
TS_ASSERT_EQUALS(actual, expected);
}
+
+ void testSkipIf()
+ {
+ Node tb = d_nodeManager->mkConst(true);
+ Node eb = d_nodeManager->mkConst(false);
+ Node cnd = d_nodeManager->mkNode(XOR, tb, eb);
+ Node top = d_nodeManager->mkNode(XOR, cnd, cnd);
+ std::vector<TNode> expected = {top};
+
+ auto traversal = NodeDfsIterable(top).inPostorder().skipIf(
+ [&cnd](TNode n) { return n == cnd; });
+
+ std::vector<TNode> actual;
+ std::copy(traversal.begin(), traversal.end(), std::back_inserter(actual));
+ TS_ASSERT_EQUALS(actual, expected);
+ }
};
class NodePreorderTraversalBlack : public CxxTest::TestSuite
@@ -278,4 +294,20 @@ class NodePreorderTraversalBlack : public CxxTest::TestSuite
std::copy(traversal.begin(), traversal.end(), std::back_inserter(actual));
TS_ASSERT_EQUALS(actual, expected);
}
+
+ void testSkipIf()
+ {
+ Node tb = d_nodeManager->mkConst(true);
+ Node eb = d_nodeManager->mkConst(false);
+ Node cnd = d_nodeManager->mkNode(XOR, tb, eb);
+ Node top = d_nodeManager->mkNode(XOR, cnd, cnd);
+ std::vector<TNode> expected = {top, cnd, eb};
+
+ auto traversal = NodeDfsIterable(top).inPreorder().skipIf(
+ [&tb](TNode n) { return n == tb; });
+
+ std::vector<TNode> actual;
+ std::copy(traversal.begin(), traversal.end(), std::back_inserter(actual));
+ TS_ASSERT_EQUALS(actual, expected);
+ }
};
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback