summaryrefslogtreecommitdiff
path: root/src/preprocessing/passes/bool_to_bv.cpp
diff options
context:
space:
mode:
authormakaimann <makaim@stanford.edu>2018-12-10 08:37:11 -0800
committerMathias Preiner <mathias.preiner@gmail.com>2018-12-10 08:37:11 -0800
commite1dc39321cd4ab29b436025badfb05714f5649b3 (patch)
treec2f02cd7370157fbea51ec6602ad174b149cd850 /src/preprocessing/passes/bool_to_bv.cpp
parent7270b2a800c45fa87ef4cdcad8fc353ccb8cd471 (diff)
BoolToBV modes (off, ite, all) (#2530)
Diffstat (limited to 'src/preprocessing/passes/bool_to_bv.cpp')
-rw-r--r--src/preprocessing/passes/bool_to_bv.cpp326
1 files changed, 198 insertions, 128 deletions
diff --git a/src/preprocessing/passes/bool_to_bv.cpp b/src/preprocessing/passes/bool_to_bv.cpp
index c8a59bdc4..252ab941c 100644
--- a/src/preprocessing/passes/bool_to_bv.cpp
+++ b/src/preprocessing/passes/bool_to_bv.cpp
@@ -9,17 +9,17 @@
** All rights reserved. See the file COPYING in the top-level source
** directory for licensing information.\endverbatim
**
- ** \brief The BoolToBv preprocessing pass
+ ** \brief The BoolToBV preprocessing pass
**
**/
#include "preprocessing/passes/bool_to_bv.h"
#include <string>
-#include <unordered_map>
-#include <vector>
+#include "base/map_util.h"
#include "expr/node.h"
+#include "options/bv_options.h"
#include "smt/smt_statistics_registry.h"
#include "theory/rewriter.h"
#include "theory/theory.h"
@@ -30,183 +30,253 @@ namespace passes {
using namespace CVC4::theory;
BoolToBV::BoolToBV(PreprocessingPassContext* preprocContext)
- : PreprocessingPass(preprocContext, "bool-to-bv"),
- d_lowerCache(),
- d_one(bv::utils::mkOne(1)),
- d_zero(bv::utils::mkZero(1)),
- d_statistics(){};
+ : PreprocessingPass(preprocContext, "bool-to-bv"), d_statistics(){};
PreprocessingPassResult BoolToBV::applyInternal(
AssertionPipeline* assertionsToPreprocess)
{
NodeManager::currentResourceManager()->spendResource(
options::preprocessStep());
- std::vector<Node> new_assertions;
- lowerBoolToBv(assertionsToPreprocess->ref(), new_assertions);
- for (unsigned i = 0; i < assertionsToPreprocess->size(); ++i)
+
+ unsigned size = assertionsToPreprocess->size();
+ for (unsigned i = 0; i < size; ++i)
{
- assertionsToPreprocess->replace(i, Rewriter::rewrite(new_assertions[i]));
+ assertionsToPreprocess->replace(
+ i, Rewriter::rewrite(lowerAssertion((*assertionsToPreprocess)[i])));
}
- return PreprocessingPassResult::NO_CONFLICT;
-}
-void BoolToBV::addToLowerCache(TNode term, Node new_term)
-{
- Assert(new_term != Node());
- Assert(!hasLowerCache(term));
- d_lowerCache[term] = new_term;
+ return PreprocessingPassResult::NO_CONFLICT;
}
-Node BoolToBV::getLowerCache(TNode term) const
+Node BoolToBV::fromCache(TNode n) const
{
- Assert(hasLowerCache(term));
- return d_lowerCache.find(term)->second;
+ if (d_lowerCache.find(n) != d_lowerCache.end())
+ {
+ return d_lowerCache.find(n)->second;
+ }
+ return n;
}
-bool BoolToBV::hasLowerCache(TNode term) const
+bool BoolToBV::needToRebuild(TNode n) const
{
- return d_lowerCache.find(term) != d_lowerCache.end();
+ // check if any children were rebuilt
+ for (const Node& nn : n)
+ {
+ if (ContainsKey(d_lowerCache, nn))
+ {
+ return true;
+ }
+ }
+ return false;
}
-Node BoolToBV::lowerNode(TNode current, bool topLevel)
+Node BoolToBV::lowerAssertion(const TNode& a)
{
- Node result;
+ bool optionITE = options::boolToBitvector() == BOOL_TO_BV_ITE;
NodeManager* nm = NodeManager::currentNM();
- if (hasLowerCache(current))
- {
- result = getLowerCache(current);
- }
- else
+ std::vector<TNode> visit;
+ visit.push_back(a);
+ std::unordered_set<TNode, TNodeHashFunction> visited;
+ // for ite mode, keeps track of whether you're in an ite condition
+ // for all mode, unused
+ std::unordered_set<TNode, TNodeHashFunction> ite_cond;
+
+ while (!visit.empty())
{
- if (current.getNumChildren() == 0)
+ TNode n = visit.back();
+ visit.pop_back();
+
+ int numChildren = n.getNumChildren();
+ Kind k = n.getKind();
+ Debug("bool-to-bv") << "BoolToBV::lowerAssertion Post-traversal with " << n
+ << " and visited = " << ContainsKey(visited, n)
+ << std::endl;
+
+ // Mark as visited
+ /* Optimization: if it's a leaf, don't need to wait to do the work */
+ if (!ContainsKey(visited, n) && (numChildren > 0))
{
- if (current.getKind() == kind::CONST_BOOLEAN)
+ visited.insert(n);
+ visit.push_back(n);
+
+ // insert children in reverse order so that they're processed in order
+ // important for rewriting which sorts by node id
+ for (int i = numChildren - 1; i >= 0; --i)
{
- result = (current == bv::utils::mkTrue()) ? d_one : d_zero;
+ visit.push_back(n[i]);
}
- else
+
+ if (optionITE)
{
- result = current;
+ // check for ite-conditions
+ if (k == kind::ITE)
+ {
+ ite_cond.insert(n[0]);
+ }
+ else if (ContainsKey(ite_cond, n))
+ {
+ // being part of an ite condition is inherited from the parent
+ ite_cond.insert(n.begin(), n.end());
+ }
}
}
+ /* Optimization for ite mode */
+ else if (optionITE && !ContainsKey(ite_cond, n) && !needToRebuild(n))
+ {
+ Debug("bool-to-bv")
+ << "BoolToBV::lowerAssertion Skipping because don't need to rebuild: "
+ << n << std::endl;
+ // in ite mode, if you've already visited the node but it's not
+ // in an ite condition and doesn't need to be rebuilt, then
+ // don't need to do anything
+ continue;
+ }
else
{
- Kind kind = current.getKind();
- Kind new_kind = kind;
- switch (kind)
- {
- case kind::EQUAL:
- if (current[0].getType().isBitVector()
- || current[0].getType().isBoolean())
- {
- new_kind = kind::BITVECTOR_COMP;
- }
- break;
- case kind::AND: new_kind = kind::BITVECTOR_AND; break;
- case kind::OR: new_kind = kind::BITVECTOR_OR; break;
- case kind::NOT: new_kind = kind::BITVECTOR_NOT; break;
- case kind::XOR: new_kind = kind::BITVECTOR_XOR; break;
- case kind::IMPLIES: new_kind = kind::BITVECTOR_OR; break;
- case kind::ITE:
- if (current.getType().isBitVector() || current.getType().isBoolean())
- {
- new_kind = kind::BITVECTOR_ITE;
- }
- break;
- case kind::BITVECTOR_ULT: new_kind = kind::BITVECTOR_ULTBV; break;
- case kind::BITVECTOR_SLT: new_kind = kind::BITVECTOR_SLTBV; break;
- case kind::BITVECTOR_ULE:
- case kind::BITVECTOR_UGT:
- case kind::BITVECTOR_UGE:
- case kind::BITVECTOR_SLE:
- case kind::BITVECTOR_SGT:
- case kind::BITVECTOR_SGE:
- // Should have been removed by rewriting.
- Unreachable();
- default: break;
- }
- NodeBuilder<> builder(new_kind);
- if (kind != new_kind)
- {
- ++(d_statistics.d_numTermsLowered);
- }
- if (current.getMetaKind() == kind::metakind::PARAMETERIZED)
- {
- builder << current.getOperator();
- }
- Node converted;
- if (new_kind == kind::ITE)
+ lowerNode(n);
+ }
+ }
+
+ if (fromCache(a).getType().isBitVector())
+ {
+ return nm->mkNode(kind::EQUAL, fromCache(a), bv::utils::mkOne(1));
+ }
+ else
+ {
+ Assert(a == fromCache(a));
+ return a;
+ }
+}
+
+void BoolToBV::lowerNode(const TNode& n)
+{
+ NodeManager* nm = NodeManager::currentNM();
+ Kind k = n.getKind();
+
+ bool all_bv = true;
+ // check if it was able to convert all children to bitvectors
+ for (const Node& nn : n)
+ {
+ all_bv = all_bv && fromCache(nn).getType().isBitVector();
+ if (!all_bv)
+ {
+ break;
+ }
+ }
+
+ if (!all_bv || (n.getNumChildren() == 0))
+ {
+ if ((options::boolToBitvector() == BOOL_TO_BV_ALL)
+ && n.getType().isBoolean())
+ {
+ if (k == kind::CONST_BOOLEAN)
{
- // Special-case ITE because need condition to be Boolean.
- converted = lowerNode(current[0], true);
- builder << converted;
- converted = lowerNode(current[1]);
- builder << converted;
- converted = lowerNode(current[2]);
- builder << converted;
- }
- else if (kind == kind::IMPLIES) {
- // Special-case IMPLIES because needs to be rewritten.
- converted = lowerNode(current[0]);
- builder << nm->mkNode(kind::BITVECTOR_NOT, converted);
- converted = lowerNode(current[1]);
- builder << converted;
+ d_lowerCache[n] = (n == bv::utils::mkTrue()) ? bv::utils::mkOne(1)
+ : bv::utils::mkZero(1);
}
else
{
- for (unsigned i = 0; i < current.getNumChildren(); ++i)
- {
- converted = lowerNode(current[i]);
- builder << converted;
- }
+ d_lowerCache[n] =
+ nm->mkNode(kind::ITE, n, bv::utils::mkOne(1), bv::utils::mkZero(1));
}
- result = builder;
+
+ Debug("bool-to-bv") << "BoolToBV::lowerNode " << n << " =>\n"
+ << fromCache(n) << std::endl;
+ ++(d_statistics.d_numTermsForcedLowered);
+ return;
}
- if (result.getType().isBoolean())
+ else
{
- ++(d_statistics.d_numTermsForcedLowered);
- result = nm->mkNode(kind::ITE, result, d_one, d_zero);
+ // invariant
+ // either one of the children is not a bit-vector or bool
+ // i.e. something that can't be 'forced' to a bitvector
+ // or it's in 'ite' mode which will give up on bools that
+ // can't be converted easily
+
+ Debug("bool-to-bv") << "BoolToBV::lowerNode skipping: " << n << std::endl;
+ return;
}
- addToLowerCache(current, result);
}
- if (topLevel)
+
+ Kind new_kind = k;
+ switch (k)
{
- result = nm->mkNode(kind::EQUAL, result, d_one);
+ case kind::EQUAL: new_kind = kind::BITVECTOR_COMP; break;
+ case kind::AND: new_kind = kind::BITVECTOR_AND; break;
+ case kind::OR: new_kind = kind::BITVECTOR_OR; break;
+ case kind::NOT: new_kind = kind::BITVECTOR_NOT; break;
+ case kind::XOR: new_kind = kind::BITVECTOR_XOR; break;
+ case kind::IMPLIES: new_kind = kind::BITVECTOR_OR; break;
+ case kind::ITE: new_kind = kind::BITVECTOR_ITE; break;
+ case kind::BITVECTOR_ULT: new_kind = kind::BITVECTOR_ULTBV; break;
+ case kind::BITVECTOR_SLT: new_kind = kind::BITVECTOR_SLTBV; break;
+ case kind::BITVECTOR_ULE:
+ case kind::BITVECTOR_UGT:
+ case kind::BITVECTOR_UGE:
+ case kind::BITVECTOR_SLE:
+ case kind::BITVECTOR_SGT:
+ case kind::BITVECTOR_SGE:
+ // Should have been removed by rewriting.
+ Unreachable();
+ default: break;
}
- Assert(result != Node());
- Debug("bool-to-bv") << "BoolToBV::lowerNode " << current << " => \n"
- << result << "\n";
- return result;
-}
-void BoolToBV::lowerBoolToBv(const std::vector<Node>& assertions,
- std::vector<Node>& new_assertions)
-{
- for (unsigned i = 0; i < assertions.size(); ++i)
+ NodeBuilder<> builder(new_kind);
+ if ((options::boolToBitvector() == BOOL_TO_BV_ALL) && (new_kind != k))
+ {
+ ++(d_statistics.d_numTermsLowered);
+ }
+
+ if (n.getMetaKind() == kind::metakind::PARAMETERIZED)
+ {
+ builder << n.getOperator();
+ }
+
+ // special case IMPLIES because needs to be rewritten
+ if (k == kind::IMPLIES)
+ {
+ builder << nm->mkNode(kind::BITVECTOR_NOT, fromCache(n[0]));
+ builder << fromCache(n[1]);
+ }
+ else
{
- Node new_assertion = lowerNode(assertions[i], true);
- new_assertions.push_back(new_assertion);
- Trace("bool-to-bv") << " " << assertions[i] << " => " << new_assertions[i]
- << "\n";
+ for (const Node& nn : n)
+ {
+ builder << fromCache(nn);
+ }
}
+
+ Debug("bool-to-bv") << "BoolToBV::lowerNode " << n << " =>\n"
+ << builder << std::endl;
+
+ d_lowerCache[n] = builder.constructNode();
}
BoolToBV::Statistics::Statistics()
- : d_numTermsLowered("preprocessing::passes::BoolToBV::NumTermsLowered", 0),
- d_numAtomsLowered("preprocessing::passes::BoolToBV::NumAtomsLowered", 0),
+ : d_numIteToBvite("preprocessing::passes::BoolToBV::NumIteToBvite", 0),
+ d_numTermsLowered("preprocessing::passes:BoolToBV::NumTermsLowered", 0),
d_numTermsForcedLowered(
"preprocessing::passes::BoolToBV::NumTermsForcedLowered", 0)
{
- smtStatisticsRegistry()->registerStat(&d_numTermsLowered);
- smtStatisticsRegistry()->registerStat(&d_numAtomsLowered);
- smtStatisticsRegistry()->registerStat(&d_numTermsForcedLowered);
+ smtStatisticsRegistry()->registerStat(&d_numIteToBvite);
+ if (options::boolToBitvector() == BOOL_TO_BV_ALL)
+ {
+ // these statistics wouldn't be correct in the ITE mode,
+ // because it might discard rebuilt nodes if it fails to
+ // convert a bool to width-one bit-vector (never forces)
+ smtStatisticsRegistry()->registerStat(&d_numTermsLowered);
+ smtStatisticsRegistry()->registerStat(&d_numTermsForcedLowered);
+ }
}
BoolToBV::Statistics::~Statistics()
{
- smtStatisticsRegistry()->unregisterStat(&d_numTermsLowered);
- smtStatisticsRegistry()->unregisterStat(&d_numAtomsLowered);
- smtStatisticsRegistry()->unregisterStat(&d_numTermsForcedLowered);
+ smtStatisticsRegistry()->unregisterStat(&d_numIteToBvite);
+ if (options::boolToBitvector() == BOOL_TO_BV_ALL)
+ {
+ smtStatisticsRegistry()->unregisterStat(&d_numTermsLowered);
+ smtStatisticsRegistry()->unregisterStat(&d_numTermsForcedLowered);
+ }
}
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback