diff options
134 files changed, 7400 insertions, 3304 deletions
diff --git a/.travis.yml b/.travis.yml index 6ad643a45..562987f3c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -17,9 +17,6 @@ cache: sudo: required dist: trusty -compiler: - - gcc - - clang env: global: # The next declaration is the encrypted COVERITY_SCAN_TOKEN, created @@ -27,11 +24,6 @@ env: - secure: "fRfdzYwV10VeW5tVSvy5qpR8ZlkXepR7XWzCulzlHs9SRI2YY20BpzWRjyMBiGu2t7IeJKT7qdjq/CJOQEM8WS76ON7QJ1iymKaRDewDs3OhyPJ71fsFKEGgLky9blk7I9qZh23hnRVECj1oJAVry9IK04bc2zyIEjUYpjRkUAQ=" - TEST_GROUPS=2 - CCACHE_COMPRESS=1 - matrix: - - TRAVIS_CVC4=yes TRAVIS_WITH_LFSC=yes TRAVIS_CVC4_CONFIG='production --enable-language-bindings=java,c --with-lfsc' - - TRAVIS_CVC4=yes TRAVIS_WITH_LFSC=yes TRAVIS_CVC4_CONFIG='debug --with-lfsc --disable-debug-symbols' - - TRAVIS_CVC4=yes TRAVIS_WITH_LFSC=yes TRAVIS_CVC4_CONFIG='debug --with-cln --enable-gpl --disable-debug-symbols --disable-proof' - - TRAVIS_CVC4=yes TRAVIS_CVC4_DISTCHECK=yes TRAVIS_CVC4_CONFIG='--enable-proof' addons: apt: sources: @@ -119,27 +111,21 @@ script: matrix: fast_finish: true include: - # Test with GCC7 - - addons: - apt: - sources: - - ubuntu-toolchain-r-test - packages: - - *common_deps - - g++-7 + # Test with GCC + - compiler: gcc env: - - MATRIX_EVAL='CC=gcc-7 && CXX=g++-7' - - TRAVIS_CVC4=yes TRAVIS_WITH_LFSC=yes TRAVIS_CVC4_CONFIG='debug --with-lfsc --disable-debug-symbols' TEST_GROUP=0 - - addons: - apt: - sources: - - ubuntu-toolchain-r-test - packages: - - *common_deps - - g++-7 + - TRAVIS_CVC4=yes TRAVIS_WITH_LFSC=yes TRAVIS_CVC4_CONFIG='production --enable-language-bindings=java,c --with-lfsc' + - compiler: gcc + env: + - TRAVIS_CVC4=yes TRAVIS_WITH_LFSC=yes TRAVIS_CVC4_CONFIG='debug --with-lfsc --disable-debug-symbols' + # + # Test with Clang + - compiler: clang + env: + - TRAVIS_CVC4=yes TRAVIS_WITH_LFSC=yes TRAVIS_CVC4_CONFIG='debug --with-cln --enable-gpl --disable-debug-symbols --disable-proof' + - compiler: clang env: - - MATRIX_EVAL='CC=gcc-7 && CXX=g++-7' - - TRAVIS_CVC4=yes TRAVIS_WITH_LFSC=yes TRAVIS_CVC4_CONFIG='debug --with-lfsc --disable-debug-symbols' TEST_GROUP=1 + - TRAVIS_CVC4=yes TRAVIS_CVC4_DISTCHECK=yes TRAVIS_CVC4_CONFIG='--enable-proof' # Rule for running Coverity Scan. - os: linux compiler: gcc @@ -7,6 +7,7 @@ the copyright and licensing of CVC4. The core designers and authors of CVC4 are: Kshitij Bansal, New York University, Google + Haniel Barbosa, The University of Iowa Clark Barrett, New York University, Google, Stanford University Francois Bobot, The University of Iowa, Commissariat a l'Energie Atomique Martin Brain, University of Oxford @@ -18,7 +19,9 @@ The core designers and authors of CVC4 are: Tim King, New York University, Universite Joseph Fourier, Google Tianyi Liang, The University of Iowa Paul Meng, The University of Iowa + Aina Niemetz, Stanford University Andres Noetzli, Stanford University + Mathias Preiner, Stanford University Andrew Reynolds, The University of Iowa, EPFL Cesare Tinelli, The University of Iowa diff --git a/contrib/get-abc b/contrib/get-abc index 7cf833e23..0a840fc84 100755 --- a/contrib/get-abc +++ b/contrib/get-abc @@ -1,37 +1,14 @@ #!/bin/bash # -set -e - -commit=53f39c11b58d - -cd "$(dirname "$0")/.." - -if ! [ -e src/parser/cvc/Cvc.g ]; then - echo "$(basename $0): I expect to be in the contrib/ of a CVC4 source tree," >&2 - echo "but apparently:" >&2 - echo >&2 - echo " $(pwd)" >&2 - echo >&2 - echo "is not a CVC4 source tree ?!" >&2 - exit 1 -fi - -function webget { - if which wget &>/dev/null; then - wget -c -O "$2" "$1" - elif which curl &>/dev/null; then - curl "$1" >"$2" - else - echo "Can't figure out how to download from web. Please install wget or curl." >&2 - exit 1 - fi -} +source "$(dirname "$0")/get-script-header.sh" if [ -e abc ]; then echo 'error: file or directory "abc" exists; please move it out of the way.' >&2 exit 1 fi +commit=53f39c11b58d + mkdir abc cd abc webget https://bitbucket.org/alanmi/abc/get/$commit.tar.gz abc-$commit.tar.gz diff --git a/contrib/get-antlr-3.4 b/contrib/get-antlr-3.4 index 87d6ea450..4ee23509a 100755 --- a/contrib/get-antlr-3.4 +++ b/contrib/get-antlr-3.4 @@ -1,34 +1,11 @@ #!/bin/bash # -set -e - -cd "$(dirname "$0")/.." +source "$(dirname "$0")/get-script-header.sh" if [ -z "${BUILD_TYPE}" ]; then BUILD_TYPE="--disable-shared --enable-static" fi -if ! [ -e src/parser/cvc/Cvc.g ]; then - echo "$(basename $0): I expect to be in the contrib/ of a CVC4 source tree," >&2 - echo "but apparently:" >&2 - echo >&2 - echo " $(pwd)" >&2 - echo >&2 - echo "is not a CVC4 source tree ?!" >&2 - exit 1 -fi - -function webget { - if which curl &>/dev/null; then - curl "$1" >"$2" - elif which wget &>/dev/null; then - wget -c -O "$2" "$1" - else - echo "Can't figure out how to download from web. Please install wget or curl." >&2 - exit 1 - fi -} - if [ -z "${MACHINE_TYPE}" ]; then if ! [ -e config/config.guess ]; then # Attempt to download once diff --git a/contrib/get-cryptominisat4 b/contrib/get-cryptominisat4 index c96bbe03d..c6f2a1ce8 100755 --- a/contrib/get-cryptominisat4 +++ b/contrib/get-cryptominisat4 @@ -1,37 +1,14 @@ #!/bin/bash # -set -e - -version="4.2.0" - -cd "$(dirname "$0")/.." - -if ! [ -e src/parser/cvc/Cvc.g ]; then - echo "$(basename $0): I expect to be in the contrib/ of a CVC4 source tree," >&2 - echo "but apparently:" >&2 - echo >&2 - echo " $(pwd)" >&2 - echo >&2 - echo "is not a CVC4 source tree ?!" >&2 - exit 1 -fi - -function webget { - if which wget &>/dev/null; then - wget -c -O "$2" "$1" - elif which curl &>/dev/null; then - curl "$1" >"$2" - else - echo "Can't figure out how to download from web. Please install wget or curl." >&2 - exit 1 - fi -} +source "$(dirname "$0")/get-script-header.sh" if [ -e cryptominisat4 ]; then echo 'error: file or directory "cryptominisat4" exists; please move it out of the way.' >&2 exit 1 fi +version="4.2.0" + mkdir cryptominisat4 cd cryptominisat4 CRYPTOMINISAT_PATH=`pwd` diff --git a/contrib/get-glpk-cut-log b/contrib/get-glpk-cut-log index 5ca18c66d..419fdba90 100755 --- a/contrib/get-glpk-cut-log +++ b/contrib/get-glpk-cut-log @@ -1,32 +1,9 @@ #!/bin/bash # -set -e +source "$(dirname "$0")/get-script-header.sh" commit=b420454e732f4b3d229c552ef7cd46fec75fe65c -cd "$(dirname "$0")/.." - -if ! [ -e src/parser/cvc/Cvc.g ]; then - echo "$(basename $0): I expect to be in the contrib/ of a CVC4 source tree," >&2 - echo "but apparently:" >&2 - echo >&2 - echo " $(pwd)" >&2 - echo >&2 - echo "is not a CVC4 source tree ?!" >&2 - exit 1 -fi - -function webget { - if which wget &>/dev/null; then - wget -c -O "$2" "$1" - elif which curl &>/dev/null; then - curl "$1" >"$2" - else - echo "Can't figure out how to download from web. Please install wget or curl." >&2 - exit 1 - fi -} - if [ -e glpk-cut-log ]; then echo 'error: file or directory "glpk-cut-log" exists; please move it out of the way.' >&2 exit 1 diff --git a/contrib/get-lfsc-checker b/contrib/get-lfsc-checker index 495082387..f4c79de2a 100755 --- a/contrib/get-lfsc-checker +++ b/contrib/get-lfsc-checker @@ -1,23 +1,10 @@ #!/bin/bash # -set -e +source "$(dirname "$0")/get-script-header.sh" lfscrepo="https://github.com/CVC4/LFSC.git" dirname="lfsc-checker" - -cd "$(dirname "$0")/.." - -if ! [ -e src/parser/cvc/Cvc.g ]; then - echo "$(basename $0): I expect to be in the contrib/ of a CVC4 source tree," >&2 - echo "but apparently:" >&2 - echo >&2 - echo " $(pwd)" >&2 - echo >&2 - echo "is not a CVC4 source tree ?!" >&2 - exit 1 -fi - function gitclone { if which git &> /dev/null then diff --git a/contrib/get-script-header.sh b/contrib/get-script-header.sh new file mode 100644 index 000000000..285d97b35 --- /dev/null +++ b/contrib/get-script-header.sh @@ -0,0 +1,26 @@ +#!/bin/bash +# +set -e + +cd "$(dirname "$0")/.." + +if ! [ -e src/parser/cvc/Cvc.g ]; then + echo "$(basename $0): I expect to be in the contrib/ of a CVC4 source tree," >&2 + echo "but apparently:" >&2 + echo >&2 + echo " $(pwd)" >&2 + echo >&2 + echo "is not a CVC4 source tree ?!" >&2 + exit 1 +fi + +function webget { + if which wget &>/dev/null; then + wget -c -O "$2" "$1" + elif which curl &>/dev/null; then + curl "$1" >"$2" + else + echo "Can't figure out how to download from web. Please install wget or curl." >&2 + exit 1 + fi +} diff --git a/contrib/update-copyright.pl b/contrib/update-copyright.pl index 5f31f48c9..d0dd33cbb 100755 --- a/contrib/update-copyright.pl +++ b/contrib/update-copyright.pl @@ -48,7 +48,7 @@ $excluded_paths .= '$)'; # Years of copyright for the template. E.g., the string # "1985, 1987, 1992, 1997, 2008" or "2006-2009" or whatever. -my $years = '2009-2017'; +my $years = '2009-2018'; my $standard_template = <<EOF; ** This file is part of the CVC4 project. diff --git a/src/Makefile.am b/src/Makefile.am index a6c58e281..0006a8521 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -463,8 +463,12 @@ libcvc4_la_SOURCES = \ theory/quantifiers/sygus_grammar_cons.h \ theory/quantifiers/sygus_grammar_norm.cpp \ theory/quantifiers/sygus_grammar_norm.h \ + theory/quantifiers/sygus_grammar_red.cpp \ + theory/quantifiers/sygus_grammar_red.h \ theory/quantifiers/sygus_process_conj.cpp \ theory/quantifiers/sygus_process_conj.h \ + theory/quantifiers/sygus_sampler.cpp \ + theory/quantifiers/sygus_sampler.h \ theory/quantifiers/symmetry_breaking.cpp \ theory/quantifiers/symmetry_breaking.h \ theory/quantifiers/term_database.cpp \ diff --git a/src/base/Makefile.am b/src/base/Makefile.am index 5537bbbdd..7dd6f47e5 100644 --- a/src/base/Makefile.am +++ b/src/base/Makefile.am @@ -20,6 +20,8 @@ libbase_la_SOURCES = \ configuration_private.h \ cvc4_assert.cpp \ cvc4_assert.h \ + cvc4_check.cpp \ + cvc4_check.h \ exception.cpp \ exception.h \ listener.cpp \ diff --git a/src/base/cvc4_check.cpp b/src/base/cvc4_check.cpp new file mode 100644 index 000000000..5976ac3f7 --- /dev/null +++ b/src/base/cvc4_check.cpp @@ -0,0 +1,44 @@ +/********************* */ +/*! \file cvc4_check.cpp + ** \verbatim + ** Top contributors (to current version): + ** Tim King + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Assertion utility classes, functions and macros. + ** + ** Implementation of assertion utility classes, functions and macros. + **/ + +#include "base/cvc4_check.h" + +#include <cstdlib> +#include <iostream> + +namespace CVC4 { + +FatalStream::FatalStream(const char* function, const char* file, int line) +{ + stream() << "Fatal failure within " << function << " at " << file << ":" + << line << "\n"; +} + +FatalStream::~FatalStream() +{ + Flush(); + abort(); +} + +std::ostream& FatalStream::stream() { return std::cerr; } + +void FatalStream::Flush() +{ + stream() << std::endl; + stream().flush(); +} + +} // namespace CVC4 diff --git a/src/base/cvc4_check.h b/src/base/cvc4_check.h new file mode 100644 index 000000000..fb4ec0bba --- /dev/null +++ b/src/base/cvc4_check.h @@ -0,0 +1,144 @@ +/********************* */ +/*! \file cvc4_check.h + ** \verbatim + ** Top contributors (to current version): + ** Tim King + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Assertion utility classes, functions and macros. + ** + ** The CHECK utility classes, functions and macros are related to the Assert() + ** macros defined in base/cvc4_assert.h. The major distinguishing attribute + ** is the CHECK's abort() the process on failures while Assert() statements + ** throw C++ exceptions. + ** + ** The main usage in the file is the CHECK macros. The CHECK macros assert a + ** condition and aborts()'s the process if the condition is not satisfied. The + ** macro leaves a hanging ostream for the user to specify additional + ** information about the failure. Example usage: + ** CHECK(x >= 0) << "x must be positive."; + ** + ** DCHECK is a CHECK that is only enabled in debug builds. + ** DCHECK(pointer != nullptr); + ** + ** CVC4_FATAL() can be used to indicate unreachable code. + ** + ** The CHECK and DCHECK macros are not safe for use in signal-handling code. + ** TODO(taking): Add a signal-handling safe version of CHECK. + **/ + +#include "cvc4_private.h" + +#ifndef __CVC4__CHECK_H +#define __CVC4__CHECK_H + +#include <ostream> + +// Define CVC4_NO_RETURN macro replacement for [[noreturn]]. +#if defined(SWIG) +#define CVC4_NO_RETURN +// SWIG does not yet support [[noreturn]] so emit nothing instead. +#else +#define CVC4_NO_RETURN [[noreturn]] +// Not checking for whether the compiler supports [[noreturn]] using +// __has_cpp_attribute as GCC 4.8 is too widespread and does not support this. +// We instead assume this is C++11 (or later) and [[noreturn]] is available. +#endif // defined(SWIG) + +// Define CVC4_PREDICT_FALSE(x) that helps the compiler predict that x will be +// false (if there is compiler support). +#ifdef __has_builtin +#if __has_builtin(__builtin_expect) +#define CVC4_PREDICT_FALSE(x) (__builtin_expect(x, false)) +#else +#define CVC4_PREDICT_FALSE(x) x +#endif +#else +#define CVC4_PREDICT_FALSE(x) x +#endif + +namespace CVC4 { + +// Implementation notes: +// To understand FatalStream and OStreamVoider, it is useful to understand +// how a CHECK is structured. CHECK(cond) is roughly the following pattern: +// cond ? (void)0 : OstreamVoider() & FatalStream().stream() +// This is a carefully crafted message to achieve a hanging ostream using +// operator precedence. The line `CHECK(cond) << foo << bar;` will bind as +// follows: +// `cond ? ((void)0) : (OSV() & ((FS().stream() << foo) << bar));` +// Once the expression is evaluated, the destructor ~FatalStream() of the +// temporary object is then run, which abort()'s the process. The role of the +// OStreamVoider() is to match the void type of the true branch. + +// Class that provides an ostream and whose destructor aborts! Direct usage of +// this class is discouraged. +class FatalStream +{ + public: + FatalStream(const char* function, const char* file, int line); + CVC4_NO_RETURN ~FatalStream(); + + std::ostream& stream(); + + private: + void Flush(); +}; + +// Helper class that changes the type of an std::ostream& into a void. See +// "Implementation notes" for more information. +class OstreamVoider +{ + public: + OstreamVoider() {} + // The operator precedence between operator& and operator<< is critical here. + void operator&(std::ostream&) {} +}; + +// CVC4_FATAL() always aborts a function and provides a convenient way of +// formatting error messages. This can be used instead of a return type. +// +// Example function that returns a type Foo: +// Foo bar(T t) { +// switch(t.type()) { +// ... +// default: +// CVC4_FATAL() << "Unknown T type " << t.enum(); +// } +// } +#define CVC4_FATAL() \ + FatalStream(__PRETTY_FUNCTION__, __FILE__, __LINE__).stream() + +// If `cond` is true, log an error message and abort the process. +// Otherwise, does nothing. This leaves a hanging std::ostream& that can be +// inserted into. +#define CVC4_FATAL_IF(cond, function, file, line) \ + CVC4_PREDICT_FALSE(!(cond)) \ + ? (void)0 : OstreamVoider() & FatalStream(function, file, line).stream() + +// If `cond` is false, log an error message and abort()'s the process. +// Otherwise, does nothing. This leaves a hanging std::ostream& that can be +// inserted into using operator<<. Example usages: +// CHECK(x >= 0); +// CHECK(x >= 0) << "x must be positive"; +// CHECK(x >= 0) << "expected a positive value. Got " << x << " instead"; +#define CHECK(cond) \ + CVC4_FATAL_IF(!(cond), __PRETTY_FUNCTION__, __FILE__, __LINE__) \ + << "Check failure\n\n " << #cond << "\n" + +// DCHECK is a variant of CHECK() that is only checked when CVC4_ASSERTIONS is +// defined. We rely on the optimizer to remove the deadcode. +#ifdef CVC4_ASSERTIONS +#define DCHECK(cond) CHECK(cond) +#else +#define DCHECK(cond) \ + CVC4_FATAL_IF(false, __PRETTY_FUNCTION__, __FILE__, __LINE__) +#endif /* CVC4_DEBUG */ + +} // namespace CVC4 + +#endif /* __CVC4__CHECK_H */ diff --git a/src/context/cdhashset_forward.h b/src/context/cdhashset_forward.h index ed665ce1b..426f8917d 100644 --- a/src/context/cdhashset_forward.h +++ b/src/context/cdhashset_forward.h @@ -26,19 +26,13 @@ #ifndef __CVC4__CONTEXT__CDSET_FORWARD_H #define __CVC4__CONTEXT__CDSET_FORWARD_H -/// \cond internals - -namespace __gnu_cxx { - template <class Key> struct hash; -}/* __gnu_cxx namespace */ +#include <functional> namespace CVC4 { - namespace context { - template <class V, class HashFcn = __gnu_cxx::hash<V> > - class CDHashSet; - }/* CVC4::context namespace */ -}/* CVC4 namespace */ - -/// \endcond +namespace context { +template <class V, class HashFcn = std::hash<V> > +class CDHashSet; +} // namespace context +} // namespace CVC4 #endif /* __CVC4__CONTEXT__CDSET_FORWARD_H */ diff --git a/src/context/cdinsert_hashmap_forward.h b/src/context/cdinsert_hashmap_forward.h index 05501f1a2..d3f46791a 100644 --- a/src/context/cdinsert_hashmap_forward.h +++ b/src/context/cdinsert_hashmap_forward.h @@ -23,16 +23,16 @@ #include "cvc4_public.h" -#pragma once +#ifndef __CVC4__CONTEXT__CDINSERT_HASHMAP_FORWARD_H +#define __CVC4__CONTEXT__CDINSERT_HASHMAP_FORWARD_H -namespace __gnu_cxx { - template <class Key> struct hash; -}/* __gnu_cxx namespace */ +#include <functional> namespace CVC4 { - namespace context { - template <class Key, class Data, class HashFcn = __gnu_cxx::hash<Key> > - class CDInsertHashMap; - }/* CVC4::context namespace */ -}/* CVC4 namespace */ +namespace context { +template <class Key, class Data, class HashFcn = std::hash<Key> > +class CDInsertHashMap; +} // namespace context +} // namespace CVC4 +#endif /* __CVC4__CONTEXT__CDINSERT_HASHMAP_FORWARD_H */ diff --git a/src/context/cdlist_forward.h b/src/context/cdlist_forward.h index 49a077349..e599c037c 100644 --- a/src/context/cdlist_forward.h +++ b/src/context/cdlist_forward.h @@ -36,10 +36,6 @@ /// \cond internals -namespace __gnu_cxx { - template <class Key> struct hash; -}/* __gnu_cxx namespace */ - namespace CVC4 { namespace context { diff --git a/src/expr/expr_manager_template.cpp b/src/expr/expr_manager_template.cpp index 951b92e1c..3993fc9b6 100644 --- a/src/expr/expr_manager_template.cpp +++ b/src/expr/expr_manager_template.cpp @@ -372,18 +372,11 @@ Expr ExprManager::mkExpr(Kind kind, Expr child1, } Expr ExprManager::mkExpr(Expr opExpr) { - const unsigned n = 0; - Kind kind = NodeManager::operatorToKind(opExpr.getNode()); + const Kind kind = NodeManager::operatorToKind(opExpr.getNode()); PrettyCheckArgument( opExpr.getKind() == kind::BUILTIN || kind::metaKindOf(kind) == kind::metakind::PARAMETERIZED, opExpr, "This Expr constructor is for parameterized kinds only"); - PrettyCheckArgument( - n >= minArity(kind) && n <= maxArity(kind), kind, - "Exprs with kind %s must have at least %u children and " - "at most %u children (the one under construction has %u)", - kind::kindToString(kind).c_str(), - minArity(kind), maxArity(kind), n); NodeManagerScope nms(d_nodeManager); try { INC_STAT(kind); diff --git a/src/expr/type_checker_template.cpp b/src/expr/type_checker_template.cpp index bb02528c7..ed615c874 100644 --- a/src/expr/type_checker_template.cpp +++ b/src/expr/type_checker_template.cpp @@ -65,7 +65,7 @@ bool TypeChecker::computeIsConst(NodeManager* nodeManager, TNode n) switch(n.getKind()) { ${construles} -#line 70 "${template}" +#line 69 "${template}" default:; } @@ -81,7 +81,7 @@ bool TypeChecker::neverIsConst(NodeManager* nodeManager, TNode n) switch(n.getKind()) { ${neverconstrules} -#line 87 "${template}" +#line 85 "${template}" default:; } diff --git a/src/expr/type_node.cpp b/src/expr/type_node.cpp index 8001ca3df..9e61e713b 100644 --- a/src/expr/type_node.cpp +++ b/src/expr/type_node.cpp @@ -319,10 +319,8 @@ TypeNode TypeNode::commonTypeNode(TypeNode t0, TypeNode t1, bool isLeast) { } case kind::SEXPR_TYPE: Unimplemented("haven't implemented leastCommonType for symbolic expressions yet"); - return TypeNode(); default: Unimplemented("don't have a commonType for types `%s' and `%s'", t0.toString().c_str(), t1.toString().c_str()); - return TypeNode(); } } diff --git a/src/expr/type_node.h b/src/expr/type_node.h index 72d00a5a2..14c4222a6 100644 --- a/src/expr/type_node.h +++ b/src/expr/type_node.h @@ -233,7 +233,7 @@ public: */ inline Node getOperator() const { Assert(getMetaKind() == kind::metakind::PARAMETERIZED); - return Node(d_nv->getChild(-1)); + return Node(d_nv->getOperator()); } /** diff --git a/src/main/command_executor.cpp b/src/main/command_executor.cpp index 7c8ee7827..a7666dfcf 100644 --- a/src/main/command_executor.cpp +++ b/src/main/command_executor.cpp @@ -193,7 +193,10 @@ bool smtEngineInvoke(SmtEngine* smt, Command* cmd, std::ostream *out) return !cmd->fail(); } -void printStatsIncremental(std::ostream& out, const std::string& prvsStatsString, const std::string& curStatsString) { +void printStatsIncremental(std::ostream& out, + const std::string& prvsStatsString, + const std::string& curStatsString) +{ if(prvsStatsString == "") { out << curStatsString; return; @@ -229,9 +232,11 @@ void printStatsIncremental(std::ostream& out, const std::string& prvsStatsString (std::istringstream(curStatValue) >> curFloat); if(isFloat) { + const std::streamsize old_precision = out.precision(); out << curStatName << ", " << curStatValue << " " << "(" << std::setprecision(8) << (curFloat-prvsFloat) << ")" << std::endl; + out.precision(old_precision); } else { out << curStatName << ", " << curStatValue << std::endl; } diff --git a/src/options/options_handler.cpp b/src/options/options_handler.cpp index c29cfc4d2..61f7646ee 100644 --- a/src/options/options_handler.cpp +++ b/src/options/options_handler.cpp @@ -492,6 +492,25 @@ all \n\ \n\ "; +const std::string OptionsHandler::s_cegisSampleHelp = + "\ +Modes for sampling with counterexample-guided inductive synthesis (CEGIS),\ +supported by --cegis-sample:\n\ +\n\ +none (default) \n\ ++ Do not use sampling with CEGIS.\n\ +\n\ +use \n\ ++ Use sampling to accelerate CEGIS. This will rule out solutions for a\ + conjecture when they are not satisfied by a sample point.\n\ +\n\ +trust \n\ ++ Trust that when a solution for a conjecture is always true under sampling,\ + then it is indeed a solution. Note this option may print out spurious\ + solutions for synthesis conjectures.\n\ +\n\ +"; + const std::string OptionsHandler::s_sygusInvTemplHelp = "\ Template modes for sygus invariant synthesis, supported by --sygus-inv-templ:\n\ \n\ @@ -877,6 +896,34 @@ OptionsHandler::stringToCegqiSingleInvMode(std::string option, } } +theory::quantifiers::CegisSampleMode OptionsHandler::stringToCegisSampleMode( + std::string option, std::string optarg) +{ + if (optarg == "none") + { + return theory::quantifiers::CEGIS_SAMPLE_NONE; + } + else if (optarg == "use") + { + return theory::quantifiers::CEGIS_SAMPLE_USE; + } + else if (optarg == "trust") + { + return theory::quantifiers::CEGIS_SAMPLE_TRUST; + } + else if (optarg == "help") + { + puts(s_cegisSampleHelp.c_str()); + exit(1); + } + else + { + throw OptionException(std::string("unknown option for --cegis-sample: `") + + optarg + + "'. Try --cegis-sample help."); + } +} + theory::quantifiers::SygusInvTemplMode OptionsHandler::stringToSygusInvTemplMode(std::string option, std::string optarg) diff --git a/src/options/options_handler.h b/src/options/options_handler.h index e7bd87ebd..304009a98 100644 --- a/src/options/options_handler.h +++ b/src/options/options_handler.h @@ -108,6 +108,8 @@ public: std::string option, std::string optarg); theory::quantifiers::CegqiSingleInvMode stringToCegqiSingleInvMode( std::string option, std::string optarg); + theory::quantifiers::CegisSampleMode stringToCegisSampleMode( + std::string option, std::string optarg); theory::quantifiers::SygusInvTemplMode stringToSygusInvTemplMode( std::string option, std::string optarg); theory::quantifiers::MacrosQuantMode stringToMacrosQuantMode( @@ -243,6 +245,7 @@ public: static const std::string s_sygusSolutionOutModeHelp; static const std::string s_cbqiBvIneqModeHelp; static const std::string s_cegqiSingleInvHelp; + static const std::string s_cegisSampleHelp; static const std::string s_sygusInvTemplHelp; static const std::string s_termDbModeHelp; static const std::string s_theoryOfModeHelp; diff --git a/src/options/quantifiers_modes.h b/src/options/quantifiers_modes.h index 6274269ce..91fab54ff 100644 --- a/src/options/quantifiers_modes.h +++ b/src/options/quantifiers_modes.h @@ -216,6 +216,16 @@ enum CegqiSingleInvMode { CEGQI_SI_MODE_ALL, }; +enum CegisSampleMode +{ + /** do not use samples for CEGIS */ + CEGIS_SAMPLE_NONE, + /** use samples for CEGIS */ + CEGIS_SAMPLE_USE, + /** trust samples for CEGQI */ + CEGIS_SAMPLE_TRUST, +}; + enum SygusInvTemplMode { /** synthesize I( x ) */ SYGUS_INV_TEMPL_MODE_NONE, diff --git a/src/options/quantifiers_options b/src/options/quantifiers_options index 2166f0add..48a577faf 100644 --- a/src/options/quantifiers_options +++ b/src/options/quantifiers_options @@ -270,8 +270,6 @@ option sygusQePreproc --sygus-qe-preproc bool :default false option sygusMinGrammar --sygus-min-grammar bool :default true statically minimize sygus grammars -option sygusMinGrammarAgg --sygus-min-grammar-agg bool :default false - aggressively minimize sygus grammars option sygusAddConstGrammar --sygus-add-const-grammar bool :default true statically add constants appearing in conjecture to grammars option sygusGrammarNorm --sygus-grammar-norm bool :default false @@ -294,10 +292,23 @@ option sygusCRefEval --sygus-cref-eval bool :default true direct evaluation of refinement lemmas for conflict analysis option sygusCRefEvalMinExp --sygus-cref-eval-min-exp bool :default true use min explain for direct evaluation of refinement lemmas for conflict analysis - -option sygusStream --sygus-stream bool :default false + +option sygusStream --sygus-stream bool :read-write :default false enumerate a stream of solutions instead of terminating after the first one +option cegisSample --cegis-sample=MODE CVC4::theory::quantifiers::CegisSampleMode :read-write :default CVC4::theory::quantifiers::CEGIS_SAMPLE_NONE :include "options/quantifiers_modes.h" :handler stringToCegisSampleMode + mode for using samples in the counterexample-guided inductive synthesis loop + +# internal uses of sygus +option sygusRewSynth --sygus-rr-synth bool :default false + use sygus to enumerate candidate rewrite rules via sampling +option sygusRewVerify --sygus-rr-verify bool :default false + use sygus to verify the correctness of rewrite rules via sampling +option sygusSamples --sygus-samples=N int :read-write :default 100 :read-write + number of points to consider when doing sygus rewriter sample testing +option sygusSampleGrammar --sygus-sample-grammar bool :default true + when applicable, use grammar for choosing sample points + # CEGQI applied to general quantified formulas option cbqi --cbqi bool :read-write :default false turns on counterexample-based quantifier instantiation @@ -315,6 +326,10 @@ option cbqiMultiInst --cbqi-multi-inst bool :read-write :default false when applicable, do multi instantiations per quantifier per round in counterexample-based quantifier instantiation option cbqiRepeatLit --cbqi-repeat-lit bool :read-write :default false solve literals more than once in counterexample-based quantifier instantiation +option cbqiInnermost --cbqi-innermost bool :read-write :default true + only process innermost quantified formulas in counterexample-based quantifier instantiation +option cbqiNestedQE --cbqi-nested-qe bool :read-write :default false + process nested quantified formulas with quantifier elimination in counterexample-based quantifier instantiation # CEGQI for arithmetic option cbqiUseInfInt --cbqi-use-inf-int bool :read-write :default false @@ -333,10 +348,6 @@ option cbqiNopt --cbqi-nopt bool :default true non-optimal bounds for counterexample-based quantifier instantiation option cbqiLitDepend --cbqi-lit-dep bool :default true dependency lemmas for quantifier alternation in counterexample-based quantifier instantiation -option cbqiInnermost --cbqi-innermost bool :read-write :default true - only process innermost quantified formulas in counterexample-based quantifier instantiation -option cbqiNestedQE --cbqi-nested-qe bool :read-write :default false - process nested quantified formulas with quantifier elimination in counterexample-based quantifier instantiation # CEGQI for EPR option quantEpr --quant-epr bool :default false :read-write @@ -345,7 +356,7 @@ option quantEprMatching --quant-epr-match bool :default true use matching heuristics for EPR instantiation # CEGQI for BV -option cbqiBv --cbqi-bv bool :read-write :default true +option cbqiBv cbqi-bv --cbqi-bv bool :read-write :default true use word-level inversion approach for counterexample-guided quantifier instantiation for bit-vectors option cbqiBvInterleaveValue --cbqi-bv-interleave-value bool :read-write :default false interleave model value instantiation with word-level inversion approach @@ -355,6 +366,8 @@ option cbqiBvRmExtract --cbqi-bv-rm-extract bool :read-write :default true replaces extract terms with variables for counterexample-guided instantiation for bit-vectors option cbqiBvLinearize --cbqi-bv-linear bool :read-write :default true linearize adder chains for variables +option cbqiBvConcInv cbqi-bv-concat-inv --cbqi-bv-concat-inv bool :read-write :default true + compute inverse for concat over equalities rather than producing an invertibility condition ### local theory extensions options diff --git a/src/options/smt_options b/src/options/smt_options index fa6c3ae4e..b19420060 100644 --- a/src/options/smt_options +++ b/src/options/smt_options @@ -54,6 +54,9 @@ option dumpUnsatCores --dump-unsat-cores bool :default false :link --produce-uns option dumpUnsatCoresFull dump-unsat-cores-full --dump-unsat-cores-full bool :default false :link --dump-unsat-cores :link-smt dump-unsat-cores :notify notifyBeforeSearch dump the full unsat core, including unlabeled assertions +option checkSynthSol --check-synth-sol bool :default false + checks whether produced solutions to functions-to-synthesize satisfy the conjecture + option produceAssignments produce-assignments --produce-assignments bool :default false :notify notifyBeforeSearch support the get-assignment command diff --git a/src/parser/antlr_input.h b/src/parser/antlr_input.h index d2bb8667d..422ad9796 100644 --- a/src/parser/antlr_input.h +++ b/src/parser/antlr_input.h @@ -234,7 +234,7 @@ protected: void setAntlr3Parser(pANTLR3_PARSER pParser); /** Set the Parser object for this input. */ - virtual void setParser(Parser& parser); + void setParser(Parser& parser) override; };/* class AntlrInput */ inline std::string AntlrInput::tokenText(pANTLR3_COMMON_TOKEN token) { diff --git a/src/parser/cvc/cvc_input.h b/src/parser/cvc/cvc_input.h index c35d8d963..c02c4f452 100644 --- a/src/parser/cvc/cvc_input.h +++ b/src/parser/cvc/cvc_input.h @@ -69,7 +69,7 @@ class CvcInput : public AntlrInput { * * @throws ParserException if an error is encountered during parsing. */ - Expr parseExpr(); + Expr parseExpr() override; private: /** Initialize the class. Called from the constructors once the input stream diff --git a/src/parser/smt1/smt1_input.h b/src/parser/smt1/smt1_input.h index 7577b7bff..cd285255f 100644 --- a/src/parser/smt1/smt1_input.h +++ b/src/parser/smt1/smt1_input.h @@ -74,7 +74,7 @@ public: * * @throws ParserException if an error is encountered during parsing. */ - Expr parseExpr(); + Expr parseExpr() override; private: /** diff --git a/src/parser/smt2/smt2.cpp b/src/parser/smt2/smt2.cpp index e4f6569b8..77b50af4c 100644 --- a/src/parser/smt2/smt2.cpp +++ b/src/parser/smt2/smt2.cpp @@ -61,6 +61,17 @@ void Smt2::addArithmeticOperators() { addOperator(kind::SINE, "sin"); addOperator(kind::COSINE, "cos"); addOperator(kind::TANGENT, "tan"); + addOperator(kind::COSECANT, "csc"); + addOperator(kind::SECANT, "sec"); + addOperator(kind::COTANGENT, "cot"); + addOperator(kind::ARCSINE, "arcsin"); + addOperator(kind::ARCCOSINE, "arccos"); + addOperator(kind::ARCTANGENT, "arctan"); + addOperator(kind::ARCCOSECANT, "arccsc"); + addOperator(kind::ARCSECANT, "arcsec"); + addOperator(kind::ARCCOTANGENT, "arccot"); + + addOperator(kind::SQRT, "sqrt"); } void Smt2::addBitvectorOperators() { diff --git a/src/parser/smt2/smt2_input.h b/src/parser/smt2/smt2_input.h index 0acb5462d..44187cd2d 100644 --- a/src/parser/smt2/smt2_input.h +++ b/src/parser/smt2/smt2_input.h @@ -85,7 +85,7 @@ class Smt2Input : public AntlrInput { * * @throws ParserException if an error is encountered during parsing. */ - Expr parseExpr(); + Expr parseExpr() override; };/* class Smt2Input */ diff --git a/src/parser/smt2/sygus_input.h b/src/parser/smt2/sygus_input.h index 0dca60a82..58d78fb76 100644 --- a/src/parser/smt2/sygus_input.h +++ b/src/parser/smt2/sygus_input.h @@ -82,7 +82,7 @@ class SygusInput : public AntlrInput { * * @throws ParserException if an error is encountered during parsing. */ - Expr parseExpr(); + Expr parseExpr() override; };/* class SygusInput */ diff --git a/src/parser/tptp/tptp_input.h b/src/parser/tptp/tptp_input.h index 5dd56034d..9a820f26d 100644 --- a/src/parser/tptp/tptp_input.h +++ b/src/parser/tptp/tptp_input.h @@ -82,7 +82,7 @@ class TptpInput : public AntlrInput { * * @throws ParserException if an error is encountered during parsing. */ - Expr parseExpr(); + Expr parseExpr() override; };/* class TptpInput */ diff --git a/src/printer/cvc/cvc_printer.cpp b/src/printer/cvc/cvc_printer.cpp index f20cb7cce..27105c3b4 100644 --- a/src/printer/cvc/cvc_printer.cpp +++ b/src/printer/cvc/cvc_printer.cpp @@ -1010,121 +1010,121 @@ void CvcPrinter::toStream(std::ostream& out, const CommandStatus* s) const }/* CvcPrinter::toStream(CommandStatus*) */ -void CvcPrinter::toStream(std::ostream& out, - const Model& m, - const Command* c) const +namespace { + +void DeclareTypeCommandToStream(std::ostream& out, + const theory::TheoryModel& model, + const DeclareTypeCommand& command) { - const theory::TheoryModel& tm = (const theory::TheoryModel&) m; - if(dynamic_cast<const DeclareTypeCommand*>(c) != NULL) { - TypeNode tn = TypeNode::fromType( ((const DeclareTypeCommand*)c)->getType() ); - if (options::modelUninterpDtEnum() && tn.isSort()) + TypeNode type_node = TypeNode::fromType(command.getType()); + const std::vector<Node>* type_reps = + model.getRepSet()->getTypeRepsOrNull(type_node); + if (options::modelUninterpDtEnum() && type_node.isSort() + && type_reps != nullptr) + { + out << "DATATYPE" << std::endl; + out << " " << command.getSymbol() << " = "; + for (size_t i = 0; i < type_reps->size(); i++) { - const theory::RepSet* rs = tm.getRepSet(); - if (rs->d_type_reps.find(tn) != rs->d_type_reps.end()) + if (i > 0) { - out << "DATATYPE" << std::endl; - out << " " << dynamic_cast<const DeclareTypeCommand*>(c)->getSymbol() - << " = "; - for (size_t i = 0; i < (*rs->d_type_reps.find(tn)).second.size(); i++) - { - if (i > 0) - { - out << "| "; - } - out << (*rs->d_type_reps.find(tn)).second[i] << " "; - } - out << std::endl << "END;" << std::endl; + out << "| "; } - else + out << (*type_reps)[i] << " "; + } + out << std::endl << "END;" << std::endl; + } + else if (type_node.isSort() && type_reps != nullptr) + { + out << "% cardinality of " << type_node << " is " << type_reps->size() + << std::endl; + out << command << std::endl; + for (Node type_rep : *type_reps) + { + if (type_rep.isVar()) { - if (tn.isSort()) - { - // print the cardinality - if (rs->d_type_reps.find(tn) != rs->d_type_reps.end()) - { - out << "% cardinality of " << tn << " is " - << (*rs->d_type_reps.find(tn)).second.size() << std::endl; - } - } - out << c << std::endl; - if (tn.isSort()) - { - // print the representatives - if (rs->d_type_reps.find(tn) != rs->d_type_reps.end()) - { - for (size_t i = 0; i < (*rs->d_type_reps.find(tn)).second.size(); - i++) - { - if ((*rs->d_type_reps.find(tn)).second[i].isVar()) - { - out << (*rs->d_type_reps.find(tn)).second[i] << " : " << tn - << ";" << std::endl; - } - else - { - out << "% rep: " << (*rs->d_type_reps.find(tn)).second[i] - << std::endl; - } - } - } - } + out << type_rep << " : " << type_node << ";" << std::endl; } - } - } else if(dynamic_cast<const DeclareFunctionCommand*>(c) != NULL) { - Node n = Node::fromExpr( ((const DeclareFunctionCommand*)c)->getFunction() ); - if(n.getKind() == kind::SKOLEM) { - // don't print out internal stuff - return; - } - TypeNode tn = n.getType(); - out << n << " : "; - if( tn.isFunction() || tn.isPredicate() ){ - out << "("; - for( size_t i=0; i<tn.getNumChildren()-1; i++ ){ - if( i>0 ) out << ", "; - out << tn[i]; + else + { + out << "% rep: " << type_rep << std::endl; } - out << ") -> " << tn.getRangeType(); - }else{ - out << tn; } - Node val = Node::fromExpr(tm.getSmtEngine()->getValue(n.toExpr())); - if( options::modelUninterpDtEnum() && val.getKind() == kind::STORE ) { - const theory::RepSet* rs = tm.getRepSet(); - TypeNode tn = val[1].getType(); - if (tn.isSort() && rs->d_type_reps.find(tn) != rs->d_type_reps.end()) + } + else + { + out << command << std::endl; + } +} + +void DeclareFunctionCommandToStream(std::ostream& out, + const theory::TheoryModel& model, + const DeclareFunctionCommand& command) +{ + Node n = Node::fromExpr(command.getFunction()); + if (n.getKind() == kind::SKOLEM) + { + // don't print out internal stuff + return; + } + TypeNode tn = n.getType(); + out << n << " : "; + if (tn.isFunction() || tn.isPredicate()) + { + out << "("; + for (size_t i = 0; i < tn.getNumChildren() - 1; i++) + { + if (i > 0) { - Cardinality indexCard((*rs->d_type_reps.find(tn)).second.size()); - val = theory::arrays::TheoryArraysRewriter::normalizeConstant( val, indexCard ); + out << ", "; } + out << tn[i]; } - out << " = " << val << ";" << std::endl; - -/* - //for table format (work in progress) - bool printedModel = false; - if( tn.isFunction() ){ - if( options::modelFormatMode()==MODEL_FORMAT_MODE_TABLE ){ - //specialized table format for functions - RepSetIterator riter( &d_rep_set ); - riter.setFunctionDomain( n ); - while( !riter.isFinished() ){ - std::vector< Node > children; - children.push_back( n ); - for( int i=0; i<riter.getNumTerms(); i++ ){ - children.push_back( riter.getTerm( i ) ); - } - Node nn = NodeManager::currentNM()->mkNode( APPLY_UF, children ); - Node val = getValue( nn ); - out << val << " "; - riter.increment(); - } - printedModel = true; + out << ") -> " << tn.getRangeType(); + } + else + { + out << tn; + } + Node val = Node::fromExpr(model.getSmtEngine()->getValue(n.toExpr())); + if (options::modelUninterpDtEnum() && val.getKind() == kind::STORE) + { + TypeNode type_node = val[1].getType(); + if (tn.isSort()) + { + if (const std::vector<Node>* type_reps = + model.getRepSet()->getTypeRepsOrNull(type_node)) + { + Cardinality indexCard(type_reps->size()); + val = theory::arrays::TheoryArraysRewriter::normalizeConstant( + val, indexCard); } } -*/ - }else{ - out << c << std::endl; + } + out << " = " << val << ";" << std::endl; +} + +} // namespace + +void CvcPrinter::toStream(std::ostream& out, + const Model& model, + const Command* command) const +{ + const auto* theory_model = dynamic_cast<const theory::TheoryModel*>(&model); + AlwaysAssert(theory_model != nullptr); + if (const auto* declare_type_command = + dynamic_cast<const DeclareTypeCommand*>(command)) + { + DeclareTypeCommandToStream(out, *theory_model, *declare_type_command); + } + else if (const auto* dfc = + dynamic_cast<const DeclareFunctionCommand*>(command)) + { + DeclareFunctionCommandToStream(out, *theory_model, *dfc); + } + else + { + out << command << std::endl; } } diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index 54fc10719..e06f8c062 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -425,6 +425,16 @@ void Smt2Printer::toStream(std::ostream& out, case kind::SINE: case kind::COSINE: case kind::TANGENT: + case kind::COSECANT: + case kind::SECANT: + case kind::COTANGENT: + case kind::ARCSINE: + case kind::ARCCOSINE: + case kind::ARCTANGENT: + case kind::ARCCOSECANT: + case kind::ARCSECANT: + case kind::ARCCOTANGENT: + case kind::SQRT: case kind::MINUS: case kind::UMINUS: case kind::LT: @@ -891,6 +901,16 @@ static string smtKindString(Kind k) case kind::SINE: return "sin"; case kind::COSINE: return "cos"; case kind::TANGENT: return "tan"; + case kind::COSECANT: return "csc"; + case kind::SECANT: return "sec"; + case kind::COTANGENT: return "cot"; + case kind::ARCSINE: return "arcsin"; + case kind::ARCCOSINE: return "arccos"; + case kind::ARCTANGENT: return "arctan"; + case kind::ARCCOSECANT: return "arccsc"; + case kind::ARCSECANT: return "arcsec"; + case kind::ARCCOTANGENT: return "arccot"; + case kind::SQRT: return "sqrt"; case kind::MINUS: return "-"; case kind::UMINUS: return "-"; case kind::LT: return "<"; @@ -1277,116 +1297,140 @@ void Smt2Printer::toStream(std::ostream& out, const Model& m) const } } -void Smt2Printer::toStream(std::ostream& out, - const Model& m, - const Command* c) const +namespace { + +void DeclareTypeCommandToStream(std::ostream& out, + const theory::TheoryModel& model, + const DeclareTypeCommand& command, + Variant variant) { - const theory::TheoryModel& tm = (const theory::TheoryModel&) m; - if(dynamic_cast<const DeclareTypeCommand*>(c) != NULL) { - TypeNode tn = TypeNode::fromType( ((const DeclareTypeCommand*)c)->getType() ); - const theory::RepSet* rs = tm.getRepSet(); - const std::map<TypeNode, std::vector<Node> >& type_reps = rs->d_type_reps; - - std::map< TypeNode, std::vector< Node > >::const_iterator tn_iterator = type_reps.find( tn ); - if( options::modelUninterpDtEnum() && tn.isSort() && tn_iterator != type_reps.end() ){ - if(d_variant == smt2_6_variant) { - out << "(declare-datatypes ((" << dynamic_cast<const DeclareTypeCommand*>(c)->getSymbol() << " 0)) ("; - }else{ - out << "(declare-datatypes () ((" << dynamic_cast<const DeclareTypeCommand*>(c)->getSymbol() << " "; - } - for( size_t i=0, N = tn_iterator->second.size(); i < N; i++ ){ - out << "(" << (*tn_iterator).second[i] << ")"; - } - out << ")))" << endl; - } else { - if( tn.isSort() ){ - //print the cardinality - if( tn_iterator != type_reps.end() ) { - out << "; cardinality of " << tn << " is " << tn_iterator->second.size() << endl; - } + TypeNode tn = TypeNode::fromType(command.getType()); + const std::vector<Node>* type_refs = model.getRepSet()->getTypeRepsOrNull(tn); + if (options::modelUninterpDtEnum() && tn.isSort() && type_refs != nullptr) + { + if (variant == smt2_6_variant) + { + out << "(declare-datatypes ((" << command.getSymbol() << " 0)) ("; + } + else + { + out << "(declare-datatypes () ((" << command.getSymbol() << " "; + } + for (Node type_ref : *type_refs) + { + out << "(" << type_ref << ")"; + } + out << ")))" << endl; + } + else if (tn.isSort() && type_refs != nullptr) + { + // print the cardinality + out << "; cardinality of " << tn << " is " << type_refs->size() << endl; + out << command << endl; + // print the representatives + for (Node type_ref : *type_refs) + { + if (type_ref.isVar()) + { + out << "(declare-fun " << quoteSymbol(type_ref) << " () " << tn << ")" + << endl; } - out << c << endl; - if( tn.isSort() ){ - //print the representatives - if( tn_iterator != type_reps.end() ){ - for( size_t i = 0, N = (*tn_iterator).second.size(); i < N; i++ ){ - TNode current = (*tn_iterator).second[i]; - if( current.isVar() ){ - out << "(declare-fun " << quoteSymbol(current) << " () " << tn << ")" << endl; - }else{ - out << "; rep: " << current << endl; - } - } - } + else + { + out << "; rep: " << type_ref << endl; } } - } else if(dynamic_cast<const DeclareFunctionCommand*>(c) != NULL) { - const DeclareFunctionCommand* dfc = (const DeclareFunctionCommand*)c; - Node n = Node::fromExpr( dfc->getFunction() ); - if(dfc->getPrintInModelSetByUser()) { - if(!dfc->getPrintInModel()) { - return; - } - } else if(n.getKind() == kind::SKOLEM) { - // don't print out internal stuff + } + else + { + out << command << endl; + } +} + +void DeclareFunctionCommandToStream(std::ostream& out, + const theory::TheoryModel& model, + const DeclareFunctionCommand& command) +{ + Node n = Node::fromExpr(command.getFunction()); + if (command.getPrintInModelSetByUser()) + { + if (!command.getPrintInModel()) + { return; } - Node val = Node::fromExpr(tm.getSmtEngine()->getValue(n.toExpr())); - if(val.getKind() == kind::LAMBDA) { - out << "(define-fun " << n << " " << val[0] - << " " << n.getType().getRangeType() - << " " << val[1] << ")" << endl; - } else { - if( options::modelUninterpDtEnum() && val.getKind() == kind::STORE ) { - TypeNode tn = val[1].getType(); - const theory::RepSet* rs = tm.getRepSet(); - if (tn.isSort() && rs->d_type_reps.find(tn) != rs->d_type_reps.end()) - { - Cardinality indexCard((*rs->d_type_reps.find(tn)).second.size()); - val = theory::arrays::TheoryArraysRewriter::normalizeConstant( val, indexCard ); - } - } - out << "(define-fun " << n << " () " - << n.getType() << " "; - if(val.getType().isInteger() && n.getType().isReal() && !n.getType().isInteger()) { - //toStreamReal(out, val, true); - toStreamRational(out, val.getConst<Rational>(), true); - //out << val << ".0"; - } else { - out << val; + } + else if (n.getKind() == kind::SKOLEM) + { + // don't print out internal stuff + return; + } + Node val = Node::fromExpr(model.getSmtEngine()->getValue(n.toExpr())); + if (val.getKind() == kind::LAMBDA) + { + out << "(define-fun " << n << " " << val[0] << " " + << n.getType().getRangeType() << " " << val[1] << ")" << endl; + } + else + { + if (options::modelUninterpDtEnum() && val.getKind() == kind::STORE) + { + TypeNode tn = val[1].getType(); + const std::vector<Node>* type_refs = + model.getRepSet()->getTypeRepsOrNull(tn); + if (tn.isSort() && type_refs != nullptr) + { + Cardinality indexCard(type_refs->size()); + val = theory::arrays::TheoryArraysRewriter::normalizeConstant( + val, indexCard); } - out << ")" << endl; } -/* - //for table format (work in progress) - bool printedModel = false; - if( tn.isFunction() ){ - if( options::modelFormatMode()==MODEL_FORMAT_MODE_TABLE ){ - //specialized table format for functions - RepSetIterator riter( &d_rep_set ); - riter.setFunctionDomain( n ); - while( !riter.isFinished() ){ - std::vector< Node > children; - children.push_back( n ); - for( int i=0; i<riter.getNumTerms(); i++ ){ - children.push_back( riter.getTerm( i ) ); - } - Node nn = NodeManager::currentNM()->mkNode( APPLY_UF, children ); - Node val = getValue( nn ); - out << val << " "; - riter.increment(); - } - printedModel = true; - } + out << "(define-fun " << n << " () " << n.getType() << " "; + if (val.getType().isInteger() && n.getType().isReal() + && !n.getType().isInteger()) + { + // toStreamReal(out, val, true); + toStreamRational(out, val.getConst<Rational>(), true); + // out << val << ".0"; } -*/ - } else { - DatatypeDeclarationCommand* c1 = (DatatypeDeclarationCommand*)c; - const vector<DatatypeType>& datatypes = c1->getDatatypes(); - if (!datatypes[0].isTuple()) { - out << c << endl; + else + { + out << val; + } + out << ")" << endl; + } +} + +} // namespace + +void Smt2Printer::toStream(std::ostream& out, + const Model& model, + const Command* command) const +{ + const theory::TheoryModel* theory_model = + dynamic_cast<const theory::TheoryModel*>(&model); + AlwaysAssert(theory_model != nullptr); + if (const DeclareTypeCommand* dtc = + dynamic_cast<const DeclareTypeCommand*>(command)) + { + DeclareTypeCommandToStream(out, *theory_model, *dtc, d_variant); + } + else if (const DeclareFunctionCommand* dfc = + dynamic_cast<const DeclareFunctionCommand*>(command)) + { + DeclareFunctionCommandToStream(out, *theory_model, *dfc); + } + else if (const DatatypeDeclarationCommand* datatype_declaration_command = + dynamic_cast<const DatatypeDeclarationCommand*>(command)) + { + if (!datatype_declaration_command->getDatatypes()[0].isTuple()) + { + out << command << endl; } } + else + { + Unreachable(); + } } void Smt2Printer::toStreamSygus(std::ostream& out, TNode n) const diff --git a/src/prop/cryptominisat.h b/src/prop/cryptominisat.h index bb2f47783..fca2c7aa1 100644 --- a/src/prop/cryptominisat.h +++ b/src/prop/cryptominisat.h @@ -87,46 +87,7 @@ public: Statistics d_statistics; }; -} // CVC4::prop -} // CVC4 -#else // CVC4_USE_CRYPTOMINISAT - -namespace CVC4 { -namespace prop { - -class CryptoMinisatSolver : public SatSolver { - -public: - CryptoMinisatSolver(StatisticsRegistry* registry, - const std::string& name = "") { Unreachable(); } - /** Assert a clause in the solver. */ - ClauseId addClause(SatClause& clause, bool removable) { - Unreachable(); - } - - /** Return true if the solver supports native xor resoning */ - bool nativeXor() { Unreachable(); } - - /** Add a clause corresponding to rhs = l1 xor .. xor ln */ - ClauseId addXorClause(SatClause& clause, bool rhs, bool removable) { - Unreachable(); - } - - SatVariable newVar(bool isTheoryAtom, bool preRegister, bool canErase) { Unreachable(); } - SatVariable trueVar() { Unreachable(); } - SatVariable falseVar() { Unreachable(); } - SatValue solve() { Unreachable(); } - SatValue solve(long unsigned int&) { Unreachable(); } - void interrupt() { Unreachable(); } - SatValue value(SatLiteral l) { Unreachable(); } - SatValue modelValue(SatLiteral l) { Unreachable(); } - unsigned getAssertionLevel() const { Unreachable(); } - bool ok() const { return false;}; - - -};/* class CryptoMinisatSolver */ -} // CVC4::prop -} // CVC4 - -#endif // CVC4_USE_CRYPTOMINISAT +} // namespace prop +} // namespace CVC4 +#endif // CVC4_USE_CRYPTOMINISAT diff --git a/src/prop/sat_solver_factory.cpp b/src/prop/sat_solver_factory.cpp index 27e2daf11..135fc300d 100644 --- a/src/prop/sat_solver_factory.cpp +++ b/src/prop/sat_solver_factory.cpp @@ -16,6 +16,8 @@ #include "prop/sat_solver_factory.h" +// Cryptominisat header has to come first since there are name clashes for +// var_Undef, l_True, ... (static const in Cryptominisat vs. #define in Minisat) #include "prop/cryptominisat.h" #include "prop/minisat/minisat.h" #include "prop/bvminisat/bvminisat.h" @@ -23,19 +25,29 @@ namespace CVC4 { namespace prop { -BVSatSolverInterface* SatSolverFactory::createMinisat(context::Context* mainSatContext, StatisticsRegistry* registry, const std::string& name) { +BVSatSolverInterface* SatSolverFactory::createMinisat( + context::Context* mainSatContext, + StatisticsRegistry* registry, + const std::string& name) +{ return new BVMinisatSatSolver(registry, mainSatContext, name); } SatSolver* SatSolverFactory::createCryptoMinisat(StatisticsRegistry* registry, - const std::string& name) { -return new CryptoMinisatSolver(registry, name); + const std::string& name) +{ +#ifdef CVC4_USE_CRYPTOMINISAT + return new CryptoMinisatSolver(registry, name); +#else + Unreachable("CVC4 was not compiled with Cryptominisat support."); +#endif } - -DPLLSatSolverInterface* SatSolverFactory::createDPLLMinisat(StatisticsRegistry* registry) { +DPLLSatSolverInterface* SatSolverFactory::createDPLLMinisat( + StatisticsRegistry* registry) +{ return new MinisatSatSolver(registry); } -} /* CVC4::prop namespace */ -} /* CVC4 namespace */ +} // namespace prop +} // namespace CVC4 diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index b2d43ac51..6af5e38d5 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -647,10 +647,9 @@ public: * * Returns false if the formula simplifies to "false" */ - bool simplifyAssertions() throw(TypeCheckingException, LogicException, - UnsafeInterruptException); + bool simplifyAssertions(); -public: + public: SmtEnginePrivate(SmtEngine& smt) : d_smt(smt), @@ -732,7 +731,8 @@ public: new SetToDefaultSourceListener(&d_managedReplayLog), true)); } - ~SmtEnginePrivate() throw() { + ~SmtEnginePrivate() + { delete d_listenerRegistrations; if(d_propagatorNeedsFinish) { @@ -743,7 +743,8 @@ public: } ResourceManager* getResourceManager() { return d_resourceManager; } - void spendResource(unsigned amount) throw(UnsafeInterruptException) { + void spendResource(unsigned amount) + { d_resourceManager->spendResource(amount); } @@ -840,13 +841,12 @@ public: * even be simplified. * the 2nd and 3rd arguments added for bookkeeping for proofs */ - void addFormula(TNode n, bool inUnsatCore, bool inInput = true) - throw(TypeCheckingException, LogicException); + void addFormula(TNode n, bool inUnsatCore, bool inInput = true); /** Expand definitions in n. */ - Node expandDefinitions(TNode n, NodeToNodeHashMap& cache, - bool expandOnly = false) - throw(TypeCheckingException, LogicException, UnsafeInterruptException); + Node expandDefinitions(TNode n, + NodeToNodeHashMap& cache, + bool expandOnly = false); /** * Simplify node "in" by expanding definitions and applying any @@ -983,7 +983,7 @@ public: }/* namespace CVC4::smt */ -SmtEngine::SmtEngine(ExprManager* em) throw() +SmtEngine::SmtEngine(ExprManager* em) : d_context(new Context()), d_userLevels(), d_userContext(new UserContext()), @@ -1176,7 +1176,8 @@ void SmtEngine::shutdown() { } } -SmtEngine::~SmtEngine() throw() { +SmtEngine::~SmtEngine() +{ SmtScope smts(this); try { @@ -1248,7 +1249,8 @@ SmtEngine::~SmtEngine() throw() { } } -void SmtEngine::setLogic(const LogicInfo& logic) throw(ModalException) { +void SmtEngine::setLogic(const LogicInfo& logic) +{ SmtScope smts(this); if(d_fullyInited) { throw ModalException("Cannot set logic in SmtEngine after the engine has " @@ -1259,7 +1261,7 @@ void SmtEngine::setLogic(const LogicInfo& logic) throw(ModalException) { } void SmtEngine::setLogic(const std::string& s) - throw(ModalException, LogicException) { +{ SmtScope smts(this); try { setLogic(LogicInfo(s)); @@ -1268,16 +1270,12 @@ void SmtEngine::setLogic(const std::string& s) } } -void SmtEngine::setLogic(const char* logic) - throw(ModalException, LogicException) { - setLogic(string(logic)); -} - +void SmtEngine::setLogic(const char* logic) { setLogic(string(logic)); } LogicInfo SmtEngine::getLogicInfo() const { return d_logic; } - -void SmtEngine::setLogicInternal() throw() { +void SmtEngine::setLogicInternal() +{ Assert(!d_fullyInited, "setting logic in SmtEngine but the engine has already" " finished initializing for this run"); d_logic.lock(); @@ -1355,13 +1353,13 @@ void SmtEngine::setDefaults() { */ } - if(options::checkModels()) { - if(! options::produceAssertions()) { + if ((options::checkModels() || options::checkSynthSol()) + && !options::produceAssertions()) + { Notice() << "SmtEngine: turning on produce-assertions to support " - << "check-models." << endl; + << "check-models or check-synth-sol." << endl; setOption("produce-assertions", SExpr("true")); } - } if(options::unsatCores()) { if(options::simplificationMode() != SIMPLIFICATION_MODE_NONE) { @@ -1887,6 +1885,11 @@ void SmtEngine::setDefaults() { if( !options::instNoEntail.wasSetByUser() ){ options::instNoEntail.set( false ); } + if (options::sygusRewSynth()) + { + // rewrite rule synthesis implies that sygus stream must be true + options::sygusStream.set(true); + } if (options::sygusStream()) { // PBE and streaming modes are incompatible @@ -2124,8 +2127,7 @@ void SmtEngine::setDefaults() { } void SmtEngine::setInfo(const std::string& key, const CVC4::SExpr& value) - throw(OptionException, ModalException) { - +{ SmtScope smts(this); Trace("smt") << "SMT setInfo(" << key << ", " << value << ")" << endl; @@ -2494,8 +2496,7 @@ void SmtEnginePrivate::finishInit() { } Node SmtEnginePrivate::expandDefinitions(TNode n, unordered_map<Node, Node, NodeHashFunction>& cache, bool expandOnly) - throw(TypeCheckingException, LogicException, UnsafeInterruptException) { - +{ stack< triple<Node, Node, bool> > worklist; stack<Node> result; worklist.push(make_triple(Node(n), Node(n), false)); @@ -3877,7 +3878,7 @@ void SmtEnginePrivate::doMiplibTrick() { // returns false if simplification led to "false" bool SmtEnginePrivate::simplifyAssertions() - throw(TypeCheckingException, LogicException, UnsafeInterruptException) { +{ spendResource(options::preprocessStep()); Assert(d_smt.d_pendingPops == 0); try { @@ -4618,8 +4619,7 @@ void SmtEnginePrivate::processAssertions() { } void SmtEnginePrivate::addFormula(TNode n, bool inUnsatCore, bool inInput) - throw(TypeCheckingException, LogicException) { - +{ if (n == d_true) { // nothing to do return; @@ -4652,7 +4652,8 @@ void SmtEnginePrivate::addFormula(TNode n, bool inUnsatCore, bool inInput) //d_assertions.push_back(Rewriter::rewrite(n)); } -void SmtEngine::ensureBoolean(const Expr& e) throw(TypeCheckingException) { +void SmtEngine::ensureBoolean(const Expr& e) +{ Type type = e.getType(options::typeChecking()); Type boolType = d_exprManager->booleanType(); if(type != boolType) { @@ -4664,11 +4665,13 @@ void SmtEngine::ensureBoolean(const Expr& e) throw(TypeCheckingException) { } } -Result SmtEngine::checkSat(const Expr& ex, bool inUnsatCore) throw(Exception) { +Result SmtEngine::checkSat(const Expr& ex, bool inUnsatCore) +{ return checkSatisfiability(ex, inUnsatCore, false); } /* SmtEngine::checkSat() */ -Result SmtEngine::query(const Expr& ex, bool inUnsatCore) throw(Exception) { +Result SmtEngine::query(const Expr& ex, bool inUnsatCore) +{ Assert(!ex.isNull()); return checkSatisfiability(ex, inUnsatCore, true); } /* SmtEngine::query() */ @@ -4702,17 +4705,19 @@ Result SmtEngine::checkSatisfiability(const Expr& ex, bool inUnsatCore, bool isQ d_needPostsolve = false; } - // Push the context - internalPush(); - // Note that a query has been made d_queryMade = true; // reset global negation d_globalNegation = false; + bool didInternalPush = false; // Add the formula if(!e.isNull()) { + // Push the context + internalPush(); + didInternalPush = true; + d_problemExtended = true; Expr ea = isQuery ? e.notExpr() : e; if(d_assertionList != NULL) { @@ -4763,7 +4768,10 @@ Result SmtEngine::checkSatisfiability(const Expr& ex, bool inUnsatCore, bool isQ } // Pop the context - internalPop(); + if (didInternalPush) + { + internalPop(); + } // Remember the status d_status = r; @@ -4793,6 +4801,12 @@ Result SmtEngine::checkSatisfiability(const Expr& ex, bool inUnsatCore, bool isQ checkUnsatCore(); } } + // Check that synthesis solutions satisfy the conjecture + if (options::checkSynthSol() + && r.asSatisfiabilityResult().isSat() == Result::UNSAT) + { + checkSynthSolution(); + } return r; } catch (UnsafeInterruptException& e) { @@ -4803,7 +4817,8 @@ Result SmtEngine::checkSatisfiability(const Expr& ex, bool inUnsatCore, bool isQ } } -Result SmtEngine::checkSynth(const Expr& e) throw(Exception) { +Result SmtEngine::checkSynth(const Expr& e) +{ SmtScope smts(this); Trace("smt") << "Check synth: " << e << std::endl; Trace("smt-synth") << "Check synthesis conjecture: " << e << std::endl; @@ -4928,7 +4943,8 @@ Result SmtEngine::checkSynth(const Expr& e) throw(Exception) { return checkSatisfiability( e_check, true, false ); } -Result SmtEngine::assertFormula(const Expr& ex, bool inUnsatCore) throw(TypeCheckingException, LogicException, UnsafeInterruptException) { +Result SmtEngine::assertFormula(const Expr& ex, bool inUnsatCore) +{ Assert(ex.getExprManager() == d_exprManager); SmtScope smts(this); finalOptionsAreSet(); @@ -4955,7 +4971,8 @@ Node SmtEngine::postprocess(TNode node, TypeNode expectedType) const { return node; } -Expr SmtEngine::simplify(const Expr& ex) throw(TypeCheckingException, LogicException, UnsafeInterruptException) { +Expr SmtEngine::simplify(const Expr& ex) +{ Assert(ex.getExprManager() == d_exprManager); SmtScope smts(this); finalOptionsAreSet(); @@ -4978,7 +4995,8 @@ Expr SmtEngine::simplify(const Expr& ex) throw(TypeCheckingException, LogicExcep return n.toExpr(); } -Expr SmtEngine::expandDefinitions(const Expr& ex) throw(TypeCheckingException, LogicException, UnsafeInterruptException) { +Expr SmtEngine::expandDefinitions(const Expr& ex) +{ d_private->spendResource(options::preprocessStep()); Assert(ex.getExprManager() == d_exprManager); @@ -5004,7 +5022,8 @@ Expr SmtEngine::expandDefinitions(const Expr& ex) throw(TypeCheckingException, L } // TODO(#1108): Simplify the error reporting of this method. -Expr SmtEngine::getValue(const Expr& ex) const throw(ModalException, TypeCheckingException, LogicException, UnsafeInterruptException) { +Expr SmtEngine::getValue(const Expr& ex) const +{ Assert(ex.getExprManager() == d_exprManager); SmtScope smts(this); @@ -5475,6 +5494,104 @@ void SmtEngine::checkModel(bool hardFailure) { Notice() << "SmtEngine::checkModel(): all assertions checked out OK !" << endl; } +void SmtEngine::checkSynthSolution() +{ + NodeManager* nm = NodeManager::currentNM(); + Notice() << "SmtEngine::checkSynthSolution(): checking synthesis solution" << endl; + map<Node, Node> sol_map; + /* Get solutions and build auxiliary vectors for substituting */ + d_theoryEngine->getSynthSolutions(sol_map); + Trace("check-synth-sol") << "Got solution map:\n"; + std::vector<Node> function_vars, function_sols; + for (const auto& pair : sol_map) + { + Trace("check-synth-sol") << pair.first << " --> " << pair.second << "\n"; + function_vars.push_back(pair.first); + function_sols.push_back(pair.second); + } + Trace("check-synth-sol") << "Starting new SMT Engine\n"; + /* Start new SMT engine to check solutions */ + SmtEngine solChecker(d_exprManager); + solChecker.setLogic(getLogicInfo()); + setOption("check-synth-sol", SExpr("false")); + + Trace("check-synth-sol") << "Retrieving assertions\n"; + // Build conjecture from original assertions + if (d_assertionList == NULL) + { + Trace("check-synth-sol") << "No assertions to check\n"; + return; + } + for (AssertionList::const_iterator i = d_assertionList->begin(); + i != d_assertionList->end(); + ++i) + { + Notice() << "SmtEngine::checkSynthSolution(): checking assertion " << *i << endl; + Trace("check-synth-sol") << "Retrieving assertion " << *i << "\n"; + Node conj = Node::fromExpr(*i); + // Apply any define-funs from the problem. + { + unordered_map<Node, Node, NodeHashFunction> cache; + conj = d_private->expandDefinitions(conj, cache); + } + Notice() << "SmtEngine::checkSynthSolution(): -- expands to " << conj << endl; + Trace("check-synth-sol") << "Expanded assertion " << conj << "\n"; + + // Apply solution map to conjecture body + Node conjBody; + /* Whether property is quantifier free */ + if (conj[1].getKind() != kind::EXISTS) + { + conjBody = conj[1].substitute(function_vars.begin(), + function_vars.end(), + function_sols.begin(), + function_sols.end()); + } + else + { + conjBody = conj[1][1].substitute(function_vars.begin(), + function_vars.end(), + function_sols.begin(), + function_sols.end()); + + /* Skolemize property */ + std::vector<Node> vars, skos; + for (unsigned j = 0, size = conj[1][0].getNumChildren(); j < size; ++j) + { + vars.push_back(conj[1][0][j]); + std::stringstream ss; + ss << "sk_" << j; + skos.push_back(nm->mkSkolem(ss.str(), conj[1][0][j].getType())); + Trace("check-synth-sol") << "\tSkolemizing " << conj[1][0][j] << " to " + << skos.back() << "\n"; + } + conjBody = conjBody.substitute( + vars.begin(), vars.end(), skos.begin(), skos.end()); + } + Notice() << "SmtEngine::checkSynthSolution(): -- body substitutes to " + << conjBody << endl; + Trace("check-synth-sol") << "Substituted body of assertion to " << conjBody + << "\n"; + solChecker.assertFormula(conjBody.toExpr()); + Result r = solChecker.checkSat(); + Notice() << "SmtEngine::checkSynthSolution(): result is " << r << endl; + Trace("check-synth-sol") << "Satsifiability check: " << r << "\n"; + if (r.asSatisfiabilityResult().isUnknown()) + { + InternalError( + "SmtEngine::checkSynthSolution(): could not check solution, result " + "unknown."); + } + else if (r.asSatisfiabilityResult().isSat()) + { + InternalError( + "SmtEngine::checkSynhtSol(): produced solution allows satisfiable " + "negated conjecture."); + } + solChecker.resetAssertions(); + } +} + // TODO(#1108): Simplify the error reporting of this method. UnsatCore SmtEngine::getUnsatCore() { Trace("smt") << "SMT getUnsatCore()" << endl; @@ -5553,8 +5670,8 @@ void SmtEngine::printSynthSolution( std::ostream& out ) { } } -Expr SmtEngine::doQuantifierElimination(const Expr& e, bool doFull, - bool strict) throw(Exception) { +Expr SmtEngine::doQuantifierElimination(const Expr& e, bool doFull, bool strict) +{ SmtScope smts(this); if(!d_logic.isPure(THEORY_ARITH) && strict){ Warning() << "Unexpected logic for quantifier elimination " << d_logic << endl; @@ -5677,7 +5794,8 @@ vector<Expr> SmtEngine::getAssertions() { return vector<Expr>(d_assertionList->begin(), d_assertionList->end()); } -void SmtEngine::push() throw(ModalException, LogicException, UnsafeInterruptException) { +void SmtEngine::push() +{ SmtScope smts(this); finalOptionsAreSet(); doPendingPops(); @@ -5787,7 +5905,8 @@ void SmtEngine::doPendingPops() { } } -void SmtEngine::reset() throw() { +void SmtEngine::reset() +{ SmtScope smts(this); ExprManager *em = d_exprManager; Trace("smt") << "SMT reset()" << endl; @@ -5801,7 +5920,8 @@ void SmtEngine::reset() throw() { new(this) SmtEngine(em); } -void SmtEngine::resetAssertions() throw() { +void SmtEngine::resetAssertions() +{ SmtScope smts(this); doPendingPops(); @@ -5823,7 +5943,8 @@ void SmtEngine::resetAssertions() throw() { d_context->push(); } -void SmtEngine::interrupt() throw(ModalException) { +void SmtEngine::interrupt() +{ if(!d_fullyInited) { return; } @@ -5846,19 +5967,23 @@ unsigned long SmtEngine::getTimeUsage() const { return d_private->getResourceManager()->getTimeUsage(); } -unsigned long SmtEngine::getResourceRemaining() const throw(ModalException) { +unsigned long SmtEngine::getResourceRemaining() const +{ return d_private->getResourceManager()->getResourceRemaining(); } -unsigned long SmtEngine::getTimeRemaining() const throw(ModalException) { +unsigned long SmtEngine::getTimeRemaining() const +{ return d_private->getResourceManager()->getTimeRemaining(); } -Statistics SmtEngine::getStatistics() const throw() { +Statistics SmtEngine::getStatistics() const +{ return Statistics(*d_statisticsRegistry); } -SExpr SmtEngine::getStatistic(std::string name) const throw() { +SExpr SmtEngine::getStatistic(std::string name) const +{ return d_statisticsRegistry->getStatistic(name); } @@ -5901,9 +6026,8 @@ void SmtEngine::setPrintFuncInModel(Expr f, bool p) { } } - - -void SmtEngine::beforeSearch() throw(ModalException) { +void SmtEngine::beforeSearch() +{ if(d_fullyInited) { throw ModalException( "SmtEngine::beforeSearch called after initialization."); @@ -5912,8 +6036,7 @@ void SmtEngine::beforeSearch() throw(ModalException) { void SmtEngine::setOption(const std::string& key, const CVC4::SExpr& value) - throw(OptionException, ModalException) { - +{ NodeManagerScope nms(d_nodeManager); Trace("smt") << "SMT setOption(" << key << ", " << value << ")" << endl; @@ -5949,8 +6072,7 @@ void SmtEngine::setOption(const std::string& key, const CVC4::SExpr& value) } CVC4::SExpr SmtEngine::getOption(const std::string& key) const - throw(OptionException) { - +{ NodeManagerScope nms(d_nodeManager); Trace("smt") << "SMT getOption(" << key << ")" << endl; diff --git a/src/smt/smt_engine.h b/src/smt/smt_engine.h index 6d648ccda..e768bf826 100644 --- a/src/smt/smt_engine.h +++ b/src/smt/smt_engine.h @@ -287,6 +287,16 @@ class CVC4_PUBLIC SmtEngine { void checkModel(bool hardFailure = true); /** + * Check that a solution to a synthesis conjecture is indeed a solution. + * + * The check is made by determining if the negation of the synthesis + * conjecture in which the functions-to-synthesize have been replaced by the + * synthesized solutions, which is a quantifier-free formula, is + * unsatisfiable. If not, then the found solutions are wrong. + */ + void checkSynthSolution(); + + /** * Postprocess a value for output to the user. Involves doing things * like turning datatypes back into tuples, length-1-bitvectors back * into booleans, etc. @@ -338,7 +348,7 @@ class CVC4_PUBLIC SmtEngine { * Fully type-check the argument, and also type-check that it's * actually Boolean. */ - void ensureBoolean(const Expr& e) throw(TypeCheckingException); + void ensureBoolean(const Expr& e) /* throw(TypeCheckingException) */; void internalPush(); @@ -350,7 +360,7 @@ class CVC4_PUBLIC SmtEngine { * Internally handle the setting of a logic. This function should always * be called when d_logic is updated. */ - void setLogicInternal() throw(); + void setLogicInternal() /* throw() */; // TODO (Issue #1096): Remove this friend relationship. friend class ::CVC4::preprocessing::PreprocessingPassContext; @@ -413,27 +423,28 @@ class CVC4_PUBLIC SmtEngine { /** * Construct an SmtEngine with the given expression manager. */ - SmtEngine(ExprManager* em) throw(); + SmtEngine(ExprManager* em) /* throw() */; /** * Destruct the SMT engine. */ - ~SmtEngine() throw(); + ~SmtEngine(); /** * Set the logic of the script. */ - void setLogic(const std::string& logic) throw(ModalException, LogicException); + void setLogic( + const std::string& logic) /* throw(ModalException, LogicException) */; /** * Set the logic of the script. */ - void setLogic(const char* logic) throw(ModalException, LogicException); + void setLogic(const char* logic) /* throw(ModalException, LogicException) */; /** * Set the logic of the script. */ - void setLogic(const LogicInfo& logic) throw(ModalException); + void setLogic(const LogicInfo& logic) /* throw(ModalException) */; /** * Get the logic information currently set @@ -444,7 +455,7 @@ class CVC4_PUBLIC SmtEngine { * Set information about the script executing. */ void setInfo(const std::string& key, const CVC4::SExpr& value) - throw(OptionException, ModalException); + /* throw(OptionException, ModalException) */; /** * Query information about the SMT environment. @@ -455,13 +466,13 @@ class CVC4_PUBLIC SmtEngine { * Set an aspect of the current SMT execution environment. */ void setOption(const std::string& key, const CVC4::SExpr& value) - throw(OptionException, ModalException); + /* throw(OptionException, ModalException) */; /** * Get an aspect of the current SMT execution environment. */ CVC4::SExpr getOption(const std::string& key) const - throw(OptionException); + /* throw(OptionException) */; /** * Define function func in the current context to be: @@ -515,27 +526,29 @@ class CVC4_PUBLIC SmtEngine { * takes a Boolean flag to determine whether to include this asserted * formula in an unsat core (if one is later requested). */ - Result assertFormula(const Expr& e, bool inUnsatCore = true) throw(TypeCheckingException, LogicException, UnsafeInterruptException); + Result assertFormula(const Expr& e, bool inUnsatCore = true) + /* throw(TypeCheckingException, LogicException, UnsafeInterruptException) */ + ; /** * Check validity of an expression with respect to the current set * of assertions by asserting the query expression's negation and * calling check(). Returns valid, invalid, or unknown result. */ - Result query(const Expr& e, bool inUnsatCore = true) throw(Exception); + Result query(const Expr& e, bool inUnsatCore = true) /* throw(Exception) */; /** * Assert a formula (if provided) to the current context and call * check(). Returns sat, unsat, or unknown result. */ Result checkSat(const Expr& e = Expr(), - bool inUnsatCore = true) throw(Exception); + bool inUnsatCore = true) /* throw(Exception) */; /** * Assert a synthesis conjecture to the current context and call * check(). Returns sat, unsat, or unknown result. */ - Result checkSynth(const Expr& e) throw(Exception); + Result checkSynth(const Expr& e) /* throw(Exception) */; /** * Simplify a formula without doing "much" work. Does not involve @@ -546,20 +559,28 @@ class CVC4_PUBLIC SmtEngine { * @todo (design) is this meant to give an equivalent or an * equisatisfiable formula? */ - Expr simplify(const Expr& e) throw(TypeCheckingException, LogicException, UnsafeInterruptException); + Expr simplify( + const Expr& + e) /* throw(TypeCheckingException, LogicException, UnsafeInterruptException) */ + ; /** * Expand the definitions in a term or formula. No other * simplification or normalization is done. */ - Expr expandDefinitions(const Expr& e) throw(TypeCheckingException, LogicException, UnsafeInterruptException); + Expr expandDefinitions( + const Expr& + e) /* throw(TypeCheckingException, LogicException, UnsafeInterruptException) */ + ; /** * Get the assigned value of an expr (only if immediately preceded * by a SAT or INVALID query). Only permitted if the SmtEngine is * set to operate interactively and produce-models is on. */ - Expr getValue(const Expr& e) const throw(ModalException, TypeCheckingException, LogicException, UnsafeInterruptException); + Expr getValue(const Expr& e) const + /* throw(ModalException, TypeCheckingException, LogicException, UnsafeInterruptException) */ + ; /** * Add a function to the set of expressions whose value is to be @@ -645,8 +666,9 @@ class CVC4_PUBLIC SmtEngine { * The argument strict is whether to output * warnings, such as when an unexpected logic is used. */ - Expr doQuantifierElimination(const Expr& e, bool doFull, - bool strict = true) throw(Exception); + Expr doQuantifierElimination(const Expr& e, + bool doFull, + bool strict = true) /* throw(Exception) */; /** * Get list of quantified formulas that were instantiated @@ -675,7 +697,8 @@ class CVC4_PUBLIC SmtEngine { /** * Push a user-level context. */ - void push() throw(ModalException, LogicException, UnsafeInterruptException); + void + push() /* throw(ModalException, LogicException, UnsafeInterruptException) */; /** * Pop a user-level context. Throws an exception if nothing to pop. @@ -687,19 +710,19 @@ class CVC4_PUBLIC SmtEngine { * recreated. The result is as if newly constructed (so it still * retains the same options structure and ExprManager). */ - void reset() throw(); + void reset() /* throw() */; /** * Reset all assertions, global declarations, etc. */ - void resetAssertions() throw(); + void resetAssertions() /* throw() */; /** * Interrupt a running query. This can be called from another thread * or from a signal handler. Throws a ModalException if the SmtEngine * isn't currently in a query. */ - void interrupt() throw(ModalException); + void interrupt() /* throw(ModalException) */; /** * Set a resource limit for SmtEngine operations. This is like a time @@ -784,7 +807,7 @@ class CVC4_PUBLIC SmtEngine { * is not a cumulative resource limit set, this function throws a * ModalException. */ - unsigned long getResourceRemaining() const throw(ModalException); + unsigned long getResourceRemaining() const /* throw(ModalException) */; /** * Get the remaining number of milliseconds that can be consumed by @@ -792,7 +815,7 @@ class CVC4_PUBLIC SmtEngine { * If there is not a cumulative resource limit set, this function * throws a ModalException. */ - unsigned long getTimeRemaining() const throw(ModalException); + unsigned long getTimeRemaining() const /* throw(ModalException) */; /** * Permit access to the underlying ExprManager. @@ -804,12 +827,12 @@ class CVC4_PUBLIC SmtEngine { /** * Export statistics from this SmtEngine. */ - Statistics getStatistics() const throw(); + Statistics getStatistics() const /* throw() */; /** * Get the value of one named statistic from this SmtEngine. */ - SExpr getStatistic(std::string name) const throw(); + SExpr getStatistic(std::string name) const /* throw() */; /** * Flush statistic from this SmtEngine. Safe to use in a signal handler. @@ -819,10 +842,7 @@ class CVC4_PUBLIC SmtEngine { /** * Returns the most recent result of checkSat/query or (set-info :status). */ - Result getStatusOfLastCommand() const throw() { - return d_status; - } - + Result getStatusOfLastCommand() const /* throw() */ { return d_status; } /** * Set user attribute. * This function is called when an attribute is set by a user. @@ -840,7 +860,7 @@ class CVC4_PUBLIC SmtEngine { /** Throws a ModalException if the SmtEngine has been fully initialized. */ - void beforeSearch() throw(ModalException); + void beforeSearch() /* throw(ModalException) */; LemmaChannels* channels() { return d_channels; } diff --git a/src/theory/arith/arith_ite_utils.cpp b/src/theory/arith/arith_ite_utils.cpp index 3e767b6db..c67af2a5d 100644 --- a/src/theory/arith/arith_ite_utils.cpp +++ b/src/theory/arith/arith_ite_utils.cpp @@ -138,7 +138,6 @@ Node ArithIteUtils::reduceVariablesInItes(Node n){ break; } Unreachable(); - return Node::null(); } ArithIteUtils::ArithIteUtils(ContainsTermITEVisitor& contains, diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp index b47cb1e60..a9761ade4 100644 --- a/src/theory/arith/arith_rewriter.cpp +++ b/src/theory/arith/arith_rewriter.cpp @@ -107,7 +107,15 @@ RewriteResponse ArithRewriter::preRewriteTerm(TNode t){ case kind::SINE: case kind::COSINE: case kind::TANGENT: - return preRewriteTranscendental(t); + case kind::COSECANT: + case kind::SECANT: + case kind::COTANGENT: + case kind::ARCSINE: + case kind::ARCCOSINE: + case kind::ARCTANGENT: + case kind::ARCCOSECANT: + case kind::ARCSECANT: + case kind::ARCCOTANGENT: return preRewriteTranscendental(t); case kind::INTS_DIVISION: case kind::INTS_MODULUS: return RewriteResponse(REWRITE_DONE, t); @@ -163,7 +171,15 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){ case kind::SINE: case kind::COSINE: case kind::TANGENT: - return postRewriteTranscendental(t); + case kind::COSECANT: + case kind::SECANT: + case kind::COTANGENT: + case kind::ARCSINE: + case kind::ARCCOSINE: + case kind::ARCTANGENT: + case kind::ARCCOSECANT: + case kind::ARCSECANT: + case kind::ARCCOTANGENT: return postRewriteTranscendental(t); case kind::INTS_DIVISION: case kind::INTS_MODULUS: return RewriteResponse(REWRITE_DONE, t); @@ -360,28 +376,30 @@ RewriteResponse ArithRewriter::preRewriteTranscendental(TNode t) { RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) { Trace("arith-tf-rewrite") << "Rewrite transcendental function : " << t << std::endl; + NodeManager* nm = NodeManager::currentNM(); switch( t.getKind() ){ case kind::EXPONENTIAL: { if(t[0].getKind() == kind::CONST_RATIONAL){ - Node one = NodeManager::currentNM()->mkConst(Rational(1)); + Node one = nm->mkConst(Rational(1)); if(t[0].getConst<Rational>().sgn()>=0 && t[0].getType().isInteger() && t[0]!=one){ - return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::POW, NodeManager::currentNM()->mkNode( kind::EXPONENTIAL, one ), t[0])); + return RewriteResponse( + REWRITE_AGAIN, + nm->mkNode(kind::POW, nm->mkNode(kind::EXPONENTIAL, one), t[0])); }else{ return RewriteResponse(REWRITE_DONE, t); } }else if(t[0].getKind() == kind::PLUS ){ std::vector<Node> product; for( unsigned i=0; i<t[0].getNumChildren(); i++ ){ - product.push_back( NodeManager::currentNM()->mkNode( kind::EXPONENTIAL, t[0][i] ) ); + product.push_back(nm->mkNode(kind::EXPONENTIAL, t[0][i])); } - return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::MULT, product)); + return RewriteResponse(REWRITE_AGAIN, nm->mkNode(kind::MULT, product)); } } break; case kind::SINE: if(t[0].getKind() == kind::CONST_RATIONAL){ const Rational& rat = t[0].getConst<Rational>(); - NodeManager* nm = NodeManager::currentNM(); if(rat.sgn() == 0){ return RewriteResponse(REWRITE_DONE, nm->mkConst(Rational(0))); } @@ -433,26 +451,29 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) { if (r_abs > rone) { //add/substract 2*pi beyond scope - Node ra_div_two = NodeManager::currentNM()->mkNode( + Node ra_div_two = nm->mkNode( kind::INTS_DIVISION, mkRationalNode(r_abs + rone), ntwo); Node new_pi_factor; if( r.sgn()==1 ){ - new_pi_factor = NodeManager::currentNM()->mkNode( kind::MINUS, pi_factor, NodeManager::currentNM()->mkNode( kind::MULT, ntwo, ra_div_two ) ); + new_pi_factor = + nm->mkNode(kind::MINUS, + pi_factor, + nm->mkNode(kind::MULT, ntwo, ra_div_two)); }else{ Assert( r.sgn()==-1 ); - new_pi_factor = NodeManager::currentNM()->mkNode( kind::PLUS, pi_factor, NodeManager::currentNM()->mkNode( kind::MULT, ntwo, ra_div_two ) ); + new_pi_factor = + nm->mkNode(kind::PLUS, + pi_factor, + nm->mkNode(kind::MULT, ntwo, ra_div_two)); } - Node new_arg = - NodeManager::currentNM()->mkNode(kind::MULT, new_pi_factor, pi); + Node new_arg = nm->mkNode(kind::MULT, new_pi_factor, pi); if (!rem.isNull()) { - new_arg = - NodeManager::currentNM()->mkNode(kind::PLUS, new_arg, rem); + new_arg = nm->mkNode(kind::PLUS, new_arg, rem); } // sin( 2*n*PI + x ) = sin( x ) - return RewriteResponse( - REWRITE_AGAIN_FULL, - NodeManager::currentNM()->mkNode(kind::SINE, new_arg)); + return RewriteResponse(REWRITE_AGAIN_FULL, + nm->mkNode(kind::SINE, new_arg)); } else if (r_abs == rone) { @@ -465,9 +486,7 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) { { return RewriteResponse( REWRITE_AGAIN_FULL, - NodeManager::currentNM()->mkNode( - kind::UMINUS, - NodeManager::currentNM()->mkNode(kind::SINE, rem))); + nm->mkNode(kind::UMINUS, nm->mkNode(kind::SINE, rem))); } } else if (rem.isNull()) @@ -498,16 +517,48 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) { } break; case kind::COSINE: { - return RewriteResponse(REWRITE_AGAIN_FULL, NodeManager::currentNM()->mkNode( kind::SINE, - NodeManager::currentNM()->mkNode( kind::MINUS, - NodeManager::currentNM()->mkNode( kind::MULT, - NodeManager::currentNM()->mkConst( Rational(1)/Rational(2) ), - NodeManager::currentNM()->mkNullaryOperator( NodeManager::currentNM()->realType(), kind::PI ) ), - t[0] ) ) ); - } break; + return RewriteResponse( + REWRITE_AGAIN_FULL, + nm->mkNode(kind::SINE, + nm->mkNode(kind::MINUS, + nm->mkNode(kind::MULT, + nm->mkConst(Rational(1) / Rational(2)), + mkPi()), + t[0]))); + } + break; case kind::TANGENT: - return RewriteResponse(REWRITE_AGAIN_FULL, NodeManager::currentNM()->mkNode(kind::DIVISION, NodeManager::currentNM()->mkNode( kind::SINE, t[0] ), - NodeManager::currentNM()->mkNode( kind::COSINE, t[0] ) )); + { + return RewriteResponse(REWRITE_AGAIN_FULL, + nm->mkNode(kind::DIVISION, + nm->mkNode(kind::SINE, t[0]), + nm->mkNode(kind::COSINE, t[0]))); + } + break; + case kind::COSECANT: + { + return RewriteResponse(REWRITE_AGAIN_FULL, + nm->mkNode(kind::DIVISION, + mkRationalNode(Rational(1)), + nm->mkNode(kind::SINE, t[0]))); + } + break; + case kind::SECANT: + { + return RewriteResponse(REWRITE_AGAIN_FULL, + nm->mkNode(kind::DIVISION, + mkRationalNode(Rational(1)), + nm->mkNode(kind::COSINE, t[0]))); + } + break; + case kind::COTANGENT: + { + return RewriteResponse(REWRITE_AGAIN_FULL, + nm->mkNode(kind::DIVISION, + nm->mkNode(kind::COSINE, t[0]), + nm->mkNode(kind::SINE, t[0]))); + } + break; default: break; } @@ -592,7 +643,6 @@ RewriteResponse ArithRewriter::postRewrite(TNode t){ return response; }else{ Unreachable(); - return RewriteResponse(REWRITE_DONE, Node::null()); } } @@ -603,7 +653,6 @@ RewriteResponse ArithRewriter::preRewrite(TNode t){ return preRewriteAtom(t); }else{ Unreachable(); - return RewriteResponse(REWRITE_DONE, Node::null()); } } diff --git a/src/theory/arith/cut_log.cpp b/src/theory/arith/cut_log.cpp index ad04cfe22..08fe0bc1e 100644 --- a/src/theory/arith/cut_log.cpp +++ b/src/theory/arith/cut_log.cpp @@ -26,6 +26,7 @@ #include "theory/arith/constraint.h" #include "theory/arith/cut_log.h" #include "theory/arith/normal_form.h" +#include "util/ostream_util.h" using namespace std; @@ -84,8 +85,9 @@ void PrimitiveVec::setup(int l){ } void PrimitiveVec::print(std::ostream& out) const{ Assert(initialized()); - out << len << " "; - out.precision(15); + StreamFormatScope scope(out); + + out << len << " " << std::setprecision(15); for(int i = 1; i <= len; ++i){ out << "["<< inds[i] <<", " << coeffs[i]<<"]"; } diff --git a/src/theory/arith/delta_rational.cpp b/src/theory/arith/delta_rational.cpp index fba7fdaf6..7a94674d6 100644 --- a/src/theory/arith/delta_rational.cpp +++ b/src/theory/arith/delta_rational.cpp @@ -83,8 +83,8 @@ DeltaRationalException::DeltaRationalException(const char* op, } DeltaRationalException::~DeltaRationalException() {} - -Integer DeltaRational::euclidianDivideQuotient(const DeltaRational& y) const throw(DeltaRationalException){ +Integer DeltaRational::euclidianDivideQuotient(const DeltaRational& y) const +{ if(isIntegral() && y.isIntegral()){ Integer ti = floor(); Integer yi = y.floor(); @@ -94,7 +94,8 @@ Integer DeltaRational::euclidianDivideQuotient(const DeltaRational& y) const thr } } -Integer DeltaRational::euclidianDivideRemainder(const DeltaRational& y) const throw(DeltaRationalException){ +Integer DeltaRational::euclidianDivideRemainder(const DeltaRational& y) const +{ if(isIntegral() && y.isIntegral()){ Integer ti = floor(); Integer yi = y.floor(); diff --git a/src/theory/arith/delta_rational.h b/src/theory/arith/delta_rational.h index 7a1c18ea2..5e4b2c3a8 100644 --- a/src/theory/arith/delta_rational.h +++ b/src/theory/arith/delta_rational.h @@ -116,7 +116,8 @@ public: * This can be done whenever this->k or a.k is 0. * Otherwise, the result is not a DeltaRational and a DeltaRationalException is thrown. */ - DeltaRational operator*(const DeltaRational& a) const throw(DeltaRationalException){ + DeltaRational operator*(const DeltaRational& a) const + /* throw(DeltaRationalException) */ { if(infinitesimalIsZero()){ return a * (this->getNoninfinitesimalPart()); }else if(a.infinitesimalIsZero()){ @@ -153,7 +154,8 @@ public: * This can be done when a.k is 0 and a.c is non-zero. * Otherwise, the result is not a DeltaRational and a DeltaRationalException is thrown. */ - DeltaRational operator/(const DeltaRational& a) const throw(DeltaRationalException){ + DeltaRational operator/(const DeltaRational& a) const + /* throw(DeltaRationalException) */ { if(a.infinitesimalIsZero()){ return (*this) / a.getNoninfinitesimalPart(); }else{ @@ -258,11 +260,12 @@ public: } /** Only well defined if both this and y are integral. */ - Integer euclidianDivideQuotient(const DeltaRational& y) const throw(DeltaRationalException); + Integer euclidianDivideQuotient(const DeltaRational& y) const + /* throw(DeltaRationalException) */; /** Only well defined if both this and y are integral. */ - Integer euclidianDivideRemainder(const DeltaRational& y) const throw(DeltaRationalException); - + Integer euclidianDivideRemainder(const DeltaRational& y) const + /* throw(DeltaRationalException) */; std::string toString() const; diff --git a/src/theory/arith/kinds b/src/theory/arith/kinds index 34ae30f4c..3073d0078 100644 --- a/src/theory/arith/kinds +++ b/src/theory/arith/kinds @@ -31,6 +31,17 @@ operator EXPONENTIAL 1 "exponential" operator SINE 1 "sine" operator COSINE 1 "consine" operator TANGENT 1 "tangent" +operator COSECANT 1 "cosecant" +operator SECANT 1 "secant" +operator COTANGENT 1 "cotangent" +operator ARCSINE 1 "arc sine" +operator ARCCOSINE 1 "arc consine" +operator ARCTANGENT 1 "arc tangent" +operator ARCCOSECANT 1 "arc cosecant" +operator ARCSECANT 1 "arc secant" +operator ARCCOTANGENT 1 "arc cotangent" + +operator SQRT 1 "square root" constant DIVISIBLE_OP \ ::CVC4::Divisible \ @@ -105,6 +116,17 @@ typerule EXPONENTIAL ::CVC4::theory::arith::RealOperatorTypeRule typerule SINE ::CVC4::theory::arith::RealOperatorTypeRule typerule COSINE ::CVC4::theory::arith::RealOperatorTypeRule typerule TANGENT ::CVC4::theory::arith::RealOperatorTypeRule +typerule COSECANT ::CVC4::theory::arith::RealOperatorTypeRule +typerule SECANT ::CVC4::theory::arith::RealOperatorTypeRule +typerule COTANGENT ::CVC4::theory::arith::RealOperatorTypeRule +typerule ARCSINE ::CVC4::theory::arith::RealOperatorTypeRule +typerule ARCCOSINE ::CVC4::theory::arith::RealOperatorTypeRule +typerule ARCTANGENT ::CVC4::theory::arith::RealOperatorTypeRule +typerule ARCCOSECANT ::CVC4::theory::arith::RealOperatorTypeRule +typerule ARCSECANT ::CVC4::theory::arith::RealOperatorTypeRule +typerule ARCCOTANGENT ::CVC4::theory::arith::RealOperatorTypeRule + +typerule SQRT ::CVC4::theory::arith::RealOperatorTypeRule nullaryoperator PI "pi" diff --git a/src/theory/arith/nonlinear_extension.cpp b/src/theory/arith/nonlinear_extension.cpp index 65a7597f1..e8f8b9fa5 100644 --- a/src/theory/arith/nonlinear_extension.cpp +++ b/src/theory/arith/nonlinear_extension.cpp @@ -216,12 +216,14 @@ bool hasNewMonomials(Node n, const std::vector<Node>& existing) { NonlinearExtension::NonlinearExtension(TheoryArith& containing, eq::EqualityEngine* ee) - : d_lemmas(containing.getUserContext()), + : d_def_lemmas(containing.getUserContext()), + d_lemmas(containing.getUserContext()), d_zero_split(containing.getUserContext()), d_skolem_atoms(containing.getUserContext()), d_containing(containing), d_ee(ee), - d_needsLastCall(false) { + d_needsLastCall(false) +{ d_true = NodeManager::currentNM()->mkConst(true); d_false = NodeManager::currentNM()->mkConst(false); d_zero = NodeManager::currentNM()->mkConst(Rational(0)); @@ -1032,7 +1034,9 @@ Kind NonlinearExtension::transKinds(Kind k1, Kind k2) { } bool NonlinearExtension::isTranscendentalKind(Kind k) { - Assert(k != TANGENT && k != COSINE); // eliminated + // many operators are eliminated during rewriting + Assert(k != TANGENT && k != COSINE && k != COSECANT && k != SECANT + && k != COTANGENT); return k == EXPONENTIAL || k == SINE || k == PI; } @@ -1161,10 +1165,12 @@ bool NonlinearExtension::checkModelTf(const std::vector<Node>& assertions) if (check_assertions.empty()) { + Trace("nl-ext-tf-check-model") << "...simple check succeeded." << std::endl; return true; } else { + Trace("nl-ext-tf-check-model") << "...simple check failed." << std::endl; // TODO (#1450) check model for general case return false; } @@ -1250,7 +1256,6 @@ bool NonlinearExtension::simpleCheckModelTfLit(Node lit) return comp == d_true; } } - Trace("nl-ext-tf-check-model-simple") << " failed due to unknown literal." << std::endl; return false; @@ -1288,11 +1293,13 @@ int NonlinearExtension::checkLastCall(const std::vector<Node>& assertions, d_tf_check_model_bounds.clear(); int lemmas_proc = 0; - std::vector<Node> lemmas; - + std::vector<Node> lemmas; + NodeManager* nm = NodeManager::currentNM(); + Trace("nl-ext-mv") << "Extended terms : " << std::endl; // register the extended function terms std::map< Node, Node > mvarg_to_term; + std::vector<Node> trig_no_base; for( unsigned i=0; i<xts.size(); i++ ){ Node a = xts[i]; computeModelValue(a, 0); @@ -1341,38 +1348,11 @@ int NonlinearExtension::checkLastCall(const std::vector<Node>& assertions, { if( d_trig_is_base.find( a )==d_trig_is_base.end() ){ consider = false; - if( d_trig_base.find( a )==d_trig_base.end() ){ - Node y = NodeManager::currentNM()->mkSkolem("y",NodeManager::currentNM()->realType(),"phase shifted trigonometric arg"); - Node new_a = NodeManager::currentNM()->mkNode( a.getKind(), y ); - d_trig_is_base[new_a] = true; - d_trig_base[a] = new_a; - Trace("nl-ext-tf") << "Basis sine : " << new_a << " for " << a << std::endl; - if( d_pi.isNull() ){ - mkPi(); - getCurrentPiBounds( lemmas ); - } - Node shift = NodeManager::currentNM()->mkSkolem( "s", NodeManager::currentNM()->integerType(), "number of shifts" ); - // FIXME : do not introduce shift here, instead needs model-based - // refinement for constant shifts (#1284) - Node shift_lem = NodeManager::currentNM()->mkNode( - AND, - mkValidPhase(y, d_pi), - a[0].eqNode(NodeManager::currentNM()->mkNode( - PLUS, - y, - NodeManager::currentNM()->mkNode( - MULT, - NodeManager::currentNM()->mkConst(Rational(2)), - shift, - d_pi))), - // particular case of above for shift=0 - NodeManager::currentNM()->mkNode( - IMPLIES, mkValidPhase(a[0], d_pi), a[0].eqNode(y)), - new_a.eqNode(a)); - //must do preprocess on this one - Trace("nl-ext-lemma") << "NonlinearExtension::Lemma : shift : " << shift_lem << std::endl; - d_containing.getOutputChannel().lemma(shift_lem, false, true); - lemmas_proc++; + trig_no_base.push_back(a); + if (d_pi.isNull()) + { + mkPi(); + getCurrentPiBounds(lemmas); } } } @@ -1383,7 +1363,7 @@ int NonlinearExtension::checkLastCall(const std::vector<Node>& assertions, //verify they have the same model value if( d_mv[1][a]!=d_mv[1][itrm->second] ){ // if not, add congruence lemma - Node cong_lemma = NodeManager::currentNM()->mkNode( + Node cong_lemma = nm->mkNode( IMPLIES, a[0].eqNode(itrm->second[0]), a.eqNode(itrm->second)); lemmas.push_back( cong_lemma ); //Assert( false ); @@ -1407,6 +1387,45 @@ int NonlinearExtension::checkLastCall(const std::vector<Node>& assertions, return lemmas_proc; } + // process SINE phase shifting + for (const Node& a : trig_no_base) + { + if (d_trig_base.find(a) == d_trig_base.end()) + { + Node y = + nm->mkSkolem("y", nm->realType(), "phase shifted trigonometric arg"); + Node new_a = nm->mkNode(a.getKind(), y); + d_trig_is_base[new_a] = true; + d_trig_base[a] = new_a; + Trace("nl-ext-tf") << "Basis sine : " << new_a << " for " << a + << std::endl; + Assert(!d_pi.isNull()); + Node shift = nm->mkSkolem("s", nm->integerType(), "number of shifts"); + // FIXME : do not introduce shift here, instead needs model-based + // refinement for constant shifts (#1284) + Node shift_lem = nm->mkNode( + AND, + mkValidPhase(y, d_pi), + a[0].eqNode(nm->mkNode( + PLUS, + y, + nm->mkNode(MULT, nm->mkConst(Rational(2)), shift, d_pi))), + // particular case of above for shift=0 + nm->mkNode(IMPLIES, mkValidPhase(a[0], d_pi), a[0].eqNode(y)), + new_a.eqNode(a)); + // must do preprocess on this one + Trace("nl-ext-lemma") + << "NonlinearExtension::Lemma : shift : " << shift_lem << std::endl; + d_containing.getOutputChannel().lemma(shift_lem, false, true); + lemmas_proc++; + } + } + if (lemmas_proc > 0) + { + Trace("nl-ext") << " ...finished with " << lemmas_proc + << " new lemmas SINE phase shifting." << std::endl; + return lemmas_proc; + } // register constants registerMonomial(d_one); @@ -1742,6 +1761,24 @@ void NonlinearExtension::check(Theory::Effort e) { } } +void NonlinearExtension::addDefinition(Node lem) +{ + Trace("nl-ext") << "NonlinearExtension::addDefinition : " << lem << std::endl; + d_def_lemmas.insert(lem); +} + +void NonlinearExtension::presolve() +{ + Trace("nl-ext") << "NonlinearExtension::presolve, #defs = " + << d_def_lemmas.size() << std::endl; + for (NodeSet::const_iterator it = d_def_lemmas.begin(); + it != d_def_lemmas.end(); + ++it) + { + flushLemma(*it); + } +} + void NonlinearExtension::assignOrderIds(std::vector<Node>& vars, NodeMultiset& order, unsigned orderType) { @@ -3237,8 +3274,6 @@ std::vector<Node> NonlinearExtension::checkTranscendentalTangentPlanes() // Figure 3: P_l, P_u // mapped to for signs of c std::map<int, Node> poly_approx_bounds[2]; - std::map<int, Node> - poly_approx_bounds_neg[2]; // the negative case is different for exp // n is the Taylor degree we are currently considering unsigned n = 2 * d_taylor_degree; // n must be even @@ -3487,6 +3522,10 @@ std::vector<Node> NonlinearExtension::checkTranscendentalTangentPlanes() antec.size() == 1 ? antec[0] : nm->mkNode(AND, antec); lem = nm->mkNode(IMPLIES, antec_n, lem); } + Trace("nl-ext-tf-tplanes-debug") + << "*** Tangent plane lemma (pre-rewrite): " << lem + << std::endl; + lem = Rewriter::rewrite(lem); Trace("nl-ext-tf-tplanes") << "*** Tangent plane lemma : " << lem << std::endl; // Figure 3 : line 9 @@ -3607,6 +3646,10 @@ std::vector<Node> NonlinearExtension::checkTranscendentalTangentPlanes() nm->mkNode(GEQ, tf[0], s == 0 ? bounds[s] : c), nm->mkNode(LEQ, tf[0], s == 0 ? c : bounds[s])); lem = nm->mkNode(IMPLIES, antec_n, lem); + Trace("nl-ext-tf-tplanes-debug") + << "*** Secant plane lemma (pre-rewrite) : " << lem + << std::endl; + lem = Rewriter::rewrite(lem); Trace("nl-ext-tf-tplanes") << "*** Secant plane lemma : " << lem << std::endl; // Figure 3 : line 22 diff --git a/src/theory/arith/nonlinear_extension.h b/src/theory/arith/nonlinear_extension.h index 34da28f6c..84acc0269 100644 --- a/src/theory/arith/nonlinear_extension.h +++ b/src/theory/arith/nonlinear_extension.h @@ -121,6 +121,23 @@ class NonlinearExtension { void check(Theory::Effort e); /** Does this class need a call to check(...) at last call effort? */ bool needsCheckLastEffort() const { return d_needsLastCall; } + /** add definition + * + * This function notifies this class that lem is a formula that defines or + * constrains an auxiliary variable. For example, during + * TheoryArith::expandDefinitions, we replace a term like arcsin( x ) with an + * auxiliary variable k. The lemmas 0 <= k < pi and sin( x ) = k are added as + * definitions to this class. + */ + void addDefinition(Node lem); + /** presolve + * + * This function is called during TheoryArith's presolve command. + * In this function, we send lemmas we accumulated during preprocessing, + * for instance, definitional lemmas from expandDefinitions are sent out + * on the output channel of TheoryArith in this function. + */ + void presolve(); /** Compare arithmetic terms i and j based an ordering. * * orderType = 0 : compare concrete model values @@ -387,8 +404,11 @@ class NonlinearExtension { // ( x*y, x*z, y ) for each pair of monomials ( x*y, x*z ) with common factors std::map<Node, std::map<Node, Node> > d_mono_diff; - // cache of all lemmas sent + /** cache of definition lemmas (user-context-dependent) */ + NodeSet d_def_lemmas; + /** cache of all lemmas sent on the output channel (user-context-dependent) */ NodeSet d_lemmas; + /** cache of terms t for which we have added the lemma ( t = 0 V t != 0 ). */ NodeSet d_zero_split; // literals with Skolems (need not be satisfied by model) diff --git a/src/theory/arith/normal_form.cpp b/src/theory/arith/normal_form.cpp index 30b9ca0b5..76782d8a5 100644 --- a/src/theory/arith/normal_form.cpp +++ b/src/theory/arith/normal_form.cpp @@ -95,7 +95,16 @@ bool Variable::isTranscendentalMember(Node n) { case kind::SINE: case kind::COSINE: case kind::TANGENT: - return Polynomial::isMember(n[0]); + case kind::COSECANT: + case kind::SECANT: + case kind::COTANGENT: + case kind::ARCSINE: + case kind::ARCCOSINE: + case kind::ARCTANGENT: + case kind::ARCCOSECANT: + case kind::ARCSECANT: + case kind::ARCCOTANGENT: + case kind::SQRT: return Polynomial::isMember(n[0]); case kind::PI: return true; default: diff --git a/src/theory/arith/normal_form.h b/src/theory/arith/normal_form.h index 21301da91..ba740146b 100644 --- a/src/theory/arith/normal_form.h +++ b/src/theory/arith/normal_form.h @@ -245,6 +245,16 @@ public: case kind::SINE: case kind::COSINE: case kind::TANGENT: + case kind::COSECANT: + case kind::SECANT: + case kind::COTANGENT: + case kind::ARCSINE: + case kind::ARCCOSINE: + case kind::ARCTANGENT: + case kind::ARCCOSECANT: + case kind::ARCSECANT: + case kind::ARCCOTANGENT: + case kind::SQRT: case kind::PI: return isTranscendentalMember(n); case kind::ABS: diff --git a/src/theory/arith/theory_arith.cpp b/src/theory/arith/theory_arith.cpp index e354305d7..1390cbee6 100644 --- a/src/theory/arith/theory_arith.cpp +++ b/src/theory/arith/theory_arith.cpp @@ -42,8 +42,6 @@ TheoryArith::TheoryArith(context::Context* c, context::UserContext* u, getExtTheory()->addFunctionKind(kind::NONLINEAR_MULT); getExtTheory()->addFunctionKind(kind::EXPONENTIAL); getExtTheory()->addFunctionKind(kind::SINE); - getExtTheory()->addFunctionKind(kind::COSINE); - getExtTheory()->addFunctionKind(kind::TANGENT); getExtTheory()->addFunctionKind(kind::PI); } } diff --git a/src/theory/arith/theory_arith.h b/src/theory/arith/theory_arith.h index 1c10bde0d..4f3a13b4d 100644 --- a/src/theory/arith/theory_arith.h +++ b/src/theory/arith/theory_arith.h @@ -48,41 +48,47 @@ public: /** * Does non-context dependent setup for a node connected to a theory. */ - void preRegisterTerm(TNode n); + void preRegisterTerm(TNode n) override; - Node expandDefinition(LogicRequest &logicRequest, Node node); + Node expandDefinition(LogicRequest& logicRequest, Node node) override; - void setMasterEqualityEngine(eq::EqualityEngine* eq); + void setMasterEqualityEngine(eq::EqualityEngine* eq) override; - void check(Effort e); - bool needsCheckLastEffort(); - void propagate(Effort e); - Node explain(TNode n); - bool getCurrentSubstitution( int effort, std::vector< Node >& vars, std::vector< Node >& subs, std::map< Node, std::vector< Node > >& exp ); - bool isExtfReduced( int effort, Node n, Node on, std::vector< Node >& exp ); + void check(Effort e) override; + bool needsCheckLastEffort() override; + void propagate(Effort e) override; + Node explain(TNode n) override; + bool getCurrentSubstitution(int effort, + std::vector<Node>& vars, + std::vector<Node>& subs, + std::map<Node, std::vector<Node> >& exp) override; + bool isExtfReduced(int effort, + Node n, + Node on, + std::vector<Node>& exp) override; bool collectModelInfo(TheoryModel* m) override; - void shutdown(){ } + void shutdown() override {} - void presolve(); - void notifyRestart(); - PPAssertStatus ppAssert(TNode in, SubstitutionMap& outSubstitutions); - Node ppRewrite(TNode atom); - void ppStaticLearn(TNode in, NodeBuilder<>& learned); + void presolve() override; + void notifyRestart() override; + PPAssertStatus ppAssert(TNode in, SubstitutionMap& outSubstitutions) override; + Node ppRewrite(TNode atom) override; + void ppStaticLearn(TNode in, NodeBuilder<>& learned) override; - std::string identify() const { return std::string("TheoryArith"); } + std::string identify() const override { return std::string("TheoryArith"); } - EqualityStatus getEqualityStatus(TNode a, TNode b); + EqualityStatus getEqualityStatus(TNode a, TNode b) override; - void addSharedTerm(TNode n); + void addSharedTerm(TNode n) override; - Node getModelValue(TNode var); + Node getModelValue(TNode var) override; - - std::pair<bool, Node> entailmentCheck(TNode lit, - const EntailmentCheckParameters* params, - EntailmentCheckSideEffects* out); + std::pair<bool, Node> entailmentCheck( + TNode lit, + const EntailmentCheckParameters* params, + EntailmentCheckSideEffects* out) override; };/* class TheoryArith */ diff --git a/src/theory/arith/theory_arith_private.cpp b/src/theory/arith/theory_arith_private.cpp index f05f47595..fc0673d21 100644 --- a/src/theory/arith/theory_arith_private.cpp +++ b/src/theory/arith/theory_arith_private.cpp @@ -85,67 +85,84 @@ namespace arith { static Node toSumNode(const ArithVariables& vars, const DenseMap<Rational>& sum); static bool complexityBelow(const DenseMap<Rational>& row, uint32_t cap); - -TheoryArithPrivate::TheoryArithPrivate(TheoryArith& containing, context::Context* c, context::UserContext* u, OutputChannel& out, Valuation valuation, const LogicInfo& logicInfo) : - d_containing(containing), - d_nlIncomplete( false), - d_rowTracking(), - d_constraintDatabase(c, u, d_partialModel, d_congruenceManager, RaiseConflict(*this)), - d_qflraStatus(Result::SAT_UNKNOWN), - d_unknownsInARow(0), - d_hasDoneWorkSinceCut(false), - d_learner(u), - d_assertionsThatDoNotMatchTheirLiterals(c), - d_nextIntegerCheckVar(0), - d_constantIntegerVariables(c), - d_diseqQueue(c, false), - d_currentPropagationList(), - d_learnedBounds(c), - d_partialModel(c, DeltaComputeCallback(*this)), - d_errorSet(d_partialModel, TableauSizes(&d_tableau), BoundCountingLookup(*this)), - d_tableau(), - d_linEq(d_partialModel, d_tableau, d_rowTracking, BasicVarModelUpdateCallBack(*this)), - d_diosolver(c), - d_restartsCounter(0), - d_tableauSizeHasBeenModified(false), - d_tableauResetDensity(1.6), - d_tableauResetPeriod(10), - d_conflicts(c), - d_blackBoxConflict(c, Node::null()), - d_congruenceManager(c, d_constraintDatabase, SetupLiteralCallBack(*this), d_partialModel, RaiseEqualityEngineConflict(*this)), - d_cmEnabled(c, true), - - d_dualSimplex(d_linEq, d_errorSet, RaiseConflict(*this), TempVarMalloc(*this)), - d_fcSimplex(d_linEq, d_errorSet, RaiseConflict(*this), TempVarMalloc(*this)), - d_soiSimplex(d_linEq, d_errorSet, RaiseConflict(*this), TempVarMalloc(*this)), - d_attemptSolSimplex(d_linEq, d_errorSet, RaiseConflict(*this), TempVarMalloc(*this)), - d_nonlinearExtension( NULL ), - d_pass1SDP(NULL), - d_otherSDP(NULL), - d_lastContextIntegerAttempted(c,-1), - - - d_DELTA_ZERO(0), - d_approxCuts(c), - d_fullCheckCounter(0), - d_cutCount(c, 0), - d_cutInContext(c), - d_likelyIntegerInfeasible(c, false), - d_guessedCoeffSet(c, false), - d_guessedCoeffs(), - d_treeLog(NULL), - d_replayVariables(), - d_replayConstraints(), - d_lhsTmp(), - d_approxStats(NULL), - d_attemptSolveIntTurnedOff(u, 0), - d_dioSolveResources(0), - d_solveIntMaybeHelp(0u), - d_solveIntAttempts(0u), - d_statistics(), - d_to_int_skolem(u), - d_div_skolem(u), - d_int_div_skolem(u) +TheoryArithPrivate::TheoryArithPrivate(TheoryArith& containing, + context::Context* c, + context::UserContext* u, + OutputChannel& out, + Valuation valuation, + const LogicInfo& logicInfo) + : d_containing(containing), + d_nlIncomplete(false), + d_rowTracking(), + d_constraintDatabase( + c, u, d_partialModel, d_congruenceManager, RaiseConflict(*this)), + d_qflraStatus(Result::SAT_UNKNOWN), + d_unknownsInARow(0), + d_hasDoneWorkSinceCut(false), + d_learner(u), + d_assertionsThatDoNotMatchTheirLiterals(c), + d_nextIntegerCheckVar(0), + d_constantIntegerVariables(c), + d_diseqQueue(c, false), + d_currentPropagationList(), + d_learnedBounds(c), + d_partialModel(c, DeltaComputeCallback(*this)), + d_errorSet( + d_partialModel, TableauSizes(&d_tableau), BoundCountingLookup(*this)), + d_tableau(), + d_linEq(d_partialModel, + d_tableau, + d_rowTracking, + BasicVarModelUpdateCallBack(*this)), + d_diosolver(c), + d_restartsCounter(0), + d_tableauSizeHasBeenModified(false), + d_tableauResetDensity(1.6), + d_tableauResetPeriod(10), + d_conflicts(c), + d_blackBoxConflict(c, Node::null()), + d_congruenceManager(c, + d_constraintDatabase, + SetupLiteralCallBack(*this), + d_partialModel, + RaiseEqualityEngineConflict(*this)), + d_cmEnabled(c, true), + + d_dualSimplex( + d_linEq, d_errorSet, RaiseConflict(*this), TempVarMalloc(*this)), + d_fcSimplex( + d_linEq, d_errorSet, RaiseConflict(*this), TempVarMalloc(*this)), + d_soiSimplex( + d_linEq, d_errorSet, RaiseConflict(*this), TempVarMalloc(*this)), + d_attemptSolSimplex( + d_linEq, d_errorSet, RaiseConflict(*this), TempVarMalloc(*this)), + d_nonlinearExtension(NULL), + d_pass1SDP(NULL), + d_otherSDP(NULL), + d_lastContextIntegerAttempted(c, -1), + + d_DELTA_ZERO(0), + d_approxCuts(c), + d_fullCheckCounter(0), + d_cutCount(c, 0), + d_cutInContext(c), + d_likelyIntegerInfeasible(c, false), + d_guessedCoeffSet(c, false), + d_guessedCoeffs(), + d_treeLog(NULL), + d_replayVariables(), + d_replayConstraints(), + d_lhsTmp(), + d_approxStats(NULL), + d_attemptSolveIntTurnedOff(u, 0), + d_dioSolveResources(0), + d_solveIntMaybeHelp(0u), + d_solveIntAttempts(0u), + d_statistics(), + d_to_int_skolem(u), + d_div_skolem(u), + d_int_div_skolem(u), + d_nlin_inverse_skolem(u) { if( options::nlExt() ){ d_nonlinearExtension = new NonlinearExtension( @@ -4120,7 +4137,7 @@ void TheoryArithPrivate::propagate(Theory::Effort e) { } DeltaRational TheoryArithPrivate::getDeltaValue(TNode term) const - throw(DeltaRationalException, ModelException) { +{ AlwaysAssert(d_qflraStatus != Result::SAT_UNKNOWN); Debug("arith::value") << term << std::endl; @@ -4419,6 +4436,11 @@ void TheoryArithPrivate::presolve(){ Debug("arith::oldprop") << " lemma lemma duck " <<lem << endl; outputLemma(lem); } + + if (options::nlExt()) + { + d_nonlinearExtension->presolve(); + } } EqualityStatus TheoryArithPrivate::getEqualityStatus(TNode a, TNode b) { @@ -4863,90 +4885,178 @@ const BoundsInfo& TheoryArithPrivate::boundsInfo(ArithVar basic) const{ Node TheoryArithPrivate::expandDefinition(LogicRequest &logicRequest, Node node) { NodeManager* nm = NodeManager::currentNM(); - // eliminate here since involves division - if( node.getKind()==kind::TANGENT ){ - node = nm->mkNode(kind::DIVISION, nm->mkNode( kind::SINE, node[0] ), - nm->mkNode( kind::COSINE, node[0] ) ); + // eliminate here since the rewritten form of these may introduce division + Kind k = node.getKind(); + if (k == kind::TANGENT || k == kind::COSECANT || k == kind::SECANT + || k == kind::COTANGENT) + { + node = Rewriter::rewrite(node); + k = node.getKind(); } - switch(node.getKind()) { - case kind::DIVISION: { - TNode num = node[0], den = node[1]; - Node ret = nm->mkNode(kind::DIVISION_TOTAL, num, den); - if (!den.isConst() || den.getConst<Rational>().sgn() == 0) + switch (k) + { + case kind::DIVISION: { - // partial function: division - if (d_divByZero.isNull()) + TNode num = node[0], den = node[1]; + Node ret = nm->mkNode(kind::DIVISION_TOTAL, num, den); + if (!den.isConst() || den.getConst<Rational>().sgn() == 0) { - d_divByZero = - nm->mkSkolem("divByZero", - nm->mkFunctionType(nm->realType(), nm->realType()), - "partial real division", - NodeManager::SKOLEM_EXACT_NAME); - logicRequest.widenLogic(THEORY_UF); + // partial function: division + if (d_divByZero.isNull()) + { + d_divByZero = + nm->mkSkolem("divByZero", + nm->mkFunctionType(nm->realType(), nm->realType()), + "partial real division", + NodeManager::SKOLEM_EXACT_NAME); + logicRequest.widenLogic(THEORY_UF); + } + Node denEq0 = nm->mkNode(kind::EQUAL, den, nm->mkConst(Rational(0))); + Node divByZeroNum = nm->mkNode(kind::APPLY_UF, d_divByZero, num); + ret = nm->mkNode(kind::ITE, denEq0, divByZeroNum, ret); } - Node denEq0 = nm->mkNode(kind::EQUAL, den, nm->mkConst(Rational(0))); - Node divByZeroNum = nm->mkNode(kind::APPLY_UF, d_divByZero, num); - ret = nm->mkNode(kind::ITE, denEq0, divByZeroNum, ret); + return ret; + break; } - return ret; - break; - } - case kind::INTS_DIVISION: { - // partial function: integer div - TNode num = node[0], den = node[1]; - Node ret = nm->mkNode(kind::INTS_DIVISION_TOTAL, num, den); - if (!den.isConst() || den.getConst<Rational>().sgn() == 0) + case kind::INTS_DIVISION: { - if (d_intDivByZero.isNull()) + // partial function: integer div + TNode num = node[0], den = node[1]; + Node ret = nm->mkNode(kind::INTS_DIVISION_TOTAL, num, den); + if (!den.isConst() || den.getConst<Rational>().sgn() == 0) { - d_intDivByZero = nm->mkSkolem( - "intDivByZero", - nm->mkFunctionType(nm->integerType(), nm->integerType()), - "partial integer division", - NodeManager::SKOLEM_EXACT_NAME); - logicRequest.widenLogic(THEORY_UF); + if (d_intDivByZero.isNull()) + { + d_intDivByZero = nm->mkSkolem( + "intDivByZero", + nm->mkFunctionType(nm->integerType(), nm->integerType()), + "partial integer division", + NodeManager::SKOLEM_EXACT_NAME); + logicRequest.widenLogic(THEORY_UF); + } + Node denEq0 = nm->mkNode(kind::EQUAL, den, nm->mkConst(Rational(0))); + Node intDivByZeroNum = nm->mkNode(kind::APPLY_UF, d_intDivByZero, num); + ret = nm->mkNode(kind::ITE, denEq0, intDivByZeroNum, ret); } - Node denEq0 = nm->mkNode(kind::EQUAL, den, nm->mkConst(Rational(0))); - Node intDivByZeroNum = nm->mkNode(kind::APPLY_UF, d_intDivByZero, num); - ret = nm->mkNode(kind::ITE, denEq0, intDivByZeroNum, ret); + return ret; + break; } - return ret; - break; - } - case kind::INTS_MODULUS: { - // partial function: mod - TNode num = node[0], den = node[1]; - Node ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, num, den); - if (!den.isConst() || den.getConst<Rational>().sgn() == 0) + case kind::INTS_MODULUS: { - if (d_modZero.isNull()) + // partial function: mod + TNode num = node[0], den = node[1]; + Node ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, num, den); + if (!den.isConst() || den.getConst<Rational>().sgn() == 0) { - d_modZero = nm->mkSkolem( - "modZero", - nm->mkFunctionType(nm->integerType(), nm->integerType()), - "partial modulus", - NodeManager::SKOLEM_EXACT_NAME); - logicRequest.widenLogic(THEORY_UF); + if (d_modZero.isNull()) + { + d_modZero = nm->mkSkolem( + "modZero", + nm->mkFunctionType(nm->integerType(), nm->integerType()), + "partial modulus", + NodeManager::SKOLEM_EXACT_NAME); + logicRequest.widenLogic(THEORY_UF); + } + Node denEq0 = nm->mkNode(kind::EQUAL, den, nm->mkConst(Rational(0))); + Node modZeroNum = nm->mkNode(kind::APPLY_UF, d_modZero, num); + ret = nm->mkNode(kind::ITE, denEq0, modZeroNum, ret); } - Node denEq0 = nm->mkNode(kind::EQUAL, den, nm->mkConst(Rational(0))); - Node modZeroNum = nm->mkNode(kind::APPLY_UF, d_modZero, num); - ret = nm->mkNode(kind::ITE, denEq0, modZeroNum, ret); + return ret; + break; } - return ret; - break; - } - case kind::ABS: { - return nm->mkNode(kind::ITE, nm->mkNode(kind::LT, node[0], nm->mkConst(Rational(0))), nm->mkNode(kind::UMINUS, node[0]), node[0]); - break; - } + case kind::ABS: + { + return nm->mkNode(kind::ITE, + nm->mkNode(kind::LT, node[0], nm->mkConst(Rational(0))), + nm->mkNode(kind::UMINUS, node[0]), + node[0]); + break; + } + case kind::SQRT: + case kind::ARCSINE: + case kind::ARCCOSINE: + case kind::ARCTANGENT: + case kind::ARCCOSECANT: + case kind::ARCSECANT: + case kind::ARCCOTANGENT: + { + // eliminate inverse functions here + NodeMap::const_iterator it = d_nlin_inverse_skolem.find(node); + if (it == d_nlin_inverse_skolem.end()) + { + Node var = nm->mkSkolem("nonlinearInv", + nm->realType(), + "the result of a non-linear inverse function"); + d_nlin_inverse_skolem[node] = var; + Node lem; + if (k == kind::SQRT) + { + lem = nm->mkNode(kind::MULT, node[0], node[0]).eqNode(var); + } + else + { + Node pi = mkPi(); + + // range of the skolem + Node rlem; + if (k == kind::ARCSINE || k == ARCTANGENT || k == ARCCOSECANT) + { + Node half = nm->mkConst(Rational(1) / Rational(2)); + Node pi2 = nm->mkNode(kind::MULT, half, pi); + Node npi2 = nm->mkNode(kind::MULT, nm->mkConst(Rational(-1)), pi2); + // -pi/2 < var <= pi/2 + rlem = nm->mkNode( + AND, nm->mkNode(LT, npi2, var), nm->mkNode(LEQ, var, pi2)); + } + else + { + // 0 <= var < pi + rlem = nm->mkNode(AND, + nm->mkNode(LEQ, nm->mkConst(Rational(0)), var), + nm->mkNode(LT, var, pi)); + } + if (options::nlExt()) + { + d_nonlinearExtension->addDefinition(rlem); + } - default: - return node; - break; + Kind rk = k == kind::ARCSINE + ? kind::SINE + : (k == kind::ARCCOSINE + ? kind::COSINE + : (k == kind::ARCTANGENT + ? kind::TANGENT + : (k == kind::ARCCOSECANT + ? kind::COSECANT + : (k == kind::ARCSECANT + ? kind::SECANT + : kind::COTANGENT)))); + Node invTerm = nm->mkNode(rk, var); + // since invTerm may introduce division, + // we must also call expandDefinition on the result + invTerm = expandDefinition(logicRequest, invTerm); + lem = invTerm.eqNode(node[0]); + } + Assert(!lem.isNull()); + if (options::nlExt()) + { + d_nonlinearExtension->addDefinition(lem); + } + else + { + d_nlIncomplete = true; + } + return var; + } + return (*it).second; + break; + } + + default: return node; break; } Unreachable(); diff --git a/src/theory/arith/theory_arith_private.h b/src/theory/arith/theory_arith_private.h index 912bae5e6..23712016d 100644 --- a/src/theory/arith/theory_arith_private.h +++ b/src/theory/arith/theory_arith_private.h @@ -414,7 +414,7 @@ private: * precondition: The linear abstraction of the nodes must be satisfiable. */ DeltaRational getDeltaValue(TNode term) const - throw(DeltaRationalException, ModelException); + /* throw(DeltaRationalException, ModelException) */; Node axiomIteForTotalDivision(Node div_tot); Node axiomIteForTotalIntDivision(Node int_div_like); @@ -848,16 +848,17 @@ private: * semantics. Needed to deal with partial function "mod". */ Node d_modZero; - - /** - * Maps for Skolems for to-integer, real/integer div-by-k. - * Introduced during ppRewriteTerms. + + /** + * Maps for Skolems for to-integer, real/integer div-by-k, and inverse + * non-linear operators that are introduced during ppRewriteTerms. */ typedef context::CDHashMap< Node, Node, NodeHashFunction > NodeMap; NodeMap d_to_int_skolem; NodeMap d_div_skolem; NodeMap d_int_div_skolem; - + NodeMap d_nlin_inverse_skolem; + };/* class TheoryArithPrivate */ }/* CVC4::theory::arith namespace */ diff --git a/src/theory/arrays/theory_arrays.h b/src/theory/arrays/theory_arrays.h index 24c286e92..caf466c0c 100644 --- a/src/theory/arrays/theory_arrays.h +++ b/src/theory/arrays/theory_arrays.h @@ -143,9 +143,9 @@ class TheoryArrays : public Theory { std::string name = ""); ~TheoryArrays(); - void setMasterEqualityEngine(eq::EqualityEngine* eq); + void setMasterEqualityEngine(eq::EqualityEngine* eq) override; - std::string identify() const { return std::string("TheoryArrays"); } + std::string identify() const override { return std::string("TheoryArrays"); } ///////////////////////////////////////////////////////////////////////////// // PREPROCESSING @@ -174,17 +174,15 @@ class TheoryArrays : public Theory { bool ppDisequal(TNode a, TNode b); Node solveWrite(TNode term, bool solve1, bool solve2, bool ppCheck); - public: - - PPAssertStatus ppAssert(TNode in, SubstitutionMap& outSubstitutions); - Node ppRewrite(TNode atom); + public: + PPAssertStatus ppAssert(TNode in, SubstitutionMap& outSubstitutions) override; + Node ppRewrite(TNode atom) override; ///////////////////////////////////////////////////////////////////////////// // T-PROPAGATION / REGISTRATION ///////////////////////////////////////////////////////////////////////////// - private: - + private: /** Literals to propagate */ context::CDList<Node> d_literalsToPropagate; @@ -204,19 +202,17 @@ class TheoryArrays : public Theory { /** Helper for preRegisterTerm, also used internally */ void preRegisterTermInternal(TNode n); - public: - - void preRegisterTerm(TNode n); - void propagate(Effort e); + public: + void preRegisterTerm(TNode n) override; + void propagate(Effort e) override; Node explain(TNode n, eq::EqProof* proof); - Node explain(TNode n); + Node explain(TNode n) override; ///////////////////////////////////////////////////////////////////////////// // SHARING ///////////////////////////////////////////////////////////////////////////// - private: - + private: class MayEqualNotifyClass { public: bool notify(TNode propagation) { return true; } @@ -232,46 +228,40 @@ class TheoryArrays : public Theory { // Helper for computeCareGraph void checkPair(TNode r1, TNode r2); - public: - - void addSharedTerm(TNode t); - EqualityStatus getEqualityStatus(TNode a, TNode b); - void computeCareGraph(); + public: + void addSharedTerm(TNode t) override; + EqualityStatus getEqualityStatus(TNode a, TNode b) override; + void computeCareGraph() override; bool isShared(TNode t) - { return (d_sharedArrays.find(t) != d_sharedArrays.end()); } - + { + return (d_sharedArrays.find(t) != d_sharedArrays.end()); + } ///////////////////////////////////////////////////////////////////////////// // MODEL GENERATION ///////////////////////////////////////////////////////////////////////////// - private: - - public: - bool collectModelInfo(TheoryModel* m) override; - - ///////////////////////////////////////////////////////////////////////////// - // NOTIFICATIONS - ///////////////////////////////////////////////////////////////////////////// + public: + bool collectModelInfo(TheoryModel* m) override; - private: - public: + ///////////////////////////////////////////////////////////////////////////// + // NOTIFICATIONS + ///////////////////////////////////////////////////////////////////////////// - Node getNextDecisionRequest( unsigned& priority ); + public: + Node getNextDecisionRequest(unsigned& priority) override; - void presolve(); - void shutdown() { } + void presolve() override; + void shutdown() override {} ///////////////////////////////////////////////////////////////////////////// // MAIN SOLVER ///////////////////////////////////////////////////////////////////////////// - public: - - void check(Effort e); - - private: + public: + void check(Effort e) override; + private: TNode weakEquivGetRep(TNode node); TNode weakEquivGetRepIndex(TNode node, TNode index); void visitAllLeaves(TNode reason, std::vector<TNode>& conjunctions); @@ -454,11 +444,8 @@ class TheoryArrays : public Theory { /** An equality-engine callback for proof reconstruction */ ArrayProofReconstruction d_proofReconstruction; - public: - - eq::EqualityEngine* getEqualityEngine() { - return &d_equalityEngine; - } + public: + eq::EqualityEngine* getEqualityEngine() override { return &d_equalityEngine; } };/* class TheoryArrays */ diff --git a/src/theory/booleans/type_enumerator.h b/src/theory/booleans/type_enumerator.h index 32c6bae42..ac0435442 100644 --- a/src/theory/booleans/type_enumerator.h +++ b/src/theory/booleans/type_enumerator.h @@ -38,7 +38,7 @@ class BooleanEnumerator : public TypeEnumeratorBase<BooleanEnumerator> { type.getConst<TypeConstant>() == BOOLEAN_TYPE); } - Node operator*() { + Node operator*() override { switch(d_value) { case FALSE: return NodeManager::currentNM()->mkConst(false); diff --git a/src/theory/bv/aig_bitblaster.cpp b/src/theory/bv/aig_bitblaster.cpp index 5459340f6..010eaf4e5 100644 --- a/src/theory/bv/aig_bitblaster.cpp +++ b/src/theory/bv/aig_bitblaster.cpp @@ -15,7 +15,10 @@ **/ #include "bitblaster_template.h" + #include "cvc4_private.h" + +#include "base/cvc4_check.h" #include "options/bv_options.h" #include "prop/cnf_stream.h" #include "prop/sat_solver_factory.h" @@ -155,8 +158,7 @@ AigBitblaster::AigBitblaster() d_satSolver = prop::SatSolverFactory::createCryptoMinisat(smtStatisticsRegistry(), "AigBitblaster"); break; - default: - Unreachable("Unknown SAT solver type"); + default: CVC4_FATAL() << "Unknown SAT solver type"; } } diff --git a/src/theory/bv/theory_bv.h b/src/theory/bv/theory_bv.h index 8cefe03b2..1992c0ae3 100644 --- a/src/theory/bv/theory_bv.h +++ b/src/theory/bv/theory_bv.h @@ -64,40 +64,43 @@ public: ~TheoryBV(); - void setMasterEqualityEngine(eq::EqualityEngine* eq); + void setMasterEqualityEngine(eq::EqualityEngine* eq) override; - Node expandDefinition(LogicRequest &logicRequest, Node node); + Node expandDefinition(LogicRequest& logicRequest, Node node) override; void mkAckermanizationAssertions(std::vector<Node>& assertions); - void preRegisterTerm(TNode n); + void preRegisterTerm(TNode n) override; - void check(Effort e); - - bool needsCheckLastEffort(); + void check(Effort e) override; + + bool needsCheckLastEffort() override; - void propagate(Effort e); + void propagate(Effort e) override; - Node explain(TNode n); + Node explain(TNode n) override; bool collectModelInfo(TheoryModel* m) override; - std::string identify() const { return std::string("TheoryBV"); } + std::string identify() const override { return std::string("TheoryBV"); } /** equality engine */ - eq::EqualityEngine * getEqualityEngine(); - bool getCurrentSubstitution( int effort, std::vector< Node >& vars, std::vector< Node >& subs, std::map< Node, std::vector< Node > >& exp ); - int getReduction( int effort, Node n, Node& nr ); - - PPAssertStatus ppAssert(TNode in, SubstitutionMap& outSubstitutions); + eq::EqualityEngine* getEqualityEngine() override; + bool getCurrentSubstitution(int effort, + std::vector<Node>& vars, + std::vector<Node>& subs, + std::map<Node, std::vector<Node> >& exp) override; + int getReduction(int effort, Node n, Node& nr) override; + + PPAssertStatus ppAssert(TNode in, SubstitutionMap& outSubstitutions) override; void enableCoreTheorySlicer(); - Node ppRewrite(TNode t); + Node ppRewrite(TNode t) override; - void ppStaticLearn(TNode in, NodeBuilder<>& learned); + void ppStaticLearn(TNode in, NodeBuilder<>& learned) override; - void presolve(); + void presolve() override; bool applyAbstraction(const std::vector<Node>& assertions, std::vector<Node>& new_assertions); @@ -206,13 +209,13 @@ private: */ void explain(TNode literal, std::vector<TNode>& assumptions); - void addSharedTerm(TNode t); + void addSharedTerm(TNode t) override; bool isSharedTerm(TNode t) { return d_sharedTermsSet.contains(t); } - EqualityStatus getEqualityStatus(TNode a, TNode b); + EqualityStatus getEqualityStatus(TNode a, TNode b) override; - Node getModelValue(TNode var); + Node getModelValue(TNode var) override; inline std::string indent() { diff --git a/src/theory/bv/theory_bv_rewrite_rules_constant_evaluation.h b/src/theory/bv/theory_bv_rewrite_rules_constant_evaluation.h index b53f7bb08..503fe5157 100644 --- a/src/theory/bv/theory_bv_rewrite_rules_constant_evaluation.h +++ b/src/theory/bv/theory_bv_rewrite_rules_constant_evaluation.h @@ -31,7 +31,7 @@ bool RewriteRule<EvalAnd>::applies(TNode node) { Unreachable(); return (node.getKind() == kind::BITVECTOR_AND && node.getNumChildren() == 2 && - utils::isBVGroundTerm(node)); + utils::isBvConstTerm(node)); } template<> inline @@ -50,7 +50,7 @@ bool RewriteRule<EvalOr>::applies(TNode node) { Unreachable(); return (node.getKind() == kind::BITVECTOR_OR && node.getNumChildren() == 2 && - utils::isBVGroundTerm(node)); + utils::isBvConstTerm(node)); } template<> inline @@ -69,7 +69,7 @@ bool RewriteRule<EvalXor>::applies(TNode node) { Unreachable(); return (node.getKind() == kind::BITVECTOR_XOR && node.getNumChildren() == 2 && - utils::isBVGroundTerm(node)); + utils::isBvConstTerm(node)); } template<> inline @@ -86,7 +86,7 @@ Node RewriteRule<EvalXor>::apply(TNode node) { // template<> inline // bool RewriteRule<EvalXnor>::applies(TNode node) { // return (node.getKind() == kind::BITVECTOR_XNOR && -// utils::isBVGroundTerm(node)); +// utils::isBvConstTerm(node)); // } // template<> inline @@ -101,7 +101,7 @@ Node RewriteRule<EvalXor>::apply(TNode node) { template<> inline bool RewriteRule<EvalNot>::applies(TNode node) { return (node.getKind() == kind::BITVECTOR_NOT && - utils::isBVGroundTerm(node)); + utils::isBvConstTerm(node)); } template<> inline @@ -115,7 +115,7 @@ Node RewriteRule<EvalNot>::apply(TNode node) { // template<> inline // bool RewriteRule<EvalComp>::applies(TNode node) { // return (node.getKind() == kind::BITVECTOR_COMP && -// utils::isBVGroundTerm(node)); +// utils::isBvConstTerm(node)); // } // template<> inline @@ -136,7 +136,7 @@ Node RewriteRule<EvalNot>::apply(TNode node) { template<> inline bool RewriteRule<EvalMult>::applies(TNode node) { return (node.getKind() == kind::BITVECTOR_MULT && - utils::isBVGroundTerm(node)); + utils::isBvConstTerm(node)); } template<> inline @@ -153,7 +153,7 @@ Node RewriteRule<EvalMult>::apply(TNode node) { template<> inline bool RewriteRule<EvalPlus>::applies(TNode node) { return (node.getKind() == kind::BITVECTOR_PLUS && - utils::isBVGroundTerm(node)); + utils::isBvConstTerm(node)); } template<> inline @@ -170,7 +170,7 @@ Node RewriteRule<EvalPlus>::apply(TNode node) { // template<> inline // bool RewriteRule<EvalSub>::applies(TNode node) { // return (node.getKind() == kind::BITVECTOR_SUB && -// utils::isBVGroundTerm(node)); +// utils::isBvConstTerm(node)); // } // template<> inline @@ -185,7 +185,7 @@ Node RewriteRule<EvalPlus>::apply(TNode node) { template<> inline bool RewriteRule<EvalNeg>::applies(TNode node) { return (node.getKind() == kind::BITVECTOR_NEG && - utils::isBVGroundTerm(node)); + utils::isBvConstTerm(node)); } template<> inline @@ -198,7 +198,7 @@ Node RewriteRule<EvalNeg>::apply(TNode node) { } template<> inline bool RewriteRule<EvalUdiv>::applies(TNode node) { - return (utils::isBVGroundTerm(node) && + return (utils::isBvConstTerm(node) && (node.getKind() == kind::BITVECTOR_UDIV_TOTAL || (node.getKind() == kind::BITVECTOR_UDIV && node[1].isConst()))); } @@ -214,7 +214,7 @@ Node RewriteRule<EvalUdiv>::apply(TNode node) { } template<> inline bool RewriteRule<EvalUrem>::applies(TNode node) { - return (utils::isBVGroundTerm(node) && + return (utils::isBvConstTerm(node) && (node.getKind() == kind::BITVECTOR_UREM_TOTAL || (node.getKind() == kind::BITVECTOR_UREM && node[1].isConst()))); } @@ -231,7 +231,7 @@ Node RewriteRule<EvalUrem>::apply(TNode node) { template<> inline bool RewriteRule<EvalShl>::applies(TNode node) { return (node.getKind() == kind::BITVECTOR_SHL && - utils::isBVGroundTerm(node)); + utils::isBvConstTerm(node)); } template<> inline @@ -247,7 +247,7 @@ Node RewriteRule<EvalShl>::apply(TNode node) { template<> inline bool RewriteRule<EvalLshr>::applies(TNode node) { return (node.getKind() == kind::BITVECTOR_LSHR && - utils::isBVGroundTerm(node)); + utils::isBvConstTerm(node)); } template<> inline @@ -263,7 +263,7 @@ Node RewriteRule<EvalLshr>::apply(TNode node) { template<> inline bool RewriteRule<EvalAshr>::applies(TNode node) { return (node.getKind() == kind::BITVECTOR_ASHR && - utils::isBVGroundTerm(node)); + utils::isBvConstTerm(node)); } template<> inline @@ -279,7 +279,7 @@ Node RewriteRule<EvalAshr>::apply(TNode node) { template<> inline bool RewriteRule<EvalUlt>::applies(TNode node) { return (node.getKind() == kind::BITVECTOR_ULT && - utils::isBVGroundTerm(node)); + utils::isBvConstTerm(node)); } template<> inline @@ -297,7 +297,7 @@ Node RewriteRule<EvalUlt>::apply(TNode node) { template<> inline bool RewriteRule<EvalUltBv>::applies(TNode node) { return (node.getKind() == kind::BITVECTOR_ULTBV && - utils::isBVGroundTerm(node)); + utils::isBvConstTerm(node)); } template<> inline @@ -315,7 +315,7 @@ Node RewriteRule<EvalUltBv>::apply(TNode node) { template<> inline bool RewriteRule<EvalSlt>::applies(TNode node) { return (node.getKind() == kind::BITVECTOR_SLT && - utils::isBVGroundTerm(node)); + utils::isBvConstTerm(node)); } template<> inline @@ -334,7 +334,7 @@ Node RewriteRule<EvalSlt>::apply(TNode node) { template<> inline bool RewriteRule<EvalSltBv>::applies(TNode node) { return (node.getKind() == kind::BITVECTOR_SLTBV && - utils::isBVGroundTerm(node)); + utils::isBvConstTerm(node)); } template<> inline @@ -354,7 +354,7 @@ template<> inline bool RewriteRule<EvalITEBv>::applies(TNode node) { Debug("bv-rewrite") << "RewriteRule<EvalITEBv>::applies(" << node << ")" << std::endl; return (node.getKind() == kind::BITVECTOR_ITE && - utils::isBVGroundTerm(node)); + utils::isBvConstTerm(node)); } template<> inline @@ -373,7 +373,7 @@ Node RewriteRule<EvalITEBv>::apply(TNode node) { template<> inline bool RewriteRule<EvalUle>::applies(TNode node) { return (node.getKind() == kind::BITVECTOR_ULE && - utils::isBVGroundTerm(node)); + utils::isBvConstTerm(node)); } template<> inline @@ -391,7 +391,7 @@ Node RewriteRule<EvalUle>::apply(TNode node) { template<> inline bool RewriteRule<EvalSle>::applies(TNode node) { return (node.getKind() == kind::BITVECTOR_SLE && - utils::isBVGroundTerm(node)); + utils::isBvConstTerm(node)); } template<> inline @@ -409,7 +409,7 @@ Node RewriteRule<EvalSle>::apply(TNode node) { template<> inline bool RewriteRule<EvalExtract>::applies(TNode node) { return (node.getKind() == kind::BITVECTOR_EXTRACT && - utils::isBVGroundTerm(node)); + utils::isBvConstTerm(node)); } template<> inline @@ -427,7 +427,7 @@ Node RewriteRule<EvalExtract>::apply(TNode node) { template<> inline bool RewriteRule<EvalConcat>::applies(TNode node) { return (node.getKind() == kind::BITVECTOR_CONCAT && - utils::isBVGroundTerm(node)); + utils::isBvConstTerm(node)); } template<> inline @@ -445,7 +445,7 @@ Node RewriteRule<EvalConcat>::apply(TNode node) { template<> inline bool RewriteRule<EvalSignExtend>::applies(TNode node) { return (node.getKind() == kind::BITVECTOR_SIGN_EXTEND && - utils::isBVGroundTerm(node)); + utils::isBvConstTerm(node)); } template<> inline @@ -461,7 +461,7 @@ Node RewriteRule<EvalSignExtend>::apply(TNode node) { template<> inline bool RewriteRule<EvalEquals>::applies(TNode node) { return (node.getKind() == kind::EQUAL && - utils::isBVGroundTerm(node)); + utils::isBvConstTerm(node)); } template<> inline @@ -479,7 +479,7 @@ Node RewriteRule<EvalEquals>::apply(TNode node) { template<> inline bool RewriteRule<EvalComp>::applies(TNode node) { return (node.getKind() == kind::BITVECTOR_COMP && - utils::isBVGroundTerm(node)); + utils::isBvConstTerm(node)); } template<> inline diff --git a/src/theory/bv/theory_bv_rewrite_rules_normalization.h b/src/theory/bv/theory_bv_rewrite_rules_normalization.h index d17f20152..ad5b37a2d 100644 --- a/src/theory/bv/theory_bv_rewrite_rules_normalization.h +++ b/src/theory/bv/theory_bv_rewrite_rules_normalization.h @@ -394,19 +394,22 @@ Node RewriteRule<MultSimplify>::apply(TNode node) { std::vector<Node> children; for (const TNode& current : node) { - if (current.getKind() == kind::CONST_BITVECTOR) { - BitVector value = current.getConst<BitVector>(); + Node c = current; + if (c.getKind() == kind::BITVECTOR_NEG) + { + isNeg = !isNeg; + c = c[0]; + } + + if (c.getKind() == kind::CONST_BITVECTOR) + { + BitVector value = c.getConst<BitVector>(); constant = constant * value; if(constant == BitVector(size, (unsigned) 0)) { return utils::mkConst(size, 0); } - } - else if (current.getKind() == kind::BITVECTOR_NEG) - { - isNeg = !isNeg; - children.push_back(current[0]); } else { - children.push_back(current); + children.push_back(c); } } BitVector oValue = BitVector(size, static_cast<unsigned>(1)); @@ -414,8 +417,7 @@ Node RewriteRule<MultSimplify>::apply(TNode node) { if (children.empty()) { - Assert(!isNeg); - return utils::mkConst(constant); + return utils::mkConst(isNeg ? -constant : constant); } std::sort(children.begin(), children.end()); diff --git a/src/theory/bv/theory_bv_rewrite_rules_simplification.h b/src/theory/bv/theory_bv_rewrite_rules_simplification.h index 067440ee2..fb083c568 100644 --- a/src/theory/bv/theory_bv_rewrite_rules_simplification.h +++ b/src/theory/bv/theory_bv_rewrite_rules_simplification.h @@ -790,7 +790,15 @@ inline Node RewriteRule<MultPow2>::apply(TNode node) } } - Node a = utils::mkNode(kind::BITVECTOR_MULT, children); + Node a; + if (children.empty()) + { + a = utils::mkOne(size); + } + else + { + a = utils::mkNode(kind::BITVECTOR_MULT, children); + } if (isNeg && size > 1) { diff --git a/src/theory/bv/theory_bv_utils.cpp b/src/theory/bv/theory_bv_utils.cpp index 783d04492..9b66574f6 100644 --- a/src/theory/bv/theory_bv_utils.cpp +++ b/src/theory/bv/theory_bv_utils.cpp @@ -2,17 +2,16 @@ /*! \file theory_bv_utils.cpp ** \verbatim ** Top contributors (to current version): - ** Liana Hadarean, Tim King, Paul Meng + ** Aina Niemetz, Liana Hadarean, Tim King ** This file is part of the CVC4 project. - ** Copyright (c) 2009-2017 by the authors listed in the file AUTHORS + ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS ** in the top-level source directory) and their institutional affiliations. ** All rights reserved. See the file COPYING in the top-level source ** directory for licensing information.\endverbatim ** - ** \brief [[ Add one-line brief description here ]] + ** \brief Util functions for theory BV. ** - ** [[ Add lengthier description here ]] - ** \todo document this file + ** Util functions for theory BV. **/ #include "theory/bv/theory_bv_utils.h" @@ -26,21 +25,422 @@ namespace theory { namespace bv { namespace utils { -Node mkSum(std::vector<Node>& children, unsigned width) +/* ------------------------------------------------------------------------- */ + +uint32_t pow2(uint32_t n) +{ + Assert(n < 32); + uint32_t one = 1; + return one << n; +} + +/* ------------------------------------------------------------------------- */ + +BitVector mkBitVectorOnes(unsigned size) +{ + Assert(size > 0); + return BitVector(1, Integer(1)).signExtend(size - 1); +} + +BitVector mkBitVectorMinSigned(unsigned size) +{ + Assert(size > 0); + return BitVector(size).setBit(size - 1); +} + +BitVector mkBitVectorMaxSigned(unsigned size) +{ + Assert(size > 0); + return ~mkBitVectorMinSigned(size); +} + +/* ------------------------------------------------------------------------- */ + +unsigned getSize(TNode node) +{ + return node.getType().getBitVectorSize(); +} + +const bool getBit(TNode node, unsigned i) +{ + Assert(i < utils::getSize(node) && node.getKind() == kind::CONST_BITVECTOR); + Integer bit = node.getConst<BitVector>().extract(i, i).getValue(); + return (bit == 1u); +} + +/* ------------------------------------------------------------------------- */ + +unsigned getExtractHigh(TNode node) +{ + return node.getOperator().getConst<BitVectorExtract>().high; +} + +unsigned getExtractLow(TNode node) +{ + return node.getOperator().getConst<BitVectorExtract>().low; +} + +unsigned getSignExtendAmount(TNode node) +{ + return node.getOperator().getConst<BitVectorSignExtend>().signExtendAmount; +} + +/* ------------------------------------------------------------------------- */ + +bool isZero(TNode node) +{ + if (!node.isConst()) return false; + return node == utils::mkConst(utils::getSize(node), 0u); +} + +unsigned isPow2Const(TNode node, bool& isNeg) +{ + if (node.getKind() != kind::CONST_BITVECTOR) + { + return false; + } + + BitVector bv = node.getConst<BitVector>(); + unsigned p = bv.isPow2(); + if (p != 0) + { + isNeg = false; + return p; + } + BitVector nbv = -bv; + p = nbv.isPow2(); + if (p != 0) + { + isNeg = true; + return p; + } + return false; +} + +bool isBvConstTerm(TNode node) +{ + if (node.getNumChildren() == 0) + { + return node.isConst(); + } + + for (size_t i = 0; i < node.getNumChildren(); ++i) + { + if (!node[i].isConst()) + { + return false; + } + } + return true; +} + +bool isBVPredicate(TNode node) +{ + if (node.getKind() == kind::EQUAL || node.getKind() == kind::BITVECTOR_ULT + || node.getKind() == kind::BITVECTOR_SLT + || node.getKind() == kind::BITVECTOR_UGT + || node.getKind() == kind::BITVECTOR_UGE + || node.getKind() == kind::BITVECTOR_SGT + || node.getKind() == kind::BITVECTOR_SGE + || node.getKind() == kind::BITVECTOR_ULE + || node.getKind() == kind::BITVECTOR_SLE + || node.getKind() == kind::BITVECTOR_REDOR + || node.getKind() == kind::BITVECTOR_REDAND + || (node.getKind() == kind::NOT + && (node[0].getKind() == kind::EQUAL + || node[0].getKind() == kind::BITVECTOR_ULT + || node[0].getKind() == kind::BITVECTOR_SLT + || node[0].getKind() == kind::BITVECTOR_UGT + || node[0].getKind() == kind::BITVECTOR_UGE + || node[0].getKind() == kind::BITVECTOR_SGT + || node[0].getKind() == kind::BITVECTOR_SGE + || node[0].getKind() == kind::BITVECTOR_ULE + || node[0].getKind() == kind::BITVECTOR_SLE + || node[0].getKind() == kind::BITVECTOR_REDOR + || node[0].getKind() == kind::BITVECTOR_REDAND))) + { + return true; + } + else + { + return false; + } +} + +bool isCoreTerm(TNode term, TNodeBoolMap& cache) +{ + term = term.getKind() == kind::NOT ? term[0] : term; + TNodeBoolMap::const_iterator it = cache.find(term); + if (it != cache.end()) + { + return it->second; + } + + if (term.getNumChildren() == 0) return true; + + if (theory::Theory::theoryOf(theory::THEORY_OF_TERM_BASED, term) == THEORY_BV) + { + Kind k = term.getKind(); + if (k != kind::CONST_BITVECTOR && k != kind::BITVECTOR_CONCAT + && k != kind::BITVECTOR_EXTRACT && k != kind::EQUAL + && term.getMetaKind() != kind::metakind::VARIABLE) + { + cache[term] = false; + return false; + } + } + + for (unsigned i = 0; i < term.getNumChildren(); ++i) + { + if (!isCoreTerm(term[i], cache)) + { + cache[term] = false; + return false; + } + } + + cache[term] = true; + return true; +} + +bool isEqualityTerm(TNode term, TNodeBoolMap& cache) +{ + term = term.getKind() == kind::NOT ? term[0] : term; + TNodeBoolMap::const_iterator it = cache.find(term); + if (it != cache.end()) + { + return it->second; + } + + if (term.getNumChildren() == 0) return true; + + if (theory::Theory::theoryOf(theory::THEORY_OF_TERM_BASED, term) == THEORY_BV) + { + Kind k = term.getKind(); + if (k != kind::CONST_BITVECTOR && k != kind::EQUAL + && term.getMetaKind() != kind::metakind::VARIABLE) + { + cache[term] = false; + return false; + } + } + + for (unsigned i = 0; i < term.getNumChildren(); ++i) + { + if (!isEqualityTerm(term[i], cache)) + { + cache[term] = false; + return false; + } + } + + cache[term] = true; + return true; +} + +bool isBitblastAtom(Node lit) +{ + TNode atom = lit.getKind() == kind::NOT ? lit[0] : lit; + return atom.getKind() != kind::EQUAL || atom[0].getType().isBitVector(); +} + +/* ------------------------------------------------------------------------- */ + +Node mkTrue() +{ + return NodeManager::currentNM()->mkConst<bool>(true); +} + +Node mkFalse() +{ + return NodeManager::currentNM()->mkConst<bool>(false); +} + +Node mkOnes(unsigned size) +{ + BitVector val = mkBitVectorOnes(size); + return NodeManager::currentNM()->mkConst<BitVector>(val); +} + +Node mkZero(unsigned size) +{ + return mkConst(size, 0u); +} + +Node mkOne(unsigned size) +{ + return mkConst(size, 1u); +} + +/* ------------------------------------------------------------------------- */ + +Node mkConst(unsigned size, unsigned int value) +{ + BitVector val(size, value); + return NodeManager::currentNM()->mkConst<BitVector>(val); +} + +Node mkConst(unsigned size, Integer& value) +{ + return NodeManager::currentNM()->mkConst<BitVector>(BitVector(size, value)); +} + +Node mkConst(const BitVector& value) +{ + return NodeManager::currentNM()->mkConst<BitVector>(value); +} + +/* ------------------------------------------------------------------------- */ + +Node mkVar(unsigned size) +{ + NodeManager* nm = NodeManager::currentNM(); + + return nm->mkSkolem("BVSKOLEM$$", + nm->mkBitVectorType(size), + "is a variable created by the theory of bitvectors"); +} + +/* ------------------------------------------------------------------------- */ + +Node mkNode(Kind kind, TNode child) +{ + return NodeManager::currentNM()->mkNode(kind, child); +} + +Node mkNode(Kind kind, TNode child1, TNode child2) { - std::size_t nchildren = children.size(); + return NodeManager::currentNM()->mkNode(kind, child1, child2); +} - if (nchildren == 0) +Node mkNode(Kind kind, TNode child1, TNode child2, TNode child3) +{ + return NodeManager::currentNM()->mkNode(kind, child1, child2, child3); +} + +Node mkNode(Kind kind, std::vector<Node>& children) +{ + Assert(children.size() > 0); + if (children.size() == 1) { - return mkZero(width); + return children[0]; } - else if (nchildren == 1) + return NodeManager::currentNM()->mkNode(kind, children); +} + +/* ------------------------------------------------------------------------- */ + +Node mkSortedNode(Kind kind, TNode child1, TNode child2) +{ + Assert(kind == kind::BITVECTOR_AND || kind == kind::BITVECTOR_OR + || kind == kind::BITVECTOR_XOR); + + if (child1 < child2) + { + return NodeManager::currentNM()->mkNode(kind, child1, child2); + } + else + { + return NodeManager::currentNM()->mkNode(kind, child2, child1); + } +} + +Node mkSortedNode(Kind kind, std::vector<Node>& children) +{ + Assert(kind == kind::BITVECTOR_AND || kind == kind::BITVECTOR_OR + || kind == kind::BITVECTOR_XOR); + Assert(children.size() > 0); + if (children.size() == 1) { return children[0]; } - return NodeManager::currentNM()->mkNode(kind::BITVECTOR_PLUS, children); + std::sort(children.begin(), children.end()); + return NodeManager::currentNM()->mkNode(kind, children); +} + +/* ------------------------------------------------------------------------- */ + +Node mkNot(Node child) +{ + return NodeManager::currentNM()->mkNode(kind::NOT, child); +} + +Node mkAnd(TNode node1, TNode node2) +{ + return NodeManager::currentNM()->mkNode(kind::AND, node1, node2); +} + +Node mkOr(TNode node1, TNode node2) +{ + return NodeManager::currentNM()->mkNode(kind::OR, node1, node2); } +Node mkXor(TNode node1, TNode node2) +{ + return NodeManager::currentNM()->mkNode(kind::XOR, node1, node2); +} + +/* ------------------------------------------------------------------------- */ + +Node mkSignExtend(TNode node, unsigned amount) +{ + NodeManager* nm = NodeManager::currentNM(); + Node signExtendOp = + nm->mkConst<BitVectorSignExtend>(BitVectorSignExtend(amount)); + return nm->mkNode(signExtendOp, node); +} + +/* ------------------------------------------------------------------------- */ + +Node mkExtract(TNode node, unsigned high, unsigned low) +{ + Node extractOp = NodeManager::currentNM()->mkConst<BitVectorExtract>( + BitVectorExtract(high, low)); + std::vector<Node> children; + children.push_back(node); + return NodeManager::currentNM()->mkNode(extractOp, children); +} + +Node mkBitOf(TNode node, unsigned index) +{ + Node bitOfOp = + NodeManager::currentNM()->mkConst<BitVectorBitOf>(BitVectorBitOf(index)); + return NodeManager::currentNM()->mkNode(bitOfOp, node); +} + +/* ------------------------------------------------------------------------- */ + +Node mkConcat(TNode t1, TNode t2) +{ + return NodeManager::currentNM()->mkNode(kind::BITVECTOR_CONCAT, t1, t2); +} + +Node mkConcat(std::vector<Node>& children) +{ + if (children.size() > 1) + return NodeManager::currentNM()->mkNode(kind::BITVECTOR_CONCAT, children); + else + return children[0]; +} + +Node mkConcat(TNode node, unsigned repeat) +{ + Assert(repeat); + if (repeat == 1) + { + return node; + } + NodeBuilder<> result(kind::BITVECTOR_CONCAT); + for (unsigned i = 0; i < repeat; ++i) + { + result << node; + } + Node resultNode = result; + return resultNode; +} + +/* ------------------------------------------------------------------------- */ + Node mkInc(TNode t) { return NodeManager::currentNM()->mkNode( @@ -53,6 +453,8 @@ Node mkDec(TNode t) kind::BITVECTOR_SUB, t, bv::utils::mkOne(bv::utils::getSize(t))); } +/* ------------------------------------------------------------------------- */ + Node mkUmulo(TNode t1, TNode t2) { unsigned w = getSize(t1); @@ -76,96 +478,233 @@ Node mkUmulo(TNode t1, TNode t2) return nm->mkNode(kind::EQUAL, nm->mkNode(kind::BITVECTOR_OR, tmp), mkOne(1)); } -bool isCoreTerm(TNode term, TNodeBoolMap& cache) { - term = term.getKind() == kind::NOT ? term[0] : term; - TNodeBoolMap::const_iterator it = cache.find(term); - if (it != cache.end()) { - return it->second; +/* ------------------------------------------------------------------------- */ + +Node mkConjunction(const std::set<TNode> nodes) +{ + std::set<TNode> expandedNodes; + + std::set<TNode>::const_iterator it = nodes.begin(); + std::set<TNode>::const_iterator it_end = nodes.end(); + while (it != it_end) + { + TNode current = *it; + if (current != mkTrue()) + { + Assert(current.getKind() == kind::EQUAL + || (current.getKind() == kind::NOT + && current[0].getKind() == kind::EQUAL)); + expandedNodes.insert(current); + } + ++it; } - if (term.getNumChildren() == 0) - return true; + Assert(expandedNodes.size() > 0); + if (expandedNodes.size() == 1) + { + return *expandedNodes.begin(); + } - if (theory::Theory::theoryOf(theory::THEORY_OF_TERM_BASED, term) == THEORY_BV) { - Kind k = term.getKind(); - if (k != kind::CONST_BITVECTOR && - k != kind::BITVECTOR_CONCAT && - k != kind::BITVECTOR_EXTRACT && - k != kind::EQUAL && - term.getMetaKind() != kind::metakind::VARIABLE) { - cache[term] = false; - return false; + NodeBuilder<> conjunction(kind::AND); + + it = expandedNodes.begin(); + it_end = expandedNodes.end(); + while (it != it_end) + { + conjunction << *it; + ++it; + } + + return conjunction; +} + +Node mkConjunction(const std::vector<TNode>& nodes) +{ + std::vector<TNode> expandedNodes; + + std::vector<TNode>::const_iterator it = nodes.begin(); + std::vector<TNode>::const_iterator it_end = nodes.end(); + while (it != it_end) + { + TNode current = *it; + + if (current != mkTrue()) + { + Assert(isBVPredicate(current)); + expandedNodes.push_back(current); } + ++it; } - for (unsigned i = 0; i < term.getNumChildren(); ++i) { - if (!isCoreTerm(term[i], cache)) { - cache[term] = false; - return false; + if (expandedNodes.size() == 0) + { + return mkTrue(); + } + + if (expandedNodes.size() == 1) + { + return *expandedNodes.begin(); + } + + NodeBuilder<> conjunction(kind::AND); + + it = expandedNodes.begin(); + it_end = expandedNodes.end(); + while (it != it_end) + { + conjunction << *it; + ++it; + } + + return conjunction; +} + +/* ------------------------------------------------------------------------- */ + +void getConjuncts(TNode node, std::set<TNode>& conjuncts) +{ + if (node.getKind() != kind::AND) + { + conjuncts.insert(node); + } + else + { + for (unsigned i = 0; i < node.getNumChildren(); ++i) + { + getConjuncts(node[i], conjuncts); } } +} - cache[term]= true; - return true; +void getConjuncts(std::vector<TNode>& nodes, std::set<TNode>& conjuncts) +{ + for (unsigned i = 0, i_end = nodes.size(); i < i_end; ++i) + { + getConjuncts(nodes[i], conjuncts); + } } -bool isEqualityTerm(TNode term, TNodeBoolMap& cache) { - term = term.getKind() == kind::NOT ? term[0] : term; - TNodeBoolMap::const_iterator it = cache.find(term); - if (it != cache.end()) { - return it->second; +Node flattenAnd(std::vector<TNode>& queue) +{ + TNodeSet nodes; + while (!queue.empty()) + { + TNode current = queue.back(); + queue.pop_back(); + if (current.getKind() == kind::AND) + { + for (unsigned i = 0; i < current.getNumChildren(); ++i) + { + if (nodes.count(current[i]) == 0) + { + queue.push_back(current[i]); + } + } + } + else + { + nodes.insert(current); + } + } + std::vector<TNode> children; + for (TNodeSet::const_iterator it = nodes.begin(); it != nodes.end(); ++it) + { + children.push_back(*it); } + return mkAnd(children); +} - if (term.getNumChildren() == 0) - return true; +/* ------------------------------------------------------------------------- */ - if (theory::Theory::theoryOf(theory::THEORY_OF_TERM_BASED, term) == THEORY_BV) { - Kind k = term.getKind(); - if (k != kind::CONST_BITVECTOR && - k != kind::EQUAL && - term.getMetaKind() != kind::metakind::VARIABLE) { - cache[term] = false; - return false; +std::string setToString(const std::set<TNode>& nodeSet) { + std::stringstream out; + out << "["; + std::set<TNode>::const_iterator it = nodeSet.begin(); + std::set<TNode>::const_iterator it_end = nodeSet.end(); + bool first = true; + while (it != it_end) { + if (!first) { + out << ","; } + first = false; + out << *it; + ++ it; } + out << "]"; + return out.str(); +} - for (unsigned i = 0; i < term.getNumChildren(); ++i) { - if (!isEqualityTerm(term[i], cache)) { - cache[term] = false; - return false; +std::string vectorToString(const std::vector<Node>& nodes) +{ + std::stringstream out; + out << "["; + for (unsigned i = 0; i < nodes.size(); ++i) + { + if (i > 0) + { + out << ","; } + out << nodes[i]; } + out << "]"; + return out.str(); +} - cache[term]= true; - return true; +/* ------------------------------------------------------------------------- */ + +// FIXME: dumb code +void intersect(const std::vector<uint32_t>& v1, + const std::vector<uint32_t>& v2, + std::vector<uint32_t>& intersection) { + for (unsigned i = 0; i < v1.size(); ++i) { + bool found = false; + for (unsigned j = 0; j < v2.size(); ++j) { + if (v2[j] == v1[i]) { + found = true; + break; + } + } + if (found) { + intersection.push_back(v1[i]); + } + } } +/* ------------------------------------------------------------------------- */ -uint64_t numNodes(TNode node, NodeSet& seen) { - if (seen.find(node) != seen.end()) - return 0; +uint64_t numNodes(TNode node, NodeSet& seen) +{ + if (seen.find(node) != seen.end()) return 0; uint64_t size = 1; - for (unsigned i = 0; i < node.getNumChildren(); ++i) { + for (unsigned i = 0; i < node.getNumChildren(); ++i) + { size += numNodes(node[i], seen); } seen.insert(node); return size; } -void collectVariables(TNode node, NodeSet& vars) { - if (vars.find(node) != vars.end()) - return; +/* ------------------------------------------------------------------------- */ - if (Theory::isLeafOf(node, THEORY_BV) && node.getKind() != kind::CONST_BITVECTOR) { +void collectVariables(TNode node, NodeSet& vars) +{ + if (vars.find(node) != vars.end()) return; + + if (Theory::isLeafOf(node, THEORY_BV) + && node.getKind() != kind::CONST_BITVECTOR) + { vars.insert(node); return; } - for (unsigned i = 0; i < node.getNumChildren(); ++i) { + for (unsigned i = 0; i < node.getNumChildren(); ++i) + { collectVariables(node[i], vars); } } +/* ------------------------------------------------------------------------- */ + }/* CVC4::theory::bv::utils namespace */ }/* CVC4::theory::bv namespace */ }/* CVC4::theory namespace */ diff --git a/src/theory/bv/theory_bv_utils.h b/src/theory/bv/theory_bv_utils.h index e304e4801..f6784621f 100644 --- a/src/theory/bv/theory_bv_utils.h +++ b/src/theory/bv/theory_bv_utils.h @@ -2,17 +2,16 @@ /*! \file theory_bv_utils.h ** \verbatim ** Top contributors (to current version): - ** Liana Hadarean, Dejan Jovanovic, Morgan Deters + ** Aina Niemetz, Dejan Jovanovic, Morgan Deters ** This file is part of the CVC4 project. - ** Copyright (c) 2009-2017 by the authors listed in the file AUTHORS + ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS ** in the top-level source directory) and their institutional affiliations. ** All rights reserved. See the file COPYING in the top-level source ** directory for licensing information.\endverbatim ** - ** \brief [[ Add one-line brief description here ]] + ** \brief Util functions for theory BV. ** - ** [[ Add lengthier description here ]] - ** \todo document this file + ** Util functions for theory BV. **/ #include "cvc4_private.h" @@ -36,205 +35,153 @@ typedef std::unordered_set<TNode, TNodeHashFunction> TNodeSet; namespace utils { -inline uint32_t pow2(uint32_t power) { - Assert (power < 32); - uint32_t one = 1; - return one << power; -} - -inline unsigned getExtractHigh(TNode node) { - return node.getOperator().getConst<BitVectorExtract>().high; -} - -inline unsigned getExtractLow(TNode node) { - return node.getOperator().getConst<BitVectorExtract>().low; -} +typedef std::unordered_map<TNode, bool, TNodeHashFunction> TNodeBoolMap; +typedef std::unordered_set<Node, NodeHashFunction> NodeSet; -inline unsigned getSize(TNode node) { - return node.getType().getBitVectorSize(); -} +/* Compute 2^n. */ +uint32_t pow2(uint32_t n); -inline unsigned getSignExtendAmount(TNode node) +/* Compute the greatest common divisor for two objects of Type T. */ +template <class T> +T gcd(T a, T b) { - return node.getOperator().getConst<BitVectorSignExtend>().signExtendAmount; -} - -inline const bool getBit(TNode node, unsigned i) { - Assert (i < utils::getSize(node) && - node.getKind() == kind::CONST_BITVECTOR); - Integer bit = node.getConst<BitVector>().extract(i, i).getValue(); - return (bit == 1u); -} - -inline Node mkTrue() { - return NodeManager::currentNM()->mkConst<bool>(true); -} - -inline Node mkFalse() { - return NodeManager::currentNM()->mkConst<bool>(false); -} - -inline Node mkVar(unsigned size) { - NodeManager* nm = NodeManager::currentNM(); - - return nm->mkSkolem("BVSKOLEM$$", nm->mkBitVectorType(size), "is a variable created by the theory of bitvectors"); -} - - -inline Node mkSortedNode(Kind kind, std::vector<Node>& children) { - Assert (kind == kind::BITVECTOR_AND || - kind == kind::BITVECTOR_OR || - kind == kind::BITVECTOR_XOR); - Assert(children.size() > 0); - if (children.size() == 1) { - return children[0]; - } - std::sort(children.begin(), children.end()); - return NodeManager::currentNM()->mkNode(kind, children); -} - - -inline Node mkNode(Kind kind, std::vector<Node>& children) { - Assert (children.size() > 0); - if (children.size() == 1) { - return children[0]; - } - return NodeManager::currentNM()->mkNode(kind, children); -} - -inline Node mkNode(Kind kind, TNode child) { - return NodeManager::currentNM()->mkNode(kind, child); -} - -inline Node mkNode(Kind kind, TNode child1, TNode child2) { - return NodeManager::currentNM()->mkNode(kind, child1, child2); -} - - -inline Node mkSortedNode(Kind kind, TNode child1, TNode child2) { - Assert (kind == kind::BITVECTOR_AND || - kind == kind::BITVECTOR_OR || - kind == kind::BITVECTOR_XOR); - - if (child1 < child2) { - return NodeManager::currentNM()->mkNode(kind, child1, child2); - } else { - return NodeManager::currentNM()->mkNode(kind, child2, child1); + while (b != 0) + { + T t = b; + b = a % t; + a = t; } + return a; } -inline Node mkNode(Kind kind, TNode child1, TNode child2, TNode child3) { - return NodeManager::currentNM()->mkNode(kind, child1, child2, child3); -} - - -inline Node mkNot(Node child) { - return NodeManager::currentNM()->mkNode(kind::NOT, child); -} - -inline Node mkAnd(TNode node1, TNode node2) { - return NodeManager::currentNM()->mkNode(kind::AND, node1, node2); -} - -inline Node mkOr(TNode node1, TNode node2) { - return NodeManager::currentNM()->mkNode(kind::OR, node1, node2); -} - -inline Node mkXor(TNode node1, TNode node2) { - return NodeManager::currentNM()->mkNode(kind::XOR, node1, node2); -} - -inline Node mkSignExtend(TNode node, unsigned amount) -{ - NodeManager* nm = NodeManager::currentNM(); - Node signExtendOp = - nm->mkConst<BitVectorSignExtend>(BitVectorSignExtend(amount)); - return nm->mkNode(signExtendOp, node); -} - -inline Node mkExtract(TNode node, unsigned high, unsigned low) { - Node extractOp = NodeManager::currentNM()->mkConst<BitVectorExtract>(BitVectorExtract(high, low)); - std::vector<Node> children; - children.push_back(node); - return NodeManager::currentNM()->mkNode(extractOp, children); -} +/* Create bit-vector of ones of given size. */ +BitVector mkBitVectorOnes(unsigned size); +/* Create bit-vector representing the minimum signed value of given size. */ +BitVector mkBitVectorMinSigned(unsigned size); +/* Create bit-vector representing the maximum signed value of given size. */ +BitVector mkBitVectorMaxSigned(unsigned size); -inline Node mkBitOf(TNode node, unsigned index) { - Node bitOfOp = NodeManager::currentNM()->mkConst<BitVectorBitOf>(BitVectorBitOf(index)); - return NodeManager::currentNM()->mkNode(bitOfOp, node); -} +/* Get the bit-width of given node. */ +unsigned getSize(TNode node); -Node mkSum(std::vector<Node>& children, unsigned width); +/* Get bit at given index. */ +const bool getBit(TNode node, unsigned i); -inline Node mkConcat(TNode node, unsigned repeat) { - Assert (repeat); - if(repeat == 1) { - return node; - } - NodeBuilder<> result(kind::BITVECTOR_CONCAT); - for (unsigned i = 0; i < repeat; ++i) { - result << node; - } - Node resultNode = result; - return resultNode; -} +/* Get the upper index of given extract node. */ +unsigned getExtractHigh(TNode node); +/* Get the lower index of given extract node. */ +unsigned getExtractLow(TNode node); -inline Node mkConcat(std::vector<Node>& children) { - if (children.size() > 1) - return NodeManager::currentNM()->mkNode(kind::BITVECTOR_CONCAT, children); - else - return children[0]; -} +/* Get the number of bits by which a given node is extended. */ +unsigned getSignExtendAmount(TNode node); -inline Node mkConcat(TNode t1, TNode t2) { - return NodeManager::currentNM()->mkNode(kind::BITVECTOR_CONCAT, t1, t2); -} - - -inline BitVector mkBitVectorOnes(unsigned size) { - Assert(size > 0); - return BitVector(1, Integer(1)).signExtend(size - 1); -} - -inline BitVector mkBitVectorMinSigned(unsigned size) +/* Returns true if given node represents a zero bit-vector. */ +bool isZero(TNode node); +/* If node is a constant of the form 2^c or -2^c, then this function returns + * c+1. Otherwise, this function returns 0. The flag isNeg is updated to + * indicate whether node is negative. */ +unsigned isPow2Const(TNode node, bool& isNeg); +/* Returns true if node or all of its children is const. */ +bool isBvConstTerm(TNode node); +/* Returns true if node is a predicate over bit-vector nodes. */ +bool isBVPredicate(TNode node); +/* Returns true if given term is a THEORY_BV term. */ +bool isCoreTerm(TNode term, TNodeBoolMap& cache); +/* Returns true if given term is a bv constant, variable or equality term. */ +bool isEqualityTerm(TNode term, TNodeBoolMap& cache); +/* Returns true if given node is an atom that is bit-blasted. */ +bool isBitblastAtom(Node lit); + +/* Create Boolean node representing true. */ +Node mkTrue(); +/* Create Boolean node representing false. */ +Node mkFalse(); +/* Create bit-vector node representing a bit-vector of ones of given size. */ +Node mkOnes(unsigned size); +/* Create bit-vector node representing a zero bit-vector of given size. */ +Node mkZero(unsigned size); +/* Create bit-vector node representing a bit-vector value one of given size. */ +Node mkOne(unsigned size); + +/* Create bit-vector constant of given size and value. */ +Node mkConst(unsigned size, unsigned int value); +Node mkConst(unsigned size, Integer& value); +/* Create bit-vector constant from given bit-vector. */ +Node mkConst(const BitVector& value); + +/* Create bit-vector variable. */ +Node mkVar(unsigned size); + +/* Create n-ary node of given kind. */ +Node mkNode(Kind kind, TNode child); +Node mkNode(Kind kind, TNode child1, TNode child2); +Node mkNode(Kind kind, TNode child1, TNode child2, TNode child3); +Node mkNode(Kind kind, std::vector<Node>& children); + +/* Create n-ary bit-vector node of kind BITVECTOR_AND, BITVECTOR_OR or + * BITVECTOR_XOR where its children are sorted */ +Node mkSortedNode(Kind kind, TNode child1, TNode child2); +Node mkSortedNode(Kind kind, std::vector<Node>& children); + +/* Create node of kind NOT. */ +Node mkNot(Node child); +/* Create node of kind AND. */ +Node mkAnd(TNode node1, TNode node2); +/* Create n-ary node of kind AND. */ +template<bool ref_count> +Node mkAnd(const std::vector<NodeTemplate<ref_count>>& conjunctions) { - Assert(size > 0); - return BitVector(size).setBit(size - 1); -} + std::set<TNode> all(conjunctions.begin(), conjunctions.end()); -inline BitVector mkBitVectorMaxSigned(unsigned size) -{ - Assert(size > 0); - return ~mkBitVectorMinSigned(size); -} + if (all.size() == 0) { return mkTrue(); } -inline Node mkOnes(unsigned size) { - BitVector val = mkBitVectorOnes(size); - return NodeManager::currentNM()->mkConst<BitVector>(val); -} + /* All the same, or just one */ + if (all.size() == 1) { return conjunctions[0]; } -inline Node mkConst(unsigned size, unsigned int value) { - BitVector val(size, value); - return NodeManager::currentNM()->mkConst<BitVector>(val); + NodeBuilder<> conjunction(kind::AND); + for (const Node& n : all) { conjunction << n; } + return conjunction; } - -inline Node mkConst(unsigned size, Integer& value) +/* Create node of kind OR. */ +Node mkOr(TNode node1, TNode node2); +/* Create n-ary node of kind OR. */ +template<bool ref_count> +Node mkOr(const std::vector<NodeTemplate<ref_count>>& nodes) { - return NodeManager::currentNM()->mkConst<BitVector>(BitVector(size, value)); -} + std::set<TNode> all(nodes.begin(), nodes.end()); -inline Node mkConst(const BitVector& value) { - return NodeManager::currentNM()->mkConst<BitVector>(value); -} + if (all.size() == 0) { return mkTrue(); } -inline Node mkZero(unsigned size) { return mkConst(size, 0u); } + /* All the same, or just one */ + if (all.size() == 1) { return nodes[0]; } -inline Node mkOne(unsigned size) { return mkConst(size, 1u); } - -/* Increment */ + NodeBuilder<> disjunction(kind::OR); + for (const Node& n : all) { disjunction << n; } + return disjunction; +} +/* Create node of kind XOR. */ +Node mkXor(TNode node1, TNode node2); + +/* Create signed extension node where given node is extended by given amount. */ +Node mkSignExtend(TNode node, unsigned amount); + +/* Create extract node where bits from index high to index low are extracted + * from given node. */ +Node mkExtract(TNode node, unsigned high, unsigned low); +/* Create extract node of bit-width 1 where the resulting node represents + * the bit at given index. */ +Node mkBitOf(TNode node, unsigned index); + +/* Create n-ary concat node of given children. */ +Node mkConcat(TNode t1, TNode t2); +Node mkConcat(std::vector<Node>& children); +/* Create concat by repeating given node n times. + * Returns given node if n = 1. */ +Node mkConcat(TNode node, unsigned repeat); + +/* Create bit-vector addition node representing the increment of given node. */ Node mkInc(TNode t); - -/* Decrement */ +/* Create bit-vector addition node representing the decrement of given node. */ Node mkDec(TNode t); /* Unsigned multiplication overflow detection. @@ -243,345 +190,33 @@ Node mkDec(TNode t); * http://ieeexplore.ieee.org/document/987767 */ Node mkUmulo(TNode t1, TNode t2); -inline void getConjuncts(TNode node, std::set<TNode>& conjuncts) { - if (node.getKind() != kind::AND) { - conjuncts.insert(node); - } else { - for (unsigned i = 0; i < node.getNumChildren(); ++ i) { - getConjuncts(node[i], conjuncts); - } - } -} - -inline void getConjuncts(std::vector<TNode>& nodes, std::set<TNode>& conjuncts) { - for (unsigned i = 0, i_end = nodes.size(); i < i_end; ++ i) { - getConjuncts(nodes[i], conjuncts); - } -} +/* Create conjunction over a set of (dis)equalities. */ +Node mkConjunction(const std::set<TNode> nodes); +Node mkConjunction(const std::vector<TNode>& nodes); -inline Node mkConjunction(const std::set<TNode> nodes) { - std::set<TNode> expandedNodes; - - std::set<TNode>::const_iterator it = nodes.begin(); - std::set<TNode>::const_iterator it_end = nodes.end(); - while (it != it_end) { - TNode current = *it; - if (current != mkTrue()) { - Assert(current.getKind() == kind::EQUAL || (current.getKind() == kind::NOT && current[0].getKind() == kind::EQUAL)); - expandedNodes.insert(current); - } - ++ it; - } +/* Get a set of all operands of nested and nodes. */ +void getConjuncts(TNode node, std::set<TNode>& conjuncts); +void getConjuncts(std::vector<TNode>& nodes, std::set<TNode>& conjuncts); +/* Create a flattened and node. */ +Node flattenAnd(std::vector<TNode>& queue); - Assert(expandedNodes.size() > 0); - if (expandedNodes.size() == 1) { - return *expandedNodes.begin(); - } +/* Create a string representing a set of nodes. */ +std::string setToString(const std::set<TNode>& nodeSet); - NodeBuilder<> conjunction(kind::AND); +/* Create a string representing a vector of nodes. */ +std::string vectorToString(const std::vector<Node>& nodes); - it = expandedNodes.begin(); - it_end = expandedNodes.end(); - while (it != it_end) { - conjunction << *it; - ++ it; - } - - return conjunction; -} - -/** - * If node is a constant of the form 2^c or -2^c, then this function returns - * c+1. Otherwise, this function returns 0. The flag isNeg is updated to - * indicate whether node is negative. - */ -inline unsigned isPow2Const(TNode node, bool& isNeg) -{ - if (node.getKind() != kind::CONST_BITVECTOR) { - return false; - } - - BitVector bv = node.getConst<BitVector>(); - unsigned p = bv.isPow2(); - if (p != 0) - { - isNeg = false; - return p; - } - BitVector nbv = -bv; - p = nbv.isPow2(); - if (p != 0) - { - isNeg = true; - return p; - } - return false; -} - -inline Node mkOr(const std::vector<Node>& nodes) { - std::set<TNode> all; - all.insert(nodes.begin(), nodes.end()); - - if (all.size() == 0) { - return mkTrue(); - } - - if (all.size() == 1) { - // All the same, or just one - return nodes[0]; - } - - - NodeBuilder<> disjunction(kind::OR); - std::set<TNode>::const_iterator it = all.begin(); - std::set<TNode>::const_iterator it_end = all.end(); - while (it != it_end) { - disjunction << *it; - ++ it; - } - - return disjunction; -}/* mkOr() */ - - -inline Node mkAnd(const std::vector<TNode>& conjunctions) { - std::set<TNode> all; - all.insert(conjunctions.begin(), conjunctions.end()); - - if (all.size() == 0) { - return mkTrue(); - } - - if (all.size() == 1) { - // All the same, or just one - return conjunctions[0]; - } - - - NodeBuilder<> conjunction(kind::AND); - std::set<TNode>::const_iterator it = all.begin(); - std::set<TNode>::const_iterator it_end = all.end(); - while (it != it_end) { - conjunction << *it; - ++ it; - } - - return conjunction; -}/* mkAnd() */ - -inline Node mkAnd(const std::vector<Node>& conjunctions) { - std::set<TNode> all; - all.insert(conjunctions.begin(), conjunctions.end()); - - if (all.size() == 0) { - return mkTrue(); - } - - if (all.size() == 1) { - // All the same, or just one - return conjunctions[0]; - } - - - NodeBuilder<> conjunction(kind::AND); - std::set<TNode>::const_iterator it = all.begin(); - std::set<TNode>::const_iterator it_end = all.end(); - while (it != it_end) { - conjunction << *it; - ++ it; - } - - return conjunction; -}/* mkAnd() */ - -inline bool isZero(TNode node) { - if (!node.isConst()) return false; - return node == utils::mkConst(utils::getSize(node), 0u); -} - -inline Node flattenAnd(std::vector<TNode>& queue) { - TNodeSet nodes; - while(!queue.empty()) { - TNode current = queue.back(); - queue.pop_back(); - if (current.getKind() == kind::AND) { - for (unsigned i = 0; i < current.getNumChildren(); ++i) { - if (nodes.count(current[i]) == 0) { - queue.push_back(current[i]); - } - } - } else { - nodes.insert(current); - } - } - std::vector<TNode> children; - for (TNodeSet::const_iterator it = nodes.begin(); it!= nodes.end(); ++it) { - children.push_back(*it); - } - return mkAnd(children); -} - - -// need a better name, this is not technically a ground term -inline bool isBVGroundTerm(TNode node) { - if (node.getNumChildren() == 0) { - return node.isConst(); - } - - for (size_t i = 0; i < node.getNumChildren(); ++i) { - if(! node[i].isConst()) { - return false; - } - } - return true; -} - -inline bool isBVPredicate(TNode node) { - if (node.getKind() == kind::EQUAL || - node.getKind() == kind::BITVECTOR_ULT || - node.getKind() == kind::BITVECTOR_SLT || - node.getKind() == kind::BITVECTOR_UGT || - node.getKind() == kind::BITVECTOR_UGE || - node.getKind() == kind::BITVECTOR_SGT || - node.getKind() == kind::BITVECTOR_SGE || - node.getKind() == kind::BITVECTOR_ULE || - node.getKind() == kind::BITVECTOR_SLE || - node.getKind() == kind::BITVECTOR_REDOR || - node.getKind() == kind::BITVECTOR_REDAND || - ( node.getKind() == kind::NOT && (node[0].getKind() == kind::EQUAL || - node[0].getKind() == kind::BITVECTOR_ULT || - node[0].getKind() == kind::BITVECTOR_SLT || - node[0].getKind() == kind::BITVECTOR_UGT || - node[0].getKind() == kind::BITVECTOR_UGE || - node[0].getKind() == kind::BITVECTOR_SGT || - node[0].getKind() == kind::BITVECTOR_SGE || - node[0].getKind() == kind::BITVECTOR_ULE || - node[0].getKind() == kind::BITVECTOR_SLE || - node[0].getKind() == kind::BITVECTOR_REDOR || - node[0].getKind() == kind::BITVECTOR_REDAND))) - { - return true; - } - else - { - return false; - } -} - -inline Node mkConjunction(const std::vector<TNode>& nodes) { - std::vector<TNode> expandedNodes; - - std::vector<TNode>::const_iterator it = nodes.begin(); - std::vector<TNode>::const_iterator it_end = nodes.end(); - while (it != it_end) { - TNode current = *it; - - if (current != mkTrue()) { - Assert(isBVPredicate(current)); - expandedNodes.push_back(current); - } - ++ it; - } - - if (expandedNodes.size() == 0) { - return mkTrue(); - } - - if (expandedNodes.size() == 1) { - return *expandedNodes.begin(); - } - - NodeBuilder<> conjunction(kind::AND); - - it = expandedNodes.begin(); - it_end = expandedNodes.end(); - while (it != it_end) { - conjunction << *it; - ++ it; - } - - return conjunction; -} - - - -// Turn a set into a string -inline std::string setToString(const std::set<TNode>& nodeSet) { - std::stringstream out; - out << "["; - std::set<TNode>::const_iterator it = nodeSet.begin(); - std::set<TNode>::const_iterator it_end = nodeSet.end(); - bool first = true; - while (it != it_end) { - if (!first) { - out << ","; - } - first = false; - out << *it; - ++ it; - } - out << "]"; - return out.str(); -} - -// Turn a vector into a string -inline std::string vectorToString(const std::vector<Node>& nodes) { - std::stringstream out; - out << "["; - for (unsigned i = 0; i < nodes.size(); ++ i) { - if (i > 0) { - out << ","; - } - out << nodes[i]; - } - out << "]"; - return out.str(); -} - -// FIXME: dumb code -inline void intersect(const std::vector<uint32_t>& v1, - const std::vector<uint32_t>& v2, - std::vector<uint32_t>& intersection) { - for (unsigned i = 0; i < v1.size(); ++i) { - bool found = false; - for (unsigned j = 0; j < v2.size(); ++j) { - if (v2[j] == v1[i]) { - found = true; - break; - } - } - if (found) { - intersection.push_back(v1[i]); - } - } -} - -template <class T> -inline T gcd(T a, T b) { - while (b != 0) { - T t = b; - b = a % t; - a = t; - } - return a; -} - -typedef std::unordered_map<TNode, bool, TNodeHashFunction> TNodeBoolMap; - -bool isCoreTerm(TNode term, TNodeBoolMap& cache); -bool isEqualityTerm(TNode term, TNodeBoolMap& cache); -typedef std::unordered_set<Node, NodeHashFunction> NodeSet; +/* Create the intersection of two vectors of uint32_t. */ +void intersect(const std::vector<uint32_t>& v1, + const std::vector<uint32_t>& v2, + std::vector<uint32_t>& intersection); +/* Determine the total number of nodes that a given node consists of. */ uint64_t numNodes(TNode node, NodeSet& seen); +/* Collect all variables under a given a node. */ void collectVariables(TNode node, NodeSet& vars); -// is bitblast atom -inline bool isBitblastAtom( Node lit ) { - TNode atom = lit.getKind()==kind::NOT ? lit[0] : lit; - return atom.getKind()!=kind::EQUAL || atom[0].getType().isBitVector(); -} - } } } diff --git a/src/theory/datatypes/datatypes_rewriter.cpp b/src/theory/datatypes/datatypes_rewriter.cpp index ff3f75998..cc8edadd0 100644 --- a/src/theory/datatypes/datatypes_rewriter.cpp +++ b/src/theory/datatypes/datatypes_rewriter.cpp @@ -23,29 +23,30 @@ namespace datatypes { RewriteResponse DatatypesRewriter::postRewrite(TNode in) { Trace("datatypes-rewrite-debug") << "post-rewriting " << in << std::endl; - if (in.getKind() == kind::APPLY_CONSTRUCTOR) + Kind k = in.getKind(); + NodeManager* nm = NodeManager::currentNM(); + if (k == kind::APPLY_CONSTRUCTOR) { return rewriteConstructor(in); } - else if (in.getKind() == kind::APPLY_SELECTOR_TOTAL) + else if (k == kind::APPLY_SELECTOR_TOTAL) { return rewriteSelector(in); } - else if (in.getKind() == kind::APPLY_TESTER) + else if (k == kind::APPLY_TESTER) { return rewriteTester(in); } - else if (in.getKind() == kind::DT_SIZE) + else if (k == kind::DT_SIZE) { if (in[0].getKind() == kind::APPLY_CONSTRUCTOR) { std::vector<Node> children; - for (unsigned i = 0; i < in[0].getNumChildren(); i++) + for (unsigned i = 0, size = in [0].getNumChildren(); i < size; i++) { if (in[0][i].getType().isDatatype()) { - children.push_back( - NodeManager::currentNM()->mkNode(kind::DT_SIZE, in[0][i])); + children.push_back(nm->mkNode(kind::DT_SIZE, in[0][i])); } } TNode constructor = in[0].getOperator(); @@ -53,17 +54,16 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in) const Datatype& dt = Datatype::datatypeOf(constructor.toExpr()); const DatatypeConstructor& c = dt[constructorIndex]; unsigned weight = c.getWeight(); - children.push_back(NodeManager::currentNM()->mkConst(Rational(weight))); - Node res = children.size() == 1 - ? children[0] - : NodeManager::currentNM()->mkNode(kind::PLUS, children); + children.push_back(nm->mkConst(Rational(weight))); + Node res = + children.size() == 1 ? children[0] : nm->mkNode(kind::PLUS, children); Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: rewrite size " << in << " to " << res << std::endl; return RewriteResponse(REWRITE_AGAIN_FULL, res); } } - else if (in.getKind() == kind::DT_HEIGHT_BOUND) + else if (k == kind::DT_HEIGHT_BOUND) { if (in[0].getKind() == kind::APPLY_CONSTRUCTOR) { @@ -71,31 +71,25 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in) Node res; Rational r = in[1].getConst<Rational>(); Rational rmo = Rational(r - Rational(1)); - for (unsigned i = 0; i < in[0].getNumChildren(); i++) + for (unsigned i = 0, size = in [0].getNumChildren(); i < size; i++) { if (in[0][i].getType().isDatatype()) { if (r.isZero()) { - res = NodeManager::currentNM()->mkConst(false); + res = nm->mkConst(false); break; } - else - { - children.push_back(NodeManager::currentNM()->mkNode( - kind::DT_HEIGHT_BOUND, - in[0][i], - NodeManager::currentNM()->mkConst(rmo))); - } + children.push_back( + nm->mkNode(kind::DT_HEIGHT_BOUND, in[0][i], nm->mkConst(rmo))); } } if (res.isNull()) { res = children.size() == 0 - ? NodeManager::currentNM()->mkConst(true) + ? nm->mkConst(true) : (children.size() == 1 ? children[0] - : NodeManager::currentNM()->mkNode( - kind::AND, children)); + : nm->mkNode(kind::AND, children)); } Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: rewrite height " << in << " to " @@ -103,53 +97,42 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in) return RewriteResponse(REWRITE_AGAIN_FULL, res); } } - else if (in.getKind() == kind::DT_SIZE_BOUND) + else if (k == kind::DT_SIZE_BOUND) { if (in[0].isConst()) { - Node res = NodeManager::currentNM()->mkNode( - kind::LEQ, - NodeManager::currentNM()->mkNode(kind::DT_SIZE, in[0]), - in[1]); + Node res = nm->mkNode(kind::LEQ, nm->mkNode(kind::DT_SIZE, in[0]), in[1]); return RewriteResponse(REWRITE_AGAIN_FULL, res); } } - if (in.getKind() == kind::EQUAL) + if (k == kind::EQUAL) { if (in[0] == in[1]) { - return RewriteResponse(REWRITE_DONE, - NodeManager::currentNM()->mkConst(true)); + return RewriteResponse(REWRITE_DONE, nm->mkConst(true)); } - else + std::vector<Node> rew; + if (checkClash(in[0], in[1], rew)) { - std::vector<Node> rew; - if (checkClash(in[0], in[1], rew)) - { - Trace("datatypes-rewrite") << "Rewrite clashing equality " << in - << " to false" << std::endl; - return RewriteResponse(REWRITE_DONE, - NodeManager::currentNM()->mkConst(false)); - //}else if( rew.size()==1 && rew[0]!=in ){ - // Trace("datatypes-rewrite") << "Rewrite equality " << in << " to " << - // rew[0] << std::endl; - // return RewriteResponse(REWRITE_AGAIN_FULL, rew[0] ); - } - else if (in[1] < in[0]) - { - Node ins = NodeManager::currentNM()->mkNode(in.getKind(), in[1], in[0]); - Trace("datatypes-rewrite") << "Swap equality " << in << " to " << ins - << std::endl; - return RewriteResponse(REWRITE_DONE, ins); - } - else - { - Trace("datatypes-rewrite-debug") << "Did not rewrite equality " << in - << " " << in[0].getKind() << " " - << in[1].getKind() << std::endl; - } + Trace("datatypes-rewrite") + << "Rewrite clashing equality " << in << " to false" << std::endl; + return RewriteResponse(REWRITE_DONE, nm->mkConst(false)); + //}else if( rew.size()==1 && rew[0]!=in ){ + // Trace("datatypes-rewrite") << "Rewrite equality " << in << " to " << + // rew[0] << std::endl; + // return RewriteResponse(REWRITE_AGAIN_FULL, rew[0] ); + } + else if (in[1] < in[0]) + { + Node ins = nm->mkNode(in.getKind(), in[1], in[0]); + Trace("datatypes-rewrite") + << "Swap equality " << in << " to " << ins << std::endl; + return RewriteResponse(REWRITE_DONE, ins); } + Trace("datatypes-rewrite-debug") + << "Did not rewrite equality " << in << " " << in[0].getKind() << " " + << in[1].getKind() << std::endl; } return RewriteResponse(REWRITE_DONE, in); @@ -215,10 +198,7 @@ RewriteResponse DatatypesRewriter::rewriteConstructor(TNode in) << inn << std::endl; return RewriteResponse(REWRITE_DONE, inn); } - else - { - return RewriteResponse(REWRITE_DONE, in); - } + return RewriteResponse(REWRITE_DONE, in); } return RewriteResponse(REWRITE_DONE, in); } @@ -322,33 +302,30 @@ RewriteResponse DatatypesRewriter::rewriteTester(TNode in) return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(result)); } - else + const Datatype& dt = static_cast<DatatypeType>(in[0].getType().toType()).getDatatype(); + if (dt.getNumConstructors() == 1) { - const Datatype& dt = DatatypeType(in[0].getType().toType()).getDatatype(); - if (dt.getNumConstructors() == 1) - { - // only one constructor, so it must be - Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: " - << "only one ctor for " << dt.getName() - << " and that is " << dt[0].getName() - << std::endl; - return RewriteResponse(REWRITE_DONE, - NodeManager::currentNM()->mkConst(true)); - } - // could try dt.getNumConstructors()==2 && - // Datatype::indexOf(in.getOperator())==1 ? - else if (!options::dtUseTesters()) - { - unsigned tindex = Datatype::indexOf(in.getOperator().toExpr()); - Trace("datatypes-rewrite-debug") << "Convert " << in << " to equality " - << in[0] << " " << tindex << std::endl; - Node neq = mkTester(in[0], tindex, dt); - Assert(neq != in); - Trace("datatypes-rewrite") - << "DatatypesRewriter::postRewrite: Rewrite tester " << in << " to " - << neq << std::endl; - return RewriteResponse(REWRITE_AGAIN_FULL, neq); - } + // only one constructor, so it must be + Trace("datatypes-rewrite") + << "DatatypesRewriter::postRewrite: " + << "only one ctor for " << dt.getName() << " and that is " + << dt[0].getName() << std::endl; + return RewriteResponse(REWRITE_DONE, + NodeManager::currentNM()->mkConst(true)); + } + // could try dt.getNumConstructors()==2 && + // Datatype::indexOf(in.getOperator())==1 ? + else if (!options::dtUseTesters()) + { + unsigned tindex = Datatype::indexOf(in.getOperator().toExpr()); + Trace("datatypes-rewrite-debug") << "Convert " << in << " to equality " + << in[0] << " " << tindex << std::endl; + Node neq = mkTester(in[0], tindex, dt); + Assert(neq != in); + Trace("datatypes-rewrite") + << "DatatypesRewriter::postRewrite: Rewrite tester " << in << " to " + << neq << std::endl; + return RewriteResponse(REWRITE_AGAIN_FULL, neq); } return RewriteResponse(REWRITE_DONE, in); } @@ -368,7 +345,7 @@ bool DatatypesRewriter::checkClash(Node n1, Node n2, std::vector<Node>& rew) return true; } Assert(n1.getNumChildren() == n2.getNumChildren()); - for (unsigned i = 0; i < n1.getNumChildren(); i++) + for (unsigned i = 0, size = n1.getNumChildren(); i < size; i++) { if (checkClash(n1[i], n2[i], rew)) { @@ -397,18 +374,17 @@ Node DatatypesRewriter::getInstCons(Node n, const Datatype& dt, int index) { Assert(index >= 0 && index < (int)dt.getNumConstructors()); std::vector<Node> children; + NodeManager* nm = NodeManager::currentNM(); children.push_back(Node::fromExpr(dt[index].getConstructor())); Type t = n.getType().toType(); - for (unsigned i = 0; i < dt[index].getNumArgs(); i++) + for (unsigned i = 0, nargs = dt[index].getNumArgs(); i < nargs; i++) { - Node nc = NodeManager::currentNM()->mkNode( - kind::APPLY_SELECTOR_TOTAL, - Node::fromExpr(dt[index].getSelectorInternal(t, i)), - n); + Node nc = nm->mkNode(kind::APPLY_SELECTOR_TOTAL, + Node::fromExpr(dt[index].getSelectorInternal(t, i)), + n); children.push_back(nc); } - Node n_ic = - NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, children); + Node n_ic = nm->mkNode(kind::APPLY_CONSTRUCTOR, children); if (dt.isParametric()) { TypeNode tn = TypeNode::fromType(t); @@ -424,12 +400,10 @@ Node DatatypesRewriter::getInstCons(Node n, const Datatype& dt, int index) dt[index].getSpecializedConstructorType(n.getType().toType()); Debug("datatypes-parametric") << "Type specification is " << tspec << std::endl; - children[0] = NodeManager::currentNM()->mkNode( - kind::APPLY_TYPE_ASCRIPTION, - NodeManager::currentNM()->mkConst(AscriptionType(tspec)), - children[0]); - n_ic = - NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, children); + children[0] = nm->mkNode(kind::APPLY_TYPE_ASCRIPTION, + nm->mkConst(AscriptionType(tspec)), + children[0]); + n_ic = nm->mkNode(kind::APPLY_CONSTRUCTOR, children); Assert(n_ic.getType() == tn); } } @@ -445,7 +419,7 @@ int DatatypesRewriter::isInstCons(Node t, Node n, const Datatype& dt) int index = Datatype::indexOf(n.getOperator().toExpr()); const DatatypeConstructor& c = dt[index]; Type nt = n.getType().toType(); - for (unsigned i = 0; i < n.getNumChildren(); i++) + for (unsigned i = 0, size = n.getNumChildren(); i < size; i++) { if (n[i].getKind() != kind::APPLY_SELECTOR_TOTAL || n[i].getOperator() != Node::fromExpr(c.getSelectorInternal(nt, i)) @@ -530,19 +504,28 @@ Node DatatypesRewriter::mkTester(Node n, int i, const Datatype& dt) return NodeManager::currentNM()->mkNode( kind::APPLY_TESTER, Node::fromExpr(dt[i].getTester()), n); } - else - { #ifdef CVC4_ASSERTIONS - Node ret = n.eqNode(DatatypesRewriter::getInstCons(n, dt, i)); - Node a; - int ii = isTester(ret, a); - Assert(ii == i); - Assert(a == n); - return ret; + Node ret = n.eqNode(DatatypesRewriter::getInstCons(n, dt, i)); + Node a; + int ii = isTester(ret, a); + Assert(ii == i); + Assert(a == n); + return ret; #else - return n.eqNode(DatatypesRewriter::getInstCons(n, dt, i)); + return n.eqNode(DatatypesRewriter::getInstCons(n, dt, i)); #endif +} + +Node DatatypesRewriter::mkSplit(Node n, const Datatype& dt) +{ + std::vector<Node> splits; + for (unsigned i = 0, ncons = dt.getNumConstructors(); i < ncons; i++) + { + Node test = mkTester(n, i, dt); + splits.push_back(test); } + NodeManager* nm = NodeManager::currentNM(); + return splits.size() == 1 ? splits[0] : nm->mkNode(kind::OR, splits); } bool DatatypesRewriter::isNullaryApplyConstructor(Node n) @@ -560,7 +543,7 @@ bool DatatypesRewriter::isNullaryApplyConstructor(Node n) bool DatatypesRewriter::isNullaryConstructor(const DatatypeConstructor& c) { - for (unsigned j = 0; j < c.getNumArgs(); j++) + for (unsigned j = 0, nargs = c.getNumArgs(); j < nargs; j++) { if (c[j].getType().getRangeType().isDatatype()) { @@ -595,7 +578,7 @@ Node DatatypesRewriter::normalizeCodatatypeConstant(Node n) std::map<Node, int> eqc; std::map<int, std::vector<Node> > eqc_nodes; // partition based on top symbol - for (unsigned j = 0; j < terms.size(); j++) + for (unsigned j = 0, size = terms.size(); j < size; j++) { Node t = terms[j]; Trace("dt-nconst") << " " << t << ", cdt=" << cdts[t] << std::endl; @@ -643,7 +626,8 @@ Node DatatypesRewriter::normalizeCodatatypeConstant(Node n) prt.clear(); // partition based on children : for the first child that causes a // split, break - for (unsigned k = 0; k < eqc_nodes[eqc_curr].size(); k++) + for (unsigned k = 0, size = eqc_nodes[eqc_curr].size(); k < size; + k++) { Node t = eqc_nodes[eqc_curr][k]; Assert(t.getNumChildren() == nchildren); @@ -664,14 +648,12 @@ Node DatatypesRewriter::normalizeCodatatypeConstant(Node n) } } // move into new eqc(s) - for (std::map<int, std::vector<Node> >::iterator it = prt.begin(); - it != prt.end(); - ++it) + for (const std::pair<const int, std::vector<Node> >& p : prt) { int e = eqc_count; - for (unsigned j = 0; j < it->second.size(); j++) + for (unsigned j = 0, size = p.second.size(); j < size; j++) { - Node t = it->second[j]; + Node t = p.second[j]; eqc[t] = e; eqc_nodes[e].push_back(t); } @@ -701,11 +683,8 @@ Node DatatypesRewriter::normalizeCodatatypeConstant(Node n) std::map<int, int> eqc_stack; return normalizeCodatatypeConstantEqc(s, eqc_stack, eqc, 0); } - else - { - Trace("dt-nconst") << "...invalid." << std::endl; - return Node::null(); - } + Trace("dt-nconst") << "...invalid." << std::endl; + return Node::null(); } // normalize constant : apply to top-level codatatype constants @@ -722,7 +701,7 @@ Node DatatypesRewriter::normalizeConstant(Node n) { std::vector<Node> children; bool childrenChanged = false; - for (unsigned i = 0; i < n.getNumChildren(); i++) + for (unsigned i = 0, size = n.getNumChildren(); i < size; i++) { Node nc = normalizeConstant(n[i]); children.push_back(nc); @@ -732,16 +711,9 @@ Node DatatypesRewriter::normalizeConstant(Node n) { return NodeManager::currentNM()->mkNode(n.getKind(), children); } - else - { - return n; - } } } - else - { - return n; - } + return n; } Node DatatypesRewriter::collectRef(Node n, @@ -773,7 +745,7 @@ Node DatatypesRewriter::collectRef(Node n, std::vector<Node> children; children.push_back(n.getOperator()); bool childChanged = false; - for (unsigned i = 0; i < n.getNumChildren(); i++) + for (unsigned i = 0, size = n.getNumChildren(); i < size; i++) { Node nc = collectRef(n[i], sk, rf, rf_pending, terms, cdts); if (nc.isNull()) @@ -808,18 +780,15 @@ Node DatatypesRewriter::collectRef(Node n, { return Node::null(); } - else + Assert(sk.size() == rf_pending.size()); + Node r = rf_pending[rf_pending.size() - 1 - index]; + if (r.isNull()) { - Assert(sk.size() == rf_pending.size()); - Node r = rf_pending[rf_pending.size() - 1 - index]; - if (r.isNull()) - { - r = NodeManager::currentNM()->mkBoundVar( - sk[rf_pending.size() - 1 - index].getType()); - rf_pending[rf_pending.size() - 1 - index] = r; - } - return r; + r = NodeManager::currentNM()->mkBoundVar( + sk[rf_pending.size() - 1 - index].getType()); + rf_pending[rf_pending.size() - 1 - index] = r; } + return r; } } } @@ -839,11 +808,7 @@ Node DatatypesRewriter::normalizeCodatatypeConstantEqc( { Trace("dt-nconst-debug") << "normalizeCodatatypeConstantEqc: " << n << " depth=" << depth << std::endl; - if (eqc.find(n) == eqc.end()) - { - return n; - } - else + if (eqc.find(n) != eqc.end()) { int e = eqc[n]; std::map<int, int>::iterator it = eqc_stack.find(e); @@ -853,31 +818,24 @@ Node DatatypesRewriter::normalizeCodatatypeConstantEqc( return NodeManager::currentNM()->mkConst( UninterpretedConstant(n.getType().toType(), debruijn)); } - else + std::vector<Node> children; + bool childChanged = false; + eqc_stack[e] = depth; + for (unsigned i = 0, size = n.getNumChildren(); i < size; i++) { - std::vector<Node> children; - bool childChanged = false; - eqc_stack[e] = depth; - for (unsigned i = 0; i < n.getNumChildren(); i++) - { - Node nc = - normalizeCodatatypeConstantEqc(n[i], eqc_stack, eqc, depth + 1); - children.push_back(nc); - childChanged = childChanged || nc != n[i]; - } - eqc_stack.erase(e); - if (childChanged) - { - Assert(n.getKind() == kind::APPLY_CONSTRUCTOR); - children.insert(children.begin(), n.getOperator()); - return NodeManager::currentNM()->mkNode(n.getKind(), children); - } - else - { - return n; - } + Node nc = normalizeCodatatypeConstantEqc(n[i], eqc_stack, eqc, depth + 1); + children.push_back(nc); + childChanged = childChanged || nc != n[i]; + } + eqc_stack.erase(e); + if (childChanged) + { + Assert(n.getKind() == kind::APPLY_CONSTRUCTOR); + children.insert(children.begin(), n.getOperator()); + return NodeManager::currentNM()->mkNode(n.getKind(), children); } } + return n; } Node DatatypesRewriter::replaceDebruijn(Node n, @@ -898,7 +856,7 @@ Node DatatypesRewriter::replaceDebruijn(Node n, { std::vector<Node> children; bool childChanged = false; - for (unsigned i = 0; i < n.getNumChildren(); i++) + for (unsigned i = 0, size = n.getNumChildren(); i < size; i++) { Node nc = replaceDebruijn(n[i], orig, orig_tn, depth + 1); children.push_back(nc); diff --git a/src/theory/datatypes/datatypes_rewriter.h b/src/theory/datatypes/datatypes_rewriter.h index dd318765b..8d9ddbf50 100644 --- a/src/theory/datatypes/datatypes_rewriter.h +++ b/src/theory/datatypes/datatypes_rewriter.h @@ -61,6 +61,12 @@ public: static int isTester(Node n); /** make tester is-C( n ), where C is the i^{th} constructor of dt */ static Node mkTester(Node n, int i, const Datatype& dt); + /** make tester split + * + * Returns the formula (OR is-C1( n ) ... is-Ck( n ) ), where C1...Ck + * are the constructors of n's type (dt). + */ + static Node mkSplit(Node n, const Datatype& dt); /** returns true iff n is a constructor term with no datatype children */ static bool isNullaryApplyConstructor(Node n); /** returns true iff c is a constructor with no datatype children */ @@ -69,6 +75,12 @@ public: * * This returns the normal form of the codatatype constant n. This runs a * DFA minimization algorithm based on the private functions below. + * + * In particular, we first call collectRefs to setup initial information + * about what terms occur in n. Then, we run a DFA minimization algorithm to + * partition these subterms in equivalence classes. Finally, we call + * normalizeCodatatypeConstantEqc to construct the normalized codatatype + * constant that is equivalent to n. */ static Node normalizeCodatatypeConstant(Node n); /** normalize constant @@ -99,18 +111,90 @@ private: /** rewrite tester term in */ static RewriteResponse rewriteTester(TNode in); - /** TODO (#1436) document these */ + /** collect references + * + * This function, given as input a codatatype term n, collects the necessary + * information for constructing a (canonical) codatatype constant that is + * equivalent to n if one exists, or null otherwise. + * + * In particular it returns a term ret such that all non-codatatype datatype + * subterms of n are replaced by a constant that is equal to them via a + * (mutually) recursive call to normalizeConstant above. Additionally, this + * function replaces references to mu-binders with fresh variables. + * In detail, mu-terms are represented by uninterpreted constants of datatype + * type that carry their Debruijn index. + * + * Consider the example of a codatatype representing a stream of integers: + * Stream := cons( head : Int, tail : Stream ) + * The stream 1,0,1,0,1,0... when written in mu-notation is the term: + * mu x. cons( 1, mu y. cons( 0, x ) ) + * This is represented in CVC4 by the Node: + * cons( 1, cons( 0, c[1] ) ) + * where c[1] is a uninterpreted constant datatype with Debruijn index 1, + * indicating that c[1] is nested underneath 1 level on the path to the + * term which it binds. On the other hand, the stream 1,0,0,0,0,... is + * represented by the codatatype term: + * cons( 1, cons( 0, c[0] ) ) + * + * Subterms that are references to mu-binders in n are replaced by a new + * variable. If n contains any subterm that is a reference to a mu-binder not + * bound in n, then we return null. For example we return null when n is: + * cons( 1, cons( 0, c[2] ) ) + * since c[2] is not bound by this codatatype term. + * + * All valid references to mu-binders are replaced by a variable that is unique + * for the term it references. For example, for the infinite tree codatatype: + * Tree : node( data : Int, left : Tree, right : Tree ) + * If n is the term: + * node( 0, c[0], node( 1, c[0], c[1] ) ) + * then the return value ret of this function is: + * node( 0, x, node( 1, y, x ) ) + * where x refers to the root of the term and y refers to the right tree of the + * root. + * + * The argument sk stores the current set of node that we are traversing + * beneath. The argument rf_pending stores, for each node that we are + * traversing beneath either null or the free variable that we are using to + * refer to its mu-binder. The remaining arguments store information that is + * relevant when performing normalization of n using the value of ret: + * + * rf : maps subterms of n to the corresponding term in ret for all subterms + * where the corresponding term in ret is different. + * terms : stores all subterms of ret. + * cdts : for each term t in terms, stores whether t is a codatatype. + */ static Node collectRef(Node n, std::vector<Node>& sk, std::map<Node, Node>& rf, std::vector<Node>& rf_pending, std::vector<Node>& terms, std::map<Node, bool>& cdts); - // eqc_stack stores depth + /** normalize codatatype constant eqc + * + * This recursive function returns a codatatype constant that is equivalent to + * n based on a pre-computed partition of the subterms of n into equivalence + * classes, as stored in the mapping eqc, which maps the subterms of n to + * equivalence class ids. The arguments eqc_stack and depth store information + * about the traversal in a term we have recursed, where + * + * eqc_stack : maps the depth of each term we have traversed to its equivalence + * class id. + * depth : the number of levels which we have traversed. + */ static Node normalizeCodatatypeConstantEqc(Node n, std::map<int, int>& eqc_stack, std::map<Node, int>& eqc, int depth); + /** replace debruijn + * + * This function, given codatatype term n, returns a node + * where all subterms of n that have Debruijn indices that refer to a + * term of input depth are replaced by orig. For example, for the infinite Tree + * datatype, + * replaceDebruijn( node( 0, c[0], node( 1, c[0], c[1] ) ), t, Tree, 0 ) + * returns + * node( 0, t, node( 1, c[0], t ) ). + */ static Node replaceDebruijn(Node n, Node orig, TypeNode orig_tn, diff --git a/src/theory/datatypes/datatypes_sygus.cpp b/src/theory/datatypes/datatypes_sygus.cpp index b06c96e68..5198b44d0 100644 --- a/src/theory/datatypes/datatypes_sygus.cpp +++ b/src/theory/datatypes/datatypes_sygus.cpp @@ -34,40 +34,6 @@ using namespace CVC4::context; using namespace CVC4::theory; using namespace CVC4::theory::datatypes; -Node SygusSplitNew::getSygusSplit( quantifiers::TermDbSygus * tds, Node n, const Datatype& dt ) { - TypeNode tnn = n.getType(); - tds->registerSygusType( tnn ); - std::vector< Node > curr_splits; - for( unsigned i=0; i<dt.getNumConstructors(); i++ ){ - Trace("sygus-split-debug2") << "Add split " << n << " : constructor " << dt[i].getName() << " : "; - if( !tds->isGenericRedundant( tnn, i ) ){ - std::vector< Node > test_c; - test_c.push_back( DatatypesRewriter::mkTester( n, i, dt ) ); - Node test = test_c.size()==1 ? test_c[0] : NodeManager::currentNM()->mkNode( AND, test_c ); - curr_splits.push_back( test ); - Trace("sygus-split-debug2") << "SUCCESS" << std::endl; - Trace("sygus-split-debug") << "Disjunct #" << curr_splits.size() << " : " << test << std::endl; - }else{ - Trace("sygus-split-debug2") << "redundant operator" << std::endl; - } - } - Assert( !curr_splits.empty() ); - return curr_splits.size()==1 ? curr_splits[0] : NodeManager::currentNM()->mkNode( OR, curr_splits ); - -} - -void SygusSplitNew::getSygusSplits( Node n, const Datatype& dt, std::vector< Node >& splits, std::vector< Node >& lemmas ) { - Assert( dt.isSygus() ); - if( d_splits.find( n )==d_splits.end() ){ - Trace("sygus-split") << "Get sygus splits " << n << std::endl; - Node split = getSygusSplit( d_tds, n, dt ); - Assert( !split.isNull() ); - d_splits[n].push_back( split ); - } - //copy to splits - splits.insert( splits.end(), d_splits[n].begin(), d_splits[n].end() ); -} - SygusSymBreakNew::SygusSymBreakNew(TheoryDatatypes* td, quantifiers::TermDbSygus* tds, context::Context* c) @@ -209,39 +175,42 @@ void SygusSymBreakNew::assertIsConst( Node n, bool polarity, std::vector< Node > } Node SygusSymBreakNew::getTermOrderPredicate( Node n1, Node n2 ) { + NodeManager* nm = NodeManager::currentNM(); std::vector< Node > comm_disj; // (1) size of left is greater than size of right - Node sz_less = NodeManager::currentNM()->mkNode( GT, NodeManager::currentNM()->mkNode( DT_SIZE, n1 ), - NodeManager::currentNM()->mkNode( DT_SIZE, n2 ) ); + Node sz_less = + nm->mkNode(GT, nm->mkNode(DT_SIZE, n1), nm->mkNode(DT_SIZE, n2)); comm_disj.push_back( sz_less ); // (2) ...or sizes are equal and first child is less by term order - std::vector< Node > sz_eq_cases; - Node sz_eq = NodeManager::currentNM()->mkNode( EQUAL, NodeManager::currentNM()->mkNode( DT_SIZE, n1 ), - NodeManager::currentNM()->mkNode( DT_SIZE, n2 ) ); + std::vector<Node> sz_eq_cases; + Node sz_eq = + nm->mkNode(EQUAL, nm->mkNode(DT_SIZE, n1), nm->mkNode(DT_SIZE, n2)); sz_eq_cases.push_back( sz_eq ); if( options::sygusOpt1() ){ TypeNode tnc = n1.getType(); const Datatype& cdt = ((DatatypeType)(tnc).toType()).getDatatype(); for( unsigned j=0; j<cdt.getNumConstructors(); j++ ){ - if( !d_tds->isGenericRedundant( tnc, j ) ){ - std::vector< Node > case_conj; - for( unsigned k=0; k<j; k++ ){ - if( !d_tds->isGenericRedundant( tnc, k ) ){ - case_conj.push_back( DatatypesRewriter::mkTester( n2, k, cdt ).negate() ); - } - } - if( !case_conj.empty() ){ - Node corder = NodeManager::currentNM()->mkNode( kind::OR, DatatypesRewriter::mkTester( n1, j, cdt ).negate(), - case_conj.size()==1 ? case_conj[0] : NodeManager::currentNM()->mkNode( kind::AND, case_conj ) ); - sz_eq_cases.push_back( corder ); - } + std::vector<Node> case_conj; + for (unsigned k = 0; k < j; k++) + { + case_conj.push_back(DatatypesRewriter::mkTester(n2, k, cdt).negate()); + } + if (!case_conj.empty()) + { + Node corder = nm->mkNode( + kind::OR, + DatatypesRewriter::mkTester(n1, j, cdt).negate(), + case_conj.size() == 1 ? case_conj[0] + : nm->mkNode(kind::AND, case_conj)); + sz_eq_cases.push_back(corder); } } } - Node sz_eqc = sz_eq_cases.size()==1 ? sz_eq_cases[0] : NodeManager::currentNM()->mkNode( kind::AND, sz_eq_cases ); + Node sz_eqc = sz_eq_cases.size() == 1 ? sz_eq_cases[0] + : nm->mkNode(kind::AND, sz_eq_cases); comm_disj.push_back( sz_eqc ); - - return NodeManager::currentNM()->mkNode( kind::OR, comm_disj ); + + return nm->mkNode(kind::OR, comm_disj); } void SygusSymBreakNew::registerTerm( Node n, std::vector< Node >& lemmas ) { @@ -441,13 +410,14 @@ Node SygusSymBreakNew::getRelevancyCondition( Node n ) { std::vector< Node > disj; bool excl = false; for( unsigned i=0; i<dt.getNumConstructors(); i++ ){ - if( !d_tds->isGenericRedundant( ntn, i ) ){ - int sindexi = dt[i].getSelectorIndexInternal( selExpr ); - if( sindexi!=-1 ){ - disj.push_back( DatatypesRewriter::mkTester( n[0], i, dt ) ); - }else{ - excl = true; - } + int sindexi = dt[i].getSelectorIndexInternal(selExpr); + if (sindexi != -1) + { + disj.push_back(DatatypesRewriter::mkTester(n[0], i, dt)); + } + else + { + excl = true; } } Assert( !disj.empty() ); @@ -624,28 +594,36 @@ Node SygusSymBreakNew::getSimpleSymBreakPred( TypeNode tn, int tindex, unsigned TypeNode tnc = nc.getType(); const Datatype& cdt = ((DatatypeType)(tnc).toType()).getDatatype(); for( unsigned k=0; k<cdt.getNumConstructors(); k++ ){ - // if not already generic redundant - if( !d_tds->isGenericRedundant( tnc, k ) ){ - Kind nck = d_tds->getConsNumKind( tnc, k ); - bool red = false; - //check if the argument is redundant - if( nck!=UNDEFINED_KIND ){ - Trace("sygus-sb-simple-debug") << " argument " << j << " " << k << " is : " << nck << std::endl; - red = !d_tds->considerArgKind( tnc, tn, nck, nk, j ); + Kind nck = d_tds->getConsNumKind(tnc, k); + bool red = false; + // check if the argument is redundant + if (nck != UNDEFINED_KIND) + { + Trace("sygus-sb-simple-debug") + << " argument " << j << " " << k << " is : " << nck + << std::endl; + red = !d_tds->considerArgKind(tnc, tn, nck, nk, j); + } + else + { + Node cc = d_tds->getConsNumConst(tnc, k); + if (!cc.isNull()) + { + Trace("sygus-sb-simple-debug") + << " argument " << j << " " << k + << " is constant : " << cc << std::endl; + red = !d_tds->considerConst(tnc, tn, cc, nk, j); }else{ - Node cc = d_tds->getConsNumConst( tnc, k ); - if( !cc.isNull() ){ - Trace("sygus-sb-simple-debug") << " argument " << j << " " << k << " is constant : " << cc << std::endl; - red = !d_tds->considerConst( tnc, tn, cc, nk, j ); - }else{ - //defined function? - } - } - if( red ){ - Trace("sygus-sb-simple-debug") << " ...redundant." << std::endl; - sbp_conj.push_back( DatatypesRewriter::mkTester( nc, k, cdt ).negate() ); + // defined function? } } + if (red) + { + Trace("sygus-sb-simple-debug") + << " ...redundant." << std::endl; + sbp_conj.push_back( + DatatypesRewriter::mkTester(nc, k, cdt).negate()); + } } } } @@ -808,7 +786,38 @@ bool SygusSymBreakNew::registerSearchValue( Node a, Node n, Node nv, unsigned d, Trace("sygus-sb-exc") << " ......programs " << prev_bv << " and " << bv << " rewrite to " << bvr << "." << std::endl; } } - + + if (options::sygusRewVerify()) + { + // add to the sampler database object + std::map<Node, quantifiers::SygusSampler>::iterator its = + d_sampler.find(a); + if (its == d_sampler.end()) + { + d_sampler[a].initializeSygus(d_tds, a, options::sygusSamples()); + its = d_sampler.find(a); + } + Node sample_ret = its->second.registerTerm(bv); + d_cache[a].d_search_val_sample[nv] = sample_ret; + if (itsv != d_cache[a].d_search_val[tn].end()) + { + // if the analog of this term and another term were rewritten to the + // same term, then they should be equivalent under examples. + Node prev = itsv->second; + Node prev_sample_ret = d_cache[a].d_search_val_sample[prev]; + if (sample_ret != prev_sample_ret) + { + Node prev_bv = d_tds->sygusToBuiltin(prev, tn); + // we have detected unsoundness in the rewriter + Options& nodeManagerOptions = + NodeManager::currentNM()->getOptions(); + std::ostream* out = nodeManagerOptions.getOut(); + (*out) << "(unsound-rewrite " << prev_bv << " " << bv << ")" + << std::endl; + } + } + } + if( !bad_val_bvr.isNull() ){ Node bad_val = nv; Node bad_val_o = d_cache[a].d_search_val[tn][bad_val_bvr]; @@ -1161,16 +1170,6 @@ void SygusSymBreakNew::check( std::vector< Node >& lemmas ) { } } -void SygusSymBreakNew::getPossibleCons( const Datatype& dt, TypeNode tn, std::vector< bool >& pcons ) { - Assert( pcons.size()==dt.getNumConstructors() ); - d_tds->registerSygusType( tn ); - for( unsigned i=0; i<dt.getNumConstructors(); i++ ){ - if( d_tds->isGenericRedundant( tn, i ) ){ - pcons[i] = false; - } - } -} - bool SygusSymBreakNew::debugTesters( Node n, Node vn, int ind, std::vector< Node >& lemmas ) { Assert( vn.getKind()==kind::APPLY_CONSTRUCTOR ); if( Trace.isOn("sygus-sb-warn") ){ @@ -1191,7 +1190,7 @@ bool SygusSymBreakNew::debugTesters( Node n, Node vn, int ind, std::vector< Node Trace("sygus-sb-warn") << "- has tester : " << tst << " : " << ( hastst ? "true" : "false" ); Trace("sygus-sb-warn") << ", value=" << tstrep << std::endl; if( !hastst ){ - Node split = SygusSplitNew::getSygusSplit( d_tds, n, dt ); + Node split = DatatypesRewriter::mkSplit(n, dt); Assert( !split.isNull() ); lemmas.push_back( split ); return false; diff --git a/src/theory/datatypes/datatypes_sygus.h b/src/theory/datatypes/datatypes_sygus.h index 099b45fec..ff2d2a873 100644 --- a/src/theory/datatypes/datatypes_sygus.h +++ b/src/theory/datatypes/datatypes_sygus.h @@ -30,6 +30,7 @@ #include "expr/datatype.h" #include "expr/node.h" #include "theory/quantifiers/ce_guided_conjecture.h" +#include "theory/quantifiers/sygus_sampler.h" #include "theory/quantifiers/term_database.h" namespace CVC4 { @@ -38,19 +39,6 @@ namespace datatypes { class TheoryDatatypes; -class SygusSplitNew -{ -private: - quantifiers::TermDbSygus * d_tds; - std::map< Node, std::vector< Node > > d_splits; -public: - SygusSplitNew( quantifiers::TermDbSygus * tds ) : d_tds( tds ){} - virtual ~SygusSplitNew(){} - /** get sygus splits */ - void getSygusSplits( Node n, const Datatype& dt, std::vector< Node >& splits, std::vector< Node >& lemmas ); - static Node getSygusSplit( quantifiers::TermDbSygus * tds, Node n, const Datatype& dt ); -}; - class SygusSymBreakNew { private: @@ -80,14 +68,34 @@ private: SearchCache(){} std::map< TypeNode, std::map< unsigned, std::vector< Node > > > d_search_terms; std::map< TypeNode, std::map< unsigned, std::vector< Node > > > d_sb_lemmas; - // search values + /** search value + * + * For each sygus type, a map from a builtin term to a sygus term for that + * type that we encountered during the search whose analog rewrites to that + * term. The range of this map can be updated if we later encounter a sygus + * term that also rewrites to the builtin value but has a smaller term size. + */ std::map< TypeNode, std::map< Node, Node > > d_search_val; + /** the size of terms in the range of d_search val. */ std::map< TypeNode, std::map< Node, unsigned > > d_search_val_sz; - std::map< TypeNode, std::map< Node, Node > > d_search_val_b; + /** search value sample + * + * This is used for the sygusRewVerify() option. For each sygus term we + * register in this cache, this stores the value returned by calling + * SygusSample::registerTerm(...) on its analog. + */ + std::map<Node, Node> d_search_val_sample; + /** For each term, whether this cache has processed that term */ std::map< Node, bool > d_search_val_proc; }; // anchor -> cache std::map< Node, SearchCache > d_cache; + /** a sygus sampler object for each anchor + * + * This is used for the sygusRewVerify() option to verify the correctness of + * the rewriter. + */ + std::map<Node, quantifiers::SygusSampler> d_sampler; Node d_null; void assertTesterInternal( int tindex, TNode n, Node exp, std::vector< Node >& lemmas ); // register search term @@ -163,7 +171,6 @@ public: void assertFact( Node n, bool polarity, std::vector< Node >& lemmas ); void preRegisterTerm( TNode n, std::vector< Node >& lemmas ); void check( std::vector< Node >& lemmas ); - void getPossibleCons( const Datatype& dt, TypeNode tn, std::vector< bool >& pcons ); public: Node getNextDecisionRequest( unsigned& priority, std::vector< Node >& lemmas ); }; diff --git a/src/theory/datatypes/theory_datatypes.cpp b/src/theory/datatypes/theory_datatypes.cpp index c17c022a1..d91eace99 100644 --- a/src/theory/datatypes/theory_datatypes.cpp +++ b/src/theory/datatypes/theory_datatypes.cpp @@ -74,7 +74,6 @@ TheoryDatatypes::TheoryDatatypes(Context* c, UserContext* u, OutputChannel& out, d_zero = NodeManager::currentNM()->mkConst( Rational(0) ); d_dtfCounter = 0; - d_sygus_split = NULL; d_sygus_sym_break = NULL; } @@ -85,7 +84,6 @@ TheoryDatatypes::~TheoryDatatypes() { Assert(current != NULL); delete current; } - delete d_sygus_split; delete d_sygus_sym_break; } @@ -309,27 +307,7 @@ void TheoryDatatypes::check(Effort e) { d_out->requirePhase( test, true ); }else{ Trace("dt-split") << "*************Split for constructors on " << n << endl; - std::vector< Node > children; - if( dt.isSygus() && d_sygus_split ){ - Trace("dt-split") << "DtSygus : split on " << n - << std::endl; - std::vector< Node > lemmas; - d_sygus_split->getSygusSplits( n, dt, children, lemmas ); - Trace("dt-split") << "Finished compute split, returned " - << lemmas.size() << " lemmas." - << std::endl; - for( unsigned i=0; i<lemmas.size(); i++ ){ - Trace("dt-lemma-sygus") << "Dt sygus lemma : " << lemmas[i] << std::endl; - doSendLemma( lemmas[i] ); - } - }else{ - for( unsigned i=0; i<dt.getNumConstructors(); i++ ){ - Node test = DatatypesRewriter::mkTester( n, i, dt ); - children.push_back( test ); - } - } - Assert( !children.empty() ); - Node lemma = children.size()==1 ? children[0] : NodeManager::currentNM()->mkNode( kind::OR, children ); + Node lemma = DatatypesRewriter::mkSplit(n, dt); Trace("dt-split-debug") << "Split lemma is : " << lemma << std::endl; //doSendLemma( lemma ); d_out->lemma( lemma, false, false, true ); @@ -551,7 +529,6 @@ void TheoryDatatypes::finishInit() { if( getQuantifiersEngine() && options::ceGuidedInst() ){ quantifiers::TermDbSygus * tds = getQuantifiersEngine()->getTermDatabaseSygus(); Assert( tds!=NULL ); - d_sygus_split = new SygusSplitNew( tds ); d_sygus_sym_break = new SygusSymBreakNew( this, tds, getSatContext() ); } } @@ -1026,10 +1003,6 @@ void TheoryDatatypes::getPossibleCons( EqcInfo* eqc, Node n, std::vector< bool > Assert( tindex!=-1 ); pcons[ tindex ] = false; } - //further limit the possibilities based on grammar minimization - if( d_sygus_sym_break && dt.isSygus() ){ - d_sygus_sym_break->getPossibleCons( dt, tn, pcons ); - } } } } @@ -1157,7 +1130,7 @@ void TheoryDatatypes::addTester( int ttindex, Node t, EqcInfo* eqc, Node n, Node break; } } - Assert( dt.isSygus() || testerIndex!=-1 ); + Assert(testerIndex != -1); //we must explain why each term in the set of testers for this equivalence class is equal std::vector< Node > eq_terms; NodeBuilder<> nb(kind::AND); diff --git a/src/theory/datatypes/theory_datatypes.h b/src/theory/datatypes/theory_datatypes.h index b3d88bb1c..8052df59a 100644 --- a/src/theory/datatypes/theory_datatypes.h +++ b/src/theory/datatypes/theory_datatypes.h @@ -220,21 +220,21 @@ private: /** get eqc constructor */ TNode getEqcConstructor( TNode r ); -protected: + protected: void addCarePairs( quantifiers::TermArgTrie * t1, quantifiers::TermArgTrie * t2, unsigned arity, unsigned depth, unsigned& n_pairs ); /** compute care graph */ - void computeCareGraph(); + void computeCareGraph() override; -public: + public: TheoryDatatypes(context::Context* c, context::UserContext* u, OutputChannel& out, Valuation valuation, const LogicInfo& logicInfo); ~TheoryDatatypes(); - void setMasterEqualityEngine(eq::EqualityEngine* eq); + void setMasterEqualityEngine(eq::EqualityEngine* eq) override; /** propagate */ - void propagate(Effort effort); + void propagate(Effort effort) override; /** propagate */ bool propagate(TNode literal); /** explain */ @@ -242,7 +242,7 @@ public: void explainEquality( TNode a, TNode b, bool polarity, std::vector<TNode>& assumptions ); void explainPredicate( TNode p, bool polarity, std::vector<TNode>& assumptions ); void explain( TNode literal, std::vector<TNode>& assumptions ); - Node explain( TNode literal ); + Node explain(TNode literal) override; Node explain( std::vector< Node >& lits ); /** Conflict when merging two constants */ void conflict(TNode a, TNode b); @@ -255,26 +255,36 @@ public: /** called when two equivalence classes are made disequal */ void eqNotifyDisequal(TNode t1, TNode t2, TNode reason); - void check(Effort e); - bool needsCheckLastEffort(); - void preRegisterTerm(TNode n); - void finishInit(); - Node expandDefinition(LogicRequest &logicRequest, Node n); - Node ppRewrite(TNode n); - void presolve(); - void addSharedTerm(TNode t); - EqualityStatus getEqualityStatus(TNode a, TNode b); + void check(Effort e) override; + bool needsCheckLastEffort() override; + void preRegisterTerm(TNode n) override; + void finishInit() override; + Node expandDefinition(LogicRequest& logicRequest, Node n) override; + Node ppRewrite(TNode n) override; + void presolve() override; + void addSharedTerm(TNode t) override; + EqualityStatus getEqualityStatus(TNode a, TNode b) override; bool collectModelInfo(TheoryModel* m) override; - void shutdown() { } - std::string identify() const { return std::string("TheoryDatatypes"); } + void shutdown() override {} + std::string identify() const override + { + return std::string("TheoryDatatypes"); + } /** equality engine */ - eq::EqualityEngine * getEqualityEngine() { return &d_equalityEngine; } - bool getCurrentSubstitution( int effort, std::vector< Node >& vars, std::vector< Node >& subs, std::map< Node, std::vector< Node > >& exp ); + eq::EqualityEngine* getEqualityEngine() override { return &d_equalityEngine; } + bool getCurrentSubstitution(int effort, + std::vector<Node>& vars, + std::vector<Node>& subs, + std::map<Node, std::vector<Node> >& exp) override; /** debug print */ void printModelDebug( const char* c ); /** entailment check */ - virtual std::pair<bool, Node> entailmentCheck(TNode lit, const EntailmentCheckParameters* params = NULL, EntailmentCheckSideEffects* out = NULL); -private: + std::pair<bool, Node> entailmentCheck( + TNode lit, + const EntailmentCheckParameters* params = NULL, + EntailmentCheckSideEffects* out = NULL) override; + + private: /** add tester to equivalence class info */ void addTester( int ttindex, Node t, EqcInfo* eqc, Node n, Node t_arg ); /** add selector to equivalence class info */ @@ -321,11 +331,11 @@ private: bool areCareDisequal( TNode x, TNode y ); TNode getRepresentative( TNode a ); private: - /** sygus utilities */ - SygusSplitNew * d_sygus_split; - SygusSymBreakNew * d_sygus_sym_break; + /** sygus symmetry breaking utility */ + SygusSymBreakNew* d_sygus_sym_break; + public: - Node getNextDecisionRequest( unsigned& priority ); + Node getNextDecisionRequest(unsigned& priority) override; };/* class TheoryDatatypes */ }/* CVC4::theory::datatypes namespace */ diff --git a/src/theory/fp/fp_converter.cpp b/src/theory/fp/fp_converter.cpp index 6ce2195cb..aba95d2ec 100644 --- a/src/theory/fp/fp_converter.cpp +++ b/src/theory/fp/fp_converter.cpp @@ -25,12 +25,10 @@ FpConverter::FpConverter(context::UserContext *user) Node FpConverter::convert(TNode node) { Unimplemented("Conversion not implemented."); - return node; } Node FpConverter::getValue(Valuation &val, TNode var) { Unimplemented("Conversion not implemented."); - return Node::null(); } } // namespace fp diff --git a/src/theory/fp/theory_fp.h b/src/theory/fp/theory_fp.h index 614cbff46..ca80546b8 100644 --- a/src/theory/fp/theory_fp.h +++ b/src/theory/fp/theory_fp.h @@ -38,21 +38,21 @@ class TheoryFp : public Theory { TheoryFp(context::Context* c, context::UserContext* u, OutputChannel& out, Valuation valuation, const LogicInfo& logicInfo); - Node expandDefinition(LogicRequest& lr, Node node); + Node expandDefinition(LogicRequest& lr, Node node) override; - void preRegisterTerm(TNode node); - void addSharedTerm(TNode node); + void preRegisterTerm(TNode node) override; + void addSharedTerm(TNode node) override; - void check(Effort); + void check(Effort) override; - Node getModelValue(TNode var); + Node getModelValue(TNode var) override; bool collectModelInfo(TheoryModel* m) override; - std::string identify() const { return "THEORY_FP"; } + std::string identify() const override { return "THEORY_FP"; } - void setMasterEqualityEngine(eq::EqualityEngine* eq); + void setMasterEqualityEngine(eq::EqualityEngine* eq) override; - Node explain(TNode n); + Node explain(TNode n) override; protected: /** Equality engine */ diff --git a/src/theory/fp/theory_fp_rewriter.cpp b/src/theory/fp/theory_fp_rewriter.cpp index 98ac536ec..9fda5c2f6 100644 --- a/src/theory/fp/theory_fp_rewriter.cpp +++ b/src/theory/fp/theory_fp_rewriter.cpp @@ -63,7 +63,6 @@ namespace rewrite { RewriteResponse type (TNode node, bool) { Unreachable("sort kind (%d) found in expression?",node.getKind()); - return RewriteResponse(REWRITE_DONE, node); } RewriteResponse removeDoubleNegation (TNode node, bool) { @@ -143,7 +142,6 @@ namespace rewrite { RewriteResponse removed (TNode node, bool) { Unreachable("kind (%s) should have been removed?",kindToString(node.getKind()).c_str()); - return RewriteResponse(REWRITE_DONE, node); } RewriteResponse variable (TNode node, bool) { @@ -492,11 +490,8 @@ namespace constantFold { return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(arg1 == arg2)); - } else { - Unreachable("Equality of unknown type"); } - - return RewriteResponse(REWRITE_DONE, node); + Unreachable("Equality of unknown type"); } diff --git a/src/theory/fp/type_enumerator.h b/src/theory/fp/type_enumerator.h index 0ae6462bc..4b243c224 100644 --- a/src/theory/fp/type_enumerator.h +++ b/src/theory/fp/type_enumerator.h @@ -48,7 +48,7 @@ class FloatingPointEnumerator return NodeManager::currentNM()->mkConst(createFP()); } - FloatingPointEnumerator& operator++() { + FloatingPointEnumerator& operator++() override { const FloatingPoint current(createFP()); if (current.isNaN()) { d_enumerationComplete = true; @@ -92,7 +92,7 @@ class RoundingModeEnumerator return NodeManager::currentNM()->mkConst(d_rm); } - RoundingModeEnumerator& operator++() { + RoundingModeEnumerator& operator++() override { switch (d_rm) { case roundNearestTiesToEven: d_rm = roundTowardPositive; diff --git a/src/theory/quantifiers/bv_inverter.cpp b/src/theory/quantifiers/bv_inverter.cpp index ec88f229e..be0e4bb31 100644 --- a/src/theory/quantifiers/bv_inverter.cpp +++ b/src/theory/quantifiers/bv_inverter.cpp @@ -2,9 +2,9 @@ /*! \file bv_inverter.cpp ** \verbatim ** Top contributors (to current version): - ** Andrew Reynolds + ** Aina Niemetz, Mathias Preiner, Andrew Reynolds ** This file is part of the CVC4 project. - ** Copyright (c) 2009-2017 by the authors listed in the file AUTHORS + ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS ** in the top-level source directory) and their institutional affiliations. ** All rights reserved. See the file COPYING in the top-level source ** directory for licensing information.\endverbatim @@ -78,7 +78,7 @@ Node BvInverter::getInversionNode(Node cond, TypeNode tn, BvInverterQuery* m) } } } - + if (c.isNull()) { NodeManager* nm = NodeManager::currentNM(); @@ -198,17 +198,20 @@ Node BvInverter::getPathToPv( static Node dropChild(Node n, unsigned index) { unsigned nchildren = n.getNumChildren(); + Assert(nchildren > 0); Assert(index < nchildren); + + if (nchildren < 2) return Node::null(); + Kind k = n.getKind(); - Assert(k == BITVECTOR_AND || k == BITVECTOR_OR || k == BITVECTOR_MULT - || k == BITVECTOR_PLUS); - NodeBuilder<> nb(NodeManager::currentNM(), k); + NodeBuilder<> nb(k); for (unsigned i = 0; i < nchildren; ++i) { if (i == index) continue; nb << n[i]; } - return nb.constructNode(); + Assert(nb.getNumChildren() > 0); + return nb.getNumChildren() == 1 ? nb[0] : nb.constructNode(); } static Node getScBvUltUgt(bool pol, Kind k, Node x, Node t) @@ -224,7 +227,7 @@ static Node getScBvUltUgt(bool pol, Kind k, Node x, Node t) if (pol == true) { /* x < t - * with side condition: + * with invertibility condition: * (distinct t z) * where * z = 0 with getSize(z) = w */ @@ -235,8 +238,8 @@ static Node getScBvUltUgt(bool pol, Kind k, Node x, Node t) else { /* x >= t - * with side condition: - * true (no side condition) */ + * with invertibility condition: + * true (no invertibility condition) */ sc = nm->mkNode(NOT, nm->mkNode(k, x, t)); } } @@ -246,7 +249,7 @@ static Node getScBvUltUgt(bool pol, Kind k, Node x, Node t) if (pol == true) { /* x > t - * with side condition: + * with invertibility condition: * (distinct t ones) * where * ones = ~0 with getSize(ones) = w */ @@ -257,8 +260,8 @@ static Node getScBvUltUgt(bool pol, Kind k, Node x, Node t) else { /* x <= t - * with side condition: - * true (no side condition) */ + * with invertibility condition: + * true (no invertibility condition) */ sc = nm->mkNode(NOT, nm->mkNode(k, x, t)); } } @@ -279,7 +282,7 @@ static Node getScBvSltSgt(bool pol, Kind k, Node x, Node t) if (pol == true) { /* x < t - * with side condition: + * with invertibility condition: * (distinct t min) * where * min is the minimum signed value with getSize(min) = w */ @@ -291,8 +294,8 @@ static Node getScBvSltSgt(bool pol, Kind k, Node x, Node t) else { /* x >= t - * with side condition: - * true (no side condition) */ + * with invertibility condition: + * true (no invertibility condition) */ sc = nm->mkNode(NOT, nm->mkNode(k, x, t)); } } @@ -302,7 +305,7 @@ static Node getScBvSltSgt(bool pol, Kind k, Node x, Node t) if (pol == true) { /* x > t - * with side condition: + * with invertibility condition: * (distinct t max) * where * max is the signed maximum value with getSize(max) = w */ @@ -314,8 +317,8 @@ static Node getScBvSltSgt(bool pol, Kind k, Node x, Node t) else { /* x <= t - * with side condition: - * true (no side condition) */ + * with invertibility condition: + * true (no invertibility condition) */ sc = nm->mkNode(NOT, nm->mkNode(k, x, t)); } } @@ -348,7 +351,7 @@ static Node getScBvMult(bool pol, if (pol) { /* x * s = t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (= (bvand (bvor (bvneg s) s) t) t) * * is equivalent to: @@ -367,7 +370,7 @@ static Node getScBvMult(bool pol, else { /* x * s != t - * with side condition: + * with invertibility condition: * (or (distinct t z) (distinct s z)) * where * z = 0 with getSize(z) = w */ @@ -379,7 +382,7 @@ static Node getScBvMult(bool pol, if (pol) { /* x * s < t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (distinct t z) * where * z = 0 with getSize(z) = w */ @@ -389,7 +392,7 @@ static Node getScBvMult(bool pol, else { /* x * s >= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvuge (bvor (bvneg s) s) t) */ Node o = nm->mkNode(BITVECTOR_OR, nm->mkNode(BITVECTOR_NEG, s), s); scl = nm->mkNode(BITVECTOR_UGE, o, t); @@ -400,7 +403,7 @@ static Node getScBvMult(bool pol, if (pol) { /* x * s > t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvult t (bvor (bvneg s) s)) */ Node o = nm->mkNode(BITVECTOR_OR, nm->mkNode(BITVECTOR_NEG, s), s); scl = nm->mkNode(BITVECTOR_ULT, t, o); @@ -408,7 +411,7 @@ static Node getScBvMult(bool pol, else { /* x * s <= t - * true (no side condition) */ + * true (no invertibility condition) */ scl = nm->mkConst<bool>(true); } } @@ -417,7 +420,7 @@ static Node getScBvMult(bool pol, if (pol) { /* x * s < t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvslt (bvand (bvnot (bvneg t)) (bvor (bvneg s) s)) t) */ Node a1 = nm->mkNode(BITVECTOR_NOT, nm->mkNode(BITVECTOR_NEG, t)); Node a2 = nm->mkNode(BITVECTOR_OR, nm->mkNode(BITVECTOR_NEG, s), s); @@ -426,7 +429,7 @@ static Node getScBvMult(bool pol, else { /* x * s >= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvsge (bvand (bvor (bvneg s) s) max) t) * where * max is the signed maximum value with getSize(max) = w */ @@ -436,12 +439,13 @@ static Node getScBvMult(bool pol, scl = nm->mkNode(BITVECTOR_SGE, a, t); } } - else /* litk == BITVECTOR_SGT */ + else { + Assert(litk == BITVECTOR_SGT); if (pol) { /* x * s > t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvslt t (bvsub t (bvor (bvor s t) (bvneg s)))) */ Node o = nm->mkNode(BITVECTOR_OR, nm->mkNode(BITVECTOR_OR, s, t), nm->mkNode(BITVECTOR_NEG, s)); @@ -451,7 +455,7 @@ static Node getScBvMult(bool pol, else { /* x * s <= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (not (and (= s z) (bvslt t s))) * where * z = 0 with getSize(z) = w */ @@ -493,7 +497,7 @@ static Node getScBvUrem(bool pol, if (pol) { /* x % s = t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvuge (bvnot (bvneg s)) t) */ Node neg = nm->mkNode(BITVECTOR_NEG, s); scl = nm->mkNode(BITVECTOR_UGE, nm->mkNode(BITVECTOR_NOT, neg), t); @@ -501,7 +505,7 @@ static Node getScBvUrem(bool pol, else { /* x % s != t - * with side condition: + * with invertibility condition: * (or (distinct s (_ bv1 w)) (distinct t z)) * where * z = 0 with getSize(z) = w */ @@ -516,7 +520,7 @@ static Node getScBvUrem(bool pol, if (pol) { /* s % x = t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvuge (bvand (bvsub (bvadd t t) s) s) t) * * is equivalent to: @@ -534,7 +538,7 @@ static Node getScBvUrem(bool pol, else { /* s % x != t - * with side condition: + * with invertibility condition: * (or (distinct s z) (distinct t z)) * where * z = 0 with getSize(z) = w */ @@ -550,7 +554,7 @@ static Node getScBvUrem(bool pol, if (pol) { /* x % s < t - * with side condition: + * with invertibility condition: * (distinct t z) * where * z = 0 with getSize(z) = w */ @@ -560,7 +564,7 @@ static Node getScBvUrem(bool pol, else { /* x % s >= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvuge (bvnot (bvneg s)) t) */ Node neg = nm->mkNode(BITVECTOR_NEG, s); scl = nm->mkNode(BITVECTOR_UGE, nm->mkNode(BITVECTOR_NOT, neg), t); @@ -571,7 +575,7 @@ static Node getScBvUrem(bool pol, if (pol) { /* s % x < t - * with side condition: + * with invertibility condition: * (distinct t z) * where * z = 0 with getSize(z) = w */ @@ -581,7 +585,7 @@ static Node getScBvUrem(bool pol, else { /* s % x >= t - * with side condition (combination of = and >): + * with invertibility condition (combination of = and >): * (or * (bvuge (bvand (bvsub (bvadd t t) s) s) t) ; eq, synthesized * (bvult t s)) ; ugt, synthesized */ @@ -601,7 +605,7 @@ static Node getScBvUrem(bool pol, if (pol) { /* x % s > t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvult t (bvnot (bvneg s))) */ Node nt = nm->mkNode(BITVECTOR_NOT, nm->mkNode(BITVECTOR_NEG, s)); scl = nm->mkNode(BITVECTOR_ULT, t, nt); @@ -609,7 +613,7 @@ static Node getScBvUrem(bool pol, else { /* x % s <= t - * true (no side condition) */ + * true (no invertibility condition) */ scl = nm->mkConst<bool>(true); } } @@ -618,14 +622,14 @@ static Node getScBvUrem(bool pol, if (pol) { /* s % x > t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvult t s) */ scl = nm->mkNode(BITVECTOR_ULT, t, s); } else { /* s % x <= t - * true (no side condition) */ + * true (no invertibility condition) */ scl = nm->mkConst<bool>(true); } } @@ -637,7 +641,7 @@ static Node getScBvUrem(bool pol, if (pol) { /* x % s < t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvslt (bvnot t) (bvor (bvneg s) (bvneg t))) */ Node o1 = nm->mkNode(BITVECTOR_NEG, s); Node o2 = nm->mkNode(BITVECTOR_NEG, t); @@ -647,7 +651,7 @@ static Node getScBvUrem(bool pol, else { /* x % s >= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (or (bvslt t s) (bvsge z s)) * where * z = 0 with getSize(z) = w */ @@ -664,7 +668,7 @@ static Node getScBvUrem(bool pol, if (pol) { /* s % x < t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (or (bvslt s t) (bvslt z t)) * where * z = 0 with getSize(z) = w */ @@ -675,7 +679,7 @@ static Node getScBvUrem(bool pol, else { /* s % x >= t - * with side condition: + * with invertibility condition: * (and * (=> (bvsge s z) (bvsge s t)) * (=> (and (bvslt s z) (bvsge t z)) (bvugt (bvsub s t) t))) @@ -691,8 +695,9 @@ static Node getScBvUrem(bool pol, } } } - else /* litk == BITVECTOR_SGT */ + else { + Assert(litk == BITVECTOR_SGT); if (idx == 0) { Node z = bv::utils::mkZero(w); @@ -700,7 +705,7 @@ static Node getScBvUrem(bool pol, if (pol) { /* x % s > t - * with side condition: + * with invertibility condition: * * (and * (and @@ -724,7 +729,7 @@ static Node getScBvUrem(bool pol, else { /* x % s <= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvslt ones (bvand (bvneg s) t)) * where * z = 0 with getSize(z) = w @@ -738,7 +743,7 @@ static Node getScBvUrem(bool pol, if (pol) { /* s % x > t - * with side condition: + * with invertibility condition: * (and * (=> (bvsge s z) (bvsgt s t)) * (=> (bvslt s z) @@ -757,7 +762,7 @@ static Node getScBvUrem(bool pol, else { /* s % x <= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (or (bvult t min) (bvsge t s)) * where * min is the minimum signed value with getSize(min) = w */ @@ -802,7 +807,7 @@ static Node getScBvUdiv(bool pol, if (pol) { /* x udiv s = t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (= (bvudiv (bvmul s t) s) t) * * is equivalent to: @@ -823,7 +828,7 @@ static Node getScBvUdiv(bool pol, else { /* x udiv s != t - * with side condition: + * with invertibility condition: * (or (distinct s z) (distinct t ones)) * where * z = 0 with getSize(z) = w @@ -837,7 +842,7 @@ static Node getScBvUdiv(bool pol, if (pol) { /* s udiv x = t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (= (bvudiv s (bvudiv s t)) t) * * is equivalent to: @@ -860,10 +865,10 @@ static Node getScBvUdiv(bool pol, else { /* s udiv x != t - * with side condition (w > 1): - * true (no side condition) + * with invertibility condition (w > 1): + * true (no invertibility condition) * - * with side condition (w == 1): + * with invertibility condition (w == 1): * (= (bvand s t) z) * * where @@ -886,7 +891,7 @@ static Node getScBvUdiv(bool pol, if (pol) { /* x udiv s < t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (and (bvult z s) (bvult z t)) * where * z = 0 with getSize(z) = w */ @@ -897,7 +902,7 @@ static Node getScBvUdiv(bool pol, else { /* x udiv s >= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (= (bvand (bvudiv (bvmul s t) t) s) s) */ Node mul = nm->mkNode(BITVECTOR_MULT, s, t); Node div = nm->mkNode(BITVECTOR_UDIV_TOTAL, mul, t); @@ -909,7 +914,7 @@ static Node getScBvUdiv(bool pol, if (pol) { /* s udiv x < t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (and (bvult z (bvnot (bvand (bvneg t) s))) (bvult z t)) * where * z = 0 with getSize(z) = w */ @@ -921,7 +926,7 @@ static Node getScBvUdiv(bool pol, else { /* s udiv x >= t - * true (no side condition) */ + * true (no invertibility condition) */ scl = nm->mkConst<bool>(true); } } @@ -933,7 +938,7 @@ static Node getScBvUdiv(bool pol, if (pol) { /* x udiv s > t - * with side condition: + * with invertibility condition: * (bvugt (bvudiv ones s) t) * where * ones = ~0 with getSize(ones) = w */ @@ -944,7 +949,7 @@ static Node getScBvUdiv(bool pol, else { /* x udiv s <= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvuge (bvor s t) (bvnot (bvneg s))) */ Node u1 = nm->mkNode(BITVECTOR_OR, s, t); Node u2 = nm->mkNode(BITVECTOR_NOT, nm->mkNode(BITVECTOR_NEG, s)); @@ -956,7 +961,7 @@ static Node getScBvUdiv(bool pol, if (pol) { /* s udiv x > t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvult t ones) * where * ones = ~0 with getSize(ones) = w */ @@ -966,7 +971,7 @@ static Node getScBvUdiv(bool pol, else { /* s udiv x <= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvult z (bvor (bvnot s) t)) * where * z = 0 with getSize(z) = w */ @@ -982,7 +987,7 @@ static Node getScBvUdiv(bool pol, if (pol) { /* x udiv s < t - * with side condition: + * with invertibility condition: * (=> (bvsle t z) (bvslt (bvudiv min s) t)) * where * z = 0 with getSize(z) = w @@ -996,7 +1001,7 @@ static Node getScBvUdiv(bool pol, else { /* x udiv s >= t - * with side condition: + * with invertibility condition: * (or * (bvsge (bvudiv ones s) t) * (bvsge (bvudiv max s) t)) @@ -1017,7 +1022,7 @@ static Node getScBvUdiv(bool pol, if (pol) { /* s udiv x < t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (or (bvslt s t) (bvsge t z)) * where * z = 0 with getSize(z) = w */ @@ -1028,12 +1033,12 @@ static Node getScBvUdiv(bool pol, else { /* s udiv x >= t - * with side condition (w > 1): + * with invertibility condition (w > 1): * (and * (=> (bvsge s z) (bvsge s t)) * (=> (bvslt s z) (bvsge (bvlshr s (_ bv1 w)) t))) * - * with side condition (w == 1): + * with invertibility condition (w == 1): * (bvsge s t) * * where @@ -1057,14 +1062,15 @@ static Node getScBvUdiv(bool pol, } } } - else /* litk == BITVECTOR_SGT */ + else { + Assert(litk == BITVECTOR_SGT); if (idx == 0) { if (pol) { /* x udiv s > t - * with side condition: + * with invertibility condition: * (or * (bvsgt (bvudiv ones s) t) * (bvsgt (bvudiv max s) t)) @@ -1082,7 +1088,7 @@ static Node getScBvUdiv(bool pol, else { /* x udiv s <= t - * with side condition (combination of = and <): + * with invertibility condition (combination of = and <): * (or * (= (bvudiv (bvmul s t) s) t) ; eq, synthesized * (=> (bvsle t z) (bvslt (bvudiv min s) t))) ; slt @@ -1105,12 +1111,12 @@ static Node getScBvUdiv(bool pol, if (pol) { /* s udiv x > t - * with side condition (w > 1): + * with invertibility condition (w > 1): * (and * (=> (bvsge s z) (bvsgt s t)) * (=> (bvslt s z) (bvsgt (bvlshr s (_ bv1 w)) t))) * - * with side condition (w == 1): + * with invertibility condition (w == 1): * (bvsgt s t) * * where @@ -1134,7 +1140,7 @@ static Node getScBvUdiv(bool pol, else { /* s udiv x <= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (not (and (bvslt t (bvnot #x0)) (bvslt t s))) * <-> * (or (bvsge t ones) (bvsge t s)) @@ -1178,7 +1184,7 @@ static Node getScBvAndOr(bool pol, { /* x & s = t * x | s = t - * with side condition: + * with invertibility condition: * (= (bvand t s) t) * (= (bvor t s) t) */ scl = nm->mkNode(EQUAL, t, nm->mkNode(k, t, s)); @@ -1188,7 +1194,7 @@ static Node getScBvAndOr(bool pol, if (k == BITVECTOR_AND) { /* x & s = t - * with side condition: + * with invertibility condition: * (or (distinct s z) (distinct t z)) * where * z = 0 with getSize(z) = w */ @@ -1198,7 +1204,7 @@ static Node getScBvAndOr(bool pol, else { /* x | s = t - * with side condition: + * with invertibility condition: * (or (distinct s ones) (distinct t ones)) * where * ones = ~0 with getSize(ones) = w */ @@ -1214,7 +1220,7 @@ static Node getScBvAndOr(bool pol, if (k == BITVECTOR_AND) { /* x & s < t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (distinct t z) * where * z = 0 with getSize(z) = 0 */ @@ -1224,7 +1230,7 @@ static Node getScBvAndOr(bool pol, else { /* x | s < t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvult s t) */ scl = nm->mkNode(BITVECTOR_ULT, s, t); } @@ -1234,15 +1240,15 @@ static Node getScBvAndOr(bool pol, if (k == BITVECTOR_AND) { /* x & s >= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvuge s t) */ scl = nm->mkNode(BITVECTOR_UGE, s, t); } else { /* x | s >= t - * with side condition (synthesized): - * true (no side condition) */ + * with invertibility condition (synthesized): + * true (no invertibility condition) */ scl = nm->mkConst<bool>(true); } } @@ -1254,14 +1260,14 @@ static Node getScBvAndOr(bool pol, if (k == BITVECTOR_AND) { /* x & s > t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvult t s) */ scl = nm->mkNode(BITVECTOR_ULT, t, s); } else { /* x | s > t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvult t ones) * where * ones = ~0 with getSize(ones) = w */ @@ -1273,14 +1279,14 @@ static Node getScBvAndOr(bool pol, if (k == BITVECTOR_AND) { /* x & s <= t - * with side condition (synthesized): - * true (no side condition) */ + * with invertibility condition (synthesized): + * true (no invertibility condition) */ scl = nm->mkConst<bool>(true); } else { /* x | s <= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvuge t s) */ scl = nm->mkNode(BITVECTOR_UGE, t, s); } @@ -1293,7 +1299,7 @@ static Node getScBvAndOr(bool pol, if (k == BITVECTOR_AND) { /* x & s < t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvslt (bvand (bvnot (bvneg t)) s) t) */ Node nnt = nm->mkNode(BITVECTOR_NOT, nm->mkNode(BITVECTOR_NEG, t)); scl = nm->mkNode(BITVECTOR_SLT, nm->mkNode(BITVECTOR_AND, nnt, s), t); @@ -1301,7 +1307,7 @@ static Node getScBvAndOr(bool pol, else { /* x | s < t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvslt (bvor (bvnot (bvsub s t)) s) t) */ Node st = nm->mkNode(BITVECTOR_NOT, nm->mkNode(BITVECTOR_SUB, s, t)); scl = nm->mkNode(BITVECTOR_SLT, nm->mkNode(BITVECTOR_OR, st, s), t); @@ -1312,7 +1318,7 @@ static Node getScBvAndOr(bool pol, if (k == BITVECTOR_AND) { /* x & s >= t - * with side condition (case = combined with synthesized bvsgt): + * with invertibility condition (case = combined with synthesized bvsgt): * (or * (= (bvand s t) t) * (bvslt t (bvand (bvsub t s) s))) */ @@ -1326,7 +1332,7 @@ static Node getScBvAndOr(bool pol, else { /* x | s >= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvsge s (bvand s t)) */ scl = nm->mkNode(BITVECTOR_SGE, s, nm->mkNode(BITVECTOR_AND, s, t)); } @@ -1339,7 +1345,7 @@ static Node getScBvAndOr(bool pol, { /* x & s > t * x | s > t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvslt t (bvand s max)) * (bvslt t (bvor s max)) * where @@ -1352,7 +1358,7 @@ static Node getScBvAndOr(bool pol, if (k == BITVECTOR_AND) { /* x & s <= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvuge s (bvand t min)) * where * min is the signed minimum value with getSize(min) = w */ @@ -1362,7 +1368,7 @@ static Node getScBvAndOr(bool pol, else { /* x | s <= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvsge t (bvor s min)) * where * min is the signed minimum value with getSize(min) = w */ @@ -1427,7 +1433,7 @@ static Node getScBvLshr(bool pol, if (pol) { /* x >> s = t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (= (bvlshr (bvshl t s) s) t) */ Node shl = nm->mkNode(BITVECTOR_SHL, t, s); Node lshr = nm->mkNode(BITVECTOR_LSHR, shl, s); @@ -1436,7 +1442,7 @@ static Node getScBvLshr(bool pol, else { /* x >> s != t - * with side condition: + * with invertibility condition: * (or (distinct t z) (bvult s w)) * where * z = 0 with getSize(z) = w @@ -1451,7 +1457,7 @@ static Node getScBvLshr(bool pol, if (pol) { /* s >> x = t - * with side condition: + * with invertibility condition: * (or (= (bvlshr s i) t) ...) * for i in 0..w */ scl = defaultShiftSc(EQUAL, BITVECTOR_LSHR, s, t); @@ -1459,7 +1465,7 @@ static Node getScBvLshr(bool pol, else { /* s >> x != t - * with side condition: + * with invertibility condition: * (or (distinct s z) (distinct t z)) * where * z = 0 with getSize(z) = w */ @@ -1474,7 +1480,7 @@ static Node getScBvLshr(bool pol, if (pol) { /* x >> s < t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (distinct t z) * where * z = 0 with getSize(z) = w */ @@ -1483,7 +1489,7 @@ static Node getScBvLshr(bool pol, else { /* x >> s >= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (= (bvlshr (bvshl t s) s) t) */ Node ts = nm->mkNode(BITVECTOR_SHL, t, s); scl = nm->mkNode(BITVECTOR_LSHR, ts, s).eqNode(t); @@ -1494,7 +1500,7 @@ static Node getScBvLshr(bool pol, if (pol) { /* s >> x < t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (distinct t z) * where * z = 0 with getSize(z) = w */ @@ -1503,7 +1509,7 @@ static Node getScBvLshr(bool pol, else { /* s >> x >= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvuge s t) */ scl = nm->mkNode(BITVECTOR_UGE, s, t); } @@ -1516,7 +1522,7 @@ static Node getScBvLshr(bool pol, if (pol) { /* x >> s > t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvult t (bvlshr (bvnot s) s)) */ Node lshr = nm->mkNode(BITVECTOR_LSHR, nm->mkNode(BITVECTOR_NOT, s), s); scl = nm->mkNode(BITVECTOR_ULT, t, lshr); @@ -1524,8 +1530,8 @@ static Node getScBvLshr(bool pol, else { /* x >> s <= t - * with side condition: - * true (no side condition) */ + * with invertibility condition: + * true (no invertibility condition) */ scl = nm->mkConst<bool>(true); } } @@ -1534,15 +1540,15 @@ static Node getScBvLshr(bool pol, if (pol) { /* s >> x > t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvult t s) */ scl = nm->mkNode(BITVECTOR_ULT, t, s); } else { /* s >> x <= t - * with side condition: - * true (no side condition) */ + * with invertibility condition: + * true (no invertibility condition) */ scl = nm->mkConst<bool>(true); } } @@ -1554,7 +1560,7 @@ static Node getScBvLshr(bool pol, if (pol) { /* x >> s < t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvslt (bvlshr (bvnot (bvneg t)) s) t) */ Node nnt = nm->mkNode(BITVECTOR_NOT, nm->mkNode(BITVECTOR_NEG, t)); Node lshr = nm->mkNode(BITVECTOR_LSHR, nnt, s); @@ -1563,7 +1569,7 @@ static Node getScBvLshr(bool pol, else { /* x >> s >= t - * with side condition: + * with invertibility condition: * (=> (not (= s z)) (bvsge (bvlshr ones s) t)) * where * z = 0 with getSize(z) = w @@ -1579,7 +1585,7 @@ static Node getScBvLshr(bool pol, if (pol) { /* s >> x < t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (or (bvslt s t) (bvslt z t)) * where * z = 0 with getSize(z) = w */ @@ -1590,7 +1596,7 @@ static Node getScBvLshr(bool pol, else { /* s >> x >= t - * with side condition: + * with invertibility condition: * (and * (=> (bvslt s z) (bvsge (bvlshr s (_ bv1 w)) t)) * (=> (bvsge s z) (bvsge s t))) @@ -1613,7 +1619,7 @@ static Node getScBvLshr(bool pol, if (pol) { /* x >> s > t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvslt t (bvlshr (bvshl max s) s)) * where * max is the signed maximum value with getSize(max) = w */ @@ -1625,7 +1631,7 @@ static Node getScBvLshr(bool pol, else { /* x >> s <= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvsge t (bvlshr t s)) */ scl = nm->mkNode(BITVECTOR_SGE, t, nm->mkNode(BITVECTOR_LSHR, t, s)); } @@ -1635,7 +1641,7 @@ static Node getScBvLshr(bool pol, if (pol) { /* s >> x > t - * with side condition: + * with invertibility condition: * (and * (=> (bvslt s z) (bvsgt (bvlshr s one) t)) * (=> (bvsge s z) (bvsgt s t))) @@ -1651,7 +1657,7 @@ static Node getScBvLshr(bool pol, else { /* s >> x <= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (or (bvult t min) (bvsge t s)) * where * min is the minimum signed value with getSize(min) = w */ @@ -1695,7 +1701,7 @@ static Node getScBvAshr(bool pol, if (pol) { /* x >> s = t - * with side condition: + * with invertibility condition: * (and * (=> (bvult s w) (= (bvashr (bvshl t s) s) t)) * (=> (bvuge s w) (or (= t ones) (= t z))) @@ -1717,7 +1723,7 @@ static Node getScBvAshr(bool pol, else { /* x >> s != t - * true (no side condition) */ + * true (no invertibility condition) */ scl = nm->mkConst<bool>(true); } } @@ -1726,7 +1732,7 @@ static Node getScBvAshr(bool pol, if (pol) { /* s >> x = t - * with side condition: + * with invertibility condition: * (or (= (bvashr s i) t) ...) * for i in 0..w */ scl = defaultShiftSc(EQUAL, BITVECTOR_ASHR, s, t); @@ -1734,7 +1740,7 @@ static Node getScBvAshr(bool pol, else { /* s >> x != t - * with side condition: + * with invertibility condition: * (and * (or (not (= t z)) (not (= s z))) * (or (not (= t ones)) (not (= s ones)))) @@ -1754,7 +1760,7 @@ static Node getScBvAshr(bool pol, if (pol) { /* x >> s < t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (distinct t z) * where * z = 0 with getSize(z) = w */ @@ -1763,8 +1769,8 @@ static Node getScBvAshr(bool pol, else { /* x >> s >= t - * with side condition (synthesized): - * true (no side condition) */ + * with invertibility condition (synthesized): + * true (no invertibility condition) */ scl = nm->mkConst<bool>(true); } } @@ -1773,7 +1779,7 @@ static Node getScBvAshr(bool pol, if (pol) { /* s >> x < t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (and (not (and (bvuge s t) (bvslt s z))) (not (= t z))) * where * z = 0 with getSize(z) = w */ @@ -1785,7 +1791,7 @@ static Node getScBvAshr(bool pol, else { /* s >> x >= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (not (and (bvult s (bvnot s)) (bvult s t))) */ Node ss = nm->mkNode(BITVECTOR_ULT, s, nm->mkNode(BITVECTOR_NOT, s)); Node st = nm->mkNode(BITVECTOR_ULT, s, t); @@ -1800,7 +1806,7 @@ static Node getScBvAshr(bool pol, if (pol) { /* x >> s > t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvult t ones) * where * ones = ~0 with getSize(ones) = w */ @@ -1809,8 +1815,8 @@ static Node getScBvAshr(bool pol, else { /* x >> s <= t - * with side condition (synthesized): - * true (no side condition) */ + * with invertibility condition (synthesized): + * true (no invertibility condition) */ scl = nm->mkConst<bool>(true); } } @@ -1819,7 +1825,7 @@ static Node getScBvAshr(bool pol, if (pol) { /* s >> x > t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (or (bvslt s (bvlshr s (bvnot t))) (bvult t s)) */ Node lshr = nm->mkNode(BITVECTOR_LSHR, s, nm->mkNode(BITVECTOR_NOT, t)); Node ts = nm->mkNode(BITVECTOR_ULT, t, s); @@ -1829,7 +1835,7 @@ static Node getScBvAshr(bool pol, else { /* s >> x <= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (or (bvult s min) (bvuge t s)) * where * min is the minimum signed value with getSize(min) = w */ @@ -1847,7 +1853,7 @@ static Node getScBvAshr(bool pol, if (pol) { /* x >> s < t - * with side condition: + * with invertibility condition: * (bvslt (bvashr min s) t) * where * min is the minimum signed value with getSize(min) = w */ @@ -1857,7 +1863,7 @@ static Node getScBvAshr(bool pol, else { /* x >> s >= t - * with side condition: + * with invertibility condition: * (bvsge (bvlshr max s) t) * where * max is the signed maximum value with getSize(max) = w */ @@ -1870,7 +1876,7 @@ static Node getScBvAshr(bool pol, if (pol) { /* s >> x < t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (or (bvslt s t) (bvslt z t)) * where * z = 0 and getSize(z) = w */ @@ -1881,7 +1887,7 @@ static Node getScBvAshr(bool pol, else { /* s >> x >= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (not (and (bvult t (bvnot t)) (bvslt s t))) */ Node tt = nm->mkNode(BITVECTOR_ULT, t, nm->mkNode(BITVECTOR_NOT, t)); Node st = nm->mkNode(BITVECTOR_SLT, s, t); @@ -1899,7 +1905,7 @@ static Node getScBvAshr(bool pol, if (pol) { /* x >> s > t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvslt t (bvlshr max s))) * where * max is the signed maximum value with getSize(max) = w */ @@ -1908,7 +1914,7 @@ static Node getScBvAshr(bool pol, else { /* x >> s <= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvsge t (bvnot (bvlshr max s))) * where * max is the signed maximum value with getSize(max) = w */ @@ -1920,7 +1926,7 @@ static Node getScBvAshr(bool pol, if (pol) { /* s >> x > t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (and (bvslt t (bvand s max)) (bvslt t (bvor s max))) * where * max is the signed maximum value with getSize(max) = w */ @@ -1933,7 +1939,7 @@ static Node getScBvAshr(bool pol, else { /* s >> x <= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (or (bvsge t z) (bvsge t s)) * where * z = 0 and getSize(z) = w */ @@ -1977,7 +1983,7 @@ static Node getScBvShl(bool pol, if (pol) { /* x << s = t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (= (bvshl (bvlshr t s) s) t) */ Node lshr = nm->mkNode(BITVECTOR_LSHR, t, s); Node shl = nm->mkNode(BITVECTOR_SHL, lshr, s); @@ -1986,7 +1992,7 @@ static Node getScBvShl(bool pol, else { /* x << s != t - * with side condition: + * with invertibility condition: * (or (distinct t z) (bvult s w)) * with * w = getSize(s) = getSize(t) @@ -2001,7 +2007,7 @@ static Node getScBvShl(bool pol, if (pol) { /* s << x = t - * with side condition: + * with invertibility condition: * (or (= (bvshl s i) t) ...) * for i in 0..w */ scl = defaultShiftSc(EQUAL, BITVECTOR_SHL, s, t); @@ -2009,7 +2015,7 @@ static Node getScBvShl(bool pol, else { /* s << x != t - * with side condition: + * with invertibility condition: * (or (distinct s z) (distinct t z)) * where * z = 0 with getSize(z) = w */ @@ -2024,14 +2030,14 @@ static Node getScBvShl(bool pol, if (pol) { /* x << s < t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (not (= t z)) */ scl = t.eqNode(z).notNode(); } else { /* x << s >= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvuge (bvshl ones s) t) */ Node shl = nm->mkNode(BITVECTOR_SHL, bv::utils::mkOnes(w), s); scl = nm->mkNode(BITVECTOR_UGE, shl, t); @@ -2042,14 +2048,14 @@ static Node getScBvShl(bool pol, if (pol) { /* s << x < t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (not (= t z)) */ scl = t.eqNode(z).notNode(); } else { /* s << x >= t - * with side condition: + * with invertibility condition: * (or (bvuge (bvshl s i) t) ...) * for i in 0..w */ scl = defaultShiftSc(BITVECTOR_UGE, BITVECTOR_SHL, s, t); @@ -2063,7 +2069,7 @@ static Node getScBvShl(bool pol, if (pol) { /* x << s > t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvult t (bvshl ones s)) * where * ones = ~0 with getSize(ones) = w */ @@ -2073,8 +2079,8 @@ static Node getScBvShl(bool pol, else { /* x << s <= t - * with side condition: - * true (no side condition) */ + * with invertibility condition: + * true (no invertibility condition) */ scl = nm->mkConst<bool>(true); } } @@ -2083,7 +2089,7 @@ static Node getScBvShl(bool pol, if (pol) { /* s << x > t - * with side condition: + * with invertibility condition: * (or (bvugt (bvshl s i) t) ...) * for i in 0..w */ scl = defaultShiftSc(BITVECTOR_UGT, BITVECTOR_SHL, s, t); @@ -2091,8 +2097,8 @@ static Node getScBvShl(bool pol, else { /* s << x <= t - * with side condition: - * true (no side condition) */ + * with invertibility condition: + * true (no invertibility condition) */ scl = nm->mkConst<bool>(true); } } @@ -2104,7 +2110,7 @@ static Node getScBvShl(bool pol, if (pol) { /* x << s < t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvslt (bvshl (bvlshr min s) s) t) * where * min is the signed minimum value with getSize(min) = w */ @@ -2116,7 +2122,7 @@ static Node getScBvShl(bool pol, else { /* x << s >= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvsge (bvand (bvshl max s) max) t) * where * max is the signed maximum value with getSize(max) = w */ @@ -2130,7 +2136,7 @@ static Node getScBvShl(bool pol, if (pol) { /* s << x < t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvult (bvshl min s) (bvadd t min)) * where * min is the signed minimum value with getSize(min) = w */ @@ -2142,7 +2148,7 @@ static Node getScBvShl(bool pol, else { /* s << x >= t - * with side condition: + * with invertibility condition: * (or (bvsge (bvshl s i) t) ...) * for i in 0..w */ scl = defaultShiftSc(BITVECTOR_SGE, BITVECTOR_SHL, s, t); @@ -2157,7 +2163,7 @@ static Node getScBvShl(bool pol, if (pol) { /* x << s > t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvslt t (bvand (bvshl max s) max)) * where * max is the signed maximum value with getSize(max) = w */ @@ -2168,7 +2174,7 @@ static Node getScBvShl(bool pol, else { /* x << s <= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvult (bvlshr t (bvlshr t s)) min) * where * min is the signed minimum value with getSize(min) = w */ @@ -2182,7 +2188,7 @@ static Node getScBvShl(bool pol, if (pol) { /* s << x > t - * with side condition: + * with invertibility condition: * (or (bvsgt (bvshl s i) t) ...) * for i in 0..w */ scl = defaultShiftSc(BITVECTOR_SGT, BITVECTOR_SHL, s, t); @@ -2190,7 +2196,7 @@ static Node getScBvShl(bool pol, else { /* s << x <= t - * with side condition (synthesized): + * with invertibility condition (synthesized): * (bvult (bvlshr t s) min) * where * min is the signed minimum value with getSize(min) = w */ @@ -2206,6 +2212,609 @@ static Node getScBvShl(bool pol, return sc; } +static Node getScBvConcat(bool pol, + Kind litk, + unsigned idx, + Node x, + Node sv_t, + Node t) +{ + Assert(litk == EQUAL + || litk == BITVECTOR_ULT || litk == BITVECTOR_SLT + || litk == BITVECTOR_UGT || litk == BITVECTOR_SGT); + + Kind k = sv_t.getKind(); + Assert(k == BITVECTOR_CONCAT); + NodeManager* nm = NodeManager::currentNM(); + unsigned nchildren = sv_t.getNumChildren(); + unsigned w1 = 0, w2 = 0; + unsigned w = bv::utils::getSize(t), wx = bv::utils::getSize(x); + NodeBuilder<> nbs1(BITVECTOR_CONCAT), nbs2(BITVECTOR_CONCAT); + Node s1, s2; + Node t1, t2, tx; + Node scl, scr; + + if (idx != 0) + { + if (idx == 1) + { + s1 = sv_t[0]; + } + else + { + for (unsigned i = 0; i < idx; ++i) { nbs1 << sv_t[i]; } + s1 = nbs1.constructNode(); + } + w1 = bv::utils::getSize(s1); + t1 = bv::utils::mkExtract(t, w - 1, w - w1); + } + + tx = bv::utils::mkExtract(t, w - w1 - 1 , w - w1 - wx); + + if (idx != nchildren-1) + { + if (idx == nchildren-2) + { + s2 = sv_t[nchildren-1]; + } + else + { + for (unsigned i = idx+1; i < nchildren; ++i) { nbs2 << sv_t[i]; } + s2 = nbs2.constructNode(); + } + w2 = bv::utils::getSize(s2); + Assert(w2 == w - w1 - wx); + t2 = bv::utils::mkExtract(t, w2 - 1, 0); + } + + Assert(!s1.isNull() || t1.isNull()); + Assert(!s2.isNull() || t2.isNull()); + Assert(!s1.isNull() || !s2.isNull()); + Assert(s1.isNull() || w1 == bv::utils::getSize(t1)); + Assert(s2.isNull() || w2 == bv::utils::getSize(t2)); + + if (litk == EQUAL) + { + if (s1.isNull()) + { + if (pol) + { + /* x o s2 = t (interpret t as tx o t2) + * with invertibility condition: + * (= s2 t2) */ + scl = s2.eqNode(t2); + } + else + { + /* x o s2 != t + * true (no invertibility condition) */ + scl = nm->mkConst<bool>(true); + } + } + else if (s2.isNull()) + { + if (pol) + { + /* s1 o x = t (interpret t as t1 o tx) + * with invertibility condition: + * (= s1 t1) */ + scl = s1.eqNode(t1); + } + else + { + /* s1 o x != t + * true (no invertibility condition) */ + scl = nm->mkConst<bool>(true); + } + } + else + { + if (pol) + { + /* s1 o x o s2 = t (interpret t as t1 o tx o t2) + * with invertibility condition: + * (and (= s1 t1) (= s2 t2)) */ + scl = nm->mkNode(AND, s1.eqNode(t1), s2.eqNode(t2)); + } + else + { + /* s1 o x o s2 != t + * true (no invertibility condition) */ + scl = nm->mkConst<bool>(true); + } + } + } + else if (litk == BITVECTOR_ULT) + { + if (s1.isNull()) + { + if (pol) + { + /* x o s2 < t (interpret t as tx o t2) + * with invertibility condition: + * (=> (= tx z) (bvult s2 t2)) + * where + * z = 0 with getSize(z) = wx */ + Node z = bv::utils::mkZero(wx); + Node ult = nm->mkNode(BITVECTOR_ULT, s2, t2); + scl = nm->mkNode(IMPLIES, tx.eqNode(z), ult); + } + else + { + /* x o s2 >= t (interpret t as tx o t2) + * (=> (= tx ones) (bvuge s2 t2)) + * where + * ones = ~0 with getSize(ones) = wx */ + Node n = bv::utils::mkOnes(wx); + Node uge = nm->mkNode(BITVECTOR_UGE, s2, t2); + scl = nm->mkNode(IMPLIES, tx.eqNode(n), uge); + } + } + else if (s2.isNull()) + { + if (pol) + { + /* s1 o x < t (interpret t as t1 o tx) + * with invertibility condition: + * (and (bvule s1 t1) (=> (= s1 t1) (distinct tx z))) + * where + * z = 0 with getSize(z) = wx */ + Node z = bv::utils::mkZero(wx); + Node ule = nm->mkNode(BITVECTOR_ULE, s1, t1); + Node imp = nm->mkNode(IMPLIES, s1.eqNode(t1), tx.eqNode(z).notNode()); + scl = nm->mkNode(AND, ule, imp); + } + else + { + /* s1 o x >= t (interpret t as t1 o tx) + * with invertibility condition: + * (bvuge s1 t1) */ + scl = nm->mkNode(BITVECTOR_UGE, s1, t1); + } + } + else + { + if (pol) + { + /* s1 o x o s2 < t (interpret t as t1 o tx o t2) + * with invertibility condition: + * (and + * (bvule s1 t1) + * (=> (and (= s1 t1) (= tx z)) (bvult s2 t2))) + * where + * z = 0 with getSize(z) = wx */ + Node z = bv::utils::mkZero(wx); + Node ule = nm->mkNode(BITVECTOR_ULE, s1, t1); + Node a = nm->mkNode(AND, s1.eqNode(t1), tx.eqNode(z)); + Node imp = nm->mkNode(IMPLIES, a, nm->mkNode(BITVECTOR_ULT, s2, t2)); + scl = nm->mkNode(AND, ule, imp); + } + else + { + /* s1 o x o s2 >= t (interpret t as t1 o tx o t2) + * with invertibility condition: + * (and + * (bvuge s1 t1) + * (=> (and (= s1 t1) (= tx ones)) (bvuge s2 t2))) + * where + * ones = ~0 with getSize(ones) = wx */ + Node n = bv::utils::mkOnes(wx); + Node uge = nm->mkNode(BITVECTOR_UGE, s1, t1); + Node a = nm->mkNode(AND, s1.eqNode(t1), tx.eqNode(n)); + Node imp = nm->mkNode(IMPLIES, a, nm->mkNode(BITVECTOR_UGE, s2, t2)); + scl = nm->mkNode(AND, uge, imp); + } + } + } + else if (litk == BITVECTOR_UGT) + { + if (s1.isNull()) + { + if (pol) + { + /* x o s2 > t (interpret t as tx o t2) + * with invertibility condition: + * (=> (= tx ones) (bvugt s2 t2)) + * where + * ones = ~0 with getSize(ones) = wx */ + Node n = bv::utils::mkOnes(wx); + Node ugt = nm->mkNode(BITVECTOR_UGT, s2, t2); + scl = nm->mkNode(IMPLIES, tx.eqNode(n), ugt); + } + else + { + /* x o s2 <= t (interpret t as tx o t2) + * with invertibility condition: + * (=> (= tx z) (bvule s2 t2)) + * where + * z = 0 with getSize(z) = wx */ + Node z = bv::utils::mkZero(wx); + Node ule = nm->mkNode(BITVECTOR_ULE, s2, t2); + scl = nm->mkNode(IMPLIES, tx.eqNode(z), ule); + } + } + else if (s2.isNull()) + { + if (pol) + { + /* s1 o x > t (interpret t as t1 o tx) + * with invertibility condition: + * (and (bvuge s1 t1) (=> (= s1 t1) (distinct tx ones))) + * where + * ones = ~0 with getSize(ones) = wx */ + Node n = bv::utils::mkOnes(wx); + Node uge = nm->mkNode(BITVECTOR_UGE, s1, t1); + Node imp = nm->mkNode(IMPLIES, s1.eqNode(t1), tx.eqNode(n).notNode()); + scl = nm->mkNode(AND, uge, imp); + } + else + { + /* s1 o x <= t (interpret t as t1 o tx) + * with invertibility condition: + * (bvule s1 t1) */ + scl = nm->mkNode(BITVECTOR_ULE, s1, t1); + } + } + else + { + if (pol) + { + /* s1 o x o s2 > t (interpret t as t1 o tx o t2) + * with invertibility condition: + * (and + * (bvuge s1 t1) + * (=> (and (= s1 t1) (= tx ones)) (bvugt s2 t2))) + * where + * ones = ~0 with getSize(ones) = wx */ + Node n = bv::utils::mkOnes(wx); + Node uge = nm->mkNode(BITVECTOR_UGE, s1, t1); + Node a = nm->mkNode(AND, s1.eqNode(t1), tx.eqNode(n)); + Node imp = nm->mkNode(IMPLIES, a, nm->mkNode(BITVECTOR_UGT, s2, t2)); + scl = nm->mkNode(AND, uge, imp); + } + else + { + /* s1 o x o s2 <= t (interpret t as t1 o tx o t2) + * with invertibility condition: + * (and + * (bvule s1 t1) + * (=> (and (= s1 t1) (= tx z)) (bvule s2 t2))) + * where + * z = 0 with getSize(z) = wx */ + Node z = bv::utils::mkZero(wx); + Node ule = nm->mkNode(BITVECTOR_ULE, s1, t1); + Node a = nm->mkNode(AND, s1.eqNode(t1), tx.eqNode(z)); + Node imp = nm->mkNode(IMPLIES, a, nm->mkNode(BITVECTOR_ULE, s2, t2)); + scl = nm->mkNode(AND, ule, imp); + } + } + } + else if (litk == BITVECTOR_SLT) + { + if (s1.isNull()) + { + if (pol) + { + /* x o s2 < t (interpret t as tx o t2) + * with invertibility condition: + * (=> (= tx min) (bvult s2 t2)) + * where + * min is the signed minimum value with getSize(min) = wx */ + Node min = bv::utils::mkConst(bv::utils::mkBitVectorMinSigned(wx)); + Node ult = nm->mkNode(BITVECTOR_ULT, s2, t2); + scl = nm->mkNode(IMPLIES, tx.eqNode(min), ult); + } + else + { + /* x o s2 >= t (interpret t as tx o t2) + * (=> (= tx max) (bvuge s2 t2)) + * where + * max is the signed maximum value with getSize(max) = wx */ + Node max = bv::utils::mkConst(bv::utils::mkBitVectorMaxSigned(wx)); + Node uge = nm->mkNode(BITVECTOR_UGE, s2, t2); + scl = nm->mkNode(IMPLIES, tx.eqNode(max), uge); + } + } + else if (s2.isNull()) + { + if (pol) + { + /* s1 o x < t (interpret t as t1 o tx) + * with invertibility condition: + * (and (bvsle s1 t1) (=> (= s1 t1) (distinct tx z))) + * where + * z = 0 with getSize(z) = wx */ + Node z = bv::utils::mkZero(wx); + Node sle = nm->mkNode(BITVECTOR_SLE, s1, t1); + Node imp = nm->mkNode(IMPLIES, s1.eqNode(t1), tx.eqNode(z).notNode()); + scl = nm->mkNode(AND, sle, imp); + } + else + { + /* s1 o x >= t (interpret t as t1 o tx) + * with invertibility condition: + * (bvsge s1 t1) */ + scl = nm->mkNode(BITVECTOR_SGE, s1, t1); + } + } + else + { + if (pol) + { + /* s1 o x o s2 < t (interpret t as t1 o tx o t2) + * with invertibility condition: + * (and + * (bvsle s1 t1) + * (=> (and (= s1 t1) (= tx z)) (bvult s2 t2))) + * where + * z = 0 with getSize(z) = wx */ + Node z = bv::utils::mkZero(wx); + Node sle = nm->mkNode(BITVECTOR_SLE, s1, t1); + Node a = nm->mkNode(AND, s1.eqNode(t1), tx.eqNode(z)); + Node imp = nm->mkNode(IMPLIES, a, nm->mkNode(BITVECTOR_ULT, s2, t2)); + scl = nm->mkNode(AND, sle, imp); + } + else + { + /* s1 o x o s2 >= t (interpret t as t1 o tx o t2) + * with invertibility condition: + * (and + * (bvsge s1 t1) + * (=> (and (= s1 t1) (= tx ones)) (bvuge s2 t2))) + * where + * ones = ~0 with getSize(ones) = wx */ + Node n = bv::utils::mkOnes(wx); + Node sge = nm->mkNode(BITVECTOR_SGE, s1, t1); + Node a = nm->mkNode(AND, s1.eqNode(t1), tx.eqNode(n)); + Node imp = nm->mkNode(IMPLIES, a, nm->mkNode(BITVECTOR_UGE, s2, t2)); + scl = nm->mkNode(AND, sge, imp); + } + } + } + else + { + Assert(litk == BITVECTOR_SGT); + if (s1.isNull()) + { + if (pol) + { + /* x o s2 > t (interpret t as tx o t2) + * with invertibility condition: + * (=> (= tx max) (bvugt s2 t2)) + * where + * max is the signed maximum value with getSize(max) = wx */ + Node max = bv::utils::mkConst(bv::utils::mkBitVectorMaxSigned(wx)); + Node ugt = nm->mkNode(BITVECTOR_UGT, s2, t2); + scl = nm->mkNode(IMPLIES, tx.eqNode(max), ugt); + } + else + { + /* x o s2 <= t (interpret t as tx o t2) + * with invertibility condition: + * (=> (= tx min) (bvule s2 t2)) + * where + * min is the signed minimum value with getSize(min) = wx */ + Node min = bv::utils::mkConst(bv::utils::mkBitVectorMinSigned(wx)); + Node ule = nm->mkNode(BITVECTOR_ULE, s2, t2); + scl = nm->mkNode(IMPLIES, tx.eqNode(min), ule); + } + } + else if (s2.isNull()) + { + if (pol) + { + /* s1 o x > t (interpret t as t1 o tx) + * with invertibility condition: + * (and (bvsge s1 t1) (=> (= s1 t1) (distinct tx ones))) + * where + * ones = ~0 with getSize(ones) = wx */ + Node n = bv::utils::mkOnes(wx); + Node sge = nm->mkNode(BITVECTOR_SGE, s1, t1); + Node imp = nm->mkNode(IMPLIES, s1.eqNode(t1), tx.eqNode(n).notNode()); + scl = nm->mkNode(AND, sge, imp); + } + else + { + /* s1 o x <= t (interpret t as t1 o tx) + * with invertibility condition: + * (bvsle s1 t1) */ + scl = nm->mkNode(BITVECTOR_SLE, s1, t1); + } + } + else + { + if (pol) + { + /* s1 o x o s2 > t (interpret t as t1 o tx o t2) + * with invertibility condition: + * (and + * (bvsge s1 t1) + * (=> (and (= s1 t1) (= tx ones)) (bvugt s2 t2))) + * where + * ones = ~0 with getSize(ones) = wx */ + Node n = bv::utils::mkOnes(wx); + Node sge = nm->mkNode(BITVECTOR_SGE, s1, t1); + Node a = nm->mkNode(AND, s1.eqNode(t1), tx.eqNode(n)); + Node imp = nm->mkNode(IMPLIES, a, nm->mkNode(BITVECTOR_UGT, s2, t2)); + scl = nm->mkNode(AND, sge, imp); + } + else + { + /* s1 o x o s2 <= t (interpret t as t1 o tx o t2) + * with invertibility condition: + * (and + * (bvsle s1 t1) + * (=> (and (= s1 t1) (= tx z)) (bvule s2 t2))) + * where + * z = 0 with getSize(z) = wx */ + Node z = bv::utils::mkZero(wx); + Node sle = nm->mkNode(BITVECTOR_SLE, s1, t1); + Node a = nm->mkNode(AND, s1.eqNode(t1), tx.eqNode(z)); + Node imp = nm->mkNode(IMPLIES, a, nm->mkNode(BITVECTOR_ULE, s2, t2)); + scl = nm->mkNode(AND, sle, imp); + } + } + } + scr = s1.isNull() ? x : bv::utils::mkConcat(s1, x); + if (!s2.isNull()) scr = bv::utils::mkConcat(scr, s2); + scr = nm->mkNode(litk, scr, t); + Node sc = nm->mkNode(IMPLIES, scl, pol ? scr : scr.notNode()); + Trace("bv-invert") << "Add SC_" << k << "(" << x << "): " << sc << std::endl; + return sc; +} + +static Node getScBvSext(bool pol, + Kind litk, + unsigned idx, + Node x, + Node sv_t, + Node t) +{ + Assert(litk == EQUAL + || litk == BITVECTOR_ULT || litk == BITVECTOR_SLT + || litk == BITVECTOR_UGT || litk == BITVECTOR_SGT); + + NodeManager* nm = NodeManager::currentNM(); + Node scl; + Assert(idx == 0); + (void) idx; + unsigned ws = bv::utils::getSignExtendAmount(sv_t); + unsigned w = bv::utils::getSize(t); + + if (litk == EQUAL) + { + if (pol) + { + /* x sext ws = t + * with invertibility condition: + * (or (= ((_ extract u l) t) z) + * (= ((_ extract u l) t) ones)) + * where + * u = w - 1 + * l = w - 1 - ws + * z = 0 with getSize(z) = ws + 1 + * ones = ~0 with getSize(ones) = ws + 1 */ + Node ext = bv::utils::mkExtract(t, w - 1, w - 1 - ws); + Node z = bv::utils::mkZero(ws + 1); + Node n = bv::utils::mkOnes(ws + 1); + scl = nm->mkNode(OR, ext.eqNode(z), ext.eqNode(n)); + } + else + { + /* x sext ws != t + * true (no invertibility condition) */ + scl = nm->mkConst<bool>(true); + } + } + else if (litk == BITVECTOR_ULT) + { + if (pol) + { + /* x sext ws < t + * with invertibility condition: + * (distinct t z) + * where + * z = 0 with getSize(z) = w */ + Node z = bv::utils::mkZero(w); + scl = t.eqNode(z).notNode(); + } + else + { + /* x sext ws >= t + * true (no invertibility condition) */ + scl = nm->mkConst<bool>(true); + } + } + else if (litk == BITVECTOR_UGT) + { + if (pol) + { + /* x sext ws > t + * with invertibility condition: + * (distinct t ones) + * where + * ones = ~0 with getSize(ones) = w */ + Node n = bv::utils::mkOnes(w); + scl = t.eqNode(n).notNode(); + } + else + { + /* x sext ws <= t + * true (no invertibility condition) */ + scl = nm->mkConst<bool>(true); + } + } + else if (litk == BITVECTOR_SLT) + { + if (pol) + { + /* x sext ws < t + * with invertibility condition: + * (bvslt ((_ sign_extend ws) min) t) + * where + * min is the signed minimum value with getSize(min) = w - ws */ + Node min = bv::utils::mkConst(bv::utils::mkBitVectorMinSigned(w - ws)); + Node ext = bv::utils::mkSignExtend(min, ws); + scl = nm->mkNode(BITVECTOR_SLT, ext, t); + } + else + { + /* x sext ws >= t + * with invertibility condition (combination of sgt and eq): + * + * (or + * (or (= ((_ extract u l) t) z) ; eq + * (= ((_ extract u l) t) ones)) + * (bvslt t ((_ zero_extend ws) max))) ; sgt + * where + * u = w - 1 + * l = w - 1 - ws + * z = 0 with getSize(z) = ws + 1 + * ones = ~0 with getSize(ones) = ws + 1 + * max is the signed maximum value with getSize(max) = w - ws */ + Node ext1 = bv::utils::mkExtract(t, w - 1, w - 1 - ws); + Node z = bv::utils::mkZero(ws + 1); + Node n = bv::utils::mkOnes(ws + 1); + Node o1 = nm->mkNode(OR, ext1.eqNode(z), ext1.eqNode(n)); + Node max = bv::utils::mkConst(bv::utils::mkBitVectorMaxSigned(w - ws)); + Node ext2 = bv::utils::mkConcat(bv::utils::mkZero(ws), max); + Node o2 = nm->mkNode(BITVECTOR_SLT, t, ext2); + scl = nm->mkNode(OR, o1, o2); + } + } + else + { + Assert(litk == BITVECTOR_SGT); + if (pol) + { + /* x sext ws > t + * with invertibility condition: + * (bvslt t ((_ zero_extend ws) max)) + * where + * max is the signed maximum value with getSize(max) = w - ws */ + Node max = bv::utils::mkConst(bv::utils::mkBitVectorMaxSigned(w - ws)); + Node ext = bv::utils::mkConcat(bv::utils::mkZero(ws), max); + scl = nm->mkNode(BITVECTOR_SLT, t, ext); + } + else + { + /* x sext ws <= t + * with invertibility condition: + * (bvsge t (bvnot ((_ zero_extend ws) max))) + * where + * max is the signed maximum value with getSize(max) = w - ws */ + Node max = bv::utils::mkConst(bv::utils::mkBitVectorMaxSigned(w - ws)); + Node ext = bv::utils::mkConcat(bv::utils::mkZero(ws), max); + scl = nm->mkNode(BITVECTOR_SGE, t, nm->mkNode(BITVECTOR_NOT, ext)); + } + } + Node scr = nm->mkNode(litk, bv::utils::mkSignExtend(x, ws), t); + Node sc = nm->mkNode(IMPLIES, scl, pol ? scr : scr.notNode()); + Trace("bv-invert") << "Add SC_" << BITVECTOR_SIGN_EXTEND << "(" << x + << "): " << sc << std::endl; + return sc; +} + Node BvInverter::solveBvLit(Node sv, Node lit, std::vector<unsigned>& path, @@ -2214,7 +2823,7 @@ Node BvInverter::solveBvLit(Node sv, Assert(!path.empty()); bool pol = true; - unsigned index, nchildren; + unsigned index; NodeManager* nm = NodeManager::currentNM(); Kind k, litk; @@ -2260,167 +2869,168 @@ Node BvInverter::solveBvLit(Node sv, while (!path.empty()) { + unsigned nchildren = sv_t.getNumChildren(); + Assert(nchildren > 0); index = path.back(); - Assert(index < sv_t.getNumChildren()); + Assert(index < nchildren); path.pop_back(); k = sv_t.getKind(); - nchildren = sv_t.getNumChildren(); - if (k == BITVECTOR_NOT || k == BITVECTOR_NEG) + /* Note: All n-ary kinds except for CONCAT (i.e., BITVECTOR_AND, + * BITVECTOR_OR, MULT, PLUS) are commutative (no case split + * based on index). */ + Node s = dropChild(sv_t, index); + Assert((nchildren == 1 && s.isNull()) || (nchildren > 1 && !s.isNull())); + TypeNode solve_tn = sv_t[index].getType(); + Node x = getSolveVariable(solve_tn); + Node sc; + + if (litk == EQUAL && (k == BITVECTOR_NOT || k == BITVECTOR_NEG)) { t = nm->mkNode(k, t); } + else if (litk == EQUAL && k == BITVECTOR_PLUS) + { + t = nm->mkNode(BITVECTOR_SUB, t, s); + } + else if (litk == EQUAL && k == BITVECTOR_XOR) + { + t = nm->mkNode(BITVECTOR_XOR, t, s); + } + else if (litk == EQUAL && k == BITVECTOR_MULT + && s.isConst() && bv::utils::getBit(s, 0)) + { + unsigned w = bv::utils::getSize(s); + Integer s_val = s.getConst<BitVector>().toInteger(); + Integer mod_val = Integer(1).multiplyByPow2(w); + Trace("bv-invert-debug") + << "Compute inverse : " << s_val << " " << mod_val << std::endl; + Integer inv_val = s_val.modInverse(mod_val); + Trace("bv-invert-debug") << "Inverse : " << inv_val << std::endl; + Node inv = bv::utils::mkConst(w, inv_val); + t = nm->mkNode(BITVECTOR_MULT, inv, t); + } + else if (k == BITVECTOR_MULT) + { + sc = getScBvMult(pol, litk, k, index, x, s, t); + } + else if (k == BITVECTOR_SHL) + { + sc = getScBvShl(pol, litk, k, index, x, s, t); + } + else if (k == BITVECTOR_UREM_TOTAL) + { + sc = getScBvUrem(pol, litk, k, index, x, s, t); + } + else if (k == BITVECTOR_UDIV_TOTAL) + { + sc = getScBvUdiv(pol, litk, k, index, x, s, t); + } + else if (k == BITVECTOR_AND || k == BITVECTOR_OR) + { + sc = getScBvAndOr(pol, litk, k, index, x, s, t); + } + else if (k == BITVECTOR_LSHR) + { + sc = getScBvLshr(pol, litk, k, index, x, s, t); + } + else if (k == BITVECTOR_ASHR) + { + sc = getScBvAshr(pol, litk, k, index, x, s, t); + } else if (k == BITVECTOR_CONCAT) { - /* x = t[upper:lower] - * where - * upper = getSize(t) - 1 - sum(getSize(sv_t[i])) for i < index - * lower = getSize(sv_t[i]) for i > index */ - unsigned upper, lower; - upper = bv::utils::getSize(t) - 1; - lower = 0; - NodeBuilder<> nb(nm, BITVECTOR_CONCAT); - for (unsigned i = 0; i < nchildren; i++) + if (litk == EQUAL && options::cbqiBvConcInv()) + { + /* Compute inverse for s1 o x, x o s2, s1 o x o s2 + * (while disregarding that invertibility depends on si) + * rather than an invertibility condition (the proper handling). + * This improves performance on a considerable number of benchmarks. + * + * x = t[upper:lower] + * where + * upper = getSize(t) - 1 - sum(getSize(sv_t[i])) for i < index + * lower = getSize(sv_t[i]) for i > index */ + unsigned upper, lower; + upper = bv::utils::getSize(t) - 1; + lower = 0; + NodeBuilder<> nb(BITVECTOR_CONCAT); + for (unsigned i = 0; i < nchildren; i++) + { + if (i < index) { upper -= bv::utils::getSize(sv_t[i]); } + else if (i > index) { lower += bv::utils::getSize(sv_t[i]); } + } + t = bv::utils::mkExtract(t, upper, lower); + } + else { - if (i < index) { upper -= bv::utils::getSize(sv_t[i]); } - else if (i > index) { lower += bv::utils::getSize(sv_t[i]); } + sc = getScBvConcat(pol, litk, index, x, sv_t, t); } - t = bv::utils::mkExtract(t, upper, lower); } else if (k == BITVECTOR_SIGN_EXTEND) { - t = bv::utils::mkExtract(t, bv::utils::getSize(sv_t[index]) - 1, 0); + sc = getScBvSext(pol, litk, index, x, sv_t, t); } - else if (k == BITVECTOR_EXTRACT || k == BITVECTOR_COMP) + else if (litk == BITVECTOR_ULT || litk == BITVECTOR_UGT) { - Trace("bv-invert") << "bv-invert : Unsupported for index " << index - << ", from " << sv_t << std::endl; - return Node::null(); + sc = getScBvUltUgt(pol, litk, x, t); + } + else if (litk == BITVECTOR_SLT || litk == BITVECTOR_SGT) + { + sc = getScBvSltSgt(pol, litk, x, t); + } + else if (pol == false) + { + Assert (litk == EQUAL); + sc = nm->mkNode(DISTINCT, x, t); + Trace("bv-invert") << "Add SC_" << litk << "(" << x << "): " << sc + << std::endl; } else { - Assert(nchildren >= 2); - Node s = nchildren == 2 ? sv_t[1 - index] : dropChild(sv_t, index); - Node t_new; - /* Note: All n-ary kinds except for CONCAT (i.e., AND, OR, MULT, PLUS) - * are commutative (no case split based on index). */ - - // handle cases where the inversion has a unique solution - if (k == BITVECTOR_PLUS) - { - t_new = nm->mkNode(BITVECTOR_SUB, t, s); - } - else if (k == BITVECTOR_XOR) - { - t_new = nm->mkNode(BITVECTOR_XOR, t, s); - } - else if (k == BITVECTOR_MULT && s.isConst() && bv::utils::getBit(s, 0)) - { - unsigned w = bv::utils::getSize(s); - Integer s_val = s.getConst<BitVector>().toInteger(); - Integer mod_val = Integer(1).multiplyByPow2(w); - Trace("bv-invert-debug") - << "Compute inverse : " << s_val << " " << mod_val << std::endl; - Integer inv_val = s_val.modInverse(mod_val); - Trace("bv-invert-debug") << "Inverse : " << inv_val << std::endl; - Node inv = bv::utils::mkConst(w, inv_val); - t_new = nm->mkNode(BITVECTOR_MULT, inv, t); - } - - if (!t_new.isNull()) - { - // In this case, s op x = t is equivalent to x = t_new - t = t_new; - } - else - { - TypeNode solve_tn = sv_t[index].getType(); - Node sc; + Trace("bv-invert") << "bv-invert : Unknown kind " << k + << " for bit-vector term " << sv_t << std::endl; + return Node::null(); + } - switch (k) - { - case BITVECTOR_MULT: - sc = getScBvMult( - pol, litk, k, index, getSolveVariable(solve_tn), s, t); - break; - - case BITVECTOR_SHL: - sc = getScBvShl( - pol, litk, k, index, getSolveVariable(solve_tn), s, t); - break; - - case BITVECTOR_UREM_TOTAL: - sc = getScBvUrem( - pol, litk, k, index, getSolveVariable(solve_tn), s, t); - break; - - case BITVECTOR_UDIV_TOTAL: - sc = getScBvUdiv( - pol, litk, k, index, getSolveVariable(solve_tn), s, t); - break; - - case BITVECTOR_AND: - case BITVECTOR_OR: - sc = getScBvAndOr( - pol, litk, k, index, getSolveVariable(solve_tn), s, t); - break; - - case BITVECTOR_LSHR: - sc = getScBvLshr( - pol, litk, k, index, getSolveVariable(solve_tn), s, t); - break; - - case BITVECTOR_ASHR: - sc = getScBvAshr( - pol, litk, k, index, getSolveVariable(solve_tn), s, t); - break; - - default: - Trace("bv-invert") << "bv-invert : Unknown kind " << k - << " for bit-vector term " << sv_t << std::endl; - return Node::null(); - } - Assert(!sc.isNull()); - /* We generate a choice term (choice x0. SC => x0 <k> s <litk> t) for - * x <k> s <litk> t. When traversing down, this choice term determines - * the value for x <k> s = (choice x0. SC => x0 <k> s <litk> t), i.e., - * from here on, the propagated literal is a positive equality. */ - litk = EQUAL; - pol = true; - /* t = fresh skolem constant */ - t = getInversionNode(sc, solve_tn, m); - if (t.isNull()) - { - return t; - } - } + if (!sc.isNull()) + { + /* We generate a choice term (choice x0. SC => x0 <k> s <litk> t) for + * x <k> s <litk> t. When traversing down, this choice term determines + * the value for x <k> s = (choice x0. SC => x0 <k> s <litk> t), i.e., + * from here on, the propagated literal is a positive equality. */ + litk = EQUAL; + pol = true; + /* t = fresh skolem constant */ + t = getInversionNode(sc, solve_tn, m); + if (t.isNull()) { return t; } } + sv_t = sv_t[index]; } + + /* Base case */ Assert(sv_t == sv); + TypeNode solve_tn = sv.getType(); + Node x = getSolveVariable(solve_tn); + Node sc; if (litk == BITVECTOR_ULT || litk == BITVECTOR_UGT) { - TypeNode solve_tn = sv_t.getType(); - Node sc = getScBvUltUgt(pol, litk, getSolveVariable(solve_tn), t); - t = getInversionNode(sc, solve_tn, m); + sc = getScBvUltUgt(pol, litk, x, t); } else if (litk == BITVECTOR_SLT || litk == BITVECTOR_SGT) { - TypeNode solve_tn = sv_t.getType(); - Node sc = getScBvSltSgt(pol, litk, getSolveVariable(solve_tn), t); - t = getInversionNode(sc, solve_tn, m); + sc = getScBvSltSgt(pol, litk, x, t); } else if (pol == false) { Assert (litk == EQUAL); - TypeNode solve_tn = sv_t.getType(); - Node x = getSolveVariable(solve_tn); - Node sc = nm->mkNode(DISTINCT, x, t); + sc = nm->mkNode(DISTINCT, x, t); Trace("bv-invert") << "Add SC_" << litk << "(" << x << "): " << sc << std::endl; - t = getInversionNode(sc, solve_tn, m); } - return t; + + return sc.isNull() ? t : getInversionNode(sc, solve_tn, m); } /*---------------------------------------------------------------------------*/ diff --git a/src/theory/quantifiers/bv_inverter.h b/src/theory/quantifiers/bv_inverter.h index 470c3a71f..10ef6ab4c 100644 --- a/src/theory/quantifiers/bv_inverter.h +++ b/src/theory/quantifiers/bv_inverter.h @@ -2,9 +2,9 @@ /*! \file bv_inverter.h ** \verbatim ** Top contributors (to current version): - ** Andrew Reynolds + ** Mathias Preiner, Andrew Reynolds, Aina Niemetz ** This file is part of the CVC4 project. - ** Copyright (c) 2009-2017 by the authors listed in the file AUTHORS + ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS ** in the top-level source directory) and their institutional affiliations. ** All rights reserved. See the file COPYING in the top-level source ** directory for licensing information.\endverbatim diff --git a/src/theory/quantifiers/ce_guided_conjecture.cpp b/src/theory/quantifiers/ce_guided_conjecture.cpp index 378b26eef..889a80879 100644 --- a/src/theory/quantifiers/ce_guided_conjecture.cpp +++ b/src/theory/quantifiers/ce_guided_conjecture.cpp @@ -112,6 +112,15 @@ void CegConjecture::assign( Node q ) { d_base_inst = Rewriter::rewrite(d_qe->getInstantiate()->getInstantiation( d_embed_quant, vars, d_candidates)); Trace("cegqi") << "Base instantiation is : " << d_base_inst << std::endl; + d_base_body = d_base_inst; + if (d_base_body.getKind() == NOT && d_base_body[0].getKind() == FORALL) + { + for (const Node& v : d_base_body[0][0]) + { + d_base_vars.push_back(v); + } + d_base_body = d_base_body[0][1]; + } // register this term with sygus database and other utilities that impact // the enumerative sygus search @@ -182,7 +191,16 @@ void CegConjecture::assign( Node q ) { Trace("cegqi-lemma") << "Cegqi::Lemma : initial (guarded) lemma : " << lem << std::endl; d_qe->getOutputChannel().lemma( lem ); } - + + // assign the cegis sampler if applicable + if (options::cegisSample() != CEGIS_SAMPLE_NONE) + { + Trace("cegis-sample") << "Initialize sampler for " << d_base_body << "..." + << std::endl; + TypeNode bt = d_base_body.getType(); + d_cegis_sampler.initialize(bt, d_base_vars, options::sygusSamples()); + } + Trace("cegqi") << "...finished, single invocation = " << isSingleInvocation() << std::endl; } @@ -284,6 +302,18 @@ void CegConjecture::doCheck(std::vector< Node >& lems, std::vector< Node >& mode //check whether we will run CEGIS on inner skolem variables bool sk_refine = ( !isGround() || d_refine_count==0 ) && ( !d_ceg_pbe->isPbe() || constructed_cand ); if( sk_refine ){ + if (options::cegisSample() == CEGIS_SAMPLE_TRUST) + { + // we have that the current candidate passed a sample test + // since we trust sampling in this mode, we assert there is no + // counterexample to the conjecture here. + NodeManager* nm = NodeManager::currentNM(); + Node lem = nm->mkNode(OR, d_quant.negate(), nm->mkConst(false)); + lem = getStreamGuardedLemma(lem); + lems.push_back(lem); + recordInstantiation(c_model_values); + return; + } Assert( d_ce_sk.empty() ); d_ce_sk.push_back( std::vector< Node >() ); }else{ @@ -329,12 +359,7 @@ void CegConjecture::doCheck(std::vector< Node >& lems, std::vector< Node >& mode std::map< Node, Node > visited_n; lem = d_qe->getTermDatabaseSygus()->getEagerUnfold( lem, visited_n ); } - if( options::sygusStream() ){ - // if we are in streaming mode, we guard with the current stream guard - Node curr_stream_guard = getCurrentStreamGuard(); - Assert( !curr_stream_guard.isNull() ); - lem = NodeManager::currentNM()->mkNode( kind::OR, curr_stream_guard.negate(), lem ); - } + lem = getStreamGuardedLemma(lem); lems.push_back( lem ); recordInstantiation( c_model_values ); } @@ -404,17 +429,13 @@ void CegConjecture::doRefine( std::vector< Node >& lems ){ Trace("cegqi-refine") << "doRefine : construct and finalize lemmas..." << std::endl; - Node lem = base_lem; base_lem = base_lem.substitute( sk_vars.begin(), sk_vars.end(), sk_subs.begin(), sk_subs.end() ); base_lem = Rewriter::rewrite( base_lem ); - d_refinement_lemmas_base.push_back( base_lem ); - - lem = NodeManager::currentNM()->mkNode( OR, getGuard().negate(), lem ); - - lem = lem.substitute( sk_vars.begin(), sk_vars.end(), sk_subs.begin(), sk_subs.end() ); - lem = Rewriter::rewrite( lem ); - d_refinement_lemmas.push_back( lem ); + d_refinement_lemmas.push_back(base_lem); + + Node lem = + NodeManager::currentNM()->mkNode(OR, getGuard().negate(), base_lem); lems.push_back( lem ); d_ce_sk.clear(); @@ -435,6 +456,12 @@ void CegConjecture::getModelValues( std::vector< Node >& n, std::vector< Node >& std::stringstream ss; Printer::getPrinter(options::outputLanguage())->toStreamSygus(ss, nv); Trace("cegqi-engine") << ss.str() << " "; + if (Trace.isOn("cegqi-engine-rr")) + { + Node bv = d_qe->getTermDatabaseSygus()->sygusToBuiltin(nv, tn); + bv = Rewriter::rewrite(bv); + Trace("cegqi-engine-rr") << " -> " << bv << std::endl; + } } Assert( !nv.isNull() ); } @@ -467,6 +494,18 @@ Node CegConjecture::getCurrentStreamGuard() const { } } +Node CegConjecture::getStreamGuardedLemma(Node n) const +{ + if (options::sygusStream()) + { + // if we are in streaming mode, we guard with the current stream guard + Node csg = getCurrentStreamGuard(); + Assert(!csg.isNull()); + return NodeManager::currentNM()->mkNode(kind::OR, csg.negate(), n); + } + return n; +} + Node CegConjecture::getNextDecisionRequest( unsigned& priority ) { // first, must try the guard // which denotes "this conjecture is feasible" @@ -554,81 +593,210 @@ Node CegConjecture::getNextDecisionRequest( unsigned& priority ) { void CegConjecture::printSynthSolution( std::ostream& out, bool singleInvocation ) { Trace("cegqi-debug") << "Printing synth solution..." << std::endl; Assert( d_quant[0].getNumChildren()==d_embed_quant[0].getNumChildren() ); - for( unsigned i=0; i<d_embed_quant[0].getNumChildren(); i++ ){ + std::vector<Node> sols; + std::vector<int> statuses; + getSynthSolutionsInternal(sols, statuses, singleInvocation); + for (unsigned i = 0, size = d_embed_quant[0].getNumChildren(); i < size; i++) + { + Node sol = sols[i]; + if (!sol.isNull()) + { + Node prog = d_embed_quant[0][i]; + int status = statuses[i]; + TypeNode tn = prog.getType(); + const Datatype& dt = static_cast<DatatypeType>(tn.toType()).getDatatype(); + std::stringstream ss; + ss << prog; + std::string f(ss.str()); + f.erase(f.begin()); + out << "(define-fun " << f << " "; + if( dt.getSygusVarList().isNull() ){ + out << "() "; + }else{ + out << dt.getSygusVarList() << " "; + } + out << dt.getSygusType() << " "; + if( status==0 ){ + out << sol; + }else{ + Printer::getPrinter(options::outputLanguage())->toStreamSygus(out, sol); + } + out << ")" << std::endl; + + if (status != 0 && options::sygusRewSynth()) + { + TermDbSygus* sygusDb = d_qe->getTermDatabaseSygus(); + std::map<Node, SygusSampler>::iterator its = d_sampler.find(prog); + if (its == d_sampler.end()) + { + d_sampler[prog].initializeSygus( + sygusDb, prog, options::sygusSamples()); + its = d_sampler.find(prog); + } + Node solb = sygusDb->sygusToBuiltin(sol, prog.getType()); + Node eq_sol = its->second.registerTerm(solb); + // eq_sol is a candidate solution that is equivalent to sol + if (eq_sol != solb) + { + // one of eq_sol or solb must be ordered + bool eqor = its->second.isOrdered(eq_sol); + bool sor = its->second.isOrdered(solb); + bool outputRewrite = false; + if (eqor || sor) + { + outputRewrite = true; + // if only one is ordered, then the ordered one must contain the + // free variables of the other + if (!eqor) + { + outputRewrite = its->second.containsFreeVariables(solb, eq_sol); + } + else if (!sor) + { + outputRewrite = its->second.containsFreeVariables(eq_sol, solb); + } + } + + if (outputRewrite) + { + // Terms solb and eq_sol are equivalent under sample points but do + // not rewrite to the same term. Hence, this indicates a candidate + // rewrite. + out << "(candidate-rewrite " << solb << " " << eq_sol << ")" + << std::endl; + // if the previous value stored was unordered, but this is + // ordered, we prefer this one. Thus, we force its addition to the + // sampler database. + if (!eqor) + { + its->second.registerTerm(solb, true); + } + } + else + { + Trace("sygus-synth-rr") + << "Alpha equivalent candidate rewrite : " << eq_sol << " " + << solb << std::endl; + } + } + } + } + } +} + +void CegConjecture::getSynthSolutions(std::map<Node, Node>& sol_map, + bool singleInvocation) +{ + NodeManager* nm = NodeManager::currentNM(); + TermDbSygus* sygusDb = d_qe->getTermDatabaseSygus(); + std::vector<Node> sols; + std::vector<int> statuses; + getSynthSolutionsInternal(sols, statuses, singleInvocation); + for (unsigned i = 0, size = d_embed_quant[0].getNumChildren(); i < size; i++) + { + Node sol = sols[i]; + int status = statuses[i]; + // get the builtin solution + Node bsol = sol; + if (status != 0) + { + // convert sygus to builtin here + bsol = sygusDb->sygusToBuiltin(sol, sol.getType()); + } + // convert to lambda + TypeNode tn = d_embed_quant[0][i].getType(); + const Datatype& dt = static_cast<DatatypeType>(tn.toType()).getDatatype(); + Node bvl = Node::fromExpr(dt.getSygusVarList()); + if (!bvl.isNull()) + { + bsol = nm->mkNode(LAMBDA, bvl, bsol); + } + // store in map + Node fvar = d_quant[0][i]; + Assert(fvar.getType() == bsol.getType()); + sol_map[fvar] = bsol; + } +} + +void CegConjecture::getSynthSolutionsInternal(std::vector<Node>& sols, + std::vector<int>& statuses, + bool singleInvocation) +{ + for (unsigned i = 0, size = d_embed_quant[0].getNumChildren(); i < size; i++) + { Node prog = d_embed_quant[0][i]; - Trace("cegqi-debug") << " print solution for " << prog << std::endl; - std::stringstream ss; - ss << prog; - std::string f(ss.str()); - f.erase(f.begin()); + Trace("cegqi-debug") << " get solution for " << prog << std::endl; TypeNode tn = prog.getType(); - Assert( tn.isDatatype() ); - const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype(); - Assert( dt.isSygus() ); - //get the solution + Assert(tn.isDatatype()); + // get the solution Node sol; int status = -1; - if( singleInvocation ){ - Assert( d_ceg_si != NULL ); - sol = d_ceg_si->getSolution( i, tn, status, true ); - if( !sol.isNull() ){ - sol = sol.getKind()==LAMBDA ? sol[1] : sol; + if (singleInvocation) + { + Assert(d_ceg_si != NULL); + sol = d_ceg_si->getSolution(i, tn, status, true); + if (!sol.isNull()) + { + sol = sol.getKind() == LAMBDA ? sol[1] : sol; } - }else{ - Node cprog = getCandidate( i ); - if( !d_cinfo[cprog].d_inst.empty() ){ + } + else + { + Node cprog = getCandidate(i); + if (!d_cinfo[cprog].d_inst.empty()) + { // the solution is just the last instantiated term sol = d_cinfo[cprog].d_inst.back(); status = 1; - - //check if there was a template + + // check if there was a template Node sf = d_quant[0][i]; - Node templ = d_ceg_si->getTemplate( sf ); - if( !templ.isNull() ){ - Trace("cegqi-inv-debug") << sf << " used template : " << templ << std::endl; + Node templ = d_ceg_si->getTemplate(sf); + if (!templ.isNull()) + { + Trace("cegqi-inv-debug") + << sf << " used template : " << templ << std::endl; // if it was not embedded into the grammar - if( !options::sygusTemplEmbedGrammar() ){ - TNode templa = d_ceg_si->getTemplateArg( sf ); + if (!options::sygusTemplEmbedGrammar()) + { + TNode templa = d_ceg_si->getTemplateArg(sf); // make the builtin version of the full solution TermDbSygus* sygusDb = d_qe->getTermDatabaseSygus(); - sol = sygusDb->sygusToBuiltin( sol, sol.getType() ); - Trace("cegqi-inv") << "Builtin version of solution is : " - << sol << ", type : " << sol.getType() - << std::endl; + sol = sygusDb->sygusToBuiltin(sol, sol.getType()); + Trace("cegqi-inv") << "Builtin version of solution is : " << sol + << ", type : " << sol.getType() << std::endl; TNode tsol = sol; - sol = templ.substitute( templa, tsol ); + sol = templ.substitute(templa, tsol); Trace("cegqi-inv-debug") << "With template : " << sol << std::endl; - sol = Rewriter::rewrite( sol ); + sol = Rewriter::rewrite(sol); Trace("cegqi-inv-debug") << "Simplified : " << sol << std::endl; // now, reconstruct to the syntax sol = d_ceg_si->reconstructToSyntax(sol, tn, status, true); - sol = sol.getKind()==LAMBDA ? sol[1] : sol; - Trace("cegqi-inv-debug") << "Reconstructed to syntax : " << sol << std::endl; - }else{ - Trace("cegqi-inv-debug") << "...was embedding into grammar." << std::endl; + sol = sol.getKind() == LAMBDA ? sol[1] : sol; + Trace("cegqi-inv-debug") + << "Reconstructed to syntax : " << sol << std::endl; + } + else + { + Trace("cegqi-inv-debug") + << "...was embedding into grammar." << std::endl; } - }else{ - Trace("cegqi-inv-debug") << sf << " did not use template" << std::endl; } - }else{ - Trace("cegqi-warn") << "WARNING : No recorded instantiations for syntax-guided solution!" << std::endl; - } - } - if( !(Trace.isOn("cegqi-stats")) && !sol.isNull() ){ - out << "(define-fun " << f << " "; - if( dt.getSygusVarList().isNull() ){ - out << "() "; - }else{ - out << dt.getSygusVarList() << " "; + else + { + Trace("cegqi-inv-debug") + << sf << " did not use template" << std::endl; + } } - out << dt.getSygusType() << " "; - if( status==0 ){ - out << sol; - }else{ - Printer::getPrinter(options::outputLanguage())->toStreamSygus(out, sol); + else + { + Trace("cegqi-warn") << "WARNING : No recorded instantiations for " + "syntax-guided solution!" + << std::endl; } - out << ")" << std::endl; } + sols.push_back(sol); + statuses.push_back(status); } } @@ -659,6 +827,83 @@ Node CegConjecture::getSymmetryBreakingPredicate( } } +bool CegConjecture::sampleAddRefinementLemma(std::vector<Node>& vals, + std::vector<Node>& lems) +{ + if (Trace.isOn("cegis-sample")) + { + Trace("cegis-sample") << "Check sampling for candidate solution" + << std::endl; + for (unsigned i = 0, size = vals.size(); i < size; i++) + { + Trace("cegis-sample") + << " " << d_candidates[i] << " -> " << vals[i] << std::endl; + } + } + Assert(vals.size() == d_candidates.size()); + Node sbody = d_base_body.substitute( + d_candidates.begin(), d_candidates.end(), vals.begin(), vals.end()); + Trace("cegis-sample-debug") << "Sample " << sbody << std::endl; + // do eager unfolding + std::map<Node, Node> visited_n; + sbody = d_qe->getTermDatabaseSygus()->getEagerUnfold(sbody, visited_n); + Trace("cegis-sample") << "Sample (after unfolding): " << sbody << std::endl; + + NodeManager* nm = NodeManager::currentNM(); + for (unsigned i = 0, size = d_cegis_sampler.getNumSamplePoints(); i < size; + i++) + { + if (d_cegis_sample_refine.find(i) == d_cegis_sample_refine.end()) + { + Node ev = d_cegis_sampler.evaluate(sbody, i); + Trace("cegis-sample-debug") + << "...evaluate point #" << i << " to " << ev << std::endl; + Assert(ev.isConst()); + Assert(ev.getType().isBoolean()); + if (!ev.getConst<bool>()) + { + Trace("cegis-sample-debug") << "...false for point #" << i << std::endl; + // mark this as a CEGIS point (no longer sampled) + d_cegis_sample_refine.insert(i); + std::vector<Node> pt; + d_cegis_sampler.getSamplePoint(i, pt); + Assert(d_base_vars.size() == pt.size()); + Node rlem = d_base_body.substitute( + d_base_vars.begin(), d_base_vars.end(), pt.begin(), pt.end()); + rlem = Rewriter::rewrite(rlem); + if (std::find( + d_refinement_lemmas.begin(), d_refinement_lemmas.end(), rlem) + == d_refinement_lemmas.end()) + { + if (Trace.isOn("cegis-sample")) + { + Trace("cegis-sample") << " false for point #" << i << " : "; + for (const Node& cn : pt) + { + Trace("cegis-sample") << cn << " "; + } + Trace("cegis-sample") << std::endl; + } + Trace("cegqi-engine") << " *** Refine by sampling" << std::endl; + d_refinement_lemmas.push_back(rlem); + // if trust, we are not interested in sending out refinement lemmas + if (options::cegisSample() != CEGIS_SAMPLE_TRUST) + { + Node lem = nm->mkNode(OR, getGuard().negate(), rlem); + lems.push_back(lem); + } + return true; + } + else + { + Trace("cegis-sample-debug") << "...duplicate." << std::endl; + } + } + } + } + return false; +} + }/* namespace CVC4::theory::quantifiers */ }/* namespace CVC4::theory */ }/* namespace CVC4 */ diff --git a/src/theory/quantifiers/ce_guided_conjecture.h b/src/theory/quantifiers/ce_guided_conjecture.h index 0f000bba5..dae261111 100644 --- a/src/theory/quantifiers/ce_guided_conjecture.h +++ b/src/theory/quantifiers/ce_guided_conjecture.h @@ -24,6 +24,7 @@ #include "theory/quantifiers/ce_guided_single_inv.h" #include "theory/quantifiers/sygus_grammar_cons.h" #include "theory/quantifiers/sygus_process_conj.h" +#include "theory/quantifiers/sygus_sampler.h" #include "theory/quantifiers_engine.h" namespace CVC4 { @@ -74,12 +75,26 @@ public: * This is step 2(b) of Figure 3 of Reynolds et al CAV 2015. */ void doRefine(std::vector< Node >& lems); - /** Print the synthesis solution - * singleInvocation is whether the solution was found by single invocation techniques. - */ //-------------------------------end for counterexample-guided check/refine - + /** + * prints the synthesis solution to output stream out. + * + * singleInvocation : set to true if we should consult the single invocation + * module to get synthesis solutions. + */ void printSynthSolution( std::ostream& out, bool singleInvocation ); + /** get synth solutions + * + * This returns a map from function-to-synthesize variables to their + * builtin solution, which has the same type. For example, for synthesis + * conjecture exists f. forall x. f( x )>x, this function may return the map + * containing the entry: + * f -> (lambda x. x+1) + * + * singleInvocation : set to true if we should consult the single invocation + * module to get synthesis solutions. + */ + void getSynthSolutions(std::map<Node, Node>& sol_map, bool singleInvocation); /** get guard, this is "G" in Figure 3 of Reynolds et al CAV 2015 */ Node getGuard(); /** is ground */ @@ -106,10 +121,21 @@ public: //-----------------------------------refinement lemmas /** get number of refinement lemmas we have added so far */ unsigned getNumRefinementLemmas() { return d_refinement_lemmas.size(); } - /** get refinement lemma */ + /** get refinement lemma + * + * If d_embed_quant is forall d. exists y. P( d, y ), then a refinement + * lemma is one of the form ~P( d_candidates, c ) for some c. + */ Node getRefinementLemma( unsigned i ) { return d_refinement_lemmas[i]; } - /** get refinement lemma */ - Node getRefinementBaseLemma( unsigned i ) { return d_refinement_lemmas_base[i]; } + /** sample add refinement lemma + * + * This function will check if there is a sample point in d_sampler that + * refutes the candidate solution (d_quant_vars->vals). If so, it adds a + * refinement lemma to the lists d_refinement_lemmas that corresponds to that + * sample point, and adds a lemma to lems if cegisSample mode is not trust. + */ + bool sampleAddRefinementLemma(std::vector<Node>& vals, + std::vector<Node>& lems); //-----------------------------------end refinement lemmas /** get program by examples utility */ @@ -133,14 +159,21 @@ private: /** grammar utility */ std::unique_ptr<CegGrammarConstructor> d_ceg_gc; /** list of constants for quantified formula - * The Skolems for the negation of d_embed_quant. + * The outer Skolems for the negation of d_embed_quant. */ std::vector< Node > d_candidates; /** base instantiation * If d_embed_quant is forall d. exists y. P( d, y ), then - * this is the formula P( candidates, y ). + * this is the formula exists y. P( d_candidates, y ). */ Node d_base_inst; + /** If d_base_inst is exists y. P( d, y ), then this is y. */ + std::vector<Node> d_base_vars; + /** + * If d_base_inst is exists y. P( d, y ), then this is the formula + * P( d_candidates, y ). + */ + Node d_base_body; /** expand base inst to disjuncts */ std::vector< Node > d_base_disj; /** list of variables on inner quantification */ @@ -152,14 +185,13 @@ private: //-----------------------------------refinement lemmas /** refinement lemmas */ std::vector< Node > d_refinement_lemmas; - std::vector< Node > d_refinement_lemmas_base; //-----------------------------------end refinement lemmas - /** quantified formula asserted */ + /** the asserted (negated) conjecture */ Node d_quant; - /** quantified formula (after simplification) */ + /** (negated) conjecture after simplification */ Node d_simp_quant; - /** quantified formula (after simplification, conversion to deep embedding) */ + /** (negated) conjecture after simplification, conversion to deep embedding */ Node d_embed_quant; /** candidate information */ class CandidateInfo { @@ -183,11 +215,38 @@ private: d_cinfo[d_candidates[i]].d_inst.push_back( vs[i] ); } } + /** get synth solutions internal + * + * This function constructs the body of solutions for all + * functions-to-synthesize in this conjecture and stores them in sols, in + * order. For each solution added to sols, we add an integer indicating what + * kind of solution n is, where if sols[i] = n, then + * if status[i] = 0: n is the (builtin term) corresponding to the solution, + * if status[i] = 1: n is the sygus representation of the solution. + * We store builtin versions under some conditions (such as when the sygus + * grammar is being ignored). + * + * singleInvocation : set to true if we should consult the single invocation + * module to get synthesis solutions. + * + * For example, for conjecture exists fg. forall x. f(x)>g(x), this function + * may set ( sols, status ) to ( { x+1, d_x() }, { 1, 0 } ), where d_x() is + * the sygus datatype constructor corresponding to variable x. + */ + void getSynthSolutionsInternal(std::vector<Node>& sols, + std::vector<int>& status, + bool singleInvocation); //-------------------------------- sygus stream /** the streaming guards for sygus streaming mode */ std::vector< Node > d_stream_guards; /** get current stream guard */ Node getCurrentStreamGuard() const; + /** get stream guarded lemma + * + * If sygusStream is enabled, this returns ( G V n ) where G is the guard + * returned by getCurrentStreamGuard, otherwise this returns n. + */ + Node getStreamGuardedLemma(Node n) const; //-------------------------------- end sygus stream //-------------------------------- non-syntax guided (deprecated) /** Whether we are syntax-guided (e.g. was the input in SyGuS format). @@ -197,6 +256,24 @@ private: /** the guard for non-syntax-guided synthesis */ Node d_nsg_guard; //-------------------------------- end non-syntax guided (deprecated) + /** sygus sampler objects for each program variable + * + * This is used for the sygusRewSynth() option to synthesize new candidate + * rewrite rules. + */ + std::map<Node, SygusSampler> d_sampler; + /** sampler object for the option cegisSample() + * + * This samples points of the type of the inner variables of the synthesis + * conjecture (d_base_vars). + */ + SygusSampler d_cegis_sampler; + /** cegis sample refine points + * + * Stores the list of indices of sample points in d_cegis_sampler we have + * added as refinement lemmas. + */ + std::unordered_set<unsigned> d_cegis_sample_refine; }; } /* namespace CVC4::theory::quantifiers */ diff --git a/src/theory/quantifiers/ce_guided_instantiation.cpp b/src/theory/quantifiers/ce_guided_instantiation.cpp index b54ce4805..38cfb9ba7 100644 --- a/src/theory/quantifiers/ce_guided_instantiation.cpp +++ b/src/theory/quantifiers/ce_guided_instantiation.cpp @@ -238,17 +238,33 @@ void CegInstantiation::checkCegConjecture( CegConjecture * conj ) { void CegInstantiation::getCRefEvaluationLemmas( CegConjecture * conj, std::vector< Node >& vs, std::vector< Node >& ms, std::vector< Node >& lems ) { Trace("sygus-cref-eval") << "Cref eval : conjecture has " << conj->getNumRefinementLemmas() << " refinement lemmas." << std::endl; - if( conj->getNumRefinementLemmas()>0 ){ + unsigned nlemmas = conj->getNumRefinementLemmas(); + if (nlemmas > 0 || options::cegisSample() != CEGIS_SAMPLE_NONE) + { Assert( vs.size()==ms.size() ); TermDbSygus* tds = d_quantEngine->getTermDatabaseSygus(); Node nfalse = d_quantEngine->getTermUtil()->d_false; Node neg_guard = conj->getGuard().negate(); - for( unsigned i=0; i<conj->getNumRefinementLemmas(); i++ ){ + for (unsigned i = 0; i <= nlemmas; i++) + { + if (i == nlemmas) + { + bool addedSample = false; + // find a new one by sampling, if applicable + if (options::cegisSample() != CEGIS_SAMPLE_NONE) + { + addedSample = conj->sampleAddRefinementLemma(ms, lems); + } + if (!addedSample) + { + return; + } + } Node lem; std::map< Node, Node > visited; std::map< Node, std::vector< Node > > exp; - lem = conj->getRefinementBaseLemma( i ); + lem = conj->getRefinementLemma(i); if( !lem.isNull() ){ std::vector< Node > lem_conj; //break into conjunctions @@ -310,14 +326,28 @@ void CegInstantiation::getCRefEvaluationLemmas( CegConjecture * conj, std::vecto } void CegInstantiation::printSynthSolution( std::ostream& out ) { - if( d_conj->isAssigned() ){ - // print the conjecture + if( d_conj->isAssigned() ) + { d_conj->printSynthSolution( out, d_last_inst_si ); - }else{ + } + else + { Assert( false ); } } +void CegInstantiation::getSynthSolutions(std::map<Node, Node>& sol_map) +{ + if (d_conj->isAssigned()) + { + d_conj->getSynthSolutions(sol_map, d_last_inst_si); + } + else + { + Assert(false); + } +} + void CegInstantiation::preregisterAssertion( Node n ) { //check if it sygus conjecture if( QuantAttributes::checkSygusConjecture( n ) ){ diff --git a/src/theory/quantifiers/ce_guided_instantiation.h b/src/theory/quantifiers/ce_guided_instantiation.h index 86f0c4c9f..691363311 100644 --- a/src/theory/quantifiers/ce_guided_instantiation.h +++ b/src/theory/quantifiers/ce_guided_instantiation.h @@ -55,6 +55,17 @@ public: std::string identify() const { return "CegInstantiation"; } /** print solution for synthesis conjectures */ void printSynthSolution( std::ostream& out ); + /** get synth solutions + * + * This function adds entries to sol_map that map functions-to-synthesize + * with their solutions, for all active conjectures (currently just the one + * assigned to d_conj). This should be called immediately after the solver + * answers unsat for sygus input. + * + * For details on what is added to sol_map, see + * CegConjecture::getSynthSolutions. + */ + void getSynthSolutions(std::map<Node, Node>& sol_map); /** preregister assertion (before rewrite) */ void preregisterAssertion( Node n ); public: diff --git a/src/theory/quantifiers/ce_guided_pbe.cpp b/src/theory/quantifiers/ce_guided_pbe.cpp index bee19daeb..7f339be5f 100644 --- a/src/theory/quantifiers/ce_guided_pbe.cpp +++ b/src/theory/quantifiers/ce_guided_pbe.cpp @@ -925,8 +925,13 @@ void CegConjecturePbe::staticLearnRedundantOps( Node c, std::vector< Node >& lem Trace("sygus-unif") << "Strategy for candidate " << c << " is : " << std::endl; std::map<Node, std::map<NodeRole, bool> > visited; std::map<Node, std::map<unsigned, bool> > needs_cons; - staticLearnRedundantOps( - c, d_cinfo[c].getRootEnumerator(), role_equal, visited, needs_cons, 0); + staticLearnRedundantOps(c, + d_cinfo[c].getRootEnumerator(), + role_equal, + visited, + needs_cons, + 0, + false); // now, check the needs_cons map for (std::pair<const Node, std::map<unsigned, bool> >& nce : needs_cons) { @@ -957,19 +962,30 @@ void CegConjecturePbe::staticLearnRedundantOps( NodeRole nrole, std::map<Node, std::map<NodeRole, bool> >& visited, std::map<Node, std::map<unsigned, bool> >& needs_cons, - int ind) + int ind, + bool isCond) { std::map< Node, EnumInfo >::iterator itn = d_einfo.find( e ); Assert( itn!=d_einfo.end() ); - if (visited[e].find(nrole) == visited[e].end()) + + if (visited[e].find(nrole) == visited[e].end() + || (isCond && !itn->second.isConditional())) { visited[e][nrole] = true; - + // if conditional + if (isCond) + { + itn->second.setConditional(); + } indent("sygus-unif", ind); Trace("sygus-unif") << e << " :: node role : " << nrole; Trace("sygus-unif") << ", type : " << ((DatatypeType)e.getType().toType()).getDatatype().getName(); + if (isCond) + { + Trace("sygus-unif") << ", conditional"; + } Trace("sygus-unif") << ", enum role : " << itn->second.getRole(); if( itn->second.isTemplated() ){ @@ -991,57 +1007,65 @@ void CegConjecturePbe::staticLearnRedundantOps( Assert(itsn != tinfo.d_snodes.end()); StrategyNode& snode = itsn->second; - if (!snode.d_strats.empty()) + if (snode.d_strats.empty()) { - std::map<unsigned, bool> needs_cons_curr; - // various strategies - for (unsigned j = 0, size = snode.d_strats.size(); j < size; j++) + return; + } + std::map<unsigned, bool> needs_cons_curr; + // various strategies + for (unsigned j = 0, size = snode.d_strats.size(); j < size; j++) + { + EnumTypeInfoStrat* etis = snode.d_strats[j]; + StrategyType strat = etis->d_this; + bool newIsCond = isCond || strat == strat_ITE; + indent("sygus-unif", ind + 1); + Trace("sygus-unif") << "Strategy : " << strat + << ", from cons : " << etis->d_cons << std::endl; + int cindex = Datatype::indexOf(etis->d_cons.toExpr()); + Assert(cindex != -1); + needs_cons_curr[static_cast<unsigned>(cindex)] = false; + for (std::pair<Node, NodeRole>& cec : etis->d_cenum) { - EnumTypeInfoStrat* etis = snode.d_strats[j]; - StrategyType strat = etis->d_this; - indent("sygus-unif", ind+1); - Trace("sygus-unif") << "Strategy : " << strat - << ", from cons : " << etis->d_cons << std::endl; - int cindex = Datatype::indexOf(etis->d_cons.toExpr()); - Assert(cindex != -1); - needs_cons_curr[static_cast<unsigned>(cindex)] = false; - for (std::pair<Node, NodeRole>& cec : etis->d_cenum) - { - // recurse - staticLearnRedundantOps( - c, cec.first, cec.second, visited, needs_cons, ind + 2); - } + // recurse + staticLearnRedundantOps(c, + cec.first, + cec.second, + visited, + needs_cons, + ind + 2, + newIsCond); } - // get the master enumerator for the type of this enumerator - std::map<TypeNode, Node>::iterator itse = - d_cinfo[c].d_search_enum.find(etn); - if (itse != d_cinfo[c].d_search_enum.end()) + } + // get the master enumerator for the type of this enumerator + std::map<TypeNode, Node>::iterator itse = + d_cinfo[c].d_search_enum.find(etn); + if (itse == d_cinfo[c].d_search_enum.end()) + { + return; + } + Node em = itse->second; + Assert(!em.isNull()); + // get the current datatype + const Datatype& dt = + static_cast<DatatypeType>(etn.toType()).getDatatype(); + // all constructors that are not a part of a strategy are needed + for (unsigned j = 0, size = dt.getNumConstructors(); j < size; j++) + { + if (needs_cons_curr.find(j) == needs_cons_curr.end()) { - Node em = itse->second; - Assert(!em.isNull()); - // get the current datatype - const Datatype& dt = - static_cast<DatatypeType>(etn.toType()).getDatatype(); - // all constructors that are not a part of a strategy are needed - for (unsigned j = 0, size = dt.getNumConstructors(); j < size; j++) - { - if (needs_cons_curr.find(j) == needs_cons_curr.end()) - { - needs_cons_curr[j] = true; - } - } - // update the constructors that the master enumerator needs - if (needs_cons.find(em) == needs_cons.end()) - { - needs_cons[em] = needs_cons_curr; - } - else - { - for (unsigned j = 0, size = dt.getNumConstructors(); j < size; j++) - { - needs_cons[em][j] = needs_cons[em][j] || needs_cons_curr[j]; - } - } + needs_cons_curr[j] = true; + } + } + // update the constructors that the master enumerator needs + if (needs_cons.find(em) == needs_cons.end()) + { + needs_cons[em] = needs_cons_curr; + } + else + { + for (unsigned j = 0, size = dt.getNumConstructors(); j < size; j++) + { + needs_cons[em][j] = needs_cons[em][j] || needs_cons_curr[j]; } } } @@ -1116,21 +1140,53 @@ void CegConjecturePbe::addEnumeratedValue( Node x, Node v, std::vector< Node >& std::map< Node, EnumInfo >::iterator it = d_einfo.find( x ); Assert( it != d_einfo.end() ); Node gstatus = d_qe->getValuation().getSatValue(it->second.d_active_guard); - if (!gstatus.isNull() && gstatus.getConst<bool>()) + if (gstatus.isNull() || !gstatus.getConst<bool>()) { - Assert( std::find( it->second.d_enum_vals.begin(), it->second.d_enum_vals.end(), v )==it->second.d_enum_vals.end() ); - Node c = it->second.d_parent_candidate; - Node exp_exc; - if( d_examples_out_invalid.find( c )==d_examples_out_invalid.end() ){ - std::map< Node, CandidateInfo >::iterator itc = d_cinfo.find( c ); - Assert( itc != d_cinfo.end() ); - TypeNode xtn = x.getType(); - Node bv = d_tds->sygusToBuiltin( v, xtn ); - std::map< Node, std::vector< std::vector< Node > > >::iterator itx = d_examples.find( c ); - std::map< Node, std::vector< Node > >::iterator itxo = d_examples_out.find( c ); - Assert( itx!=d_examples.end() ); - Assert( itxo!=d_examples_out.end() ); - Assert( itx->second.size()==itxo->second.size() ); + Trace("sygus-pbe-enum-debug") << " ...guard is inactive." << std::endl; + return; + } + Assert( + std::find(it->second.d_enum_vals.begin(), it->second.d_enum_vals.end(), v) + == it->second.d_enum_vals.end()); + Node c = it->second.d_parent_candidate; + // The explanation for why the current value should be excluded in future + // iterations. + Node exp_exc; + if (d_examples_out_invalid.find(c) == d_examples_out_invalid.end()) + { + std::map<Node, CandidateInfo>::iterator itc = d_cinfo.find(c); + Assert(itc != d_cinfo.end()); + TypeNode xtn = x.getType(); + Node bv = d_tds->sygusToBuiltin(v, xtn); + std::map<Node, std::vector<std::vector<Node> > >::iterator itx = + d_examples.find(c); + std::map<Node, std::vector<Node> >::iterator itxo = d_examples_out.find(c); + Assert(itx != d_examples.end()); + Assert(itxo != d_examples_out.end()); + Assert(itx->second.size() == itxo->second.size()); + std::vector<Node> base_results; + // compte the results + for (unsigned j = 0, size = itx->second.size(); j < size; j++) + { + Node res = d_tds->evaluateBuiltin(xtn, bv, itx->second[j]); + Trace("sygus-pbe-enum-debug") + << "...got res = " << res << " from " << bv << std::endl; + base_results.push_back(res); + } + // is it excluded for domain-specific reason? + std::vector<Node> exp_exc_vec; + if (getExplanationForEnumeratorExclude( + c, x, v, base_results, it->second, exp_exc_vec)) + { + Assert(!exp_exc_vec.empty()); + exp_exc = exp_exc_vec.size() == 1 + ? exp_exc_vec[0] + : NodeManager::currentNM()->mkNode(AND, exp_exc_vec); + Trace("sygus-pbe-enum") + << " ...fail : term is excluded (domain-specific)" << std::endl; + } + else + { // notify all slaves Assert( !it->second.d_enum_slave.empty() ); //explanation for why this value should be excluded @@ -1153,9 +1209,9 @@ void CegConjecturePbe::addEnumeratedValue( Node x, Node v, std::vector< Node >& Node templ = itv->second.d_template; TNode templ_var = itv->second.d_template_arg; std::map< Node, bool > cond_vals; - for( unsigned j=0; j<itx->second.size(); j++ ){ - Node res = d_tds->evaluateBuiltin( xtn, bv, itx->second[j] ); - Trace("sygus-pbe-enum-debug") << "...got res = " << res << " from " << bv << std::endl; + for (unsigned j = 0, size = base_results.size(); j < size; j++) + { + Node res = base_results[j]; Assert( res.isConst() ); if( !templ.isNull() ){ TNode tres = res; @@ -1185,7 +1241,9 @@ void CegConjecturePbe::addEnumeratedValue( Node x, Node v, std::vector< Node >& bool keep = false; if (itv->second.getRole() == enum_io) { - if( cond_vals.find( d_true )!=cond_vals.end() || cond_vals.empty() ){ // latter is the degenerate case of no examples + // latter is the degenerate case of no examples + if (cond_vals.find(d_true) != cond_vals.end() || cond_vals.empty()) + { //check subsumbed/subsuming std::vector< Node > subsume; if( cond_vals.find( d_false )==cond_vals.end() ){ @@ -1217,113 +1275,141 @@ void CegConjecturePbe::addEnumeratedValue( Node x, Node v, std::vector< Node >& Trace("sygus-pbe-enum") << " ...fail : it does not satisfy examples." << std::endl; } }else{ - // is it excluded for domain-specific reason? - std::vector< Node > exp_exc_vec; - if( getExplanationForEnumeratorExclude( c, x, v, results, it->second, exp_exc_vec ) ){ - Assert( !exp_exc_vec.empty() ); - exp_exc = exp_exc_vec.size() == 1 - ? exp_exc_vec[0] - : NodeManager::currentNM()->mkNode(AND, exp_exc_vec); - Trace("sygus-pbe-enum") << " ...fail : term is excluded (domain-specific)" << std::endl; + // must be unique up to examples + Node val = itv->second.d_term_trie.addCond(this, v, results, true); + if (val == v) + { + Trace("sygus-pbe-enum") << " ...success! add to PBE pool : " + << d_tds->sygusToBuiltin(v) << std::endl; + keep = true; }else{ - //if( cond_vals.size()!=2 ){ - // // must discriminate - // Trace("sygus-pbe-enum") << " ...fail : conditional is constant." << std::endl; - // keep = false; - //} - // must be unique up to examples - Node val = itv->second.d_term_trie.addCond( this, v, results, true ); - if( val==v ){ - Trace("sygus-pbe-enum") << " ...success! add to PBE pool : " << d_tds->sygusToBuiltin( v ) << std::endl; - keep = true; - }else{ - Trace("sygus-pbe-enum") << " ...fail : term is not unique" << std::endl; - } - itc->second.d_cond_count++; + Trace("sygus-pbe-enum") + << " ...fail : term is not unique" << std::endl; } + itc->second.d_cond_count++; } if( keep ){ // notify the parent to retry the build of PBE itc->second.d_check_sol = true; itv->second.addEnumValue( this, v, results ); - /* - if( Trace.isOn("sygus-pbe-enum") ){ - if( itv->second.getRole()==enum_io ){ - if( !prevIsCover && itv->second.isFeasible() ){ - Trace("sygus-pbe-enum") << "...PBE : success : Evaluation of " - << xs << " now covers all examples." << std::endl; - } - } - } - */ } } - }else{ - Trace("sygus-pbe-enum-debug") << " ...examples do not have output." << std::endl; - } - //exclude this value on subsequent iterations - Node g = it->second.d_active_guard; - if( exp_exc.isNull() ){ - // if we did not already explain why this should be excluded, use default - exp_exc = d_tds->getExplain()->getExplanationForConstantEquality(x, v); - } - Node exlem = - NodeManager::currentNM()->mkNode(OR, g.negate(), exp_exc.negate()); - Trace("sygus-pbe-enum-lemma") << "CegConjecturePbe : enumeration exclude lemma : " << exlem << std::endl; - lems.push_back( exlem ); + } }else{ - Trace("sygus-pbe-enum-debug") << " ...guard is inactive." << std::endl; + Trace("sygus-pbe-enum-debug") + << " ...examples do not have output." << std::endl; + } + // exclude this value on subsequent iterations + Node g = it->second.d_active_guard; + if (exp_exc.isNull()) + { + // if we did not already explain why this should be excluded, use default + exp_exc = d_tds->getExplain()->getExplanationForConstantEquality(x, v); } + Node exlem = + NodeManager::currentNM()->mkNode(OR, g.negate(), exp_exc.negate()); + Trace("sygus-pbe-enum-lemma") + << "CegConjecturePbe : enumeration exclude lemma : " << exlem + << std::endl; + lems.push_back(exlem); } -bool CegConjecturePbe::getExplanationForEnumeratorExclude( Node c, Node x, Node v, std::vector< Node >& results, EnumInfo& ei, std::vector< Node >& exp ) { - if( ei.d_enum_slave.size()==1 ){ - // this check whether the example evaluates to something that is larger than the output - // if so, then this term is never useful when using a concatenation strategy - if (ei.getRole() == enum_concat_term) +bool CegConjecturePbe::useStrContainsEnumeratorExclude(Node x, EnumInfo& ei) +{ + TypeNode xbt = d_tds->sygusToBuiltinType(x.getType()); + if (xbt.isString()) + { + std::map<Node, bool>::iterator itx = d_use_str_contains_eexc.find(x); + if (itx != d_use_str_contains_eexc.end()) { - if( Trace.isOn("sygus-pbe-cterm-debug") ){ - Trace("sygus-pbe-enum") << std::endl; + return itx->second; + } + Trace("sygus-pbe-enum-debug") + << "Is " << x << " is str.contains exclusion?" << std::endl; + d_use_str_contains_eexc[x] = true; + for (const Node& sn : ei.d_enum_slave) + { + std::map<Node, EnumInfo>::iterator itv = d_einfo.find(sn); + EnumRole er = itv->second.getRole(); + if (er != enum_io && er != enum_concat_term) + { + Trace("sygus-pbe-enum-debug") << " incompatible slave : " << sn + << ", role = " << er << std::endl; + d_use_str_contains_eexc[x] = false; + return false; } + if (itv->second.isConditional()) + { + Trace("sygus-pbe-enum-debug") + << " conditional slave : " << sn << std::endl; + d_use_str_contains_eexc[x] = false; + return false; + } + } + Trace("sygus-pbe-enum-debug") + << "...can use str.contains exclusion." << std::endl; + return d_use_str_contains_eexc[x]; + } + return false; +} - // check if all examples had longer length that the output - std::map< Node, std::vector< Node > >::iterator itxo = d_examples_out.find( c ); - Assert( itxo!=d_examples_out.end() ); - Assert( itxo->second.size()==results.size() ); - Trace("sygus-pbe-cterm-debug") << "Check enumerator exclusion for " << x << " -> " << d_tds->sygusToBuiltin( v ) << " based on containment." << std::endl; - std::vector< unsigned > cmp_indices; - for( unsigned i=0; i<results.size(); i++ ){ - Assert( results[i].isConst() ); - Assert( itxo->second[i].isConst() ); - /* - unsigned vlen = results[i].getConst<String>().size(); - unsigned xlen = itxo->second[i].getConst<String>().size(); - Trace("sygus-pbe-cterm-debug") << " " << results[i] << " <> " << itxo->second[i]; - int index = vlen>xlen ? 1 : ( vlen<xlen ? -1 : 0 ); - Trace("sygus-pbe-cterm-debug") << "..." << index << std::endl; - cmp_indices[index].push_back( i ); - */ - Trace("sygus-pbe-cterm-debug") << " " << results[i] << " <> " << itxo->second[i]; - Node cont = NodeManager::currentNM()->mkNode( - STRING_STRCTN, itxo->second[i], results[i]); - Node contr = Rewriter::rewrite( cont ); - if( contr==d_false ){ - cmp_indices.push_back( i ); - Trace("sygus-pbe-cterm-debug") << "...not contained." << std::endl; - }else{ - Trace("sygus-pbe-cterm-debug") << "...contained." << std::endl; - } +bool CegConjecturePbe::getExplanationForEnumeratorExclude( + Node c, + Node x, + Node v, + std::vector<Node>& results, + EnumInfo& ei, + std::vector<Node>& exp) +{ + if (useStrContainsEnumeratorExclude(x, ei)) + { + NodeManager* nm = NodeManager::currentNM(); + // This check whether the example evaluates to something that is larger than + // the output for some input/output pair. If so, then this term is never + // useful. We generalize its explanation below. + + if (Trace.isOn("sygus-pbe-cterm-debug")) + { + Trace("sygus-pbe-enum") << std::endl; + } + // check if all examples had longer length that the output + std::map<Node, std::vector<Node> >::iterator itxo = d_examples_out.find(c); + Assert(itxo != d_examples_out.end()); + Assert(itxo->second.size() == results.size()); + Trace("sygus-pbe-cterm-debug") + << "Check enumerator exclusion for " << x << " -> " + << d_tds->sygusToBuiltin(v) << " based on str.contains." << std::endl; + std::vector<unsigned> cmp_indices; + for (unsigned i = 0, size = results.size(); i < size; i++) + { + Assert(results[i].isConst()); + Assert(itxo->second[i].isConst()); + Trace("sygus-pbe-cterm-debug") + << " " << results[i] << " <> " << itxo->second[i]; + Node cont = nm->mkNode(STRING_STRCTN, itxo->second[i], results[i]); + Node contr = Rewriter::rewrite(cont); + if (contr == d_false) + { + cmp_indices.push_back(i); + Trace("sygus-pbe-cterm-debug") << "...not contained." << std::endl; } - // TODO : stronger requirement if we incorporate ITE + CONCAT mixed strategy : must be longer than *all* examples - if( !cmp_indices.empty() ){ - //set up the inclusion set - NegContainsSygusInvarianceTest ncset; - ncset.init(d_parent, x, itxo->second, cmp_indices); - d_tds->getExplain()->getExplanationFor(x, v, exp, ncset); - Trace("sygus-pbe-cterm") << "PBE-cterm : enumerator exclude " << d_tds->sygusToBuiltin( v ) << " due to negative containment." << std::endl; - return true; + else + { + Trace("sygus-pbe-cterm-debug") << "...contained." << std::endl; } } + if (!cmp_indices.empty()) + { + // we check invariance with respect to a negative contains test + NegContainsSygusInvarianceTest ncset; + ncset.init(d_parent, x, itxo->second, cmp_indices); + // construct the generalized explanation + d_tds->getExplain()->getExplanationFor(x, v, exp, ncset); + Trace("sygus-pbe-cterm") + << "PBE-cterm : enumerator exclude " << d_tds->sygusToBuiltin(v) + << " due to negative containment." << std::endl; + return true; + } } return false; } diff --git a/src/theory/quantifiers/ce_guided_pbe.h b/src/theory/quantifiers/ce_guided_pbe.h index e8bccaac5..ce1f2bf5e 100644 --- a/src/theory/quantifiers/ce_guided_pbe.h +++ b/src/theory/quantifiers/ce_guided_pbe.h @@ -386,15 +386,26 @@ class CegConjecturePbe { * (possibly multiple) slave enumerators, stored in d_enum_slave, */ class EnumInfo { - public: - EnumInfo() : d_role( enum_io ){} + public: + EnumInfo() : d_role(enum_io), d_is_conditional(false) {} /** initialize this class * c is the parent function-to-synthesize * role is the "role" the enumerator plays in the high-level strategy, * which is one of enum_* above. */ void initialize(Node c, EnumRole role); + /** is this enumerator associated with a template? */ bool isTemplated() { return !d_template.isNull(); } + /** set conditional + * + * This flag is set to true if this enumerator may not apply to all + * input/output examples. For example, if this enumerator is used + * as an output value beneath a conditional in an instance of strat_ITE, + * then this enumerator is conditional. + */ + void setConditional() { d_is_conditional = true; } + /** is conditional */ + bool isConditional() { return d_is_conditional; } void addEnumValue(CegConjecturePbe* pbe, Node v, std::vector<Node>& results); @@ -406,26 +417,30 @@ class CegConjecturePbe { // for template Node d_template; Node d_template_arg; - + Node d_active_guard; - std::vector< Node > d_enum_slave; + std::vector<Node> d_enum_slave; /** values we have enumerated */ - std::vector< Node > d_enum_vals; + std::vector<Node> d_enum_vals; /** - * This either stores the values of f( I ) for inputs - * or the value of f( I ) = O if d_role==enum_io - */ - std::vector< std::vector< Node > > d_enum_vals_res; - std::vector< Node > d_enum_subsume; - std::map< Node, unsigned > d_enum_val_to_index; + * This either stores the values of f( I ) for inputs + * or the value of f( I ) = O if d_role==enum_io + */ + std::vector<std::vector<Node> > d_enum_vals_res; + std::vector<Node> d_enum_subsume; + std::map<Node, unsigned> d_enum_val_to_index; SubsumeTrie d_term_trie; private: - /** whether an enumerated value for this conjecture has solved the entire - * conjecture */ + /** + * Whether an enumerated value for this conjecture has solved the entire + * conjecture. + */ Node d_enum_solved; /** the role of this enumerator (one of enum_* above). */ EnumRole d_role; + /** is this enumerator conditional */ + bool d_is_conditional; }; /** maps enumerators to the information above */ std::map< Node, EnumInfo > d_einfo; @@ -524,9 +539,42 @@ class CegConjecturePbe { std::map< Node, CandidateInfo > d_cinfo; //------------------------------ representation of an enumeration strategy - /** add enumerated value */ + /** add enumerated value + * + * We have enumerated the value v for x. This function adds x->v to the + * relevant data structures that are used for strategy-specific construction + * of solutions when necessary, and returns a set of lemmas, which are added + * to the input argument lems. These lemmas are used to rule out models where + * x = v, to force that a new value is enumerated for x. + */ void addEnumeratedValue( Node x, Node v, std::vector< Node >& lems ); + /** domain-specific enumerator exclusion techniques + * + * Returns true if the value v for x can be excluded based on a + * domain-specific exclusion technique like the ones below. + * + * c : the candidate variable that x is enumerating for, + * results : the values of v under the input examples of c, + * ei : the enumerator information for x, + * exp : if this function returns true, then exp contains a (possibly + * generalize) explanation for why v can be excluded. + */ bool getExplanationForEnumeratorExclude( Node c, Node x, Node v, std::vector< Node >& results, EnumInfo& ei, std::vector< Node >& exp ); + /** returns true if we can exlude values of x based on negative str.contains + * + * Values v for x may be excluded if we realize that the value of v under the + * substitution for some input example will never be contained in some output + * example. For details on this technique, see NegContainsSygusInvarianceTest + * in sygus_invariance.h. + * + * This function depends on whether x is being used to enumerate values + * for any node that is conditional in the strategy graph. For example, + * nodes that are children of ITE strategy nodes are conditional. If any node + * is conditional, then this function returns false. + */ + bool useStrContainsEnumeratorExclude(Node x, EnumInfo& ei); + /** cache for the above function */ + std::map<Node, bool> d_use_str_contains_eexc; //------------------------------ strategy registration /** collect enumerator types @@ -567,6 +615,9 @@ class CegConjecturePbe { * to a map from the constructors that it needs. * * ind is the depth in the strategy graph we are at (for debugging). + * + * isCond is whether the current enumerator is conditional (beneath a + * conditional of an strat_ITE strategy). */ void staticLearnRedundantOps( Node c, @@ -574,7 +625,8 @@ class CegConjecturePbe { NodeRole nrole, std::map<Node, std::map<NodeRole, bool> >& visited, std::map<Node, std::map<unsigned, bool> >& needs_cons, - int ind); + int ind, + bool isCond); //------------------------------ end strategy registration //------------------------------ constructing solutions diff --git a/src/theory/quantifiers/ce_guided_single_inv_sol.cpp b/src/theory/quantifiers/ce_guided_single_inv_sol.cpp index 91c6e3089..74408a7c3 100644 --- a/src/theory/quantifiers/ce_guided_single_inv_sol.cpp +++ b/src/theory/quantifiers/ce_guided_single_inv_sol.cpp @@ -19,19 +19,27 @@ #include "theory/quantifiers/ce_guided_instantiation.h" #include "theory/quantifiers/ce_guided_single_inv.h" #include "theory/quantifiers/first_order_model.h" +#include "theory/quantifiers/quantifiers_attributes.h" #include "theory/quantifiers/term_database_sygus.h" #include "theory/quantifiers/term_enumeration.h" #include "theory/quantifiers/term_util.h" #include "theory/quantifiers/trigger.h" #include "theory/theory_engine.h" -using namespace CVC4; using namespace CVC4::kind; -using namespace CVC4::theory; -using namespace CVC4::theory::quantifiers; using namespace std; namespace CVC4 { +namespace theory { +namespace quantifiers { + +bool doCompare(Node a, Node b, Kind k) +{ + Node com = NodeManager::currentNM()->mkNode(k, a, b); + com = Rewriter::rewrite(com); + Assert(com.getType().isBoolean()); + return com.isConst() && com.getConst<bool>(); +} CegConjectureSingleInvSol::CegConjectureSingleInvSol(QuantifiersEngine* qe) : d_qe(qe), d_id_count(0), d_root_id() {} @@ -720,7 +728,8 @@ int CegConjectureSingleInvSol::collectReconstructNodes( Node t, TypeNode stn, in return d_rcons_to_id[stn][t]; }else{ status = 1; - d_qe->getTermDatabaseSygus()->registerSygusType( stn ); + // register the type + registerType(stn); int id = allocate( t, stn ); d_rcons_to_status[stn][t] = -1; TypeNode tn = t.getType(); @@ -777,7 +786,7 @@ int CegConjectureSingleInvSol::collectReconstructNodes( Node t, TypeNode stn, in //try constant reconstruction if( min_t.isConst() ){ Trace("csi-rcons-debug") << "...try constant reconstruction." << std::endl; - Node min_t_c = d_qe->getTermDatabaseSygus()->builtinToSygusConst( min_t, stn ); + Node min_t_c = builtinToSygusConst(min_t, stn); if( !min_t_c.isNull() ){ Trace("csi-rcons-debug") << " constant reconstruction success for " << id << ", result = " << min_t_c << std::endl; d_reconstruct[id] = min_t_c; @@ -786,8 +795,8 @@ int CegConjectureSingleInvSol::collectReconstructNodes( Node t, TypeNode stn, in } if( status!=0 ){ //try identity functions - for( unsigned i=0; i<d_qe->getTermDatabaseSygus()->getNumIdFuncs( stn ); i++ ){ - unsigned ii = d_qe->getTermDatabaseSygus()->getIdFuncIndex( stn, i ); + for (unsigned ii : d_id_funcs[stn]) + { Assert( dt[ii].getNumArgs()==1 ); //try to directly reconstruct from single argument std::vector< Node > tchildren; @@ -813,7 +822,8 @@ int CegConjectureSingleInvSol::collectReconstructNodes( Node t, TypeNode stn, in success = false; int index_found; std::vector< Node > args; - if( d_qe->getTermDatabaseSygus()->getMatch( min_t, stn, index_found, args, karg, c_index ) ){ + if (getMatch(min_t, stn, index_found, args, karg, c_index)) + { success = true; status = 0; Node cons = Node::fromExpr( dt[index_found].getConstructor() ); @@ -1068,7 +1078,6 @@ void CegConjectureSingleInvSol::setReconstructed( int id, Node n ) { } void CegConjectureSingleInvSol::getEquivalentTerms( Kind k, Node n, std::vector< Node >& equiv ) { - Assert( n.getKind()!=k ); //? if( k==AND || k==OR ){ equiv.push_back( NodeManager::currentNM()->mkNode( k, n, n ) ); equiv.push_back( NodeManager::currentNM()->mkNode( k, n, NodeManager::currentNM()->mkConst( k==AND ) ) ); @@ -1187,4 +1196,317 @@ void CegConjectureSingleInvSol::registerEquivalentTerms( Node n ) { } } +Node CegConjectureSingleInvSol::builtinToSygusConst(Node c, + TypeNode tn, + int rcons_depth) +{ + std::map<Node, Node>::iterator it = d_builtin_const_to_sygus[tn].find(c); + if (it != d_builtin_const_to_sygus[tn].end()) + { + return it->second; + } + TermDbSygus* tds = d_qe->getTermDatabaseSygus(); + NodeManager* nm = NodeManager::currentNM(); + Node sc; + d_builtin_const_to_sygus[tn][c] = sc; + Assert(c.isConst()); + Assert(tn.isDatatype()); + const Datatype& dt = static_cast<DatatypeType>(tn.toType()).getDatatype(); + Trace("csi-rcons-debug") << "Try to reconstruct " << c << " in " + << dt.getName() << std::endl; + Assert(dt.isSygus()); + // if we are not interested in reconstructing constants, or the grammar allows + // them, return a proxy + if (!options::cegqiSingleInvReconstructConst() || dt.getSygusAllowConst()) + { + Node k = nm->mkSkolem("sy", tn, "sygus proxy"); + SygusPrintProxyAttribute spa; + k.setAttribute(spa, c); + sc = k; + } + else + { + int carg = tds->getOpConsNum(tn, c); + if (carg != -1) + { + sc = nm->mkNode(APPLY_CONSTRUCTOR, + Node::fromExpr(dt[carg].getConstructor())); + } + else + { + // identity functions + for (unsigned ii : d_id_funcs[tn]) + { + Assert(dt[ii].getNumArgs() == 1); + // try to directly reconstruct from single argument + TypeNode tnc = tds->getArgType(dt[ii], 0); + Trace("csi-rcons-debug") + << "Based on id function " << dt[ii].getSygusOp() + << ", try reconstructing " << c << " instead in " << tnc + << std::endl; + Node n = builtinToSygusConst(c, tnc, rcons_depth); + if (!n.isNull()) + { + sc = nm->mkNode( + APPLY_CONSTRUCTOR, Node::fromExpr(dt[ii].getConstructor()), n); + break; + } + } + if (sc.isNull()) + { + if (rcons_depth < 1000) + { + // accelerated, recursive reconstruction of constants + Kind pk = tds->getPlusKind(TypeNode::fromType(dt.getSygusType())); + if (pk != UNDEFINED_KIND) + { + int arg = tds->getKindConsNum(tn, pk); + if (arg != -1) + { + Kind ck = + tds->getComparisonKind(TypeNode::fromType(dt.getSygusType())); + Kind pkm = + tds->getPlusKind(TypeNode::fromType(dt.getSygusType()), true); + // get types + Assert(dt[arg].getNumArgs() == 2); + TypeNode tn1 = tds->getArgType(dt[arg], 0); + TypeNode tn2 = tds->getArgType(dt[arg], 1); + // initialize d_const_list for tn1 + registerType(tn1); + // iterate over all positive constants, largest to smallest + int start = d_const_list[tn1].size() - 1; + int end = d_const_list[tn1].size() - d_const_list_pos[tn1]; + for (int i = start; i >= end; --i) + { + Node c1 = d_const_list[tn1][i]; + // only consider if smaller than c, and + if (doCompare(c1, c, ck)) + { + Node c2 = nm->mkNode(pkm, c, c1); + c2 = Rewriter::rewrite(c2); + if (c2.isConst()) + { + // reconstruct constant on the other side + Node sc2 = builtinToSygusConst(c2, tn2, rcons_depth + 1); + if (!sc2.isNull()) + { + Node sc1 = builtinToSygusConst(c1, tn1, rcons_depth); + Assert(!sc1.isNull()); + sc = nm->mkNode(APPLY_CONSTRUCTOR, + Node::fromExpr(dt[arg].getConstructor()), + sc1, + sc2); + break; + } + } + } + } + } + } + } + } + } + } + d_builtin_const_to_sygus[tn][c] = sc; + return sc; +} + +struct sortConstants +{ + Kind d_comp_kind; + bool operator()(Node i, Node j) + { + return i != j && doCompare(i, j, d_comp_kind); + } +}; + +void CegConjectureSingleInvSol::registerType(TypeNode tn) +{ + if (d_const_list_pos.find(tn) != d_const_list_pos.end()) + { + return; + } + d_const_list_pos[tn] = 0; + Assert(tn.isDatatype()); + + TermDbSygus* tds = d_qe->getTermDatabaseSygus(); + // ensure it is registered + tds->registerSygusType(tn); + const Datatype& dt = static_cast<DatatypeType>(tn.toType()).getDatatype(); + TypeNode btn = TypeNode::fromType(dt.getSygusType()); + // for constant reconstruction + Kind ck = tds->getComparisonKind(btn); + Node z = d_qe->getTermUtil()->getTypeValue(btn, 0); + + // iterate over constructors + for (unsigned i = 0, ncons = dt.getNumConstructors(); i < ncons; i++) + { + Node n = Node::fromExpr(dt[i].getSygusOp()); + if (n.getKind() != kind::BUILTIN && n.isConst()) + { + d_const_list[tn].push_back(n); + if (ck != UNDEFINED_KIND && doCompare(z, n, ck)) + { + d_const_list_pos[tn]++; + } + } + if (dt[i].isSygusIdFunc()) + { + d_id_funcs[tn].push_back(i); + } + } + // sort the constant list + if (!d_const_list[tn].empty()) + { + if (ck != UNDEFINED_KIND) + { + sortConstants sc; + sc.d_comp_kind = ck; + std::sort(d_const_list[tn].begin(), d_const_list[tn].end(), sc); + } + Trace("csi-rcons") << "Type has " << d_const_list[tn].size() + << " constants..." << std::endl + << " "; + for (unsigned i = 0; i < d_const_list[tn].size(); i++) + { + Trace("csi-rcons") << d_const_list[tn][i] << " "; + } + Trace("csi-rcons") << std::endl; + Trace("csi-rcons") << "Of these, " << d_const_list_pos[tn] + << " are marked as positive." << std::endl; + } +} + +bool CegConjectureSingleInvSol::getMatch(Node p, + Node n, + std::map<int, Node>& s, + std::vector<int>& new_s) +{ + TermDbSygus* tds = d_qe->getTermDatabaseSygus(); + if (tds->isFreeVar(p)) + { + unsigned vnum = tds->getVarNum(p); + Node prev = s[vnum]; + s[vnum] = n; + if (prev.isNull()) + { + new_s.push_back(vnum); + } + return prev.isNull() || prev == n; + } + if (n.getNumChildren() == 0) + { + return p == n; + } + if (n.getKind() == p.getKind() && n.getNumChildren() == p.getNumChildren()) + { + // try both ways? + unsigned rmax = + TermUtil::isComm(n.getKind()) && n.getNumChildren() == 2 ? 2 : 1; + std::vector<int> new_tmp; + for (unsigned r = 0; r < rmax; r++) + { + bool success = true; + for (unsigned i = 0, size = n.getNumChildren(); i < size; i++) + { + int io = r == 0 ? i : (i == 0 ? 1 : 0); + if (!getMatch(p[i], n[io], s, new_tmp)) + { + success = false; + for (unsigned j = 0; j < new_tmp.size(); j++) + { + s.erase(new_tmp[j]); + } + new_tmp.clear(); + break; + } + } + if (success) + { + new_s.insert(new_s.end(), new_tmp.begin(), new_tmp.end()); + return true; + } + } + } + return false; +} + +bool CegConjectureSingleInvSol::getMatch(Node t, + TypeNode st, + int& index_found, + std::vector<Node>& args, + int index_exc, + int index_start) +{ + Assert(st.isDatatype()); + const Datatype& dt = static_cast<DatatypeType>(st.toType()).getDatatype(); + Assert(dt.isSygus()); + std::map<Kind, std::vector<Node> > kgens; + std::vector<Node> gens; + for (unsigned i = index_start, ncons = dt.getNumConstructors(); i < ncons; + i++) + { + if ((int)i != index_exc) + { + Node g = getGenericBase(st, dt, i); + gens.push_back(g); + kgens[g.getKind()].push_back(g); + Trace("csi-sol-debug") << "Check generic base : " << g << " from " + << dt[i].getName() << std::endl; + if (g.getKind() == t.getKind()) + { + Trace("csi-sol-debug") << "Possible match ? " << g << " " << t + << " for " << dt[i].getName() << std::endl; + std::map<int, Node> sigma; + std::vector<int> new_s; + if (getMatch(g, t, sigma, new_s)) + { + // we found an exact match + bool msuccess = true; + for (unsigned j = 0, nargs = dt[i].getNumArgs(); j < nargs; j++) + { + if (sigma[j].isNull()) + { + msuccess = false; + break; + } + else + { + args.push_back(sigma[j]); + } + } + if (msuccess) + { + index_found = i; + return true; + } + } + } + } + } + return false; +} + +Node CegConjectureSingleInvSol::getGenericBase(TypeNode tn, + const Datatype& dt, + int c) +{ + std::map<int, Node>::iterator it = d_generic_base[tn].find(c); + if (it != d_generic_base[tn].end()) + { + return it->second; + } + TermDbSygus* tds = d_qe->getTermDatabaseSygus(); + Assert(tds->isRegistered(tn)); + std::map<TypeNode, int> var_count; + std::map<int, Node> pre; + Node g = tds->mkGeneric(dt, c, var_count, pre); + Trace("csi-sol-debug") << "Generic is " << g << std::endl; + Node gr = Rewriter::rewrite(g); + Trace("csi-sol-debug") << "Generic rewritten is " << gr << std::endl; + d_generic_base[tn][c] = gr; + return gr; +} +} +} } diff --git a/src/theory/quantifiers/ce_guided_single_inv_sol.h b/src/theory/quantifiers/ce_guided_single_inv_sol.h index c5f976f02..7043e1ecf 100644 --- a/src/theory/quantifiers/ce_guided_single_inv_sol.h +++ b/src/theory/quantifiers/ce_guided_single_inv_sol.h @@ -27,6 +27,12 @@ namespace quantifiers { class CegConjectureSingleInv; +/** CegConjectureSingleInvSol + * + * This function implements Figure 5 of "Counterexample-Guided Quantifier + * Instantiation for Synthesis in SMT", Reynolds et al CAV 2015. + * + */ class CegConjectureSingleInvSol { friend class CegConjectureSingleInv; @@ -47,10 +53,32 @@ private: bool getAssignEquality( Node eq, std::vector< Node >& vars, std::vector< Node >& new_vars, std::vector< Node >& new_subs ); Node simplifySolutionNode( Node sol, TypeNode stn, std::map< Node, bool >& assign, std::vector< Node >& vars, std::vector< Node >& subs, int status ); -public: + + public: + CegConjectureSingleInvSol(QuantifiersEngine* qe); + /** simplify solution + * + * Returns the simplified version of node sol whose syntax is restricted by + * the grammar corresponding to sygus datatype stn. + */ Node simplifySolution( Node sol, TypeNode stn ); -//solution reconstruction -private: + /** reconstruct solution + * + * Returns (if possible) a node that is equivalent to sol those syntax + * matches the grammar corresponding to sygus datatype stn. + * The value reconstructed is set to 1 if we successfully return a node, + * otherwise it is set to -1. + */ + Node reconstructSolution(Node sol, TypeNode stn, int& reconstructed); + /** preregister conjecture + * + * q : the synthesis conjecture this class is for. + * This is used as a heuristic to find terms in the original conjecture which + * may be helpful for using during reconstruction. + */ + void preregisterConjecture(Node q); + + private: int d_id_count; int d_root_id; std::map< int, Node > d_id_node; @@ -85,11 +113,74 @@ private: void getEquivalentTerms( Kind k, Node n, std::vector< Node >& equiv ); //register equivalent terms void registerEquivalentTerms( Node n ); -public: - Node reconstructSolution( Node sol, TypeNode stn, int& reconstructed ); - void preregisterConjecture( Node q ); -public: - CegConjectureSingleInvSol( QuantifiersEngine * qe ); + /** builtin to sygus const + * + * Returns a sygus term of type tn that encodes the builtin constant c. + * If the sygus datatype tn allows any constant, this may return a variable + * with the attribute SygusPrintProxyAttribute that associates it with c. + * + * rcons_depth limits the number of recursive calls when doing accelerated + * constant reconstruction (currently limited to 1000). Notice this is hacky: + * depending upon order of calls, constant rcons may succeed, e.g. 1001, 999 + * vs. 999, 1001. + */ + Node builtinToSygusConst(Node c, TypeNode tn, int rcons_depth = 0); + /** cache for the above function */ + std::map<TypeNode, std::map<Node, Node> > d_builtin_const_to_sygus; + /** sorted list of constants, per type */ + std::map<TypeNode, std::vector<Node> > d_const_list; + /** number of positive constants, per type */ + std::map<TypeNode, unsigned> d_const_list_pos; + /** list of constructor indices whose operators are identity functions */ + std::map<TypeNode, std::vector<int> > d_id_funcs; + /** initialize the above information for sygus type tn */ + void registerType(TypeNode tn); + /** get generic base + * + * This returns the builtin term that is the analog of an application of the + * c^th constructor of dt to fresh variables. + */ + Node getGenericBase(TypeNode tn, const Datatype& dt, int c); + /** cache for the above function */ + std::map<TypeNode, std::map<int, Node> > d_generic_base; + /** get match + * + * This function attempts to find a substitution for which p = n. If + * successful, this function returns a substitution in the form of s/new_s, + * where: + * s : substitution, where the domain are indices of terms in the sygus + * term database, and + * new_s : the members that were added to s on this call. + * Otherwise, this function returns false and s and new_s are unmodified. + */ + bool getMatch(Node p, + Node n, + std::map<int, Node>& s, + std::vector<int>& new_s); + /** get match + * + * This function attempts to find a builtin term that is analog to a value + * of the sygus datatype st that is equivalent to n. If this function returns + * true, then it has found such a term. Then we set: + * index_found : updated to the constructor index of the sygus term whose + * analog to equivalent to n. + * args : builtin terms corresponding to the match, in order. + * Otherwise, this function returns false and index_found and args are + * unmodified. + * For example, for grammar: + * A -> 0 | 1 | x | +( A, A ) + * Given input ( 5 + (x+1) ) and A we would return true, where: + * index_found is set to 3 and args is set to { 5, x+1 }. + * + * index_exc : (if applicable) exclude a constructor index of st + * index_start : start index of constructors of st to try + */ + bool getMatch(Node n, + TypeNode st, + int& index_found, + std::vector<Node>& args, + int index_exc = -1, + int index_start = 0); }; diff --git a/src/theory/quantifiers/ceg_t_instantiator.h b/src/theory/quantifiers/ceg_t_instantiator.h index a607909cc..95295d214 100644 --- a/src/theory/quantifiers/ceg_t_instantiator.h +++ b/src/theory/quantifiers/ceg_t_instantiator.h @@ -187,43 +187,44 @@ class EprInstantiator : public Instantiator { class BvInstantiator : public Instantiator { public: BvInstantiator(QuantifiersEngine* qe, TypeNode tn); - virtual ~BvInstantiator(); - virtual void reset(CegInstantiator* ci, - SolvedForm& sf, - Node pv, - CegInstEffort effort) override; - virtual bool hasProcessAssertion(CegInstantiator* ci, - SolvedForm& sf, - Node pv, - CegInstEffort effort) override + ~BvInstantiator() override; + void reset(CegInstantiator* ci, + SolvedForm& sf, + Node pv, + CegInstEffort effort) override; + bool hasProcessAssertion(CegInstantiator* ci, + SolvedForm& sf, + Node pv, + CegInstEffort effort) override { return true; } - virtual Node hasProcessAssertion(CegInstantiator* ci, - SolvedForm& sf, - Node pv, - Node lit, - CegInstEffort effort) override; - virtual bool processAssertion(CegInstantiator* ci, - SolvedForm& sf, - Node pv, - Node lit, - Node alit, - CegInstEffort effort) override; - virtual bool processAssertions(CegInstantiator* ci, - SolvedForm& sf, - Node pv, - CegInstEffort effort) override; + Node hasProcessAssertion(CegInstantiator* ci, + SolvedForm& sf, + Node pv, + Node lit, + CegInstEffort effort) override; + bool processAssertion(CegInstantiator* ci, + SolvedForm& sf, + Node pv, + Node lit, + Node alit, + CegInstEffort effort) override; + bool processAssertions(CegInstantiator* ci, + SolvedForm& sf, + Node pv, + CegInstEffort effort) override; /** use model value * * We allow model values if we have not already tried an assertion, * and only at levels below full if cbqiFullEffort is false. */ - virtual bool useModelValue(CegInstantiator* ci, - SolvedForm& sf, - Node pv, - CegInstEffort effort) override; - virtual std::string identify() const { return "Bv"; } + bool useModelValue(CegInstantiator* ci, + SolvedForm& sf, + Node pv, + CegInstEffort effort) override; + std::string identify() const override { return "Bv"; } + private: // point to the bv inverter class BvInverter * d_inverter; @@ -281,7 +282,7 @@ class BvInstantiatorPreprocess : public InstantiatorPreprocess { public: BvInstantiatorPreprocess() {} - virtual ~BvInstantiatorPreprocess() {} + ~BvInstantiatorPreprocess() override {} /** register counterexample lemma * * This method modifies the contents of lems based on the extract terms @@ -308,8 +309,8 @@ class BvInstantiatorPreprocess : public InstantiatorPreprocess * since the added equalities ensure we are able to construct the proper * solved forms for variables in t and for the intermediate variables above. */ - virtual void registerCounterexampleLemma(std::vector<Node>& lems, - std::vector<Node>& ce_vars) override; + void registerCounterexampleLemma(std::vector<Node>& lems, + std::vector<Node>& ce_vars) override; private: /** collect extracts diff --git a/src/theory/quantifiers/extended_rewrite.cpp b/src/theory/quantifiers/extended_rewrite.cpp index 956822303..ba0860d38 100644 --- a/src/theory/quantifiers/extended_rewrite.cpp +++ b/src/theory/quantifiers/extended_rewrite.cpp @@ -16,8 +16,8 @@ #include "theory/arith/arith_msum.h" #include "theory/datatypes/datatypes_rewriter.h" +#include "theory/quantifiers/term_util.h" #include "theory/rewriter.h" -#include "theory/strings/theory_strings_rewriter.h" using namespace CVC4::kind; using namespace std; @@ -106,6 +106,13 @@ Node ExtendedRewriter::extendedRewrite(Node n) childChanged = nc != n[i] || childChanged; children.push_back(nc); } + // Some commutative operators have rewriters that are agnostic to order, + // thus, we sort here. + if (TermUtil::isComm(n.getKind())) + { + childChanged = true; + std::sort(children.begin(), children.end()); + } if (childChanged) { ret = NodeManager::currentNM()->mkNode(n.getKind(), children); @@ -123,8 +130,6 @@ Node ExtendedRewriter::extendedRewrite(Node n) // simple ITE pulling new_ret = extendedRewritePullIte(ret); } - // TODO (as part of #1343) - // ( ~contains( x, y ) --> false ) => ( ~x=y --> false ) } else if (ret.getKind() == kind::ITE) { diff --git a/src/theory/quantifiers/first_order_model.h b/src/theory/quantifiers/first_order_model.h index 0c4b6b7a4..f33151b4d 100644 --- a/src/theory/quantifiers/first_order_model.h +++ b/src/theory/quantifiers/first_order_model.h @@ -248,19 +248,21 @@ class Def; class FirstOrderModelFmc : public FirstOrderModel { friend class FullModelChecker; -private: + + private: /** models for UF */ std::map<Node, Def * > d_models; std::map<TypeNode, Node > d_type_star; Node intervalOp; /** get current model value */ - void processInitializeModelForTerm(Node n); -public: + void processInitializeModelForTerm(Node n) override; + + public: FirstOrderModelFmc(QuantifiersEngine * qe, context::Context* c, std::string name); ~FirstOrderModelFmc() override; - FirstOrderModelFmc * asFirstOrderModelFmc() { return this; } + FirstOrderModelFmc* asFirstOrderModelFmc() override { return this; } // initialize the model - void processInitialize( bool ispre ); + void processInitialize(bool ispre) override; Node getFunctionValue(Node op, const char* argPrefix ); bool isStar(Node n); diff --git a/src/theory/quantifiers/ho_trigger.h b/src/theory/quantifiers/ho_trigger.h index 4db3a660f..87f7fe07f 100644 --- a/src/theory/quantifiers/ho_trigger.h +++ b/src/theory/quantifiers/ho_trigger.h @@ -167,7 +167,7 @@ class HigherOrderTrigger : public Trigger * matching ground terms to function applications with variable heads. * See examples (EX1)-(EX3) above. */ - virtual bool sendInstantiation(InstMatch& m); + bool sendInstantiation(InstMatch& m) override; private: //-------------------- current information about the match diff --git a/src/theory/quantifiers/inst_match_generator.h b/src/theory/quantifiers/inst_match_generator.h index fc913c7cf..1903a0f95 100644 --- a/src/theory/quantifiers/inst_match_generator.h +++ b/src/theory/quantifiers/inst_match_generator.h @@ -222,7 +222,7 @@ class InstMatchGenerator : public IMGenerator { * * See Trigger::getActiveScore for details. */ - int getActiveScore(QuantifiersEngine* qe); + int getActiveScore(QuantifiersEngine* qe) override; /** exclude match * * Exclude matching d_match_pattern with Node n on subsequent calls to diff --git a/src/theory/quantifiers/inst_strategy_cbqi.h b/src/theory/quantifiers/inst_strategy_cbqi.h index c2520a973..26591c678 100644 --- a/src/theory/quantifiers/inst_strategy_cbqi.h +++ b/src/theory/quantifiers/inst_strategy_cbqi.h @@ -129,18 +129,19 @@ public: }; class InstStrategyCegqi : public InstStrategyCbqi { -protected: + protected: CegqiOutputInstStrategy * d_out; std::map< Node, CegInstantiator * > d_cinst; Node d_small_const; Node d_curr_quant; bool d_check_vts_lemma_lc; /** process functions */ - void processResetInstantiationRound( Theory::Effort effort ); - void process( Node f, Theory::Effort effort, int e ); + void processResetInstantiationRound(Theory::Effort effort) override; + void process(Node f, Theory::Effort effort, int e) override; /** register ce lemma */ - void registerCounterexampleLemma( Node q, Node lem ); -public: + void registerCounterexampleLemma(Node q, Node lem) override; + + public: InstStrategyCegqi( QuantifiersEngine * qe ); ~InstStrategyCegqi() override; @@ -148,14 +149,14 @@ public: bool isEligibleForInstantiation( Node n ); bool addLemma( Node lem ); /** identify */ - std::string identify() const { return std::string("Cegqi"); } + std::string identify() const override { return std::string("Cegqi"); } //get instantiator for quantifier CegInstantiator * getInstantiator( Node q ); //register quantifier - void registerQuantifier( Node q ); + void registerQuantifier(Node q) override; //presolve - void presolve(); + void presolve() override; }; } diff --git a/src/theory/quantifiers/model_builder.h b/src/theory/quantifiers/model_builder.h index 511aebf3b..4eb592b3e 100644 --- a/src/theory/quantifiers/model_builder.h +++ b/src/theory/quantifiers/model_builder.h @@ -178,7 +178,7 @@ class QModelBuilderDefault : public QModelBuilderIG //do InstGen techniques for quantifier, return number of lemmas produced int doInstGen(FirstOrderModel* fm, Node f) override; //theory-specific build models - void constructModelUf( FirstOrderModel* fm, Node op ); + void constructModelUf(FirstOrderModel* fm, Node op) override; protected: std::map< Node, QuantPhaseReq > d_phase_reqs; @@ -189,7 +189,10 @@ class QModelBuilderDefault : public QModelBuilderIG //options bool optReconsiderFuncConstants() { return true; } //has inst gen - bool hasInstGen( Node f ) { return !d_quant_selection_lit[f].isNull(); } + bool hasInstGen(Node f) override + { + return !d_quant_selection_lit[f].isNull(); + } }; }/* CVC4::theory::quantifiers namespace */ diff --git a/src/theory/quantifiers/quant_conflict_find.cpp b/src/theory/quantifiers/quant_conflict_find.cpp index 95f8e3093..23e2ad721 100644 --- a/src/theory/quantifiers/quant_conflict_find.cpp +++ b/src/theory/quantifiers/quant_conflict_find.cpp @@ -34,7 +34,7 @@ namespace CVC4 { namespace theory { namespace quantifiers { -QuantInfo::QuantInfo() : d_unassigned_nvar(0), d_mg(NULL), d_una_index(0) {} +QuantInfo::QuantInfo() : d_unassigned_nvar(0), d_una_index(0), d_mg(nullptr) {} QuantInfo::~QuantInfo() { delete d_mg; diff --git a/src/theory/quantifiers/sygus_grammar_norm.cpp b/src/theory/quantifiers/sygus_grammar_norm.cpp index 67c40d6aa..6776aca15 100644 --- a/src/theory/quantifiers/sygus_grammar_norm.cpp +++ b/src/theory/quantifiers/sygus_grammar_norm.cpp @@ -21,6 +21,7 @@ #include "smt/smt_engine.h" #include "smt/smt_engine_scope.h" #include "theory/quantifiers/ce_guided_conjecture.h" +#include "theory/quantifiers/sygus_grammar_red.h" #include "theory/quantifiers/term_database_sygus.h" #include "theory/quantifiers/term_util.h" @@ -34,7 +35,7 @@ namespace quantifiers { bool OpPosTrie::getOrMakeType(TypeNode tn, TypeNode& unres_tn, - std::vector<unsigned> op_pos, + const std::vector<unsigned>& op_pos, unsigned ind) { if (ind == op_pos.size()) @@ -118,6 +119,20 @@ void SygusGrammarNorm::TypeObject::buildDatatype(SygusGrammarNorm* sygus_norm, Trace("sygus-grammar-normalize") << "---------------------------------\n"; } +void SygusGrammarNorm::TransfDrop::buildType(SygusGrammarNorm* sygus_norm, + TypeObject& to, + const Datatype& dt, + std::vector<unsigned>& op_pos) +{ + std::vector<unsigned> difference; + std::set_difference(op_pos.begin(), + op_pos.end(), + d_drop_indices.begin(), + d_drop_indices.end(), + std::back_inserter(difference)); + op_pos = difference; +} + /* TODO #1304: have more operators and types. Moreover, have more general ways of finding kind of operator, e.g. if op is (\lambda xy. x + y) this function should realize that it is chainable for integers */ @@ -259,12 +274,38 @@ std::map<TypeNode, Node> SygusGrammarNorm::d_tn_to_id = {}; * * returns true if collected anything */ -SygusGrammarNorm::Transf* SygusGrammarNorm::inferTransf( +std::unique_ptr<SygusGrammarNorm::Transf> SygusGrammarNorm::inferTransf( TypeNode tn, const Datatype& dt, const std::vector<unsigned>& op_pos) { NodeManager* nm = NodeManager::currentNM(); TypeNode sygus_tn = TypeNode::fromType(dt.getSygusType()); - /* TODO #1304: step 0: look for redundant constructors to drop */ + Trace("sygus-gnorm") << "Infer transf for " << dt.getName() << "..." + << std::endl; + Trace("sygus-gnorm") << " #cons = " << op_pos.size() << " / " + << dt.getNumConstructors() << std::endl; + // look for redundant constructors to drop + if (options::sygusMinGrammar() && dt.getNumConstructors() == op_pos.size()) + { + SygusRedundantCons src; + src.initialize(d_qe, tn); + std::vector<unsigned> rindices; + src.getRedundant(rindices); + if (!rindices.empty()) + { + Trace("sygus-gnorm") << "...drop transf, " << rindices.size() << "/" + << op_pos.size() << " constructors." << std::endl; + Assert(rindices.size() < op_pos.size()); + return std::unique_ptr<Transf>(new TransfDrop(rindices)); + } + } + + // if normalization option is not enabled, we do not infer the transformations + // below + if (!options::sygusGrammarNorm()) + { + return nullptr; + } + /* TODO #1304: step 1: look for singleton */ /* step 2: look for chain */ unsigned chain_op_pos = dt.getNumConstructors(); @@ -319,9 +360,10 @@ SygusGrammarNorm::Transf* SygusGrammarNorm::inferTransf( /* Typenode admits a chain transformation for normalization */ if (chain_op_pos != dt.getNumConstructors() && !elem_pos.empty()) { + Trace("sygus-gnorm") << "...chain transf." << std::endl; Trace("sygus-grammar-normalize-infer") << "\tInfering chain transformation\n"; - return new TransfChain(chain_op_pos, elem_pos); + return std::unique_ptr<Transf>(new TransfChain(chain_op_pos, elem_pos)); } return nullptr; } @@ -372,19 +414,16 @@ TypeNode SygusGrammarNorm::normalizeSygusRec(TypeNode tn, } /* Creates type object for normalization */ TypeObject to(tn, unres_tn); - /* If normalization option enabled, infer transformations to be applied in the - * type */ - if (options::sygusGrammarNorm()) + + /* Determine normalization transformation based on sygus type and given + * operators */ + std::unique_ptr<Transf> transformation = inferTransf(tn, dt, op_pos); + /* If a transformation was selected, apply it */ + if (transformation != nullptr) { - /* Determine normalization transformation based on sygus type and given - * operators */ - Transf* transformation = inferTransf(tn, dt, op_pos); - /* If a transformation was selected, apply it */ - if (transformation != nullptr) - { - transformation->buildType(this, to, dt, op_pos); - } + transformation->buildType(this, to, dt, op_pos); } + /* Remaining operators are rebuilt as they are */ for (unsigned i = 0, size = op_pos.size(); i < size; ++i) { diff --git a/src/theory/quantifiers/sygus_grammar_norm.h b/src/theory/quantifiers/sygus_grammar_norm.h index 38e3f168e..f72a83e5a 100644 --- a/src/theory/quantifiers/sygus_grammar_norm.h +++ b/src/theory/quantifiers/sygus_grammar_norm.h @@ -17,7 +17,16 @@ #ifndef __CVC4__THEORY__QUANTIFIERS__SYGUS_GRAMMAR_NORM_H #define __CVC4__THEORY__QUANTIFIERS__SYGUS_GRAMMAR_NORM_H +#include <map> +#include <memory> +#include <string> +#include <vector> + +#include "expr/datatype.h" +#include "expr/node.h" #include "expr/node_manager_attributes.h" // for VarNameAttr +#include "expr/type.h" +#include "expr/type_node.h" #include "theory/quantifiers/term_util.h" #include "theory/quantifiers_engine.h" @@ -78,7 +87,7 @@ class OpPosTrie */ bool getOrMakeType(TypeNode tn, TypeNode& unres_tn, - std::vector<unsigned> op_pos, + const std::vector<unsigned>& op_pos, unsigned ind = 0); /** clear all data from this trie */ void clear() { d_children.clear(); } @@ -241,6 +250,8 @@ class SygusGrammarNorm class Transf { public: + virtual ~Transf() {} + /** abstract function for building normalized types * * Builds normalized types for the operators specifed by the positions in @@ -254,6 +265,27 @@ class SygusGrammarNorm std::vector<unsigned>& op_pos) = 0; }; /* class Transf */ + /** Drop transformation class + * + * This class builds a type by dropping a set of redundant constructors, + * whose indices are given as input to the constructor of this class. + */ + class TransfDrop : public Transf + { + public: + TransfDrop(const std::vector<unsigned>& indices) : d_drop_indices(indices) + { + } + /** build type */ + void buildType(SygusGrammarNorm* sygus_norm, + TypeObject& to, + const Datatype& dt, + std::vector<unsigned>& op_pos) override; + + private: + std::vector<unsigned> d_drop_indices; + }; + /** Chain transformation class * * Determines how to build normalized types by chaining the application of one @@ -275,7 +307,7 @@ class SygusGrammarNorm class TransfChain : public Transf { public: - TransfChain(unsigned chain_op_pos, std::vector<unsigned> elem_pos) + TransfChain(unsigned chain_op_pos, const std::vector<unsigned>& elem_pos) : d_chain_op_pos(chain_op_pos), d_elem_pos(elem_pos){}; /** builds types encoding a chain in which each link contains a repetition @@ -303,10 +335,10 @@ class SygusGrammarNorm * transformation and so on until all operators originally given are * considered. */ - virtual void buildType(SygusGrammarNorm* sygus_norm, - TypeObject& to, - const Datatype& dt, - std::vector<unsigned>& op_pos) override; + void buildType(SygusGrammarNorm* sygus_norm, + TypeObject& to, + const Datatype& dt, + std::vector<unsigned>& op_pos) override; /** Whether operator is chainable for the type (e.g. PLUS for Int) * @@ -411,9 +443,9 @@ class SygusGrammarNorm * * TODO: #1304: Infer more complex transformations */ - Transf* inferTransf(TypeNode tn, - const Datatype& dt, - const std::vector<unsigned>& op_pos); + std::unique_ptr<Transf> inferTransf(TypeNode tn, + const Datatype& dt, + const std::vector<unsigned>& op_pos); }; /* class SygusGrammarNorm */ } // namespace quantifiers diff --git a/src/theory/quantifiers/sygus_grammar_red.cpp b/src/theory/quantifiers/sygus_grammar_red.cpp new file mode 100644 index 000000000..056fc455a --- /dev/null +++ b/src/theory/quantifiers/sygus_grammar_red.cpp @@ -0,0 +1,136 @@ +/********************* */ +/*! \file sygus_grammar_red.cpp + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2017 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Implementation of sygus_grammar_red + **/ + +#include "theory/quantifiers/sygus_grammar_red.h" + +#include "options/quantifiers_options.h" +#include "theory/quantifiers/term_database_sygus.h" +#include "theory/quantifiers/term_util.h" + +using namespace std; +using namespace CVC4::kind; + +namespace CVC4 { +namespace theory { +namespace quantifiers { + +void SygusRedundantCons::initialize(QuantifiersEngine* qe, TypeNode tn) +{ + Assert(qe != nullptr); + Trace("sygus-red") << "Compute redundant cons for " << tn << std::endl; + d_type = tn; + Assert(tn.isDatatype()); + TermDbSygus* tds = qe->getTermDatabaseSygus(); + tds->registerSygusType(tn); + const Datatype& dt = static_cast<DatatypeType>(tn.toType()).getDatatype(); + Assert(dt.isSygus()); + TypeNode btn = TypeNode::fromType(dt.getSygusType()); + for (unsigned i = 0, ncons = dt.getNumConstructors(); i < ncons; i++) + { + Trace("sygus-red") << " Is " << dt[i].getName() << " a redundant operator?" + << std::endl; + std::map<int, Node> pre; + Node g = tds->mkGeneric(dt, i, pre); + Trace("sygus-red-debug") << " ...pre-rewrite : " << g << std::endl; + Assert(g.getNumChildren() == dt[i].getNumArgs()); + d_gen_terms[i] = g; + for (unsigned j = 0, nargs = dt[i].getNumArgs(); j < nargs; j++) + { + pre[j] = g[j]; + } + std::vector<Node> glist; + getGenericList(tds, dt, i, 0, pre, glist); + // call the extended rewriter + bool red = false; + for (const Node& gr : glist) + { + Trace("sygus-red-debug") << " ...variant : " << gr << std::endl; + std::map<Node, unsigned>::iterator itg = d_gen_cons.find(gr); + if (itg != d_gen_cons.end() && itg->second != i) + { + red = true; + Trace("sygus-red") << " ......redundant, since a variant of " << g + << " and " << d_gen_terms[itg->second] + << " both rewrite to " << gr << std::endl; + break; + } + else + { + d_gen_cons[gr] = i; + Trace("sygus-red") << " ......not redundant." << std::endl; + } + } + d_sygus_red_status.push_back(red ? 1 : 0); + } +} + +void SygusRedundantCons::getRedundant(std::vector<unsigned>& indices) +{ + const Datatype& dt = static_cast<DatatypeType>(d_type.toType()).getDatatype(); + for (unsigned i = 0, ncons = dt.getNumConstructors(); i < ncons; i++) + { + if (isRedundant(i)) + { + indices.push_back(i); + } + } +} + +bool SygusRedundantCons::isRedundant(unsigned i) +{ + Assert(i < d_sygus_red_status.size()); + return d_sygus_red_status[i] == 1; +} + +void SygusRedundantCons::getGenericList(TermDbSygus* tds, + const Datatype& dt, + unsigned c, + unsigned index, + std::map<int, Node>& pre, + std::vector<Node>& terms) +{ + if (index == dt[c].getNumArgs()) + { + Node gt = tds->mkGeneric(dt, c, pre); + gt = tds->getExtRewriter()->extendedRewrite(gt); + terms.push_back(gt); + return; + } + // with no swap + getGenericList(tds, dt, c, index + 1, pre, terms); + // swapping is exponential, only use for operators with small # args. + if (dt[c].getNumArgs() <= 5) + { + TypeNode atype = tds->getArgType(dt[c], index); + for (unsigned s = index + 1, nargs = dt[c].getNumArgs(); s < nargs; s++) + { + if (tds->getArgType(dt[c], s) == atype) + { + // swap s and index + Node tmp = pre[s]; + pre[s] = pre[index]; + pre[index] = tmp; + getGenericList(tds, dt, c, index + 1, pre, terms); + // revert + tmp = pre[s]; + pre[s] = pre[index]; + pre[index] = tmp; + } + } + } +} + +} /* CVC4::theory::quantifiers namespace */ +} /* CVC4::theory namespace */ +} /* CVC4 namespace */ diff --git a/src/theory/quantifiers/sygus_grammar_red.h b/src/theory/quantifiers/sygus_grammar_red.h new file mode 100644 index 000000000..b65a12da2 --- /dev/null +++ b/src/theory/quantifiers/sygus_grammar_red.h @@ -0,0 +1,119 @@ +/********************* */ +/*! \file sygus_grammar_red.h + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2017 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief sygus_grammar_red + **/ + +#include "cvc4_private.h" + +#ifndef __CVC4__THEORY__QUANTIFIERS__INSTANTIATE_H +#define __CVC4__THEORY__QUANTIFIERS__INSTANTIATE_H + +#include <map> +#include "theory/quantifiers_engine.h" + +namespace CVC4 { +namespace theory { +namespace quantifiers { + +/** SygusRedundantCons + * + * This class computes the subset of indices of the constructors of a sygus type + * that are redundant. To use this class, first call initialize( qe, tn ), + * where tn is a sygus tn. Then, use getRedundant and/or isRedundant to get the + * indicies of the constructors of tn that are redundant. + */ +class SygusRedundantCons +{ + public: + SygusRedundantCons() {} + ~SygusRedundantCons() {} + /** register type tn + * + * qe : pointer to the quantifiers engine, + * tn : the (sygus) type to compute redundant constructors for + */ + void initialize(QuantifiersEngine* qe, TypeNode tn); + /** Get the indices of the redundant constructors of the register type */ + void getRedundant(std::vector<unsigned>& indices); + /** + * This function returns true if the i^th constructor of the registered type + * is redundant. + */ + bool isRedundant(unsigned i); + + private: + /** the registered type */ + TypeNode d_type; + /** redundant status + * + * For each constructor, status indicating whether the constructor is + * redundant, where: + * + * 0 : not redundant, + * 1 : redundant since another constructor can be used to construct values for + * this constructor. + * + * For example, for grammar: + * A -> C > B | B < C | not D + * B -> x | y + * C -> 0 | 1 | C+C + * D -> B >= C + * If A is register with this class, then we store may store { 0, 1, 0 }, + * noting that the second constructor of A can be simulated with the first. + * Notice that the third constructor is not considered redundant. + */ + std::vector<int> d_sygus_red_status; + /** + * Map from constructor indices to the generic term for that constructor, + * where the generic term for a constructor is the (canonical) term returned + * by a call to TermDbSygus::mkGeneric. + */ + std::map<unsigned, Node> d_gen_terms; + /** + * Map from the rewritten form of generic terms for constructors of the + * registered type to their corresponding constructor index. + */ + std::map<Node, unsigned> d_gen_cons; + /** get generic list + * + * This function constructs all well-typed variants of a term of the form + * op( x1, ..., xn ) + * where op is the builtin operator for dt[c], and xi = pre[i] for i=1,...,n. + * + * It constructs a list of terms of the form g * sigma, where sigma + * is an automorphism on { x1...xn } such that for all xi -> xj in sigma, + * the type for arguments i and j of dt[c] are the same. We store this + * list of terms in terms. + * + * This function recurses on the arguments of g, index is the current argument + * we are processing, and pre stores the current arguments of + * + * For example, for a sygus grammar + * A -> and( A, A, B ) + * B -> false + * passing arguments such that g=and( x1, x2, x3 ) to this function will add: + * and( x1, x2, x3 ) and and( x2, x1, x3 ) + * to terms. + */ + void getGenericList(TermDbSygus* tds, + const Datatype& dt, + unsigned c, + unsigned index, + std::map<int, Node>& pre, + std::vector<Node>& terms); +}; + +} /* CVC4::theory::quantifiers namespace */ +} /* CVC4::theory namespace */ +} /* CVC4 namespace */ + +#endif /* __CVC4__THEORY__QUANTIFIERS__INSTANTIATE_H */ diff --git a/src/theory/quantifiers/sygus_invariance.cpp b/src/theory/quantifiers/sygus_invariance.cpp index 6813f4320..1fd6bc7cb 100644 --- a/src/theory/quantifiers/sygus_invariance.cpp +++ b/src/theory/quantifiers/sygus_invariance.cpp @@ -191,6 +191,7 @@ bool NegContainsSygusInvarianceTest::invariant(TermDbSygus* tds, Node out = d_exo[ii]; Node cont = NodeManager::currentNM()->mkNode(kind::STRING_STRCTN, out, nbvre); + Trace("sygus-pbe-cterm-debug") << "Check: " << cont << std::endl; Node contr = Rewriter::rewrite(cont); if (contr == tds->d_false) { @@ -216,6 +217,8 @@ bool NegContainsSygusInvarianceTest::invariant(TermDbSygus* tds, } return true; } + Trace("sygus-pbe-cterm-debug2") + << "...check failed, rewrites to : " << contr << std::endl; } } return false; diff --git a/src/theory/quantifiers/sygus_sampler.cpp b/src/theory/quantifiers/sygus_sampler.cpp new file mode 100644 index 000000000..f824cd6f7 --- /dev/null +++ b/src/theory/quantifiers/sygus_sampler.cpp @@ -0,0 +1,567 @@ +/********************* */ +/*! \file sygus_sampler.cpp + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2017 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Implementation of sygus_sampler + **/ + +#include "theory/quantifiers/sygus_sampler.h" + +#include "options/quantifiers_options.h" +#include "util/bitvector.h" +#include "util/random.h" + +namespace CVC4 { +namespace theory { +namespace quantifiers { + +Node LazyTrie::add(Node n, + LazyTrieEvaluator* ev, + unsigned index, + unsigned ntotal, + bool forceKeep) +{ + LazyTrie* lt = this; + while (lt != NULL) + { + if (index == ntotal) + { + // lazy child holds the leaf data + if (lt->d_lazy_child.isNull() || forceKeep) + { + lt->d_lazy_child = n; + } + return lt->d_lazy_child; + } + std::vector<Node> ex; + if (lt->d_children.empty()) + { + if (lt->d_lazy_child.isNull()) + { + // no one has been here, we are done + lt->d_lazy_child = n; + return lt->d_lazy_child; + } + // evaluate the lazy child + Node e_lc = ev->evaluate(lt->d_lazy_child, index); + // store at next level + lt->d_children[e_lc].d_lazy_child = lt->d_lazy_child; + // replace + lt->d_lazy_child = Node::null(); + } + // recurse + Node e = ev->evaluate(n, index); + lt = <->d_children[e]; + index = index + 1; + } + return Node::null(); +} + +SygusSampler::SygusSampler() : d_tds(nullptr), d_is_valid(false) {} + +void SygusSampler::initialize(TypeNode tn, + std::vector<Node>& vars, + unsigned nsamples) +{ + d_tds = nullptr; + d_is_valid = true; + d_tn = tn; + d_ftn = TypeNode::null(); + d_vars.insert(d_vars.end(), vars.begin(), vars.end()); + for (const Node& sv : vars) + { + TypeNode svt = sv.getType(); + d_var_index[sv] = d_type_vars[svt].size(); + d_type_vars[svt].push_back(sv); + } + d_rvalue_cindices.clear(); + d_rvalue_null_cindices.clear(); + d_var_sygus_types.clear(); + initializeSamples(nsamples); +} + +void SygusSampler::initializeSygus(TermDbSygus* tds, Node f, unsigned nsamples) +{ + d_tds = tds; + d_is_valid = true; + d_ftn = f.getType(); + Assert(d_ftn.isDatatype()); + const Datatype& dt = static_cast<DatatypeType>(d_ftn.toType()).getDatatype(); + Assert(dt.isSygus()); + d_tn = TypeNode::fromType(dt.getSygusType()); + + Trace("sygus-sample") << "Register sampler for " << f << std::endl; + + d_var_index.clear(); + d_type_vars.clear(); + // get the sygus variable list + Node var_list = Node::fromExpr(dt.getSygusVarList()); + if (!var_list.isNull()) + { + for (const Node& sv : var_list) + { + TypeNode svt = sv.getType(); + d_var_index[sv] = d_type_vars[svt].size(); + d_vars.push_back(sv); + d_type_vars[svt].push_back(sv); + } + } + d_rvalue_cindices.clear(); + d_rvalue_null_cindices.clear(); + d_var_sygus_types.clear(); + registerSygusType(d_ftn); + initializeSamples(nsamples); +} + +void SygusSampler::initializeSamples(unsigned nsamples) +{ + d_samples.clear(); + std::vector<TypeNode> types; + for (const Node& v : d_vars) + { + TypeNode vt = v.getType(); + types.push_back(vt); + Trace("sygus-sample") << " var #" << types.size() << " : " << v << " : " + << vt << std::endl; + } + std::map<unsigned, std::map<Node, std::vector<TypeNode> >::iterator> sts; + if (options::sygusSampleGrammar()) + { + for (unsigned j = 0, size = types.size(); j < size; j++) + { + sts[j] = d_var_sygus_types.find(d_vars[j]); + } + } + + unsigned nduplicates = 0; + for (unsigned i = 0; i < nsamples; i++) + { + std::vector<Node> sample_pt; + for (unsigned j = 0, size = types.size(); j < size; j++) + { + Node v = d_vars[j]; + Node r; + if (options::sygusSampleGrammar()) + { + // choose a random start sygus type, if possible + if (sts[j] != d_var_sygus_types.end()) + { + unsigned ntypes = sts[j]->second.size(); + Assert(ntypes > 0); + unsigned index = Random::getRandom().pick(0, ntypes - 1); + if (index < ntypes) + { + // currently hard coded to 0.0, 0.5 + r = getSygusRandomValue(sts[j]->second[index], 0.0, 0.5); + } + } + } + if (r.isNull()) + { + r = getRandomValue(types[j]); + if (r.isNull()) + { + d_is_valid = false; + } + } + sample_pt.push_back(r); + } + if (d_samples_trie.add(sample_pt)) + { + if (Trace.isOn("sygus-sample")) + { + Trace("sygus-sample") << "Sample point #" << i << " : "; + for (const Node& r : sample_pt) + { + Trace("sygus-sample") << r << " "; + } + Trace("sygus-sample") << std::endl; + } + d_samples.push_back(sample_pt); + } + else + { + i--; + nduplicates++; + if (nduplicates == nsamples * 10) + { + Trace("sygus-sample") + << "...WARNING: excessive duplicates, cut off sampling at " << i + << "/" << nsamples << " points." << std::endl; + break; + } + } + } + + d_trie.clear(); +} + +bool SygusSampler::PtTrie::add(std::vector<Node>& pt) +{ + PtTrie* curr = this; + for (unsigned i = 0, size = pt.size(); i < size; i++) + { + curr = &(curr->d_children[pt[i]]); + } + bool retVal = curr->d_children.empty(); + curr = &(curr->d_children[Node::null()]); + return retVal; +} + +Node SygusSampler::registerTerm(Node n, bool forceKeep) +{ + if (d_is_valid) + { + Assert(n.getType() == d_tn); + return d_trie.add(n, this, 0, d_samples.size(), forceKeep); + } + return n; +} + +bool SygusSampler::isContiguous(Node n) +{ + // compute free variables in n + std::vector<Node> fvs; + computeFreeVariables(n, fvs); + // compute contiguous condition + for (const std::pair<const TypeNode, std::vector<Node> >& p : d_type_vars) + { + bool foundNotFv = false; + for (const Node& v : p.second) + { + bool hasFv = std::find(fvs.begin(), fvs.end(), v) != fvs.end(); + if (!hasFv) + { + foundNotFv = true; + } + else if (foundNotFv) + { + return false; + } + } + } + return true; +} + +void SygusSampler::computeFreeVariables(Node n, std::vector<Node>& fvs) +{ + std::unordered_set<TNode, TNodeHashFunction> visited; + std::unordered_set<TNode, TNodeHashFunction>::iterator it; + std::vector<TNode> visit; + TNode cur; + visit.push_back(n); + do + { + cur = visit.back(); + visit.pop_back(); + if (visited.find(cur) == visited.end()) + { + visited.insert(cur); + if (cur.isVar()) + { + if (d_var_index.find(cur) != d_var_index.end()) + { + fvs.push_back(cur); + } + } + for (const Node& cn : cur) + { + visit.push_back(cn); + } + } + } while (!visit.empty()); +} + +bool SygusSampler::isOrdered(Node n) +{ + // compute free variables in n for each type + std::map<TypeNode, std::vector<Node> > fvs; + + std::unordered_set<TNode, TNodeHashFunction> visited; + std::unordered_set<TNode, TNodeHashFunction>::iterator it; + std::vector<TNode> visit; + TNode cur; + visit.push_back(n); + do + { + cur = visit.back(); + visit.pop_back(); + if (visited.find(cur) == visited.end()) + { + visited.insert(cur); + if (cur.isVar()) + { + std::map<Node, unsigned>::iterator itv = d_var_index.find(cur); + if (itv != d_var_index.end()) + { + TypeNode tn = cur.getType(); + // if this variable is out of order + if (itv->second != fvs[tn].size()) + { + return false; + } + fvs[tn].push_back(cur); + } + } + for (unsigned j = 0, nchildren = cur.getNumChildren(); j < nchildren; j++) + { + visit.push_back(cur[(nchildren - j) - 1]); + } + } + } while (!visit.empty()); + return true; +} + +bool SygusSampler::containsFreeVariables(Node a, Node b) +{ + // compute free variables in a + std::vector<Node> fvs; + computeFreeVariables(a, fvs); + + std::unordered_set<TNode, TNodeHashFunction> visited; + std::unordered_set<TNode, TNodeHashFunction>::iterator it; + std::vector<TNode> visit; + TNode cur; + visit.push_back(b); + do + { + cur = visit.back(); + visit.pop_back(); + if (visited.find(cur) == visited.end()) + { + visited.insert(cur); + if (cur.isVar()) + { + if (std::find(fvs.begin(), fvs.end(), cur) == fvs.end()) + { + return false; + } + } + for (const Node& cn : cur) + { + visit.push_back(cn); + } + } + } while (!visit.empty()); + return true; +} + +void SygusSampler::getSamplePoint(unsigned index, std::vector<Node>& pt) +{ + Assert(index < d_samples.size()); + std::vector<Node>& spt = d_samples[index]; + pt.insert(pt.end(), spt.begin(), spt.end()); +} + +Node SygusSampler::evaluate(Node n, unsigned index) +{ + Assert(index < d_samples.size()); + // just a substitution + std::vector<Node>& pt = d_samples[index]; + Node ev = n.substitute(d_vars.begin(), d_vars.end(), pt.begin(), pt.end()); + ev = Rewriter::rewrite(ev); + Trace("sygus-sample-ev") << "( " << n << ", " << index << " ) -> " << ev + << std::endl; + return ev; +} + +Node SygusSampler::getRandomValue(TypeNode tn) +{ + NodeManager* nm = NodeManager::currentNM(); + if (tn.isBoolean()) + { + return nm->mkConst(Random::getRandom().pickWithProb(0.5)); + } + else if (tn.isBitVector()) + { + unsigned sz = tn.getBitVectorSize(); + std::stringstream ss; + for (unsigned i = 0; i < sz; i++) + { + ss << (Random::getRandom().pickWithProb(0.5) ? "1" : "0"); + } + return nm->mkConst(BitVector(ss.str(), 2)); + } + else if (tn.isString() || tn.isInteger()) + { + std::vector<unsigned> vec; + double ext_freq = .5; + unsigned base = 10; + while (Random::getRandom().pickWithProb(ext_freq)) + { + // add a digit + vec.push_back(Random::getRandom().pick(0, base)); + } + if (tn.isString()) + { + return nm->mkConst(String(vec)); + } + else if (tn.isInteger()) + { + Rational baser(base); + Rational curr(1); + std::vector<Node> sum; + for (unsigned j = 0, size = vec.size(); j < size; j++) + { + Node digit = nm->mkConst(Rational(vec[j]) * curr); + sum.push_back(digit); + curr = curr * baser; + } + Node ret; + if (sum.empty()) + { + ret = nm->mkConst(Rational(0)); + } + else if (sum.size() == 1) + { + ret = sum[0]; + } + else + { + ret = nm->mkNode(kind::PLUS, sum); + } + + if (Random::getRandom().pickWithProb(0.5)) + { + // negative + ret = nm->mkNode(kind::UMINUS, ret); + } + ret = Rewriter::rewrite(ret); + Assert(ret.isConst()); + return ret; + } + } + else if (tn.isReal()) + { + Node s = getRandomValue(nm->integerType()); + Node r = getRandomValue(nm->integerType()); + if (!s.isNull() && !r.isNull()) + { + Rational sr = s.getConst<Rational>(); + Rational rr = s.getConst<Rational>(); + if (rr.sgn() == 0) + { + return s; + } + else + { + return nm->mkConst(sr / rr); + } + } + } + return Node::null(); +} + +Node SygusSampler::getSygusRandomValue(TypeNode tn, + double rchance, + double rinc, + unsigned depth) +{ + Assert(tn.isDatatype()); + const Datatype& dt = static_cast<DatatypeType>(tn.toType()).getDatatype(); + Assert(dt.isSygus()); + Assert(d_rvalue_cindices.find(tn) != d_rvalue_cindices.end()); + Trace("sygus-sample-grammar") + << "Sygus random value " << tn << ", depth = " << depth + << ", rchance = " << rchance << std::endl; + // check if we terminate on this call + // we refuse to enumerate terms of 10+ depth as a hard limit + bool terminate = Random::getRandom().pickWithProb(rchance) || depth >= 10; + // if we terminate, only nullary constructors can be chosen + std::vector<unsigned>& cindices = + terminate ? d_rvalue_null_cindices[tn] : d_rvalue_cindices[tn]; + unsigned ncons = cindices.size(); + // select a random constructor, or random value when index=ncons. + unsigned index = Random::getRandom().pick(0, ncons); + Trace("sygus-sample-grammar") + << "Random index 0..." << ncons << " was : " << index << std::endl; + if (index < ncons) + { + Trace("sygus-sample-grammar") + << "Recurse constructor index #" << index << std::endl; + unsigned cindex = cindices[index]; + Assert(cindex < dt.getNumConstructors()); + const DatatypeConstructor& dtc = dt[cindex]; + // more likely to terminate in recursive calls + double rchance_new = rchance + (1.0 - rchance) * rinc; + std::map<int, Node> pre; + bool success = true; + // generate random values for all arguments + for (unsigned i = 0, nargs = dtc.getNumArgs(); i < nargs; i++) + { + TypeNode tnc = d_tds->getArgType(dtc, i); + Node c = getSygusRandomValue(tnc, rchance_new, rinc, depth + 1); + if (c.isNull()) + { + success = false; + Trace("sygus-sample-grammar") << "...fail." << std::endl; + break; + } + Trace("sygus-sample-grammar") + << " child #" << i << " : " << c << std::endl; + pre[i] = c; + } + if (success) + { + Trace("sygus-sample-grammar") << "mkGeneric" << std::endl; + Node ret = d_tds->mkGeneric(dt, cindex, pre); + Trace("sygus-sample-grammar") << "...returned " << ret << std::endl; + ret = Rewriter::rewrite(ret); + Trace("sygus-sample-grammar") << "...after rewrite " << ret << std::endl; + Assert(ret.isConst()); + return ret; + } + } + Trace("sygus-sample-grammar") << "...resort to random value" << std::endl; + // if we did not generate based on the grammar, pick a random value + return getRandomValue(TypeNode::fromType(dt.getSygusType())); +} + +// recursion depth bounded by number of types in grammar (small) +void SygusSampler::registerSygusType(TypeNode tn) +{ + if (d_rvalue_cindices.find(tn) == d_rvalue_cindices.end()) + { + d_rvalue_cindices[tn].clear(); + Assert(tn.isDatatype()); + const Datatype& dt = static_cast<DatatypeType>(tn.toType()).getDatatype(); + Assert(dt.isSygus()); + for (unsigned i = 0, ncons = dt.getNumConstructors(); i < ncons; i++) + { + const DatatypeConstructor& dtc = dt[i]; + Node sop = Node::fromExpr(dtc.getSygusOp()); + bool isVar = std::find(d_vars.begin(), d_vars.end(), sop) != d_vars.end(); + if (isVar) + { + // if it is a variable, add it to the list of sygus types for that var + d_var_sygus_types[sop].push_back(tn); + } + else + { + // otherwise, it is a constructor for sygus random value + d_rvalue_cindices[tn].push_back(i); + if (dtc.getNumArgs() == 0) + { + d_rvalue_null_cindices[tn].push_back(i); + } + } + // recurse on all subfields + for (unsigned j = 0, nargs = dtc.getNumArgs(); j < nargs; j++) + { + TypeNode tnc = d_tds->getArgType(dtc, j); + registerSygusType(tnc); + } + } + } +} + +} /* CVC4::theory::quantifiers namespace */ +} /* CVC4::theory namespace */ +} /* CVC4 namespace */ diff --git a/src/theory/quantifiers/sygus_sampler.h b/src/theory/quantifiers/sygus_sampler.h new file mode 100644 index 000000000..02b60d155 --- /dev/null +++ b/src/theory/quantifiers/sygus_sampler.h @@ -0,0 +1,307 @@ +/********************* */ +/*! \file sygus_sampler.h + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2017 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief sygus_sampler + **/ + +#include "cvc4_private.h" + +#ifndef __CVC4__THEORY__QUANTIFIERS__SYGUS_SAMPLER_H +#define __CVC4__THEORY__QUANTIFIERS__SYGUS_SAMPLER_H + +#include <map> +#include "theory/quantifiers/term_database_sygus.h" + +namespace CVC4 { +namespace theory { +namespace quantifiers { + +/** abstract evaluator class + * + * This class is used for the LazyTrie data structure below. + */ +class LazyTrieEvaluator +{ + public: + virtual Node evaluate(Node n, unsigned index) = 0; +}; + +/** LazyTrie + * + * This is a trie where terms are added in a lazy fashion. This data structure + * is useful, for instance, when we are only interested in when two terms + * map to the same node in the trie but we need not care about computing + * exactly where they are. + * + * In other words, when a term n is added to this trie, we do not insist + * that n is placed at the maximal depth of the trie. Instead, we place n at a + * minimal depth node that has no children. In this case we say n is partially + * evaluated in this trie. + * + * This class relies on an abstract evaluator interface above, which evaluates + * nodes for indices. + * + * For example, say we have terms a, b, c and an evaluator ev where: + * ev->evaluate( a, 0,1,2 ) = 0, 5, 6 + * ev->evaluate( b, 0,1,2 ) = 1, 3, 0 + * ev->evaluate( c, 0,1,2 ) = 1, 3, 2 + * After adding a to the trie, we get: + * root: a + * After adding b to the resulting trie, we get: + * root: null + * d_children[0]: a + * d_children[1]: b + * After adding c to the resulting trie, we get: + * root: null + * d_children[0]: a + * d_children[1]: null + * d_children[3]: null + * d_children[0] : b + * d_children[2] : c + * Notice that we need not call ev->evalute( a, 1 ) and ev->evalute( a, 2 ). + */ +class LazyTrie +{ + public: + LazyTrie() {} + ~LazyTrie() {} + /** the data at this node, which may be partially evaluated */ + Node d_lazy_child; + /** the children of this node */ + std::map<Node, LazyTrie> d_children; + /** clear the trie */ + void clear() { d_children.clear(); } + /** add n to the trie + * + * This function returns a node that is mapped to the same node in the trie + * if one exists, or n otherwise. + * + * ev is an evaluator which determines where n is placed in the trie + * index is the depth of this node + * ntotal is the maximal depth of the trie + * forceKeep is whether we wish to force that n is chosen as a representative + */ + Node add(Node n, + LazyTrieEvaluator* ev, + unsigned index, + unsigned ntotal, + bool forceKeep); +}; + +/** SygusSampler + * + * This class can be used to test whether two expressions are equivalent + * by random sampling. We use this class for the following options: + * + * sygus-rr-synth: synthesize candidate rewrite rules by finding two terms + * t1 and t2 do not rewrite to the same term, but are identical when considering + * a set of sample points, and + * + * sygus-rr-verify: detect unsound rewrite rules by finding two terms t1 and + * t2 that rewrite to the same term, but not identical when considering a set + * of sample points. + * + * To use this class: + * (1) Call initialize( tds, f, nsamples) where f is sygus datatype term. + * (2) For terms n1....nm enumerated that correspond to builtin analog of sygus + * term f, we call registerTerm( n1 )....registerTerm( nm ). It may be the case + * that registerTerm( ni ) returns nj for some j < i. In this case, we have that + * ni and nj are equivalent under the sample points in this class. + * + * For example, say the grammar for f is: + * A = 0 | 1 | x | y | A+A | ite( B, A, A ) + * B = A <= A + * If we call intialize( tds, f, 5 ), this class will generate 5 random sample + * points for (x,y), say (0,0), (1,1), (0,1), (1,0), (2,2). The return values + * of successive calls to registerTerm are listed below. + * registerTerm( 0 ) -> 0 + * registerTerm( x ) -> x + * registerTerm( x+y ) -> x+y + * registerTerm( y+x ) -> x+y + * registerTerm( x+ite(x <= 1+1, 0, y ) ) -> x + * Notice that the number of sample points can be configured for the above + * options using sygus-samples=N. + */ +class SygusSampler : public LazyTrieEvaluator +{ + public: + SygusSampler(); + virtual ~SygusSampler() {} + /** initialize + * + * tn : the return type of terms we will be testing with this class + * vars : the variables we are testing substitutions for + * nsamples : number of sample points this class will test. + */ + void initialize(TypeNode tn, std::vector<Node>& vars, unsigned nsamples); + /** initialize sygus + * + * tds : pointer to sygus database, + * f : a term of some SyGuS datatype type whose (builtin) values we will be + * testing under the free variables in the grammar of f, + * nsamples : number of sample points this class will test. + */ + void initializeSygus(TermDbSygus* tds, Node f, unsigned nsamples); + /** register term n with this sampler database + * + * forceKeep is whether we wish to force that n is chosen as a representative + * value in the trie. + */ + Node registerTerm(Node n, bool forceKeep = false); + /** is contiguous + * + * This returns whether n's free variables (terms occurring in the range of + * d_type_vars) are a prefix of the list of variables in d_type_vars for each + * type. For instance, if d_type_vars[Int] = { x, y }, then 0, x, x+y, y+x are + * contiguous but y is not. This is useful for excluding terms from + * consideration that are alpha-equivalent to others. + */ + bool isContiguous(Node n); + /** is ordered + * + * This returns whether n's free variables are in order with respect to + * variables in d_type_vars for each type. For instance, if + * d_type_vars[Int] = { x, y }, then 0, x, x+y are ordered but y and y+x + * are not. + */ + bool isOrdered(Node n); + /** contains free variables + * + * Returns true if all free variables of a are contained in b. Free variables + * are those that occur in the range d_type_vars. + */ + bool containsFreeVariables(Node a, Node b); + /** get number of sample points */ + unsigned getNumSamplePoints() const { return d_samples.size(); } + /** get sample point + * + * Appends sample point #index to the vector pt. + */ + void getSamplePoint(unsigned index, std::vector<Node>& pt); + /** evaluate n on sample point index */ + Node evaluate(Node n, unsigned index); + + private: + /** sygus term database of d_qe */ + TermDbSygus* d_tds; + /** samples */ + std::vector<std::vector<Node> > d_samples; + /** data structure to check duplication of sample points */ + class PtTrie + { + public: + /** add pt to this trie, returns true if pt is not a duplicate. */ + bool add(std::vector<Node>& pt); + + private: + /** the children of this node */ + std::map<Node, PtTrie> d_children; + }; + /** a trie for samples */ + PtTrie d_samples_trie; + /** type of nodes we will be registering with this class */ + TypeNode d_tn; + /** the sygus type for this sampler (if applicable). */ + TypeNode d_ftn; + /** all variables */ + std::vector<Node> d_vars; + /** type variables + * + * For each type, a list of variables in the grammar we are considering, for + * that type. These typically correspond to the arguments of the + * function-to-synthesize whose grammar we are considering. + */ + std::map<TypeNode, std::vector<Node> > d_type_vars; + /** + * A map all variables in the grammar we are considering to their index in + * d_type_vars. + */ + std::map<Node, unsigned> d_var_index; + /** constants + * + * For each type, a list of constants in the grammar we are considering, for + * that type. + */ + std::map<TypeNode, std::vector<Node> > d_type_consts; + /** the lazy trie */ + LazyTrie d_trie; + /** is this sampler valid? + * + * A sampler can be invalid if sample points cannot be generated for a type + * of an argument to function f. + */ + bool d_is_valid; + /** + * Compute the variables from the domain of d_var_index that occur in n, + * store these in the vector fvs. + */ + void computeFreeVariables(Node n, std::vector<Node>& fvs); + /** initialize samples + * + * Adds nsamples sample points to d_samples. + */ + void initializeSamples(unsigned nsamples); + /** get random value for a type + * + * Returns a random value for the given type based on the random number + * generator. Currently, supported types: + * + * Bool, Bitvector : returns a random value in the range of that type. + * Int, String : returns a random string of values in (base 10) of random + * length, currently by a repeated coin flip. + * Real : returns the division of two random integers, where the denominator + * is omitted if it is zero. + */ + Node getRandomValue(TypeNode tn); + /** get sygus random value + * + * Returns a random value based on the sygus type tn. The return value is + * a constant in the analog type of tn. This function chooses either to + * return a random value, or otherwise will construct a constant based on + * a random constructor of tn whose builtin operator is not a variable. + * + * rchance: the chance that the call to this function will be forbidden + * from making recursive calls and instead must return a value based on + * a nullary constructor of tn or based on getRandomValue above. + * rinc: the percentage to increment rchance on recursive calls. + * + * For example, consider the grammar: + * A -> x | y | 0 | 1 | +( A, A ) | ite( B, A, A ) + * B -> A = A + * If we call this function on A and rchance is 0.0, there are five evenly + * chosen possibilities, either we return a random value via getRandomValue + * above, or we choose one of the four non-variable constructors of A. + * Say we choose ite, then we recursively call this function for + * B, A, and A, which return constants c1, c2, and c3. Then, this function + * returns the rewritten form of ite( c1, c2, c3 ). + * If on the other hand, rchance was 0.5 and rand() < 0.5. Then, we force + * this call to terminate by either selecting a random value via + * getRandomValue, 0 or 1. + */ + Node getSygusRandomValue(TypeNode tn, + double rchance, + double rinc, + unsigned depth = 0); + /** map from sygus types to non-variable constructors */ + std::map<TypeNode, std::vector<unsigned> > d_rvalue_cindices; + /** map from sygus types to non-variable nullary constructors */ + std::map<TypeNode, std::vector<unsigned> > d_rvalue_null_cindices; + /** map from variables to sygus types that include them */ + std::map<Node, std::vector<TypeNode> > d_var_sygus_types; + /** register sygus type, intializes the above two data structures */ + void registerSygusType(TypeNode tn); +}; + +} /* CVC4::theory::quantifiers namespace */ +} /* CVC4::theory namespace */ +} /* CVC4 namespace */ + +#endif /* __CVC4__THEORY__QUANTIFIERS__SYGUS_SAMPLER_H */ diff --git a/src/theory/quantifiers/term_database_sygus.cpp b/src/theory/quantifiers/term_database_sygus.cpp index 8b1ff37f1..cda652ee7 100644 --- a/src/theory/quantifiers/term_database_sygus.cpp +++ b/src/theory/quantifiers/term_database_sygus.cpp @@ -109,129 +109,12 @@ TypeNode TermDbSygus::getSygusTypeForVar( Node v ) { return d_fv_stype[v]; } -bool TermDbSygus::getMatch( Node p, Node n, std::map< int, Node >& s ) { - std::vector< int > new_s; - return getMatch2( p, n, s, new_s ); -} - -bool TermDbSygus::getMatch2( Node p, Node n, std::map< int, Node >& s, std::vector< int >& new_s ) { - std::map< Node, int >::iterator it = d_fv_num.find( p ); - if( it!=d_fv_num.end() ){ - Node prev = s[it->second]; - s[it->second] = n; - if( prev.isNull() ){ - new_s.push_back( it->second ); - } - return prev.isNull() || prev==n; - }else if( n.getNumChildren()==0 ){ - return p==n; - }else if( n.getKind()==p.getKind() && n.getNumChildren()==p.getNumChildren() ){ - //try both ways? - unsigned rmax = TermUtil::isComm( n.getKind() ) && n.getNumChildren()==2 ? 2 : 1; - std::vector< int > new_tmp; - for( unsigned r=0; r<rmax; r++ ){ - bool success = true; - for( unsigned i=0; i<n.getNumChildren(); i++ ){ - int io = r==0 ? i : ( i==0 ? 1 : 0 ); - if( !getMatch2( p[i], n[io], s, new_tmp ) ){ - success = false; - for( unsigned j=0; j<new_tmp.size(); j++ ){ - s.erase( new_tmp[j] ); - } - new_tmp.clear(); - break; - } - } - if( success ){ - new_s.insert( new_s.end(), new_tmp.begin(), new_tmp.end() ); - return true; - } - } - } - return false; -} - -bool TermDbSygus::getMatch( Node t, TypeNode st, int& index_found, std::vector< Node >& args, int index_exc, int index_start ) { - Assert( st.isDatatype() ); - const Datatype& dt = ((DatatypeType)(st).toType()).getDatatype(); - Assert( dt.isSygus() ); - std::map< Kind, std::vector< Node > > kgens; - std::vector< Node > gens; - for( unsigned i=index_start; i<dt.getNumConstructors(); i++ ){ - if( (int)i!=index_exc ){ - Node g = getGenericBase( st, dt, i ); - gens.push_back( g ); - kgens[g.getKind()].push_back( g ); - Trace("sygus-db-debug") << "Check generic base : " << g << " from " << dt[i].getName() << std::endl; - if( g.getKind()==t.getKind() ){ - Trace("sygus-db-debug") << "Possible match ? " << g << " " << t << " for " << dt[i].getName() << std::endl; - std::map< int, Node > sigma; - if( getMatch( g, t, sigma ) ){ - //we found an exact match - bool msuccess = true; - for( unsigned j=0; j<dt[i].getNumArgs(); j++ ){ - if( sigma[j].isNull() ){ - msuccess = false; - break; - }else{ - args.push_back( sigma[j] ); - } - } - if( msuccess ){ - index_found = i; - return true; - } - //we found an exact match - //std::map< TypeNode, int > var_count; - //Node new_t = mkGeneric( dt, i, var_count, args ); - //Trace("sygus-db-debug") << "Rewrote to : " << new_t << std::endl; - //return new_t; - } - } - } - } - /* - //otherwise, try to modulate based on kinds - for( std::map< Kind, std::vector< Node > >::iterator it = kgens.begin(); it != kgens.end(); ++it ){ - if( it->second.size()>1 ){ - for( unsigned i=0; i<it->second.size(); i++ ){ - for( unsigned j=0; j<it->second.size(); j++ ){ - if( i!=j ){ - std::map< int, Node > sigma; - if( getMatch( it->second[i], it->second[j], sigma ) ){ - if( sigma.size()==1 ){ - //Node mod_pat = sigma.begin().second; - //Trace("cegqi-si-rcons-debug") << "Modulated pattern " << mod_pat << " from " << it->second[i] << " and " << it->second[j] << std::endl; - } - } - } - } - } - } - } - */ - return false; -} - -Node TermDbSygus::getGenericBase( TypeNode tn, const Datatype& dt, int c ) { - std::map< int, Node >::iterator it = d_generic_base[tn].find( c ); - if( it==d_generic_base[tn].end() ){ - Assert( isRegistered( tn ) ); - std::map< TypeNode, int > var_count; - std::map< int, Node > pre; - Node g = mkGeneric( dt, c, var_count, pre ); - Trace("sygus-db-debug") << "Sygus DB : Generic is " << g << std::endl; - Node gr = Rewriter::rewrite( g ); - Trace("sygus-db-debug") << "Sygus DB : Generic rewritten is " << gr << std::endl; - d_generic_base[tn][c] = gr; - return gr; - }else{ - return it->second; - } -} - -Node TermDbSygus::mkGeneric( const Datatype& dt, int c, std::map< TypeNode, int >& var_count, std::map< int, Node >& pre ) { - Assert( c>=0 && c<(int)dt.getNumConstructors() ); +Node TermDbSygus::mkGeneric(const Datatype& dt, + unsigned c, + std::map<TypeNode, int>& var_count, + std::map<int, Node>& pre) +{ + Assert(c < dt.getNumConstructors()); Assert( dt.isSygus() ); Assert( !dt[c].getSygusOp().isNull() ); std::vector< Node > children; @@ -240,7 +123,8 @@ Node TermDbSygus::mkGeneric( const Datatype& dt, int c, std::map< TypeNode, int children.push_back( op ); } Trace("sygus-db-debug") << "mkGeneric " << dt.getName() << " " << op << " " << op.getKind() << "..." << std::endl; - for( int i=0; i<(int)dt[c].getNumArgs(); i++ ){ + for (unsigned i = 0, nargs = dt[c].getNumArgs(); i < nargs; i++) + { TypeNode tna = getArgType( dt[c], i ); Node a; std::map< int, Node >::iterator it = pre.find( i ); @@ -249,11 +133,14 @@ Node TermDbSygus::mkGeneric( const Datatype& dt, int c, std::map< TypeNode, int }else{ a = getFreeVarInc( tna, var_count, true ); } + Trace("sygus-db-debug") + << " child " << i << " : " << a << " : " << a.getType() << std::endl; Assert( !a.isNull() ); children.push_back( a ); } Node ret; if( op.getKind()==BUILTIN ){ + Trace("sygus-db-debug") << "Make builtin node..." << std::endl; ret = NodeManager::currentNM()->mkNode( op, children ); }else{ Kind ok = getOperatorKind( op ); @@ -268,33 +155,44 @@ Node TermDbSygus::mkGeneric( const Datatype& dt, int c, std::map< TypeNode, int return ret; } +Node TermDbSygus::mkGeneric(const Datatype& dt, int c, std::map<int, Node>& pre) +{ + std::map<TypeNode, int> var_count; + return mkGeneric(dt, c, var_count, pre); +} + Node TermDbSygus::sygusToBuiltin( Node n, TypeNode tn ) { Assert( n.getType()==tn ); Assert( tn.isDatatype() ); std::map< Node, Node >::iterator it = d_sygus_to_builtin[tn].find( n ); if( it==d_sygus_to_builtin[tn].end() ){ Trace("sygus-db-debug") << "SygusToBuiltin : compute for " << n << ", type = " << tn << std::endl; - Node ret; const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype(); if( n.getKind()==APPLY_CONSTRUCTOR ){ unsigned i = Datatype::indexOf( n.getOperator().toExpr() ); Assert( n.getNumChildren()==dt[i].getNumArgs() ); std::map< TypeNode, int > var_count; std::map< int, Node > pre; - for( unsigned j=0; j<n.getNumChildren(); j++ ){ + for (unsigned j = 0, size = n.getNumChildren(); j < size; j++) + { pre[j] = sygusToBuiltin( n[j], getArgType( dt[i], j ) ); } - ret = mkGeneric( dt, i, var_count, pre ); + Node ret = mkGeneric(dt, i, var_count, pre); Trace("sygus-db-debug") << "SygusToBuiltin : Generic is " << ret << std::endl; d_sygus_to_builtin[tn][n] = ret; - }else{ - Assert( isFreeVar( n ) ); - //map to builtin variable type - int fv_num = getVarNum( n ); - Assert( !dt.getSygusType().isNull() ); - TypeNode vtn = TypeNode::fromType( dt.getSygusType() ); - ret = getFreeVar( vtn, fv_num ); + return ret; } + if (n.hasAttribute(SygusPrintProxyAttribute())) + { + // this variable was associated by an attribute to a builtin node + return n.getAttribute(SygusPrintProxyAttribute()); + } + Assert(isFreeVar(n)); + // map to builtin variable type + int fv_num = getVarNum(n); + Assert(!dt.getSygusType().isNull()); + TypeNode vtn = TypeNode::fromType(dt.getSygusType()); + Node ret = getFreeVar(vtn, fv_num); return ret; }else{ return it->second; @@ -305,102 +203,6 @@ Node TermDbSygus::sygusSubstituted( TypeNode tn, Node n, std::vector< Node >& ar Assert( d_var_list[tn].size()==args.size() ); return n.substitute( d_var_list[tn].begin(), d_var_list[tn].end(), args.begin(), args.end() ); } - -//rcons_depth limits the number of recursive calls when doing accelerated constant reconstruction (currently limited to 1000) -//this is hacky : depending upon order of calls, constant rcons may succeed, e.g. 1001, 999 vs. 999, 1001 -Node TermDbSygus::builtinToSygusConst( Node c, TypeNode tn, int rcons_depth ) { - std::map< Node, Node >::iterator it = d_builtin_const_to_sygus[tn].find( c ); - if( it==d_builtin_const_to_sygus[tn].end() ){ - Node sc; - d_builtin_const_to_sygus[tn][c] = sc; - Assert( c.isConst() ); - Assert( tn.isDatatype() ); - const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype(); - Trace("csi-rcons-debug") << "Try to reconstruct " << c << " in " << dt.getName() << std::endl; - Assert( dt.isSygus() ); - // if we are not interested in reconstructing constants, or the grammar allows them, return a proxy - if( !options::cegqiSingleInvReconstructConst() || dt.getSygusAllowConst() ){ - Node k = NodeManager::currentNM()->mkSkolem( "sy", tn, "sygus proxy" ); - SygusPrintProxyAttribute spa; - k.setAttribute(spa,c); - sc = k; - }else{ - int carg = getOpConsNum( tn, c ); - if( carg!=-1 ){ - sc = NodeManager::currentNM()->mkNode( APPLY_CONSTRUCTOR, Node::fromExpr( dt[carg].getConstructor() ) ); - }else{ - //identity functions - for( unsigned i=0; i<getNumIdFuncs( tn ); i++ ){ - unsigned ii = getIdFuncIndex( tn, i ); - Assert( dt[ii].getNumArgs()==1 ); - //try to directly reconstruct from single argument - TypeNode tnc = getArgType( dt[ii], 0 ); - Trace("csi-rcons-debug") << "Based on id function " << dt[ii].getSygusOp() << ", try reconstructing " << c << " instead in " << tnc << std::endl; - Node n = builtinToSygusConst( c, tnc, rcons_depth ); - if( !n.isNull() ){ - sc = NodeManager::currentNM()->mkNode( APPLY_CONSTRUCTOR, Node::fromExpr( dt[ii].getConstructor() ), n ); - break; - } - } - if( sc.isNull() ){ - if( rcons_depth<1000 ){ - //accelerated, recursive reconstruction of constants - Kind pk = getPlusKind( TypeNode::fromType( dt.getSygusType() ) ); - if( pk!=UNDEFINED_KIND ){ - int arg = getKindConsNum( tn, pk ); - if( arg!=-1 ){ - Kind ck = getComparisonKind( TypeNode::fromType( dt.getSygusType() ) ); - Kind pkm = getPlusKind( TypeNode::fromType( dt.getSygusType() ), true ); - //get types - Assert( dt[arg].getNumArgs()==2 ); - TypeNode tn1 = getArgType( dt[arg], 0 ); - TypeNode tn2 = getArgType( dt[arg], 1 ); - //iterate over all positive constants, largest to smallest - int start = d_const_list[tn1].size()-1; - int end = d_const_list[tn1].size()-d_const_list_pos[tn1]; - for( int i=start; i>=end; --i ){ - Node c1 = d_const_list[tn1][i]; - //only consider if smaller than c, and - if( doCompare( c1, c, ck ) ){ - Node c2 = NodeManager::currentNM()->mkNode( pkm, c, c1 ); - c2 = Rewriter::rewrite( c2 ); - if( c2.isConst() ){ - //reconstruct constant on the other side - Node sc2 = builtinToSygusConst( c2, tn2, rcons_depth+1 ); - if( !sc2.isNull() ){ - Node sc1 = builtinToSygusConst( c1, tn1, rcons_depth ); - Assert( !sc1.isNull() ); - sc = NodeManager::currentNM()->mkNode( APPLY_CONSTRUCTOR, Node::fromExpr( dt[arg].getConstructor() ), sc1, sc2 ); - break; - } - } - } - } - } - } - } - } - } - } - d_builtin_const_to_sygus[tn][c] = sc; - return sc; - }else{ - return it->second; - } -} - -Node TermDbSygus::getNormalized(TypeNode t, Node prog) -{ - std::map< Node, Node >::iterator itn = d_normalized[t].find( prog ); - if( itn==d_normalized[t].end() ){ - Node progr = Rewriter::rewrite( prog ); - Trace("sygus-sym-break2") << "...rewrites to " << progr << std::endl; - d_normalized[t][prog] = progr; - return progr; - }else{ - return itn->second; - } -} unsigned TermDbSygus::getSygusTermSize( Node n ){ if( n.getNumChildren()==0 ){ @@ -419,23 +221,6 @@ unsigned TermDbSygus::getSygusTermSize( Node n ){ } } -unsigned TermDbSygus::getSygusConstructors( Node n, std::vector< Node >& cons ) { - Assert( n.getKind()==APPLY_CONSTRUCTOR ); - Node op = n.getOperator(); - if( std::find( cons.begin(), cons.end(), op )==cons.end() ){ - cons.push_back( op ); - } - if( n.getNumChildren()==0 ){ - return 0; - }else{ - unsigned sum = 0; - for( unsigned i=0; i<n.getNumChildren(); i++ ){ - sum += getSygusConstructors( n[i], cons ); - } - return 1+sum; - } -} - class ReqTrie { public: ReqTrie() : d_req_kind( UNDEFINED_KIND ){} @@ -825,61 +610,6 @@ int TermDbSygus::solveForArgument( TypeNode tn, unsigned cindex, unsigned arg ) return -1; } -struct sortConstants { - TermDbSygus * d_tds; - Kind d_comp_kind; - bool operator() (Node i, Node j) { - if( i!=j ){ - return d_tds->doCompare( i, j, d_comp_kind ); - }else{ - return false; - } - } -}; - -class ReconstructTrie { -public: - std::map< Node, ReconstructTrie > d_children; - std::vector< Node > d_reconstruct; - void add( std::vector< Node >& cons, Node r, unsigned index = 0 ){ - if( index==cons.size() ){ - d_reconstruct.push_back( r ); - }else{ - d_children[cons[index]].add( cons, r, index+1 ); - } - } - Node getReconstruct( std::map< Node, int >& rcons, unsigned depth ) { - if( !d_reconstruct.empty() ){ - for( unsigned i=0; i<d_reconstruct.size(); i++ ){ - Node r = d_reconstruct[i]; - if( rcons[r]==0 ){ - Trace("sygus-static-enum") << "...eliminate constructor " << r << std::endl; - rcons[r] = 1; - return r; - } - } - } - if( depth>0 ){ - for( unsigned w=0; w<2; w++ ){ - for( std::map< Node, ReconstructTrie >::iterator it = d_children.begin(); it != d_children.end(); ++it ){ - Node n = it->first; - if( ( w==0 && rcons[n]!=0 ) || ( w==1 && rcons[n]==0 ) ){ - Node r = it->second.getReconstruct( rcons, depth - w ); - if( !r.isNull() ){ - if( w==1 ){ - Trace("sygus-static-enum") << "...use " << n << " to eliminate constructor " << r << std::endl; - rcons[n] = -1; - } - return r; - } - } - } - } - } - return Node::null(); - } -}; - void TermDbSygus::registerSygusType( TypeNode tn ) { std::map< TypeNode, TypeNode >::iterator itr = d_register.find( tn ); if( itr==d_register.end() ){ @@ -902,11 +632,6 @@ void TermDbSygus::registerSygusType( TypeNode tn ) { }else{ // no arguments to synthesis functions } - //for constant reconstruction - Kind ck = getComparisonKind( TypeNode::fromType( dt.getSygusType() ) ); - Node z = d_quantEngine->getTermUtil()->getTypeValue( - TypeNode::fromType(dt.getSygusType()), 0); - d_const_list_pos[tn] = 0; //iterate over constructors for( unsigned i=0; i<dt.getNumConstructors(); i++ ){ Expr sop = dt[i].getSygusOp(); @@ -922,180 +647,17 @@ void TermDbSygus::registerSygusType( TypeNode tn ) { Trace("sygus-db") << ", constant"; d_consts[tn][n] = i; d_arg_const[tn][i] = n; - d_const_list[tn].push_back( n ); - if( ck!=UNDEFINED_KIND && doCompare( z, n, ck ) ){ - d_const_list_pos[tn]++; - } - } - if( dt[i].isSygusIdFunc() ){ - d_id_funcs[tn].push_back( i ); } d_ops[tn][n] = i; d_arg_ops[tn][i] = n; Trace("sygus-db") << std::endl; } - //sort the constant list - if( !d_const_list[tn].empty() ){ - if( ck!=UNDEFINED_KIND ){ - sortConstants sc; - sc.d_comp_kind = ck; - sc.d_tds = this; - std::sort( d_const_list[tn].begin(), d_const_list[tn].end(), sc ); - } - Trace("sygus-db") << "Type has " << d_const_list[tn].size() << " constants..." << std::endl << " "; - for( unsigned i=0; i<d_const_list[tn].size(); i++ ){ - Trace("sygus-db") << d_const_list[tn][i] << " "; - } - Trace("sygus-db") << std::endl; - Trace("sygus-db") << "Of these, " << d_const_list_pos[tn] << " are marked as positive." << std::endl; - } //register connected types for( unsigned i=0; i<dt.getNumConstructors(); i++ ){ for( unsigned j=0; j<dt[i].getNumArgs(); j++ ){ registerSygusType( getArgType( dt[i], j ) ); } } - - //compute the redundant operators - if( options::sygusMinGrammar() ){ - for( unsigned i=0; i<dt.getNumConstructors(); i++ ){ - bool nred = true; - Trace("sygus-split-debug") << "Is " << dt[i].getName() << " a redundant operator?" << std::endl; - Kind ck = getConsNumKind( tn, i ); - if( ck!=UNDEFINED_KIND ){ - Kind dk; - if (TermUtil::isAntisymmetric(ck, dk)) - { - int j = getKindConsNum( tn, dk ); - if( j!=-1 ){ - Trace("sygus-split-debug") << "Possible redundant operator : " << ck << " with " << dk << std::endl; - //check for type mismatches - bool success = true; - for( unsigned k=0; k<2; k++ ){ - unsigned ko = k==0 ? 1 : 0; - TypeNode tni = TypeNode::fromType( ((SelectorType)dt[i][k].getType()).getRangeType() ); - TypeNode tnj = TypeNode::fromType( ((SelectorType)dt[j][ko].getType()).getRangeType() ); - if( tni!=tnj ){ - Trace("sygus-split-debug") << "Argument types " << tni << " and " << tnj << " are not equal." << std::endl; - success = false; - break; - } - } - if( success ){ - Trace("sygus-nf") << "* Sygus norm " << dt.getName() << " : do not consider any " << ck << " terms." << std::endl; - nred = false; - } - } - } - } - if( nred ){ - Trace("sygus-split-debug") << "Check " << dt[i].getName() << " based on generic rewriting" << std::endl; - std::map< TypeNode, int > var_count; - std::map< int, Node > pre; - Node g = mkGeneric( dt, i, var_count, pre ); - nred = !computeGenericRedundant( tn, g ); - Trace("sygus-split-debug") << "...done check " << dt[i].getName() << " based on generic rewriting" << std::endl; - } - d_sygus_red_status[tn].push_back( nred ? 0 : 1 ); - } - // run an enumerator for this type - if( options::sygusMinGrammarAgg() ){ - TypeEnumerator te(tn); - unsigned count = 0; - std::map< Node, std::vector< Node > > builtin_to_orig; - Trace("sygus-static-enum") << "Static enumerate " << dt.getName() << "..." << std::endl; - while( !te.isFinished() && count<1000 ){ - Node n = *te; - Node bn = sygusToBuiltin( n, tn ); - Trace("sygus-static-enum") << " " << bn; - Node bnr = Rewriter::rewrite( bn ); - Trace("sygus-static-enum") << " ..." << bnr << std::endl; - builtin_to_orig[bnr].push_back( n ); - ++te; - count++; - } - std::map< Node, bool > reserved; - for( unsigned i=0; i<=2; i++ ){ - Node rsv = - i == 2 ? d_quantEngine->getTermUtil()->getTypeMaxValue(btn) - : d_quantEngine->getTermUtil()->getTypeValue(btn, i); - if( !rsv.isNull() ){ - reserved[ rsv ] = true; - } - } - Trace("sygus-static-enum") << "...make the reconstruct index data structure..." << std::endl; - ReconstructTrie rt; - std::map< Node, int > rcons; - unsigned max_depth = 0; - for( std::map< Node, std::vector< Node > >::iterator itb = builtin_to_orig.begin(); itb != builtin_to_orig.end(); ++itb ){ - if( itb->second.size()>0 ){ - std::map< Node, std::vector< Node > > clist; - Node single_cons; - for( unsigned j=0; j<itb->second.size(); j++ ){ - Node e = itb->second[j]; - getSygusConstructors( e, clist[e] ); - if( clist[e].size()>max_depth ){ - max_depth = clist[e].size(); - } - for( unsigned k=0; k<clist[e].size(); k++ ){ - /* - unsigned cindex = Datatype::indexOf( clist[e][k].toExpr() ); - if( isGenericRedundant( tn, cindex ) ){ - is_gen_redundant = true; - break; - }else{ - */ - rcons[clist[e][k]] = 0; - } - //if( is_gen_redundant ){ - // clist.erase( e ); - //}else{ - if( clist[e].size()==1 ){ - Trace("sygus-static-enum") << "...single constructor term : " << e << ", builtin is " << itb->first << ", cons is " << clist[e][0] << std::endl; - if( single_cons.isNull() ){ - single_cons = clist[e][0]; - }else{ - Trace("sygus-static-enum") << "*** already can eliminate constructor " << clist[e][0] << std::endl; - unsigned cindex = Datatype::indexOf( clist[e][0].toExpr() ); - d_sygus_red_status[tn][cindex] = 1; - } - } - //} - } - // do not eliminate 0, 1, or max - if( !single_cons.isNull() && reserved.find( itb->first )==reserved.end() ){ - Trace("sygus-static-enum") << "...possibly elim " << single_cons << std::endl; - for( std::map< Node, std::vector< Node > >::iterator itc = clist.begin(); itc != clist.end(); ++itc ){ - if( std::find( itc->second.begin(), itc->second.end(), single_cons )==itc->second.end() ){ - rt.add( itc->second, single_cons ); - } - } - } - } - } - Trace("sygus-static-enum") << "...compute reconstructions..." << std::endl; - Node next_rcons; - do { - unsigned depth = 0; - do{ - next_rcons = rt.getReconstruct( rcons, depth ); - depth++; - }while( next_rcons.isNull() && depth<=max_depth ); - // if we found a constructor to eliminate - if( !next_rcons.isNull() ){ - Trace("sygus-static-enum") << "*** eliminate constructor " << next_rcons << std::endl; - unsigned cindex = Datatype::indexOf( next_rcons.toExpr() ); - d_sygus_red_status[tn][cindex] = 2; - } - }while( !next_rcons.isNull() ); - Trace("sygus-static-enum") << "...finished..." << std::endl; - } - }else{ - // assume all are non-redundant - for( unsigned i=0; i<dt.getNumConstructors(); i++ ){ - d_sygus_red_status[tn].push_back( 0 ); - } - } } } } @@ -1212,11 +774,10 @@ unsigned TermDbSygus::getMinTermSize( TypeNode tn ) { if( it==d_min_term_size.end() ){ const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype(); for( unsigned i=0; i<dt.getNumConstructors(); i++ ){ - if( !isGenericRedundant( tn, i ) ){ - if( dt[i].getNumArgs()==0 ){ - d_min_term_size[tn] = 0; - return 0; - } + if (dt[i].getNumArgs() == 0) + { + d_min_term_size[tn] = 0; + return 0; } } // TODO : improve @@ -1229,7 +790,6 @@ unsigned TermDbSygus::getMinTermSize( TypeNode tn ) { unsigned TermDbSygus::getMinConsTermSize( TypeNode tn, unsigned cindex ) { Assert( isRegistered( tn ) ); - Assert( !isGenericRedundant( tn, cindex ) ); std::map< unsigned, unsigned >::iterator it = d_min_cons_term_size[tn].find( cindex ); if( it==d_min_cons_term_size[tn].end() ){ const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype(); @@ -1372,16 +932,9 @@ bool TermDbSygus::isConstArg( TypeNode tn, int i ) { } } -unsigned TermDbSygus::getNumIdFuncs( TypeNode tn ) { - return d_id_funcs[tn].size(); -} - -unsigned TermDbSygus::getIdFuncIndex( TypeNode tn, unsigned i ) { - return d_id_funcs[tn][i]; -} - -TypeNode TermDbSygus::getArgType( const DatatypeConstructor& c, int i ) { - Assert( i>=0 && i<(int)c.getNumArgs() ); +TypeNode TermDbSygus::getArgType(const DatatypeConstructor& c, unsigned i) +{ + Assert(i < c.getNumArgs()); return TypeNode::fromType( ((SelectorType)c[i].getType()).getRangeType() ); } @@ -1492,12 +1045,6 @@ Kind TermDbSygus::getPlusKind( TypeNode tn, bool is_neg ) { } } -bool TermDbSygus::doCompare( Node a, Node b, Kind k ) { - Node com = NodeManager::currentNM()->mkNode( k, a, b ); - com = Rewriter::rewrite( com ); - return com==d_true; -} - Node TermDbSygus::getSemanticSkolem( TypeNode tn, Node n, bool doMk ){ std::map< Node, Node >::iterator its = d_semantic_skolem[tn].find( n ); if( its!=d_semantic_skolem[tn].end() ){ @@ -1934,40 +1481,6 @@ Node TermDbSygus::evaluateWithUnfolding( Node n ) { return evaluateWithUnfolding( n, visited ); } -bool TermDbSygus::computeGenericRedundant( TypeNode tn, Node g ) { - //everything added to this cache should be mutually exclusive cases - std::map< Node, bool >::iterator it = d_gen_redundant[tn].find( g ); - if( it==d_gen_redundant[tn].end() ){ - Trace("sygus-gnf") << "Register generic for " << tn << " : " << g << std::endl; - Node gr = getNormalized(tn, g); - Trace("sygus-gnf-debug") << "Generic " << g << " rewrites to " << gr << std::endl; - std::map< Node, Node >::iterator itg = d_gen_terms[tn].find( gr ); - bool red = true; - if( itg==d_gen_terms[tn].end() ){ - red = false; - d_gen_terms[tn][gr] = g; - Trace("sygus-gnf-debug") << "...not redundant." << std::endl; - Trace("sygus-nf-reg") << "*** Sygus (generic) normal form : normal form of " << g << " is " << gr << std::endl; - }else{ - Trace("sygus-gnf-debug") << "...redundant." << std::endl; - Trace("sygus-nf") << "* Sygus normal form : simplify since " << g << " and " << itg->second << " both rewrite to " << gr << std::endl; - } - d_gen_redundant[tn][g] = red; - return red; - }else{ - return it->second; - } -} - -bool TermDbSygus::isGenericRedundant( TypeNode tn, unsigned i ) { - Assert( i<d_sygus_red_status[tn].size() ); - if( options::sygusMinGrammarAgg() ){ - return d_sygus_red_status[tn][i]!=0; - }else{ - return d_sygus_red_status[tn][i]==1; - } -} - }/* CVC4::theory::quantifiers namespace */ }/* CVC4::theory namespace */ }/* CVC4 namespace */ diff --git a/src/theory/quantifiers/term_database_sygus.h b/src/theory/quantifiers/term_database_sygus.h index 01e518eb1..b9af26b6e 100644 --- a/src/theory/quantifiers/term_database_sygus.h +++ b/src/theory/quantifiers/term_database_sygus.h @@ -68,6 +68,56 @@ class TermDbSygus { SygusExplain* getExplain() { return d_syexp.get(); } /** get the extended rewrite utility */ ExtendedRewriter* getExtRewriter() { return d_ext_rw.get(); } + //-----------------------------conversion from sygus to builtin + /** get free variable + * + * This class caches a list of free variables for each type, which are + * used, for instance, for constructing canonical forms of terms with free + * variables. This function returns the i^th free variable for type tn. + * If useSygusType is true, then this function returns a variable of the + * analog type for sygus type tn (see d_fv for details). + */ + TNode getFreeVar(TypeNode tn, int i, bool useSygusType = false); + /** get free variable and increment + * + * This function returns the next free variable for type tn, and increments + * the counter in var_count for that type. + */ + TNode getFreeVarInc(TypeNode tn, + std::map<TypeNode, int>& var_count, + bool useSygusType = false); + /** returns true if n is a cached free variable (in d_fv). */ + bool isFreeVar(Node n) { return d_fv_stype.find(n) != d_fv_stype.end(); } + /** returns the index of n in the free variable cache (d_fv). */ + int getVarNum(Node n) { return d_fv_num[n]; } + /** returns true if n has a cached free variable (in d_fv). */ + bool hasFreeVar(Node n); + /** make generic + * + * This function returns a builtin term f( t1, ..., tn ) where f is the + * builtin op of the sygus datatype constructor specified by arguments + * dt and c. The mapping pre maps child indices to the term for that index + * in the term we are constructing. That is, for each i = 1,...,n: + * If i is in the domain of pre, then ti = pre[i]. + * If i is not in the domain of pre, then ti = d_fv[1][ var_count[Ti ] ], + * and var_count[Ti] is incremented. + */ + Node mkGeneric(const Datatype& dt, + unsigned c, + std::map<TypeNode, int>& var_count, + std::map<int, Node>& pre); + /** same as above, but with empty var_count */ + Node mkGeneric(const Datatype& dt, int c, std::map<int, Node>& pre); + /** sygus to builtin + * + * Given a sygus datatype term n of type tn, this function returns its analog, + * that is, the term that n encodes. + */ + Node sygusToBuiltin(Node n, TypeNode tn); + /** same as above, but without tn */ + Node sygusToBuiltin(Node n) { return sygusToBuiltin(n, n.getType()); } + //-----------------------------end conversion from sygus to builtin + private: /** reference to the quantifiers engine */ QuantifiersEngine* d_quantEngine; @@ -88,30 +138,31 @@ class TermDbSygus { */ std::map<Node, Node> d_enum_to_active_guard; + //-----------------------------conversion from sygus to builtin + /** cache for sygusToBuiltin */ + std::map<TypeNode, std::map<Node, Node> > d_sygus_to_builtin; + /** a cache of fresh variables for each type + * + * We store two versions of this list: + * index 0: mapping from builtin types to fresh variables of that type, + * index 1: mapping from sygus types to fresh varaibles of the type they + * encode. + */ + std::map<TypeNode, std::vector<Node> > d_fv[2]; + /** Maps free variables to the domain type they are associated with in d_fv */ + std::map<Node, TypeNode> d_fv_stype; + /** Maps free variables to their index in d_fv. */ + std::map<Node, int> d_fv_num; + /** recursive helper for hasFreeVar, visited stores nodes we have visited. */ + bool hasFreeVar(Node n, std::map<Node, bool>& visited); + //-----------------------------end conversion from sygus to builtin + // TODO :issue #1235 : below here needs refactor public: Node d_true; Node d_false; - private: - std::map< TypeNode, std::vector< Node > > d_fv[2]; - std::map< Node, TypeNode > d_fv_stype; - std::map< Node, int > d_fv_num; - bool hasFreeVar( Node n, std::map< Node, bool >& visited ); -public: - TNode getFreeVar( TypeNode tn, int i, bool useSygusType = false ); - TNode getFreeVarInc( TypeNode tn, std::map< TypeNode, int >& var_count, bool useSygusType = false ); - bool isFreeVar( Node n ) { return d_fv_stype.find( n )!=d_fv_stype.end(); } - int getVarNum( Node n ) { return d_fv_num[n]; } - bool hasFreeVar( Node n ); -private: - std::map< TypeNode, std::map< int, Node > > d_generic_base; - std::map< TypeNode, std::vector< Node > > d_generic_templ; - bool getMatch( Node p, Node n, std::map< int, Node >& s ); - bool getMatch2( Node p, Node n, std::map< int, Node >& s, std::vector< int >& new_s ); -public: - bool getMatch( Node n, TypeNode st, int& index_found, std::vector< Node >& args, int index_exc = -1, int index_start = 0 ); private: void computeMinTypeDepthInternal( TypeNode root_tn, TypeNode tn, unsigned type_depth ); bool involvesDivByZero( Node n, std::map< Node, bool >& visited ); @@ -126,15 +177,7 @@ private: std::map<TypeNode, std::map<Node, int> > d_consts; std::map<TypeNode, std::map<Node, int> > d_ops; std::map<TypeNode, std::map<int, Node> > d_arg_ops; - std::map<TypeNode, std::vector<int> > d_id_funcs; - std::map<TypeNode, std::vector<Node> > - d_const_list; // sorted list of constants for type - std::map<TypeNode, unsigned> d_const_list_pos; std::map<TypeNode, std::map<Node, Node> > d_semantic_skolem; - // normalized map - std::map<TypeNode, std::map<Node, Node> > d_normalized; - std::map<TypeNode, std::map<Node, Node> > d_sygus_to_builtin; - std::map<TypeNode, std::map<Node, Node> > d_builtin_const_to_sygus; // grammar information // root -> type -> _ std::map<TypeNode, std::map<TypeNode, unsigned> > d_min_type_depth; @@ -169,27 +212,18 @@ private: Kind getConsNumKind( TypeNode tn, int i ); bool isKindArg( TypeNode tn, int i ); bool isConstArg( TypeNode tn, int i ); - unsigned getNumIdFuncs( TypeNode tn ); - unsigned getIdFuncIndex( TypeNode tn, unsigned i ); /** get arg type */ - TypeNode getArgType( const DatatypeConstructor& c, int i ); + TypeNode getArgType(const DatatypeConstructor& c, unsigned i); /** get first occurrence */ int getFirstArgOccurrence( const DatatypeConstructor& c, TypeNode tn ); /** is type match */ bool isTypeMatch( const DatatypeConstructor& c1, const DatatypeConstructor& c2 ); TypeNode getSygusTypeForVar( Node v ); - Node getGenericBase( TypeNode tn, const Datatype& dt, int c ); - Node mkGeneric( const Datatype& dt, int c, std::map< TypeNode, int >& var_count, std::map< int, Node >& pre ); - Node sygusToBuiltin( Node n, TypeNode tn ); - Node sygusToBuiltin( Node n ) { return sygusToBuiltin( n, n.getType() ); } Node sygusSubstituted( TypeNode tn, Node n, std::vector< Node >& args ); - Node builtinToSygusConst( Node c, TypeNode tn, int rcons_depth = 0 ); Node getSygusNormalized( Node n, std::map< TypeNode, int >& var_count, std::map< Node, Node >& subs ); Node getNormalized(TypeNode t, Node prog); unsigned getSygusTermSize( Node n ); - // returns size - unsigned getSygusConstructors( Node n, std::vector< Node >& cons ); /** given a term, construct an equivalent smaller one that respects syntax */ Node minimizeBuiltinTerm( Node n ); /** given a term, expand it into more basic components */ @@ -197,7 +231,6 @@ private: /** get comparison kind */ Kind getComparisonKind( TypeNode tn ); Kind getPlusKind( TypeNode tn, bool is_neg = false ); - bool doCompare( Node a, Node b, Kind k ); // get semantic skolem for n (a sygus term whose builtin version is n) Node getSemanticSkolem( TypeNode tn, Node n, bool doMk = true ); /** involves div-by-zero */ @@ -244,18 +277,6 @@ public: Node evaluateWithUnfolding( Node n, std::unordered_map<Node, Node, NodeHashFunction>& visited); Node evaluateWithUnfolding( Node n ); -//for calculating redundant operators -private: - //whether each constructor is redundant - // 0 : not redundant, 1 : redundant, 2 : partially redundant - std::map< TypeNode, std::vector< int > > d_sygus_red_status; - // type to (rewritten) to original - std::map< TypeNode, std::map< Node, Node > > d_gen_terms; - std::map< TypeNode, std::map< Node, bool > > d_gen_redundant; - //compute generic redundant - bool computeGenericRedundant( TypeNode tn, Node g ); -public: - bool isGenericRedundant( TypeNode tn, unsigned i ); }; }/* CVC4::theory::quantifiers namespace */ diff --git a/src/theory/quantifiers/theory_quantifiers.h b/src/theory/quantifiers/theory_quantifiers.h index 295a39464..4f87f6aae 100644 --- a/src/theory/quantifiers/theory_quantifiers.h +++ b/src/theory/quantifiers/theory_quantifiers.h @@ -33,37 +33,47 @@ namespace theory { namespace quantifiers { class TheoryQuantifiers : public Theory { -private: - typedef context::CDHashMap< Node, bool, NodeHashFunction > BoolMap; - /** number of instantiations */ - int d_numInstantiations; - int d_baseDecLevel; -private: - void computeCareGraph(); - -public: + public: TheoryQuantifiers(context::Context* c, context::UserContext* u, OutputChannel& out, Valuation valuation, const LogicInfo& logicInfo); ~TheoryQuantifiers(); - void setMasterEqualityEngine(eq::EqualityEngine* eq); - void addSharedTerm(TNode t); + void setMasterEqualityEngine(eq::EqualityEngine* eq) override; + void addSharedTerm(TNode t) override; void notifyEq(TNode lhs, TNode rhs); - void preRegisterTerm(TNode n); - void presolve(); - void ppNotifyAssertions(const std::vector<Node>& assertions); - void check(Effort e); - Node getNextDecisionRequest( unsigned& priority ); + void preRegisterTerm(TNode n) override; + void presolve() override; + void ppNotifyAssertions(const std::vector<Node>& assertions) override; + void check(Effort e) override; + Node getNextDecisionRequest(unsigned& priority) override; Node getValue(TNode n); bool collectModelInfo(TheoryModel* m) override; - void shutdown() { } - std::string identify() const { return std::string("TheoryQuantifiers"); } - void setUserAttribute(const std::string& attr, Node n, std::vector<Node> node_values, std::string str_value); - bool ppDontRewriteSubterm(TNode atom) { return atom.getKind() == kind::FORALL || atom.getKind() == kind::EXISTS; } -private: + void shutdown() override {} + std::string identify() const override + { + return std::string("TheoryQuantifiers"); + } + void setUserAttribute(const std::string& attr, + Node n, + std::vector<Node> node_values, + std::string str_value) override; + bool ppDontRewriteSubterm(TNode atom) override + { + return atom.getKind() == kind::FORALL || atom.getKind() == kind::EXISTS; + } + + private: void assertUniversal( Node n ); void assertExistential( Node n ); + void computeCareGraph() override; + + using BoolMap = context::CDHashMap<Node, bool, NodeHashFunction>; + + /** number of instantiations */ + int d_numInstantiations; + int d_baseDecLevel; + };/* class TheoryQuantifiers */ }/* CVC4::theory::quantifiers namespace */ diff --git a/src/theory/quantifiers_engine.cpp b/src/theory/quantifiers_engine.cpp index 34dde7fc8..a0efe80f9 100644 --- a/src/theory/quantifiers_engine.cpp +++ b/src/theory/quantifiers_engine.cpp @@ -1195,6 +1195,11 @@ Node QuantifiersEngine::getInternalRepresentative( Node a, Node q, int index ){ return ret; } +void QuantifiersEngine::getSynthSolutions(std::map<Node, Node>& sol_map) +{ + d_ceg_inst->getSynthSolutions(sol_map); +} + void QuantifiersEngine::debugPrintEqualityEngine( const char * c ) { eq::EqualityEngine* ee = getActiveEqualityEngine(); eq::EqClassesIterator eqcs_i = eq::EqClassesIterator( ee ); diff --git a/src/theory/quantifiers_engine.h b/src/theory/quantifiers_engine.h index 3cd4e8ef9..3716fc497 100644 --- a/src/theory/quantifiers_engine.h +++ b/src/theory/quantifiers_engine.h @@ -399,6 +399,18 @@ public: void getExplanationForInstLemmas(const std::vector<Node>& lems, std::map<Node, Node>& quant, std::map<Node, std::vector<Node> >& tvec); + + /** get synth solutions + * + * This function adds entries to sol_map that map functions-to-synthesize with + * their solutions, for all active conjectures. This should be called + * immediately after the solver answers unsat for sygus input. + * + * For details on what is added to sol_map, see + * CegConjecture::getSynthSolutions. + */ + void getSynthSolutions(std::map<Node, Node>& sol_map); + //----------end user interface for instantiations /** statistics class */ diff --git a/src/theory/rep_set.cpp b/src/theory/rep_set.cpp index bff5e36cd..04c39c897 100644 --- a/src/theory/rep_set.cpp +++ b/src/theory/rep_set.cpp @@ -43,12 +43,8 @@ bool RepSet::hasRep(TypeNode tn, Node n) const unsigned RepSet::getNumRepresentatives(TypeNode tn) const { - std::map< TypeNode, std::vector< Node > >::const_iterator it = d_type_reps.find( tn ); - if( it!=d_type_reps.end() ){ - return it->second.size(); - }else{ - return 0; - } + const std::vector<Node>* reps = getTypeRepsOrNull(tn); + return (reps != nullptr) ? reps->size() : 0; } Node RepSet::getRepresentative(TypeNode tn, unsigned i) const @@ -60,14 +56,18 @@ Node RepSet::getRepresentative(TypeNode tn, unsigned i) const return it->second[i]; } -void RepSet::getRepresentatives(TypeNode tn, std::vector<Node>& reps) const +const std::vector<Node>* RepSet::getTypeRepsOrNull(TypeNode tn) const { - std::map<TypeNode, std::vector<Node> >::const_iterator it = - d_type_reps.find(tn); - Assert(it != d_type_reps.end()); - reps.insert(reps.end(), it->second.begin(), it->second.end()); + auto it = d_type_reps.find(tn); + if (it == d_type_reps.end()) + { + return nullptr; + } + return &(it->second); } +namespace { + bool containsStoreAll(Node n, std::unordered_set<Node, NodeHashFunction>& cache) { if( std::find( cache.begin(), cache.end(), n )==cache.end() ){ @@ -85,6 +85,8 @@ bool containsStoreAll(Node n, std::unordered_set<Node, NodeHashFunction>& cache) return false; } +} // namespace + void RepSet::add( TypeNode tn, Node n ){ //for now, do not add array constants FIXME if( tn.isArray() ){ @@ -264,7 +266,12 @@ bool RepSetIterator::initialize() if (d_rs->hasType(tn)) { d_enum_type.push_back( ENUM_DEFAULT ); - d_rs->getRepresentatives(tn, d_domain_elements[v]); + if (const auto* type_reps = d_rs->getTypeRepsOrNull(tn)) + { + std::vector<Node>& v_domain_elements = d_domain_elements[v]; + v_domain_elements.insert(v_domain_elements.end(), + type_reps->begin(), type_reps->end()); + } }else{ Assert( d_incomplete ); return false; diff --git a/src/theory/rep_set.h b/src/theory/rep_set.h index 5b75fa943..a75918b5a 100644 --- a/src/theory/rep_set.h +++ b/src/theory/rep_set.h @@ -57,9 +57,9 @@ class QuantifiersEngine; * finite types. */ class RepSet { -public: + public: RepSet(){} - ~RepSet(){} + /** map from types to the list of representatives * TODO : as part of #1199, encapsulate this */ @@ -67,15 +67,19 @@ public: /** clear the set */ void clear(); /** does this set have representatives of type tn? */ - bool hasType( TypeNode tn ) const { return d_type_reps.find( tn )!=d_type_reps.end(); } + bool hasType(TypeNode tn) const { return d_type_reps.count(tn) > 0; } /** does this set have representative n of type tn? */ bool hasRep(TypeNode tn, Node n) const; /** get the number of representatives for type */ unsigned getNumRepresentatives(TypeNode tn) const; /** get representative at index */ Node getRepresentative(TypeNode tn, unsigned i) const; - /** get representatives of type tn, appends them to reps */ - void getRepresentatives(TypeNode tn, std::vector<Node>& reps) const; + /** + * Returns the representatives of a type for a `type_node` if one exists. + * Otherwise, returns nullptr. + */ + const std::vector<Node>* getTypeRepsOrNull(TypeNode type_node) const; + /** add representative n for type tn, where n has type tn */ void add( TypeNode tn, Node n ); /** returns index in d_type_reps for node n */ diff --git a/src/theory/rewriter.cpp b/src/theory/rewriter.cpp index fe58f658d..d13003581 100644 --- a/src/theory/rewriter.cpp +++ b/src/theory/rewriter.cpp @@ -251,7 +251,6 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) { } Unreachable(); - return Node::null(); }/* Rewriter::rewriteTo() */ void Rewriter::clearCaches() { diff --git a/src/theory/sep/theory_sep.h b/src/theory/sep/theory_sep.h index 65f076631..7468d2778 100644 --- a/src/theory/sep/theory_sep.h +++ b/src/theory/sep/theory_sep.h @@ -43,7 +43,7 @@ class TheorySep : public Theory { // MISC ///////////////////////////////////////////////////////////////////////////// - private: + private: /** all lemmas sent */ NodeSet d_lemmas_produced_c; @@ -62,119 +62,137 @@ class TheorySep : public Theory { std::map< int, std::map< Node, std::vector< Node > > >& references, std::map< int, std::map< Node, bool > >& references_strict, bool pol, bool hasPol, bool underSpatial ); - public: - + public: TheorySep(context::Context* c, context::UserContext* u, OutputChannel& out, Valuation valuation, const LogicInfo& logicInfo); ~TheorySep(); - void setMasterEqualityEngine(eq::EqualityEngine* eq); + void setMasterEqualityEngine(eq::EqualityEngine* eq) override; - std::string identify() const { return std::string("TheorySep"); } + std::string identify() const override { return std::string("TheorySep"); } ///////////////////////////////////////////////////////////////////////////// // PREPROCESSING ///////////////////////////////////////////////////////////////////////////// - public: - - PPAssertStatus ppAssert(TNode in, SubstitutionMap& outSubstitutions); - Node ppRewrite(TNode atom); + public: + PPAssertStatus ppAssert(TNode in, SubstitutionMap& outSubstitutions) override; + Node ppRewrite(TNode atom) override; - void ppNotifyAssertions(const std::vector<Node>& assertions); + void ppNotifyAssertions(const std::vector<Node>& assertions) override; ///////////////////////////////////////////////////////////////////////////// // T-PROPAGATION / REGISTRATION ///////////////////////////////////////////////////////////////////////////// - private: - + private: /** Should be called to propagate the literal. */ bool propagate(TNode literal); /** Explain why this literal is true by adding assumptions */ void explain(TNode literal, std::vector<TNode>& assumptions); - public: + public: + void propagate(Effort e) override; + Node explain(TNode n) override; - void propagate(Effort e); - Node explain(TNode n); - - public: - - void addSharedTerm(TNode t); - EqualityStatus getEqualityStatus(TNode a, TNode b); - void computeCareGraph(); + public: + void addSharedTerm(TNode t) override; + EqualityStatus getEqualityStatus(TNode a, TNode b) override; + void computeCareGraph() override; ///////////////////////////////////////////////////////////////////////////// // MODEL GENERATION ///////////////////////////////////////////////////////////////////////////// - public: - bool collectModelInfo(TheoryModel* m) override; - void postProcessModel(TheoryModel* m); - - ///////////////////////////////////////////////////////////////////////////// - // NOTIFICATIONS - ///////////////////////////////////////////////////////////////////////////// + public: + bool collectModelInfo(TheoryModel* m) override; + void postProcessModel(TheoryModel* m) override; - private: - public: + ///////////////////////////////////////////////////////////////////////////// + // NOTIFICATIONS + ///////////////////////////////////////////////////////////////////////////// - Node getNextDecisionRequest( unsigned& priority ); + public: + Node getNextDecisionRequest(unsigned& priority) override; - void presolve(); - void shutdown() { } + void presolve() override; + void shutdown() override {} ///////////////////////////////////////////////////////////////////////////// // MAIN SOLVER ///////////////////////////////////////////////////////////////////////////// - public: - - void check(Effort e); + public: + void check(Effort e) override; - bool needsCheckLastEffort(); + bool needsCheckLastEffort() override; - private: - - // NotifyClass: template helper class for d_equalityEngine - handles call-back from congruence closure module - class NotifyClass : public eq::EqualityEngineNotify { + private: + // NotifyClass: template helper class for d_equalityEngine - handles + // call-back from congruence closure module + class NotifyClass : public eq::EqualityEngineNotify + { TheorySep& d_sep; - public: - NotifyClass(TheorySep& sep): d_sep(sep) {} - bool eqNotifyTriggerEquality(TNode equality, bool value) { - Debug("sep::propagate") << "NotifyClass::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false") << ")" << std::endl; + public: + NotifyClass(TheorySep& sep) : d_sep(sep) {} + + bool eqNotifyTriggerEquality(TNode equality, bool value) + { + Debug("sep::propagate") + << "NotifyClass::eqNotifyTriggerEquality(" << equality << ", " + << (value ? "true" : "false") << ")" << std::endl; // Just forward to sep - if (value) { + if (value) + { return d_sep.propagate(equality); - } else { + } + else + { return d_sep.propagate(equality.notNode()); } } - bool eqNotifyTriggerPredicate(TNode predicate, bool value) { + bool eqNotifyTriggerPredicate(TNode predicate, bool value) + { Unreachable(); } - bool eqNotifyTriggerTermEquality(TheoryId tag, TNode t1, TNode t2, bool value) { - Debug("sep::propagate") << "NotifyClass::eqNotifyTriggerTermEquality(" << t1 << ", " << t2 << ", " << (value ? "true" : "false") << ")" << std::endl; - if (value) { + bool eqNotifyTriggerTermEquality(TheoryId tag, + TNode t1, + TNode t2, + bool value) + { + Debug("sep::propagate") + << "NotifyClass::eqNotifyTriggerTermEquality(" << t1 << ", " << t2 + << ", " << (value ? "true" : "false") << ")" << std::endl; + if (value) + { // Propagate equality between shared terms return d_sep.propagate(t1.eqNode(t2)); - } else { + } + else + { return d_sep.propagate(t1.eqNode(t2).notNode()); } return true; } - void eqNotifyConstantTermMerge(TNode t1, TNode t2) { - Debug("sep::propagate") << "NotifyClass::eqNotifyConstantTermMerge(" << t1 << ", " << t2 << ")" << std::endl; + void eqNotifyConstantTermMerge(TNode t1, TNode t2) + { + Debug("sep::propagate") << "NotifyClass::eqNotifyConstantTermMerge(" << t1 + << ", " << t2 << ")" << std::endl; d_sep.conflict(t1, t2); } - void eqNotifyNewClass(TNode t) { } - void eqNotifyPreMerge(TNode t1, TNode t2) { d_sep.eqNotifyPreMerge( t1, t2 ); } - void eqNotifyPostMerge(TNode t1, TNode t2) { d_sep.eqNotifyPostMerge( t1, t2 ); } - void eqNotifyDisequal(TNode t1, TNode t2, TNode reason) { } + void eqNotifyNewClass(TNode t) {} + void eqNotifyPreMerge(TNode t1, TNode t2) + { + d_sep.eqNotifyPreMerge(t1, t2); + } + void eqNotifyPostMerge(TNode t1, TNode t2) + { + d_sep.eqNotifyPostMerge(t1, t2); + } + void eqNotifyDisequal(TNode t1, TNode t2, TNode reason) {} }; /** The notify class for d_equalityEngine */ @@ -289,7 +307,8 @@ class TheorySep : public Theory { void setInactiveAssertionRec( Node fact, std::map< Node, std::vector< Node > >& lbl_to_assertions, std::map< Node, bool >& assert_active ); Node mkUnion( TypeNode tn, std::vector< Node >& locs ); -private: + + private: Node getRepresentative( Node t ); bool hasTerm( Node a ); bool areEqual( Node a, Node b ); @@ -299,10 +318,9 @@ private: void sendLemma( std::vector< Node >& ant, Node conc, const char * c, bool infer = false ); void doPendingFacts(); -public: - eq::EqualityEngine* getEqualityEngine() { - return &d_equalityEngine; - } + + public: + eq::EqualityEngine* getEqualityEngine() override { return &d_equalityEngine; } void initializeBounds(); };/* class TheorySep */ diff --git a/src/theory/sets/theory_sets.h b/src/theory/sets/theory_sets.h index 1f0fbdd1f..a246903a1 100644 --- a/src/theory/sets/theory_sets.h +++ b/src/theory/sets/theory_sets.h @@ -46,36 +46,36 @@ public: ~TheorySets(); - void addSharedTerm(TNode); + void addSharedTerm(TNode) override; - void check(Effort); - - bool needsCheckLastEffort(); + void check(Effort) override; + + bool needsCheckLastEffort() override; bool collectModelInfo(TheoryModel* m) override; - void computeCareGraph(); + void computeCareGraph() override; + + Node explain(TNode) override; - Node explain(TNode); + EqualityStatus getEqualityStatus(TNode a, TNode b) override; - EqualityStatus getEqualityStatus(TNode a, TNode b); + Node getModelValue(TNode) override; - Node getModelValue(TNode); + std::string identify() const override { return "THEORY_SETS"; } - std::string identify() const { return "THEORY_SETS"; } + void preRegisterTerm(TNode node) override; - void preRegisterTerm(TNode node); + Node expandDefinition(LogicRequest& logicRequest, Node n) override; - Node expandDefinition(LogicRequest &logicRequest, Node n); + PPAssertStatus ppAssert(TNode in, SubstitutionMap& outSubstitutions) override; - PPAssertStatus ppAssert(TNode in, SubstitutionMap& outSubstitutions); + void presolve() override; - void presolve(); + void propagate(Effort) override; - void propagate(Effort); + void setMasterEqualityEngine(eq::EqualityEngine* eq) override; - void setMasterEqualityEngine(eq::EqualityEngine* eq); - bool isEntailed( Node n, bool pol ); };/* class TheorySets */ diff --git a/src/theory/strings/theory_strings.h b/src/theory/strings/theory_strings.h index f07057444..e07cc6b5e 100644 --- a/src/theory/strings/theory_strings.h +++ b/src/theory/strings/theory_strings.h @@ -54,25 +54,28 @@ class TheoryStrings : public Theory { typedef context::CDHashMap<Node, Node, NodeHashFunction> NodeNodeMap; typedef context::CDHashSet<Node, NodeHashFunction> NodeSet; -public: + public: TheoryStrings(context::Context* c, context::UserContext* u, OutputChannel& out, Valuation valuation, const LogicInfo& logicInfo); ~TheoryStrings(); - void setMasterEqualityEngine(eq::EqualityEngine* eq); + void setMasterEqualityEngine(eq::EqualityEngine* eq) override; - std::string identify() const { return std::string("TheoryStrings"); } + std::string identify() const override { return std::string("TheoryStrings"); } -public: - void propagate(Effort e); + public: + void propagate(Effort e) override; bool propagate(TNode literal); void explain( TNode literal, std::vector<TNode>& assumptions ); - Node explain( TNode literal ); - eq::EqualityEngine * getEqualityEngine() { return &d_equalityEngine; } - bool getCurrentSubstitution( int effort, std::vector< Node >& vars, std::vector< Node >& subs, std::map< Node, std::vector< Node > >& exp ); - int getReduction( int effort, Node n, Node& nr ); - + Node explain(TNode literal) override; + eq::EqualityEngine* getEqualityEngine() override { return &d_equalityEngine; } + bool getCurrentSubstitution(int effort, + std::vector<Node>& vars, + std::vector<Node>& subs, + std::map<Node, std::vector<Node> >& exp) override; + int getReduction(int effort, Node n, Node& nr) override; + // NotifyClass for equality engine class NotifyClass : public eq::EqualityEngineNotify { TheoryStrings& d_str; @@ -213,24 +216,24 @@ private: ///////////////////////////////////////////////////////////////////////////// // MODEL GENERATION ///////////////////////////////////////////////////////////////////////////// -public: - bool collectModelInfo(TheoryModel* m) override; + public: + bool collectModelInfo(TheoryModel* m) override; - ///////////////////////////////////////////////////////////////////////////// - // NOTIFICATIONS - ///////////////////////////////////////////////////////////////////////////// -public: - void presolve(); - void shutdown() { } + ///////////////////////////////////////////////////////////////////////////// + // NOTIFICATIONS + ///////////////////////////////////////////////////////////////////////////// + public: + void presolve() override; + void shutdown() override {} ///////////////////////////////////////////////////////////////////////////// // MAIN SOLVER ///////////////////////////////////////////////////////////////////////////// -private: - void addSharedTerm(TNode n); - EqualityStatus getEqualityStatus(TNode a, TNode b); + private: + void addSharedTerm(TNode n) override; + EqualityStatus getEqualityStatus(TNode a, TNode b) override; -private: + private: class EqcInfo { public: EqcInfo( context::Context* c ); @@ -367,17 +370,18 @@ private: //cardinality check void checkCardinality(); -private: + private: void addCarePairs( quantifiers::TermArgTrie * t1, quantifiers::TermArgTrie * t2, unsigned arity, unsigned depth ); -public: + + public: /** preregister term */ - void preRegisterTerm(TNode n); + void preRegisterTerm(TNode n) override; /** Expand definition */ - Node expandDefinition(LogicRequest &logicRequest, Node n); + Node expandDefinition(LogicRequest& logicRequest, Node n) override; /** Check at effort e */ - void check(Effort e); + void check(Effort e) override; /** needs check last effort */ - bool needsCheckLastEffort(); + bool needsCheckLastEffort() override; /** Conflict when merging two constants */ void conflict(TNode a, TNode b); /** called when a new equivalence class is created */ @@ -389,39 +393,48 @@ public: /** called when two equivalence classes are made disequal */ void eqNotifyDisequal(TNode t1, TNode t2, TNode reason); /** get preprocess */ - StringsPreprocess * getPreprocess() { return &d_preproc; } -protected: + StringsPreprocess* getPreprocess() { return &d_preproc; } + + protected: /** compute care graph */ - void computeCareGraph(); + void computeCareGraph() override; - //do pending merges + // do pending merges void assertPendingFact(Node atom, bool polarity, Node exp); void doPendingFacts(); void doPendingLemmas(); bool hasProcessed(); - void addToExplanation( Node a, Node b, std::vector< Node >& exp ); - void addToExplanation( Node lit, std::vector< Node >& exp ); - - //register term - void registerTerm( Node n, int effort ); - //send lemma - void sendInference( std::vector< Node >& exp, std::vector< Node >& exp_n, Node eq, const char * c, bool asLemma = false ); - void sendInference( std::vector< Node >& exp, Node eq, const char * c, bool asLemma = false ); - void sendLemma( Node ant, Node conc, const char * c ); - void sendInfer( Node eq_exp, Node eq, const char * c ); + void addToExplanation(Node a, Node b, std::vector<Node>& exp); + void addToExplanation(Node lit, std::vector<Node>& exp); + + // register term + void registerTerm(Node n, int effort); + // send lemma + void sendInference(std::vector<Node>& exp, + std::vector<Node>& exp_n, + Node eq, + const char* c, + bool asLemma = false); + void sendInference(std::vector<Node>& exp, + Node eq, + const char* c, + bool asLemma = false); + void sendLemma(Node ant, Node conc, const char* c); + void sendInfer(Node eq_exp, Node eq, const char* c); bool sendSplit(Node a, Node b, const char* c, bool preq = true); - void sendLengthLemma( Node n ); + void sendLengthLemma(Node n); /** mkConcat **/ - inline Node mkConcat( Node n1, Node n2 ); - inline Node mkConcat( Node n1, Node n2, Node n3 ); - inline Node mkConcat( const std::vector< Node >& c ); - inline Node mkLength( Node n ); - //mkSkolem - enum { + inline Node mkConcat(Node n1, Node n2); + inline Node mkConcat(Node n1, Node n2, Node n3); + inline Node mkConcat(const std::vector<Node>& c); + inline Node mkLength(Node n); + // mkSkolem + enum + { sk_id_c_spt, sk_id_vc_spt, sk_id_vc_bin_spt, - sk_id_v_spt, + sk_id_v_spt, sk_id_c_spt_rev, sk_id_vc_spt_rev, sk_id_vc_bin_spt_rev, @@ -434,30 +447,36 @@ protected: sk_id_deq_y, sk_id_deq_z, }; - std::map< Node, std::map< Node, std::map< int, Node > > > d_skolem_cache; - Node mkSkolemCached( Node a, Node b, int id, const char * c, int isLenSplit = 0 ); - inline Node mkSkolemS(const char * c, int isLenSplit = 0); - void registerNonEmptySkolem( Node sk ); - //inline Node mkSkolemI(const char * c); + std::map<Node, std::map<Node, std::map<int, Node> > > d_skolem_cache; + Node mkSkolemCached( + Node a, Node b, int id, const char* c, int isLenSplit = 0); + inline Node mkSkolemS(const char* c, int isLenSplit = 0); + void registerNonEmptySkolem(Node sk); + // inline Node mkSkolemI(const char * c); /** mkExplain **/ - Node mkExplain( std::vector< Node >& a ); - Node mkExplain( std::vector< Node >& a, std::vector< Node >& an ); + Node mkExplain(std::vector<Node>& a); + Node mkExplain(std::vector<Node>& a, std::vector<Node>& an); /** mkAnd **/ - Node mkAnd( std::vector< Node >& a ); + Node mkAnd(std::vector<Node>& a); /** get concat vector */ - void getConcatVec( Node n, std::vector< Node >& c ); + void getConcatVec(Node n, std::vector<Node>& c); - //get equivalence classes - void getEquivalenceClasses( std::vector< Node >& eqcs ); + // get equivalence classes + void getEquivalenceClasses(std::vector<Node>& eqcs); - //separate into collections with equal length - void separateByLength( std::vector< Node >& n, std::vector< std::vector< Node > >& col, std::vector< Node >& lts ); - void printConcat( std::vector< Node >& n, const char * c ); + // separate into collections with equal length + void separateByLength(std::vector<Node>& n, + std::vector<std::vector<Node> >& col, + std::vector<Node>& lts); + void printConcat(std::vector<Node>& n, const char* c); - void inferSubstitutionProxyVars( Node n, std::vector< Node >& vars, std::vector< Node >& subs, std::vector< Node >& unproc ); + void inferSubstitutionProxyVars(Node n, + std::vector<Node>& vars, + std::vector<Node>& subs, + std::vector<Node>& unproc); // Symbolic Regular Expression -private: + private: // regular expression memberships NodeList d_regexp_memberships; NodeSet d_regexp_ucached; @@ -492,18 +511,20 @@ private: // Finite Model Finding -private: + private: NodeSet d_input_vars; context::CDO< Node > d_input_var_lsum; context::CDHashMap< int, Node > d_cardinality_lits; context::CDO< int > d_curr_cardinality; -public: + + public: //for finite model finding - Node getNextDecisionRequest( unsigned& priority ); - //ppRewrite - Node ppRewrite(TNode atom); -public: -/** statistics class */ + Node getNextDecisionRequest(unsigned& priority) override; + // ppRewrite + Node ppRewrite(TNode atom) override; + + public: + /** statistics class */ class Statistics { public: IntStat d_splits; diff --git a/src/theory/strings/theory_strings_rewriter.cpp b/src/theory/strings/theory_strings_rewriter.cpp index a478667e9..f79922a53 100644 --- a/src/theory/strings/theory_strings_rewriter.cpp +++ b/src/theory/strings/theory_strings_rewriter.cpp @@ -1926,40 +1926,59 @@ Node TheoryStringsRewriter::rewriteReplace( Node node ) { if( node[1]==node[2] ){ return returnRewrite(node, node[0], "rpl-id"); } - else if (node[0] == node[1]) + else if (node[0].isConst() && node[0].getConst<String>().isEmptyString()) { - return returnRewrite(node, node[2], "rpl-replace"); + return returnRewrite(node, node[0], "rpl-empty"); } - else if (node[1].isConst()) + else if (node[1].isConst() && node[1].getConst<String>().isEmptyString()) { - if (node[1].getConst<String>().isEmptyString()) - { - return returnRewrite(node, node[0], "rpl-empty"); - } - else if (node[0].isConst()) + return returnRewrite(node, node[0], "rpl-rpl-empty"); + } + + std::vector<Node> children0; + getConcat(node[0], children0); + + if (node[1].isConst() && children0[0].isConst()) + { + CVC4::String s = children0[0].getConst<String>(); + CVC4::String t = node[1].getConst<String>(); + std::size_t p = s.find(t); + if (p == std::string::npos) { - CVC4::String s = node[0].getConst<String>(); - CVC4::String t = node[1].getConst<String>(); - std::size_t p = s.find(t); - if (p == std::string::npos) + if (children0.size() == 1) { return returnRewrite(node, node[0], "rpl-const-nfind"); } - else + // if no overlap, we can pull the first child + if (s.overlap(t) == 0) { - CVC4::String s1 = s.substr(0, (int)p); - CVC4::String s3 = s.substr((int)p + (int)t.size()); - Node ns1 = NodeManager::currentNM()->mkConst(::CVC4::String(s1)); - Node ns3 = NodeManager::currentNM()->mkConst(::CVC4::String(s3)); + std::vector<Node> spl(children0.begin() + 1, children0.end()); Node ret = NodeManager::currentNM()->mkNode( - kind::STRING_CONCAT, ns1, node[2], ns3); - return returnRewrite(node, ret, "rpl-const-find"); + kind::STRING_CONCAT, + children0[0], + NodeManager::currentNM()->mkNode(kind::STRING_STRREPL, + mkConcat(kind::STRING_CONCAT, spl), + node[1], + node[2])); + return returnRewrite(node, ret, "rpl-prefix-nfind"); } } + else + { + CVC4::String s1 = s.substr(0, (int)p); + CVC4::String s3 = s.substr((int)p + (int)t.size()); + Node ns1 = NodeManager::currentNM()->mkConst(::CVC4::String(s1)); + Node ns3 = NodeManager::currentNM()->mkConst(::CVC4::String(s3)); + std::vector<Node> children; + children.push_back(ns1); + children.push_back(node[2]); + children.push_back(ns3); + children.insert(children.end(), children0.begin() + 1, children0.end()); + Node ret = mkConcat(kind::STRING_CONCAT, children); + return returnRewrite(node, ret, "rpl-const-find"); + } } - std::vector<Node> children0; - getConcat(node[0], children0); std::vector<Node> children1; getConcat(node[1], children1); @@ -1971,13 +1990,26 @@ Node TheoryStringsRewriter::rewriteReplace( Node node ) { { if (cmp_conr.getConst<bool>()) { + // currently by the semantics of replace, if the second argument is + // empty, then we return the first argument. + // hence, we test whether the second argument must be non-empty here. + // if it definitely non-empty, we can use rules that successfully replace + // node[1]->node[2] among those below. + Node l1 = NodeManager::currentNM()->mkNode(kind::STRING_LENGTH, node[1]); + Node zero = NodeManager::currentNM()->mkConst(CVC4::Rational(0)); + bool is_non_empty = checkEntailArith(l1, zero, true); + + if (node[0] == node[1] && is_non_empty) + { + return returnRewrite(node, node[2], "rpl-replace"); + } // component-wise containment std::vector<Node> cb; std::vector<Node> ce; int cc = componentContains(children0, children1, cb, ce, true, 1); if (cc != -1) { - if (cc == 0 && children0[0] == children1[0]) + if (cc == 0 && children0[0] == children1[0] && is_non_empty) { // definitely a prefix, can do the replace // for example, @@ -1995,6 +2027,7 @@ Node TheoryStringsRewriter::rewriteReplace( Node node ) { // for example, // str.replace( str.++( x, "ab" ), "a", y ) ---> // str.++( str.replace( str.++( x, "a" ), "a", y ), "b" ) + // this is independent of whether the second argument may be empty std::vector<Node> cc; cc.push_back(NodeManager::currentNM()->mkNode( kind::STRING_STRREPL, @@ -2599,7 +2632,7 @@ bool TheoryStringsRewriter::stripConstantEndpoints(std::vector<Node>& n1, else if (n2[index1].getKind() == kind::STRING_ITOS) { const std::vector<unsigned>& svec = s.getVec(); - // can remove up to the first occurrence of a non-digit + // can remove up to the first occurrence of a digit for (unsigned i = 0; i < svec.size(); i++) { unsigned sindex = r == 0 ? i : svec.size() - i; @@ -2609,8 +2642,8 @@ bool TheoryStringsRewriter::stripConstantEndpoints(std::vector<Node>& n1, } else { - // e.g. str.contains( str.++( "a", x ), str.to.int(y) ) --> - // str.contains( x, str.to.int(y) ) + // e.g. str.contains( str.++( "a", x ), int.to.str(y) ) --> + // str.contains( x, int.to.str(y) ) overlap--; } } @@ -2656,7 +2689,7 @@ bool TheoryStringsRewriter::stripConstantEndpoints(std::vector<Node>& n1, { // if n1.size()==1, then if n2[index1] is not a number, we can drop // the entire component - // e.g. str.contains( str.to.int(x), "123a45") --> false + // e.g. str.contains( int.to.str(x), "123a45") --> false if (!t.isNumber()) { removeComponent = true; @@ -2670,9 +2703,9 @@ bool TheoryStringsRewriter::stripConstantEndpoints(std::vector<Node>& n1, // if n1.size()>1, then if the first (resp. last) character of // n2[index1] // is not a digit, we can drop the entire component, e.g.: - // str.contains( str.++( str.to.int(x), y ), "a12") --> + // str.contains( str.++( int.to.str(x), y ), "a12") --> // str.contains( y, "a12" ) - // str.contains( str.++( y, str.to.int(x) ), "a0b") --> + // str.contains( str.++( y, int.to.str(x) ), "a0b") --> // str.contains( y, "a0b" ) unsigned i = r == 0 ? 0 : (tvec.size() - 1); if (!String::isDigit(tvec[i])) diff --git a/src/theory/theory_engine.cpp b/src/theory/theory_engine.cpp index 435dadce7..edbd768d7 100644 --- a/src/theory/theory_engine.cpp +++ b/src/theory/theory_engine.cpp @@ -902,6 +902,11 @@ TheoryModel* TheoryEngine::getModel() { return d_curr_model; } +void TheoryEngine::getSynthSolutions(std::map<Node, Node>& sol_map) +{ + d_quantEngine->getSynthSolutions(sol_map); +} + bool TheoryEngine::presolve() { // Reset the interrupt flag d_interrupted = false; diff --git a/src/theory/theory_engine.h b/src/theory/theory_engine.h index 22e269409..7bc95b097 100644 --- a/src/theory/theory_engine.h +++ b/src/theory/theory_engine.h @@ -715,6 +715,17 @@ public: */ theory::TheoryModel* getModel(); + /** get synth solutions + * + * This function adds entries to sol_map that map functions-to-synthesize with + * their solutions, for all active conjectures. This should be called + * immediately after the solver answers unsat for sygus input. + * + * For details on what is added to sol_map, see + * CegConjecture::getSynthSolutions. + */ + void getSynthSolutions(std::map<Node, Node>& sol_map); + /** * Get the model builder */ diff --git a/src/theory/theory_model.h b/src/theory/theory_model.h index 934a09a8e..1f9fd92d4 100644 --- a/src/theory/theory_model.h +++ b/src/theory/theory_model.h @@ -162,13 +162,13 @@ public: */ Node getValue(TNode n, bool useDontCares = false) const; /** get comments */ - void getComments(std::ostream& out) const; + void getComments(std::ostream& out) const override; //---------------------------- separation logic /** set the heap and value sep.nil is equal to */ void setHeapModel(Node h, Node neq); /** get the heap and value sep.nil is equal to */ - bool getHeapModel(Expr& h, Expr& neq) const; + bool getHeapModel(Expr& h, Expr& neq) const override; //---------------------------- end separation logic /** get the representative set object */ @@ -176,11 +176,11 @@ public: /** get the representative set object (FIXME: remove this, see #1199) */ RepSet* getRepSetPtr() { return &d_rep_set; } /** return whether this node is a don't-care */ - bool isDontCare(Expr expr) const; + bool isDontCare(Expr expr) const override; /** get value function for Exprs. */ - Expr getValue( Expr expr ) const; + Expr getValue(Expr expr) const override; /** get cardinality for sort */ - Cardinality getCardinality( Type t ) const; + Cardinality getCardinality(Type t) const override; /** print representative debug function */ void printRepresentativeDebug( const char* c, Node r ); /** print representative function */ diff --git a/src/theory/uf/theory_uf.h b/src/theory/uf/theory_uf.h index 269aa63db..6fde4a9af 100644 --- a/src/theory/uf/theory_uf.h +++ b/src/theory/uf/theory_uf.h @@ -234,34 +234,30 @@ public: ~TheoryUF(); - void setMasterEqualityEngine(eq::EqualityEngine* eq); - void finishInit(); + void setMasterEqualityEngine(eq::EqualityEngine* eq) override; + void finishInit() override; - void check(Effort); - Node expandDefinition(LogicRequest &logicRequest, Node node); - void preRegisterTerm(TNode term); - Node explain(TNode n); + void check(Effort) override; + Node expandDefinition(LogicRequest& logicRequest, Node node) override; + void preRegisterTerm(TNode term) override; + Node explain(TNode n) override; bool collectModelInfo(TheoryModel* m) override; - void ppStaticLearn(TNode in, NodeBuilder<>& learned); - void presolve(); + void ppStaticLearn(TNode in, NodeBuilder<>& learned) override; + void presolve() override; - void addSharedTerm(TNode n); - void computeCareGraph(); + void addSharedTerm(TNode n) override; + void computeCareGraph() override; - void propagate(Effort effort); - Node getNextDecisionRequest( unsigned& priority ); + void propagate(Effort effort) override; + Node getNextDecisionRequest(unsigned& priority) override; - EqualityStatus getEqualityStatus(TNode a, TNode b); + EqualityStatus getEqualityStatus(TNode a, TNode b) override; - std::string identify() const { - return "THEORY_UF"; - } + std::string identify() const override { return "THEORY_UF"; } - eq::EqualityEngine* getEqualityEngine() { - return &d_equalityEngine; - } + eq::EqualityEngine* getEqualityEngine() override { return &d_equalityEngine; } StrongSolverTheoryUF* getStrongSolver() { return d_thss; diff --git a/src/util/Makefile.am b/src/util/Makefile.am index 33218dbe9..ddee2e72b 100644 --- a/src/util/Makefile.am +++ b/src/util/Makefile.am @@ -38,6 +38,8 @@ libutil_la_SOURCES = \ index.h \ maybe.h \ ntuple.h \ + ostream_util.cpp \ + ostream_util.h \ proof.h \ regexp.cpp \ regexp.h \ diff --git a/src/util/ostream_util.cpp b/src/util/ostream_util.cpp new file mode 100644 index 000000000..3d6eeea01 --- /dev/null +++ b/src/util/ostream_util.cpp @@ -0,0 +1,31 @@ +/********************* */ +/*! \file result.cpp + ** \verbatim + ** Top contributors (to current version): + ** Tim King, Morgan Deters, Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2017 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Utilities for using ostreams. + ** + ** Utilities for using ostreams. + **/ +#include "util/ostream_util.h" + +namespace CVC4 { + +StreamFormatScope::StreamFormatScope(std::ostream& out) + : d_out(out), d_format_flags(out.flags()), d_precision(out.precision()) +{ +} + +StreamFormatScope::~StreamFormatScope() +{ + d_out.precision(d_precision); + d_out.flags(d_format_flags); +} + +} // namespace CVC4 diff --git a/src/util/ostream_util.h b/src/util/ostream_util.h new file mode 100644 index 000000000..e047caa17 --- /dev/null +++ b/src/util/ostream_util.h @@ -0,0 +1,49 @@ +/********************* */ +/*! \file ostream_util.h + ** \verbatim + ** Top contributors (to current version): + ** Tim King + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2017 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Utilities for using ostreams. + ** + ** Utilities for using ostreams. + **/ + +#include "cvc4_private.h" + +#ifndef __CVC4__UTIL__OSTREAM_UTIL_H +#define __CVC4__UTIL__OSTREAM_UTIL_H + +#include <ios> +#include <ostream> + +namespace CVC4 { + +// Saves the formatting of an ostream and restores the previous settings on +// destruction. Example usage: +// void Foo::Print(std::ostream& out) { +// StreamFormatScope format_scope(out); +// out << std::setprecision(6) << bar(); +// } +class StreamFormatScope +{ + public: + // `out` must outlive StreamFormatScope. + StreamFormatScope(std::ostream& out); + ~StreamFormatScope(); + + private: + // Does not own the memory of d_out + std::ostream& d_out; + std::ios_base::fmtflags d_format_flags; + std::streamsize d_precision; +}; + +} // namespace CVC4 + +#endif /* __CVC4__UTIL__OSTREAM_UTIL_H */ diff --git a/src/util/sexpr.cpp b/src/util/sexpr.cpp index 61dbccbee..504d58b0e 100644 --- a/src/util/sexpr.cpp +++ b/src/util/sexpr.cpp @@ -30,6 +30,7 @@ #include "base/cvc4_assert.h" #include "options/set_language.h" +#include "util/ostream_util.h" #include "util/smt2_quote_string.h" namespace CVC4 { @@ -219,6 +220,8 @@ void SExpr::toStream(std::ostream& out, const SExpr& sexpr, void SExpr::toStreamRec(std::ostream& out, const SExpr& sexpr, OutputLanguage language, int indent) { + StreamFormatScope scope(out); + if (sexpr.isInteger()) { out << sexpr.getIntegerValue(); } else if (sexpr.isRational()) { diff --git a/src/util/statistics_registry.cpp b/src/util/statistics_registry.cpp index 11afb99ed..2dd1eddd7 100644 --- a/src/util/statistics_registry.cpp +++ b/src/util/statistics_registry.cpp @@ -17,9 +17,13 @@ #include "util/statistics_registry.h" +#include <cstdlib> +#include <iostream> + #include "base/cvc4_assert.h" +#include "base/cvc4_check.h" #include "lib/clock_gettime.h" - +#include "util/ostream_util.h" #ifdef CVC4_STATISTICS_ON # define __CVC4_USE_STATISTICS true @@ -73,7 +77,7 @@ inline timespec& operator-=(timespec& a, const timespec& b) { nsec -= nsec_per_sec; ++a.tv_sec; } - Assert(nsec >= 0 && nsec < nsec_per_sec); + DCHECK(nsec >= 0 && nsec < nsec_per_sec); a.tv_nsec = nsec; return a; } @@ -134,6 +138,7 @@ inline bool operator>=(const timespec& a, const timespec& b) { /** Output a timespec on an output stream. */ std::ostream& operator<<(std::ostream& os, const timespec& t) { // assumes t.tv_nsec is in range + StreamFormatScope format_scope(os); return os << t.tv_sec << "." << std::setfill('0') << std::setw(9) << std::right << t.tv_nsec; } @@ -163,12 +168,11 @@ void StatisticsRegistry::registerStat(Stat* s) void StatisticsRegistry::unregisterStat(Stat* s) { #ifdef CVC4_STATISTICS_ON - PrettyCheckArgument(d_stats.find(s) != d_stats.end(), s, - "Statistic `%s' was not registered with this registry.", - s->getName().c_str()); - d_stats.erase(s); + CHECK(s != nullptr); + CHECK(d_stats.erase(s) > 0) << "Statistic `" << s->getName() + << "' was not registered with this registry."; #endif /* CVC4_STATISTICS_ON */ -}/* StatisticsRegistry::unregisterStat_() */ +} /* StatisticsRegistry::unregisterStat() */ void StatisticsRegistry::flushStat(std::ostream &out) const { #ifdef CVC4_STATISTICS_ON @@ -198,7 +202,7 @@ void TimerStat::start() { void TimerStat::stop() { if(__CVC4_USE_STATISTICS) { - PrettyCheckArgument(d_running, *this, "timer not running"); + CHECK(d_running) << "timer not running"; ::timespec end; clock_gettime(CLOCK_MONOTONIC, &end); d_data += end - d_start; diff --git a/test/regress/regress0/bv/Makefile.am b/test/regress/regress0/bv/Makefile.am index 68a5f791c..912f6871d 100644 --- a/test/regress/regress0/bv/Makefile.am +++ b/test/regress/regress0/bv/Makefile.am @@ -105,7 +105,8 @@ SMT_TESTS = \ divtest_2_5.smt2 \ divtest_2_6.smt2 \ mul-neg-unsat.smt2 \ - mul-negpow2.smt2 + mul-negpow2.smt2 \ + bvmul-pow2-only.smt2 # This benchmark is currently disabled as it uses --check-proof # bench_38.delta.smt2 diff --git a/test/regress/regress0/bv/bvmul-pow2-only.smt2 b/test/regress/regress0/bv/bvmul-pow2-only.smt2 new file mode 100644 index 000000000..d4f085046 --- /dev/null +++ b/test/regress/regress0/bv/bvmul-pow2-only.smt2 @@ -0,0 +1,9 @@ +(set-logic QF_BV) +(set-info :status sat) +(declare-fun x () (_ BitVec 4)) + +(assert (= x #b1000)) + +(assert (= (bvmul (bvneg x) x) #b0000)) +(assert (= (bvmul (bvneg #b0100) #b0100) #b0000)) +(check-sat) diff --git a/test/regress/regress0/nl/Makefile.am b/test/regress/regress0/nl/Makefile.am index 4f7c2172b..e770ca9ba 100644 --- a/test/regress/regress0/nl/Makefile.am +++ b/test/regress/regress0/nl/Makefile.am @@ -79,7 +79,10 @@ TESTS = \ nta/exp-n0.5-lb.smt2 \ nta/dumortier_llibre_artes_ex_5_13.transcendental.k2.smt2 \ nta/NAVIGATION2.smt2 \ - nta/sin1-sat.smt2 + nta/sin1-sat.smt2 \ + nta/sugar-ident.smt2 \ + nta/sugar-ident-2.smt2 \ + nta/sugar-ident-3.smt2 # unsolved : garbage_collect.cvc diff --git a/test/regress/regress0/nl/nta/sugar-ident-2.smt2 b/test/regress/regress0/nl/nta/sugar-ident-2.smt2 new file mode 100644 index 000000000..84c224715 --- /dev/null +++ b/test/regress/regress0/nl/nta/sugar-ident-2.smt2 @@ -0,0 +1,27 @@ +; COMMAND-LINE: --nl-ext-tf-tplanes +; EXPECT: unsat +(set-logic QF_NRA) +(set-info :status unsat) +(declare-fun x1 () Real) +(declare-fun x2 () Real) +(declare-fun x3 () Real) +(declare-fun x4 () Real) +(declare-fun x5 () Real) + +(declare-fun a1 () Bool) +(declare-fun a2 () Bool) +(declare-fun a3 () Bool) +(declare-fun a4 () Bool) +(declare-fun a5 () Bool) +(declare-fun a6 () Bool) +(declare-fun a7 () Bool) + +(assert (= a2 (and (> (sin 1.0) 0.0) (> (cot 1.0) (/ (cos 1.0) (sin 1.0)))))) +(assert (= a7 (> (* (sec 1.0) (cos 1.0)) 1.0))) + +(assert (or +a2 +a7 +)) + +(check-sat) diff --git a/test/regress/regress0/nl/nta/sugar-ident-3.smt2 b/test/regress/regress0/nl/nta/sugar-ident-3.smt2 new file mode 100644 index 000000000..ab50bcb1d --- /dev/null +++ b/test/regress/regress0/nl/nta/sugar-ident-3.smt2 @@ -0,0 +1,8 @@ +; COMMAND-LINE: --nl-ext-tf-tplanes +; EXPECT: unsat +(set-logic QF_NRA) +(set-info :status unsat) +(declare-fun a6 () Bool) +(assert (= a6 (> (* (csc 1.0) (sin 1.0)) 1.0))) +(assert a6) +(check-sat) diff --git a/test/regress/regress0/nl/nta/sugar-ident.smt2 b/test/regress/regress0/nl/nta/sugar-ident.smt2 new file mode 100644 index 000000000..95dbbc5fc --- /dev/null +++ b/test/regress/regress0/nl/nta/sugar-ident.smt2 @@ -0,0 +1,23 @@ +; COMMAND-LINE: --nl-ext-tf-tplanes +; EXPECT: unsat +(set-logic QF_NRA) +(set-info :status unsat) +(declare-fun x1 () Real) +(declare-fun x2 () Real) +(declare-fun x3 () Real) +(declare-fun x4 () Real) +(declare-fun x5 () Real) + +(declare-fun a1 () Bool) +(declare-fun a3 () Bool) +(declare-fun a4 () Bool) +(declare-fun a5 () Bool) +(declare-fun a6 () Bool) + +(assert (= a1 (not (= (sin (arcsin x1)) x1)))) +(assert (= a3 (< (arccos x3) 0))) +(assert (= a4 (> (arctan x4) 1.8))) + +(assert (or a1 a3 a4)) + +(check-sat) diff --git a/test/regress/regress0/strings/Makefile.am b/test/regress/regress0/strings/Makefile.am index 18b07b91d..7f7511e74 100644 --- a/test/regress/regress0/strings/Makefile.am +++ b/test/regress/regress0/strings/Makefile.am @@ -95,7 +95,9 @@ TESTS = \ substr-rewrites.smt2 \ norn-ab.smt2 \ type002.smt2 \ - strip-endpt-sound.smt2 + strip-endpt-sound.smt2 \ + repl-rewrites2.smt2 \ + repl-soundness-sem.smt2 FAILING_TESTS = diff --git a/test/regress/regress0/strings/repl-rewrites2.smt2 b/test/regress/regress0/strings/repl-rewrites2.smt2 new file mode 100644 index 000000000..42699bc8b --- /dev/null +++ b/test/regress/regress0/strings/repl-rewrites2.smt2 @@ -0,0 +1,14 @@ +; COMMAND-LINE: --strings-exp +; EXPECT: unsat +(set-logic ALL) +(set-info :status unsat) +(declare-fun x () String) +(declare-fun y () String) +(assert (or +(not (= (str.replace "" "" "c") "")) +(not (= (str.replace (str.++ "abc" y) "b" x) (str.++ "a" x "c" y))) +(not (= (str.replace "" "abc" "de") "")) +(not (= (str.replace "ab" "ab" "de") "de")) +(not (= (str.replace "ab" "" "de") "ab")) +)) +(check-sat) diff --git a/test/regress/regress0/strings/repl-soundness-sem.smt2 b/test/regress/regress0/strings/repl-soundness-sem.smt2 new file mode 100644 index 000000000..d56d7945f --- /dev/null +++ b/test/regress/regress0/strings/repl-soundness-sem.smt2 @@ -0,0 +1,12 @@ +; COMMAND-LINE: --strings-exp +; EXPECT: sat +(set-logic ALL) +(set-info :status sat) +(declare-fun x () String) +(declare-fun y () String) +(assert (and +(= (str.replace x x "abc") "") +(= (str.replace (str.++ x y) x "abc") (str.++ x y)) +(= (str.replace (str.++ x y) (str.substr x 0 2) "abc") y) +)) +(check-sat) diff --git a/test/regress/regress0/sygus/Makefile.am b/test/regress/regress0/sygus/Makefile.am index dc721248c..9e7427eb0 100644 --- a/test/regress/regress0/sygus/Makefile.am +++ b/test/regress/regress0/sygus/Makefile.am @@ -71,15 +71,18 @@ TESTS = commutative.sy \ process-10-vars-2fun.sy \ inv-unused.sy \ ccp16.lus.sy \ - Base16_1.sy \ icfp_14.12-flip-args.sy \ strings-template-infer-unused.sy \ strings-trivial-two-type.sy \ strings-double-rec.sy \ hd-19-d1-prog-dup-op.sy \ real-grammar-neg.sy \ - real-si-all.sy + real-si-all.sy \ + c100.sy \ + check-generic-red.sy +# disabled, takes too long with and without CBQI BV +# Base16_1.sy # sygus tests currently taking too long for make regress EXTRA_DIST = $(TESTS) \ diff --git a/test/regress/regress0/sygus/c100.sy b/test/regress/regress0/sygus/c100.sy new file mode 100644 index 000000000..ef124c953 --- /dev/null +++ b/test/regress/regress0/sygus/c100.sy @@ -0,0 +1,18 @@ +; EXPECT: unsat +; COMMAND-LINE: --cegqi-si=all --sygus-out=status + +(set-logic LIA) + +(synth-fun constant ((x Int)) Int + ((Start Int (0 + 2 + 3 + 5 + (+ Start Start) + (- Start Start) + )) + )) +(declare-var x Int) +(constraint (= (constant x) 100)) +(check-synth) + diff --git a/test/regress/regress0/sygus/check-generic-red.sy b/test/regress/regress0/sygus/check-generic-red.sy new file mode 100644 index 000000000..917c1473a --- /dev/null +++ b/test/regress/regress0/sygus/check-generic-red.sy @@ -0,0 +1,19 @@ +; EXPECT: unsat +; COMMAND-LINE: --cegqi-si=all --sygus-out=status +(set-logic LIA) + +(synth-fun P ((x Int) (y Int)) Bool + ((Start Bool ((and Start Start) + (not Start) + (<= StartInt StartIntC) + (<= StartInt StartInt) + (>= StartInt StartInt) + (<= StartIntC StartInt) + (>= StartIntC StartInt) + (<= StartIntC StartIntC) + )) + (StartIntC Int (0 0 1)) + (StartInt Int (x y 0 1)))) + +(constraint (P 0 2)) +(check-synth) diff --git a/test/regress/run_regression b/test/regress/run_regression index 5d4165597..e236234e1 100755 --- a/test/regress/run_regression +++ b/test/regress/run_regression @@ -262,6 +262,13 @@ else echo "$expected_output" >"$expoutfile" fi +if [ "$lang" = sygus ]; then + if ! expr "$CVC4_REGRESSION_ARGS $command_line" : '.*--check-synth-sol' &>/dev/null && + ! expr "$CVC4_REGRESSION_ARGS $command_line" : '.*--no-check-synth-sol' &>/dev/null; then + # later on, we'll run another test with --check-models on + command_line="$command_line --check-synth-sol" + fi +fi check_models=false if grep '^sat$' "$expoutfile" &>/dev/null || grep '^invalid$' "$expoutfile" &>/dev/null || grep '^unknown$' "$expoptfile" &>/dev/null; then if ! expr "$CVC4_REGRESSION_ARGS $command_line" : '.*--check-models' &>/dev/null && @@ -407,4 +414,5 @@ if $check_models || $check_proofs || $check_unsat_cores; then fi fi + exit $exitcode diff --git a/test/unit/Makefile.am b/test/unit/Makefile.am index 9f61ef031..167100ff0 100644 --- a/test/unit/Makefile.am +++ b/test/unit/Makefile.am @@ -41,6 +41,7 @@ UNIT_TESTS += \ context/cdvector_black \ util/array_store_all_black \ util/assert_white \ + util/check_white \ util/binary_heap_black \ util/bitvector_black \ util/datatype_black \ diff --git a/test/unit/theory/theory_quantifiers_bv_inverter_white.h b/test/unit/theory/theory_quantifiers_bv_inverter_white.h index ba8dd1668..a9cd7b8d6 100644 --- a/test/unit/theory/theory_quantifiers_bv_inverter_white.h +++ b/test/unit/theory/theory_quantifiers_bv_inverter_white.h @@ -43,6 +43,7 @@ class TheoryQuantifiersBvInverter : public CxxTest::TestSuite void runTestPred(bool pol, Kind k, + Node x, Node (*getsc)(bool, Kind, Node, Node)) { Assert(k == BITVECTOR_ULT || k == BITVECTOR_SLT || k == EQUAL @@ -81,7 +82,7 @@ class TheoryQuantifiersBvInverter : public CxxTest::TestSuite k = DISTINCT; } } - Node body = d_nm->mkNode(k, d_x, d_t); + Node body = d_nm->mkNode(k, x, d_t); Node scr = d_nm->mkNode(EXISTS, d_bvarlist, body); Expr a = d_nm->mkNode(DISTINCT, scl, scr).toExpr(); Result res = d_smt->checkSat(a); @@ -126,6 +127,100 @@ class TheoryQuantifiersBvInverter : public CxxTest::TestSuite TS_ASSERT(res.d_sat == Result::UNSAT); } + void runTestConcat(bool pol, Kind litk, unsigned idx) + { + Node s1, s2, sv_t; + Node x, t, sk; + Node sc; + + if (idx == 0) + { + s2 = d_nm->mkVar("s2", d_nm->mkBitVectorType(4)); + x = d_nm->mkBoundVar(s2.getType()); + sk = d_nm->mkSkolem("sk", s2.getType()); + t = d_nm->mkVar("t", d_nm->mkBitVectorType(8)); + sv_t = d_nm->mkNode(BITVECTOR_CONCAT, x, s2); + sc = getScBvConcat(pol, litk, 0, sk, sv_t, t); + } + else if (idx == 1) + { + s1 = d_nm->mkVar("s1", d_nm->mkBitVectorType(4)); + x = d_nm->mkBoundVar(s1.getType()); + sk = d_nm->mkSkolem("sk", s1.getType()); + t = d_nm->mkVar("t", d_nm->mkBitVectorType(8)); + sv_t = d_nm->mkNode(BITVECTOR_CONCAT, s1, x); + sc = getScBvConcat(pol, litk, 1, sk, sv_t, t); + } + else + { + Assert(idx == 2); + s1 = d_nm->mkVar("s1", d_nm->mkBitVectorType(4)); + s2 = d_nm->mkVar("s2", d_nm->mkBitVectorType(4)); + x = d_nm->mkBoundVar(s2.getType()); + sk = d_nm->mkSkolem("sk", s1.getType()); + t = d_nm->mkVar("t", d_nm->mkBitVectorType(12)); + sv_t = d_nm->mkNode(BITVECTOR_CONCAT, s1, x, s2); + sc = getScBvConcat(pol, litk, 1, sk, sv_t, t); + } + + TS_ASSERT(!sc.isNull()); + Kind ksc = sc.getKind(); + TS_ASSERT((litk == kind::EQUAL && pol == false) + || ksc == IMPLIES); + Node scl = ksc == IMPLIES ? sc[0] : bv::utils::mkTrue(); + Node body = d_nm->mkNode(litk, sv_t, t); + Node bvarlist = d_nm->mkNode(BOUND_VAR_LIST, { x }); + Node scr = d_nm->mkNode(EXISTS, bvarlist, pol ? body : body.notNode()); + Expr a = d_nm->mkNode(DISTINCT, scl, scr).toExpr(); + Result res = d_smt->checkSat(a); + if (res.d_sat == Result::SAT) + { + std::cout << std::endl; + if (!s1.isNull()) + std::cout << "s1 " << d_smt->getValue(s1.toExpr()) << std::endl; + if (!s2.isNull()) + std::cout << "s2 " << d_smt->getValue(s2.toExpr()) << std::endl; + std::cout << "t " << d_smt->getValue(t.toExpr()) << std::endl; + std::cout << "x " << d_smt->getValue(x.toExpr()) << std::endl; + } + TS_ASSERT(res.d_sat == Result::UNSAT); + } + + void runTestSext(bool pol, Kind litk) + { + unsigned ws = 3; + unsigned wx = 5; + unsigned w = 8; + + Node x = d_nm->mkVar("x", d_nm->mkBitVectorType(wx)); + Node sk = d_nm->mkSkolem("sk", x.getType()); + x = d_nm->mkBoundVar(x.getType()); + + Node t = d_nm->mkVar("t", d_nm->mkBitVectorType(w)); + Node sv_t = bv::utils::mkSignExtend(x, ws); + Node sc = getScBvSext(pol, litk, 0, sk, sv_t, t); + + TS_ASSERT(!sc.isNull()); + Kind ksc = sc.getKind(); + TS_ASSERT((litk == kind::EQUAL && pol == false) + || (litk == kind::BITVECTOR_ULT && pol == false) + || (litk == kind::BITVECTOR_UGT && pol == false) + || ksc == IMPLIES); + Node scl = ksc == IMPLIES ? sc[0] : bv::utils::mkTrue(); + Node body = d_nm->mkNode(litk, sv_t, t); + Node bvarlist = d_nm->mkNode(BOUND_VAR_LIST, { x }); + Node scr = d_nm->mkNode(EXISTS, bvarlist, pol ? body : body.notNode()); + Expr a = d_nm->mkNode(DISTINCT, scl, scr).toExpr(); + Result res = d_smt->checkSat(a); + if (res.d_sat == Result::SAT) + { + std::cout << std::endl; + std::cout << "t " << d_smt->getValue(t.toExpr()) << std::endl; + std::cout << "x " << d_smt->getValue(x.toExpr()) << std::endl; + } + TS_ASSERT(res.d_sat == Result::UNSAT); + } + public: TheoryQuantifiersBvInverter() {} @@ -134,7 +229,7 @@ class TheoryQuantifiersBvInverter : public CxxTest::TestSuite d_em = new ExprManager(); d_nm = NodeManager::fromExprManager(d_em); d_smt = new SmtEngine(d_em); - d_smt->setOption("cbqi-bv", CVC4::SExpr(false)); + d_smt->setOption("cbqi-full", CVC4::SExpr(true)); d_smt->setOption("produce-models", CVC4::SExpr(true)); d_scope = new SmtScope(d_smt); @@ -161,42 +256,42 @@ class TheoryQuantifiersBvInverter : public CxxTest::TestSuite void testGetScBvUltTrue() { - runTestPred(true, BITVECTOR_ULT, getScBvUltUgt); + runTestPred(true, BITVECTOR_ULT, d_x, getScBvUltUgt); } void testGetScBvUltFalse() { - runTestPred(false, BITVECTOR_ULT, getScBvUltUgt); + runTestPred(false, BITVECTOR_ULT, d_x, getScBvUltUgt); } void testGetScBvUgtTrue() { - runTestPred(true, BITVECTOR_UGT, getScBvUltUgt); + runTestPred(true, BITVECTOR_UGT, d_x, getScBvUltUgt); } void testGetScBvUgtFalse() { - runTestPred(false, BITVECTOR_UGT, getScBvUltUgt); + runTestPred(false, BITVECTOR_UGT, d_x, getScBvUltUgt); } void testGetScBvSltTrue() { - runTestPred(true, BITVECTOR_SLT, getScBvSltSgt); + runTestPred(true, BITVECTOR_SLT, d_x, getScBvSltSgt); } void testGetScBvSltFalse() { - runTestPred(false, BITVECTOR_SLT, getScBvSltSgt); + runTestPred(false, BITVECTOR_SLT, d_x, getScBvSltSgt); } void testGetScBvSgtTrue() { - runTestPred(true, BITVECTOR_SGT, getScBvSltSgt); + runTestPred(true, BITVECTOR_SGT, d_x, getScBvSltSgt); } void testGetScBvSgtFalse() { - runTestPred(false, BITVECTOR_SGT, getScBvSltSgt); + runTestPred(false, BITVECTOR_SGT, d_x, getScBvSltSgt); } /* Equality and Disequality ---------------------------------------------- */ @@ -376,8 +471,346 @@ class TheoryQuantifiersBvInverter : public CxxTest::TestSuite runTest(false, EQUAL, BITVECTOR_SHL, 1, getScBvShl); } + /* Concat */ + + void testGetScBvConcatEqTrue0() + { + runTestConcat(true, EQUAL, 0); + } + + void testGetScBvConcatEqTrue1() + { + runTestConcat(true, EQUAL, 1); + } + + void testGetScBvConcatEqTrue2() + { + runTestConcat(true, EQUAL, 2); + } + + void testGetScBvConcatEqFalse0() + { + runTestConcat(false, EQUAL, 0); + } + + void testGetScBvConcatEqFalse1() + { + runTestConcat(false, EQUAL, 1); + } + + void testGetScBvConcatEqFalse2() + { + runTestConcat(false, EQUAL, 2); + } + + /* Sext */ + + void testGetScBvSextEqTrue() + { + runTestSext(true, EQUAL); + } + + void testGetScBvSextEqFalse() + { + runTestSext(false, EQUAL); + } + /* Inequality ------------------------------------------------------------ */ + /* Not */ + + void testGetScBvNotUltTrue0() + { + Node x = d_nm->mkNode(BITVECTOR_NOT, d_x); + runTestPred(true, BITVECTOR_ULT, x, getScBvUltUgt); + } + + void testGetScBvNotUltTrue1() + { + Node x = d_nm->mkNode(BITVECTOR_NOT, d_x); + runTestPred(true, BITVECTOR_ULT, x, getScBvUltUgt); + } + + void testGetScBvNotUltFalse0() + { + Node x = d_nm->mkNode(BITVECTOR_NOT, d_x); + runTestPred(false, BITVECTOR_ULT, x, getScBvUltUgt); + } + + void testGetScBvNotUltFalse1() + { + Node x = d_nm->mkNode(BITVECTOR_NOT, d_x); + runTestPred(false, BITVECTOR_ULT, x, getScBvUltUgt); + } + + void testGetScBvNotUgtTrue0() + { + Node x = d_nm->mkNode(BITVECTOR_NOT, d_x); + runTestPred(true, BITVECTOR_UGT, x, getScBvUltUgt); + } + + void testGetScBvNotUgtTrue1() + { + Node x = d_nm->mkNode(BITVECTOR_NOT, d_x); + runTestPred(true, BITVECTOR_UGT, x, getScBvUltUgt); + } + + void testGetScBvNotUgtFalse0() + { + Node x = d_nm->mkNode(BITVECTOR_NOT, d_x); + runTestPred(false, BITVECTOR_UGT, x, getScBvUltUgt); + } + + void testGetScBvNotUgtFalse1() + { + Node x = d_nm->mkNode(BITVECTOR_NOT, d_x); + runTestPred(false, BITVECTOR_UGT, x, getScBvUltUgt); + } + + void testGetScBvNotSltTrue0() + { + Node x = d_nm->mkNode(BITVECTOR_NOT, d_x); + runTestPred(true, BITVECTOR_SLT, x, getScBvSltSgt); + } + + void testGetScBvNotSltTrue1() + { + Node x = d_nm->mkNode(BITVECTOR_NOT, d_x); + runTestPred(true, BITVECTOR_SLT, x, getScBvSltSgt); + } + + void testGetScBvNotSltFalse0() + { + Node x = d_nm->mkNode(BITVECTOR_NOT, d_x); + runTestPred(false, BITVECTOR_SLT, x, getScBvSltSgt); + } + + void testGetScBvNotSltFalse1() + { + Node x = d_nm->mkNode(BITVECTOR_NOT, d_x); + runTestPred(false, BITVECTOR_SLT, x, getScBvSltSgt); + } + + void testGetScBvNotSgtTrue0() + { + Node x = d_nm->mkNode(BITVECTOR_NOT, d_x); + runTestPred(true, BITVECTOR_SGT, x, getScBvSltSgt); + } + + void testGetScBvNotSgtTrue1() + { + Node x = d_nm->mkNode(BITVECTOR_NOT, d_x); + runTestPred(true, BITVECTOR_SGT, x, getScBvSltSgt); + } + + void testGetScBvNotSgtFalse0() + { + Node x = d_nm->mkNode(BITVECTOR_NOT, d_x); + runTestPred(false, BITVECTOR_SGT, x, getScBvSltSgt); + } + + void testGetScBvNotSgtFalse1() + { + Node x = d_nm->mkNode(BITVECTOR_NOT, d_x); + runTestPred(false, BITVECTOR_SGT, x, getScBvSltSgt); + } + + /* Neg */ + + void testGetScBvNegUltTrue0() + { + Node x = d_nm->mkNode(BITVECTOR_NEG, d_x); + runTestPred(true, BITVECTOR_ULT, x, getScBvUltUgt); + } + + void testGetScBvNegUltTrue1() + { + Node x = d_nm->mkNode(BITVECTOR_NEG, d_x); + runTestPred(true, BITVECTOR_ULT, x, getScBvUltUgt); + } + + void testGetScBvNegUltFalse0() + { + Node x = d_nm->mkNode(BITVECTOR_NEG, d_x); + runTestPred(false, BITVECTOR_ULT, x, getScBvUltUgt); + } + + void testGetScBvNegUltFalse1() + { + Node x = d_nm->mkNode(BITVECTOR_NEG, d_x); + runTestPred(false, BITVECTOR_ULT, x, getScBvUltUgt); + } + + void testGetScBvNegUgtTrue0() + { + Node x = d_nm->mkNode(BITVECTOR_NEG, d_x); + runTestPred(true, BITVECTOR_UGT, x, getScBvUltUgt); + } + + void testGetScBvNegUgtTrue1() + { + Node x = d_nm->mkNode(BITVECTOR_NEG, d_x); + runTestPred(true, BITVECTOR_UGT, x, getScBvUltUgt); + } + + void testGetScBvNegUgtFalse0() + { + Node x = d_nm->mkNode(BITVECTOR_NEG, d_x); + runTestPred(false, BITVECTOR_UGT, x, getScBvUltUgt); + } + + void testGetScBvNegUgtFalse1() + { + Node x = d_nm->mkNode(BITVECTOR_NEG, d_x); + runTestPred(false, BITVECTOR_UGT, x, getScBvUltUgt); + } + + void testGetScBvNegSltTrue0() + { + Node x = d_nm->mkNode(BITVECTOR_NEG, d_x); + runTestPred(true, BITVECTOR_SLT, x, getScBvSltSgt); + } + + void testGetScBvNegSltTrue1() + { + Node x = d_nm->mkNode(BITVECTOR_NEG, d_x); + runTestPred(true, BITVECTOR_SLT, x, getScBvSltSgt); + } + + void testGetScBvNegSltFalse0() + { + Node x = d_nm->mkNode(BITVECTOR_NEG, d_x); + runTestPred(false, BITVECTOR_SLT, x, getScBvSltSgt); + } + + void testGetScBvNegSltFalse1() + { + Node x = d_nm->mkNode(BITVECTOR_NEG, d_x); + runTestPred(false, BITVECTOR_SLT, x, getScBvSltSgt); + } + + void testGetScBvNegSgtTrue0() + { + Node x = d_nm->mkNode(BITVECTOR_NEG, d_x); + runTestPred(true, BITVECTOR_SGT, x, getScBvSltSgt); + } + + void testGetScBvNegSgtTrue1() + { + Node x = d_nm->mkNode(BITVECTOR_NEG, d_x); + runTestPred(true, BITVECTOR_SGT, x, getScBvSltSgt); + } + + void testGetScBvNegSgtFalse0() + { + Node x = d_nm->mkNode(BITVECTOR_NEG, d_x); + runTestPred(false, BITVECTOR_SGT, x, getScBvSltSgt); + } + + void testGetScBvNegSgtFalse1() + { + Node x = d_nm->mkNode(BITVECTOR_NEG, d_x); + runTestPred(false, BITVECTOR_SGT, x, getScBvSltSgt); + } + + /* Add */ + + void testGetScBvPlusUltTrue0() + { + Node x = d_nm->mkNode(BITVECTOR_PLUS, d_x, d_s); + runTestPred(true, BITVECTOR_ULT, x, getScBvUltUgt); + } + + void testGetScBvPlusUltTrue1() + { + Node x = d_nm->mkNode(BITVECTOR_PLUS, d_s, d_x); + runTestPred(true, BITVECTOR_ULT, x, getScBvUltUgt); + } + + void testGetScBvPlusUltFalse0() + { + Node x = d_nm->mkNode(BITVECTOR_PLUS, d_x, d_s); + runTestPred(false, BITVECTOR_ULT, x, getScBvUltUgt); + } + + void testGetScBvPlusUltFalse1() + { + Node x = d_nm->mkNode(BITVECTOR_PLUS, d_s, d_x); + runTestPred(false, BITVECTOR_ULT, x, getScBvUltUgt); + } + + void testGetScBvPlusUgtTrue0() + { + Node x = d_nm->mkNode(BITVECTOR_PLUS, d_x, d_s); + runTestPred(true, BITVECTOR_UGT, x, getScBvUltUgt); + } + + void testGetScBvPlusUgtTrue1() + { + Node x = d_nm->mkNode(BITVECTOR_PLUS, d_s, d_x); + runTestPred(true, BITVECTOR_UGT, x, getScBvUltUgt); + } + + void testGetScBvPlusUgtFalse0() + { + Node x = d_nm->mkNode(BITVECTOR_PLUS, d_x, d_s); + runTestPred(false, BITVECTOR_UGT, x, getScBvUltUgt); + } + + void testGetScBvPlusUgtFalse1() + { + Node x = d_nm->mkNode(BITVECTOR_PLUS, d_s, d_x); + runTestPred(false, BITVECTOR_UGT, x, getScBvUltUgt); + } + + void testGetScBvPlusSltTrue0() + { + Node x = d_nm->mkNode(BITVECTOR_PLUS, d_x, d_s); + runTestPred(true, BITVECTOR_SLT, x, getScBvSltSgt); + } + + void testGetScBvPlusSltTrue1() + { + Node x = d_nm->mkNode(BITVECTOR_PLUS, d_s, d_x); + runTestPred(true, BITVECTOR_SLT, x, getScBvSltSgt); + } + + void testGetScBvPlusSltFalse0() + { + Node x = d_nm->mkNode(BITVECTOR_PLUS, d_x, d_s); + runTestPred(false, BITVECTOR_SLT, x, getScBvSltSgt); + } + + void testGetScBvPlusSltFalse1() + { + Node x = d_nm->mkNode(BITVECTOR_PLUS, d_s, d_x); + runTestPred(false, BITVECTOR_SLT, x, getScBvSltSgt); + } + + void testGetScBvPlusSgtTrue0() + { + Node x = d_nm->mkNode(BITVECTOR_PLUS, d_x, d_s); + runTestPred(true, BITVECTOR_SGT, x, getScBvSltSgt); + } + + void testGetScBvPlusSgtTrue1() + { + Node x = d_nm->mkNode(BITVECTOR_PLUS, d_s, d_x); + runTestPred(true, BITVECTOR_SGT, x, getScBvSltSgt); + } + + void testGetScBvPlusSgtFalse0() + { + Node x = d_nm->mkNode(BITVECTOR_PLUS, d_x, d_s); + runTestPred(false, BITVECTOR_SGT, x, getScBvSltSgt); + } + + void testGetScBvPlusSgtFalse1() + { + Node x = d_nm->mkNode(BITVECTOR_PLUS, d_s, d_x); + runTestPred(false, BITVECTOR_SGT, x, getScBvSltSgt); + } + /* Mult */ void testGetScBvMultUltTrue0() @@ -1033,4 +1466,169 @@ class TheoryQuantifiersBvInverter : public CxxTest::TestSuite { runTest(false, BITVECTOR_SGT, BITVECTOR_SHL, 1, getScBvShl); } + + /* Concat */ + + void testGetScBvConcatUltTrue0() + { + runTestConcat(true, BITVECTOR_ULT, 0); + } + + void testGetScBvConcatUltTrue1() + { + runTestConcat(true, BITVECTOR_ULT, 1); + } + + void testGetScBvConcatUltTrue2() + { + runTestConcat(true, BITVECTOR_ULT, 2); + } + + void testGetScBvConcatUltFalse0() + { + runTestConcat(false, BITVECTOR_ULT, 0); + } + + void testGetScBvConcatUltFalse1() + { + runTestConcat(false, BITVECTOR_ULT, 1); + } + + void testGetScBvConcatUltFalse2() + { + runTestConcat(false, BITVECTOR_ULT, 2); + } + + void testGetScBvConcatUgtTrue0() + { + runTestConcat(true, BITVECTOR_UGT, 0); + } + + void testGetScBvConcatUgtTrue1() + { + runTestConcat(true, BITVECTOR_UGT, 1); + } + + void testGetScBvConcatUgtTrue2() + { + runTestConcat(true, BITVECTOR_UGT, 2); + } + + void testGetScBvConcatUgtFalse0() + { + runTestConcat(false, BITVECTOR_UGT, 0); + } + + void testGetScBvConcatUgtFalse1() + { + runTestConcat(false, BITVECTOR_UGT, 1); + } + + void testGetScBvConcatUgtFalse2() + { + runTestConcat(false, BITVECTOR_UGT, 2); + } + + void testGetScBvConcatSltTrue0() + { + runTestConcat(true, BITVECTOR_SLT, 0); + } + + void testGetScBvConcatSltTrue1() + { + runTestConcat(true, BITVECTOR_SLT, 1); + } + + void testGetScBvConcatSltTrue2() + { + runTestConcat(true, BITVECTOR_SLT, 2); + } + + void testGetScBvConcatSltFalse0() + { + runTestConcat(false, BITVECTOR_SLT, 0); + } + + void testGetScBvConcatSltFalse1() + { + runTestConcat(false, BITVECTOR_SLT, 1); + } + + void testGetScBvConcatSltFalse2() + { + runTestConcat(false, BITVECTOR_SLT, 2); + } + + void testGetScBvConcatSgtTrue0() + { + runTestConcat(true, BITVECTOR_SGT, 0); + } + + void testGetScBvConcatSgtTrue1() + { + runTestConcat(true, BITVECTOR_SGT, 1); + } + + void testGetScBvConcatSgtTrue2() + { + runTestConcat(true, BITVECTOR_SGT, 2); + } + + void testGetScBvConcatSgtFalse0() + { + runTestConcat(false, BITVECTOR_SGT, 0); + } + + void testGetScBvConcatSgtFalse1() + { + runTestConcat(false, BITVECTOR_SGT, 1); + } + + void testGetScBvConcatSgtFalse2() + { + runTestConcat(false, BITVECTOR_SGT, 2); + } + + /* Sext */ + + void testGetScBvSextUltTrue() + { + runTestSext(true, BITVECTOR_ULT); + } + + void testGetScBvSextUltFalse() + { + runTestSext(false, BITVECTOR_ULT); + } + + void testGetScBvSextUgtTrue() + { + runTestSext(true, BITVECTOR_UGT); + } + + void testGetScBvSextUgtFalse() + { + runTestSext(false, BITVECTOR_UGT); + } + + void testGetScBvSextSltTrue() + { + runTestSext(true, BITVECTOR_SLT); + } + + void testGetScBvSextSltFalse() + { + runTestSext(false, BITVECTOR_SLT); + } + + void testGetScBvSextSgtTrue() + { + runTestSext(true, BITVECTOR_SGT); + } + + void testGetScBvSextSgtFalse() + { + runTestSext(false, BITVECTOR_SGT); + } + }; diff --git a/test/unit/util/check_white.h b/test/unit/util/check_white.h new file mode 100644 index 000000000..e57afa6c7 --- /dev/null +++ b/test/unit/util/check_white.h @@ -0,0 +1,58 @@ +/********************* */ +/*! \file check_white.h + ** \verbatim + ** Top contributors (to current version): + ** Tim King + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2017 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief White box testing of check utilities. + ** + ** White box testing of check utilities. + **/ + +#include <cxxtest/TestSuite.h> + +#include <cstring> +#include <string> + +#include "base/cvc4_check.h" + +using namespace std; +using namespace CVC4; + +namespace { + +class CheckWhite : public CxxTest::TestSuite +{ + public: + const int kOne = 1; + + // This test just checks that this statement compiles. + std::string TerminalCvc4Fatal() const + { + CVC4_FATAL() << "This is a test that confirms that CVC4_FATAL can be a " + "terminal statement in a function that has a non-void " + "return type."; + } + + void testCheck() { CHECK(kOne >= 0) << kOne << " must be positive"; } + void testDCheck() + { + DCHECK(kOne == 1) << "always passes"; +#ifndef CVC4_ASSERTIONS + DCHECK(false) << "Will not be compiled in when CVC4_ASSERTIONS off."; +#endif /* CVC4_ASSERTIONS */ + } + + void testPointerTypeCanBeTheCondition() + { + const int* one_pointer = &kOne; + CHECK(one_pointer); + } +}; + +} // namespace |