diff options
author | Andrew Reynolds <andrew.j.reynolds@gmail.com> | 2017-09-28 23:58:03 -0500 |
---|---|---|
committer | Andres Noetzli <andres.noetzli@gmail.com> | 2017-09-28 21:58:03 -0700 |
commit | 821a9d90914fca4a13bc29f8ff15fb4220cbd1d4 (patch) | |
tree | 740dfcd00ef961c73029268bd2abd438a186f5f2 /src/expr/symbol_table.cpp | |
parent | fabc9849e7d9ab31b3622487f74235a065852caf (diff) |
Update symbol table to support operator overloading (#1154)
Diffstat (limited to 'src/expr/symbol_table.cpp')
-rw-r--r-- | src/expr/symbol_table.cpp | 289 |
1 files changed, 274 insertions, 15 deletions
diff --git a/src/expr/symbol_table.cpp b/src/expr/symbol_table.cpp index c760b3a80..b411d8dfb 100644 --- a/src/expr/symbol_table.cpp +++ b/src/expr/symbol_table.cpp @@ -21,6 +21,7 @@ #include <ostream> #include <string> #include <utility> +#include <unordered_map> #include "context/cdhashmap.h" #include "context/cdhashset.h" @@ -41,23 +42,204 @@ using ::std::pair; using ::std::string; using ::std::vector; +// This data structure stores a trie of expressions with +// the same name, and must be distinguished by their argument types. +// It is context-dependent. +class OverloadedTypeTrie +{ +public: + OverloadedTypeTrie(Context * c ) : + d_overloaded_symbols(new (true) CDHashSet<Expr, ExprHashFunction>(c)) { + } + ~OverloadedTypeTrie() { + d_overloaded_symbols->deleteSelf(); + } + /** is this function overloaded? */ + bool isOverloadedFunction(Expr fun) const; + + /** Get overloaded constant for type. + * If possible, it returns a defined symbol with name + * that has type t. Otherwise returns null expression. + */ + Expr getOverloadedConstantForType(const std::string& name, Type t) const; + + /** + * If possible, returns a defined function for a name + * and a vector of expected argument types. Otherwise returns + * null expression. + */ + Expr getOverloadedFunctionForTypes(const std::string& name, const std::vector< Type >& argTypes) const; + /** called when obj is bound to name, and prev_bound_obj was already bound to name + * Returns false if the binding is invalid. + */ + bool bind(const string& name, Expr prev_bound_obj, Expr obj); +private: + /** Marks expression obj with name as overloaded. + * Adds relevant information to the type arg trie data structure. + * It returns false if there is already an expression bound to that name + * whose type expects the same arguments as the type of obj but is not identical + * to the type of obj. For example, if we declare : + * + * (declare-datatypes () ((List (cons (hd Int) (tl List)) (nil)))) + * (declare-fun cons (Int List) List) + * + * cons : constructor_type( Int, List, List ) + * cons : function_type( Int, List, List ) + * + * These are put in the same place in the trie but do not have identical type, + * hence we return false. + */ + bool markOverloaded(const string& name, Expr obj); + /** the null expression */ + Expr d_nullExpr; + // The (context-independent) trie storing that maps expected argument + // vectors to symbols. All expressions stored in d_symbols are only + // interpreted as active if they also appear in the context-dependent + // set d_overloaded_symbols. + class TypeArgTrie { + public: + // children of this node + std::map< Type, TypeArgTrie > d_children; + // symbols at this node + std::map< Type, Expr > d_symbols; + }; + /** for each string with operator overloading, this stores the data structure above. */ + std::unordered_map< std::string, TypeArgTrie > d_overload_type_arg_trie; + /** The set of overloaded symbols. */ + CDHashSet<Expr, ExprHashFunction>* d_overloaded_symbols; +}; + +bool OverloadedTypeTrie::isOverloadedFunction(Expr fun) const { + return d_overloaded_symbols->find(fun)!=d_overloaded_symbols->end(); +} + +Expr OverloadedTypeTrie::getOverloadedConstantForType(const std::string& name, Type t) const { + std::unordered_map< std::string, TypeArgTrie >::const_iterator it = d_overload_type_arg_trie.find(name); + if(it!=d_overload_type_arg_trie.end()) { + std::map< Type, Expr >::const_iterator its = it->second.d_symbols.find(t); + if(its!=it->second.d_symbols.end()) { + Expr expr = its->second; + // must be an active symbol + if(isOverloadedFunction(expr)) { + return expr; + } + } + } + return d_nullExpr; +} + +Expr OverloadedTypeTrie::getOverloadedFunctionForTypes(const std::string& name, + const std::vector< Type >& argTypes) const { + std::unordered_map< std::string, TypeArgTrie >::const_iterator it = d_overload_type_arg_trie.find(name); + if(it!=d_overload_type_arg_trie.end()) { + const TypeArgTrie * tat = &it->second; + for(unsigned i=0; i<argTypes.size(); i++) { + std::map< Type, TypeArgTrie >::const_iterator itc = tat->d_children.find(argTypes[i]); + if(itc!=tat->d_children.end()) { + tat = &itc->second; + }else{ + // no functions match + return d_nullExpr; + } + } + // now, we must ensure that there is *only* one active symbol at this node + Expr retExpr; + for(std::map< Type, Expr >::const_iterator its = tat->d_symbols.begin(); its != tat->d_symbols.end(); ++its) { + Expr expr = its->second; + if(isOverloadedFunction(expr)) { + if(retExpr.isNull()) { + retExpr = expr; + }else{ + // multiple functions match + return d_nullExpr; + } + } + } + return retExpr; + } + return d_nullExpr; +} + +bool OverloadedTypeTrie::bind(const string& name, Expr prev_bound_obj, Expr obj) { + bool retprev = true; + if(!isOverloadedFunction(prev_bound_obj)) { + // mark previous as overloaded + retprev = markOverloaded(name, prev_bound_obj); + } + // mark this as overloaded + bool retobj = markOverloaded(name, obj); + return retprev && retobj; +} + +bool OverloadedTypeTrie::markOverloaded(const string& name, Expr obj) { + Trace("parser-overloading") << "Overloaded function : " << name; + Trace("parser-overloading") << " with type " << obj.getType() << std::endl; + // get the argument types + Type t = obj.getType(); + Type rangeType = t; + std::vector< Type > argTypes; + if(t.isFunction()) { + argTypes = static_cast<FunctionType>(t).getArgTypes(); + rangeType = static_cast<FunctionType>(t).getRangeType(); + }else if(t.isConstructor()) { + argTypes = static_cast<ConstructorType>(t).getArgTypes(); + rangeType = static_cast<ConstructorType>(t).getRangeType(); + }else if(t.isTester()) { + argTypes.push_back( static_cast<TesterType>(t).getDomain() ); + rangeType = static_cast<TesterType>(t).getRangeType(); + }else if(t.isSelector()) { + argTypes.push_back( static_cast<SelectorType>(t).getDomain() ); + rangeType = static_cast<SelectorType>(t).getRangeType(); + } + // add to the trie + TypeArgTrie * tat = &d_overload_type_arg_trie[name]; + for(unsigned i=0; i<argTypes.size(); i++) { + tat = &(tat->d_children[argTypes[i]]); + } + + // types can be identical but vary on the kind of the type, thus we must distinguish based on this + std::map< Type, Expr >::iterator it = tat->d_symbols.find( rangeType ); + if( it!=tat->d_symbols.end() ){ + Expr prev_obj = it->second; + // if there is already an active function with the same name and expects the same argument types + if( isOverloadedFunction(prev_obj) ){ + if( prev_obj.getType()==obj.getType() ){ + //types are identical, simply ignore it + return true; + }else{ + //otherwise there is no way to distinguish these types, we return an error + return false; + } + } + } + + // otherwise, update the symbols + d_overloaded_symbols->insert(obj); + tat->d_symbols[rangeType] = obj; + return true; +} + + class SymbolTable::Implementation { public: Implementation() : d_context(), d_exprMap(new (true) CDHashMap<string, Expr>(&d_context)), d_typeMap(new (true) TypeMap(&d_context)), - d_functions(new (true) CDHashSet<Expr, ExprHashFunction>(&d_context)) {} + d_functions(new (true) CDHashSet<Expr, ExprHashFunction>(&d_context)){ + d_overload_trie = new OverloadedTypeTrie(&d_context); + } ~Implementation() { d_exprMap->deleteSelf(); d_typeMap->deleteSelf(); d_functions->deleteSelf(); + delete d_overload_trie; } - void bind(const string& name, Expr obj, bool levelZero) throw(); - void bindDefinedFunction(const string& name, Expr obj, - bool levelZero) throw(); + bool bind(const string& name, Expr obj, bool levelZero, bool doOverload) throw(); + bool bindDefinedFunction(const string& name, Expr obj, + bool levelZero, bool doOverload) throw(); void bindType(const string& name, Type t, bool levelZero = false) throw(); void bindType(const string& name, const vector<Type>& params, Type t, bool levelZero = false) throw(); @@ -73,7 +255,16 @@ class SymbolTable::Implementation { void pushScope() throw(); size_t getLevel() const throw(); void reset(); - + //------------------------ operator overloading + /** implementation of function from header */ + bool isOverloadedFunction(Expr fun) const; + + /** implementation of function from header */ + Expr getOverloadedConstantForType(const std::string& name, Type t) const; + + /** implementation of function from header */ + Expr getOverloadedFunctionForTypes(const std::string& name, const std::vector< Type >& argTypes) const; + //------------------------ end operator overloading private: /** The context manager for the scope maps. */ Context d_context; @@ -87,24 +278,49 @@ class SymbolTable::Implementation { /** A set of defined functions. */ CDHashSet<Expr, ExprHashFunction>* d_functions; + + //------------------------ operator overloading + // the null expression + Expr d_nullExpr; + // overloaded type trie, stores all information regarding overloading + OverloadedTypeTrie * d_overload_trie; + /** bind with overloading + * This is called whenever obj is bound to name where overloading symbols is allowed. + * If a symbol is previously bound to that name, it marks both as overloaded. + * Returns false if the binding was invalid. + */ + bool bindWithOverloading(const string& name, Expr obj); + //------------------------ end operator overloading }; /* SymbolTable::Implementation */ -void SymbolTable::Implementation::bind(const string& name, Expr obj, - bool levelZero) throw() { +bool SymbolTable::Implementation::bind(const string& name, Expr obj, + bool levelZero, bool doOverload) throw() { PrettyCheckArgument(!obj.isNull(), obj, "cannot bind to a null Expr"); ExprManagerScope ems(obj); + if (doOverload) { + if( !bindWithOverloading(name, obj) ){ + return false; + } + } if (levelZero) { d_exprMap->insertAtContextLevelZero(name, obj); } else { d_exprMap->insert(name, obj); } + return true; } -void SymbolTable::Implementation::bindDefinedFunction(const string& name, +bool SymbolTable::Implementation::bindDefinedFunction(const string& name, Expr obj, - bool levelZero) throw() { + bool levelZero, + bool doOverload) throw() { PrettyCheckArgument(!obj.isNull(), obj, "cannot bind to a null Expr"); ExprManagerScope ems(obj); + if (doOverload) { + if( !bindWithOverloading(name, obj) ){ + return false; + } + } if (levelZero) { d_exprMap->insertAtContextLevelZero(name, obj); d_functions->insertAtContextLevelZero(obj); @@ -112,6 +328,7 @@ void SymbolTable::Implementation::bindDefinedFunction(const string& name, d_exprMap->insert(name, obj); d_functions->insert(obj); } + return true; } bool SymbolTable::Implementation::isBound(const string& name) const throw() { @@ -130,7 +347,12 @@ bool SymbolTable::Implementation::isBoundDefinedFunction(Expr func) const } Expr SymbolTable::Implementation::lookup(const string& name) const throw() { - return (*d_exprMap->find(name)).second; + Expr expr = (*d_exprMap->find(name)).second; + if(isOverloadedFunction(expr)) { + return d_nullExpr; + }else{ + return expr; + } } void SymbolTable::Implementation::bindType(const string& name, Type t, @@ -255,18 +477,55 @@ void SymbolTable::Implementation::reset() { new (this) SymbolTable::Implementation(); } +bool SymbolTable::Implementation::isOverloadedFunction(Expr fun) const { + return d_overload_trie->isOverloadedFunction(fun); +} + +Expr SymbolTable::Implementation::getOverloadedConstantForType(const std::string& name, Type t) const { + return d_overload_trie->getOverloadedConstantForType(name, t); +} + +Expr SymbolTable::Implementation::getOverloadedFunctionForTypes(const std::string& name, + const std::vector< Type >& argTypes) const { + return d_overload_trie->getOverloadedFunctionForTypes(name, argTypes); +} + +bool SymbolTable::Implementation::bindWithOverloading(const string& name, Expr obj){ + CDHashMap<string, Expr>::const_iterator it = d_exprMap->find(name); + if(it != d_exprMap->end()) { + const Expr& prev_bound_obj = (*it).second; + if(prev_bound_obj!=obj) { + return d_overload_trie->bind(name, prev_bound_obj, obj); + } + } + return true; +} + +bool SymbolTable::isOverloadedFunction(Expr fun) const { + return d_implementation->isOverloadedFunction(fun); +} + +Expr SymbolTable::getOverloadedConstantForType(const std::string& name, Type t) const { + return d_implementation->getOverloadedConstantForType(name, t); +} + +Expr SymbolTable::getOverloadedFunctionForTypes(const std::string& name, + const std::vector< Type >& argTypes) const { + return d_implementation->getOverloadedFunctionForTypes(name, argTypes); +} + SymbolTable::SymbolTable() : d_implementation(new SymbolTable::Implementation()) {} SymbolTable::~SymbolTable() {} -void SymbolTable::bind(const string& name, Expr obj, bool levelZero) throw() { - d_implementation->bind(name, obj, levelZero); +bool SymbolTable::bind(const string& name, Expr obj, bool levelZero, bool doOverload) throw() { + return d_implementation->bind(name, obj, levelZero, doOverload); } -void SymbolTable::bindDefinedFunction(const string& name, Expr obj, - bool levelZero) throw() { - d_implementation->bindDefinedFunction(name, obj, levelZero); +bool SymbolTable::bindDefinedFunction(const string& name, Expr obj, + bool levelZero, bool doOverload) throw() { + return d_implementation->bindDefinedFunction(name, obj, levelZero, doOverload); } void SymbolTable::bindType(const string& name, Type t, bool levelZero) throw() { |