summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/expr/datatype.cpp4
-rw-r--r--src/expr/datatype.h2
-rw-r--r--src/printer/smt2/smt2_printer.cpp2
-rw-r--r--src/theory/quantifiers/sygus_grammar_norm.cpp24
-rw-r--r--src/theory/quantifiers/sygus_grammar_norm.h6
5 files changed, 27 insertions, 11 deletions
diff --git a/src/expr/datatype.cpp b/src/expr/datatype.cpp
index 513cb2170..8b6384dcc 100644
--- a/src/expr/datatype.cpp
+++ b/src/expr/datatype.cpp
@@ -879,11 +879,11 @@ bool DatatypeConstructor::isSygusIdFunc() const {
&& d_sygus_op[0][0] == d_sygus_op[1]);
}
-SygusPrintCallback* DatatypeConstructor::getSygusPrintCallback() const
+std::shared_ptr<SygusPrintCallback> DatatypeConstructor::getSygusPrintCallback() const
{
PrettyCheckArgument(
isResolved(), this, "this datatype constructor is not yet resolved");
- return d_sygus_pc.get();
+ return d_sygus_pc;
}
Cardinality DatatypeConstructor::getCardinality( Type t ) const throw(IllegalArgumentException) {
diff --git a/src/expr/datatype.h b/src/expr/datatype.h
index 85ecfb946..b899b0099 100644
--- a/src/expr/datatype.h
+++ b/src/expr/datatype.h
@@ -300,7 +300,7 @@ class CVC4_PUBLIC DatatypeConstructor {
* to handle defined or let expressions that
* appear in user-provided grammars.
*/
- SygusPrintCallback* getSygusPrintCallback() const;
+ std::shared_ptr<SygusPrintCallback> getSygusPrintCallback() const;
/**
* Get the tester name for this Datatype constructor.
diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp
index 82871a1d5..c029c0824 100644
--- a/src/printer/smt2/smt2_printer.cpp
+++ b/src/printer/smt2/smt2_printer.cpp
@@ -1324,7 +1324,7 @@ void Smt2Printer::toStreamSygus(std::ostream& out, TNode n) const throw()
{
int cIndex = Datatype::indexOf(n.getOperator().toExpr());
Assert(!dt[cIndex].getSygusOp().isNull());
- SygusPrintCallback* spc = dt[cIndex].getSygusPrintCallback();
+ SygusPrintCallback* spc = dt[cIndex].getSygusPrintCallback().get();
if (spc != nullptr && options::sygusPrintCallbacks())
{
spc->toStreamSygus(this, out, n.toExpr());
diff --git a/src/theory/quantifiers/sygus_grammar_norm.cpp b/src/theory/quantifiers/sygus_grammar_norm.cpp
index ea7722df0..9037566b7 100644
--- a/src/theory/quantifiers/sygus_grammar_norm.cpp
+++ b/src/theory/quantifiers/sygus_grammar_norm.cpp
@@ -17,6 +17,8 @@
#include "expr/datatype.h"
#include "options/quantifiers_options.h"
+#include "smt/smt_engine.h"
+#include "smt/smt_engine_scope.h"
#include "theory/quantifiers/ce_guided_conjecture.h"
#include "theory/quantifiers/term_database_sygus.h"
#include "theory/quantifiers/term_util.h"
@@ -40,6 +42,7 @@ TypeObject::TypeObject(TypeNode src_tn, std::string type_name)
: d_dt(Datatype(type_name))
{
d_tn = src_tn;
+ d_t = src_tn.toType();
/* Create an unresolved type */
d_unres_t = NodeManager::currentNM()
->mkSort(type_name, ExprManager::SORT_FLAG_PLACEHOLDER)
@@ -69,7 +72,7 @@ void SygusGrammarNorm::collectInfoFor(TypeNode tn,
std::string type_name = ss.str();
/* Add to global accumulators */
tos.push_back(TypeObject(tn, type_name));
- const Datatype& dt = static_cast<DatatypeType>(tn.toType()).getDatatype();
+ const Datatype& dt = static_cast<DatatypeType>(tos.back().d_t).getDatatype();
tn_to_unres[tn] = tos.back().d_unres_t;
/* Visit types of constructor arguments */
for (const DatatypeConstructor& cons : dt)
@@ -92,7 +95,7 @@ void SygusGrammarNorm::normalizeSygusInt(unsigned ind,
Node sygus_vars)
{
const Datatype& dt =
- static_cast<DatatypeType>(tos[ind].d_tn.toType()).getDatatype();
+ static_cast<DatatypeType>(tos[ind].d_t).getDatatype();
Trace("sygus-grammar-normalize")
<< "Normalizing integer type " << tos[ind].d_tn << " from datatype\n"
<< dt << std::endl;
@@ -113,7 +116,7 @@ TypeNode SygusGrammarNorm::normalizeSygusType(TypeNode tn, Node sygus_vars)
for (unsigned i = 0, size = tos.size(); i < size; ++i)
{
const Datatype& dt =
- static_cast<DatatypeType>(tos[i].d_tn.toType()).getDatatype();
+ static_cast<DatatypeType>(tos[i].d_t).getDatatype();
Trace("sygus-grammar-normalize")
<< "Rebuild " << tos[i].d_tn << " from " << dt << std::endl;
/* Collect information to rebuild constructors */
@@ -123,8 +126,15 @@ TypeNode SygusGrammarNorm::normalizeSygusType(TypeNode tn, Node sygus_vars)
<< "...for " << cons.getName() << std::endl;
/* Recover the sygus operator to not lose reference to the original
* operator (NOT, ITE, etc) */
- tos[i].d_ops.push_back(cons.getSygusOp());
+ Node exp_sop_n = Node::fromExpr(
+ smt::currentSmtEngine()->expandDefinitions(cons.getSygusOp()));
+ tos[i].d_ops.push_back(Rewriter::rewrite(exp_sop_n));
+ Trace("sygus-grammar-normalize")
+ << "\tOriginal op: " << cons.getSygusOp()
+ << "\n\tExpanded one: " << exp_sop_n
+ << "\n\tRewritten one: " << tos[i].d_ops.back() << std::endl;
tos[i].d_cons_names.push_back(cons.getName());
+ tos[i].d_pcb.push_back(cons.getSygusPrintCallback());
tos[i].d_cons_args_t.push_back(std::vector<Type>());
for (const DatatypeConstructorArg& arg : cons)
{
@@ -142,8 +152,10 @@ TypeNode SygusGrammarNorm::normalizeSygusType(TypeNode tn, Node sygus_vars)
dt.getSygusAllowAll());
for (unsigned j = 0, size_d_ops = tos[i].d_ops.size(); j < size_d_ops; ++j)
{
- tos[i].d_dt.addSygusConstructor(
- tos[i].d_ops[j], tos[i].d_cons_names[j], tos[i].d_cons_args_t[j]);
+ tos[i].d_dt.addSygusConstructor(tos[i].d_ops[j].toExpr(),
+ tos[i].d_cons_names[j],
+ tos[i].d_cons_args_t[j],
+ tos[i].d_pcb[j]);
}
Trace("sygus-grammar-normalize")
<< "...built datatype " << tos[i].d_dt << std::endl;
diff --git a/src/theory/quantifiers/sygus_grammar_norm.h b/src/theory/quantifiers/sygus_grammar_norm.h
index bd63f5fdb..15b4502d3 100644
--- a/src/theory/quantifiers/sygus_grammar_norm.h
+++ b/src/theory/quantifiers/sygus_grammar_norm.h
@@ -47,10 +47,14 @@ struct TypeObject
/* The original typenode this TypeObject is built from */
TypeNode d_tn;
+ /* The type represented by d_tn */
+ Type d_t;
/* Operators for each constructor. */
- std::vector<Expr> d_ops;
+ std::vector<Node> d_ops;
/* Names for each constructor. */
std::vector<std::string> d_cons_names;
+ /* Print callbacks for each constructor */
+ std::vector<std::shared_ptr<SygusPrintCallback>> d_pcb;
/* List of argument types for each constructor */
std::vector<std::vector<Type>> d_cons_args_t;
/* Unresolved type placeholder */
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback