summaryrefslogtreecommitdiff
path: root/src/python/m5/util
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/m5/util')
-rw-r--r--src/python/m5/util/__init__.py149
-rw-r--r--src/python/m5/util/convert.py250
-rw-r--r--src/python/m5/util/jobfile.py10
-rw-r--r--src/python/m5/util/misc.py87
-rw-r--r--src/python/m5/util/smartdict.py154
5 files changed, 554 insertions, 96 deletions
diff --git a/src/python/m5/util/__init__.py b/src/python/m5/util/__init__.py
index 3930c8b6f..7a674dd2d 100644
--- a/src/python/m5/util/__init__.py
+++ b/src/python/m5/util/__init__.py
@@ -1,4 +1,5 @@
-# Copyright (c) 2008 The Hewlett-Packard Development Company
+# Copyright (c) 2008-2009 The Hewlett-Packard Development Company
+# Copyright (c) 2004-2006 The Regents of The University of Michigan
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
@@ -26,14 +27,130 @@
#
# Authors: Nathan Binkert
+import os
+import re
+import sys
+
+import convert
+import jobfile
+
from attrdict import attrdict, optiondict
from code_formatter import code_formatter
-from misc import *
from multidict import multidict
from orderdict import orderdict
-import jobfile
+from smartdict import SmartDict
+
+# define this here so we can use it right away if necessary
+def errorURL(prefix, s):
+ try:
+ import zlib
+ hashstr = "%x" % zlib.crc32(s)
+ except:
+ hashstr = "UnableToHash"
+ return "For more information see: http://www.m5sim.org/%s/%s" % \
+ (prefix, hashstr)
+
+# panic() should be called when something happens that should never
+# ever happen regardless of what the user does (i.e., an acutal m5
+# bug).
+def panic(fmt, *args):
+ print >>sys.stderr, 'panic:', fmt % args
+ print >>sys.stderr, errorURL('panic',fmt)
+ sys.exit(1)
+
+# fatal() should be called when the simulation cannot continue due to
+# some condition that is the user's fault (bad configuration, invalid
+# arguments, etc.) and not a simulator bug.
+def fatal(fmt, *args):
+ print >>sys.stderr, 'fatal:', fmt % args
+ print >>sys.stderr, errorURL('fatal',fmt)
+ sys.exit(1)
+
+class Singleton(type):
+ def __call__(cls, *args, **kwargs):
+ if hasattr(cls, '_instance'):
+ return cls._instance
+
+ cls._instance = super(Singleton, cls).__call__(*args, **kwargs)
+ return cls._instance
+
+def addToPath(path):
+ """Prepend given directory to system module search path. We may not
+ need this anymore if we can structure our config library more like a
+ Python package."""
+
+ # if it's a relative path and we know what directory the current
+ # python script is in, make the path relative to that directory.
+ if not os.path.isabs(path) and sys.path[0]:
+ path = os.path.join(sys.path[0], path)
+ path = os.path.realpath(path)
+ # sys.path[0] should always refer to the current script's directory,
+ # so place the new dir right after that.
+ sys.path.insert(1, path)
+
+# Apply method to object.
+# applyMethod(obj, 'meth', <args>) is equivalent to obj.meth(<args>)
+def applyMethod(obj, meth, *args, **kwargs):
+ return getattr(obj, meth)(*args, **kwargs)
-def print_list(items, indent=4):
+# If the first argument is an (non-sequence) object, apply the named
+# method with the given arguments. If the first argument is a
+# sequence, apply the method to each element of the sequence (a la
+# 'map').
+def applyOrMap(objOrSeq, meth, *args, **kwargs):
+ if not isinstance(objOrSeq, (list, tuple)):
+ return applyMethod(objOrSeq, meth, *args, **kwargs)
+ else:
+ return [applyMethod(o, meth, *args, **kwargs) for o in objOrSeq]
+
+def compareVersions(v1, v2):
+ """helper function: compare arrays or strings of version numbers.
+ E.g., compare_version((1,3,25), (1,4,1)')
+ returns -1, 0, 1 if v1 is <, ==, > v2
+ """
+ def make_version_list(v):
+ if isinstance(v, (list,tuple)):
+ return v
+ elif isinstance(v, str):
+ return map(lambda x: int(re.match('\d+', x).group()), v.split('.'))
+ else:
+ raise TypeError
+
+ v1 = make_version_list(v1)
+ v2 = make_version_list(v2)
+ # Compare corresponding elements of lists
+ for n1,n2 in zip(v1, v2):
+ if n1 < n2: return -1
+ if n1 > n2: return 1
+ # all corresponding values are equal... see if one has extra values
+ if len(v1) < len(v2): return -1
+ if len(v1) > len(v2): return 1
+ return 0
+
+def crossproduct(items):
+ if len(items) == 1:
+ for i in items[0]:
+ yield (i,)
+ else:
+ for i in items[0]:
+ for j in crossproduct(items[1:]):
+ yield (i,) + j
+
+def flatten(items):
+ while items:
+ item = items.pop(0)
+ if isinstance(item, (list, tuple)):
+ items[0:0] = item
+ else:
+ yield item
+
+# force scalars to one-element lists for uniformity
+def makeList(objOrList):
+ if isinstance(objOrList, list):
+ return objOrList
+ return [objOrList]
+
+def printList(items, indent=4):
line = ' ' * indent
for i,item in enumerate(items):
if len(line) + len(item) > 76:
@@ -45,3 +162,27 @@ def print_list(items, indent=4):
else:
line += item
print line
+
+def readCommand(cmd, **kwargs):
+ """run the command cmd, read the results and return them
+ this is sorta like `cmd` in shell"""
+ from subprocess import Popen, PIPE, STDOUT
+
+ if isinstance(cmd, str):
+ cmd = cmd.split()
+
+ no_exception = 'exception' in kwargs
+ exception = kwargs.pop('exception', None)
+
+ kwargs.setdefault('shell', False)
+ kwargs.setdefault('stdout', PIPE)
+ kwargs.setdefault('stderr', STDOUT)
+ kwargs.setdefault('close_fds', True)
+ try:
+ subp = Popen(cmd, **kwargs)
+ except Exception, e:
+ if no_exception:
+ return exception
+ raise
+
+ return subp.communicate()[0]
diff --git a/src/python/m5/util/convert.py b/src/python/m5/util/convert.py
new file mode 100644
index 000000000..bb9e3e1f1
--- /dev/null
+++ b/src/python/m5/util/convert.py
@@ -0,0 +1,250 @@
+# Copyright (c) 2005 The Regents of The University of Michigan
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met: redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer;
+# redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution;
+# neither the name of the copyright holders nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+# Authors: Nathan Binkert
+
+# metric prefixes
+exa = 1.0e18
+peta = 1.0e15
+tera = 1.0e12
+giga = 1.0e9
+mega = 1.0e6
+kilo = 1.0e3
+
+milli = 1.0e-3
+micro = 1.0e-6
+nano = 1.0e-9
+pico = 1.0e-12
+femto = 1.0e-15
+atto = 1.0e-18
+
+# power of 2 prefixes
+kibi = 1024
+mebi = kibi * 1024
+gibi = mebi * 1024
+tebi = gibi * 1024
+pebi = tebi * 1024
+exbi = pebi * 1024
+
+# memory size configuration stuff
+def toFloat(value):
+ if not isinstance(value, str):
+ raise TypeError, "wrong type '%s' should be str" % type(value)
+
+ if value.endswith('Ei'):
+ return float(value[:-2]) * exbi
+ elif value.endswith('Pi'):
+ return float(value[:-2]) * pebi
+ elif value.endswith('Ti'):
+ return float(value[:-2]) * tebi
+ elif value.endswith('Gi'):
+ return float(value[:-2]) * gibi
+ elif value.endswith('Mi'):
+ return float(value[:-2]) * mebi
+ elif value.endswith('ki'):
+ return float(value[:-2]) * kibi
+ elif value.endswith('E'):
+ return float(value[:-1]) * exa
+ elif value.endswith('P'):
+ return float(value[:-1]) * peta
+ elif value.endswith('T'):
+ return float(value[:-1]) * tera
+ elif value.endswith('G'):
+ return float(value[:-1]) * giga
+ elif value.endswith('M'):
+ return float(value[:-1]) * mega
+ elif value.endswith('k'):
+ return float(value[:-1]) * kilo
+ elif value.endswith('m'):
+ return float(value[:-1]) * milli
+ elif value.endswith('u'):
+ return float(value[:-1]) * micro
+ elif value.endswith('n'):
+ return float(value[:-1]) * nano
+ elif value.endswith('p'):
+ return float(value[:-1]) * pico
+ elif value.endswith('f'):
+ return float(value[:-1]) * femto
+ else:
+ return float(value)
+
+def toInteger(value):
+ value = toFloat(value)
+ result = long(value)
+ if value != result:
+ raise ValueError, "cannot convert '%s' to integer" % value
+
+ return result
+
+_bool_dict = {
+ 'true' : True, 't' : True, 'yes' : True, 'y' : True, '1' : True,
+ 'false' : False, 'f' : False, 'no' : False, 'n' : False, '0' : False
+ }
+
+def toBool(value):
+ if not isinstance(value, str):
+ raise TypeError, "wrong type '%s' should be str" % type(value)
+
+ value = value.lower()
+ result = _bool_dict.get(value, None)
+ if result == None:
+ raise ValueError, "cannot convert '%s' to bool" % value
+ return result
+
+def toFrequency(value):
+ if not isinstance(value, str):
+ raise TypeError, "wrong type '%s' should be str" % type(value)
+
+ if value.endswith('THz'):
+ return float(value[:-3]) * tera
+ elif value.endswith('GHz'):
+ return float(value[:-3]) * giga
+ elif value.endswith('MHz'):
+ return float(value[:-3]) * mega
+ elif value.endswith('kHz'):
+ return float(value[:-3]) * kilo
+ elif value.endswith('Hz'):
+ return float(value[:-2])
+
+ raise ValueError, "cannot convert '%s' to frequency" % value
+
+def toLatency(value):
+ if not isinstance(value, str):
+ raise TypeError, "wrong type '%s' should be str" % type(value)
+
+ if value.endswith('ps'):
+ return float(value[:-2]) * pico
+ elif value.endswith('ns'):
+ return float(value[:-2]) * nano
+ elif value.endswith('us'):
+ return float(value[:-2]) * micro
+ elif value.endswith('ms'):
+ return float(value[:-2]) * milli
+ elif value.endswith('s'):
+ return float(value[:-1])
+
+ raise ValueError, "cannot convert '%s' to latency" % value
+
+def anyToLatency(value):
+ """result is a clock period"""
+
+ if not isinstance(value, str):
+ raise TypeError, "wrong type '%s' should be str" % type(value)
+
+ try:
+ val = toFrequency(value)
+ if val != 0:
+ val = 1 / val
+ return val
+ except ValueError:
+ pass
+
+ try:
+ val = toLatency(value)
+ return val
+ except ValueError:
+ pass
+
+ raise ValueError, "cannot convert '%s' to clock period" % value
+
+def anyToFrequency(value):
+ """result is a clock period"""
+
+ if not isinstance(value, str):
+ raise TypeError, "wrong type '%s' should be str" % type(value)
+
+ try:
+ val = toFrequency(value)
+ return val
+ except ValueError:
+ pass
+
+ try:
+ val = toLatency(value)
+ if val != 0:
+ val = 1 / val
+ return val
+ except ValueError:
+ pass
+
+ raise ValueError, "cannot convert '%s' to clock period" % value
+
+def toNetworkBandwidth(value):
+ if not isinstance(value, str):
+ raise TypeError, "wrong type '%s' should be str" % type(value)
+
+ if value.endswith('Tbps'):
+ return float(value[:-4]) * tera
+ elif value.endswith('Gbps'):
+ return float(value[:-4]) * giga
+ elif value.endswith('Mbps'):
+ return float(value[:-4]) * mega
+ elif value.endswith('kbps'):
+ return float(value[:-4]) * kilo
+ elif value.endswith('bps'):
+ return float(value[:-3])
+ else:
+ return float(value)
+
+ raise ValueError, "cannot convert '%s' to network bandwidth" % value
+
+def toMemoryBandwidth(value):
+ if not isinstance(value, str):
+ raise TypeError, "wrong type '%s' should be str" % type(value)
+
+ if value.endswith('PB/s'):
+ return float(value[:-4]) * pebi
+ elif value.endswith('TB/s'):
+ return float(value[:-4]) * tebi
+ elif value.endswith('GB/s'):
+ return float(value[:-4]) * gibi
+ elif value.endswith('MB/s'):
+ return float(value[:-4]) * mebi
+ elif value.endswith('kB/s'):
+ return float(value[:-4]) * kibi
+ elif value.endswith('B/s'):
+ return float(value[:-3])
+
+ raise ValueError, "cannot convert '%s' to memory bandwidth" % value
+
+def toMemorySize(value):
+ if not isinstance(value, str):
+ raise TypeError, "wrong type '%s' should be str" % type(value)
+
+ if value.endswith('PB'):
+ return long(value[:-2]) * pebi
+ elif value.endswith('TB'):
+ return long(value[:-2]) * tebi
+ elif value.endswith('GB'):
+ return long(value[:-2]) * gibi
+ elif value.endswith('MB'):
+ return long(value[:-2]) * mebi
+ elif value.endswith('kB'):
+ return long(value[:-2]) * kibi
+ elif value.endswith('B'):
+ return long(value[:-1])
+
+ raise ValueError, "cannot convert '%s' to memory size" % value
diff --git a/src/python/m5/util/jobfile.py b/src/python/m5/util/jobfile.py
index c830895f6..9c59778e5 100644
--- a/src/python/m5/util/jobfile.py
+++ b/src/python/m5/util/jobfile.py
@@ -28,9 +28,6 @@
import sys
-from attrdict import optiondict
-from misc import crossproduct
-
class Data(object):
def __init__(self, name, desc, **kwargs):
self.name = name
@@ -108,7 +105,8 @@ class Data(object):
yield key
def optiondict(self):
- result = optiondict()
+ import m5.util
+ result = m5.util.optiondict()
for key in self:
result[key] = self[key]
return result
@@ -328,7 +326,9 @@ class Configuration(Data):
optgroups = [ g.subopts() for g in groups ]
if not optgroups:
return
- for options in crossproduct(optgroups):
+
+ import m5.util
+ for options in m5.util.crossproduct(optgroups):
for opt in options:
cpt = opt._group._checkpoint
if not isinstance(cpt, bool) and cpt != opt:
diff --git a/src/python/m5/util/misc.py b/src/python/m5/util/misc.py
deleted file mode 100644
index 094e3ed9a..000000000
--- a/src/python/m5/util/misc.py
+++ /dev/null
@@ -1,87 +0,0 @@
-# Copyright (c) 2004-2006 The Regents of The University of Michigan
-# All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met: redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer;
-# redistributions in binary form must reproduce the above copyright
-# notice, this list of conditions and the following disclaimer in the
-# documentation and/or other materials provided with the distribution;
-# neither the name of the copyright holders nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-#
-# Authors: Steve Reinhardt
-# Nathan Binkert
-
-#############################
-#
-# Utility classes & methods
-#
-#############################
-
-class Singleton(type):
- def __call__(cls, *args, **kwargs):
- if hasattr(cls, '_instance'):
- return cls._instance
-
- cls._instance = super(Singleton, cls).__call__(*args, **kwargs)
- return cls._instance
-
-# Apply method to object.
-# applyMethod(obj, 'meth', <args>) is equivalent to obj.meth(<args>)
-def applyMethod(obj, meth, *args, **kwargs):
- return getattr(obj, meth)(*args, **kwargs)
-
-# If the first argument is an (non-sequence) object, apply the named
-# method with the given arguments. If the first argument is a
-# sequence, apply the method to each element of the sequence (a la
-# 'map').
-def applyOrMap(objOrSeq, meth, *args, **kwargs):
- if not isinstance(objOrSeq, (list, tuple)):
- return applyMethod(objOrSeq, meth, *args, **kwargs)
- else:
- return [applyMethod(o, meth, *args, **kwargs) for o in objOrSeq]
-
-def crossproduct(items):
- if not isinstance(items, (list, tuple)):
- raise AttributeError, 'crossproduct works only on sequences'
-
- if not items:
- yield None
- return
-
- current = items[0]
- remainder = items[1:]
-
- if not hasattr(current, '__iter__'):
- current = [ current ]
-
- for item in current:
- for rem in crossproduct(remainder):
- data = [ item ]
- if rem:
- data += rem
- yield data
-
-def flatten(items):
- if not isinstance(items, (list, tuple)):
- yield items
- return
-
- for item in items:
- for flat in flatten(item):
- yield flat
diff --git a/src/python/m5/util/smartdict.py b/src/python/m5/util/smartdict.py
new file mode 100644
index 000000000..d85dbd517
--- /dev/null
+++ b/src/python/m5/util/smartdict.py
@@ -0,0 +1,154 @@
+# Copyright (c) 2005 The Regents of The University of Michigan
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met: redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer;
+# redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution;
+# neither the name of the copyright holders nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+# Authors: Nathan Binkert
+
+# The SmartDict class fixes a couple of issues with using the content
+# of os.environ or similar dicts of strings as Python variables:
+#
+# 1) Undefined variables should return False rather than raising KeyError.
+#
+# 2) String values of 'False', '0', etc., should evaluate to False
+# (not just the empty string).
+#
+# #1 is solved by overriding __getitem__, and #2 is solved by using a
+# proxy class for values and overriding __nonzero__ on the proxy.
+# Everything else is just to (a) make proxies behave like normal
+# values otherwise, (b) make sure any dict operation returns a proxy
+# rather than a normal value, and (c) coerce values written to the
+# dict to be strings.
+
+
+from convert import *
+
+class Variable(str):
+ """Intelligent proxy class for SmartDict. Variable will use the
+ various convert functions to attempt to convert values to useable
+ types"""
+ def __int__(self):
+ return toInteger(str(self))
+ def __long__(self):
+ return toLong(str(self))
+ def __float__(self):
+ return toFloat(str(self))
+ def __nonzero__(self):
+ return toBool(str(self))
+ def convert(self, other):
+ t = type(other)
+ if t == bool:
+ return bool(self)
+ if t == int:
+ return int(self)
+ if t == long:
+ return long(self)
+ if t == float:
+ return float(self)
+ return str(self)
+ def __lt__(self, other):
+ return self.convert(other) < other
+ def __le__(self, other):
+ return self.convert(other) <= other
+ def __eq__(self, other):
+ return self.convert(other) == other
+ def __ne__(self, other):
+ return self.convert(other) != other
+ def __gt__(self, other):
+ return self.convert(other) > other
+ def __ge__(self, other):
+ return self.convert(other) >= other
+
+ def __add__(self, other):
+ return self.convert(other) + other
+ def __sub__(self, other):
+ return self.convert(other) - other
+ def __mul__(self, other):
+ return self.convert(other) * other
+ def __div__(self, other):
+ return self.convert(other) / other
+ def __truediv__(self, other):
+ return self.convert(other) / other
+
+ def __radd__(self, other):
+ return other + self.convert(other)
+ def __rsub__(self, other):
+ return other - self.convert(other)
+ def __rmul__(self, other):
+ return other * self.convert(other)
+ def __rdiv__(self, other):
+ return other / self.convert(other)
+ def __rtruediv__(self, other):
+ return other / self.convert(other)
+
+class UndefinedVariable(object):
+ """Placeholder class to represent undefined variables. Will
+ generally cause an exception whenever it is used, but evaluates to
+ zero for boolean truth testing such as in an if statement"""
+ def __nonzero__(self):
+ return False
+
+class SmartDict(dict):
+ """Dictionary class that holds strings, but intelligently converts
+ those strings to other types depending on their usage"""
+
+ def __getitem__(self, key):
+ """returns a Variable proxy if the values exists in the database and
+ returns an UndefinedVariable otherwise"""
+
+ if key in self:
+ return Variable(dict.get(self, key))
+ else:
+ # Note that this does *not* change the contents of the dict,
+ # so that even after we call env['foo'] we still get a
+ # meaningful answer from "'foo' in env" (which
+ # calls dict.__contains__, which we do not override).
+ return UndefinedVariable()
+
+ def __setitem__(self, key, item):
+ """intercept the setting of any variable so that we always
+ store strings in the dict"""
+ dict.__setitem__(self, key, str(item))
+
+ def values(self):
+ return [ Variable(v) for v in dict.values(self) ]
+
+ def itervalues(self):
+ for value in dict.itervalues(self):
+ yield Variable(value)
+
+ def items(self):
+ return [ (k, Variable(v)) for k,v in dict.items(self) ]
+
+ def iteritems(self):
+ for key,value in dict.iteritems(self):
+ yield key, Variable(value)
+
+ def get(self, key, default='False'):
+ return Variable(dict.get(self, key, str(default)))
+
+ def setdefault(self, key, default='False'):
+ return Variable(dict.setdefault(self, key, str(default)))
+
+__all__ = [ 'SmartDict' ]