summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHaniel Barbosa <hanielbbarbosa@gmail.com>2018-08-24 20:19:14 -0500
committerGitHub <noreply@github.com>2018-08-24 20:19:14 -0500
commit7b9c2529c149a9cd046083af401cbdeadf406804 (patch)
treebbae5bbf4c9538181f01fae61f0e38bbf46dc3d2
parent248f841f37b8b2d514d7308faa8f4573115f82e9 (diff)
Refactor nlExtPurify preprocessing pass (#1963)
-rw-r--r--src/Makefile.am2
-rw-r--r--src/preprocessing/passes/nl_ext_purify.cpp130
-rw-r--r--src/preprocessing/passes/nl_ext_purify.h57
-rw-r--r--src/smt/smt_engine.cpp102
-rw-r--r--test/regress/Makefile.tests1
-rw-r--r--test/regress/regress0/nl/nlExtPurify-test.smt215
6 files changed, 217 insertions, 90 deletions
diff --git a/src/Makefile.am b/src/Makefile.am
index 3b8a12fa5..d399602cb 100644
--- a/src/Makefile.am
+++ b/src/Makefile.am
@@ -85,6 +85,8 @@ libcvc4_la_SOURCES = \
preprocessing/passes/ite_removal.h \
preprocessing/passes/ite_simp.cpp \
preprocessing/passes/ite_simp.h \
+ preprocessing/passes/nl_ext_purify.cpp \
+ preprocessing/passes/nl_ext_purify.h \
preprocessing/passes/pseudo_boolean_processor.cpp \
preprocessing/passes/pseudo_boolean_processor.h \
preprocessing/passes/bool_to_bv.cpp \
diff --git a/src/preprocessing/passes/nl_ext_purify.cpp b/src/preprocessing/passes/nl_ext_purify.cpp
new file mode 100644
index 000000000..afb092571
--- /dev/null
+++ b/src/preprocessing/passes/nl_ext_purify.cpp
@@ -0,0 +1,130 @@
+/********************* */
+/*! \file nl_ext_purify.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ ** Haniel Barbosa
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved. See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief The NlExtPurify preprocessing pass
+ **
+ ** Purifies non-linear terms
+ **/
+
+#include "preprocessing/passes/nl_ext_purify.h"
+
+namespace CVC4 {
+namespace preprocessing {
+namespace passes {
+
+using namespace CVC4::theory;
+
+Node NlExtPurify::purifyNlTerms(TNode n,
+ NodeMap& cache,
+ NodeMap& bcache,
+ std::vector<Node>& var_eq,
+ bool beneathMult)
+{
+ if (beneathMult)
+ {
+ NodeMap::iterator find = bcache.find(n);
+ if (find != bcache.end())
+ {
+ return (*find).second;
+ }
+ }
+ else
+ {
+ NodeMap::iterator find = cache.find(n);
+ if (find != cache.end())
+ {
+ return (*find).second;
+ }
+ }
+ Node ret = n;
+ if (n.getNumChildren() > 0)
+ {
+ if (beneathMult
+ && (n.getKind() == kind::PLUS || n.getKind() == kind::MINUS))
+ {
+ // don't do it if it rewrites to a constant
+ Node nr = Rewriter::rewrite(n);
+ if (nr.isConst())
+ {
+ // return the rewritten constant
+ ret = nr;
+ }
+ else
+ {
+ // new variable
+ ret = NodeManager::currentNM()->mkSkolem(
+ "__purifyNl_var",
+ n.getType(),
+ "Variable introduced in purifyNl pass");
+ Node np = purifyNlTerms(n, cache, bcache, var_eq, false);
+ var_eq.push_back(np.eqNode(ret));
+ Trace("nl-ext-purify") << "Purify : " << ret << " -> " << np
+ << std::endl;
+ }
+ }
+ else
+ {
+ bool beneathMultNew = beneathMult || n.getKind() == kind::MULT;
+ bool childChanged = false;
+ std::vector<Node> children;
+ for (unsigned i = 0, size = n.getNumChildren(); i < size; ++i)
+ {
+ Node nc = purifyNlTerms(n[i], cache, bcache, var_eq, beneathMultNew);
+ childChanged = childChanged || nc != n[i];
+ children.push_back(nc);
+ }
+ if (childChanged)
+ {
+ ret = NodeManager::currentNM()->mkNode(n.getKind(), children);
+ }
+ }
+ }
+ if (beneathMult)
+ {
+ bcache[n] = ret;
+ }
+ else
+ {
+ cache[n] = ret;
+ }
+ return ret;
+}
+
+NlExtPurify::NlExtPurify(PreprocessingPassContext* preprocContext)
+ : PreprocessingPass(preprocContext, "nl-ext-purify"){};
+
+PreprocessingPassResult NlExtPurify::applyInternal(
+ AssertionPipeline* assertionsToPreprocess)
+{
+ unordered_map<Node, Node, NodeHashFunction> cache;
+ unordered_map<Node, Node, NodeHashFunction> bcache;
+ std::vector<Node> var_eq;
+ unsigned size = assertionsToPreprocess->size();
+ for (unsigned i = 0; i < size; ++i)
+ {
+ Node a = (*assertionsToPreprocess)[i];
+ assertionsToPreprocess->replace(i, purifyNlTerms(a, cache, bcache, var_eq));
+ Trace("nl-ext-purify") << "Purify : " << a << " -> "
+ << (*assertionsToPreprocess)[i] << "\n";
+ }
+ if (!var_eq.empty())
+ {
+ unsigned lastIndex = size - 1;
+ var_eq.insert(var_eq.begin(), (*assertionsToPreprocess)[lastIndex]);
+ assertionsToPreprocess->replace(
+ lastIndex, NodeManager::currentNM()->mkNode(kind::AND, var_eq));
+ }
+ return PreprocessingPassResult::NO_CONFLICT;
+}
+
+} // namespace passes
+} // namespace preprocessing
+} // namespace CVC4
diff --git a/src/preprocessing/passes/nl_ext_purify.h b/src/preprocessing/passes/nl_ext_purify.h
new file mode 100644
index 000000000..8d28b0742
--- /dev/null
+++ b/src/preprocessing/passes/nl_ext_purify.h
@@ -0,0 +1,57 @@
+/********************* */
+/*! \file nl_ext_purify.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ ** Haniel Barbosa
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved. See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief The NlExtPurify preprocessing pass
+ **
+ ** Purifies non-linear terms by replacing sums under multiplications by fresh
+ ** variables
+ **/
+
+#include "cvc4_private.h"
+
+#ifndef __CVC4__PREPROCESSING__PASSES__NL_EXT_PURIFY_H
+#define __CVC4__PREPROCESSING__PASSES__NL_EXT_PURIFY_H
+
+#include <unordered_map>
+#include <vector>
+
+#include "expr/node.h"
+#include "preprocessing/preprocessing_pass.h"
+#include "preprocessing/preprocessing_pass_context.h"
+
+namespace CVC4 {
+namespace preprocessing {
+namespace passes {
+
+using NodeMap = std::unordered_map<Node, Node, NodeHashFunction>;
+
+class NlExtPurify : public PreprocessingPass
+{
+ public:
+ NlExtPurify(PreprocessingPassContext* preprocContext);
+
+ protected:
+ PreprocessingPassResult applyInternal(
+ AssertionPipeline* assertionsToPreprocess) override;
+
+ private:
+ Node purifyNlTerms(TNode n,
+ NodeMap& cache,
+ NodeMap& bcache,
+ std::vector<Node>& var_eq,
+ bool beneathMult = false);
+};
+
+} // namespace passes
+} // namespace preprocessing
+} // namespace CVC4
+
+#endif /* __CVC4__PREPROCESSING__PASSES__NL_EXT_PURIFY_H */
diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp
index deafcc96c..70e575487 100644
--- a/src/smt/smt_engine.cpp
+++ b/src/smt/smt_engine.cpp
@@ -83,6 +83,7 @@
#include "preprocessing/passes/int_to_bv.h"
#include "preprocessing/passes/ite_removal.h"
#include "preprocessing/passes/ite_simp.h"
+#include "preprocessing/passes/nl_ext_purify.h"
#include "preprocessing/passes/pseudo_boolean_processor.h"
#include "preprocessing/passes/quantifiers_preprocess.h"
#include "preprocessing/passes/real_to_int.h"
@@ -567,14 +568,6 @@ class SmtEnginePrivate : public NodeManagerListener {
bool nonClausalSimplify();
/**
- * Performs static learning on the assertions.
- */
- void staticLearning();
-
- Node realToInt(TNode n, NodeToNodeHashMap& cache, std::vector< Node >& var_eq);
- Node purifyNlTerms(TNode n, NodeToNodeHashMap& cache, NodeToNodeHashMap& bcache, std::vector< Node >& var_eq, bool beneathMult = false);
-
- /**
* Helper function to fix up assertion list to restore invariants needed after
* ite removal.
*/
@@ -790,7 +783,7 @@ class SmtEnginePrivate : public NodeManagerListener {
/** Process a user push.
*/
void notifyPush() {
-
+
}
/**
@@ -872,13 +865,13 @@ class SmtEnginePrivate : public NodeManagerListener {
std::ostream* getReplayLog() const {
return d_managedReplayLog.getReplayLog();
}
-
+
//------------------------------- expression names
// implements setExpressionName, as described in smt_engine.h
void setExpressionName(Expr e, std::string name) {
d_exprNames[Node::fromExpr(e)] = name;
}
-
+
// implements getExpressionName, as described in smt_engine.h
bool getExpressionName(Expr e, std::string& name) const {
context::CDHashMap< Node, std::string, NodeHashFunction >::const_iterator it = d_exprNames.find(e);
@@ -2657,6 +2650,8 @@ void SmtEnginePrivate::finishInit()
new IntToBV(d_preprocessingPassContext.get()));
std::unique_ptr<ITESimp> iteSimp(
new ITESimp(d_preprocessingPassContext.get()));
+ std::unique_ptr<NlExtPurify> nlExtPurify(
+ new NlExtPurify(d_preprocessingPassContext.get()));
std::unique_ptr<QuantifiersPreprocess> quantifiersPreprocess(
new QuantifiersPreprocess(d_preprocessingPassContext.get()));
std::unique_ptr<PseudoBooleanProcessor> pbProc(
@@ -2700,6 +2695,8 @@ void SmtEnginePrivate::finishInit()
std::move(globalNegate));
d_preprocessingPassRegistry.registerPass("int-to-bv", std::move(intToBV));
d_preprocessingPassRegistry.registerPass("ite-simp", std::move(iteSimp));
+ d_preprocessingPassRegistry.registerPass("nl-ext-purify",
+ std::move(nlExtPurify));
d_preprocessingPassRegistry.registerPass("quantifiers-preprocess",
std::move(quantifiersPreprocess));
d_preprocessingPassRegistry.registerPass("pseudo-boolean-processor",
@@ -2712,7 +2709,7 @@ void SmtEnginePrivate::finishInit()
std::move(sepSkolemEmp));
d_preprocessingPassRegistry.registerPass("sort-inference",
std::move(sortInfer));
- d_preprocessingPassRegistry.registerPass("static-learning",
+ d_preprocessingPassRegistry.registerPass("static-learning",
std::move(staticLearning));
d_preprocessingPassRegistry.registerPass("sygus-infer",
std::move(sygusInfer));
@@ -2903,68 +2900,6 @@ Node SmtEnginePrivate::expandDefinitions(TNode n, unordered_map<Node, Node, Node
return result.top();
}
-typedef std::unordered_map<Node, Node, NodeHashFunction> NodeMap;
-
-Node SmtEnginePrivate::purifyNlTerms(TNode n, NodeMap& cache, NodeMap& bcache, std::vector< Node >& var_eq, bool beneathMult) {
- if( beneathMult ){
- NodeMap::iterator find = bcache.find(n);
- if (find != bcache.end()) {
- return (*find).second;
- }
- }else{
- NodeMap::iterator find = cache.find(n);
- if (find != cache.end()) {
- return (*find).second;
- }
- }
- Node ret = n;
- if( n.getNumChildren()>0 ){
- if (beneathMult
- && (n.getKind() == kind::PLUS || n.getKind() == kind::MINUS))
- {
- // don't do it if it rewrites to a constant
- Node nr = Rewriter::rewrite(n);
- if (nr.isConst())
- {
- // return the rewritten constant
- ret = nr;
- }
- else
- {
- // new variable
- ret = NodeManager::currentNM()->mkSkolem(
- "__purifyNl_var",
- n.getType(),
- "Variable introduced in purifyNl pass");
- Node np = purifyNlTerms(n, cache, bcache, var_eq, false);
- var_eq.push_back(np.eqNode(ret));
- Trace("nl-ext-purify")
- << "Purify : " << ret << " -> " << np << std::endl;
- }
- }
- else
- {
- bool beneathMultNew = beneathMult || n.getKind()==kind::MULT;
- bool childChanged = false;
- std::vector< Node > children;
- for( unsigned i=0; i<n.getNumChildren(); i++ ){
- Node nc = purifyNlTerms( n[i], cache, bcache, var_eq, beneathMultNew );
- childChanged = childChanged || nc!=n[i];
- children.push_back( nc );
- }
- if( childChanged ){
- ret = NodeManager::currentNM()->mkNode( n.getKind(), children );
- }
- }
- }
- if( beneathMult ){
- bcache[n] = ret;
- }else{
- cache[n] = ret;
- }
- return ret;
-}
-
// do dumping (before/after any preprocessing pass)
static void dumpAssertions(const char* key,
const AssertionPipeline& assertionList) {
@@ -4037,20 +3972,7 @@ void SmtEnginePrivate::processAssertions() {
}
if( options::nlExtPurify() ){
- unordered_map<Node, Node, NodeHashFunction> cache;
- unordered_map<Node, Node, NodeHashFunction> bcache;
- std::vector< Node > var_eq;
- for (unsigned i = 0; i < d_assertions.size(); ++ i) {
- Node a = d_assertions[i];
- d_assertions.replace(i, purifyNlTerms(a, cache, bcache, var_eq));
- Trace("nl-ext-purify")
- << "Purify : " << a << " -> " << d_assertions[i] << std::endl;
- }
- if( !var_eq.empty() ){
- unsigned lastIndex = d_assertions.size()-1;
- var_eq.insert( var_eq.begin(), d_assertions[lastIndex] );
- d_assertions.replace(lastIndex, NodeManager::currentNM()->mkNode( kind::AND, var_eq ) );
- }
+ d_preprocessingPassRegistry.getPass("nl-ext-purify")->apply(&d_assertions);
}
if( options::ceGuidedInst() ){
@@ -5527,7 +5449,7 @@ Expr SmtEngine::doQuantifierElimination(const Expr& e, bool doFull, bool strict)
Assert( inst_qs.size()<=1 );
Node ret_n;
if( inst_qs.size()==1 ){
- Node top_q = inst_qs[0];
+ Node top_q = inst_qs[0];
//Node top_q = Rewriter::rewrite( nn_e ).negate();
Assert( top_q.getKind()==kind::FORALL );
Trace("smt-qe") << "Get qe for " << top_q << std::endl;
@@ -5950,7 +5872,7 @@ void SmtEngine::setReplayStream(ExprStream* replayStream) {
AlwaysAssert(!d_fullyInited,
"Cannot set replay stream once fully initialized");
d_replayStream = replayStream;
-}
+}
bool SmtEngine::getExpressionName(Expr e, std::string& name) const {
return d_private->getExpressionName(e, name);
diff --git a/test/regress/Makefile.tests b/test/regress/Makefile.tests
index 2922085ca..f707da219 100644
--- a/test/regress/Makefile.tests
+++ b/test/regress/Makefile.tests
@@ -503,6 +503,7 @@ REG0_TESTS = \
regress0/nl/magnitude-wrong-1020-m.smt2 \
regress0/nl/mult-po.smt2 \
regress0/nl/nia-wrong-tl.smt2 \
+ regress0/nl/nlExtPurify-test.smt2 \
regress0/nl/nta/cos-sig-value.smt2 \
regress0/nl/nta/exp-n0.5-lb.smt2 \
regress0/nl/nta/exp-n0.5-ub.smt2 \
diff --git a/test/regress/regress0/nl/nlExtPurify-test.smt2 b/test/regress/regress0/nl/nlExtPurify-test.smt2
new file mode 100644
index 000000000..1a2391c3b
--- /dev/null
+++ b/test/regress/regress0/nl/nlExtPurify-test.smt2
@@ -0,0 +1,15 @@
+; COMMAND-LINE: --nl-ext-purify
+; EXPECT: sat
+(set-info :smt-lib-version 2.6)
+(set-logic QF_NRA)
+(set-info :category "crafted")
+(set-info :status sat)
+(declare-fun skoX () Real)
+(declare-fun skoS3 () Real)
+(declare-fun skoSX () Real)
+
+(assert (and (not (<= skoX 0)) (and (not (<= (* (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX)) 0)) (not (<= skoS3 0)))))
+
+
+(check-sat)
+(exit)
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback