diff options
Diffstat (limited to 'src/expr/node.h')
-rw-r--r-- | src/expr/node.h | 169 |
1 files changed, 154 insertions, 15 deletions
diff --git a/src/expr/node.h b/src/expr/node.h index a40b3fce5..9351293f8 100644 --- a/src/expr/node.h +++ b/src/expr/node.h @@ -11,7 +11,7 @@ ** See the file COPYING in the top-level source directory for licensing ** information.\endverbatim ** - ** \brief Reference-counted encapsulation of a pointer to node information. + ** \brief Reference-counted encapsulation of a pointer to node information ** ** Reference-counted encapsulation of a pointer to node information. **/ @@ -29,7 +29,7 @@ #include <iostream> #include <stdint.h> -#include "type.h" +#include "expr/type.h" #include "expr/kind.h" #include "expr/metakind.h" #include "expr/expr.h" @@ -38,6 +38,8 @@ #include "util/output.h" #include "util/exception.h" #include "util/language.h" +#include "util/utility.h" +#include "util/hash.h" namespace CVC4 { @@ -156,6 +158,16 @@ namespace kind { template <bool ref_count> class NodeTemplate { + // for hash_maps, hash_sets.. + template <bool ref_count1> + struct HashFunction { + size_t operator()(CVC4::NodeTemplate<ref_count1> node) const { + return (size_t) node.getId(); + } + };/* struct HashFunction */ + + typedef HashFunction<false> TNodeHashFunction; + /** * The NodeValue has access to the private constructors, so that the * iterators can can create new nodes. @@ -209,6 +221,30 @@ class NodeTemplate { } } + /** + * Cache-aware, recursive version of substitute() used by the public + * member function with a similar signature. + */ + Node substitute(TNode node, TNode replacement, + std::hash_map<TNode, TNode, TNodeHashFunction>& cache) const; + + /** + * Cache-aware, recursive version of substitute() used by the public + * member function with a similar signature. + */ + template <class Iterator1, class Iterator2> + Node substitute(Iterator1 nodesBegin, Iterator1 nodesEnd, + Iterator2 replacementsBegin, Iterator2 replacementsEnd, + std::hash_map<TNode, TNode, TNodeHashFunction>& cache) const; + + /** + * Cache-aware, recursive version of substitute() used by the public + * member function with a similar signature. + */ + template <class Iterator> + Node substitute(Iterator substitutionsBegin, Iterator substitutionsEnd, + std::hash_map<TNode, TNode, TNodeHashFunction>& cache) const; + public: /** Default constructor, makes a null expression. */ @@ -444,7 +480,7 @@ public: * type checking is not requested, getType() will do the minimum * amount of checking required to return a valid result. * - * @param check whether we should check the type as we compute it + * @param check whether we should check the type as we compute it * (default: false) */ TypeNode getType(bool check = false) const @@ -456,7 +492,9 @@ public: Node substitute(TNode node, TNode replacement) const; /** - * Simultaneous substitution of Nodes. + * Simultaneous substitution of Nodes. Elements in the Iterator1 + * range will be replaced by their corresponding element in the + * Iterator2 range. Both ranges should have the same size. */ template <class Iterator1, class Iterator2> Node substitute(Iterator1 nodesBegin, @@ -465,6 +503,14 @@ public: Iterator2 replacementsEnd) const; /** + * Simultaneous substitution of Nodes. Iterators should be over + * pairs (x,y) for the rewrites [x->y]. + */ + template <class Iterator> + Node substitute(Iterator substitutionsBegin, + Iterator substitutionsEnd) const; + + /** * Returns the kind of this node. * @return the kind */ @@ -1146,39 +1192,81 @@ TypeNode NodeTemplate<ref_count>::getType(bool check) const } template <bool ref_count> -Node NodeTemplate<ref_count>::substitute(TNode node, - TNode replacement) const { +inline Node +NodeTemplate<ref_count>::substitute(TNode node, TNode replacement) const { + std::hash_map<TNode, TNode, TNodeHashFunction> cache; + return substitute(node, replacement, cache); +} + +template <bool ref_count> +Node +NodeTemplate<ref_count>::substitute(TNode node, TNode replacement, + std::hash_map<TNode, TNode, TNodeHashFunction>& cache) const { + // in cache? + typename std::hash_map<TNode, TNode, TNodeHashFunction>::const_iterator i = cache.find(*this); + if(i != cache.end()) { + return (*i).second; + } + + // otherwise compute NodeBuilder<> nb(getKind()); if(getMetaKind() == kind::metakind::PARAMETERIZED) { // push the operator nb << getOperator(); } - for(TNode::const_iterator i = begin(), + for(const_iterator i = begin(), iend = end(); i != iend; ++i) { if(*i == node) { nb << replacement; } else { - (*i).substitute(node, replacement); + (*i).substitute(node, replacement, cache); } } + + // put in cache Node n = nb; + cache[*this] = n; return n; } template <bool ref_count> template <class Iterator1, class Iterator2> -Node NodeTemplate<ref_count>::substitute(Iterator1 nodesBegin, - Iterator1 nodesEnd, - Iterator2 replacementsBegin, - Iterator2 replacementsEnd) const { +inline Node +NodeTemplate<ref_count>::substitute(Iterator1 nodesBegin, + Iterator1 nodesEnd, + Iterator2 replacementsBegin, + Iterator2 replacementsEnd) const { + std::hash_map<TNode, TNode, TNodeHashFunction> cache; + return substitute(nodesBegin, nodesEnd, + replacementsBegin, replacementsEnd, cache); +} + +template <bool ref_count> +template <class Iterator1, class Iterator2> +Node +NodeTemplate<ref_count>::substitute(Iterator1 nodesBegin, + Iterator1 nodesEnd, + Iterator2 replacementsBegin, + Iterator2 replacementsEnd, + std::hash_map<TNode, TNode, TNodeHashFunction>& cache) const { + // in cache? + typename std::hash_map<TNode, TNode, TNodeHashFunction>::const_iterator i = cache.find(*this); + if(i != cache.end()) { + return (*i).second; + } + + // otherwise compute Assert( nodesEnd - nodesBegin == replacementsEnd - replacementsBegin, "Substitution iterator ranges must be equal size" ); Iterator1 j = find(nodesBegin, nodesEnd, *this); if(j != nodesEnd) { - return *(replacementsBegin + (j - nodesBegin)); + Node n = *(replacementsBegin + (j - nodesBegin)); + cache[*this] = n; + return n; } else if(getNumChildren() == 0) { + cache[*this] = *this; return *this; } else { NodeBuilder<> nb(getKind()); @@ -1186,14 +1274,65 @@ Node NodeTemplate<ref_count>::substitute(Iterator1 nodesBegin, // push the operator nb << getOperator(); } - for(TNode::const_iterator i = begin(), + for(const_iterator i = begin(), iend = end(); i != iend; ++i) { nb << (*i).substitute(nodesBegin, nodesEnd, - replacementsBegin, replacementsEnd); + replacementsBegin, replacementsEnd, + cache); + } + Node n = nb; + cache[*this] = n; + return n; + } +} + +template <bool ref_count> +template <class Iterator> +inline Node +NodeTemplate<ref_count>::substitute(Iterator substitutionsBegin, + Iterator substitutionsEnd) const { + std::hash_map<TNode, TNode, TNodeHashFunction> cache; + return substitute(substitutionsBegin, substitutionsEnd, cache); +} + +template <bool ref_count> +template <class Iterator> +Node +NodeTemplate<ref_count>::substitute(Iterator substitutionsBegin, + Iterator substitutionsEnd, + std::hash_map<TNode, TNode, TNodeHashFunction>& cache) const { + // in cache? + typename std::hash_map<TNode, TNode, TNodeHashFunction>::const_iterator i = cache.find(*this); + if(i != cache.end()) { + return (*i).second; + } + + // otherwise compute + Iterator j = find_if(substitutionsBegin, substitutionsEnd, + bind2nd(first_equal_to<typename Iterator::value_type::first_type, typename Iterator::value_type::second_type>(), *this)); + if(j != substitutionsEnd) { + Node n = (*j).second; + cache[*this] = n; + return n; + } else if(getNumChildren() == 0) { + cache[*this] = *this; + return *this; + } else { + NodeBuilder<> nb(getKind()); + if(getMetaKind() == kind::metakind::PARAMETERIZED) { + // push the operator + nb << getOperator(); + } + for(const_iterator i = begin(), + iend = end(); + i != iend; + ++i) { + nb << (*i).substitute(substitutionsBegin, substitutionsEnd, cache); } Node n = nb; + cache[*this] = n; return n; } } |