summaryrefslogtreecommitdiff
path: root/src/smt/print_benchmark.cpp
diff options
context:
space:
mode:
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>2021-10-01 08:58:14 -0500
committerGitHub <noreply@github.com>2021-10-01 13:58:14 +0000
commit971ae785a4789776e6fe36121c80b69162c2fd27 (patch)
treece487004cc4002056a656fc269079de7855af1ec /src/smt/print_benchmark.cpp
parentbb0e6dcde2e7267e391a46b868b990d7cb7e42bd (diff)
Add the print benchmark utility (#7196)
This utility is capable of printing a vector of Node as a valid (SMT-LIB) benchmark with no prior bookkeeping. It also optionally allows for taking a vector Node corresponding to define-fun. It will be used to replace the old internal benchmark dumping infrastructure which was error prone.
Diffstat (limited to 'src/smt/print_benchmark.cpp')
-rw-r--r--src/smt/print_benchmark.cpp278
1 files changed, 278 insertions, 0 deletions
diff --git a/src/smt/print_benchmark.cpp b/src/smt/print_benchmark.cpp
new file mode 100644
index 000000000..c1913e209
--- /dev/null
+++ b/src/smt/print_benchmark.cpp
@@ -0,0 +1,278 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ * Andrew Reynolds
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved. See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Print benchmark utility.
+ */
+
+#include "smt/print_benchmark.h"
+
+#include "expr/dtype.h"
+#include "expr/node_algorithm.h"
+#include "printer/printer.h"
+
+using namespace cvc5::kind;
+
+namespace cvc5 {
+namespace smt {
+
+void PrintBenchmark::printAssertions(std::ostream& out,
+ const std::vector<Node>& defs,
+ const std::vector<Node>& assertions)
+{
+ std::unordered_set<TypeNode> types;
+ std::unordered_set<TNode> typeVisited;
+ for (const Node& a : defs)
+ {
+ expr::getTypes(a, types, typeVisited);
+ }
+ for (const Node& a : assertions)
+ {
+ expr::getTypes(a, types, typeVisited);
+ }
+ // print the declared types first
+ std::unordered_set<TypeNode> alreadyPrintedDeclSorts;
+ for (const TypeNode& st : types)
+ {
+ // note that we must get all "component types" of a type, so that
+ // e.g. U is printed as a sort declaration when we have type (Array U Int).
+ std::unordered_set<TypeNode> ctypes;
+ expr::getComponentTypes(st, ctypes);
+ for (const TypeNode& stc : ctypes)
+ {
+ // get all connected datatypes to this one
+ std::vector<TypeNode> connectedTypes;
+ getConnectedSubfieldTypes(stc, connectedTypes, alreadyPrintedDeclSorts);
+ // now, separate into sorts and datatypes
+ std::vector<TypeNode> datatypeBlock;
+ for (const TypeNode& ctn : connectedTypes)
+ {
+ if (stc.isSort())
+ {
+ d_printer->toStreamCmdDeclareType(out, stc);
+ }
+ else if (stc.isDatatype())
+ {
+ datatypeBlock.push_back(ctn);
+ }
+ }
+ // print the mutually recursive datatype block if necessary
+ if (!datatypeBlock.empty())
+ {
+ d_printer->toStreamCmdDatatypeDeclaration(out, datatypeBlock);
+ }
+ }
+ }
+
+ // global visited cache for expr::getSymbols calls
+ std::unordered_set<TNode> visited;
+
+ // print the definitions
+ std::unordered_map<Node, std::pair<bool, Node>> defMap;
+ std::vector<Node> defSyms;
+ // first, record all the defined symbols
+ for (const Node& a : defs)
+ {
+ bool isRec;
+ Node defSym;
+ Node defBody;
+ decomposeDefinition(a, isRec, defSym, defBody);
+ if (!defSym.isNull())
+ {
+ Assert(defMap.find(defSym) == defMap.end());
+ defMap[defSym] = std::pair<bool, Node>(isRec, defBody);
+ defSyms.push_back(defSym);
+ }
+ }
+ // go back and print the definitions
+ std::unordered_set<Node> alreadyPrintedDecl;
+ std::unordered_set<Node> alreadyPrintedDef;
+
+ std::unordered_map<Node, std::pair<bool, Node>>::const_iterator itd;
+ for (const Node& s : defSyms)
+ {
+ std::vector<Node> recDefs;
+ std::vector<Node> ordinaryDefs;
+ std::unordered_set<Node> syms;
+ getConnectedDefinitions(
+ s, recDefs, ordinaryDefs, syms, defMap, alreadyPrintedDef, visited);
+ // print the declarations that are encountered for the first time in this
+ // block
+ printDeclaredFuns(out, syms, alreadyPrintedDecl);
+ // print the ordinary definitions
+ for (const Node& f : ordinaryDefs)
+ {
+ itd = defMap.find(f);
+ Assert(itd != defMap.end());
+ Assert(!itd->second.first);
+ d_printer->toStreamCmdDefineFunction(out, f, itd->second.second);
+ // a definition is also a declaration
+ alreadyPrintedDecl.insert(f);
+ }
+ // print a recursive function definition block
+ if (!recDefs.empty())
+ {
+ std::vector<Node> lambdas;
+ for (const Node& f : recDefs)
+ {
+ lambdas.push_back(defMap[f].second);
+ // a recursive definition is also a declaration
+ alreadyPrintedDecl.insert(f);
+ }
+ d_printer->toStreamCmdDefineFunctionRec(out, recDefs, lambdas);
+ }
+ }
+
+ // print the remaining declared symbols
+ std::unordered_set<Node> syms;
+ for (const Node& a : assertions)
+ {
+ expr::getSymbols(a, syms, visited);
+ }
+ printDeclaredFuns(out, syms, alreadyPrintedDecl);
+
+ // print the assertions
+ for (const Node& a : assertions)
+ {
+ d_printer->toStreamCmdAssert(out, a);
+ }
+}
+void PrintBenchmark::printAssertions(std::ostream& out,
+ const std::vector<Node>& assertions)
+{
+ std::vector<Node> defs;
+ printAssertions(out, defs, assertions);
+}
+
+void PrintBenchmark::printDeclaredFuns(std::ostream& out,
+ const std::unordered_set<Node>& funs,
+ std::unordered_set<Node>& alreadyPrinted)
+{
+ for (const Node& f : funs)
+ {
+ Assert(f.isVar());
+ if (alreadyPrinted.find(f) == alreadyPrinted.end())
+ {
+ d_printer->toStreamCmdDeclareFunction(out, f);
+ }
+ }
+ alreadyPrinted.insert(funs.begin(), funs.end());
+}
+
+void PrintBenchmark::getConnectedSubfieldTypes(
+ TypeNode tn,
+ std::vector<TypeNode>& connectedTypes,
+ std::unordered_set<TypeNode>& processed)
+{
+ if (processed.find(tn) != processed.end())
+ {
+ return;
+ }
+ processed.insert(tn);
+ if (tn.isSort())
+ {
+ connectedTypes.push_back(tn);
+ }
+ else if (tn.isDatatype())
+ {
+ connectedTypes.push_back(tn);
+ std::unordered_set<TypeNode> subfieldTypes =
+ tn.getDType().getSubfieldTypes();
+ for (const TypeNode& ctn : subfieldTypes)
+ {
+ getConnectedSubfieldTypes(ctn, connectedTypes, processed);
+ }
+ }
+}
+
+void PrintBenchmark::getConnectedDefinitions(
+ Node n,
+ std::vector<Node>& recDefs,
+ std::vector<Node>& ordinaryDefs,
+ std::unordered_set<Node>& syms,
+ const std::unordered_map<Node, std::pair<bool, Node>>& defMap,
+ std::unordered_set<Node>& processedDefs,
+ std::unordered_set<TNode>& visited)
+{
+ // does it have a definition?
+ std::unordered_map<Node, std::pair<bool, Node>>::const_iterator it =
+ defMap.find(n);
+ if (it == defMap.end())
+ {
+ // an ordinary declared symbol
+ syms.insert(n);
+ return;
+ }
+ if (processedDefs.find(n) != processedDefs.end())
+ {
+ return;
+ }
+ processedDefs.insert(n);
+ if (!it->second.first)
+ {
+ // an ordinary define-fun symbol
+ ordinaryDefs.push_back(n);
+ }
+ else
+ {
+ // a recursively defined symbol
+ recDefs.push_back(n);
+ // get the symbols in the body
+ std::unordered_set<Node> symsBody;
+ expr::getSymbols(it->second.second, symsBody, visited);
+ for (const Node& s : symsBody)
+ {
+ getConnectedDefinitions(
+ s, recDefs, ordinaryDefs, syms, defMap, processedDefs, visited);
+ }
+ }
+}
+
+bool PrintBenchmark::decomposeDefinition(Node a,
+ bool& isRecDef,
+ Node& sym,
+ Node& body)
+{
+ if (a.getKind() == EQUAL && a[0].isVar())
+ {
+ // an ordinary define-fun
+ isRecDef = false;
+ sym = a[0];
+ body = a[1];
+ return true;
+ }
+ else if (a.getKind() == FORALL && a[1].getKind() == EQUAL
+ && a[1][0].getKind() == APPLY_UF)
+ {
+ isRecDef = true;
+ sym = a[1][0].getOperator();
+ body = NodeManager::currentNM()->mkNode(LAMBDA, a[0], a[1][1]);
+ return true;
+ }
+ else
+ {
+ Warning() << "Unhandled definition: " << a << std::endl;
+ }
+ return false;
+}
+
+void PrintBenchmark::printBenchmark(std::ostream& out,
+ const std::string& logic,
+ const std::vector<Node>& defs,
+ const std::vector<Node>& assertions)
+{
+ d_printer->toStreamCmdSetBenchmarkLogic(out, logic);
+ printAssertions(out, defs, assertions);
+ d_printer->toStreamCmdCheckSat(out);
+}
+
+} // namespace smt
+} // namespace cvc5
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback