summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>2017-09-28 23:58:03 -0500
committerAndres Noetzli <andres.noetzli@gmail.com>2017-09-28 21:58:03 -0700
commit821a9d90914fca4a13bc29f8ff15fb4220cbd1d4 (patch)
tree740dfcd00ef961c73029268bd2abd438a186f5f2
parentfabc9849e7d9ab31b3622487f74235a065852caf (diff)
Update symbol table to support operator overloading (#1154)
-rw-r--r--src/expr/symbol_table.cpp289
-rw-r--r--src/expr/symbol_table.h77
2 files changed, 338 insertions, 28 deletions
diff --git a/src/expr/symbol_table.cpp b/src/expr/symbol_table.cpp
index c760b3a80..b411d8dfb 100644
--- a/src/expr/symbol_table.cpp
+++ b/src/expr/symbol_table.cpp
@@ -21,6 +21,7 @@
#include <ostream>
#include <string>
#include <utility>
+#include <unordered_map>
#include "context/cdhashmap.h"
#include "context/cdhashset.h"
@@ -41,23 +42,204 @@ using ::std::pair;
using ::std::string;
using ::std::vector;
+// This data structure stores a trie of expressions with
+// the same name, and must be distinguished by their argument types.
+// It is context-dependent.
+class OverloadedTypeTrie
+{
+public:
+ OverloadedTypeTrie(Context * c ) :
+ d_overloaded_symbols(new (true) CDHashSet<Expr, ExprHashFunction>(c)) {
+ }
+ ~OverloadedTypeTrie() {
+ d_overloaded_symbols->deleteSelf();
+ }
+ /** is this function overloaded? */
+ bool isOverloadedFunction(Expr fun) const;
+
+ /** Get overloaded constant for type.
+ * If possible, it returns a defined symbol with name
+ * that has type t. Otherwise returns null expression.
+ */
+ Expr getOverloadedConstantForType(const std::string& name, Type t) const;
+
+ /**
+ * If possible, returns a defined function for a name
+ * and a vector of expected argument types. Otherwise returns
+ * null expression.
+ */
+ Expr getOverloadedFunctionForTypes(const std::string& name, const std::vector< Type >& argTypes) const;
+ /** called when obj is bound to name, and prev_bound_obj was already bound to name
+ * Returns false if the binding is invalid.
+ */
+ bool bind(const string& name, Expr prev_bound_obj, Expr obj);
+private:
+ /** Marks expression obj with name as overloaded.
+ * Adds relevant information to the type arg trie data structure.
+ * It returns false if there is already an expression bound to that name
+ * whose type expects the same arguments as the type of obj but is not identical
+ * to the type of obj. For example, if we declare :
+ *
+ * (declare-datatypes () ((List (cons (hd Int) (tl List)) (nil))))
+ * (declare-fun cons (Int List) List)
+ *
+ * cons : constructor_type( Int, List, List )
+ * cons : function_type( Int, List, List )
+ *
+ * These are put in the same place in the trie but do not have identical type,
+ * hence we return false.
+ */
+ bool markOverloaded(const string& name, Expr obj);
+ /** the null expression */
+ Expr d_nullExpr;
+ // The (context-independent) trie storing that maps expected argument
+ // vectors to symbols. All expressions stored in d_symbols are only
+ // interpreted as active if they also appear in the context-dependent
+ // set d_overloaded_symbols.
+ class TypeArgTrie {
+ public:
+ // children of this node
+ std::map< Type, TypeArgTrie > d_children;
+ // symbols at this node
+ std::map< Type, Expr > d_symbols;
+ };
+ /** for each string with operator overloading, this stores the data structure above. */
+ std::unordered_map< std::string, TypeArgTrie > d_overload_type_arg_trie;
+ /** The set of overloaded symbols. */
+ CDHashSet<Expr, ExprHashFunction>* d_overloaded_symbols;
+};
+
+bool OverloadedTypeTrie::isOverloadedFunction(Expr fun) const {
+ return d_overloaded_symbols->find(fun)!=d_overloaded_symbols->end();
+}
+
+Expr OverloadedTypeTrie::getOverloadedConstantForType(const std::string& name, Type t) const {
+ std::unordered_map< std::string, TypeArgTrie >::const_iterator it = d_overload_type_arg_trie.find(name);
+ if(it!=d_overload_type_arg_trie.end()) {
+ std::map< Type, Expr >::const_iterator its = it->second.d_symbols.find(t);
+ if(its!=it->second.d_symbols.end()) {
+ Expr expr = its->second;
+ // must be an active symbol
+ if(isOverloadedFunction(expr)) {
+ return expr;
+ }
+ }
+ }
+ return d_nullExpr;
+}
+
+Expr OverloadedTypeTrie::getOverloadedFunctionForTypes(const std::string& name,
+ const std::vector< Type >& argTypes) const {
+ std::unordered_map< std::string, TypeArgTrie >::const_iterator it = d_overload_type_arg_trie.find(name);
+ if(it!=d_overload_type_arg_trie.end()) {
+ const TypeArgTrie * tat = &it->second;
+ for(unsigned i=0; i<argTypes.size(); i++) {
+ std::map< Type, TypeArgTrie >::const_iterator itc = tat->d_children.find(argTypes[i]);
+ if(itc!=tat->d_children.end()) {
+ tat = &itc->second;
+ }else{
+ // no functions match
+ return d_nullExpr;
+ }
+ }
+ // now, we must ensure that there is *only* one active symbol at this node
+ Expr retExpr;
+ for(std::map< Type, Expr >::const_iterator its = tat->d_symbols.begin(); its != tat->d_symbols.end(); ++its) {
+ Expr expr = its->second;
+ if(isOverloadedFunction(expr)) {
+ if(retExpr.isNull()) {
+ retExpr = expr;
+ }else{
+ // multiple functions match
+ return d_nullExpr;
+ }
+ }
+ }
+ return retExpr;
+ }
+ return d_nullExpr;
+}
+
+bool OverloadedTypeTrie::bind(const string& name, Expr prev_bound_obj, Expr obj) {
+ bool retprev = true;
+ if(!isOverloadedFunction(prev_bound_obj)) {
+ // mark previous as overloaded
+ retprev = markOverloaded(name, prev_bound_obj);
+ }
+ // mark this as overloaded
+ bool retobj = markOverloaded(name, obj);
+ return retprev && retobj;
+}
+
+bool OverloadedTypeTrie::markOverloaded(const string& name, Expr obj) {
+ Trace("parser-overloading") << "Overloaded function : " << name;
+ Trace("parser-overloading") << " with type " << obj.getType() << std::endl;
+ // get the argument types
+ Type t = obj.getType();
+ Type rangeType = t;
+ std::vector< Type > argTypes;
+ if(t.isFunction()) {
+ argTypes = static_cast<FunctionType>(t).getArgTypes();
+ rangeType = static_cast<FunctionType>(t).getRangeType();
+ }else if(t.isConstructor()) {
+ argTypes = static_cast<ConstructorType>(t).getArgTypes();
+ rangeType = static_cast<ConstructorType>(t).getRangeType();
+ }else if(t.isTester()) {
+ argTypes.push_back( static_cast<TesterType>(t).getDomain() );
+ rangeType = static_cast<TesterType>(t).getRangeType();
+ }else if(t.isSelector()) {
+ argTypes.push_back( static_cast<SelectorType>(t).getDomain() );
+ rangeType = static_cast<SelectorType>(t).getRangeType();
+ }
+ // add to the trie
+ TypeArgTrie * tat = &d_overload_type_arg_trie[name];
+ for(unsigned i=0; i<argTypes.size(); i++) {
+ tat = &(tat->d_children[argTypes[i]]);
+ }
+
+ // types can be identical but vary on the kind of the type, thus we must distinguish based on this
+ std::map< Type, Expr >::iterator it = tat->d_symbols.find( rangeType );
+ if( it!=tat->d_symbols.end() ){
+ Expr prev_obj = it->second;
+ // if there is already an active function with the same name and expects the same argument types
+ if( isOverloadedFunction(prev_obj) ){
+ if( prev_obj.getType()==obj.getType() ){
+ //types are identical, simply ignore it
+ return true;
+ }else{
+ //otherwise there is no way to distinguish these types, we return an error
+ return false;
+ }
+ }
+ }
+
+ // otherwise, update the symbols
+ d_overloaded_symbols->insert(obj);
+ tat->d_symbols[rangeType] = obj;
+ return true;
+}
+
+
class SymbolTable::Implementation {
public:
Implementation()
: d_context(),
d_exprMap(new (true) CDHashMap<string, Expr>(&d_context)),
d_typeMap(new (true) TypeMap(&d_context)),
- d_functions(new (true) CDHashSet<Expr, ExprHashFunction>(&d_context)) {}
+ d_functions(new (true) CDHashSet<Expr, ExprHashFunction>(&d_context)){
+ d_overload_trie = new OverloadedTypeTrie(&d_context);
+ }
~Implementation() {
d_exprMap->deleteSelf();
d_typeMap->deleteSelf();
d_functions->deleteSelf();
+ delete d_overload_trie;
}
- void bind(const string& name, Expr obj, bool levelZero) throw();
- void bindDefinedFunction(const string& name, Expr obj,
- bool levelZero) throw();
+ bool bind(const string& name, Expr obj, bool levelZero, bool doOverload) throw();
+ bool bindDefinedFunction(const string& name, Expr obj,
+ bool levelZero, bool doOverload) throw();
void bindType(const string& name, Type t, bool levelZero = false) throw();
void bindType(const string& name, const vector<Type>& params, Type t,
bool levelZero = false) throw();
@@ -73,7 +255,16 @@ class SymbolTable::Implementation {
void pushScope() throw();
size_t getLevel() const throw();
void reset();
-
+ //------------------------ operator overloading
+ /** implementation of function from header */
+ bool isOverloadedFunction(Expr fun) const;
+
+ /** implementation of function from header */
+ Expr getOverloadedConstantForType(const std::string& name, Type t) const;
+
+ /** implementation of function from header */
+ Expr getOverloadedFunctionForTypes(const std::string& name, const std::vector< Type >& argTypes) const;
+ //------------------------ end operator overloading
private:
/** The context manager for the scope maps. */
Context d_context;
@@ -87,24 +278,49 @@ class SymbolTable::Implementation {
/** A set of defined functions. */
CDHashSet<Expr, ExprHashFunction>* d_functions;
+
+ //------------------------ operator overloading
+ // the null expression
+ Expr d_nullExpr;
+ // overloaded type trie, stores all information regarding overloading
+ OverloadedTypeTrie * d_overload_trie;
+ /** bind with overloading
+ * This is called whenever obj is bound to name where overloading symbols is allowed.
+ * If a symbol is previously bound to that name, it marks both as overloaded.
+ * Returns false if the binding was invalid.
+ */
+ bool bindWithOverloading(const string& name, Expr obj);
+ //------------------------ end operator overloading
}; /* SymbolTable::Implementation */
-void SymbolTable::Implementation::bind(const string& name, Expr obj,
- bool levelZero) throw() {
+bool SymbolTable::Implementation::bind(const string& name, Expr obj,
+ bool levelZero, bool doOverload) throw() {
PrettyCheckArgument(!obj.isNull(), obj, "cannot bind to a null Expr");
ExprManagerScope ems(obj);
+ if (doOverload) {
+ if( !bindWithOverloading(name, obj) ){
+ return false;
+ }
+ }
if (levelZero) {
d_exprMap->insertAtContextLevelZero(name, obj);
} else {
d_exprMap->insert(name, obj);
}
+ return true;
}
-void SymbolTable::Implementation::bindDefinedFunction(const string& name,
+bool SymbolTable::Implementation::bindDefinedFunction(const string& name,
Expr obj,
- bool levelZero) throw() {
+ bool levelZero,
+ bool doOverload) throw() {
PrettyCheckArgument(!obj.isNull(), obj, "cannot bind to a null Expr");
ExprManagerScope ems(obj);
+ if (doOverload) {
+ if( !bindWithOverloading(name, obj) ){
+ return false;
+ }
+ }
if (levelZero) {
d_exprMap->insertAtContextLevelZero(name, obj);
d_functions->insertAtContextLevelZero(obj);
@@ -112,6 +328,7 @@ void SymbolTable::Implementation::bindDefinedFunction(const string& name,
d_exprMap->insert(name, obj);
d_functions->insert(obj);
}
+ return true;
}
bool SymbolTable::Implementation::isBound(const string& name) const throw() {
@@ -130,7 +347,12 @@ bool SymbolTable::Implementation::isBoundDefinedFunction(Expr func) const
}
Expr SymbolTable::Implementation::lookup(const string& name) const throw() {
- return (*d_exprMap->find(name)).second;
+ Expr expr = (*d_exprMap->find(name)).second;
+ if(isOverloadedFunction(expr)) {
+ return d_nullExpr;
+ }else{
+ return expr;
+ }
}
void SymbolTable::Implementation::bindType(const string& name, Type t,
@@ -255,18 +477,55 @@ void SymbolTable::Implementation::reset() {
new (this) SymbolTable::Implementation();
}
+bool SymbolTable::Implementation::isOverloadedFunction(Expr fun) const {
+ return d_overload_trie->isOverloadedFunction(fun);
+}
+
+Expr SymbolTable::Implementation::getOverloadedConstantForType(const std::string& name, Type t) const {
+ return d_overload_trie->getOverloadedConstantForType(name, t);
+}
+
+Expr SymbolTable::Implementation::getOverloadedFunctionForTypes(const std::string& name,
+ const std::vector< Type >& argTypes) const {
+ return d_overload_trie->getOverloadedFunctionForTypes(name, argTypes);
+}
+
+bool SymbolTable::Implementation::bindWithOverloading(const string& name, Expr obj){
+ CDHashMap<string, Expr>::const_iterator it = d_exprMap->find(name);
+ if(it != d_exprMap->end()) {
+ const Expr& prev_bound_obj = (*it).second;
+ if(prev_bound_obj!=obj) {
+ return d_overload_trie->bind(name, prev_bound_obj, obj);
+ }
+ }
+ return true;
+}
+
+bool SymbolTable::isOverloadedFunction(Expr fun) const {
+ return d_implementation->isOverloadedFunction(fun);
+}
+
+Expr SymbolTable::getOverloadedConstantForType(const std::string& name, Type t) const {
+ return d_implementation->getOverloadedConstantForType(name, t);
+}
+
+Expr SymbolTable::getOverloadedFunctionForTypes(const std::string& name,
+ const std::vector< Type >& argTypes) const {
+ return d_implementation->getOverloadedFunctionForTypes(name, argTypes);
+}
+
SymbolTable::SymbolTable()
: d_implementation(new SymbolTable::Implementation()) {}
SymbolTable::~SymbolTable() {}
-void SymbolTable::bind(const string& name, Expr obj, bool levelZero) throw() {
- d_implementation->bind(name, obj, levelZero);
+bool SymbolTable::bind(const string& name, Expr obj, bool levelZero, bool doOverload) throw() {
+ return d_implementation->bind(name, obj, levelZero, doOverload);
}
-void SymbolTable::bindDefinedFunction(const string& name, Expr obj,
- bool levelZero) throw() {
- d_implementation->bindDefinedFunction(name, obj, levelZero);
+bool SymbolTable::bindDefinedFunction(const string& name, Expr obj,
+ bool levelZero, bool doOverload) throw() {
+ return d_implementation->bindDefinedFunction(name, obj, levelZero, doOverload);
}
void SymbolTable::bindType(const string& name, Type t, bool levelZero) throw() {
diff --git a/src/expr/symbol_table.h b/src/expr/symbol_table.h
index e64488563..b6ca7a76f 100644
--- a/src/expr/symbol_table.h
+++ b/src/expr/symbol_table.h
@@ -43,33 +43,58 @@ class CVC4_PUBLIC SymbolTable {
~SymbolTable();
/**
- * Bind an expression to a name in the current scope level. If
- * <code>name</code> is already bound to an expression in the current
+ * Bind an expression to a name in the current scope level.
+ *
+ * When doOverload is false:
+ * if <code>name</code> is already bound to an expression in the current
* level, then the binding is replaced. If <code>name</code> is bound
* in a previous level, then the binding is "covered" by this one
- * until the current scope is popped. If levelZero is true the name
- * shouldn't be already bound.
+ * until the current scope is popped.
+ * If levelZero is true the name shouldn't be already bound.
+ *
+ * When doOverload is true:
+ * if <code>name</code> is already bound to an expression in the current
+ * level, then we mark the previous bound expression and obj as overloaded
+ * functions.
*
* @param name an identifier
* @param obj the expression to bind to <code>name</code>
* @param levelZero set if the binding must be done at level 0
+ * @param doOverload set if the binding can overload the function name.
+ *
+ * Returns false if the binding was invalid.
*/
- void bind(const std::string& name, Expr obj, bool levelZero = false) throw();
+ bool bind(const std::string& name, Expr obj, bool levelZero = false,
+ bool doOverload = false) throw();
/**
- * Bind a function body to a name in the current scope. If
- * <code>name</code> is already bound to an expression in the current
+ * Bind a function body to a name in the current scope.
+ *
+ * When doOverload is false:
+ * if <code>name</code> is already bound to an expression in the current
* level, then the binding is replaced. If <code>name</code> is bound
* in a previous level, then the binding is "covered" by this one
- * until the current scope is popped. Same as bind() but registers
- * this as a function (so that isBoundDefinedFunction() returns true).
+ * until the current scope is popped.
+ * If levelZero is true the name shouldn't be already bound.
+ *
+ * When doOverload is true:
+ * if <code>name</code> is already bound to an expression in the current
+ * level, then we mark the previous bound expression and obj as overloaded
+ * functions.
+ *
+ * Same as bind() but registers this as a function (so that
+ * isBoundDefinedFunction() returns true).
*
* @param name an identifier
* @param obj the expression to bind to <code>name</code>
* @param levelZero set if the binding must be done at level 0
+ * @param doOverload set if the binding can overload the function name.
+ *
+ * Returns false if the binding was invalid.
*/
- void bindDefinedFunction(const std::string& name, Expr obj,
- bool levelZero = false) throw();
+ bool bindDefinedFunction(const std::string& name, Expr obj,
+ bool levelZero = false,
+ bool doOverload = false) throw();
/**
* Bind a type to a name in the current scope. If <code>name</code>
@@ -133,7 +158,9 @@ class CVC4_PUBLIC SymbolTable {
* Lookup a bound expression.
*
* @param name the identifier to lookup
- * @returns the expression bound to <code>name</code> in the current scope.
+ * @returns the unique expression bound to <code>name</code> in the current scope.
+ * It returns the null expression if there is not a unique expression bound to
+ * <code>name</code> in the current scope (i.e. if there is not exactly one).
*/
Expr lookup(const std::string& name) const throw();
@@ -178,7 +205,31 @@ class CVC4_PUBLIC SymbolTable {
/** Reset everything. */
void reset();
-
+
+ //------------------------ operator overloading
+ /** is this function overloaded? */
+ bool isOverloadedFunction(Expr fun) const;
+
+ /** Get overloaded constant for type.
+ * If possible, it returns the defined symbol with name
+ * that has type t. Otherwise returns null expression.
+ */
+ Expr getOverloadedConstantForType(const std::string& name, Type t) const;
+
+ /**
+ * If possible, returns the unique defined function for a name
+ * that expects arguments with types "argTypes".
+ * For example, if argTypes = ( T1, ..., Tn ), then this may return
+ * an expression with type function( T1, ..., Tn ), or constructor( T1, ...., Tn ).
+ *
+ * If there is not a unique defined function for the name and argTypes,
+ * this returns the null expression. This can happen either if there are
+ * no functions with name and expected argTypes, or alternatively there is
+ * more than one function with name and expected argTypes.
+ */
+ Expr getOverloadedFunctionForTypes(const std::string& name, const std::vector< Type >& argTypes) const;
+ //------------------------ end operator overloading
+
private:
// Copying and assignment have not yet been implemented.
SymbolTable(const SymbolTable&);
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback