diff options
-rw-r--r-- | src/expr/datatype.cpp | 4 | ||||
-rw-r--r-- | src/expr/datatype.h | 2 | ||||
-rw-r--r-- | src/printer/smt2/smt2_printer.cpp | 2 | ||||
-rw-r--r-- | src/theory/quantifiers/sygus_grammar_norm.cpp | 24 | ||||
-rw-r--r-- | src/theory/quantifiers/sygus_grammar_norm.h | 6 |
5 files changed, 27 insertions, 11 deletions
diff --git a/src/expr/datatype.cpp b/src/expr/datatype.cpp index 513cb2170..8b6384dcc 100644 --- a/src/expr/datatype.cpp +++ b/src/expr/datatype.cpp @@ -879,11 +879,11 @@ bool DatatypeConstructor::isSygusIdFunc() const { && d_sygus_op[0][0] == d_sygus_op[1]); } -SygusPrintCallback* DatatypeConstructor::getSygusPrintCallback() const +std::shared_ptr<SygusPrintCallback> DatatypeConstructor::getSygusPrintCallback() const { PrettyCheckArgument( isResolved(), this, "this datatype constructor is not yet resolved"); - return d_sygus_pc.get(); + return d_sygus_pc; } Cardinality DatatypeConstructor::getCardinality( Type t ) const throw(IllegalArgumentException) { diff --git a/src/expr/datatype.h b/src/expr/datatype.h index 85ecfb946..b899b0099 100644 --- a/src/expr/datatype.h +++ b/src/expr/datatype.h @@ -300,7 +300,7 @@ class CVC4_PUBLIC DatatypeConstructor { * to handle defined or let expressions that * appear in user-provided grammars. */ - SygusPrintCallback* getSygusPrintCallback() const; + std::shared_ptr<SygusPrintCallback> getSygusPrintCallback() const; /** * Get the tester name for this Datatype constructor. diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index 82871a1d5..c029c0824 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -1324,7 +1324,7 @@ void Smt2Printer::toStreamSygus(std::ostream& out, TNode n) const throw() { int cIndex = Datatype::indexOf(n.getOperator().toExpr()); Assert(!dt[cIndex].getSygusOp().isNull()); - SygusPrintCallback* spc = dt[cIndex].getSygusPrintCallback(); + SygusPrintCallback* spc = dt[cIndex].getSygusPrintCallback().get(); if (spc != nullptr && options::sygusPrintCallbacks()) { spc->toStreamSygus(this, out, n.toExpr()); diff --git a/src/theory/quantifiers/sygus_grammar_norm.cpp b/src/theory/quantifiers/sygus_grammar_norm.cpp index ea7722df0..9037566b7 100644 --- a/src/theory/quantifiers/sygus_grammar_norm.cpp +++ b/src/theory/quantifiers/sygus_grammar_norm.cpp @@ -17,6 +17,8 @@ #include "expr/datatype.h" #include "options/quantifiers_options.h" +#include "smt/smt_engine.h" +#include "smt/smt_engine_scope.h" #include "theory/quantifiers/ce_guided_conjecture.h" #include "theory/quantifiers/term_database_sygus.h" #include "theory/quantifiers/term_util.h" @@ -40,6 +42,7 @@ TypeObject::TypeObject(TypeNode src_tn, std::string type_name) : d_dt(Datatype(type_name)) { d_tn = src_tn; + d_t = src_tn.toType(); /* Create an unresolved type */ d_unres_t = NodeManager::currentNM() ->mkSort(type_name, ExprManager::SORT_FLAG_PLACEHOLDER) @@ -69,7 +72,7 @@ void SygusGrammarNorm::collectInfoFor(TypeNode tn, std::string type_name = ss.str(); /* Add to global accumulators */ tos.push_back(TypeObject(tn, type_name)); - const Datatype& dt = static_cast<DatatypeType>(tn.toType()).getDatatype(); + const Datatype& dt = static_cast<DatatypeType>(tos.back().d_t).getDatatype(); tn_to_unres[tn] = tos.back().d_unres_t; /* Visit types of constructor arguments */ for (const DatatypeConstructor& cons : dt) @@ -92,7 +95,7 @@ void SygusGrammarNorm::normalizeSygusInt(unsigned ind, Node sygus_vars) { const Datatype& dt = - static_cast<DatatypeType>(tos[ind].d_tn.toType()).getDatatype(); + static_cast<DatatypeType>(tos[ind].d_t).getDatatype(); Trace("sygus-grammar-normalize") << "Normalizing integer type " << tos[ind].d_tn << " from datatype\n" << dt << std::endl; @@ -113,7 +116,7 @@ TypeNode SygusGrammarNorm::normalizeSygusType(TypeNode tn, Node sygus_vars) for (unsigned i = 0, size = tos.size(); i < size; ++i) { const Datatype& dt = - static_cast<DatatypeType>(tos[i].d_tn.toType()).getDatatype(); + static_cast<DatatypeType>(tos[i].d_t).getDatatype(); Trace("sygus-grammar-normalize") << "Rebuild " << tos[i].d_tn << " from " << dt << std::endl; /* Collect information to rebuild constructors */ @@ -123,8 +126,15 @@ TypeNode SygusGrammarNorm::normalizeSygusType(TypeNode tn, Node sygus_vars) << "...for " << cons.getName() << std::endl; /* Recover the sygus operator to not lose reference to the original * operator (NOT, ITE, etc) */ - tos[i].d_ops.push_back(cons.getSygusOp()); + Node exp_sop_n = Node::fromExpr( + smt::currentSmtEngine()->expandDefinitions(cons.getSygusOp())); + tos[i].d_ops.push_back(Rewriter::rewrite(exp_sop_n)); + Trace("sygus-grammar-normalize") + << "\tOriginal op: " << cons.getSygusOp() + << "\n\tExpanded one: " << exp_sop_n + << "\n\tRewritten one: " << tos[i].d_ops.back() << std::endl; tos[i].d_cons_names.push_back(cons.getName()); + tos[i].d_pcb.push_back(cons.getSygusPrintCallback()); tos[i].d_cons_args_t.push_back(std::vector<Type>()); for (const DatatypeConstructorArg& arg : cons) { @@ -142,8 +152,10 @@ TypeNode SygusGrammarNorm::normalizeSygusType(TypeNode tn, Node sygus_vars) dt.getSygusAllowAll()); for (unsigned j = 0, size_d_ops = tos[i].d_ops.size(); j < size_d_ops; ++j) { - tos[i].d_dt.addSygusConstructor( - tos[i].d_ops[j], tos[i].d_cons_names[j], tos[i].d_cons_args_t[j]); + tos[i].d_dt.addSygusConstructor(tos[i].d_ops[j].toExpr(), + tos[i].d_cons_names[j], + tos[i].d_cons_args_t[j], + tos[i].d_pcb[j]); } Trace("sygus-grammar-normalize") << "...built datatype " << tos[i].d_dt << std::endl; diff --git a/src/theory/quantifiers/sygus_grammar_norm.h b/src/theory/quantifiers/sygus_grammar_norm.h index bd63f5fdb..15b4502d3 100644 --- a/src/theory/quantifiers/sygus_grammar_norm.h +++ b/src/theory/quantifiers/sygus_grammar_norm.h @@ -47,10 +47,14 @@ struct TypeObject /* The original typenode this TypeObject is built from */ TypeNode d_tn; + /* The type represented by d_tn */ + Type d_t; /* Operators for each constructor. */ - std::vector<Expr> d_ops; + std::vector<Node> d_ops; /* Names for each constructor. */ std::vector<std::string> d_cons_names; + /* Print callbacks for each constructor */ + std::vector<std::shared_ptr<SygusPrintCallback>> d_pcb; /* List of argument types for each constructor */ std::vector<std::vector<Type>> d_cons_args_t; /* Unresolved type placeholder */ |