diff options
-rw-r--r-- | src/python/m5/SimObject.py | 87 |
1 files changed, 80 insertions, 7 deletions
diff --git a/src/python/m5/SimObject.py b/src/python/m5/SimObject.py index 97f684739..b74e93a87 100644 --- a/src/python/m5/SimObject.py +++ b/src/python/m5/SimObject.py @@ -415,6 +415,7 @@ class MetaSimObject(type): 'cxx_extra_bases' : list, 'cxx_exports' : list, 'cxx_param_exports' : list, + 'cxx_template_params' : list, } # Attributes that can be set any time keywords = { 'check' : FunctionType } @@ -454,6 +455,8 @@ class MetaSimObject(type): value_dict['cxx_exports'] += cxx_exports if 'cxx_param_exports' not in value_dict: value_dict['cxx_param_exports'] = [] + if 'cxx_template_params' not in value_dict: + value_dict['cxx_template_params'] = [] cls_dict['_value_dict'] = value_dict cls = super(MetaSimObject, mcls).__new__(mcls, name, bases, cls_dict) if 'type' in value_dict: @@ -773,6 +776,7 @@ module_init(py::module &m_internal) code('static EmbeddedPyBind embed_obj("${0}", module_init, "${1}");', cls, cls._base.type if cls._base else "") + _warned_about_nested_templates = False # Generate the C++ declaration (.hh file) for this SimObject's # param struct. Called from src/SConscript. @@ -790,7 +794,78 @@ module_init(py::module &m_internal) print(params) raise - class_path = cls._value_dict['cxx_class'].split('::') + class CxxClass(object): + def __init__(self, sig, template_params=[]): + # Split the signature into its constituent parts. This could + # potentially be done with regular expressions, but + # it's simple enough to pick appart a class signature + # manually. + parts = sig.split('<', 1) + base = parts[0] + t_args = [] + if len(parts) > 1: + # The signature had template arguments. + text = parts[1].rstrip(' \t\n>') + arg = '' + # Keep track of nesting to avoid splitting on ","s embedded + # in the arguments themselves. + depth = 0 + for c in text: + if c == '<': + depth = depth + 1 + if depth > 0 and not \ + self._warned_about_nested_templates: + self._warned_about_nested_templates = True + print('Nested template argument in cxx_class.' + ' This feature is largely untested and ' + ' may not work.') + elif c == '>': + depth = depth - 1 + elif c == ',' and depth == 0: + t_args.append(arg.strip()) + arg = '' + else: + arg = arg + c + if arg: + t_args.append(arg.strip()) + # Split the non-template part on :: boundaries. + class_path = base.split('::') + + # The namespaces are everything except the last part of the + # class path. + self.namespaces = class_path[:-1] + # And the class name is the last part. + self.name = class_path[-1] + + self.template_params = template_params + self.template_arguments = [] + # Iterate through the template arguments and their values. This + # will likely break if parameter packs are used. + for arg, param in zip(t_args, template_params): + type_keys = ('class', 'typename') + # If a parameter is a type, parse it recursively. Otherwise + # assume it's a constant, and store it verbatim. + if any(param.strip().startswith(kw) for kw in type_keys): + self.template_arguments.append(CxxClass(arg)) + else: + self.template_arguments.append(arg) + + def declare(self, code): + # First declare any template argument types. + for arg in self.template_arguments: + if isinstance(arg, CxxClass): + arg.declare(code) + # Re-open the target namespace. + for ns in self.namespaces: + code('namespace $ns {') + # If this is a class template... + if self.template_params: + code('template <${{", ".join(self.template_params)}}>') + # The actual class declaration. + code('class ${{self.name}};') + # Close the target namespaces. + for ns in reversed(self.namespaces): + code('} // namespace $ns') code('''\ #ifndef __PARAMS__${cls}__ @@ -806,14 +881,12 @@ module_init(py::module &m_internal) if cls == SimObject: code('''#include <string>''') + cxx_class = CxxClass(cls._value_dict['cxx_class'], + cls._value_dict['cxx_template_params']) + # A forward class declaration is sufficient since we are just # declaring a pointer. - for ns in class_path[:-1]: - code('namespace $ns {') - code('class $0;', class_path[-1]) - for ns in reversed(class_path[:-1]): - code('} // namespace $ns') - code() + cxx_class.declare(code) for param in params: param.cxx_predecls(code) |