From da3eff9ba6c632e290c9af990dc5750f65d78820 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Fri, 2 Apr 2021 11:55:16 -0500 Subject: Cleaning up friend relationships for commands (#6254) --- src/api/cvc4cpp.h | 63 +++------------------------- src/smt/command.cpp | 116 +++++++++++++++++++++++++++++++--------------------- src/smt/command.h | 19 +++++++++ 3 files changed, 93 insertions(+), 105 deletions(-) (limited to 'src') diff --git a/src/api/cvc4cpp.h b/src/api/cvc4cpp.h index c446fcaf5..8f4977b28 100644 --- a/src/api/cvc4cpp.h +++ b/src/api/cvc4cpp.h @@ -36,38 +36,14 @@ template class NodeTemplate; typedef NodeTemplate Node; -class AssertCommand; -class BlockModelValuesCommand; -class CheckSatCommand; -class CheckSatAssumingCommand; -class DatatypeDeclarationCommand; -class DeclareFunctionCommand; -class DeclareHeapCommand; -class DeclareSortCommand; -class DeclareSygusVarCommand; -class DefineFunctionCommand; -class DefineFunctionRecCommand; -class DefineSortCommand; +class Command; class DType; class DTypeConstructor; class DTypeSelector; -class GetAbductCommand; -class GetInterpolCommand; -class GetModelCommand; -class GetQuantifierEliminationCommand; -class GetUnsatCoreCommand; -class GetValueCommand; class NodeManager; -class ResetCommand; -class SetUserAttributeCommand; -class SimplifyCommand; class SmtEngine; -class SygusConstraintCommand; -class SygusInvConstraintCommand; -class SynthFunCommand; class TypeNode; class Options; -class QueryCommand; class Random; class Result; @@ -240,16 +216,7 @@ class Datatype; */ class CVC4_EXPORT Sort { - friend class cvc5::DatatypeDeclarationCommand; - friend class cvc5::DeclareFunctionCommand; - friend class cvc5::DeclareHeapCommand; - friend class cvc5::DeclareSortCommand; - friend class cvc5::DeclareSygusVarCommand; - friend class cvc5::DefineSortCommand; - friend class cvc5::GetAbductCommand; - friend class cvc5::GetInterpolCommand; - friend class cvc5::GetModelCommand; - friend class cvc5::SynthFunCommand; + friend class cvc5::Command; friend class DatatypeConstructor; friend class DatatypeConstructorDecl; friend class DatatypeSelector; @@ -890,25 +857,7 @@ class CVC4_EXPORT Op */ class CVC4_EXPORT Term { - friend class cvc5::AssertCommand; - friend class cvc5::BlockModelValuesCommand; - friend class cvc5::CheckSatCommand; - friend class cvc5::CheckSatAssumingCommand; - friend class cvc5::DeclareSygusVarCommand; - friend class cvc5::DefineFunctionCommand; - friend class cvc5::DefineFunctionRecCommand; - friend class cvc5::GetAbductCommand; - friend class cvc5::GetInterpolCommand; - friend class cvc5::GetModelCommand; - friend class cvc5::GetQuantifierEliminationCommand; - friend class cvc5::GetUnsatCoreCommand; - friend class cvc5::GetValueCommand; - friend class cvc5::SetUserAttributeCommand; - friend class cvc5::SimplifyCommand; - friend class cvc5::SygusConstraintCommand; - friend class cvc5::SygusInvConstraintCommand; - friend class cvc5::SynthFunCommand; - friend class cvc5::QueryCommand; + friend class cvc5::Command; friend class Datatype; friend class DatatypeConstructor; friend class DatatypeSelector; @@ -2136,9 +2085,7 @@ std::ostream& operator<<(std::ostream& out, */ class CVC4_EXPORT Grammar { - friend class cvc5::GetAbductCommand; - friend class cvc5::GetInterpolCommand; - friend class cvc5::SynthFunCommand; + friend class cvc5::Command; friend class Solver; public: @@ -2323,7 +2270,7 @@ class CVC4_EXPORT Solver friend class DatatypeSelector; friend class Grammar; friend class Op; - friend class cvc5::ResetCommand; + friend class cvc5::Command; friend class Sort; friend class Term; diff --git a/src/smt/command.cpp b/src/smt/command.cpp index e4b179cf4..4a6efe713 100644 --- a/src/smt/command.cpp +++ b/src/smt/command.cpp @@ -237,6 +237,36 @@ void Command::printResult(std::ostream& out, uint32_t verbosity) const } } +Node Command::termToNode(const api::Term& term) { return term.getNode(); } + +std::vector Command::termVectorToNodes( + const std::vector& terms) +{ + return api::Term::termVectorToNodes(terms); +} + +TypeNode Command::sortToTypeNode(const api::Sort& sort) +{ + return sort.getTypeNode(); +} + +std::vector Command::sortVectorToTypeNodes( + const std::vector& sorts) +{ + return api::Sort::sortVectorToTypeNodes(sorts); +} + +TypeNode Command::grammarToTypeNode(api::Grammar* grammar) +{ + return grammar == nullptr ? TypeNode::null() + : sortToTypeNode(grammar->resolve()); +} + +Options& Command::getOriginalOptionsFrom(api::Solver* s) +{ + return *s->d_originalOptions.get(); +} + /* -------------------------------------------------------------------------- */ /* class EmptyCommand */ /* -------------------------------------------------------------------------- */ @@ -310,7 +340,7 @@ void AssertCommand::invoke(api::Solver* solver, SymbolManager* sm) { try { - solver->getSmtEngine()->assertFormula(d_term.getNode(), d_inUnsatCore); + solver->getSmtEngine()->assertFormula(termToNode(d_term), d_inUnsatCore); d_commandStatus = CommandSuccess::instance(); } catch (UnsafeInterruptException& e) @@ -335,7 +365,7 @@ void AssertCommand::toStream(std::ostream& out, size_t dag, OutputLanguage language) const { - Printer::getPrinter(language)->toStreamCmdAssert(out, d_term.getNode()); + Printer::getPrinter(language)->toStreamCmdAssert(out, termToNode(d_term)); } /* -------------------------------------------------------------------------- */ @@ -456,7 +486,7 @@ void CheckSatCommand::toStream(std::ostream& out, size_t dag, OutputLanguage language) const { - Printer::getPrinter(language)->toStreamCmdCheckSat(out, d_term.getNode()); + Printer::getPrinter(language)->toStreamCmdCheckSat(out, termToNode(d_term)); } /* -------------------------------------------------------------------------- */ @@ -531,7 +561,7 @@ void CheckSatAssumingCommand::toStream(std::ostream& out, OutputLanguage language) const { Printer::getPrinter(language)->toStreamCmdCheckSatAssuming( - out, api::Term::termVectorToNodes(d_terms)); + out, termVectorToNodes(d_terms)); } /* -------------------------------------------------------------------------- */ @@ -584,7 +614,7 @@ void QueryCommand::toStream(std::ostream& out, size_t dag, OutputLanguage language) const { - Printer::getPrinter(language)->toStreamCmdQuery(out, d_term.getNode()); + Printer::getPrinter(language)->toStreamCmdQuery(out, termToNode(d_term)); } /* -------------------------------------------------------------------------- */ @@ -622,7 +652,7 @@ void DeclareSygusVarCommand::toStream(std::ostream& out, OutputLanguage language) const { Printer::getPrinter(language)->toStreamCmdDeclareVar( - out, d_var.getNode(), d_sort.getTypeNode()); + out, termToNode(d_var), sortToTypeNode(d_sort)); } /* -------------------------------------------------------------------------- */ @@ -677,14 +707,13 @@ void SynthFunCommand::toStream(std::ostream& out, size_t dag, OutputLanguage language) const { - std::vector nodeVars = api::Term::termVectorToNodes(d_vars); + std::vector nodeVars = termVectorToNodes(d_vars); Printer::getPrinter(language)->toStreamCmdSynthFun( out, - d_fun.getNode(), + termToNode(d_fun), nodeVars, d_isInv, - d_grammar == nullptr ? TypeNode::null() - : d_grammar->resolve().getTypeNode()); + d_grammar == nullptr ? TypeNode::null() : grammarToTypeNode(d_grammar)); } /* -------------------------------------------------------------------------- */ @@ -725,7 +754,7 @@ void SygusConstraintCommand::toStream(std::ostream& out, size_t dag, OutputLanguage language) const { - Printer::getPrinter(language)->toStreamCmdConstraint(out, d_term.getNode()); + Printer::getPrinter(language)->toStreamCmdConstraint(out, termToNode(d_term)); } /* -------------------------------------------------------------------------- */ @@ -782,10 +811,10 @@ void SygusInvConstraintCommand::toStream(std::ostream& out, { Printer::getPrinter(language)->toStreamCmdInvConstraint( out, - d_predicates[0].getNode(), - d_predicates[1].getNode(), - d_predicates[2].getNode(), - d_predicates[3].getNode()); + termToNode(d_predicates[0]), + termToNode(d_predicates[1]), + termToNode(d_predicates[2]), + termToNode(d_predicates[3])); } /* -------------------------------------------------------------------------- */ @@ -866,7 +895,7 @@ void ResetCommand::invoke(api::Solver* solver, SymbolManager* sm) { sm->reset(); Options opts; - opts.copyValues(*solver->d_originalOptions); + opts.copyValues(getOriginalOptionsFrom(solver)); // This reconstructs a new solver object at the same memory location as the // current one. Note that this command does not own the solver object! // It may be safer to instead make the ResetCommand a special case in the @@ -1136,7 +1165,7 @@ void DeclareFunctionCommand::toStream(std::ostream& out, OutputLanguage language) const { Printer::getPrinter(language)->toStreamCmdDeclareFunction( - out, d_func.toString(), d_sort.getTypeNode()); + out, d_func.toString(), sortToTypeNode(d_sort)); } /* -------------------------------------------------------------------------- */ @@ -1175,7 +1204,7 @@ void DeclareSortCommand::toStream(std::ostream& out, OutputLanguage language) const { Printer::getPrinter(language)->toStreamCmdDeclareType(out, - d_sort.getTypeNode()); + sortToTypeNode(d_sort)); } /* -------------------------------------------------------------------------- */ @@ -1218,10 +1247,7 @@ void DefineSortCommand::toStream(std::ostream& out, OutputLanguage language) const { Printer::getPrinter(language)->toStreamCmdDefineType( - out, - d_symbol, - api::Sort::sortVectorToTypeNodes(d_params), - d_sort.getTypeNode()); + out, d_symbol, sortVectorToTypeNodes(d_params), sortToTypeNode(d_sort)); } /* -------------------------------------------------------------------------- */ @@ -1296,9 +1322,9 @@ void DefineFunctionCommand::toStream(std::ostream& out, Printer::getPrinter(language)->toStreamCmdDefineFunction( out, d_func.toString(), - api::Term::termVectorToNodes(d_formals), - d_func.getNode().getType().getRangeType(), - d_formula.getNode()); + termVectorToNodes(d_formals), + termToNode(d_func).getType().getRangeType(), + termToNode(d_formula)); } /* -------------------------------------------------------------------------- */ @@ -1376,14 +1402,11 @@ void DefineFunctionRecCommand::toStream(std::ostream& out, formals.reserve(d_formals.size()); for (const std::vector& formal : d_formals) { - formals.push_back(api::Term::termVectorToNodes(formal)); + formals.push_back(termVectorToNodes(formal)); } Printer::getPrinter(language)->toStreamCmdDefineFunctionRec( - out, - api::Term::termVectorToNodes(d_funcs), - formals, - api::Term::termVectorToNodes(d_formulas)); + out, termVectorToNodes(d_funcs), formals, termVectorToNodes(d_formulas)); } /* -------------------------------------------------------------------------- */ /* class DeclareHeapCommand */ @@ -1417,7 +1440,7 @@ void DeclareHeapCommand::toStream(std::ostream& out, OutputLanguage language) const { Printer::getPrinter(language)->toStreamCmdDeclareHeap( - out, d_locSort.getTypeNode(), d_dataSort.getTypeNode()); + out, sortToTypeNode(d_locSort), sortToTypeNode(d_dataSort)); } /* -------------------------------------------------------------------------- */ @@ -1460,11 +1483,10 @@ void SetUserAttributeCommand::invoke(api::Solver* solver, SymbolManager* sm) { if (!d_term.isNull()) { - solver->getSmtEngine()->setUserAttribute( - d_attr, - d_term.getNode(), - api::Term::termVectorToNodes(d_termValues), - d_strValue); + solver->getSmtEngine()->setUserAttribute(d_attr, + termToNode(d_term), + termVectorToNodes(d_termValues), + d_strValue); } d_commandStatus = CommandSuccess::instance(); } @@ -1490,7 +1512,7 @@ void SetUserAttributeCommand::toStream(std::ostream& out, OutputLanguage language) const { Printer::getPrinter(language)->toStreamCmdSetUserAttribute( - out, d_attr, d_term.getNode()); + out, d_attr, termToNode(d_term)); } /* -------------------------------------------------------------------------- */ @@ -1543,7 +1565,7 @@ void SimplifyCommand::toStream(std::ostream& out, size_t dag, OutputLanguage language) const { - Printer::getPrinter(language)->toStreamCmdSimplify(out, d_term.getNode()); + Printer::getPrinter(language)->toStreamCmdSimplify(out, termToNode(d_term)); } /* -------------------------------------------------------------------------- */ @@ -1624,7 +1646,7 @@ void GetValueCommand::toStream(std::ostream& out, OutputLanguage language) const { Printer::getPrinter(language)->toStreamCmdGetValue( - out, api::Term::termVectorToNodes(d_terms)); + out, termVectorToNodes(d_terms)); } /* -------------------------------------------------------------------------- */ @@ -1722,12 +1744,12 @@ void GetModelCommand::invoke(api::Solver* solver, SymbolManager* sm) std::vector declareSorts = sm->getModelDeclareSorts(); for (const api::Sort& s : declareSorts) { - d_result->addDeclarationSort(s.getTypeNode()); + d_result->addDeclarationSort(sortToTypeNode(s)); } std::vector declareTerms = sm->getModelDeclareTerms(); for (const api::Term& t : declareTerms) { - d_result->addDeclarationTerm(t.getNode()); + d_result->addDeclarationTerm(termToNode(t)); } d_commandStatus = CommandSuccess::instance(); } @@ -1877,7 +1899,7 @@ void BlockModelValuesCommand::toStream(std::ostream& out, OutputLanguage language) const { Printer::getPrinter(language)->toStreamCmdBlockModelValues( - out, api::Term::termVectorToNodes(d_terms)); + out, termVectorToNodes(d_terms)); } /* -------------------------------------------------------------------------- */ @@ -2120,7 +2142,7 @@ void GetInterpolCommand::toStream(std::ostream& out, OutputLanguage language) const { Printer::getPrinter(language)->toStreamCmdGetInterpol( - out, d_name, d_conj.getNode(), d_sygus_grammar->resolve().getTypeNode()); + out, d_name, termToNode(d_conj), grammarToTypeNode(d_sygus_grammar)); } /* -------------------------------------------------------------------------- */ @@ -2205,7 +2227,7 @@ void GetAbductCommand::toStream(std::ostream& out, OutputLanguage language) const { Printer::getPrinter(language)->toStreamCmdGetAbduct( - out, d_name, d_conj.getNode(), d_sygus_grammar->resolve().getTypeNode()); + out, d_name, termToNode(d_conj), grammarToTypeNode(d_sygus_grammar)); } /* -------------------------------------------------------------------------- */ @@ -2281,7 +2303,7 @@ void GetQuantifierEliminationCommand::toStream(std::ostream& out, OutputLanguage language) const { Printer::getPrinter(language)->toStreamCmdGetQuantifierElimination( - out, d_term.getNode()); + out, termToNode(d_term)); } /* -------------------------------------------------------------------------- */ @@ -2381,7 +2403,7 @@ void GetUnsatCoreCommand::printResult(std::ostream& out, if (options::dumpUnsatCoresFull()) { // use the assertions - UnsatCore ucr(api::Term::termVectorToNodes(d_result)); + UnsatCore ucr(termVectorToNodes(d_result)); ucr.toStream(out); } else @@ -2816,7 +2838,7 @@ void DatatypeDeclarationCommand::toStream(std::ostream& out, OutputLanguage language) const { Printer::getPrinter(language)->toStreamCmdDatatypeDeclaration( - out, api::Sort::sortVectorToTypeNodes(d_datatypes)); + out, sortVectorToTypeNodes(d_datatypes)); } } // namespace cvc5 diff --git a/src/smt/command.h b/src/smt/command.h index 2d13a2246..6c3b4f0e4 100644 --- a/src/smt/command.h +++ b/src/smt/command.h @@ -277,6 +277,25 @@ class CVC4_EXPORT Command * successful execution. */ bool d_muted; + + protected: + // These methods rely on Command being a friend of classes in the API. + // Subclasses of command should use these methods for conversions, + // which is currently necessary for e.g. printing commands. + /** Helper to convert a Term to an internal Node */ + static Node termToNode(const api::Term& term); + /** Helper to convert a vector of Terms to internal Nodes. */ + static std::vector termVectorToNodes( + const std::vector& terms); + /** Helper to convert a Sort to an internal TypeNode */ + static TypeNode sortToTypeNode(const api::Sort& sort); + /** Helper to convert a vector of Sorts to internal TypeNodes. */ + static std::vector sortVectorToTypeNodes( + const std::vector& sorts); + /** Helper to convert a Grammar to an internal TypeNode */ + static TypeNode grammarToTypeNode(api::Grammar* grammar); + /** Get original options from the solver (for ResetCommand) */ + Options& getOriginalOptionsFrom(api::Solver* s); }; /* class Command */ /** -- cgit v1.2.3