diff options
Diffstat (limited to 'module/lib')
101 files changed, 14462 insertions, 5019 deletions
diff --git a/module/lib/BeautifulSoup.py b/module/lib/BeautifulSoup.py index 55567f588..7278215ca 100644 --- a/module/lib/BeautifulSoup.py +++ b/module/lib/BeautifulSoup.py @@ -79,8 +79,8 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE, DAMMIT. from __future__ import generators __author__ = "Leonard Richardson (leonardr@segfault.org)" -__version__ = "3.0.8.1" -__copyright__ = "Copyright (c) 2004-2010 Leonard Richardson" +__version__ = "3.2.1" +__copyright__ = "Copyright (c) 2004-2012 Leonard Richardson" __license__ = "New-style BSD" from sgmllib import SGMLParser, SGMLParseError @@ -114,6 +114,21 @@ class PageElement(object): """Contains the navigational information for some part of the page (either a tag or a piece of text)""" + def _invert(h): + "Cheap function to invert a hash." + i = {} + for k,v in h.items(): + i[v] = k + return i + + XML_ENTITIES_TO_SPECIAL_CHARS = { "apos" : "'", + "quot" : '"', + "amp" : "&", + "lt" : "<", + "gt" : ">" } + + XML_SPECIAL_CHARS_TO_ENTITIES = _invert(XML_ENTITIES_TO_SPECIAL_CHARS) + def setup(self, parent=None, previous=None): """Sets up the initial relations between this element and other elements.""" @@ -421,6 +436,16 @@ class PageElement(object): s = unicode(s) return s + BARE_AMPERSAND_OR_BRACKET = re.compile("([<>]|" + + "&(?!#\d+;|#x[0-9a-fA-F]+;|\w+;)" + + ")") + + def _sub_entity(self, x): + """Used with a regular expression to substitute the + appropriate XML entity for an XML special character.""" + return "&" + self.XML_SPECIAL_CHARS_TO_ENTITIES[x.group(0)[0]] + ";" + + class NavigableString(unicode, PageElement): def __new__(cls, value): @@ -451,10 +476,12 @@ class NavigableString(unicode, PageElement): return str(self).decode(DEFAULT_OUTPUT_ENCODING) def __str__(self, encoding=DEFAULT_OUTPUT_ENCODING): + # Substitute outgoing XML entities. + data = self.BARE_AMPERSAND_OR_BRACKET.sub(self._sub_entity, self) if encoding: - return self.encode(encoding) + return data.encode(encoding) else: - return self + return data class CData(NavigableString): @@ -480,21 +507,6 @@ class Tag(PageElement): """Represents a found HTML tag with its attributes and contents.""" - def _invert(h): - "Cheap function to invert a hash." - i = {} - for k,v in h.items(): - i[v] = k - return i - - XML_ENTITIES_TO_SPECIAL_CHARS = { "apos" : "'", - "quot" : '"', - "amp" : "&", - "lt" : "<", - "gt" : ">" } - - XML_SPECIAL_CHARS_TO_ENTITIES = _invert(XML_ENTITIES_TO_SPECIAL_CHARS) - def _convertEntities(self, match): """Used in a call to re.sub to replace HTML, XML, and numeric entities with the appropriate Unicode characters. If HTML @@ -531,6 +543,8 @@ class Tag(PageElement): self.name = name if attrs is None: attrs = [] + elif isinstance(attrs, dict): + attrs = attrs.items() self.attrs = attrs self.contents = [] self.setup(parent, previous) @@ -679,15 +693,6 @@ class Tag(PageElement): def __unicode__(self): return self.__str__(None) - BARE_AMPERSAND_OR_BRACKET = re.compile("([<>]|" - + "&(?!#\d+;|#x[0-9a-fA-F]+;|\w+;)" - + ")") - - def _sub_entity(self, x): - """Used with a regular expression to substitute the - appropriate XML entity for an XML special character.""" - return "&" + self.XML_SPECIAL_CHARS_TO_ENTITIES[x.group(0)[0]] + ";" - def __str__(self, encoding=DEFAULT_OUTPUT_ENCODING, prettyPrint=False, indentLevel=0): """Returns a string or Unicode representation of this tag and @@ -1295,7 +1300,7 @@ class BeautifulStoneSoup(Tag, SGMLParser): """ nestingResetTriggers = self.NESTABLE_TAGS.get(name) - isNestable = nestingResetTriggers is not None + isNestable = nestingResetTriggers != None isResetNesting = self.RESET_NESTING_TAGS.has_key(name) popTo = None inclusive = True diff --git a/module/lib/Unzip.py b/module/lib/Unzip.py deleted file mode 100644 index f56fbe751..000000000 --- a/module/lib/Unzip.py +++ /dev/null @@ -1,50 +0,0 @@ -import zipfile -import os - -class Unzip: - def __init__(self): - pass - - def extract(self, file, dir): - if not dir.endswith(':') and not os.path.exists(dir): - os.mkdir(dir) - - zf = zipfile.ZipFile(file) - - # create directory structure to house files - self._createstructure(file, dir) - - # extract files to directory structure - for i, name in enumerate(zf.namelist()): - - if not name.endswith('/') and not name.endswith("config"): - print "extracting", name.replace("pyload/","") - outfile = open(os.path.join(dir, name.replace("pyload/","")), 'wb') - outfile.write(zf.read(name)) - outfile.flush() - outfile.close() - - def _createstructure(self, file, dir): - self._makedirs(self._listdirs(file), dir) - - def _makedirs(self, directories, basedir): - """ Create any directories that don't currently exist """ - for dir in directories: - curdir = os.path.join(basedir, dir) - if not os.path.exists(curdir): - os.mkdir(curdir) - - def _listdirs(self, file): - """ Grabs all the directories in the zip structure - This is necessary to create the structure before trying - to extract the file to it. """ - zf = zipfile.ZipFile(file) - - dirs = [] - - for name in zf.namelist(): - if name.endswith('/'): - dirs.append(name.replace("pyload/","")) - - dirs.sort() - return dirs diff --git a/module/lib/beaker/__init__.py b/module/lib/beaker/__init__.py index 792d60054..d07785c52 100644 --- a/module/lib/beaker/__init__.py +++ b/module/lib/beaker/__init__.py @@ -1 +1 @@ -# +__version__ = '1.6.4' diff --git a/module/lib/beaker/cache.py b/module/lib/beaker/cache.py index 4a96537ff..0ae96e020 100644 --- a/module/lib/beaker/cache.py +++ b/module/lib/beaker/cache.py @@ -1,156 +1,248 @@ -"""Cache object +"""This package contains the "front end" classes and functions +for Beaker caching. -The Cache object is used to manage a set of cache files and their -associated backend. The backends can be rotated on the fly by -specifying an alternate type when used. - -Advanced users can add new backends in beaker.backends +Included are the :class:`.Cache` and :class:`.CacheManager` classes, +as well as the function decorators :func:`.region_decorate`, +:func:`.region_invalidate`. """ - import warnings import beaker.container as container import beaker.util as util +from beaker.crypto.util import sha1 from beaker.exceptions import BeakerException, InvalidCacheBackendError +from beaker.synchronization import _threading import beaker.ext.memcached as memcached import beaker.ext.database as database import beaker.ext.sqla as sqla import beaker.ext.google as google -# Initialize the basic available backends -clsmap = { - 'memory':container.MemoryNamespaceManager, - 'dbm':container.DBMNamespaceManager, - 'file':container.FileNamespaceManager, - 'ext:memcached':memcached.MemcachedNamespaceManager, - 'ext:database':database.DatabaseNamespaceManager, - 'ext:sqla': sqla.SqlaNamespaceManager, - 'ext:google': google.GoogleNamespaceManager, - } - # Initialize the cache region dict cache_regions = {} +"""Dictionary of 'region' arguments. + +A "region" is a string name that refers to a series of cache +configuration arguments. An application may have multiple +"regions" - one which stores things in a memory cache, one +which writes data to files, etc. + +The dictionary stores string key names mapped to dictionaries +of configuration arguments. Example:: + + from beaker.cache import cache_regions + cache_regions.update({ + 'short_term':{ + 'expire':'60', + 'type':'memory' + }, + 'long_term':{ + 'expire':'1800', + 'type':'dbm', + 'data_dir':'/tmp', + } + }) +""" + + cache_managers = {} -try: - import pkg_resources - # Load up the additional entry point defined backends - for entry_point in pkg_resources.iter_entry_points('beaker.backends'): +class _backends(object): + initialized = False + + def __init__(self, clsmap): + self._clsmap = clsmap + self._mutex = _threading.Lock() + + def __getitem__(self, key): try: - NamespaceManager = entry_point.load() - name = entry_point.name - if name in clsmap: - raise BeakerException("NamespaceManager name conflict,'%s' " - "already loaded" % name) - clsmap[name] = NamespaceManager - except (InvalidCacheBackendError, SyntaxError): - # Ignore invalid backends + return self._clsmap[key] + except KeyError, e: + if not self.initialized: + self._mutex.acquire() + try: + if not self.initialized: + self._init() + self.initialized = True + + return self._clsmap[key] + finally: + self._mutex.release() + + raise e + + def _init(self): + try: + import pkg_resources + + # Load up the additional entry point defined backends + for entry_point in pkg_resources.iter_entry_points('beaker.backends'): + try: + namespace_manager = entry_point.load() + name = entry_point.name + if name in self._clsmap: + raise BeakerException("NamespaceManager name conflict,'%s' " + "already loaded" % name) + self._clsmap[name] = namespace_manager + except (InvalidCacheBackendError, SyntaxError): + # Ignore invalid backends + pass + except: + import sys + from pkg_resources import DistributionNotFound + # Warn when there's a problem loading a NamespaceManager + if not isinstance(sys.exc_info()[1], DistributionNotFound): + import traceback + from StringIO import StringIO + tb = StringIO() + traceback.print_exc(file=tb) + warnings.warn( + "Unable to load NamespaceManager " + "entry point: '%s': %s" % ( + entry_point, + tb.getvalue()), + RuntimeWarning, 2) + except ImportError: pass - except: - import sys - from pkg_resources import DistributionNotFound - # Warn when there's a problem loading a NamespaceManager - if not isinstance(sys.exc_info()[1], DistributionNotFound): - import traceback - from StringIO import StringIO - tb = StringIO() - traceback.print_exc(file=tb) - warnings.warn("Unable to load NamespaceManager entry point: '%s': " - "%s" % (entry_point, tb.getvalue()), RuntimeWarning, - 2) -except ImportError: - pass - - - - -def cache_region(region, *deco_args): - """Decorate a function to cache itself using a cache region - - The region decorator requires arguments if there are more than - 2 of the same named function, in the same module. This is - because the namespace used for the functions cache is based on - the functions name and the module. - - + +# Initialize the basic available backends +clsmap = _backends({ + 'memory': container.MemoryNamespaceManager, + 'dbm': container.DBMNamespaceManager, + 'file': container.FileNamespaceManager, + 'ext:memcached': memcached.MemcachedNamespaceManager, + 'ext:database': database.DatabaseNamespaceManager, + 'ext:sqla': sqla.SqlaNamespaceManager, + 'ext:google': google.GoogleNamespaceManager, + }) + + +def cache_region(region, *args): + """Decorate a function such that its return result is cached, + using a "region" to indicate the cache arguments. + Example:: - - # Add cache region settings to beaker: - beaker.cache.cache_regions.update(dict_of_config_region_options)) - - @cache_region('short_term', 'some_data') - def populate_things(search_term, limit, offset): - return load_the_data(search_term, limit, offset) - - return load('rabbits', 20, 0) - + + from beaker.cache import cache_regions, cache_region + + # configure regions + cache_regions.update({ + 'short_term':{ + 'expire':'60', + 'type':'memory' + } + }) + + @cache_region('short_term', 'load_things') + def load(search_term, limit, offset): + '''Load from a database given a search term, limit, offset.''' + return database.query(search_term)[offset:offset + limit] + + The decorator can also be used with object methods. The ``self`` + argument is not part of the cache key. This is based on the + actual string name ``self`` being in the first argument + position (new in 1.6):: + + class MyThing(object): + @cache_region('short_term', 'load_things') + def load(self, search_term, limit, offset): + '''Load from a database given a search term, limit, offset.''' + return database.query(search_term)[offset:offset + limit] + + Classmethods work as well - use ``cls`` as the name of the class argument, + and place the decorator around the function underneath ``@classmethod`` + (new in 1.6):: + + class MyThing(object): + @classmethod + @cache_region('short_term', 'load_things') + def load(cls, search_term, limit, offset): + '''Load from a database given a search term, limit, offset.''' + return database.query(search_term)[offset:offset + limit] + + :param region: String name of the region corresponding to the desired + caching arguments, established in :attr:`.cache_regions`. + + :param \*args: Optional ``str()``-compatible arguments which will uniquely + identify the key used by this decorated function, in addition + to the positional arguments passed to the function itself at call time. + This is recommended as it is needed to distinguish between any two functions + or methods that have the same name (regardless of parent class or not). + .. note:: - + The function being decorated must only be called with - positional arguments. - + positional arguments, and the arguments must support + being stringified with ``str()``. The concatenation + of the ``str()`` version of each argument, combined + with that of the ``*args`` sent to the decorator, + forms the unique cache key. + + .. note:: + + When a method on a class is decorated, the ``self`` or ``cls`` + argument in the first position is + not included in the "key" used for caching. New in 1.6. + """ - cache = [None] - - def decorate(func): - namespace = util.func_namespace(func) - def cached(*args): - reg = cache_regions[region] - if not reg.get('enabled', True): - return func(*args) - - if not cache[0]: - if region not in cache_regions: - raise BeakerException('Cache region not configured: %s' % region) - cache[0] = Cache._get_cache(namespace, reg) - - cache_key = " ".join(map(str, deco_args + args)) - def go(): - return func(*args) - - return cache[0].get_value(cache_key, createfunc=go) - cached._arg_namespace = namespace - cached._arg_region = region - return cached - return decorate + return _cache_decorate(args, None, None, region) def region_invalidate(namespace, region, *args): - """Invalidate a cache region namespace or decorated function - - This function only invalidates cache spaces created with the - cache_region decorator. - - :param namespace: Either the namespace of the result to invalidate, or the - cached function reference - - :param region: The region the function was cached to. If the function was - cached to a single region then this argument can be None - - :param args: Arguments that were used to differentiate the cached - function as well as the arguments passed to the decorated - function + """Invalidate a cache region corresponding to a function + decorated with :func:`.cache_region`. + + :param namespace: The namespace of the cache to invalidate. This is typically + a reference to the original function (as returned by the :func:`.cache_region` + decorator), where the :func:`.cache_region` decorator applies a "memo" to + the function in order to locate the string name of the namespace. + + :param region: String name of the region used with the decorator. This can be + ``None`` in the usual case that the decorated function itself is passed, + not the string name of the namespace. + + :param args: Stringifyable arguments that are used to locate the correct + key. This consists of the ``*args`` sent to the :func:`.cache_region` + decorator itself, plus the ``*args`` sent to the function itself + at runtime. Example:: - - # Add cache region settings to beaker: - beaker.cache.cache_regions.update(dict_of_config_region_options)) - - def populate_things(invalidate=False): - + + from beaker.cache import cache_regions, cache_region, region_invalidate + + # configure regions + cache_regions.update({ + 'short_term':{ + 'expire':'60', + 'type':'memory' + } + }) + + @cache_region('short_term', 'load_data') + def load(search_term, limit, offset): + '''Load from a database given a search term, limit, offset.''' + return database.query(search_term)[offset:offset + limit] + + def invalidate_search(search_term, limit, offset): + '''Invalidate the cached storage for a given search term, limit, offset.''' + region_invalidate(load, 'short_term', 'load_data', search_term, limit, offset) + + Note that when a method on a class is decorated, the first argument ``cls`` + or ``self`` is not included in the cache key. This means you don't send + it to :func:`.region_invalidate`:: + + class MyThing(object): @cache_region('short_term', 'some_data') - def load(search_term, limit, offset): - return load_the_data(search_term, limit, offset) - - # If the results should be invalidated first - if invalidate: - region_invalidate(load, None, 'some_data', - 'rabbits', 20, 0) - return load('rabbits', 20, 0) - + def load(self, search_term, limit, offset): + '''Load from a database given a search term, limit, offset.''' + return database.query(search_term)[offset:offset + limit] + + def invalidate_search(self, search_term, limit, offset): + '''Invalidate the cached storage for a given search term, limit, offset.''' + region_invalidate(self.load, 'short_term', 'some_data', search_term, limit, offset) + """ if callable(namespace): if not region: @@ -162,10 +254,9 @@ def region_invalidate(namespace, region, *args): "namespace is required") else: region = cache_regions[region] - + cache = Cache._get_cache(namespace, region) - cache_key = " ".join(str(x) for x in args) - cache.remove_value(cache_key) + _cache_decorator_invalidate(cache, region['key_length'], args) class Cache(object): @@ -180,7 +271,7 @@ class Cache(object): :param expiretime: seconds to keep cached data (legacy support) :param starttime: time when cache was cache was - + """ def __init__(self, namespace, type='memory', expiretime=None, starttime=None, expire=None, **nsargs): @@ -190,12 +281,12 @@ class Cache(object): raise cls except KeyError: raise TypeError("Unknown cache implementation %r" % type) - + self.namespace_name = namespace self.namespace = cls(namespace, **nsargs) self.expiretime = expiretime or expire self.starttime = starttime self.nsargs = nsargs - + @classmethod def _get_cache(cls, namespace, kw): key = namespace + str(kw) @@ -204,20 +295,19 @@ class Cache(object): except KeyError: cache_managers[key] = cache = cls(namespace, **kw) return cache - + def put(self, key, value, **kw): self._get_value(key, **kw).set_value(value) set_value = put - + def get(self, key, **kw): """Retrieve a cached value from the container""" return self._get_value(key, **kw).get_value() get_value = get - + def remove_value(self, key, **kw): mycontainer = self._get_value(key, **kw) - if mycontainer.has_current_value(): - mycontainer.clear_value() + mycontainer.clear_value() remove = remove_value def _get_value(self, key, **kw): @@ -229,9 +319,9 @@ class Cache(object): kw.setdefault('expiretime', self.expiretime) kw.setdefault('starttime', self.starttime) - + return container.Value(key, self.namespace, **kw) - + @util.deprecated("Specifying a " "'type' and other namespace configuration with cache.get()/put()/etc. " "is deprecated. Specify 'type' and other namespace configuration to " @@ -243,26 +333,26 @@ class Cache(object): kwargs = self.nsargs.copy() kwargs.update(kw) c = Cache(self.namespace.namespace, type=type, **kwargs) - return c._get_value(key, expiretime=expiretime, createfunc=createfunc, + return c._get_value(key, expiretime=expiretime, createfunc=createfunc, starttime=starttime) - + def clear(self): """Clear all the values from the namespace""" self.namespace.remove() - + # dict interface def __getitem__(self, key): return self.get(key) - + def __contains__(self, key): return self._get_value(key).has_current_value() - + def has_key(self, key): return key in self - + def __delitem__(self, key): self.remove_value(key) - + def __setitem__(self, key, value): self.put(key, value) @@ -270,110 +360,96 @@ class Cache(object): class CacheManager(object): def __init__(self, **kwargs): """Initialize a CacheManager object with a set of options - + Options should be parsed with the :func:`~beaker.util.parse_cache_config_options` function to ensure only valid options are used. - + """ self.kwargs = kwargs self.regions = kwargs.pop('cache_regions', {}) - + # Add these regions to the module global cache_regions.update(self.regions) - + def get_cache(self, name, **kwargs): kw = self.kwargs.copy() kw.update(kwargs) return Cache._get_cache(name, kw) - + def get_cache_region(self, name, region): if region not in self.regions: raise BeakerException('Cache region not configured: %s' % region) kw = self.regions[region] return Cache._get_cache(name, kw) - + def region(self, region, *args): """Decorate a function to cache itself using a cache region - + The region decorator requires arguments if there are more than - 2 of the same named function, in the same module. This is + two of the same named function, in the same module. This is because the namespace used for the functions cache is based on the functions name and the module. - - + + Example:: - + # Assuming a cache object is available like: cache = CacheManager(dict_of_config_options) - - + + def populate_things(): - + @cache.region('short_term', 'some_data') def load(search_term, limit, offset): return load_the_data(search_term, limit, offset) - + return load('rabbits', 20, 0) - + .. note:: - + The function being decorated must only be called with positional arguments. - + """ return cache_region(region, *args) def region_invalidate(self, namespace, region, *args): """Invalidate a cache region namespace or decorated function - + This function only invalidates cache spaces created with the cache_region decorator. - + :param namespace: Either the namespace of the result to invalidate, or the - name of the cached function - + cached function + :param region: The region the function was cached to. If the function was cached to a single region then this argument can be None - + :param args: Arguments that were used to differentiate the cached function as well as the arguments passed to the decorated function Example:: - + # Assuming a cache object is available like: cache = CacheManager(dict_of_config_options) - + def populate_things(invalidate=False): - + @cache.region('short_term', 'some_data') def load(search_term, limit, offset): return load_the_data(search_term, limit, offset) - + # If the results should be invalidated first if invalidate: cache.region_invalidate(load, None, 'some_data', 'rabbits', 20, 0) return load('rabbits', 20, 0) - - + + """ return region_invalidate(namespace, region, *args) - if callable(namespace): - if not region: - region = namespace._arg_region - namespace = namespace._arg_namespace - - if not region: - raise BeakerException("Region or callable function " - "namespace is required") - else: - region = self.regions[region] - - cache = self.get_cache(namespace, **region) - cache_key = " ".join(str(x) for x in args) - cache.remove_value(cache_key) def cache(self, *args, **kwargs): """Decorate a function to cache itself with supplied parameters @@ -387,46 +463,32 @@ class CacheManager(object): # Assuming a cache object is available like: cache = CacheManager(dict_of_config_options) - - + + def populate_things(): - + @cache.cache('mycache', expire=15) def load(search_term, limit, offset): return load_the_data(search_term, limit, offset) - + return load('rabbits', 20, 0) - + .. note:: - + The function being decorated must only be called with - positional arguments. + positional arguments. """ - cache = [None] - key = " ".join(str(x) for x in args) - - def decorate(func): - namespace = util.func_namespace(func) - def cached(*args): - if not cache[0]: - cache[0] = self.get_cache(namespace, **kwargs) - cache_key = key + " " + " ".join(str(x) for x in args) - def go(): - return func(*args) - return cache[0].get_value(cache_key, createfunc=go) - cached._arg_namespace = namespace - return cached - return decorate + return _cache_decorate(args, self, kwargs, None) def invalidate(self, func, *args, **kwargs): """Invalidate a cache decorated function - + This function only invalidates cache spaces created with the cache decorator. - + :param func: Decorated function to invalidate - + :param args: Used to make the key unique for this function, as in region() above. @@ -435,25 +497,93 @@ class CacheManager(object): function Example:: - + # Assuming a cache object is available like: cache = CacheManager(dict_of_config_options) - - + + def populate_things(invalidate=False): - + @cache.cache('mycache', type="file", expire=15) def load(search_term, limit, offset): return load_the_data(search_term, limit, offset) - + # If the results should be invalidated first if invalidate: cache.invalidate(load, 'mycache', 'rabbits', 20, 0, type="file") return load('rabbits', 20, 0) - + """ namespace = func._arg_namespace cache = self.get_cache(namespace, **kwargs) - cache_key = " ".join(str(x) for x in args) - cache.remove_value(cache_key) + if hasattr(func, '_arg_region'): + key_length = cache_regions[func._arg_region]['key_length'] + else: + key_length = kwargs.pop('key_length', 250) + _cache_decorator_invalidate(cache, key_length, args) + + +def _cache_decorate(deco_args, manager, kwargs, region): + """Return a caching function decorator.""" + + cache = [None] + + def decorate(func): + namespace = util.func_namespace(func) + skip_self = util.has_self_arg(func) + + def cached(*args): + if not cache[0]: + if region is not None: + if region not in cache_regions: + raise BeakerException( + 'Cache region not configured: %s' % region) + reg = cache_regions[region] + if not reg.get('enabled', True): + return func(*args) + cache[0] = Cache._get_cache(namespace, reg) + elif manager: + cache[0] = manager.get_cache(namespace, **kwargs) + else: + raise Exception("'manager + kwargs' or 'region' " + "argument is required") + + if skip_self: + try: + cache_key = " ".join(map(str, deco_args + args[1:])) + except UnicodeEncodeError: + cache_key = " ".join(map(unicode, deco_args + args[1:])) + else: + try: + cache_key = " ".join(map(str, deco_args + args)) + except UnicodeEncodeError: + cache_key = " ".join(map(unicode, deco_args + args)) + if region: + key_length = cache_regions[region]['key_length'] + else: + key_length = kwargs.pop('key_length', 250) + if len(cache_key) + len(namespace) > int(key_length): + cache_key = sha1(cache_key).hexdigest() + + def go(): + return func(*args) + + return cache[0].get_value(cache_key, createfunc=go) + cached._arg_namespace = namespace + if region is not None: + cached._arg_region = region + return cached + return decorate + + +def _cache_decorator_invalidate(cache, key_length, args): + """Invalidate a cache key based on function arguments.""" + + try: + cache_key = " ".join(map(str, args)) + except UnicodeEncodeError: + cache_key = " ".join(map(unicode, args)) + if len(cache_key) + len(cache.namespace_name) > key_length: + cache_key = sha1(cache_key).hexdigest() + cache.remove_value(cache_key) diff --git a/module/lib/beaker/container.py b/module/lib/beaker/container.py index 515e97af6..5a2e8e75c 100644 --- a/module/lib/beaker/container.py +++ b/module/lib/beaker/container.py @@ -1,11 +1,18 @@ """Container and Namespace classes""" -import anydbm + +import beaker.util as util +if util.py3k: + try: + import dbm as anydbm + except: + import dumbdbm as anydbm +else: + import anydbm import cPickle import logging import os import time -import beaker.util as util from beaker.exceptions import CreationAbortedError, MissingCacheParameter from beaker.synchronization import _threading, file_synchronizer, \ mutex_synchronizer, NameLock, null_synchronizer @@ -28,88 +35,162 @@ else: class NamespaceManager(object): """Handles dictionary operations and locking for a namespace of values. - + + :class:`.NamespaceManager` provides a dictionary-like interface, + implementing ``__getitem__()``, ``__setitem__()``, and + ``__contains__()``, as well as functions related to lock + acquisition. + The implementation for setting and retrieving the namespace data is handled by subclasses. - - NamespaceManager may be used alone, or may be privately accessed by - one or more Container objects. Container objects provide per-key + + NamespaceManager may be used alone, or may be accessed by + one or more :class:`.Value` objects. :class:`.Value` objects provide per-key services like expiration times and automatic recreation of values. - + Multiple NamespaceManagers created with a particular name will all share access to the same underlying datasource and will attempt to synchronize against a common mutex object. The scope of this sharing may be within a single process or across multiple processes, depending on the type of NamespaceManager used. - + The NamespaceManager itself is generally threadsafe, except in the case of the DBMNamespaceManager in conjunction with the gdbm dbm implementation. """ - + @classmethod def _init_dependencies(cls): - pass - + """Initialize module-level dependent libraries required + by this :class:`.NamespaceManager`.""" + def __init__(self, namespace): self._init_dependencies() self.namespace = namespace - + def get_creation_lock(self, key): + """Return a locking object that is used to synchronize + multiple threads or processes which wish to generate a new + cache value. + + This function is typically an instance of + :class:`.FileSynchronizer`, :class:`.ConditionSynchronizer`, + or :class:`.null_synchronizer`. + + The creation lock is only used when a requested value + does not exist, or has been expired, and is only used + by the :class:`.Value` key-management object in conjunction + with a "createfunc" value-creation function. + + """ raise NotImplementedError() def do_remove(self): + """Implement removal of the entire contents of this + :class:`.NamespaceManager`. + + e.g. for a file-based namespace, this would remove + all the files. + + The front-end to this method is the + :meth:`.NamespaceManager.remove` method. + + """ raise NotImplementedError() def acquire_read_lock(self): - pass + """Establish a read lock. + + This operation is called before a key is read. By + default the function does nothing. + + """ def release_read_lock(self): - pass + """Release a read lock. + + This operation is called after a key is read. By + default the function does nothing. + + """ + + def acquire_write_lock(self, wait=True, replace=False): + """Establish a write lock. + + This operation is called before a key is written. + A return value of ``True`` indicates the lock has + been acquired. - def acquire_write_lock(self, wait=True): + By default the function returns ``True`` unconditionally. + + 'replace' is a hint indicating the full contents + of the namespace may be safely discarded. Some backends + may implement this (i.e. file backend won't unpickle the + current contents). + + """ return True def release_write_lock(self): - pass + """Release a write lock. + + This operation is called after a new value is written. + By default this function does nothing. + + """ def has_key(self, key): + """Return ``True`` if the given key is present in this + :class:`.Namespace`. + """ return self.__contains__(key) def __getitem__(self, key): raise NotImplementedError() - + def __setitem__(self, key, value): raise NotImplementedError() - + def set_value(self, key, value, expiretime=None): - """Optional set_value() method called by Value. - - Allows an expiretime to be passed, for namespace - implementations which can prune their collections - using expiretime. - + """Sets a value in this :class:`.NamespaceManager`. + + This is the same as ``__setitem__()``, but + also allows an expiration time to be passed + at the same time. + """ self[key] = value - + def __contains__(self, key): raise NotImplementedError() def __delitem__(self, key): raise NotImplementedError() - + def keys(self): + """Return the list of all keys. + + This method may not be supported by all + :class:`.NamespaceManager` implementations. + + """ raise NotImplementedError() - + def remove(self): + """Remove the entire contents of this + :class:`.NamespaceManager`. + + e.g. for a file-based namespace, this would remove + all the files. + """ self.do_remove() - + class OpenResourceNamespaceManager(NamespaceManager): """A NamespaceManager where read/write operations require opening/ closing of a resource which is possibly mutexed. - + """ def __init__(self, namespace): NamespaceManager.__init__(self, namespace) @@ -120,51 +201,51 @@ class OpenResourceNamespaceManager(NamespaceManager): def get_access_lock(self): raise NotImplementedError() - def do_open(self, flags): + def do_open(self, flags, replace): raise NotImplementedError() - def do_close(self): + def do_close(self): raise NotImplementedError() - def acquire_read_lock(self): + def acquire_read_lock(self): self.access_lock.acquire_read_lock() try: - self.open('r', checkcount = True) + self.open('r', checkcount=True) except: self.access_lock.release_read_lock() raise - + def release_read_lock(self): try: - self.close(checkcount = True) + self.close(checkcount=True) finally: self.access_lock.release_read_lock() - - def acquire_write_lock(self, wait=True): + + def acquire_write_lock(self, wait=True, replace=False): r = self.access_lock.acquire_write_lock(wait) try: - if (wait or r): - self.open('c', checkcount = True) + if (wait or r): + self.open('c', checkcount=True, replace=replace) return r except: self.access_lock.release_write_lock() raise - - def release_write_lock(self): + + def release_write_lock(self): try: self.close(checkcount=True) finally: self.access_lock.release_write_lock() - def open(self, flags, checkcount=False): + def open(self, flags, checkcount=False, replace=False): self.mutex.acquire() try: if checkcount: - if self.openers == 0: - self.do_open(flags) + if self.openers == 0: + self.do_open(flags, replace) self.openers += 1 else: - self.do_open(flags) + self.do_open(flags, replace) self.openers = 1 finally: self.mutex.release() @@ -174,7 +255,7 @@ class OpenResourceNamespaceManager(NamespaceManager): try: if checkcount: self.openers -= 1 - if self.openers == 0: + if self.openers == 0: self.do_close() else: if self.openers > 0: @@ -191,7 +272,13 @@ class OpenResourceNamespaceManager(NamespaceManager): finally: self.access_lock.release_write_lock() + class Value(object): + """Implements synchronization, expiration, and value-creation logic + for a single value stored in a :class:`.NamespaceManager`. + + """ + __slots__ = 'key', 'createfunc', 'expiretime', 'expire_argument', 'starttime', 'storedtime',\ 'namespace' @@ -210,18 +297,18 @@ class Value(object): """ self.namespace.acquire_read_lock() - try: - return self.namespace.has_key(self.key) + try: + return self.key in self.namespace finally: self.namespace.release_read_lock() def can_have_value(self): - return self.has_current_value() or self.createfunc is not None + return self.has_current_value() or self.createfunc is not None def has_current_value(self): self.namespace.acquire_read_lock() - try: - has_value = self.namespace.has_key(self.key) + try: + has_value = self.key in self.namespace if has_value: try: stored, expired, value = self._get_value() @@ -258,7 +345,7 @@ class Value(object): except KeyError: # guard against un-mutexed backends raising KeyError has_value = False - + if not self.createfunc: raise KeyError(self.key) finally: @@ -317,10 +404,10 @@ class Value(object): self.set_value(value, stored) self.namespace.acquire_read_lock() except TypeError: - # occurs when the value is None. memcached - # may yank the rug from under us in which case + # occurs when the value is None. memcached + # may yank the rug from under us in which case # that's the result - raise KeyError(self.key) + raise KeyError(self.key) return stored, expired, value def set_value(self, value, storedtime=None): @@ -337,7 +424,7 @@ class Value(object): self.namespace.acquire_write_lock() try: debug("clear_value") - if self.namespace.has_key(self.key): + if self.key in self.namespace: try: del self.namespace[self.key] except KeyError: @@ -347,71 +434,80 @@ class Value(object): finally: self.namespace.release_write_lock() + class AbstractDictionaryNSManager(NamespaceManager): """A subclassable NamespaceManager that places data in a dictionary. - + Subclasses should provide a "dictionary" attribute or descriptor which returns a dict-like object. The dictionary will store keys that are local to the "namespace" attribute of this manager, so ensure that the dictionary will not be used by any other namespace. e.g.:: - + import collections cached_data = collections.defaultdict(dict) - + class MyDictionaryManager(AbstractDictionaryNSManager): def __init__(self, namespace): AbstractDictionaryNSManager.__init__(self, namespace) self.dictionary = cached_data[self.namespace] - + The above stores data in a global dictionary called "cached_data", which is structured as a dictionary of dictionaries, keyed first on namespace name to a sub-dictionary, then on actual cache key to value. - + """ - + def get_creation_lock(self, key): return NameLock( - identifier="memorynamespace/funclock/%s/%s" % (self.namespace, key), + identifier="memorynamespace/funclock/%s/%s" % + (self.namespace, key), reentrant=True ) - def __getitem__(self, key): + def __getitem__(self, key): return self.dictionary[key] - def __contains__(self, key): + def __contains__(self, key): return self.dictionary.__contains__(key) - def has_key(self, key): + def has_key(self, key): return self.dictionary.__contains__(key) - + def __setitem__(self, key, value): self.dictionary[key] = value - + def __delitem__(self, key): del self.dictionary[key] def do_remove(self): self.dictionary.clear() - + def keys(self): return self.dictionary.keys() - + + class MemoryNamespaceManager(AbstractDictionaryNSManager): + """:class:`.NamespaceManager` that uses a Python dictionary for storage.""" + namespaces = util.SyncDict() def __init__(self, namespace, **kwargs): AbstractDictionaryNSManager.__init__(self, namespace) - self.dictionary = MemoryNamespaceManager.namespaces.get(self.namespace, - dict) + self.dictionary = MemoryNamespaceManager.\ + namespaces.get(self.namespace, dict) + class DBMNamespaceManager(OpenResourceNamespaceManager): - def __init__(self, namespace, dbmmodule=None, data_dir=None, - dbm_dir=None, lock_dir=None, digest_filenames=True, **kwargs): + """:class:`.NamespaceManager` that uses ``dbm`` files for storage.""" + + def __init__(self, namespace, dbmmodule=None, data_dir=None, + dbm_dir=None, lock_dir=None, + digest_filenames=True, **kwargs): self.digest_filenames = digest_filenames - + if not dbm_dir and not data_dir: raise MissingCacheParameter("data_dir or dbm_dir is required") elif dbm_dir: @@ -419,7 +515,7 @@ class DBMNamespaceManager(OpenResourceNamespaceManager): else: self.dbm_dir = data_dir + "/container_dbm" util.verify_directory(self.dbm_dir) - + if not lock_dir and not data_dir: raise MissingCacheParameter("data_dir or lock_dir is required") elif lock_dir: @@ -433,50 +529,52 @@ class DBMNamespaceManager(OpenResourceNamespaceManager): self.dbm = None OpenResourceNamespaceManager.__init__(self, namespace) - self.file = util.encoded_path(root= self.dbm_dir, + self.file = util.encoded_path(root=self.dbm_dir, identifiers=[self.namespace], extension='.dbm', digest_filenames=self.digest_filenames) - + debug("data file %s", self.file) self._checkfile() def get_access_lock(self): return file_synchronizer(identifier=self.namespace, lock_dir=self.lock_dir) - + def get_creation_lock(self, key): return file_synchronizer( - identifier = "dbmcontainer/funclock/%s" % self.namespace, + identifier="dbmcontainer/funclock/%s/%s" % ( + self.namespace, key + ), lock_dir=self.lock_dir ) def file_exists(self, file): - if os.access(file, os.F_OK): + if os.access(file, os.F_OK): return True else: for ext in ('db', 'dat', 'pag', 'dir'): if os.access(file + os.extsep + ext, os.F_OK): return True - + return False - + def _checkfile(self): if not self.file_exists(self.file): - g = self.dbmmodule.open(self.file, 'c') + g = self.dbmmodule.open(self.file, 'c') g.close() - + def get_filenames(self): list = [] if os.access(self.file, os.F_OK): list.append(self.file) - + for ext in ('pag', 'dir', 'db', 'dat'): if os.access(self.file + os.extsep + ext, os.F_OK): list.append(self.file + os.extsep + ext) return list - def do_open(self, flags): + def do_open(self, flags, replace): debug("opening dbm file %s", self.file) try: self.dbm = self.dbmmodule.open(self.file, flags) @@ -488,17 +586,17 @@ class DBMNamespaceManager(OpenResourceNamespaceManager): if self.dbm is not None: debug("closing dbm file %s", self.file) self.dbm.close() - + def do_remove(self): for f in self.get_filenames(): os.remove(f) - - def __getitem__(self, key): + + def __getitem__(self, key): return cPickle.loads(self.dbm[key]) - def __contains__(self, key): - return self.dbm.has_key(key) - + def __contains__(self, key): + return key in self.dbm + def __setitem__(self, key, value): self.dbm[key] = cPickle.dumps(value) @@ -510,10 +608,17 @@ class DBMNamespaceManager(OpenResourceNamespaceManager): class FileNamespaceManager(OpenResourceNamespaceManager): + """:class:`.NamespaceManager` that uses binary files for storage. + + Each namespace is implemented as a single file storing a + dictionary of key/value pairs, serialized using the Python + ``pickle`` module. + + """ def __init__(self, namespace, data_dir=None, file_dir=None, lock_dir=None, digest_filenames=True, **kwargs): self.digest_filenames = digest_filenames - + if not file_dir and not data_dir: raise MissingCacheParameter("data_dir or file_dir is required") elif file_dir: @@ -531,38 +636,37 @@ class FileNamespaceManager(OpenResourceNamespaceManager): util.verify_directory(self.lock_dir) OpenResourceNamespaceManager.__init__(self, namespace) - self.file = util.encoded_path(root=self.file_dir, + self.file = util.encoded_path(root=self.file_dir, identifiers=[self.namespace], extension='.cache', digest_filenames=self.digest_filenames) self.hash = {} - + debug("data file %s", self.file) def get_access_lock(self): return file_synchronizer(identifier=self.namespace, lock_dir=self.lock_dir) - + def get_creation_lock(self, key): return file_synchronizer( - identifier = "filecontainer/funclock/%s" % self.namespace, - lock_dir = self.lock_dir + identifier="dbmcontainer/funclock/%s/%s" % ( + self.namespace, key + ), + lock_dir=self.lock_dir ) - + def file_exists(self, file): return os.access(file, os.F_OK) - def do_open(self, flags): - if self.file_exists(self.file): + def do_open(self, flags, replace): + if not replace and self.file_exists(self.file): fh = open(self.file, 'rb') - try: - self.hash = cPickle.load(fh) - except (IOError, OSError, EOFError, cPickle.PickleError, ValueError): - pass + self.hash = cPickle.load(fh) fh.close() self.flags = flags - + def do_close(self): if self.flags == 'c' or self.flags == 'w': fh = open(self.file, 'wb') @@ -571,22 +675,22 @@ class FileNamespaceManager(OpenResourceNamespaceManager): self.hash = {} self.flags = None - + def do_remove(self): try: os.remove(self.file) - except OSError, err: + except OSError: # for instance, because we haven't yet used this cache, # but client code has asked for a clear() operation... pass self.hash = {} - - def __getitem__(self, key): + + def __getitem__(self, key): return self.hash[key] - def __contains__(self, key): - return self.hash.has_key(key) - + def __contains__(self, key): + return key in self.hash + def __setitem__(self, key, value): self.hash[key] = value @@ -602,11 +706,13 @@ class FileNamespaceManager(OpenResourceNamespaceManager): namespace_classes = {} ContainerContext = dict - + + class ContainerMeta(type): def __init__(cls, classname, bases, dict_): namespace_classes[cls] = cls.namespace_class return type.__init__(cls, classname, bases, dict_) + def __call__(self, key, context, namespace, createfunc=None, expiretime=None, starttime=None, **kwargs): if namespace in context: @@ -617,16 +723,27 @@ class ContainerMeta(type): return Value(key, ns, createfunc=createfunc, expiretime=expiretime, starttime=starttime) + class Container(object): + """Implements synchronization and value-creation logic + for a 'value' stored in a :class:`.NamespaceManager`. + + :class:`.Container` and its subclasses are deprecated. The + :class:`.Value` class is now used for this purpose. + + """ __metaclass__ = ContainerMeta namespace_class = NamespaceManager + class FileContainer(Container): namespace_class = FileNamespaceManager + class MemoryContainer(Container): namespace_class = MemoryNamespaceManager + class DBMContainer(Container): namespace_class = DBMNamespaceManager diff --git a/module/lib/beaker/converters.py b/module/lib/beaker/converters.py index f0ad34963..3fb80692f 100644 --- a/module/lib/beaker/converters.py +++ b/module/lib/beaker/converters.py @@ -1,3 +1,5 @@ + + # (c) 2005 Ian Bicking and contributors; written for Paste (http://pythonpaste.org) # Licensed under the MIT license: http://www.opensource.org/licenses/mit-license.php def asbool(obj): @@ -12,6 +14,7 @@ def asbool(obj): "String is not true/false: %r" % obj) return bool(obj) + def aslist(obj, sep=None, strip=True): if isinstance(obj, (str, unicode)): lst = obj.split(sep) diff --git a/module/lib/beaker/crypto/__init__.py b/module/lib/beaker/crypto/__init__.py index 3e26b0c13..ac13da527 100644 --- a/module/lib/beaker/crypto/__init__.py +++ b/module/lib/beaker/crypto/__init__.py @@ -14,10 +14,14 @@ if util.jython: pass else: try: - from beaker.crypto.pycrypto import getKeyLength, aesEncrypt, aesDecrypt + from beaker.crypto.nsscrypto import getKeyLength, aesEncrypt, aesDecrypt keyLength = getKeyLength() except ImportError: - pass + try: + from beaker.crypto.pycrypto import getKeyLength, aesEncrypt, aesDecrypt + keyLength = getKeyLength() + except ImportError: + pass if not keyLength: has_aes = False diff --git a/module/lib/beaker/crypto/jcecrypto.py b/module/lib/beaker/crypto/jcecrypto.py index 4062d513e..ce313d6e1 100644 --- a/module/lib/beaker/crypto/jcecrypto.py +++ b/module/lib/beaker/crypto/jcecrypto.py @@ -16,6 +16,7 @@ import jarray # Initialization vector filled with zeros _iv = IvParameterSpec(jarray.zeros(16, 'b')) + def aesEncrypt(data, key): cipher = Cipher.getInstance('AES/CTR/NoPadding') skeySpec = SecretKeySpec(key, 'AES') @@ -25,6 +26,7 @@ def aesEncrypt(data, key): # magic. aesDecrypt = aesEncrypt + def getKeyLength(): maxlen = Cipher.getMaxAllowedKeyLength('AES/CTR/NoPadding') return min(maxlen, 256) / 8 diff --git a/module/lib/beaker/crypto/nsscrypto.py b/module/lib/beaker/crypto/nsscrypto.py new file mode 100644 index 000000000..3a7797877 --- /dev/null +++ b/module/lib/beaker/crypto/nsscrypto.py @@ -0,0 +1,45 @@ +"""Encryption module that uses nsscrypto""" +import nss.nss + +nss.nss.nss_init_nodb() + +# Apparently the rest of beaker doesn't care about the particluar cipher, +# mode and padding used. +# NOTE: A constant IV!!! This is only secure if the KEY is never reused!!! +_mech = nss.nss.CKM_AES_CBC_PAD +_iv = '\0' * nss.nss.get_iv_length(_mech) + +def aesEncrypt(data, key): + slot = nss.nss.get_best_slot(_mech) + + key_obj = nss.nss.import_sym_key(slot, _mech, nss.nss.PK11_OriginGenerated, + nss.nss.CKA_ENCRYPT, nss.nss.SecItem(key)) + + param = nss.nss.param_from_iv(_mech, nss.nss.SecItem(_iv)) + ctx = nss.nss.create_context_by_sym_key(_mech, nss.nss.CKA_ENCRYPT, key_obj, + param) + l1 = ctx.cipher_op(data) + # Yes, DIGEST. This needs fixing in NSS, but apparently nobody (including + # me :( ) cares enough. + l2 = ctx.digest_final() + + return l1 + l2 + +def aesDecrypt(data, key): + slot = nss.nss.get_best_slot(_mech) + + key_obj = nss.nss.import_sym_key(slot, _mech, nss.nss.PK11_OriginGenerated, + nss.nss.CKA_DECRYPT, nss.nss.SecItem(key)) + + param = nss.nss.param_from_iv(_mech, nss.nss.SecItem(_iv)) + ctx = nss.nss.create_context_by_sym_key(_mech, nss.nss.CKA_DECRYPT, key_obj, + param) + l1 = ctx.cipher_op(data) + # Yes, DIGEST. This needs fixing in NSS, but apparently nobody (including + # me :( ) cares enough. + l2 = ctx.digest_final() + + return l1 + l2 + +def getKeyLength(): + return 32 diff --git a/module/lib/beaker/crypto/pbkdf2.py b/module/lib/beaker/crypto/pbkdf2.py index 96dc5fbb2..71df22198 100644 --- a/module/lib/beaker/crypto/pbkdf2.py +++ b/module/lib/beaker/crypto/pbkdf2.py @@ -5,22 +5,22 @@ # # Copyright (C) 2007 Dwayne C. Litzenberger <dlitz@dlitz.net> # All rights reserved. -# +# # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose and without fee is hereby granted, # provided that the above copyright notice appear in all copies and that # both that copyright notice and this permission notice appear in # supporting documentation. -# -# THE AUTHOR PROVIDES THIS SOFTWARE ``AS IS'' AND ANY EXPRESSED OR -# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES -# OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. -# IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, +# +# THE AUTHOR PROVIDES THIS SOFTWARE ``AS IS'' AND ANY EXPRESSED OR +# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES +# OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT -# NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # # Country of origin: Canada @@ -74,12 +74,14 @@ from base64 import b64encode from beaker.crypto.util import hmac as HMAC, hmac_sha1 as SHA1 + def strxor(a, b): return "".join([chr(ord(x) ^ ord(y)) for (x, y) in zip(a, b)]) + class PBKDF2(object): """PBKDF2.py : PKCS#5 v2.0 Password-Based Key Derivation - + This implementation takes a passphrase and a salt (and optionally an iteration count, a digest module, and a MAC module) and provides a file-like object from which an arbitrarily-sized key can be read. @@ -89,10 +91,10 @@ class PBKDF2(object): The idea behind PBKDF2 is to derive a cryptographic key from a passphrase and a salt. - + PBKDF2 may also be used as a strong salted password hash. The 'crypt' function is provided for that purpose. - + Remember: Keys generated using PBKDF2 are only as strong as the passphrases they are derived from. """ @@ -109,7 +111,7 @@ class PBKDF2(object): """Pseudorandom function. e.g. HMAC-SHA1""" return self.__macmodule(key=key, msg=msg, digestmod=self.__digestmodule).digest() - + def read(self, bytes): """Read the specified number of key bytes.""" if self.closed: @@ -121,7 +123,7 @@ class PBKDF2(object): while size < bytes: i += 1 if i > 0xffffffff: - # We could return "" here, but + # We could return "" here, but raise OverflowError("derived key too long") block = self.__f(i) blocks.append(block) @@ -131,17 +133,17 @@ class PBKDF2(object): self.__buf = buf[bytes:] self.__blockNum = i return retval - + def __f(self, i): # i must fit within 32 bits - assert (1 <= i <= 0xffffffff) + assert (1 <= i and i <= 0xffffffff) U = self.__prf(self.__passphrase, self.__salt + pack("!L", i)) result = U - for j in xrange(2, 1+self.__iterations): + for j in xrange(2, 1 + self.__iterations): U = self.__prf(self.__passphrase, U) result = strxor(result, U) return result - + def hexread(self, octets): """Read the specified number of octets. Return them as hexadecimal. @@ -151,7 +153,7 @@ class PBKDF2(object): def _setup(self, passphrase, salt, iterations, prf): # Sanity checks: - + # passphrase and salt must be str or unicode (in the latter # case, we convert to UTF-8) if isinstance(passphrase, unicode): @@ -168,7 +170,7 @@ class PBKDF2(object): raise TypeError("iterations must be an integer") if iterations < 1: raise ValueError("iterations must be at least 1") - + # prf must be callable if not callable(prf): raise TypeError("prf must be callable") @@ -180,7 +182,7 @@ class PBKDF2(object): self.__blockNum = 0 self.__buf = "" self.closed = False - + def close(self): """Close the stream.""" if not self.closed: @@ -192,15 +194,16 @@ class PBKDF2(object): del self.__buf self.closed = True + def crypt(word, salt=None, iterations=None): """PBKDF2-based unix crypt(3) replacement. - + The number of iterations specified in the salt overrides the 'iterations' parameter. The effective hash length is 192 bits. """ - + # Generate a (pseudo-)random salt if the user hasn't provided one. if salt is None: salt = _makesalt() @@ -229,7 +232,7 @@ def crypt(word, salt=None, iterations=None): iterations = converted if not (iterations >= 1): raise ValueError("Invalid salt") - + # Make sure the salt matches the allowed character set allowed = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789./" for ch in salt: @@ -249,18 +252,20 @@ def crypt(word, salt=None, iterations=None): # crypt. PBKDF2.crypt = staticmethod(crypt) + def _makesalt(): """Return a 48-bit pseudorandom salt for crypt(). - + This function is not suitable for generating cryptographic secrets. """ binarysalt = "".join([pack("@H", randint(0, 0xffff)) for i in range(3)]) return b64encode(binarysalt, "./") + def test_pbkdf2(): """Module self-test""" from binascii import a2b_hex - + # # Test vectors from RFC 3962 # @@ -279,23 +284,23 @@ def test_pbkdf2(): raise RuntimeError("self-test failed") # Test 3 - result = PBKDF2("X"*64, "pass phrase equals block size", 1200).hexread(32) + result = PBKDF2("X" * 64, "pass phrase equals block size", 1200).hexread(32) expected = ("139c30c0966bc32ba55fdbf212530ac9" "c5ec59f1a452f5cc9ad940fea0598ed1") if result != expected: raise RuntimeError("self-test failed") - + # Test 4 - result = PBKDF2("X"*65, "pass phrase exceeds block size", 1200).hexread(32) + result = PBKDF2("X" * 65, "pass phrase exceeds block size", 1200).hexread(32) expected = ("9ccad6d468770cd51b10e6a68721be61" "1a8b4d282601db3b36be9246915ec82a") if result != expected: raise RuntimeError("self-test failed") - + # # Other test vectors # - + # Chunked read f = PBKDF2("kickstart", "workbench", 256) result = f.read(17) @@ -306,7 +311,7 @@ def test_pbkdf2(): expected = PBKDF2("kickstart", "workbench", 256).read(40) if result != expected: raise RuntimeError("self-test failed") - + # # crypt() test vectors # @@ -316,7 +321,7 @@ def test_pbkdf2(): expected = '$p5k2$$exec$r1EWMCMk7Rlv3L/RNcFXviDefYa0hlql' if result != expected: raise RuntimeError("self-test failed") - + # crypt 2 result = crypt("gnu", '$p5k2$c$u9HvcT4d$.....') expected = '$p5k2$c$u9HvcT4d$Sd1gwSVCLZYAuqZ25piRnbBEoAesaa/g' @@ -328,7 +333,7 @@ def test_pbkdf2(): expected = "$p5k2$d$tUsch7fU$nqDkaxMDOFBeJsTSfABsyn.PYUXilHwL" if result != expected: raise RuntimeError("self-test failed") - + # crypt 4 (unicode) result = crypt(u'\u0399\u03c9\u03b1\u03bd\u03bd\u03b7\u03c2', '$p5k2$$KosHgqNo$9mjN8gqjt02hDoP0c2J0ABtLIwtot8cQ') diff --git a/module/lib/beaker/crypto/pycrypto.py b/module/lib/beaker/crypto/pycrypto.py index a3eb4d9db..6657bff56 100644 --- a/module/lib/beaker/crypto/pycrypto.py +++ b/module/lib/beaker/crypto/pycrypto.py @@ -9,23 +9,26 @@ try: def aesEncrypt(data, key): cipher = aes.AES(key) return cipher.process(data) - + # magic. aesDecrypt = aesEncrypt - + except ImportError: from Crypto.Cipher import AES + from Crypto.Util import Counter def aesEncrypt(data, key): - cipher = AES.new(key) - - data = data + (" " * (16 - (len(data) % 16))) + cipher = AES.new(key, AES.MODE_CTR, + counter=Counter.new(128, initial_value=0)) + return cipher.encrypt(data) def aesDecrypt(data, key): - cipher = AES.new(key) + cipher = AES.new(key, AES.MODE_CTR, + counter=Counter.new(128, initial_value=0)) + return cipher.decrypt(data) + - return cipher.decrypt(data).rstrip() def getKeyLength(): return 32 diff --git a/module/lib/beaker/crypto/util.py b/module/lib/beaker/crypto/util.py index d97e8ce6f..7f96ac856 100644 --- a/module/lib/beaker/crypto/util.py +++ b/module/lib/beaker/crypto/util.py @@ -6,9 +6,9 @@ try: # Use PyCrypto (if available) from Crypto.Hash import HMAC as hmac, SHA as hmac_sha1 sha1 = hmac_sha1.new - + except ImportError: - + # PyCrypto not available. Use the Python standard library. import hmac diff --git a/module/lib/beaker/exceptions.py b/module/lib/beaker/exceptions.py index cc0eed286..4f81e456d 100644 --- a/module/lib/beaker/exceptions.py +++ b/module/lib/beaker/exceptions.py @@ -1,9 +1,14 @@ """Beaker exception classes""" + class BeakerException(Exception): pass +class BeakerWarning(RuntimeWarning): + """Issued at runtime.""" + + class CreationAbortedError(Exception): """Deprecated.""" diff --git a/module/lib/beaker/ext/database.py b/module/lib/beaker/ext/database.py index 701e6f7d2..462fb8de4 100644 --- a/module/lib/beaker/ext/database.py +++ b/module/lib/beaker/ext/database.py @@ -14,6 +14,7 @@ sa = None pool = None types = None + class DatabaseNamespaceManager(OpenResourceNamespaceManager): metadatas = SyncDict() tables = SyncDict() @@ -30,12 +31,12 @@ class DatabaseNamespaceManager(OpenResourceNamespaceManager): except ImportError: raise InvalidCacheBackendError("Database cache backend requires " "the 'sqlalchemy' library") - + def __init__(self, namespace, url=None, sa_opts=None, optimistic=False, table_name='beaker_cache', data_dir=None, lock_dir=None, - **params): + schema_name=None, **params): """Creates a database namespace manager - + ``url`` SQLAlchemy compliant db url ``sa_opts`` @@ -47,9 +48,11 @@ class DatabaseNamespaceManager(OpenResourceNamespaceManager): numbers. ``table_name`` The table name to use in the database for the cache. + ``schema_name`` + The schema name to use in the database for the cache. """ OpenResourceNamespaceManager.__init__(self, namespace) - + if sa_opts is None: sa_opts = params @@ -58,14 +61,16 @@ class DatabaseNamespaceManager(OpenResourceNamespaceManager): elif data_dir: self.lock_dir = data_dir + "/container_db_lock" if self.lock_dir: - verify_directory(self.lock_dir) - + verify_directory(self.lock_dir) + # Check to see if the table's been created before url = url or sa_opts['sa.url'] table_key = url + table_name + def make_cache(): # Check to see if we have a connection pool open already meta_key = url + table_name + def make_meta(): # SQLAlchemy pops the url, this ensures it sticks around # later @@ -82,7 +87,8 @@ class DatabaseNamespaceManager(OpenResourceNamespaceManager): sa.Column('accessed', types.DateTime, nullable=False), sa.Column('created', types.DateTime, nullable=False), sa.Column('data', types.PickleType, nullable=False), - sa.UniqueConstraint('namespace') + sa.UniqueConstraint('namespace'), + schema=schema_name if schema_name else meta.schema ) cache.create(checkfirst=True) return cache @@ -90,24 +96,26 @@ class DatabaseNamespaceManager(OpenResourceNamespaceManager): self._is_new = False self.loaded = False self.cache = DatabaseNamespaceManager.tables.get(table_key, make_cache) - + def get_access_lock(self): return null_synchronizer() def get_creation_lock(self, key): return file_synchronizer( - identifier ="databasecontainer/funclock/%s" % self.namespace, - lock_dir = self.lock_dir) + identifier="databasecontainer/funclock/%s/%s" % ( + self.namespace, key + ), + lock_dir=self.lock_dir) - def do_open(self, flags): + def do_open(self, flags, replace): # If we already loaded the data, don't bother loading it again if self.loaded: self.flags = flags return - + cache = self.cache - result = sa.select([cache.c.data], - cache.c.namespace==self.namespace + result = sa.select([cache.c.data], + cache.c.namespace == self.namespace ).execute().fetchone() if not result: self._is_new = True @@ -123,7 +131,7 @@ class DatabaseNamespaceManager(OpenResourceNamespaceManager): self._is_new = True self.flags = flags self.loaded = True - + def do_close(self): if self.flags is not None and (self.flags == 'c' or self.flags == 'w'): cache = self.cache @@ -133,25 +141,25 @@ class DatabaseNamespaceManager(OpenResourceNamespaceManager): created=datetime.now()) self._is_new = False else: - cache.update(cache.c.namespace==self.namespace).execute( + cache.update(cache.c.namespace == self.namespace).execute( data=self.hash, accessed=datetime.now()) self.flags = None - + def do_remove(self): cache = self.cache - cache.delete(cache.c.namespace==self.namespace).execute() + cache.delete(cache.c.namespace == self.namespace).execute() self.hash = {} - + # We can retain the fact that we did a load attempt, but since the # file is gone this will be a new namespace should it be saved. self._is_new = True - def __getitem__(self, key): + def __getitem__(self, key): return self.hash[key] - def __contains__(self, key): - return self.hash.has_key(key) - + def __contains__(self, key): + return key in self.hash + def __setitem__(self, key, value): self.hash[key] = value @@ -161,5 +169,6 @@ class DatabaseNamespaceManager(OpenResourceNamespaceManager): def keys(self): return self.hash.keys() + class DatabaseContainer(Container): namespace_manager = DatabaseNamespaceManager diff --git a/module/lib/beaker/ext/google.py b/module/lib/beaker/ext/google.py index dd8380d7f..d0a6205f4 100644 --- a/module/lib/beaker/ext/google.py +++ b/module/lib/beaker/ext/google.py @@ -10,6 +10,7 @@ log = logging.getLogger(__name__) db = None + class GoogleNamespaceManager(OpenResourceNamespaceManager): tables = {} @@ -23,11 +24,11 @@ class GoogleNamespaceManager(OpenResourceNamespaceManager): except ImportError: raise InvalidCacheBackendError("Datastore cache backend requires the " "'google.appengine.ext' library") - + def __init__(self, namespace, table_name='beaker_cache', **params): """Creates a datastore namespace manager""" OpenResourceNamespaceManager.__init__(self, namespace) - + def make_cache(): table_dict = dict(created=db.DateTimeProperty(), accessed=db.DateTimeProperty(), @@ -40,11 +41,11 @@ class GoogleNamespaceManager(OpenResourceNamespaceManager): self._is_new = False self.loaded = False self.log_debug = logging.DEBUG >= log.getEffectiveLevel() - + # Google wants namespaces to start with letters, change the namespace # to start with a letter self.namespace = 'p%s' % self.namespace - + def get_access_lock(self): return null_synchronizer() @@ -52,14 +53,14 @@ class GoogleNamespaceManager(OpenResourceNamespaceManager): # this is weird, should probably be present return null_synchronizer() - def do_open(self, flags): + def do_open(self, flags, replace): # If we already loaded the data, don't bother loading it again if self.loaded: self.flags = flags return - + item = self.cache.get_by_key_name(self.namespace) - + if not item: self._is_new = True self.hash = {} @@ -74,7 +75,7 @@ class GoogleNamespaceManager(OpenResourceNamespaceManager): self._is_new = True self.flags = flags self.loaded = True - + def do_close(self): if self.flags is not None and (self.flags == 'c' or self.flags == 'w'): if self._is_new: @@ -90,12 +91,12 @@ class GoogleNamespaceManager(OpenResourceNamespaceManager): item.accessed = datetime.now() item.put() self.flags = None - + def do_remove(self): item = self.cache.get_by_key_name(self.namespace) item.delete() self.hash = {} - + # We can retain the fact that we did a load attempt, but since the # file is gone this will be a new namespace should it be saved. self._is_new = True @@ -103,9 +104,9 @@ class GoogleNamespaceManager(OpenResourceNamespaceManager): def __getitem__(self, key): return self.hash[key] - def __contains__(self, key): - return self.hash.has_key(key) - + def __contains__(self, key): + return key in self.hash + def __setitem__(self, key, value): self.hash[key] = value @@ -114,7 +115,7 @@ class GoogleNamespaceManager(OpenResourceNamespaceManager): def keys(self): return self.hash.keys() - + class GoogleContainer(Container): namespace_class = GoogleNamespaceManager diff --git a/module/lib/beaker/ext/memcached.py b/module/lib/beaker/ext/memcached.py index 96516953f..94e3da3c9 100644 --- a/module/lib/beaker/ext/memcached.py +++ b/module/lib/beaker/ext/memcached.py @@ -1,54 +1,118 @@ +from __future__ import with_statement from beaker.container import NamespaceManager, Container +from beaker.crypto.util import sha1 from beaker.exceptions import InvalidCacheBackendError, MissingCacheParameter -from beaker.synchronization import file_synchronizer, null_synchronizer -from beaker.util import verify_directory, SyncDict +from beaker.synchronization import file_synchronizer +from beaker.util import verify_directory, SyncDict, parse_memcached_behaviors import warnings -memcache = None +MAX_KEY_LENGTH = 250 -class MemcachedNamespaceManager(NamespaceManager): - clients = SyncDict() - - @classmethod - def _init_dependencies(cls): +_client_libs = {} + + +def _load_client(name='auto'): + if name in _client_libs: + return _client_libs[name] + + def _pylibmc(): + global pylibmc + import pylibmc + return pylibmc + + def _cmemcache(): + global cmemcache + import cmemcache + warnings.warn("cmemcache is known to have serious " + "concurrency issues; consider using 'memcache' " + "or 'pylibmc'") + return cmemcache + + def _memcache(): global memcache - if memcache is not None: - return - try: - import pylibmc as memcache - except ImportError: + import memcache + return memcache + + def _auto(): + for _client in (_pylibmc, _cmemcache, _memcache): try: - import cmemcache as memcache - warnings.warn("cmemcache is known to have serious " - "concurrency issues; consider using 'memcache' or 'pylibmc'") + return _client() except ImportError: - try: - import memcache - except ImportError: - raise InvalidCacheBackendError("Memcached cache backend requires either " - "the 'memcache' or 'cmemcache' library") - - def __init__(self, namespace, url=None, data_dir=None, lock_dir=None, **params): + pass + else: + raise InvalidCacheBackendError( + "Memcached cache backend requires one " + "of: 'pylibmc' or 'memcache' to be installed.") + + clients = { + 'pylibmc': _pylibmc, + 'cmemcache': _cmemcache, + 'memcache': _memcache, + 'auto': _auto + } + _client_libs[name] = clib = clients[name]() + return clib + + +def _is_configured_for_pylibmc(memcache_module_config, memcache_client): + return memcache_module_config == 'pylibmc' or \ + memcache_client.__name__.startswith('pylibmc') + + +class MemcachedNamespaceManager(NamespaceManager): + """Provides the :class:`.NamespaceManager` API over a memcache client library.""" + + clients = SyncDict() + + def __new__(cls, *args, **kw): + memcache_module = kw.pop('memcache_module', 'auto') + + memcache_client = _load_client(memcache_module) + + if _is_configured_for_pylibmc(memcache_module, memcache_client): + return object.__new__(PyLibMCNamespaceManager) + else: + return object.__new__(MemcachedNamespaceManager) + + def __init__(self, namespace, url, + memcache_module='auto', + data_dir=None, lock_dir=None, + **kw): NamespaceManager.__init__(self, namespace) - + + _memcache_module = _client_libs[memcache_module] + if not url: - raise MissingCacheParameter("url is required") - + raise MissingCacheParameter("url is required") + if lock_dir: self.lock_dir = lock_dir elif data_dir: self.lock_dir = data_dir + "/container_mcd_lock" if self.lock_dir: - verify_directory(self.lock_dir) - - self.mc = MemcachedNamespaceManager.clients.get(url, memcache.Client, url.split(';')) + verify_directory(self.lock_dir) + + # Check for pylibmc namespace manager, in which case client will be + # instantiated by subclass __init__, to handle behavior passing to the + # pylibmc client + if not _is_configured_for_pylibmc(memcache_module, _memcache_module): + self.mc = MemcachedNamespaceManager.clients.get( + (memcache_module, url), + _memcache_module.Client, + url.split(';')) def get_creation_lock(self, key): return file_synchronizer( - identifier="memcachedcontainer/funclock/%s" % self.namespace,lock_dir = self.lock_dir) + identifier="memcachedcontainer/funclock/%s/%s" % + (self.namespace, key), lock_dir=self.lock_dir) def _format_key(self, key): - return self.namespace + '_' + key.replace(' ', '\302\267') + if not isinstance(key, str): + key = key.decode('ascii') + formated_key = (self.namespace + '_' + key).replace(' ', '\302\267') + if len(formated_key) > MAX_KEY_LENGTH: + formated_key = sha1(formated_key).hexdigest() + return formated_key def __getitem__(self, key): return self.mc.get(self._format_key(key)) @@ -68,15 +132,72 @@ class MemcachedNamespaceManager(NamespaceManager): def __setitem__(self, key, value): self.set_value(key, value) - + def __delitem__(self, key): self.mc.delete(self._format_key(key)) def do_remove(self): self.mc.flush_all() - + def keys(self): - raise NotImplementedError("Memcache caching does not support iteration of all cache keys") + raise NotImplementedError( + "Memcache caching does not " + "support iteration of all cache keys") + + +class PyLibMCNamespaceManager(MemcachedNamespaceManager): + """Provide thread-local support for pylibmc.""" + + def __init__(self, *arg, **kw): + super(PyLibMCNamespaceManager, self).__init__(*arg, **kw) + + memcache_module = kw.get('memcache_module', 'auto') + _memcache_module = _client_libs[memcache_module] + protocol = kw.get('protocol', 'text') + username = kw.get('username', None) + password = kw.get('password', None) + url = kw.get('url') + behaviors = parse_memcached_behaviors(kw) + + self.mc = MemcachedNamespaceManager.clients.get( + (memcache_module, url), + _memcache_module.Client, + servers=url.split(';'), behaviors=behaviors, + binary=(protocol == 'binary'), username=username, + password=password) + self.pool = pylibmc.ThreadMappedPool(self.mc) + + def __getitem__(self, key): + with self.pool.reserve() as mc: + return mc.get(self._format_key(key)) + + def __contains__(self, key): + with self.pool.reserve() as mc: + value = mc.get(self._format_key(key)) + return value is not None + + def has_key(self, key): + return key in self + + def set_value(self, key, value, expiretime=None): + with self.pool.reserve() as mc: + if expiretime: + mc.set(self._format_key(key), value, time=expiretime) + else: + mc.set(self._format_key(key), value) + + def __setitem__(self, key, value): + self.set_value(key, value) + + def __delitem__(self, key): + with self.pool.reserve() as mc: + mc.delete(self._format_key(key)) + + def do_remove(self): + with self.pool.reserve() as mc: + mc.flush_all() + class MemcachedContainer(Container): + """Container class which invokes :class:`.MemcacheNamespaceManager`.""" namespace_class = MemcachedNamespaceManager diff --git a/module/lib/beaker/ext/sqla.py b/module/lib/beaker/ext/sqla.py index 8c79633c1..6405c2919 100644 --- a/module/lib/beaker/ext/sqla.py +++ b/module/lib/beaker/ext/sqla.py @@ -13,6 +13,7 @@ log = logging.getLogger(__name__) sa = None + class SqlaNamespaceManager(OpenResourceNamespaceManager): binds = SyncDict() tables = SyncDict() @@ -47,7 +48,7 @@ class SqlaNamespaceManager(OpenResourceNamespaceManager): elif data_dir: self.lock_dir = data_dir + "/container_db_lock" if self.lock_dir: - verify_directory(self.lock_dir) + verify_directory(self.lock_dir) self.bind = self.__class__.binds.get(str(bind.url), lambda: bind) self.table = self.__class__.tables.get('%s:%s' % (bind.url, table.name), @@ -61,10 +62,10 @@ class SqlaNamespaceManager(OpenResourceNamespaceManager): def get_creation_lock(self, key): return file_synchronizer( - identifier ="databasecontainer/funclock/%s" % self.namespace, + identifier="databasecontainer/funclock/%s" % self.namespace, lock_dir=self.lock_dir) - def do_open(self, flags): + def do_open(self, flags, replace): if self.loaded: self.flags = flags return @@ -108,7 +109,7 @@ class SqlaNamespaceManager(OpenResourceNamespaceManager): return self.hash[key] def __contains__(self, key): - return self.hash.has_key(key) + return key in self.hash def __setitem__(self, key, value): self.hash[key] = value @@ -123,11 +124,13 @@ class SqlaNamespaceManager(OpenResourceNamespaceManager): class SqlaContainer(Container): namespace_manager = SqlaNamespaceManager -def make_cache_table(metadata, table_name='beaker_cache'): + +def make_cache_table(metadata, table_name='beaker_cache', schema_name=None): """Return a ``Table`` object suitable for storing cached values for the namespace manager. Do not create the table.""" return sa.Table(table_name, metadata, sa.Column('namespace', sa.String(255), primary_key=True), sa.Column('accessed', sa.DateTime, nullable=False), sa.Column('created', sa.DateTime, nullable=False), - sa.Column('data', sa.PickleType, nullable=False)) + sa.Column('data', sa.PickleType, nullable=False), + schema=schema_name if schema_name else metadata.schema) diff --git a/module/lib/beaker/middleware.py b/module/lib/beaker/middleware.py index 7ba88b37d..803398584 100644 --- a/module/lib/beaker/middleware.py +++ b/module/lib/beaker/middleware.py @@ -16,15 +16,15 @@ from beaker.util import coerce_cache_params, coerce_session_params, \ class CacheMiddleware(object): cache = beaker_cache - + def __init__(self, app, config=None, environ_key='beaker.cache', **kwargs): """Initialize the Cache Middleware - - The Cache middleware will make a Cache instance available + + The Cache middleware will make a CacheManager instance available every request under the ``environ['beaker.cache']`` key by default. The location in environ can be changed by setting ``environ_key``. - + ``config`` dict All settings should be prefixed by 'cache.'. This method of passing variables is intended for Paste and other @@ -32,11 +32,11 @@ class CacheMiddleware(object): single dictionary. If config contains *no cache. prefixed args*, then *all* of the config options will be used to intialize the Cache objects. - + ``environ_key`` Location where the Cache instance will keyed in the WSGI environ - + ``**kwargs`` All keyword arguments are assumed to be cache settings and will override any settings found in ``config`` @@ -44,26 +44,26 @@ class CacheMiddleware(object): """ self.app = app config = config or {} - + self.options = {} - + # Update the options with the parsed config self.options.update(parse_cache_config_options(config)) - + # Add any options from kwargs, but leave out the defaults this # time self.options.update( parse_cache_config_options(kwargs, include_defaults=False)) - + # Assume all keys are intended for cache if none are prefixed with # 'cache.' if not self.options and config: self.options = config - + self.options.update(kwargs) self.cache_manager = CacheManager(**self.options) self.environ_key = environ_key - + def __call__(self, environ, start_response): if environ.get('paste.registry'): if environ['paste.registry'].reglist: @@ -75,16 +75,16 @@ class CacheMiddleware(object): class SessionMiddleware(object): session = beaker_session - + def __init__(self, wrap_app, config=None, environ_key='beaker.session', **kwargs): """Initialize the Session Middleware - + The Session middleware will make a lazy session instance available every request under the ``environ['beaker.session']`` key by default. The location in environ can be changed by setting ``environ_key``. - + ``config`` dict All settings should be prefixed by 'session.'. This method of passing variables is intended for Paste and other @@ -92,21 +92,21 @@ class SessionMiddleware(object): single dictionary. If config contains *no cache. prefixed args*, then *all* of the config options will be used to intialize the Cache objects. - + ``environ_key`` Location where the Session instance will keyed in the WSGI environ - + ``**kwargs`` All keyword arguments are assumed to be session settings and will override any settings found in ``config`` """ config = config or {} - + # Load up the default params - self.options = dict(invalidate_corrupt=True, type=None, - data_dir=None, key='beaker.session.id', + self.options = dict(invalidate_corrupt=True, type=None, + data_dir=None, key='beaker.session.id', timeout=None, secret=None, log_file=None) # Pull out any config args meant for beaker session. if there are any @@ -120,19 +120,19 @@ class SessionMiddleware(object): warnings.warn('Session options should start with session. ' 'instead of session_.', DeprecationWarning, 2) self.options[key[8:]] = val - + # Coerce and validate session params coerce_session_params(self.options) - + # Assume all keys are intended for cache if none are prefixed with # 'cache.' if not self.options and config: self.options = config - + self.options.update(kwargs) - self.wrap_app = wrap_app + self.wrap_app = self.app = wrap_app self.environ_key = environ_key - + def __call__(self, environ, start_response): session = SessionObject(environ, **self.options) if environ.get('paste.registry'): @@ -140,8 +140,11 @@ class SessionMiddleware(object): environ['paste.registry'].register(self.session, session) environ[self.environ_key] = session environ['beaker.get_session'] = self._get_session - - def session_start_response(status, headers, exc_info = None): + + if 'paste.testing_variables' in environ and 'webtest_varname' in self.options: + environ['paste.testing_variables'][self.options['webtest_varname']] = session + + def session_start_response(status, headers, exc_info=None): if session.accessed(): session.persist() if session.__dict__['_headers']['set_cookie']: @@ -150,7 +153,7 @@ class SessionMiddleware(object): headers.append(('Set-cookie', cookie)) return start_response(status, headers, exc_info) return self.wrap_app(environ, session_start_response) - + def _get_session(self): return Session({}, use_cookies=False, **self.options) diff --git a/module/lib/beaker/session.py b/module/lib/beaker/session.py index 7d465530b..d70a670eb 100644 --- a/module/lib/beaker/session.py +++ b/module/lib/beaker/session.py @@ -1,13 +1,9 @@ import Cookie import os -import random -import time from datetime import datetime, timedelta - +import time from beaker.crypto import hmac as HMAC, hmac_sha1 as SHA1, md5 -from beaker.util import pickle - -from beaker import crypto +from beaker import crypto, util from beaker.cache import clsmap from beaker.exceptions import BeakerException, InvalidCryptoBackendError from base64 import b64encode, b64decode @@ -15,56 +11,102 @@ from base64 import b64encode, b64decode __all__ = ['SignedCookie', 'Session'] -getpid = hasattr(os, 'getpid') and os.getpid or (lambda : '') + +try: + import uuid + + def _session_id(): + return uuid.uuid4().hex +except ImportError: + import random + if hasattr(os, 'getpid'): + getpid = os.getpid + else: + def getpid(): + return '' + + def _session_id(): + id_str = "%f%s%f%s" % ( + time.time(), + id({}), + random.random(), + getpid() + ) + if util.py3k: + return md5( + md5( + id_str.encode('ascii') + ).hexdigest().encode('ascii') + ).hexdigest() + else: + return md5(md5(id_str).hexdigest()).hexdigest() + class SignedCookie(Cookie.BaseCookie): """Extends python cookie to give digital signature support""" def __init__(self, secret, input=None): - self.secret = secret + self.secret = secret.encode('UTF-8') Cookie.BaseCookie.__init__(self, input) - + def value_decode(self, val): val = val.strip('"') - sig = HMAC.new(self.secret, val[40:], SHA1).hexdigest() - + sig = HMAC.new(self.secret, val[40:].encode('UTF-8'), SHA1).hexdigest() + # Avoid timing attacks invalid_bits = 0 input_sig = val[:40] if len(sig) != len(input_sig): return None, val - + for a, b in zip(sig, input_sig): invalid_bits += a != b - + if invalid_bits: return None, val else: return val[40:], val - + def value_encode(self, val): - sig = HMAC.new(self.secret, val, SHA1).hexdigest() + sig = HMAC.new(self.secret, val.encode('UTF-8'), SHA1).hexdigest() return str(val), ("%s%s" % (sig, val)) class Session(dict): """Session object that uses container package for storage. - - ``key`` - The name the cookie should be set to. - ``timeout`` - How long session data is considered valid. This is used - regardless of the cookie being present or not to determine - whether session data is still valid. - ``cookie_domain`` - Domain to use for the cookie. - ``secure`` - Whether or not the cookie should only be sent over SSL. + + :param invalidate_corrupt: How to handle corrupt data when loading. When + set to True, then corrupt data will be silently + invalidated and a new session created, + otherwise invalid data will cause an exception. + :type invalidate_corrupt: bool + :param use_cookies: Whether or not cookies should be created. When set to + False, it is assumed the user will handle storing the + session on their own. + :type use_cookies: bool + :param type: What data backend type should be used to store the underlying + session data + :param key: The name the cookie should be set to. + :param timeout: How long session data is considered valid. This is used + regardless of the cookie being present or not to determine + whether session data is still valid. + :type timeout: int + :param cookie_expires: Expiration date for cookie + :param cookie_domain: Domain to use for the cookie. + :param cookie_path: Path to use for the cookie. + :param secure: Whether or not the cookie should only be sent over SSL. + :param httponly: Whether or not the cookie should only be accessible by + the browser not by JavaScript. + :param encrypt_key: The key to use for the local session encryption, if not + provided the session will not be encrypted. + :param validate_key: The key used to sign the local encrypted session + """ def __init__(self, request, id=None, invalidate_corrupt=False, use_cookies=True, type=None, data_dir=None, key='beaker.session.id', timeout=None, cookie_expires=True, - cookie_domain=None, secret=None, secure=False, - namespace_class=None, **namespace_args): + cookie_domain=None, cookie_path='/', secret=None, + secure=False, namespace_class=None, httponly=False, + encrypt_key=None, validate_key=None, **namespace_args): if not type: if data_dir: self.type = 'file' @@ -76,24 +118,28 @@ class Session(dict): self.namespace_class = namespace_class or clsmap[self.type] self.namespace_args = namespace_args - + self.request = request self.data_dir = data_dir self.key = key - + self.timeout = timeout self.use_cookies = use_cookies self.cookie_expires = cookie_expires - + # Default cookie domain/path self._domain = cookie_domain - self._path = '/' + self._path = cookie_path self.was_invalidated = False self.secret = secret self.secure = secure + self.httponly = httponly + self.encrypt_key = encrypt_key + self.validate_key = validate_key self.id = id self.accessed_dict = {} - + self.invalidate_corrupt = invalidate_corrupt + if self.use_cookies: cookieheader = request.get('cookie', '') if secret: @@ -103,10 +149,10 @@ class Session(dict): self.cookie = SignedCookie(secret, input=None) else: self.cookie = Cookie.SimpleCookie(input=cookieheader) - + if not self.id and self.key in self.cookie: self.id = self.cookie[self.key].value - + self.is_new = self.id is None if self.is_new: self._create_id() @@ -114,80 +160,146 @@ class Session(dict): else: try: self.load() - except: + except Exception, e: if invalidate_corrupt: + util.warn( + "Invalidating corrupt session %s; " + "error was: %s. Set invalidate_corrupt=False " + "to propagate this exception." % (self.id, e)) self.invalidate() else: raise - - def _create_id(self): - self.id = md5( - md5("%f%s%f%s" % (time.time(), id({}), random.random(), - getpid())).hexdigest(), - ).hexdigest() - self.is_new = True - self.last_accessed = None - if self.use_cookies: - self.cookie[self.key] = self.id - if self._domain: - self.cookie[self.key]['domain'] = self._domain - if self.secure: - self.cookie[self.key]['secure'] = True - self.cookie[self.key]['path'] = self._path + + def has_key(self, name): + return name in self + + def _set_cookie_values(self, expires=None): + self.cookie[self.key] = self.id + if self._domain: + self.cookie[self.key]['domain'] = self._domain + if self.secure: + self.cookie[self.key]['secure'] = True + self._set_cookie_http_only() + self.cookie[self.key]['path'] = self._path + + self._set_cookie_expires(expires) + + def _set_cookie_expires(self, expires): + if expires is None: if self.cookie_expires is not True: if self.cookie_expires is False: - expires = datetime.fromtimestamp( 0x7FFFFFFF ) + expires = datetime.fromtimestamp(0x7FFFFFFF) elif isinstance(self.cookie_expires, timedelta): - expires = datetime.today() + self.cookie_expires + expires = datetime.utcnow() + self.cookie_expires elif isinstance(self.cookie_expires, datetime): expires = self.cookie_expires else: raise ValueError("Invalid argument for cookie_expires: %s" % repr(self.cookie_expires)) - self.cookie[self.key]['expires'] = \ - expires.strftime("%a, %d-%b-%Y %H:%M:%S GMT" ) - self.request['cookie_out'] = self.cookie[self.key].output(header='') - self.request['set_cookie'] = False - + else: + expires = None + if expires is not None: + if not self.cookie or self.key not in self.cookie: + self.cookie[self.key] = self.id + self.cookie[self.key]['expires'] = \ + expires.strftime("%a, %d-%b-%Y %H:%M:%S GMT") + return expires + + def _update_cookie_out(self, set_cookie=True): + self.request['cookie_out'] = self.cookie[self.key].output(header='') + self.request['set_cookie'] = set_cookie + + def _set_cookie_http_only(self): + try: + if self.httponly: + self.cookie[self.key]['httponly'] = True + except Cookie.CookieError, e: + if 'Invalid Attribute httponly' not in str(e): + raise + util.warn('Python 2.6+ is required to use httponly') + + def _create_id(self, set_new=True): + self.id = _session_id() + + if set_new: + self.is_new = True + self.last_accessed = None + if self.use_cookies: + self._set_cookie_values() + sc = set_new == False + self._update_cookie_out(set_cookie=sc) + + @property def created(self): return self['_creation_time'] - created = property(created) - + def _set_domain(self, domain): self['_domain'] = domain self.cookie[self.key]['domain'] = domain - self.request['cookie_out'] = self.cookie[self.key].output(header='') - self.request['set_cookie'] = True - + self._update_cookie_out() + def _get_domain(self): return self._domain - + domain = property(_get_domain, _set_domain) - + def _set_path(self, path): - self['_path'] = path + self['_path'] = self._path = path self.cookie[self.key]['path'] = path - self.request['cookie_out'] = self.cookie[self.key].output(header='') - self.request['set_cookie'] = True - + self._update_cookie_out() + def _get_path(self): return self._path - + path = property(_get_path, _set_path) + def _encrypt_data(self, session_data=None): + """Serialize, encipher, and base64 the session dict""" + session_data = session_data or self.copy() + if self.encrypt_key: + nonce = b64encode(os.urandom(6))[:8] + encrypt_key = crypto.generateCryptoKeys(self.encrypt_key, + self.validate_key + nonce, 1) + data = util.pickle.dumps(session_data, 2) + return nonce + b64encode(crypto.aesEncrypt(data, encrypt_key)) + else: + data = util.pickle.dumps(session_data, 2) + return b64encode(data) + + def _decrypt_data(self, session_data): + """Bas64, decipher, then un-serialize the data for the session + dict""" + if self.encrypt_key: + try: + nonce = session_data[:8] + encrypt_key = crypto.generateCryptoKeys(self.encrypt_key, + self.validate_key + nonce, 1) + payload = b64decode(session_data[8:]) + data = crypto.aesDecrypt(payload, encrypt_key) + except: + # As much as I hate a bare except, we get some insane errors + # here that get tossed when crypto fails, so we raise the + # 'right' exception + if self.invalidate_corrupt: + return None + else: + raise + try: + return util.pickle.loads(data) + except: + if self.invalidate_corrupt: + return None + else: + raise + else: + data = b64decode(session_data) + return util.pickle.loads(data) + def _delete_cookie(self): self.request['set_cookie'] = True - self.cookie[self.key] = self.id - if self._domain: - self.cookie[self.key]['domain'] = self._domain - if self.secure: - self.cookie[self.key]['secure'] = True - self.cookie[self.key]['path'] = '/' - expires = datetime.today().replace(year=2003) - self.cookie[self.key]['expires'] = \ - expires.strftime("%a, %d-%b-%Y %H:%M:%S GMT" ) - self.request['cookie_out'] = self.cookie[self.key].output(header='') - self.request['set_cookie'] = True + expires = datetime.utcnow() - timedelta(365) + self._set_cookie_values(expires) + self._update_cookie_out() def delete(self): """Deletes the session from the persistent storage, and sends @@ -203,15 +315,17 @@ class Session(dict): self.was_invalidated = True self._create_id() self.load() - + def load(self): "Loads the data from this session from persistent storage" self.namespace = self.namespace_class(self.id, - data_dir=self.data_dir, digest_filenames=False, + data_dir=self.data_dir, + digest_filenames=False, **self.namespace_args) now = time.time() - self.request['set_cookie'] = True - + if self.use_cookies: + self.request['set_cookie'] = True + self.namespace.acquire_read_lock() timed_out = False try: @@ -219,24 +333,34 @@ class Session(dict): try: session_data = self.namespace['session'] + if (session_data is not None and self.encrypt_key): + session_data = self._decrypt_data(session_data) + # Memcached always returns a key, its None when its not # present if session_data is None: session_data = { - '_creation_time':now, - '_accessed_time':now + '_creation_time': now, + '_accessed_time': now } self.is_new = True except (KeyError, TypeError): session_data = { - '_creation_time':now, - '_accessed_time':now + '_creation_time': now, + '_accessed_time': now + } + self.is_new = True + + if session_data is None or len(session_data) == 0: + session_data = { + '_creation_time': now, + '_accessed_time': now } self.is_new = True - + if self.timeout is not None and \ now - session_data['_accessed_time'] > self.timeout: - timed_out= True + timed_out = True else: # Properly set the last_accessed time, which is different # than the *currently* _accessed_time @@ -244,43 +368,52 @@ class Session(dict): self.last_accessed = None else: self.last_accessed = session_data['_accessed_time'] - + # Update the current _accessed_time session_data['_accessed_time'] = now + + # Set the path if applicable + if '_path' in session_data: + self._path = session_data['_path'] self.update(session_data) - self.accessed_dict = session_data.copy() + self.accessed_dict = session_data.copy() finally: self.namespace.release_read_lock() if timed_out: self.invalidate() - + def save(self, accessed_only=False): """Saves the data for this session to persistent storage - + If accessed_only is True, then only the original data loaded at the beginning of the request will be saved, with the updated last accessed time. - + """ # Look to see if its a new session that was only accessed # Don't save it under that case if accessed_only and self.is_new: return None - - if not hasattr(self, 'namespace'): + + # this session might not have a namespace yet or the session id + # might have been regenerated + if not hasattr(self, 'namespace') or self.namespace.namespace != self.id: self.namespace = self.namespace_class( - self.id, + self.id, data_dir=self.data_dir, - digest_filenames=False, + digest_filenames=False, **self.namespace_args) - - self.namespace.acquire_write_lock() + + self.namespace.acquire_write_lock(replace=True) try: if accessed_only: data = dict(self.accessed_dict.items()) else: data = dict(self.items()) - + + if self.encrypt_key: + data = self._encrypt_data(data) + # Save the data if not data and 'session' in self.namespace: del self.namespace['session'] @@ -288,22 +421,32 @@ class Session(dict): self.namespace['session'] = data finally: self.namespace.release_write_lock() - if self.is_new: + if self.use_cookies and self.is_new: self.request['set_cookie'] = True - + def revert(self): """Revert the session to its original state from its first access in the request""" self.clear() self.update(self.accessed_dict) - + + def regenerate_id(self): + """ + creates a new session id, retains all session data + + Its a good security practice to regnerate the id after a client + elevates priviliges. + + """ + self._create_id(set_new=False) + # TODO: I think both these methods should be removed. They're from # the original mod_python code i was ripping off but they really # have no use here. def lock(self): """Locks this session against other processes/threads. This is automatic when load/save is called. - + ***use with caution*** and always with a corresponding 'unlock' inside a "finally:" block, as a stray lock typically cannot be unlocked without shutting down the whole application. @@ -322,37 +465,38 @@ class Session(dict): """ self.namespace.release_write_lock() + class CookieSession(Session): """Pure cookie-based session - + Options recognized when using cookie-based sessions are slightly more restricted than general sessions. - - ``key`` - The name the cookie should be set to. - ``timeout`` - How long session data is considered valid. This is used - regardless of the cookie being present or not to determine - whether session data is still valid. - ``encrypt_key`` - The key to use for the session encryption, if not provided the - session will not be encrypted. - ``validate_key`` - The key used to sign the encrypted session - ``cookie_domain`` - Domain to use for the cookie. - ``secure`` - Whether or not the cookie should only be sent over SSL. - + + :param key: The name the cookie should be set to. + :param timeout: How long session data is considered valid. This is used + regardless of the cookie being present or not to determine + whether session data is still valid. + :type timeout: int + :param cookie_expires: Expiration date for cookie + :param cookie_domain: Domain to use for the cookie. + :param cookie_path: Path to use for the cookie. + :param secure: Whether or not the cookie should only be sent over SSL. + :param httponly: Whether or not the cookie should only be accessible by + the browser not by JavaScript. + :param encrypt_key: The key to use for the local session encryption, if not + provided the session will not be encrypted. + :param validate_key: The key used to sign the local encrypted session + """ def __init__(self, request, key='beaker.session.id', timeout=None, - cookie_expires=True, cookie_domain=None, encrypt_key=None, - validate_key=None, secure=False, **kwargs): - + cookie_expires=True, cookie_domain=None, cookie_path='/', + encrypt_key=None, validate_key=None, secure=False, + httponly=False, **kwargs): + if not crypto.has_aes and encrypt_key: raise InvalidCryptoBackendError("No AES library is installed, can't generate " "encrypted cookie-only Session.") - + self.request = request self.key = key self.timeout = timeout @@ -361,31 +505,34 @@ class CookieSession(Session): self.validate_key = validate_key self.request['set_cookie'] = False self.secure = secure + self.httponly = httponly self._domain = cookie_domain - self._path = '/' - + self._path = cookie_path + try: cookieheader = request['cookie'] except KeyError: cookieheader = '' - + if validate_key is None: raise BeakerException("No validate_key specified for Cookie only " "Session.") - + try: self.cookie = SignedCookie(validate_key, input=cookieheader) except Cookie.CookieError: self.cookie = SignedCookie(validate_key, input=None) - - self['_id'] = self._make_id() + + self['_id'] = _session_id() self.is_new = True - + # If we have a cookie, load it if self.key in self.cookie and self.cookie[self.key].value is not None: self.is_new = False try: - self.update(self._decrypt_data()) + cookie_data = self.cookie[self.key].value + self.update(self._decrypt_data(cookie_data)) + self._path = self.get('_path', '/') except: pass if self.timeout is not None and time.time() - \ @@ -393,11 +540,11 @@ class CookieSession(Session): self.clear() self.accessed_dict = self.copy() self._create_cookie() - + def created(self): return self['_creation_time'] created = property(created) - + def id(self): return self['_id'] id = property(id) @@ -405,53 +552,20 @@ class CookieSession(Session): def _set_domain(self, domain): self['_domain'] = domain self._domain = domain - + def _get_domain(self): return self._domain - + domain = property(_get_domain, _set_domain) - + def _set_path(self, path): - self['_path'] = path - self._path = path - + self['_path'] = self._path = path + def _get_path(self): return self._path - + path = property(_get_path, _set_path) - def _encrypt_data(self): - """Serialize, encipher, and base64 the session dict""" - if self.encrypt_key: - nonce = b64encode(os.urandom(40))[:8] - encrypt_key = crypto.generateCryptoKeys(self.encrypt_key, - self.validate_key + nonce, 1) - data = pickle.dumps(self.copy(), 2) - return nonce + b64encode(crypto.aesEncrypt(data, encrypt_key)) - else: - data = pickle.dumps(self.copy(), 2) - return b64encode(data) - - def _decrypt_data(self): - """Bas64, decipher, then un-serialize the data for the session - dict""" - if self.encrypt_key: - nonce = self.cookie[self.key].value[:8] - encrypt_key = crypto.generateCryptoKeys(self.encrypt_key, - self.validate_key + nonce, 1) - payload = b64decode(self.cookie[self.key].value[8:]) - data = crypto.aesDecrypt(payload, encrypt_key) - return pickle.loads(data) - else: - data = b64decode(self.cookie[self.key].value) - return pickle.loads(data) - - def _make_id(self): - return md5(md5( - "%f%s%f%s" % (time.time(), id({}), random.random(), getpid()) - ).hexdigest() - ).hexdigest() - def save(self, accessed_only=False): """Saves the data for this session to persistent storage""" if accessed_only and self.is_new: @@ -460,88 +574,79 @@ class CookieSession(Session): self.clear() self.update(self.accessed_dict) self._create_cookie() - + def expire(self): """Delete the 'expires' attribute on this Session, if any.""" - + self.pop('_expires', None) - + def _create_cookie(self): if '_creation_time' not in self: self['_creation_time'] = time.time() if '_id' not in self: - self['_id'] = self._make_id() + self['_id'] = _session_id() self['_accessed_time'] = time.time() - - if self.cookie_expires is not True: - if self.cookie_expires is False: - expires = datetime.fromtimestamp( 0x7FFFFFFF ) - elif isinstance(self.cookie_expires, timedelta): - expires = datetime.today() + self.cookie_expires - elif isinstance(self.cookie_expires, datetime): - expires = self.cookie_expires - else: - raise ValueError("Invalid argument for cookie_expires: %s" - % repr(self.cookie_expires)) - self['_expires'] = expires - elif '_expires' in self: - expires = self['_expires'] - else: - expires = None val = self._encrypt_data() if len(val) > 4064: raise BeakerException("Cookie value is too long to store") - + self.cookie[self.key] = val + + if '_expires' in self: + expires = self['_expires'] + else: + expires = None + expires = self._set_cookie_expires(expires) + if expires is not None: + self['_expires'] = expires + if '_domain' in self: self.cookie[self.key]['domain'] = self['_domain'] elif self._domain: self.cookie[self.key]['domain'] = self._domain if self.secure: self.cookie[self.key]['secure'] = True - + self._set_cookie_http_only() + self.cookie[self.key]['path'] = self.get('_path', '/') - - if expires: - self.cookie[self.key]['expires'] = \ - expires.strftime("%a, %d-%b-%Y %H:%M:%S GMT" ) + self.request['cookie_out'] = self.cookie[self.key].output(header='') self.request['set_cookie'] = True - + def delete(self): """Delete the cookie, and clear the session""" # Send a delete cookie request self._delete_cookie() self.clear() - + def invalidate(self): """Clear the contents and start a new session""" - self.delete() - self['_id'] = self._make_id() + self.clear() + self['_id'] = _session_id() class SessionObject(object): """Session proxy/lazy creator - + This object proxies access to the actual session object, so that in the case that the session hasn't been used before, it will be setup. This avoid creating and loading the session from persistent storage unless its actually used during the request. - + """ def __init__(self, environ, **params): self.__dict__['_params'] = params self.__dict__['_environ'] = environ self.__dict__['_sess'] = None - self.__dict__['_headers'] = [] - + self.__dict__['_headers'] = {} + def _session(self): """Lazy initial creation of session object""" if self.__dict__['_sess'] is None: params = self.__dict__['_params'] environ = self.__dict__['_environ'] - self.__dict__['_headers'] = req = {'cookie_out':None} + self.__dict__['_headers'] = req = {'cookie_out': None} req['cookie'] = environ.get('HTTP_COOKIE') if params.get('type') == 'cookie': self.__dict__['_sess'] = CookieSession(req, **params) @@ -549,35 +654,38 @@ class SessionObject(object): self.__dict__['_sess'] = Session(req, use_cookies=True, **params) return self.__dict__['_sess'] - + def __getattr__(self, attr): return getattr(self._session(), attr) - + def __setattr__(self, attr, value): setattr(self._session(), attr, value) - + def __delattr__(self, name): self._session().__delattr__(name) - + def __getitem__(self, key): return self._session()[key] - + def __setitem__(self, key, value): self._session()[key] = value - + def __delitem__(self, key): self._session().__delitem__(key) - + def __repr__(self): return self._session().__repr__() - + def __iter__(self): """Only works for proxying to a dict""" return iter(self._session().keys()) - + def __contains__(self, key): - return self._session().has_key(key) - + return key in self._session() + + def has_key(self, key): + return key in self._session() + def get_by_id(self, id): """Loads a session given a session ID""" params = self.__dict__['_params'] @@ -585,22 +693,22 @@ class SessionObject(object): if session.is_new: return None return session - + def save(self): self.__dict__['_dirty'] = True - + def delete(self): self.__dict__['_dirty'] = True self._session().delete() - + def persist(self): """Persist the session to the storage - + If its set to autosave, then the entire session will be saved regardless of if save() has been called. Otherwise, just the accessed time will be updated if save() was not called, or the session will be saved if save() was called. - + """ if self.__dict__['_params'].get('auto'): self._session().save() @@ -609,10 +717,10 @@ class SessionObject(object): self._session().save() else: self._session().save(accessed_only=True) - + def dirty(self): return self.__dict__.get('_dirty', False) - + def accessed(self): """Returns whether or not the session has been accessed""" return self.__dict__['_sess'] is not None diff --git a/module/lib/beaker/synchronization.py b/module/lib/beaker/synchronization.py index 761303707..f236b8cfe 100644 --- a/module/lib/beaker/synchronization.py +++ b/module/lib/beaker/synchronization.py @@ -29,18 +29,18 @@ except: from beaker import util from beaker.exceptions import LockError -__all__ = ["file_synchronizer", "mutex_synchronizer", "null_synchronizer", +__all__ = ["file_synchronizer", "mutex_synchronizer", "null_synchronizer", "NameLock", "_threading"] class NameLock(object): """a proxy for an RLock object that is stored in a name based - registry. - + registry. + Multiple threads can get a reference to the same RLock based on the name alone, and synchronize operations related to that name. - """ + """ locks = util.WeakValuedRegistry() class NLContainer(object): @@ -49,17 +49,18 @@ class NameLock(object): self.lock = _threading.RLock() else: self.lock = _threading.Lock() + def __call__(self): return self.lock - def __init__(self, identifier = None, reentrant = False): + def __init__(self, identifier=None, reentrant=False): if identifier is None: self._lock = NameLock.NLContainer(reentrant) else: self._lock = NameLock.locks.get(identifier, NameLock.NLContainer, reentrant) - def acquire(self, wait = True): + def acquire(self, wait=True): return self._lock().acquire(wait) def release(self): @@ -67,6 +68,8 @@ class NameLock(object): _synchronizers = util.WeakValuedRegistry() + + def _synchronizer(identifier, cls, **kwargs): return _synchronizers.sync_get((identifier, cls), cls, identifier, **kwargs) @@ -83,12 +86,19 @@ def mutex_synchronizer(identifier, **kwargs): class null_synchronizer(object): + """A 'null' synchronizer, which provides the :class:`.SynchronizerImpl` interface + without any locking. + + """ def acquire_write_lock(self, wait=True): return True + def acquire_read_lock(self): pass + def release_write_lock(self): pass + def release_read_lock(self): pass acquire = acquire_write_lock @@ -96,6 +106,10 @@ class null_synchronizer(object): class SynchronizerImpl(object): + """Base class for a synchronization object that allows + multiple readers, single writers. + + """ def __init__(self): self._state = util.ThreadLocal() @@ -115,27 +129,27 @@ class SynchronizerImpl(object): else: return self._state.get() state = property(state) - + def release_read_lock(self): state = self.state - if state.writing: + if state.writing: raise LockError("lock is in writing state") - if not state.reading: + if not state.reading: raise LockError("lock is not in reading state") - + if state.reentrantcount == 1: self.do_release_read_lock() state.reading = False state.reentrantcount -= 1 - - def acquire_read_lock(self, wait = True): + + def acquire_read_lock(self, wait=True): state = self.state - if state.writing: + if state.writing: raise LockError("lock is in writing state") - + if state.reentrantcount == 0: x = self.do_acquire_read_lock(wait) if (wait or x): @@ -145,13 +159,13 @@ class SynchronizerImpl(object): elif state.reading: state.reentrantcount += 1 return True - + def release_write_lock(self): state = self.state - if state.reading: + if state.reading: raise LockError("lock is in reading state") - if not state.writing: + if not state.writing: raise LockError("lock is not in writing state") if state.reentrantcount == 1: @@ -159,18 +173,18 @@ class SynchronizerImpl(object): state.writing = False state.reentrantcount -= 1 - + release = release_write_lock - - def acquire_write_lock(self, wait = True): + + def acquire_write_lock(self, wait=True): state = self.state - if state.reading: + if state.reading: raise LockError("lock is in reading state") - + if state.reentrantcount == 0: x = self.do_acquire_write_lock(wait) - if (wait or x): + if (wait or x): state.reentrantcount += 1 state.writing = True return x @@ -182,56 +196,47 @@ class SynchronizerImpl(object): def do_release_read_lock(self): raise NotImplementedError() - + def do_acquire_read_lock(self): raise NotImplementedError() - + def do_release_write_lock(self): raise NotImplementedError() - + def do_acquire_write_lock(self): raise NotImplementedError() class FileSynchronizer(SynchronizerImpl): - """a synchronizer which locks using flock(). - - Adapted for Python/multithreads from Apache::Session::Lock::File, - http://search.cpan.org/src/CWEST/Apache-Session-1.81/Session/Lock/File.pm - - This module does not unlink temporary files, - because it interferes with proper locking. This can cause - problems on certain systems (Linux) whose file systems (ext2) do not - perform well with lots of files in one directory. To prevent this - you should use a script to clean out old files from your lock directory. - + """A synchronizer which locks using flock(). + """ def __init__(self, identifier, lock_dir): super(FileSynchronizer, self).__init__() self._filedescriptor = util.ThreadLocal() - + if lock_dir is None: lock_dir = tempfile.gettempdir() else: lock_dir = lock_dir self.filename = util.encoded_path( - lock_dir, - [identifier], + lock_dir, + [identifier], extension='.lock' ) def _filedesc(self): return self._filedescriptor.get() _filedesc = property(_filedesc) - + def _open(self, mode): filedescriptor = self._filedesc if filedescriptor is None: filedescriptor = os.open(self.filename, mode) self._filedescriptor.put(filedescriptor) return filedescriptor - + def do_acquire_read_lock(self, wait): filedescriptor = self._open(os.O_CREAT | os.O_RDONLY) if not wait: @@ -259,13 +264,13 @@ class FileSynchronizer(SynchronizerImpl): else: fcntl.flock(filedescriptor, fcntl.LOCK_EX) return True - + def do_release_read_lock(self): self._release_all_locks() - + def do_release_write_lock(self): self._release_all_locks() - + def _release_all_locks(self): filedescriptor = self._filedesc if filedescriptor is not None: @@ -276,7 +281,7 @@ class FileSynchronizer(SynchronizerImpl): class ConditionSynchronizer(SynchronizerImpl): """a synchronizer using a Condition.""" - + def __init__(self, identifier): super(ConditionSynchronizer, self).__init__() @@ -289,7 +294,7 @@ class ConditionSynchronizer(SynchronizerImpl): # condition object to lock on self.condition = _threading.Condition(_threading.Lock()) - def do_acquire_read_lock(self, wait = True): + def do_acquire_read_lock(self, wait=True): self.condition.acquire() try: # see if a synchronous operation is waiting to start @@ -306,15 +311,15 @@ class ConditionSynchronizer(SynchronizerImpl): finally: self.condition.release() - if not wait: + if not wait: return True - + def do_release_read_lock(self): self.condition.acquire() try: self.async -= 1 - - # check if we are the last asynchronous reader thread + + # check if we are the last asynchronous reader thread # out the door. if self.async == 0: # yes. so if a sync operation is waiting, notifyAll to wake @@ -326,13 +331,13 @@ class ConditionSynchronizer(SynchronizerImpl): "release_read_locks called") finally: self.condition.release() - - def do_acquire_write_lock(self, wait = True): + + def do_acquire_write_lock(self, wait=True): self.condition.acquire() try: # here, we are not a synchronous reader, and after returning, # assuming waiting or immediate availability, we will be. - + if wait: # if another sync is working, wait while self.current_sync_operation is not None: @@ -342,8 +347,8 @@ class ConditionSynchronizer(SynchronizerImpl): # we dont want to wait, so forget it if self.current_sync_operation is not None: return False - - # establish ourselves as the current sync + + # establish ourselves as the current sync # this indicates to other read/write operations # that they should wait until this is None again self.current_sync_operation = _threading.currentThread() @@ -359,8 +364,8 @@ class ConditionSynchronizer(SynchronizerImpl): return False finally: self.condition.release() - - if not wait: + + if not wait: return True def do_release_write_lock(self): @@ -370,7 +375,7 @@ class ConditionSynchronizer(SynchronizerImpl): raise LockError("Synchronizer error - current thread doesnt " "have the write lock") - # reset the current sync operation so + # reset the current sync operation so # another can get it self.current_sync_operation = None diff --git a/module/lib/beaker/util.py b/module/lib/beaker/util.py index 04c9617c5..c7002cd92 100644 --- a/module/lib/beaker/util.py +++ b/module/lib/beaker/util.py @@ -9,14 +9,16 @@ except ImportError: from datetime import datetime, timedelta import os +import re import string import types import weakref import warnings import sys +import inspect py3k = getattr(sys, 'py3kwarning', False) or sys.version_info >= (3, 0) -py24 = sys.version_info < (2,5) +py24 = sys.version_info < (2, 5) jython = sys.platform.startswith('java') if py3k or jython: @@ -25,11 +27,56 @@ else: import cPickle as pickle from beaker.converters import asbool +from beaker import exceptions from threading import local as _tlocal -__all__ = ["ThreadLocal", "Registry", "WeakValuedRegistry", "SyncDict", - "encoded_path", "verify_directory"] +__all__ = ["ThreadLocal", "WeakValuedRegistry", "SyncDict", "encoded_path", + "verify_directory"] + + +def function_named(fn, name): + """Return a function with a given __name__. + + Will assign to __name__ and return the original function if possible on + the Python implementation, otherwise a new function will be constructed. + + """ + fn.__name__ = name + return fn + + +def skip_if(predicate, reason=None): + """Skip a test if predicate is true.""" + reason = reason or predicate.__name__ + + from nose import SkipTest + + def decorate(fn): + fn_name = fn.__name__ + + def maybe(*args, **kw): + if predicate(): + msg = "'%s' skipped: %s" % ( + fn_name, reason) + raise SkipTest(msg) + else: + return fn(*args, **kw) + return function_named(maybe, fn_name) + return decorate + + +def assert_raises(except_cls, callable_, *args, **kw): + """Assert the given exception is raised by the given function + arguments.""" + + try: + callable_(*args, **kw) + success = False + except except_cls: + success = True + + # assert outside the block so it works for AssertionError too ! + assert success, "Callable did not raise an exception" def verify_directory(dir): @@ -45,7 +92,24 @@ def verify_directory(dir): if tries > 5: raise - + +def has_self_arg(func): + """Return True if the given function has a 'self' argument.""" + args = inspect.getargspec(func) + if args and args[0] and args[0][0] in ('self', 'cls'): + return True + else: + return False + + +def warn(msg, stacklevel=3): + """Issue a warning.""" + if isinstance(msg, basestring): + warnings.warn(msg, exceptions.BeakerWarning, stacklevel=stacklevel) + else: + warnings.warn(msg, stacklevel=stacklevel) + + def deprecated(message): def wrapper(fn): def deprecated_method(*args, **kargs): @@ -56,7 +120,8 @@ def deprecated(message): deprecated_method.__doc__ = "%s\n\n%s" % (message, fn.__doc__) return deprecated_method return wrapper - + + class ThreadLocal(object): """stores a value on a per-thread basis""" @@ -64,42 +129,43 @@ class ThreadLocal(object): def __init__(self): self._tlocal = _tlocal() - + def put(self, value): self._tlocal.value = value - + def has(self): return hasattr(self._tlocal, 'value') - + def get(self, default=None): return getattr(self._tlocal, 'value', default) - + def remove(self): del self._tlocal.value - + + class SyncDict(object): """ An efficient/threadsafe singleton map algorithm, a.k.a. "get a value based on this key, and create if not found or not valid" paradigm: - + exists && isvalid ? get : create Designed to work with weakref dictionaries to expect items - to asynchronously disappear from the dictionary. + to asynchronously disappear from the dictionary. Use python 2.3.3 or greater ! a major bug was just fixed in Nov. 2003 that was driving me nuts with garbage collection/weakrefs in this section. - """ + """ def __init__(self): self.mutex = _thread.allocate_lock() self.dict = {} - + def get(self, key, createfunc, *args, **kwargs): try: - if self.has_key(key): + if key in self.dict: return self.dict[key] else: return self.sync_get(key, createfunc, *args, **kwargs) @@ -110,7 +176,7 @@ class SyncDict(object): self.mutex.acquire() try: try: - if self.has_key(key): + if key in self.dict: return self.dict[key] else: return self._create(key, createfunc, *args, **kwargs) @@ -124,16 +190,20 @@ class SyncDict(object): return obj def has_key(self, key): - return self.dict.has_key(key) - + return key in self.dict + def __contains__(self, key): return self.dict.__contains__(key) + def __getitem__(self, key): return self.dict.__getitem__(key) + def __setitem__(self, key, value): self.dict.__setitem__(key, value) + def __delitem__(self, key): return self.dict.__delitem__(key) + def clear(self): self.dict.clear() @@ -143,36 +213,47 @@ class WeakValuedRegistry(SyncDict): self.mutex = _threading.RLock() self.dict = weakref.WeakValueDictionary() -sha1 = None -def encoded_path(root, identifiers, extension = ".enc", depth = 3, +sha1 = None + + +def encoded_path(root, identifiers, extension=".enc", depth=3, digest_filenames=True): - + """Generate a unique file-accessible path from the given list of identifiers starting at the given root directory.""" ident = "_".join(identifiers) - + global sha1 if sha1 is None: from beaker.crypto import sha1 - + if digest_filenames: if py3k: ident = sha1(ident.encode('utf-8')).hexdigest() else: ident = sha1(ident).hexdigest() - + ident = os.path.basename(ident) tokens = [] for d in range(1, depth): tokens.append(ident[0:d]) - + dir = os.path.join(root, *tokens) verify_directory(dir) - + return os.path.join(dir, ident + extension) +def asint(obj): + if isinstance(obj, int): + return obj + elif isinstance(obj, basestring) and re.match(r'^\d+$', obj): + return int(obj) + else: + raise Exception("This is not a proper int") + + def verify_options(opt, types, error): if not isinstance(opt, types): if not isinstance(types, tuple): @@ -185,6 +266,11 @@ def verify_options(opt, types, error): else: if typ == bool: typ = asbool + elif typ == int: + typ = asint + elif typ in (timedelta, datetime): + if not isinstance(opt, typ): + raise Exception("%s requires a timedelta type", typ) opt = typ(opt) coerced = True except: @@ -212,10 +298,12 @@ def coerce_session_params(params): ('lock_dir', (str, types.NoneType), "lock_dir must be a string referring to a " "directory."), ('type', (str, types.NoneType), "Session type must be a string."), - ('cookie_expires', (bool, datetime, timedelta), "Cookie expires was " - "not a boolean, datetime, or timedelta instance."), + ('cookie_expires', (bool, datetime, timedelta, int), "Cookie expires was " + "not a boolean, datetime, int, or timedelta instance."), ('cookie_domain', (str, types.NoneType), "Cookie domain must be a " "string."), + ('cookie_path', (str, types.NoneType), "Cookie path must be a " + "string."), ('id', (str,), "Session id must be a string."), ('key', (str,), "Session key must be a string."), ('secret', (str, types.NoneType), "Session secret must be a string."), @@ -224,11 +312,19 @@ def coerce_session_params(params): ('encrypt_key', (str, types.NoneType), "Session validate_key must be " "a string."), ('secure', (bool, types.NoneType), "Session secure must be a boolean."), + ('httponly', (bool, types.NoneType), "Session httponly must be a boolean."), ('timeout', (int, types.NoneType), "Session timeout must be an " "integer."), ('auto', (bool, types.NoneType), "Session is created if accessed."), + ('webtest_varname', (str, types.NoneType), "Session varname must be " + "a string."), ] - return verify_rules(params, rules) + opts = verify_rules(params, rules) + cookie_expires = opts.get('cookie_expires') + if cookie_expires and isinstance(cookie_expires, int) and \ + not isinstance(cookie_expires, bool): + opts['cookie_expires'] = timedelta(seconds=cookie_expires) + return opts def coerce_cache_params(params): @@ -243,18 +339,63 @@ def coerce_cache_params(params): ('expire', (int, types.NoneType), "expire must be an integer representing " "how many seconds the cache is valid for"), ('regions', (list, tuple, types.NoneType), "Regions must be a " - "comma seperated list of valid regions") + "comma seperated list of valid regions"), + ('key_length', (int, types.NoneType), "key_length must be an integer " + "which indicates the longest a key can be before hashing"), ] return verify_rules(params, rules) +def coerce_memcached_behaviors(behaviors): + rules = [ + ('cas', (bool, int), 'cas must be a boolean or an integer'), + ('no_block', (bool, int), 'no_block must be a boolean or an integer'), + ('receive_timeout', (int,), 'receive_timeout must be an integer'), + ('send_timeout', (int,), 'send_timeout must be an integer'), + ('ketama_hash', (str,), 'ketama_hash must be a string designating ' + 'a valid hashing strategy option'), + ('_poll_timeout', (int,), '_poll_timeout must be an integer'), + ('auto_eject', (bool, int), 'auto_eject must be an integer'), + ('retry_timeout', (int,), 'retry_timeout must be an integer'), + ('_sort_hosts', (bool, int), '_sort_hosts must be an integer'), + ('_io_msg_watermark', (int,), '_io_msg_watermark must be an integer'), + ('ketama', (bool, int), 'ketama must be a boolean or an integer'), + ('ketama_weighted', (bool, int), 'ketama_weighted must be a boolean or ' + 'an integer'), + ('_io_key_prefetch', (int, bool), '_io_key_prefetch must be a boolean ' + 'or an integer'), + ('_hash_with_prefix_key', (bool, int), '_hash_with_prefix_key must be ' + 'a boolean or an integer'), + ('tcp_nodelay', (bool, int), 'tcp_nodelay must be a boolean or an ' + 'integer'), + ('failure_limit', (int,), 'failure_limit must be an integer'), + ('buffer_requests', (bool, int), 'buffer_requests must be a boolean ' + 'or an integer'), + ('_socket_send_size', (int,), '_socket_send_size must be an integer'), + ('num_replicas', (int,), 'num_replicas must be an integer'), + ('remove_failed', (int,), 'remove_failed must be an integer'), + ('_noreply', (bool, int), '_noreply must be a boolean or an integer'), + ('_io_bytes_watermark', (int,), '_io_bytes_watermark must be an ' + 'integer'), + ('_socket_recv_size', (int,), '_socket_recv_size must be an integer'), + ('distribution', (str,), 'distribution must be a string designating ' + 'a valid distribution option'), + ('connect_timeout', (int,), 'connect_timeout must be an integer'), + ('hash', (str,), 'hash must be a string designating a valid hashing ' + 'option'), + ('verify_keys', (bool, int), 'verify_keys must be a boolean or an integer'), + ('dead_timeout', (int,), 'dead_timeout must be an integer') + ] + return verify_rules(behaviors, rules) + + def parse_cache_config_options(config, include_defaults=True): """Parse configuration options and validate for use with the CacheManager""" - + # Load default cache options if include_defaults: - options= dict(type='memory', data_dir=None, expire=None, + options = dict(type='memory', data_dir=None, expire=None, log_file=None) else: options = {} @@ -264,39 +405,58 @@ def parse_cache_config_options(config, include_defaults=True): if key.startswith('cache.'): options[key[6:]] = val coerce_cache_params(options) - + # Set cache to enabled if not turned off - if 'enabled' not in options: + if 'enabled' not in options and include_defaults: options['enabled'] = True - + # Configure region dict if regions are available regions = options.pop('regions', None) if regions: region_configs = {} for region in regions: + if not region: # ensure region name is valid + continue # Setup the default cache options region_options = dict(data_dir=options.get('data_dir'), lock_dir=options.get('lock_dir'), type=options.get('type'), enabled=options['enabled'], - expire=options.get('expire')) - region_len = len(region) + 1 + expire=options.get('expire'), + key_length=options.get('key_length', 250)) + region_prefix = '%s.' % region + region_len = len(region_prefix) for key in options.keys(): - if key.startswith('%s.' % region): + if key.startswith(region_prefix): region_options[key[region_len:]] = options.pop(key) coerce_cache_params(region_options) region_configs[region] = region_options options['cache_regions'] = region_configs return options + +def parse_memcached_behaviors(config): + """Parse behavior options and validate for use with pylibmc + client/PylibMCNamespaceManager, or potentially other memcached + NamespaceManagers that support behaviors""" + behaviors = {} + + for key, val in config.iteritems(): + if key.startswith('behavior.'): + behaviors[key[9:]] = val + + coerce_memcached_behaviors(behaviors) + return behaviors + + def func_namespace(func): """Generates a unique namespace for a function""" kls = None if hasattr(func, 'im_func'): kls = func.im_class func = func.im_func - + if kls: return '%s.%s' % (kls.__module__, kls.__name__) else: - return '%s.%s' % (func.__module__, func.__name__) + return '%s|%s' % (inspect.getsourcefile(func), func.__name__) diff --git a/module/lib/bottle.py b/module/lib/bottle.py index 2c243278e..bcc284c1d 100644 --- a/module/lib/bottle.py +++ b/module/lib/bottle.py @@ -9,14 +9,14 @@ Python Standard Library. Homepage and documentation: http://bottlepy.org/ -Copyright (c) 2011, Marcel Hellkamp. -License: MIT (see LICENSE.txt for details) +Copyright (c) 2013, Marcel Hellkamp. +License: MIT (see LICENSE for details) """ from __future__ import with_statement __author__ = 'Marcel Hellkamp' -__version__ = '0.10.2' +__version__ = '0.12.7' __license__ = 'MIT' # The gevent server adapter needs to patch some modules before they are imported @@ -35,102 +35,116 @@ if __name__ == '__main__': if _cmd_options.server and _cmd_options.server.startswith('gevent'): import gevent.monkey; gevent.monkey.patch_all() -import sys -import base64 -import cgi -import email.utils -import functools -import hmac -import httplib -import imp -import itertools -import mimetypes -import os -import re -import subprocess -import tempfile -import thread -import threading -import time -import warnings - -from Cookie import SimpleCookie +import base64, cgi, email.utils, functools, hmac, imp, itertools, mimetypes,\ + os, re, subprocess, sys, tempfile, threading, time, warnings + from datetime import date as datedate, datetime, timedelta from tempfile import TemporaryFile from traceback import format_exc, print_exc -from urlparse import urljoin, SplitResult as UrlSplitResult - -# Workaround for a bug in some versions of lib2to3 (fixed on CPython 2.7 and 3.2) -import urllib -urlencode = urllib.urlencode -urlquote = urllib.quote -urlunquote = urllib.unquote - -try: from collections import MutableMapping as DictMixin -except ImportError: # pragma: no cover - from UserDict import DictMixin +from inspect import getargspec +from unicodedata import normalize -try: from urlparse import parse_qs -except ImportError: # pragma: no cover - from cgi import parse_qs -try: import cPickle as pickle +try: from simplejson import dumps as json_dumps, loads as json_lds except ImportError: # pragma: no cover - import pickle - -try: from json import dumps as json_dumps, loads as json_lds -except ImportError: # pragma: no cover - try: from simplejson import dumps as json_dumps, loads as json_lds - except ImportError: # pragma: no cover + try: from json import dumps as json_dumps, loads as json_lds + except ImportError: try: from django.utils.simplejson import dumps as json_dumps, loads as json_lds - except ImportError: # pragma: no cover + except ImportError: def json_dumps(data): raise ImportError("JSON support requires Python 2.6 or simplejson.") json_lds = json_dumps -py3k = sys.version_info >= (3,0,0) -NCTextIOWrapper = None -if py3k: # pragma: no cover - json_loads = lambda s: json_lds(touni(s)) - # See Request.POST + +# We now try to fix 2.5/2.6/3.1/3.2 incompatibilities. +# It ain't pretty but it works... Sorry for the mess. + +py = sys.version_info +py3k = py >= (3, 0, 0) +py25 = py < (2, 6, 0) +py31 = (3, 1, 0) <= py < (3, 2, 0) + +# Workaround for the missing "as" keyword in py3k. +def _e(): return sys.exc_info()[1] + +# Workaround for the "print is a keyword/function" Python 2/3 dilemma +# and a fallback for mod_wsgi (resticts stdout/err attribute access) +try: + _stdout, _stderr = sys.stdout.write, sys.stderr.write +except IOError: + _stdout = lambda x: sys.stdout.write(x) + _stderr = lambda x: sys.stderr.write(x) + +# Lots of stdlib and builtin differences. +if py3k: + import http.client as httplib + import _thread as thread + from urllib.parse import urljoin, SplitResult as UrlSplitResult + from urllib.parse import urlencode, quote as urlquote, unquote as urlunquote + urlunquote = functools.partial(urlunquote, encoding='latin1') + from http.cookies import SimpleCookie + from collections import MutableMapping as DictMixin + import pickle from io import BytesIO - def touni(x, enc='utf8', err='strict'): - """ Convert anything to unicode """ - return str(x, enc, err) if isinstance(x, bytes) else str(x) - if sys.version_info < (3,2,0): - from io import TextIOWrapper - class NCTextIOWrapper(TextIOWrapper): - ''' Garbage collecting an io.TextIOWrapper(buffer) instance closes - the wrapped buffer. This subclass keeps it open. ''' - def close(self): pass -else: - json_loads = json_lds + from configparser import ConfigParser + basestring = str + unicode = str + json_loads = lambda s: json_lds(touni(s)) + callable = lambda x: hasattr(x, '__call__') + imap = map + def _raise(*a): raise a[0](a[1]).with_traceback(a[2]) +else: # 2.x + import httplib + import thread + from urlparse import urljoin, SplitResult as UrlSplitResult + from urllib import urlencode, quote as urlquote, unquote as urlunquote + from Cookie import SimpleCookie + from itertools import imap + import cPickle as pickle from StringIO import StringIO as BytesIO - bytes = str - def touni(x, enc='utf8', err='strict'): - """ Convert anything to unicode """ - return x if isinstance(x, unicode) else unicode(str(x), enc, err) - -def tob(data, enc='utf8'): - """ Convert anything to bytes """ - return data.encode(enc) if isinstance(data, unicode) else bytes(data) + from ConfigParser import SafeConfigParser as ConfigParser + if py25: + msg = "Python 2.5 support may be dropped in future versions of Bottle." + warnings.warn(msg, DeprecationWarning) + from UserDict import DictMixin + def next(it): return it.next() + bytes = str + else: # 2.6, 2.7 + from collections import MutableMapping as DictMixin + unicode = unicode + json_loads = json_lds + eval(compile('def _raise(*a): raise a[0], a[1], a[2]', '<py3fix>', 'exec')) +# Some helpers for string/byte handling +def tob(s, enc='utf8'): + return s.encode(enc) if isinstance(s, unicode) else bytes(s) +def touni(s, enc='utf8', err='strict'): + return s.decode(enc, err) if isinstance(s, bytes) else unicode(s) tonat = touni if py3k else tob -tonat.__doc__ = """ Convert anything to native strings """ -def try_update_wrapper(wrapper, wrapped, *a, **ka): - try: # Bug: functools breaks if wrapper is an instane method - functools.update_wrapper(wrapper, wrapped, *a, **ka) +# 3.2 fixes cgi.FieldStorage to accept bytes (which makes a lot of sense). +# 3.1 needs a workaround. +if py31: + from io import TextIOWrapper + class NCTextIOWrapper(TextIOWrapper): + def close(self): pass # Keep wrapped buffer open. + + +# A bug in functools causes it to break if the wrapper is an instance method +def update_wrapper(wrapper, wrapped, *a, **ka): + try: functools.update_wrapper(wrapper, wrapped, *a, **ka) except AttributeError: pass -# Backward compatibility -def depr(message): - warnings.warn(message, DeprecationWarning, stacklevel=3) -# Small helpers -def makelist(data): +# These helpers are used at module level and need to be defined first. +# And yes, I know PEP-8, but sometimes a lower-case classname makes more sense. + +def depr(message, hard=False): + warnings.warn(message, DeprecationWarning, stacklevel=3) + +def makelist(data): # This is just to handy if isinstance(data, (tuple, list, set, dict)): return list(data) elif data: return [data] else: return [] @@ -161,12 +175,13 @@ class DictProperty(object): del getattr(obj, self.attr)[self.key] -class CachedProperty(object): +class cached_property(object): ''' A property that is only computed once per instance and then replaces itself with an ordinary attribute. Deleting the attribute resets the property. ''' def __init__(self, func): + self.__doc__ = getattr(func, '__doc__') self.func = func def __get__(self, obj, cls): @@ -174,10 +189,8 @@ class CachedProperty(object): value = obj.__dict__[self.func.__name__] = self.func(obj) return value -cached_property = CachedProperty - -class lazy_attribute(object): # Does not need configuration -> lower-case name +class lazy_attribute(object): ''' A property that caches itself to the class object. ''' def __init__(self, func): functools.update_wrapper(self, func, updated=[]) @@ -203,35 +216,6 @@ class BottleException(Exception): pass -#TODO: These should subclass BaseRequest - -class HTTPResponse(BottleException): - """ Used to break execution and immediately finish the response """ - def __init__(self, output='', status=200, header=None): - super(BottleException, self).__init__("HTTP Response %d" % status) - self.status = int(status) - self.output = output - self.headers = HeaderDict(header) if header else None - - def apply(self, response): - if self.headers: - for key, value in self.headers.iterallitems(): - response.headers[key] = value - response.status = self.status - - -class HTTPError(HTTPResponse): - """ Used to generate an error page """ - def __init__(self, code=500, output='Unknown Error', exception=None, - traceback=None, header=None): - super(HTTPError, self).__init__(output, code, header) - self.exception = exception - self.traceback = traceback - - def __repr__(self): - return template(ERROR_PAGE_TEMPLATE, e=self) - - @@ -251,11 +235,22 @@ class RouteReset(BottleException): class RouterUnknownModeError(RouteError): pass + class RouteSyntaxError(RouteError): - """ The route parser found something not supported by this router """ + """ The route parser found something not supported by this router. """ + class RouteBuildError(RouteError): - """ The route could not been built """ + """ The route could not be built. """ + + +def _re_flatten(p): + ''' Turn all capturing groups in a regular expression pattern into + non-capturing groups. ''' + if '(' not in p: return p + return re.sub(r'(\\*)(\(\?P<[^>]+>|\((?!\?))', + lambda m: m.group(0) if len(m.group(1)) % 2 else m.group(1) + '(?:', p) + class Router(object): ''' A Router is an ordered collection of route->target pairs. It is used to @@ -270,44 +265,40 @@ class Router(object): ''' default_pattern = '[^/]+' - default_filter = 're' - #: Sorry for the mess. It works. Trust me. - rule_syntax = re.compile('(\\\\*)'\ - '(?:(?::([a-zA-Z_][a-zA-Z_0-9]*)?()(?:#(.*?)#)?)'\ - '|(?:<([a-zA-Z_][a-zA-Z_0-9]*)?(?::([a-zA-Z_]*)'\ - '(?::((?:\\\\.|[^\\\\>]+)+)?)?)?>))') + default_filter = 're' + + #: The current CPython regexp implementation does not allow more + #: than 99 matching groups per regular expression. + _MAX_GROUPS_PER_PATTERN = 99 def __init__(self, strict=False): - self.rules = {} # A {rule: Rule} mapping - self.builder = {} # A rule/name->build_info mapping - self.static = {} # Cache for static routes: {path: {method: target}} - self.dynamic = [] # Cache for dynamic routes. See _compile() + self.rules = [] # All rules in order + self._groups = {} # index of regexes to find them in dyna_routes + self.builder = {} # Data structure for the url builder + self.static = {} # Search structure for static routes + self.dyna_routes = {} + self.dyna_regexes = {} # Search structure for dynamic routes #: If true, static routes are no longer checked first. self.strict_order = strict - self.filters = {'re': self.re_filter, 'int': self.int_filter, - 'float': self.re_filter, 'path': self.path_filter} - - def re_filter(self, conf): - return conf or self.default_pattern, None, None + self.filters = { + 're': lambda conf: + (_re_flatten(conf or self.default_pattern), None, None), + 'int': lambda conf: (r'-?\d+', int, lambda x: str(int(x))), + 'float': lambda conf: (r'-?[\d.]+', float, lambda x: str(float(x))), + 'path': lambda conf: (r'.+?', None, None)} - def int_filter(self, conf): - return r'-?\d+', int, lambda x: str(int(x)) - - def float_filter(self, conf): - return r'-?\d*\.\d+', float, lambda x: str(float(x)) - - def path_filter(self, conf): - return r'.*?', None, None - def add_filter(self, name, func): ''' Add a filter. The provided function is called with the configuration string as parameter and must return a (regexp, to_python, to_url) tuple. The first element is a string, the last two are callables or None. ''' self.filters[name] = func - - def parse_rule(self, rule): - ''' Parses a rule into a (name, filter, conf) token stream. If mode is - None, name contains a static rule part. ''' + + rule_syntax = re.compile('(\\\\*)'\ + '(?:(?::([a-zA-Z_][a-zA-Z_0-9]*)?()(?:#(.*?)#)?)'\ + '|(?:<([a-zA-Z_][a-zA-Z_0-9]*)?(?::([a-zA-Z_]*)'\ + '(?::((?:\\\\.|[^\\\\>]+)+)?)?)?>))') + + def _itertokens(self, rule): offset, prefix = 0, '' for match in self.rule_syntax.finditer(rule): prefix += rule[offset:match.start()] @@ -316,77 +307,95 @@ class Router(object): prefix += match.group(0)[len(g[0]):] offset = match.end() continue - if prefix: yield prefix, None, None - name, filtr, conf = g[1:4] if not g[2] is None else g[4:7] - if not filtr: filtr = self.default_filter - yield name, filtr, conf or None + if prefix: + yield prefix, None, None + name, filtr, conf = g[4:7] if g[2] is None else g[1:4] + yield name, filtr or 'default', conf or None offset, prefix = match.end(), '' if offset <= len(rule) or prefix: yield prefix+rule[offset:], None, None def add(self, rule, method, target, name=None): - ''' Add a new route or replace the target for an existing route. ''' - if rule in self.rules: - self.rules[rule][method] = target - if name: self.builder[name] = self.builder[rule] - return - - target = self.rules[rule] = {method: target} - - # Build pattern and other structures for dynamic routes - anons = 0 # Number of anonymous wildcards - pattern = '' # Regular expression pattern - filters = [] # Lists of wildcard input filters - builder = [] # Data structure for the URL builder + ''' Add a new rule or replace the target for an existing rule. ''' + anons = 0 # Number of anonymous wildcards found + keys = [] # Names of keys + pattern = '' # Regular expression pattern with named groups + filters = [] # Lists of wildcard input filters + builder = [] # Data structure for the URL builder is_static = True - for key, mode, conf in self.parse_rule(rule): + + for key, mode, conf in self._itertokens(rule): if mode: is_static = False + if mode == 'default': mode = self.default_filter mask, in_filter, out_filter = self.filters[mode](conf) - if key: - pattern += '(?P<%s>%s)' % (key, mask) - else: + if not key: pattern += '(?:%s)' % mask - key = 'anon%d' % anons; anons += 1 + key = 'anon%d' % anons + anons += 1 + else: + pattern += '(?P<%s>%s)' % (key, mask) + keys.append(key) if in_filter: filters.append((key, in_filter)) builder.append((key, out_filter or str)) elif key: pattern += re.escape(key) builder.append((None, key)) + self.builder[rule] = builder if name: self.builder[name] = builder if is_static and not self.strict_order: - self.static[self.build(rule)] = target + self.static.setdefault(method, {}) + self.static[method][self.build(rule)] = (target, None) return - def fpat_sub(m): - return m.group(0) if len(m.group(1)) % 2 else m.group(1) + '(?:' - flat_pattern = re.sub(r'(\\*)(\(\?P<[^>]*>|\((?!\?))', fpat_sub, pattern) - try: - re_match = re.compile('^(%s)$' % pattern).match - except re.error, e: - raise RouteSyntaxError("Could not add Route: %s (%s)" % (rule, e)) - - def match(path): - """ Return an url-argument dictionary. """ - url_args = re_match(path).groupdict() - for name, wildcard_filter in filters: - try: - url_args[name] = wildcard_filter(url_args[name]) - except ValueError: - raise HTTPError(400, 'Path has wrong format.') - return url_args + re_pattern = re.compile('^(%s)$' % pattern) + re_match = re_pattern.match + except re.error: + raise RouteSyntaxError("Could not add Route: %s (%s)" % (rule, _e())) + + if filters: + def getargs(path): + url_args = re_match(path).groupdict() + for name, wildcard_filter in filters: + try: + url_args[name] = wildcard_filter(url_args[name]) + except ValueError: + raise HTTPError(400, 'Path has wrong format.') + return url_args + elif re_pattern.groupindex: + def getargs(path): + return re_match(path).groupdict() + else: + getargs = None - try: - combined = '%s|(^%s$)' % (self.dynamic[-1][0].pattern, flat_pattern) - self.dynamic[-1] = (re.compile(combined), self.dynamic[-1][1]) - self.dynamic[-1][1].append((match, target)) - except (AssertionError, IndexError), e: # AssertionError: Too many groups - self.dynamic.append((re.compile('(^%s$)' % flat_pattern), - [(match, target)])) - return match + flatpat = _re_flatten(pattern) + whole_rule = (rule, flatpat, target, getargs) + + if (flatpat, method) in self._groups: + if DEBUG: + msg = 'Route <%s %s> overwrites a previously defined route' + warnings.warn(msg % (method, rule), RuntimeWarning) + self.dyna_routes[method][self._groups[flatpat, method]] = whole_rule + else: + self.dyna_routes.setdefault(method, []).append(whole_rule) + self._groups[flatpat, method] = len(self.dyna_routes[method]) - 1 + + self._compile(method) + + def _compile(self, method): + all_rules = self.dyna_routes[method] + comborules = self.dyna_regexes[method] = [] + maxgroups = self._MAX_GROUPS_PER_PATTERN + for x in range(0, len(all_rules), maxgroups): + some = all_rules[x:x+maxgroups] + combined = (flatpat for (_, flatpat, _, _) in some) + combined = '|'.join('(^%s$)' % flatpat for flatpat in combined) + combined = re.compile(combined).match + rules = [(target, getargs) for (_, _, target, getargs) in some] + comborules.append((combined, rules)) def build(self, _name, *anons, **query): ''' Build an URL by filling the wildcards in a rule. ''' @@ -396,36 +405,50 @@ class Router(object): for i, value in enumerate(anons): query['anon%d'%i] = value url = ''.join([f(query.pop(n)) if n else f for (n,f) in builder]) return url if not query else url+'?'+urlencode(query) - except KeyError, e: - raise RouteBuildError('Missing URL argument: %r' % e.args[0]) + except KeyError: + raise RouteBuildError('Missing URL argument: %r' % _e().args[0]) def match(self, environ): ''' Return a (target, url_agrs) tuple or raise HTTPError(400/404/405). ''' - path, targets, urlargs = environ['PATH_INFO'] or '/', None, {} - if path in self.static: - targets = self.static[path] + verb = environ['REQUEST_METHOD'].upper() + path = environ['PATH_INFO'] or '/' + target = None + if verb == 'HEAD': + methods = ['PROXY', verb, 'GET', 'ANY'] else: - for combined, rules in self.dynamic: - match = combined.match(path) - if not match: continue - getargs, targets = rules[match.lastindex - 1] - urlargs = getargs(path) if getargs else {} - break - - if not targets: - raise HTTPError(404, "Not found: " + repr(environ['PATH_INFO'])) - method = environ['REQUEST_METHOD'].upper() - if method in targets: - return targets[method], urlargs - if method == 'HEAD' and 'GET' in targets: - return targets['GET'], urlargs - if 'ANY' in targets: - return targets['ANY'], urlargs - allowed = [verb for verb in targets if verb != 'ANY'] - if 'GET' in allowed and 'HEAD' not in allowed: - allowed.append('HEAD') - raise HTTPError(405, "Method not allowed.", - header=[('Allow',",".join(allowed))]) + methods = ['PROXY', verb, 'ANY'] + + for method in methods: + if method in self.static and path in self.static[method]: + target, getargs = self.static[method][path] + return target, getargs(path) if getargs else {} + elif method in self.dyna_regexes: + for combined, rules in self.dyna_regexes[method]: + match = combined(path) + if match: + target, getargs = rules[match.lastindex - 1] + return target, getargs(path) if getargs else {} + + # No matching route found. Collect alternative methods for 405 response + allowed = set([]) + nocheck = set(methods) + for method in set(self.static) - nocheck: + if path in self.static[method]: + allowed.add(verb) + for method in set(self.dyna_regexes) - allowed - nocheck: + for combined, rules in self.dyna_regexes[method]: + match = combined(path) + if match: + allowed.add(method) + if allowed: + allow_header = ",".join(sorted(allowed)) + raise HTTPError(405, "Method not allowed.", Allow=allow_header) + + # No matching route and no alternative method found. We give up + raise HTTPError(404, "Not found: " + repr(path)) + + + @@ -435,7 +458,6 @@ class Route(object): turing an URL path rule into a regular expression usable by the Router. ''' - def __init__(self, app, rule, method, callback, name=None, plugins=None, skiplist=None, **config): #: The application this route is installed to. @@ -455,12 +477,12 @@ class Route(object): #: Additional keyword arguments passed to the :meth:`Bottle.route` #: decorator are stored in this dictionary. Used for route-specific #: plugin configuration and meta-data. - self.config = ConfigDict(config) + self.config = ConfigDict().load_dict(config, make_namespaces=True) def __call__(self, *a, **ka): depr("Some APIs changed to return Route() instances instead of"\ " callables. Make sure to use the Route.call method and not to"\ - " call Route instances directly.") + " call Route instances directly.") #0.12 return self.call(*a, **ka) @cached_property @@ -480,7 +502,7 @@ class Route(object): @property def _context(self): - depr('Switch to Plugin API v2 and access the Route object directly.') + depr('Switch to Plugin API v2 and access the Route object directly.') #0.12 return dict(rule=self.rule, method=self.method, callback=self.callback, name=self.name, app=self.app, config=self.config, apply=self.plugins, skip=self.skiplist) @@ -509,9 +531,36 @@ class Route(object): except RouteReset: # Try again with changed configuration. return self._make_callback() if not callback is self.callback: - try_update_wrapper(callback, self.callback) + update_wrapper(callback, self.callback) return callback + def get_undecorated_callback(self): + ''' Return the callback. If the callback is a decorated function, try to + recover the original function. ''' + func = self.callback + func = getattr(func, '__func__' if py3k else 'im_func', func) + closure_attr = '__closure__' if py3k else 'func_closure' + while hasattr(func, closure_attr) and getattr(func, closure_attr): + func = getattr(func, closure_attr)[0].cell_contents + return func + + def get_callback_args(self): + ''' Return a list of argument names the callback (most likely) accepts + as keyword arguments. If the callback is a decorated function, try + to recover the original function before inspection. ''' + return getargspec(self.get_undecorated_callback())[0] + + def get_config(self, key, default=None): + ''' Lookup a config field and return its value, first checking the + route.config, then route.app.config.''' + for conf in (self.config, self.app.conifg): + if key in conf: return conf[key] + return default + + def __repr__(self): + cb = self.get_undecorated_callback() + return '<%s %r %r>' % (self.method, self.rule, cb) + @@ -523,27 +572,81 @@ class Route(object): class Bottle(object): - """ WSGI application """ + """ Each Bottle object represents a single, distinct web application and + consists of routes, callbacks, plugins, resources and configuration. + Instances are callable WSGI applications. + + :param catchall: If true (default), handle all exceptions. Turn off to + let debugging middleware handle exceptions. + """ + + def __init__(self, catchall=True, autojson=True): + + #: A :class:`ConfigDict` for app specific configuration. + self.config = ConfigDict() + self.config._on_change = functools.partial(self.trigger_hook, 'config') + self.config.meta_set('autojson', 'validate', bool) + self.config.meta_set('catchall', 'validate', bool) + self.config['catchall'] = catchall + self.config['autojson'] = autojson + + #: A :class:`ResourceManager` for application files + self.resources = ResourceManager() - def __init__(self, catchall=True, autojson=True, config=None): - """ Create a new bottle instance. - You usually don't do that. Use `bottle.app.push()` instead. - """ self.routes = [] # List of installed :class:`Route` instances. self.router = Router() # Maps requests to :class:`Route` instances. - self.plugins = [] # List of installed plugins. - self.error_handler = {} - #: If true, most exceptions are catched and returned as :exc:`HTTPError` - self.config = ConfigDict(config or {}) - self.catchall = catchall - #: An instance of :class:`HooksPlugin`. Empty by default. - self.hooks = HooksPlugin() - self.install(self.hooks) - if autojson: + + # Core plugins + self.plugins = [] # List of installed plugins. + if self.config['autojson']: self.install(JSONPlugin()) self.install(TemplatePlugin()) + #: If true, most exceptions are caught and returned as :exc:`HTTPError` + catchall = DictProperty('config', 'catchall') + + __hook_names = 'before_request', 'after_request', 'app_reset', 'config' + __hook_reversed = 'after_request' + + @cached_property + def _hooks(self): + return dict((name, []) for name in self.__hook_names) + + def add_hook(self, name, func): + ''' Attach a callback to a hook. Three hooks are currently implemented: + + before_request + Executed once before each request. The request context is + available, but no routing has happened yet. + after_request + Executed once after each request regardless of its outcome. + app_reset + Called whenever :meth:`Bottle.reset` is called. + ''' + if name in self.__hook_reversed: + self._hooks[name].insert(0, func) + else: + self._hooks[name].append(func) + + def remove_hook(self, name, func): + ''' Remove a callback from a hook. ''' + if name in self._hooks and func in self._hooks[name]: + self._hooks[name].remove(func) + return True + + def trigger_hook(self, __name, *args, **kwargs): + ''' Trigger a hook and return a list of results. ''' + return [hook(*args, **kwargs) for hook in self._hooks[__name][:]] + + def hook(self, name): + """ Return a decorator that attaches a callback to a hook. See + :meth:`add_hook` for details.""" + def decorator(func): + self.add_hook(name, func) + return func + return decorator + def mount(self, prefix, app, **options): ''' Mount an application (:class:`Bottle` or plain WSGI) to a specific URL prefix. Example:: @@ -557,31 +660,50 @@ class Bottle(object): All other parameters are passed to the underlying :meth:`route` call. ''' if isinstance(app, basestring): - prefix, app = app, prefix - depr('Parameter order of Bottle.mount() changed.') # 0.10 + depr('Parameter order of Bottle.mount() changed.', True) # 0.10 - parts = filter(None, prefix.split('/')) - if not parts: raise ValueError('Empty path prefix.') - path_depth = len(parts) - options.setdefault('skip', True) - options.setdefault('method', 'ANY') + segments = [p for p in prefix.split('/') if p] + if not segments: raise ValueError('Empty path prefix.') + path_depth = len(segments) - @self.route('/%s/:#.*#' % '/'.join(parts), **options) - def mountpoint(): + def mountpoint_wrapper(): try: request.path_shift(path_depth) - rs = BaseResponse([], 200) - def start_response(status, header): + rs = HTTPResponse([]) + def start_response(status, headerlist, exc_info=None): + if exc_info: + try: + _raise(*exc_info) + finally: + exc_info = None rs.status = status - for name, value in header: rs.add_header(name, value) + for name, value in headerlist: rs.add_header(name, value) return rs.body.append - rs.body = itertools.chain(rs.body, app(request.environ, start_response)) - return HTTPResponse(rs.body, rs.status, rs.headers) + body = app(request.environ, start_response) + if body and rs.body: body = itertools.chain(rs.body, body) + rs.body = body or rs.body + return rs finally: request.path_shift(-path_depth) + options.setdefault('skip', True) + options.setdefault('method', 'PROXY') + options.setdefault('mountpoint', {'prefix': prefix, 'target': app}) + options['callback'] = mountpoint_wrapper + + self.route('/%s/<:re:.*>' % '/'.join(segments), **options) if not prefix.endswith('/'): - self.route('/' + '/'.join(parts), callback=mountpoint, **options) + self.route('/' + '/'.join(segments), **options) + + def merge(self, routes): + ''' Merge the routes of another :class:`Bottle` application or a list of + :class:`Route` objects into this application. The routes keep their + 'owner', meaning that the :data:`Route.app` attribute is not + changed. ''' + if isinstance(routes, Bottle): + routes = routes.routes + for route in routes: + self.add_route(route) def install(self, plugin): ''' Add a plugin to the list of plugins and prepare it for being @@ -620,7 +742,7 @@ class Bottle(object): for route in routes: route.reset() if DEBUG: for route in routes: route.prepare() - self.hooks.trigger('app_reset') + self.trigger_hook('app_reset') def close(self): ''' Close the application and all installed plugins. ''' @@ -628,6 +750,10 @@ class Bottle(object): if hasattr(plugin, 'close'): plugin.close() self.stopped = True + def run(self, **kwargs): + ''' Calls :func:`run` with the same parameters. ''' + run(self, **kwargs) + def match(self, environ): """ Search for a matching route and return a (:class:`Route` , urlargs) tuple. The second value is a dictionary with parameters extracted @@ -640,6 +766,13 @@ class Bottle(object): location = self.router.build(routename, **kargs).lstrip('/') return urljoin(urljoin('/', scriptname), location) + def add_route(self, route): + ''' Add a route object, but do not change the :data:`Route.app` + attribute.''' + self.routes.append(route) + self.router.add(route.rule, route.method, route, name=route.name) + if DEBUG: route.prepare() + def route(self, path=None, method='GET', callback=None, name=None, apply=None, skip=None, **config): """ A decorator to bind a function to a request URL. Example:: @@ -678,9 +811,7 @@ class Bottle(object): verb = verb.upper() route = Route(self, rule, verb, callback, name=name, plugins=plugins, skiplist=skiplist, **config) - self.routes.append(route) - self.router.add(rule, verb, route, name=name) - if DEBUG: route.prepare() + self.add_route(route) return callback return decorator(callback) if callback else decorator @@ -707,44 +838,45 @@ class Bottle(object): return handler return wrapper - def hook(self, name): - """ Return a decorator that attaches a callback to a hook. """ - def wrapper(func): - self.hooks.add(name, func) - return func - return wrapper - - def handle(self, path, method='GET'): - """ (deprecated) Execute the first matching route callback and return - the result. :exc:`HTTPResponse` exceptions are catched and returned. - If :attr:`Bottle.catchall` is true, other exceptions are catched as - well and returned as :exc:`HTTPError` instances (500). - """ - depr("This method will change semantics in 0.10. Try to avoid it.") - if isinstance(path, dict): - return self._handle(path) - return self._handle({'PATH_INFO': path, 'REQUEST_METHOD': method.upper()}) + def default_error_handler(self, res): + return tob(template(ERROR_PAGE_TEMPLATE, e=res)) def _handle(self, environ): + path = environ['bottle.raw_path'] = environ['PATH_INFO'] + if py3k: + try: + environ['PATH_INFO'] = path.encode('latin1').decode('utf8') + except UnicodeError: + return HTTPError(400, 'Invalid path string. Expected UTF-8') + try: - route, args = self.router.match(environ) - environ['route.handle'] = environ['bottle.route'] = route - environ['route.url_args'] = args - return route.call(**args) - except HTTPResponse, r: - return r + environ['bottle.app'] = self + request.bind(environ) + response.bind() + try: + self.trigger_hook('before_request') + route, args = self.router.match(environ) + environ['route.handle'] = route + environ['bottle.route'] = route + environ['route.url_args'] = args + return route.call(**args) + finally: + self.trigger_hook('after_request') + + except HTTPResponse: + return _e() except RouteReset: route.reset() return self._handle(environ) except (KeyboardInterrupt, SystemExit, MemoryError): raise - except Exception, e: + except Exception: if not self.catchall: raise - stacktrace = format_exc(10) + stacktrace = format_exc() environ['wsgi.errors'].write(stacktrace) - return HTTPError(500, "Internal Server Error", e, stacktrace) + return HTTPError(500, "Internal Server Error", _e(), stacktrace) - def _cast(self, out, request, response, peek=None): + def _cast(self, out, peek=None): """ Try to convert the parameter into something WSGI compatible and set correct HTTP headers when possible. Support: False, str, unicode, dict, HTTPResponse, HTTPError, file-like, @@ -753,7 +885,8 @@ class Bottle(object): # Empty output is done here if not out: - response['Content-Length'] = 0 + if 'Content-Length' not in response: + response['Content-Length'] = 0 return [] # Join lists of byte or unicode strings. Mixed lists are NOT supported if isinstance(out, (tuple, list))\ @@ -764,19 +897,18 @@ class Bottle(object): out = out.encode(response.charset) # Byte Strings are just returned if isinstance(out, bytes): - response['Content-Length'] = len(out) + if 'Content-Length' not in response: + response['Content-Length'] = len(out) return [out] # HTTPError or HTTPException (recursive, because they may wrap anything) # TODO: Handle these explicitly in handle() or make them iterable. if isinstance(out, HTTPError): out.apply(response) - out = self.error_handler.get(out.status, repr)(out) - if isinstance(out, HTTPResponse): - depr('Error handlers must not return :exc:`HTTPResponse`.') #0.9 - return self._cast(out, request, response) + out = self.error_handler.get(out.status_code, self.default_error_handler)(out) + return self._cast(out) if isinstance(out, HTTPResponse): out.apply(response) - return self._cast(out.output, request, response) + return self._cast(out.body) # File-like objects. if hasattr(out, 'read'): @@ -787,55 +919,59 @@ class Bottle(object): # Handle Iterables. We peek into them to detect their inner type. try: - out = iter(out) - first = out.next() + iout = iter(out) + first = next(iout) while not first: - first = out.next() + first = next(iout) except StopIteration: - return self._cast('', request, response) - except HTTPResponse, e: - first = e - except Exception, e: - first = HTTPError(500, 'Unhandled exception', e, format_exc(10)) - if isinstance(e, (KeyboardInterrupt, SystemExit, MemoryError))\ - or not self.catchall: - raise + return self._cast('') + except HTTPResponse: + first = _e() + except (KeyboardInterrupt, SystemExit, MemoryError): + raise + except Exception: + if not self.catchall: raise + first = HTTPError(500, 'Unhandled exception', _e(), format_exc()) + # These are the inner types allowed in iterator or generator objects. if isinstance(first, HTTPResponse): - return self._cast(first, request, response) - if isinstance(first, bytes): - return itertools.chain([first], out) - if isinstance(first, unicode): - return itertools.imap(lambda x: x.encode(response.charset), - itertools.chain([first], out)) - return self._cast(HTTPError(500, 'Unsupported response type: %s'\ - % type(first)), request, response) + return self._cast(first) + elif isinstance(first, bytes): + new_iter = itertools.chain([first], iout) + elif isinstance(first, unicode): + encoder = lambda x: x.encode(response.charset) + new_iter = imap(encoder, itertools.chain([first], iout)) + else: + msg = 'Unsupported response type: %s' % type(first) + return self._cast(HTTPError(500, msg)) + if hasattr(out, 'close'): + new_iter = _closeiter(new_iter, out.close) + return new_iter def wsgi(self, environ, start_response): """ The bottle WSGI-interface. """ try: - environ['bottle.app'] = self - request.bind(environ) - response.bind() - out = self._cast(self._handle(environ), request, response) + out = self._cast(self._handle(environ)) # rfc2616 section 4.3 if response._status_code in (100, 101, 204, 304)\ - or request.method == 'HEAD': + or environ['REQUEST_METHOD'] == 'HEAD': if hasattr(out, 'close'): out.close() out = [] - start_response(response._status_line, list(response.iter_headers())) + start_response(response._status_line, response.headerlist) return out except (KeyboardInterrupt, SystemExit, MemoryError): raise - except Exception, e: + except Exception: if not self.catchall: raise err = '<h1>Critical error while processing request: %s</h1>' \ - % environ.get('PATH_INFO', '/') + % html_escape(environ.get('PATH_INFO', '/')) if DEBUG: - err += '<h2>Error:</h2>\n<pre>%s</pre>\n' % repr(e) - err += '<h2>Traceback:</h2>\n<pre>%s</pre>\n' % format_exc(10) - environ['wsgi.errors'].write(err) #TODO: wsgi.error should not get html - start_response('500 INTERNAL SERVER ERROR', [('Content-Type', 'text/html')]) + err += '<h2>Error:</h2>\n<pre>\n%s\n</pre>\n' \ + '<h2>Traceback:</h2>\n<pre>\n%s\n</pre>\n' \ + % (html_escape(repr(_e())), html_escape(format_exc())) + environ['wsgi.errors'].write(err) + headers = [('Content-Type', 'text/html; charset=UTF-8')] + start_response('500 INTERNAL SERVER ERROR', headers, sys.exc_info()) return [tob(err)] def __call__(self, environ, start_response): @@ -851,20 +987,41 @@ class Bottle(object): # HTTP and WSGI Tools ########################################################## ############################################################################### - -class BaseRequest(DictMixin): +class BaseRequest(object): """ A wrapper for WSGI environment dictionaries that adds a lot of - convenient access methods and properties. Most of them are read-only.""" + convenient access methods and properties. Most of them are read-only. + + Adding new attributes to a request actually adds them to the environ + dictionary (as 'bottle.request.ext.<name>'). This is the recommended + way to store and access request-specific data. + """ + + __slots__ = ('environ') #: Maximum size of memory buffer for :attr:`body` in bytes. MEMFILE_MAX = 102400 - def __init__(self, environ): + def __init__(self, environ=None): """ Wrap a WSGI environ dictionary. """ #: The wrapped WSGI environ dictionary. This is the only real attribute. #: All other attributes actually are read-only properties. - self.environ = environ - environ['bottle.request'] = self + self.environ = {} if environ is None else environ + self.environ['bottle.request'] = self + + @DictProperty('environ', 'bottle.app', read_only=True) + def app(self): + ''' Bottle application handling this request. ''' + raise RuntimeError('This request is not connected to an application.') + + @DictProperty('environ', 'bottle.route', read_only=True) + def route(self): + """ The bottle :class:`Route` object that matches this request. """ + raise RuntimeError('This request is not connected to a route.') + + @DictProperty('environ', 'route.url_args', read_only=True) + def url_args(self): + """ The arguments extracted from the URL. """ + raise RuntimeError('This request is not connected to a route.') @property def path(self): @@ -891,8 +1048,8 @@ class BaseRequest(DictMixin): def cookies(self): """ Cookies parsed into a :class:`FormsDict`. Signed cookies are NOT decoded. Use :meth:`get_cookie` if you expect signed cookies. """ - cookies = SimpleCookie(self.environ.get('HTTP_COOKIE','')) - return FormsDict((c.key, c.value) for c in cookies.itervalues()) + cookies = SimpleCookie(self.environ.get('HTTP_COOKIE','')).values() + return FormsDict((c.key, c.value) for c in cookies) def get_cookie(self, key, default=None, secret=None): """ Return the content of a cookie. To read a `Signed Cookie`, the @@ -911,22 +1068,21 @@ class BaseRequest(DictMixin): values are sometimes called "URL arguments" or "GET parameters", but not to be confused with "URL wildcards" as they are provided by the :class:`Router`. ''' - data = parse_qs(self.query_string, keep_blank_values=True) get = self.environ['bottle.get'] = FormsDict() - for key, values in data.iteritems(): - for value in values: - get[key] = value + pairs = _parse_qsl(self.environ.get('QUERY_STRING', '')) + for key, value in pairs: + get[key] = value return get @DictProperty('environ', 'bottle.request.forms', read_only=True) def forms(self): """ Form values parsed from an `url-encoded` or `multipart/form-data` - encoded POST or PUT request body. The result is retuned as a + encoded POST or PUT request body. The result is returned as a :class:`FormsDict`. All keys and values are strings. File uploads are stored separately in :attr:`files`. """ forms = FormsDict() - for name, item in self.POST.iterallitems(): - if not hasattr(item, 'filename'): + for name, item in self.POST.allitems(): + if not isinstance(item, FileUpload): forms[name] = item return forms @@ -935,32 +1091,21 @@ class BaseRequest(DictMixin): """ A :class:`FormsDict` with the combined values of :attr:`query` and :attr:`forms`. File uploads are stored in :attr:`files`. """ params = FormsDict() - for key, value in self.query.iterallitems(): + for key, value in self.query.allitems(): params[key] = value - for key, value in self.forms.iterallitems(): + for key, value in self.forms.allitems(): params[key] = value return params @DictProperty('environ', 'bottle.request.files', read_only=True) def files(self): - """ File uploads parsed from an `url-encoded` or `multipart/form-data` - encoded POST or PUT request body. The values are instances of - :class:`cgi.FieldStorage`. The most important attributes are: - - filename - The filename, if specified; otherwise None; this is the client - side filename, *not* the file name on which it is stored (that's - a temporary file you don't deal with) - file - The file(-like) object from which you can read the data. - value - The value as a *string*; for file uploads, this transparently - reads the file every time you request the value. Do not do this - on big files. + """ File uploads parsed from `multipart/form-data` encoded POST or PUT + request body. The values are instances of :class:`FileUpload`. + """ files = FormsDict() - for name, item in self.POST.iterallitems(): - if hasattr(item, 'filename'): + for name, item in self.POST.allitems(): + if isinstance(item, FileUpload): files[name] = item return files @@ -970,25 +1115,75 @@ class BaseRequest(DictMixin): property holds the parsed content of the request body. Only requests smaller than :attr:`MEMFILE_MAX` are processed to avoid memory exhaustion. ''' - if 'application/json' in self.environ.get('CONTENT_TYPE', '') \ - and 0 < self.content_length < self.MEMFILE_MAX: - return json_loads(self.body.read(self.MEMFILE_MAX)) + ctype = self.environ.get('CONTENT_TYPE', '').lower().split(';')[0] + if ctype == 'application/json': + return json_loads(self._get_body_string()) return None - @DictProperty('environ', 'bottle.request.body', read_only=True) - def _body(self): + def _iter_body(self, read, bufsize): maxread = max(0, self.content_length) - stream = self.environ['wsgi.input'] - body = BytesIO() if maxread < self.MEMFILE_MAX else TemporaryFile(mode='w+b') - while maxread > 0: - part = stream.read(min(maxread, self.MEMFILE_MAX)) + while maxread: + part = read(min(maxread, bufsize)) if not part: break - body.write(part) + yield part maxread -= len(part) + + def _iter_chunked(self, read, bufsize): + err = HTTPError(400, 'Error while parsing chunked transfer body.') + rn, sem, bs = tob('\r\n'), tob(';'), tob('') + while True: + header = read(1) + while header[-2:] != rn: + c = read(1) + header += c + if not c: raise err + if len(header) > bufsize: raise err + size, _, _ = header.partition(sem) + try: + maxread = int(tonat(size.strip()), 16) + except ValueError: + raise err + if maxread == 0: break + buff = bs + while maxread > 0: + if not buff: + buff = read(min(maxread, bufsize)) + part, buff = buff[:maxread], buff[maxread:] + if not part: raise err + yield part + maxread -= len(part) + if read(2) != rn: + raise err + + @DictProperty('environ', 'bottle.request.body', read_only=True) + def _body(self): + body_iter = self._iter_chunked if self.chunked else self._iter_body + read_func = self.environ['wsgi.input'].read + body, body_size, is_temp_file = BytesIO(), 0, False + for part in body_iter(read_func, self.MEMFILE_MAX): + body.write(part) + body_size += len(part) + if not is_temp_file and body_size > self.MEMFILE_MAX: + body, tmp = TemporaryFile(mode='w+b'), body + body.write(tmp.getvalue()) + del tmp + is_temp_file = True self.environ['wsgi.input'] = body body.seek(0) return body + def _get_body_string(self): + ''' read body until content-length or MEMFILE_MAX into a string. Raise + HTTPError(413) on requests that are to large. ''' + clen = self.content_length + if clen > self.MEMFILE_MAX: + raise HTTPError(413, 'Request to large') + if clen < 0: clen = self.MEMFILE_MAX + 1 + data = self.body.read(clen) + if len(data) > self.MEMFILE_MAX: # Fail fast + raise HTTPError(413, 'Request to large') + return data + @property def body(self): """ The HTTP request body as a seek-able file-like object. Depending on @@ -999,6 +1194,11 @@ class BaseRequest(DictMixin): self._body.seek(0) return self._body + @property + def chunked(self): + ''' True if Chunked transfer encoding was. ''' + return 'chunked' in self.environ.get('HTTP_TRANSFER_ENCODING', '').lower() + #: An alias for :attr:`query`. GET = query @@ -1009,25 +1209,35 @@ class BaseRequest(DictMixin): instances of :class:`cgi.FieldStorage` (file uploads). """ post = FormsDict() + # We default to application/x-www-form-urlencoded for everything that + # is not multipart and take the fast path (also: 3.1 workaround) + if not self.content_type.startswith('multipart/'): + pairs = _parse_qsl(tonat(self._get_body_string(), 'latin1')) + for key, value in pairs: + post[key] = value + return post + safe_env = {'QUERY_STRING':''} # Build a safe environment for cgi for key in ('REQUEST_METHOD', 'CONTENT_TYPE', 'CONTENT_LENGTH'): if key in self.environ: safe_env[key] = self.environ[key] - if NCTextIOWrapper: - fb = NCTextIOWrapper(self.body, encoding='ISO-8859-1', newline='\n') - else: - fb = self.body - data = cgi.FieldStorage(fp=fb, environ=safe_env, keep_blank_values=True) - for item in data.list or []: - post[item.name] = item if item.filename else item.value + args = dict(fp=self.body, environ=safe_env, keep_blank_values=True) + if py31: + args['fp'] = NCTextIOWrapper(args['fp'], encoding='utf8', + newline='\n') + elif py3k: + args['encoding'] = 'utf8' + data = cgi.FieldStorage(**args) + self['_cgi.FieldStorage'] = data #http://bugs.python.org/issue18394#msg207958 + data = data.list or [] + for item in data: + if item.filename: + post[item.name] = FileUpload(item.file, item.name, + item.filename, item.headers) + else: + post[item.name] = item.value return post @property - def COOKIES(self): - ''' Alias for :attr:`cookies` (deprecated). ''' - depr('BaseRequest.COOKIES was renamed to BaseRequest.cookies (lowercase).') - return self.cookies - - @property def url(self): """ The full request URI including hostname and scheme. If your app lives behind a reverse proxy or load balancer and you get confusing @@ -1042,7 +1252,7 @@ class BaseRequest(DictMixin): but the fragment is always empty because it is not visible to the server. ''' env = self.environ - http = env.get('wsgi.url_scheme', 'http') + http = env.get('HTTP_X_FORWARDED_PROTO') or env.get('wsgi.url_scheme', 'http') host = env.get('HTTP_X_FORWARDED_HOST') or env.get('HTTP_HOST') if not host: # HTTP 1.1 requires a Host-header. This is for HTTP/1.0 clients. @@ -1091,6 +1301,11 @@ class BaseRequest(DictMixin): return int(self.environ.get('CONTENT_LENGTH') or -1) @property + def content_type(self): + ''' The Content-Type header as a lowercase-string (default: empty). ''' + return self.environ.get('CONTENT_TYPE', '').lower() + + @property def is_xhr(self): ''' True if the request was triggered by a XMLHttpRequest. This only works with JavaScript libraries that support the `X-Requested-With` @@ -1139,6 +1354,7 @@ class BaseRequest(DictMixin): """ Return a new :class:`Request` with a shallow :attr:`environ` copy. """ return Request(self.environ.copy()) + def get(self, value, default=None): return self.environ.get(value, default) def __getitem__(self, key): return self.environ[key] def __delitem__(self, key): self[key] = ""; del(self.environ[key]) def __iter__(self): return iter(self.environ) @@ -1166,27 +1382,41 @@ class BaseRequest(DictMixin): def __repr__(self): return '<%s: %s %s>' % (self.__class__.__name__, self.method, self.url) + def __getattr__(self, name): + ''' Search in self.environ for additional user defined attributes. ''' + try: + var = self.environ['bottle.request.ext.%s'%name] + return var.__get__(self) if hasattr(var, '__get__') else var + except KeyError: + raise AttributeError('Attribute %r not defined.' % name) + + def __setattr__(self, name, value): + if name == 'environ': return object.__setattr__(self, name, value) + self.environ['bottle.request.ext.%s'%name] = value + + + + def _hkey(s): return s.title().replace('_','-') class HeaderProperty(object): def __init__(self, name, reader=None, writer=str, default=''): - self.name, self.reader, self.writer, self.default = name, reader, writer, default + self.name, self.default = name, default + self.reader, self.writer = reader, writer self.__doc__ = 'Current value of the %r header.' % name.title() def __get__(self, obj, cls): if obj is None: return self - value = obj.headers.get(self.name) - return self.reader(value) if (value and self.reader) else (value or self.default) + value = obj.headers.get(self.name, self.default) + return self.reader(value) if self.reader else value def __set__(self, obj, value): - if self.writer: value = self.writer(value) - obj.headers[self.name] = value + obj.headers[self.name] = self.writer(value) def __delete__(self, obj): - if self.name in obj.headers: - del obj.headers[self.name] + del obj.headers[self.name] class BaseResponse(object): @@ -1195,6 +1425,14 @@ class BaseResponse(object): This class does support dict-like case-insensitive item-access to headers, but is NOT a dict. Most notably, iterating over a response yields parts of the body and not the headers. + + :param body: The response body as one of the supported types. + :param status: Either an HTTP status code (e.g. 200) or a status line + including the reason phrase (e.g. '200 OK'). + :param headers: A dictionary or a list of name-value pairs. + + Additional keyword arguments are added to the list of headers. + Underscores in the header name are replaced with dashes. """ default_status = 200 @@ -1208,22 +1446,30 @@ class BaseResponse(object): 'Content-Length', 'Content-Range', 'Content-Type', 'Content-Md5', 'Last-Modified'))} - def __init__(self, body='', status=None, **headers): - self._status_line = None - self._status_code = None - self.body = body + def __init__(self, body='', status=None, headers=None, **more_headers): self._cookies = None - self._headers = {'Content-Type': [self.default_content_type]} + self._headers = {} + self.body = body self.status = status or self.default_status if headers: - for name, value in headers.items(): - self[name] = value - - def copy(self): + if isinstance(headers, dict): + headers = headers.items() + for name, value in headers: + self.add_header(name, value) + if more_headers: + for name, value in more_headers.items(): + self.add_header(name, value) + + def copy(self, cls=None): ''' Returns a copy of self. ''' - copy = Response() + cls = cls or BaseResponse + assert issubclass(cls, BaseResponse) + copy = cls() copy.status = self.status copy._headers = dict((k, v[:]) for (k, v) in self._headers.items()) + if self._cookies: + copy._cookies = SimpleCookie() + copy._cookies.load(self._cookies.output()) return copy def __iter__(self): @@ -1253,26 +1499,24 @@ class BaseResponse(object): raise ValueError('String status line without a reason phrase.') if not 100 <= code <= 999: raise ValueError('Status code out of range.') self._status_code = code - self._status_line = status or ('%d Unknown' % code) + self._status_line = str(status or ('%d Unknown' % code)) def _get_status(self): - depr('BaseReuqest.status will change to return a string in 0.11. Use'\ - ' status_line and status_code to make sure.') #0.10 - return self._status_code + return self._status_line status = property(_get_status, _set_status, None, ''' A writeable property to change the HTTP response status. It accepts either a numeric code (100-999) or a string with a custom reason phrase (e.g. "404 Brain not found"). Both :data:`status_line` and - :data:`status_code` are updates accordingly. The return value is - always a numeric code. ''') + :data:`status_code` are updated accordingly. The return value is + always a status string. ''') del _get_status, _set_status @property def headers(self): ''' An instance of :class:`HeaderDict`, a case-insensitive dict-like view on the response headers. ''' - self.__dict__['headers'] = hdict = HeaderDict() + hdict = HeaderDict() hdict.dict = self._headers return hdict @@ -1286,13 +1530,10 @@ class BaseResponse(object): header with that name, return a default value. ''' return self._headers.get(_hkey(name), [default])[-1] - def set_header(self, name, value, append=False): + def set_header(self, name, value): ''' Create a new response header, replacing any previously defined headers with the same name. ''' - if append: - self.add_header(name, value) - else: - self._headers[_hkey(name)] = [str(value)] + self._headers[_hkey(name)] = [str(value)] def add_header(self, name, value): ''' Add an additional response header, not removing duplicates. ''' @@ -1301,44 +1542,36 @@ class BaseResponse(object): def iter_headers(self): ''' Yield (header, value) tuples, skipping headers that are not allowed with the current response status code. ''' - headers = self._headers.iteritems() - bad_headers = self.bad_headers.get(self.status_code) - if bad_headers: - headers = [h for h in headers if h[0] not in bad_headers] - for name, values in headers: - for value in values: - yield name, value - if self._cookies: - for c in self._cookies.values(): - yield 'Set-Cookie', c.OutputString() - - def wsgiheader(self): - depr('The wsgiheader method is deprecated. See headerlist.') #0.10 return self.headerlist @property def headerlist(self): ''' WSGI conform list of (header, value) tuples. ''' - return list(self.iter_headers()) + out = [] + headers = list(self._headers.items()) + if 'Content-Type' not in self._headers: + headers.append(('Content-Type', [self.default_content_type])) + if self._status_code in self.bad_headers: + bad_headers = self.bad_headers[self._status_code] + headers = [h for h in headers if h[0] not in bad_headers] + out += [(name, val) for name, vals in headers for val in vals] + if self._cookies: + for c in self._cookies.values(): + out.append(('Set-Cookie', c.OutputString())) + return out content_type = HeaderProperty('Content-Type') content_length = HeaderProperty('Content-Length', reader=int) + expires = HeaderProperty('Expires', + reader=lambda x: datetime.utcfromtimestamp(parse_date(x)), + writer=lambda x: http_date(x)) @property - def charset(self): + def charset(self, default='UTF-8'): """ Return the charset specified in the content-type header (default: utf8). """ if 'charset=' in self.content_type: return self.content_type.split('charset=')[-1].split(';')[0].strip() - return 'UTF-8' - - @property - def COOKIES(self): - """ A dict-like SimpleCookie instance. This should not be used directly. - See :meth:`set_cookie`. """ - depr('The COOKIES dict is deprecated. Use `set_cookie()` instead.') # 0.10 - if not self._cookies: - self._cookies = SimpleCookie() - return self._cookies + return default def set_cookie(self, name, value, secret=None, **options): ''' Create a new cookie or replace an old one. If the `secret` parameter is @@ -1384,7 +1617,7 @@ class BaseResponse(object): if len(value) > 4096: raise ValueError('Cookie value to long.') self._cookies[name] = value - for key, value in options.iteritems(): + for key, value in options.items(): if key == 'max_age': if isinstance(value, timedelta): value = value.seconds + value.days * 24 * 3600 @@ -1410,19 +1643,65 @@ class BaseResponse(object): return out -class LocalRequest(BaseRequest, threading.local): - ''' A thread-local subclass of :class:`BaseRequest`. ''' - def __init__(self): pass +def local_property(name=None): + if name: depr('local_property() is deprecated and will be removed.') #0.12 + ls = threading.local() + def fget(self): + try: return ls.var + except AttributeError: + raise RuntimeError("Request context not initialized.") + def fset(self, value): ls.var = value + def fdel(self): del ls.var + return property(fget, fset, fdel, 'Thread-local property') + + +class LocalRequest(BaseRequest): + ''' A thread-local subclass of :class:`BaseRequest` with a different + set of attributes for each thread. There is usually only one global + instance of this class (:data:`request`). If accessed during a + request/response cycle, this instance always refers to the *current* + request (even on a multithreaded server). ''' bind = BaseRequest.__init__ + environ = local_property() -class LocalResponse(BaseResponse, threading.local): - ''' A thread-local subclass of :class:`BaseResponse`. ''' +class LocalResponse(BaseResponse): + ''' A thread-local subclass of :class:`BaseResponse` with a different + set of attributes for each thread. There is usually only one global + instance of this class (:data:`response`). Its attributes are used + to build the HTTP response at the end of the request/response cycle. + ''' bind = BaseResponse.__init__ + _status_line = local_property() + _status_code = local_property() + _cookies = local_property() + _headers = local_property() + body = local_property() -Response = LocalResponse # BC 0.9 -Request = LocalRequest # BC 0.9 +Request = BaseRequest +Response = BaseResponse + + +class HTTPResponse(Response, BottleException): + def __init__(self, body='', status=None, headers=None, **more_headers): + super(HTTPResponse, self).__init__(body, status, headers, **more_headers) + + def apply(self, response): + response._status_code = self._status_code + response._status_line = self._status_line + response._headers = self._headers + response._cookies = self._cookies + response.body = self.body + + +class HTTPError(HTTPResponse): + default_status = 500 + def __init__(self, status=None, body=None, exception=None, traceback=None, + **options): + self.exception = exception + self.traceback = traceback + super(HTTPError, self).__init__(body, status, **options) @@ -1434,6 +1713,7 @@ Request = LocalRequest # BC 0.9 class PluginError(BottleException): pass + class JSONPlugin(object): name = 'json' api = 2 @@ -1441,63 +1721,26 @@ class JSONPlugin(object): def __init__(self, json_dumps=json_dumps): self.json_dumps = json_dumps - def apply(self, callback, context): + def apply(self, callback, route): dumps = self.json_dumps if not dumps: return callback def wrapper(*a, **ka): - rv = callback(*a, **ka) + try: + rv = callback(*a, **ka) + except HTTPError: + rv = _e() + if isinstance(rv, dict): #Attempt to serialize, raises exception on failure json_response = dumps(rv) #Set content type only if serialization succesful response.content_type = 'application/json' return json_response + elif isinstance(rv, HTTPResponse) and isinstance(rv.body, dict): + rv.body = dumps(rv.body) + rv.content_type = 'application/json' return rv - return wrapper - -class HooksPlugin(object): - name = 'hooks' - api = 2 - - _names = 'before_request', 'after_request', 'app_reset' - - def __init__(self): - self.hooks = dict((name, []) for name in self._names) - self.app = None - - def _empty(self): - return not (self.hooks['before_request'] or self.hooks['after_request']) - - def setup(self, app): - self.app = app - - def add(self, name, func): - ''' Attach a callback to a hook. ''' - was_empty = self._empty() - self.hooks.setdefault(name, []).append(func) - if self.app and was_empty and not self._empty(): self.app.reset() - - def remove(self, name, func): - ''' Remove a callback from a hook. ''' - was_empty = self._empty() - if name in self.hooks and func in self.hooks[name]: - self.hooks[name].remove(func) - if self.app and not was_empty and self._empty(): self.app.reset() - - def trigger(self, name, *a, **ka): - ''' Trigger a hook and return a list of results. ''' - hooks = self.hooks[name] - if ka.pop('reversed', False): hooks = hooks[::-1] - return [hook(*a, **ka) for hook in hooks] - - def apply(self, callback, context): - if self._empty(): return callback - def wrapper(*a, **ka): - self.trigger('before_request') - rv = callback(*a, **ka) - self.trigger('after_request', reversed=True) - return rv return wrapper @@ -1513,9 +1756,6 @@ class TemplatePlugin(object): conf = route.config.get('template') if isinstance(conf, (tuple, list)) and len(conf) == 2: return view(conf[0], **conf[1])(callback) - elif isinstance(conf, str) and 'template_opts' in route.config: - depr('The `template_opts` parameter is deprecated.') #0.9 - return view(conf, **route.config['template_opts'])(callback) elif isinstance(conf, str): return view(conf)(callback) else: @@ -1535,13 +1775,13 @@ class _ImportRedirect(object): def find_module(self, fullname, path=None): if '.' not in fullname: return - packname, modname = fullname.rsplit('.', 1) + packname = fullname.rsplit('.', 1)[0] if packname != self.name: return return self def load_module(self, fullname): if fullname in sys.modules: return sys.modules[fullname] - packname, modname = fullname.rsplit('.', 1) + modname = fullname.rsplit('.', 1)[1] realname = self.impmask % modname __import__(realname) module = sys.modules[fullname] = sys.modules[realname] @@ -1566,26 +1806,37 @@ class MultiDict(DictMixin): """ def __init__(self, *a, **k): - self.dict = dict((k, [v]) for k, v in dict(*a, **k).iteritems()) + self.dict = dict((k, [v]) for (k, v) in dict(*a, **k).items()) + def __len__(self): return len(self.dict) def __iter__(self): return iter(self.dict) def __contains__(self, key): return key in self.dict def __delitem__(self, key): del self.dict[key] def __getitem__(self, key): return self.dict[key][-1] def __setitem__(self, key, value): self.append(key, value) - def iterkeys(self): return self.dict.iterkeys() - def itervalues(self): return (v[-1] for v in self.dict.itervalues()) - def iteritems(self): return ((k, v[-1]) for (k, v) in self.dict.iteritems()) - def iterallitems(self): - for key, values in self.dict.iteritems(): - for value in values: - yield key, value - - # 2to3 is not able to fix these automatically. - keys = iterkeys if py3k else lambda self: list(self.iterkeys()) - values = itervalues if py3k else lambda self: list(self.itervalues()) - items = iteritems if py3k else lambda self: list(self.iteritems()) - allitems = iterallitems if py3k else lambda self: list(self.iterallitems()) + def keys(self): return self.dict.keys() + + if py3k: + def values(self): return (v[-1] for v in self.dict.values()) + def items(self): return ((k, v[-1]) for k, v in self.dict.items()) + def allitems(self): + return ((k, v) for k, vl in self.dict.items() for v in vl) + iterkeys = keys + itervalues = values + iteritems = items + iterallitems = allitems + + else: + def values(self): return [v[-1] for v in self.dict.values()] + def items(self): return [(k, v[-1]) for k, v in self.dict.items()] + def iterkeys(self): return self.dict.iterkeys() + def itervalues(self): return (v[-1] for v in self.dict.itervalues()) + def iteritems(self): + return ((k, v[-1]) for k, v in self.dict.iteritems()) + def iterallitems(self): + return ((k, v) for k, vl in self.dict.iteritems() for v in vl) + def allitems(self): + return [(k, v) for k, vl in self.dict.iteritems() for v in vl] def get(self, key, default=None, index=-1, type=None): ''' Return the most recent value for a key. @@ -1600,7 +1851,7 @@ class MultiDict(DictMixin): try: val = self.dict[key][index] return type(val) if type else val - except Exception, e: + except Exception: pass return default @@ -1621,30 +1872,51 @@ class MultiDict(DictMixin): getlist = getall - class FormsDict(MultiDict): ''' This :class:`MultiDict` subclass is used to store request form data. Additionally to the normal dict-like item access methods (which return unmodified data as native strings), this container also supports - attribute-like access to its values. Attribues are automatiically de- or - recoded to match :attr:`input_encoding` (default: 'utf8'). Missing + attribute-like access to its values. Attributes are automatically de- + or recoded to match :attr:`input_encoding` (default: 'utf8'). Missing attributes default to an empty string. ''' #: Encoding used for attribute values. input_encoding = 'utf8' + #: If true (default), unicode strings are first encoded with `latin1` + #: and then decoded to match :attr:`input_encoding`. + recode_unicode = True + + def _fix(self, s, encoding=None): + if isinstance(s, unicode) and self.recode_unicode: # Python 3 WSGI + return s.encode('latin1').decode(encoding or self.input_encoding) + elif isinstance(s, bytes): # Python 2 WSGI + return s.decode(encoding or self.input_encoding) + else: + return s + + def decode(self, encoding=None): + ''' Returns a copy with all keys and values de- or recoded to match + :attr:`input_encoding`. Some libraries (e.g. WTForms) want a + unicode dictionary. ''' + copy = FormsDict() + enc = copy.input_encoding = encoding or self.input_encoding + copy.recode_unicode = False + for key, value in self.allitems(): + copy.append(self._fix(key, enc), self._fix(value, enc)) + return copy def getunicode(self, name, default=None, encoding=None): - value, enc = self.get(name, default), encoding or self.input_encoding + ''' Return the value as a unicode string, or the default. ''' try: - if isinstance(value, bytes): # Python 2 WSGI - return value.decode(enc) - elif isinstance(value, unicode): # Python 3 WSGI - return value.encode('latin1').decode(enc) - return value - except UnicodeError, e: + return self._fix(self[name], encoding) + except (UnicodeError, KeyError): return default - def __getattr__(self, name): return self.getunicode(name, default=u'') + def __getattr__(self, name, default=unicode()): + # Without this guard, pickle generates a cryptic TypeError: + if name.startswith('__') and name.endswith('__'): + return super(FormsDict, self).__getattr__(name) + return self.getunicode(name, default=default) class HeaderDict(MultiDict): @@ -1666,7 +1938,7 @@ class HeaderDict(MultiDict): def get(self, key, default=None, index=-1): return MultiDict.get(self, _hkey(key), default, index) def filter(self, names): - for name in map(_hkey, names): + for name in [_hkey(n) for n in names]: if name in self.dict: del self.dict[name] @@ -1682,7 +1954,7 @@ class WSGIHeaderDict(DictMixin): Currently PEP 333, 444 and 3333 are supported. (PEP 444 is the only one that uses non-native strings.) ''' - #: List of keys that do not have a 'HTTP_' prefix. + #: List of keys that do not have a ``HTTP_`` prefix. cgikeys = ('CONTENT_TYPE', 'CONTENT_LENGTH') def __init__(self, environ): @@ -1720,39 +1992,212 @@ class WSGIHeaderDict(DictMixin): def __contains__(self, key): return self._ekey(key) in self.environ + class ConfigDict(dict): - ''' A dict-subclass with some extras: You can access keys like attributes. - Uppercase attributes create new ConfigDicts and act as name-spaces. - Other missing attributes return None. Calling a ConfigDict updates its - values and returns itself. - - >>> cfg = ConfigDict() - >>> cfg.Namespace.value = 5 - >>> cfg.OtherNamespace(a=1, b=2) - >>> cfg - {'Namespace': {'value': 5}, 'OtherNamespace': {'a': 1, 'b': 2}} + ''' A dict-like configuration storage with additional support for + namespaces, validators, meta-data, on_change listeners and more. + + This storage is optimized for fast read access. Retrieving a key + or using non-altering dict methods (e.g. `dict.get()`) has no overhead + compared to a native dict. ''' + __slots__ = ('_meta', '_on_change') + + class Namespace(DictMixin): + + def __init__(self, config, namespace): + self._config = config + self._prefix = namespace + + def __getitem__(self, key): + depr('Accessing namespaces as dicts is discouraged. ' + 'Only use flat item access: ' + 'cfg["names"]["pace"]["key"] -> cfg["name.space.key"]') #0.12 + return self._config[self._prefix + '.' + key] + + def __setitem__(self, key, value): + self._config[self._prefix + '.' + key] = value + + def __delitem__(self, key): + del self._config[self._prefix + '.' + key] + + def __iter__(self): + ns_prefix = self._prefix + '.' + for key in self._config: + ns, dot, name = key.rpartition('.') + if ns == self._prefix and name: + yield name + + def keys(self): return [x for x in self] + def __len__(self): return len(self.keys()) + def __contains__(self, key): return self._prefix + '.' + key in self._config + def __repr__(self): return '<Config.Namespace %s.*>' % self._prefix + def __str__(self): return '<Config.Namespace %s.*>' % self._prefix + + # Deprecated ConfigDict features + def __getattr__(self, key): + depr('Attribute access is deprecated.') #0.12 + if key not in self and key[0].isupper(): + self[key] = ConfigDict.Namespace(self._config, self._prefix + '.' + key) + if key not in self and key.startswith('__'): + raise AttributeError(key) + return self.get(key) + + def __setattr__(self, key, value): + if key in ('_config', '_prefix'): + self.__dict__[key] = value + return + depr('Attribute assignment is deprecated.') #0.12 + if hasattr(DictMixin, key): + raise AttributeError('Read-only attribute.') + if key in self and self[key] and isinstance(self[key], self.__class__): + raise AttributeError('Non-empty namespace attribute.') + self[key] = value + + def __delattr__(self, key): + if key in self: + val = self.pop(key) + if isinstance(val, self.__class__): + prefix = key + '.' + for key in self: + if key.startswith(prefix): + del self[prefix+key] + + def __call__(self, *a, **ka): + depr('Calling ConfDict is deprecated. Use the update() method.') #0.12 + self.update(*a, **ka) + return self + + def __init__(self, *a, **ka): + self._meta = {} + self._on_change = lambda name, value: None + if a or ka: + depr('Constructor does no longer accept parameters.') #0.12 + self.update(*a, **ka) + + def load_config(self, filename): + ''' Load values from an *.ini style config file. + + If the config file contains sections, their names are used as + namespaces for the values within. The two special sections + ``DEFAULT`` and ``bottle`` refer to the root namespace (no prefix). + ''' + conf = ConfigParser() + conf.read(filename) + for section in conf.sections(): + for key, value in conf.items(section): + if section not in ('DEFAULT', 'bottle'): + key = section + '.' + key + self[key] = value + return self + + def load_dict(self, source, namespace='', make_namespaces=False): + ''' Import values from a dictionary structure. Nesting can be used to + represent namespaces. + + >>> ConfigDict().load_dict({'name': {'space': {'key': 'value'}}}) + {'name.space.key': 'value'} + ''' + stack = [(namespace, source)] + while stack: + prefix, source = stack.pop() + if not isinstance(source, dict): + raise TypeError('Source is not a dict (r)' % type(key)) + for key, value in source.items(): + if not isinstance(key, str): + raise TypeError('Key is not a string (%r)' % type(key)) + full_key = prefix + '.' + key if prefix else key + if isinstance(value, dict): + stack.append((full_key, value)) + if make_namespaces: + self[full_key] = self.Namespace(self, full_key) + else: + self[full_key] = value + return self + def update(self, *a, **ka): + ''' If the first parameter is a string, all keys are prefixed with this + namespace. Apart from that it works just as the usual dict.update(). + Example: ``update('some.namespace', key='value')`` ''' + prefix = '' + if a and isinstance(a[0], str): + prefix = a[0].strip('.') + '.' + a = a[1:] + for key, value in dict(*a, **ka).items(): + self[prefix+key] = value + + def setdefault(self, key, value): + if key not in self: + self[key] = value + return self[key] + + def __setitem__(self, key, value): + if not isinstance(key, str): + raise TypeError('Key has type %r (not a string)' % type(key)) + + value = self.meta_get(key, 'filter', lambda x: x)(value) + if key in self and self[key] is value: + return + self._on_change(key, value) + dict.__setitem__(self, key, value) + + def __delitem__(self, key): + dict.__delitem__(self, key) + + def clear(self): + for key in self: + del self[key] + + def meta_get(self, key, metafield, default=None): + ''' Return the value of a meta field for a key. ''' + return self._meta.get(key, {}).get(metafield, default) + + def meta_set(self, key, metafield, value): + ''' Set the meta field for a key to a new value. This triggers the + on-change handler for existing keys. ''' + self._meta.setdefault(key, {})[metafield] = value + if key in self: + self[key] = self[key] + + def meta_list(self, key): + ''' Return an iterable of meta field names defined for a key. ''' + return self._meta.get(key, {}).keys() + + # Deprecated ConfigDict features def __getattr__(self, key): + depr('Attribute access is deprecated.') #0.12 if key not in self and key[0].isupper(): - self[key] = ConfigDict() + self[key] = self.Namespace(self, key) + if key not in self and key.startswith('__'): + raise AttributeError(key) return self.get(key) def __setattr__(self, key, value): + if key in self.__slots__: + return dict.__setattr__(self, key, value) + depr('Attribute assignment is deprecated.') #0.12 if hasattr(dict, key): raise AttributeError('Read-only attribute.') - if key in self and self[key] and isinstance(self[key], ConfigDict): + if key in self and self[key] and isinstance(self[key], self.Namespace): raise AttributeError('Non-empty namespace attribute.') self[key] = value def __delattr__(self, key): - if key in self: del self[key] + if key in self: + val = self.pop(key) + if isinstance(val, self.Namespace): + prefix = key + '.' + for key in self: + if key.startswith(prefix): + del self[prefix+key] def __call__(self, *a, **ka): - for key, value in dict(*a, **ka).iteritems(): setattr(self, key, value) + depr('Calling ConfDict is deprecated. Use the update() method.') #0.12 + self.update(*a, **ka) return self + class AppStack(list): """ A stack-like list. Calling it returns the head of the stack. """ @@ -1770,17 +2215,182 @@ class AppStack(list): class WSGIFileWrapper(object): - def __init__(self, fp, buffer_size=1024*64): - self.fp, self.buffer_size = fp, buffer_size - for attr in ('fileno', 'close', 'read', 'readlines'): - if hasattr(fp, attr): setattr(self, attr, getattr(fp, attr)) + def __init__(self, fp, buffer_size=1024*64): + self.fp, self.buffer_size = fp, buffer_size + for attr in ('fileno', 'close', 'read', 'readlines', 'tell', 'seek'): + if hasattr(fp, attr): setattr(self, attr, getattr(fp, attr)) + + def __iter__(self): + buff, read = self.buffer_size, self.read + while True: + part = read(buff) + if not part: return + yield part + + +class _closeiter(object): + ''' This only exists to be able to attach a .close method to iterators that + do not support attribute assignment (most of itertools). ''' + + def __init__(self, iterator, close=None): + self.iterator = iterator + self.close_callbacks = makelist(close) + + def __iter__(self): + return iter(self.iterator) + + def close(self): + for func in self.close_callbacks: + func() + + +class ResourceManager(object): + ''' This class manages a list of search paths and helps to find and open + application-bound resources (files). + + :param base: default value for :meth:`add_path` calls. + :param opener: callable used to open resources. + :param cachemode: controls which lookups are cached. One of 'all', + 'found' or 'none'. + ''' + + def __init__(self, base='./', opener=open, cachemode='all'): + self.opener = open + self.base = base + self.cachemode = cachemode + + #: A list of search paths. See :meth:`add_path` for details. + self.path = [] + #: A cache for resolved paths. ``res.cache.clear()`` clears the cache. + self.cache = {} + + def add_path(self, path, base=None, index=None, create=False): + ''' Add a new path to the list of search paths. Return False if the + path does not exist. + + :param path: The new search path. Relative paths are turned into + an absolute and normalized form. If the path looks like a file + (not ending in `/`), the filename is stripped off. + :param base: Path used to absolutize relative search paths. + Defaults to :attr:`base` which defaults to ``os.getcwd()``. + :param index: Position within the list of search paths. Defaults + to last index (appends to the list). + + The `base` parameter makes it easy to reference files installed + along with a python module or package:: + + res.add_path('./resources/', __file__) + ''' + base = os.path.abspath(os.path.dirname(base or self.base)) + path = os.path.abspath(os.path.join(base, os.path.dirname(path))) + path += os.sep + if path in self.path: + self.path.remove(path) + if create and not os.path.isdir(path): + os.makedirs(path) + if index is None: + self.path.append(path) + else: + self.path.insert(index, path) + self.cache.clear() + return os.path.exists(path) + + def __iter__(self): + ''' Iterate over all existing files in all registered paths. ''' + search = self.path[:] + while search: + path = search.pop() + if not os.path.isdir(path): continue + for name in os.listdir(path): + full = os.path.join(path, name) + if os.path.isdir(full): search.append(full) + else: yield full + + def lookup(self, name): + ''' Search for a resource and return an absolute file path, or `None`. + + The :attr:`path` list is searched in order. The first match is + returend. Symlinks are followed. The result is cached to speed up + future lookups. ''' + if name not in self.cache or DEBUG: + for path in self.path: + fpath = os.path.join(path, name) + if os.path.isfile(fpath): + if self.cachemode in ('all', 'found'): + self.cache[name] = fpath + return fpath + if self.cachemode == 'all': + self.cache[name] = None + return self.cache[name] + + def open(self, name, mode='r', *args, **kwargs): + ''' Find a resource and return a file object, or raise IOError. ''' + fname = self.lookup(name) + if not fname: raise IOError("Resource %r not found." % name) + return self.opener(fname, mode=mode, *args, **kwargs) + + +class FileUpload(object): + + def __init__(self, fileobj, name, filename, headers=None): + ''' Wrapper for file uploads. ''' + #: Open file(-like) object (BytesIO buffer or temporary file) + self.file = fileobj + #: Name of the upload form field + self.name = name + #: Raw filename as sent by the client (may contain unsafe characters) + self.raw_filename = filename + #: A :class:`HeaderDict` with additional headers (e.g. content-type) + self.headers = HeaderDict(headers) if headers else HeaderDict() - def __iter__(self): - read, buff = self.fp.read, self.buffer_size - while True: - part = read(buff) - if not part: break - yield part + content_type = HeaderProperty('Content-Type') + content_length = HeaderProperty('Content-Length', reader=int, default=-1) + + @cached_property + def filename(self): + ''' Name of the file on the client file system, but normalized to ensure + file system compatibility. An empty filename is returned as 'empty'. + + Only ASCII letters, digits, dashes, underscores and dots are + allowed in the final filename. Accents are removed, if possible. + Whitespace is replaced by a single dash. Leading or tailing dots + or dashes are removed. The filename is limited to 255 characters. + ''' + fname = self.raw_filename + if not isinstance(fname, unicode): + fname = fname.decode('utf8', 'ignore') + fname = normalize('NFKD', fname).encode('ASCII', 'ignore').decode('ASCII') + fname = os.path.basename(fname.replace('\\', os.path.sep)) + fname = re.sub(r'[^a-zA-Z0-9-_.\s]', '', fname).strip() + fname = re.sub(r'[-\s]+', '-', fname).strip('.-') + return fname[:255] or 'empty' + + def _copy_file(self, fp, chunk_size=2**16): + read, write, offset = self.file.read, fp.write, self.file.tell() + while 1: + buf = read(chunk_size) + if not buf: break + write(buf) + self.file.seek(offset) + + def save(self, destination, overwrite=False, chunk_size=2**16): + ''' Save file to disk or copy its content to an open file(-like) object. + If *destination* is a directory, :attr:`filename` is added to the + path. Existing files are not overwritten by default (IOError). + + :param destination: File path, directory or file(-like) object. + :param overwrite: If True, replace existing files. (default: False) + :param chunk_size: Bytes to read at a time. (default: 64kb) + ''' + if isinstance(destination, basestring): # Except file-likes here + if os.path.isdir(destination): + destination = os.path.join(destination, self.filename) + if not overwrite and os.path.exists(destination): + raise IOError('File exists.') + with open(destination, 'wb') as fp: + self._copy_file(fp, chunk_size) + else: + self._copy_file(destination, chunk_size) @@ -1792,7 +2402,7 @@ class WSGIFileWrapper(object): ############################################################################### -def abort(code=500, text='Unknown Error: Application stopped.'): +def abort(code=500, text='Unknown Error.'): """ Aborts execution and causes a HTTP error. """ raise HTTPError(code, text) @@ -1800,21 +2410,48 @@ def abort(code=500, text='Unknown Error: Application stopped.'): def redirect(url, code=None): """ Aborts execution and causes a 303 or 302 redirect, depending on the HTTP protocol version. """ - if code is None: + if not code: code = 303 if request.get('SERVER_PROTOCOL') == "HTTP/1.1" else 302 - location = urljoin(request.url, url) - raise HTTPResponse("", status=code, header=dict(Location=location)) + res = response.copy(cls=HTTPResponse) + res.status = code + res.body = "" + res.set_header('Location', urljoin(request.url, url)) + raise res -def static_file(filename, root, mimetype='auto', download=False): +def _file_iter_range(fp, offset, bytes, maxread=1024*1024): + ''' Yield chunks from a range in a file. No chunk is bigger than maxread.''' + fp.seek(offset) + while bytes > 0: + part = fp.read(min(bytes, maxread)) + if not part: break + bytes -= len(part) + yield part + + +def static_file(filename, root, mimetype='auto', download=False, charset='UTF-8'): """ Open a file in a safe way and return :exc:`HTTPResponse` with status - code 200, 305, 401 or 404. Set Content-Type, Content-Encoding, - Content-Length and Last-Modified header. Obey If-Modified-Since header - and HEAD requests. + code 200, 305, 403 or 404. The ``Content-Type``, ``Content-Encoding``, + ``Content-Length`` and ``Last-Modified`` headers are set if possible. + Special support for ``If-Modified-Since``, ``Range`` and ``HEAD`` + requests. + + :param filename: Name or path of the file to send. + :param root: Root path for file lookups. Should be an absolute directory + path. + :param mimetype: Defines the content-type header (default: guess from + file extension) + :param download: If True, ask the browser to open a `Save as...` dialog + instead of opening the file with the associated program. You can + specify a custom filename as a string. If not specified, the + original filename is used (default: False). + :param charset: The charset to use for files with a ``text/*`` + mime-type. (default: UTF-8) """ + root = os.path.abspath(root) + os.sep filename = os.path.abspath(os.path.join(root, filename.strip('/\\'))) - header = dict() + headers = dict() if not filename.startswith(root): return HTTPError(403, "Access denied.") @@ -1825,29 +2462,43 @@ def static_file(filename, root, mimetype='auto', download=False): if mimetype == 'auto': mimetype, encoding = mimetypes.guess_type(filename) - if mimetype: header['Content-Type'] = mimetype - if encoding: header['Content-Encoding'] = encoding - elif mimetype: - header['Content-Type'] = mimetype + if encoding: headers['Content-Encoding'] = encoding + + if mimetype: + if mimetype[:5] == 'text/' and charset and 'charset' not in mimetype: + mimetype += '; charset=%s' % charset + headers['Content-Type'] = mimetype if download: download = os.path.basename(filename if download == True else download) - header['Content-Disposition'] = 'attachment; filename="%s"' % download + headers['Content-Disposition'] = 'attachment; filename="%s"' % download stats = os.stat(filename) - header['Content-Length'] = stats.st_size + headers['Content-Length'] = clen = stats.st_size lm = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime(stats.st_mtime)) - header['Last-Modified'] = lm + headers['Last-Modified'] = lm ims = request.environ.get('HTTP_IF_MODIFIED_SINCE') if ims: ims = parse_date(ims.split(";")[0].strip()) if ims is not None and ims >= int(stats.st_mtime): - header['Date'] = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime()) - return HTTPResponse(status=304, header=header) + headers['Date'] = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime()) + return HTTPResponse(status=304, **headers) body = '' if request.method == 'HEAD' else open(filename, 'rb') - return HTTPResponse(body, header=header) + + headers["Accept-Ranges"] = "bytes" + ranges = request.environ.get('HTTP_RANGE') + if 'HTTP_RANGE' in request.environ: + ranges = list(parse_range_header(request.environ['HTTP_RANGE'], clen)) + if not ranges: + return HTTPError(416, "Requested Range Not Satisfiable") + offset, end = ranges[0] + headers["Content-Range"] = "bytes %d-%d/%d" % (offset, end-1, clen) + headers["Content-Length"] = str(end-offset) + if body: body = _file_iter_range(body, offset, end-offset) + return HTTPResponse(body, status=206, **headers) + return HTTPResponse(body, **headers) @@ -1863,8 +2514,17 @@ def debug(mode=True): """ Change the debug level. There is only one debug level supported at the moment.""" global DEBUG + if mode: warnings.simplefilter('default') DEBUG = bool(mode) +def http_date(value): + if isinstance(value, (datedate, datetime)): + value = value.utctimetuple() + elif isinstance(value, (int, float)): + value = time.gmtime(value) + if not isinstance(value, basestring): + value = time.strftime("%a, %d %b %Y %H:%M:%S GMT", value) + return value def parse_date(ims): """ Parse rfc1123, rfc850 and asctime timestamps and return UTC epoch. """ @@ -1874,21 +2534,47 @@ def parse_date(ims): except (TypeError, ValueError, IndexError, OverflowError): return None - def parse_auth(header): """ Parse rfc2617 HTTP authentication header string (basic) and return (user,pass) tuple or None""" try: method, data = header.split(None, 1) if method.lower() == 'basic': - #TODO: Add 2to3 save base64[encode/decode] functions. user, pwd = touni(base64.b64decode(tob(data))).split(':',1) return user, pwd except (KeyError, ValueError): return None +def parse_range_header(header, maxlen=0): + ''' Yield (start, end) ranges parsed from a HTTP Range header. Skip + unsatisfiable ranges. The end index is non-inclusive.''' + if not header or header[:6] != 'bytes=': return + ranges = [r.split('-', 1) for r in header[6:].split(',') if '-' in r] + for start, end in ranges: + try: + if not start: # bytes=-100 -> last 100 bytes + start, end = max(0, maxlen-int(end)), maxlen + elif not end: # bytes=100- -> all but the first 99 bytes + start, end = int(start), maxlen + else: # bytes=100-200 -> bytes 100-200 (inclusive) + start, end = int(start), min(int(end)+1, maxlen) + if 0 <= start < end <= maxlen: + yield start, end + except ValueError: + pass + +def _parse_qsl(qs): + r = [] + for pair in qs.replace(';','&').split('&'): + if not pair: continue + nv = pair.split('=', 1) + if len(nv) != 2: nv.append('') + key = urlunquote(nv[0].replace('+', ' ')) + value = urlunquote(nv[1].replace('+', ' ')) + r.append((key, value)) + return r def _lscmp(a, b): - ''' Compares two strings in a cryptographically save way: + ''' Compares two strings in a cryptographically safe way: Runtime is not affected by length of common prefix. ''' return not sum(0 if x==y else 1 for x, y in zip(a, b)) and len(a) == len(b) @@ -1923,7 +2609,7 @@ def html_escape(string): def html_quote(string): ''' Escape and quote a string to be used as an HTTP attribute.''' - return '"%s"' % html_escape(string).replace('\n','%#10;')\ + return '"%s"' % html_escape(string).replace('\n',' ')\ .replace('\r',' ').replace('\t','	') @@ -1933,18 +2619,17 @@ def yieldroutes(func): takes optional keyword arguments. The output is best described by example:: a() -> '/a' - b(x, y) -> '/b/:x/:y' - c(x, y=5) -> '/c/:x' and '/c/:x/:y' - d(x=5, y=6) -> '/d' and '/d/:x' and '/d/:x/:y' + b(x, y) -> '/b/<x>/<y>' + c(x, y=5) -> '/c/<x>' and '/c/<x>/<y>' + d(x=5, y=6) -> '/d' and '/d/<x>' and '/d/<x>/<y>' """ - import inspect # Expensive module. Only import if necessary. path = '/' + func.__name__.replace('__','/').lstrip('/') - spec = inspect.getargspec(func) + spec = getargspec(func) argc = len(spec[0]) - len(spec[3] or []) - path += ('/:%s' * argc) % tuple(spec[0][:argc]) + path += ('/<%s>' * argc) % tuple(spec[0][:argc]) yield path for arg in spec[0][argc:]: - path += '/:%s' % arg + path += '/<%s>' % arg yield path @@ -1979,41 +2664,24 @@ def path_shift(script_name, path_info, shift=1): return new_script_name, new_path_info -def validate(**vkargs): - """ - Validates and manipulates keyword arguments by user defined callables. - Handles ValueError and missing arguments by raising HTTPError(403). - """ - depr('Use route wildcard filters instead.') - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kargs): - for key, value in vkargs.iteritems(): - if key not in kargs: - abort(403, 'Missing parameter: %s' % key) - try: - kargs[key] = value(kargs[key]) - except ValueError: - abort(403, 'Wrong parameter format for: %s' % key) - return func(*args, **kargs) - return wrapper - return decorator - - def auth_basic(check, realm="private", text="Access denied"): ''' Callback decorator to require HTTP auth (basic). TODO: Add route(check_auth=...) parameter. ''' def decorator(func): - def wrapper(*a, **ka): - user, password = request.auth or (None, None) - if user is None or not check(user, password): - response.headers['WWW-Authenticate'] = 'Basic realm="%s"' % realm - return HTTPError(401, text) - return func(*a, **ka) - return wrapper + def wrapper(*a, **ka): + user, password = request.auth or (None, None) + if user is None or not check(user, password): + err = HTTPError(401, text) + err.add_header('WWW-Authenticate', 'Basic realm="%s"' % realm) + return err + return func(*a, **ka) + return wrapper return decorator +# Shortcuts for common Bottle methods. +# They all refer to the current default application. + def make_default_app_wrapper(name): ''' Return a callable that relays calls to the current default app. ''' @functools.wraps(getattr(Bottle, name)) @@ -2021,12 +2689,18 @@ def make_default_app_wrapper(name): return getattr(app(), name)(*a, **ka) return wrapper +route = make_default_app_wrapper('route') +get = make_default_app_wrapper('get') +post = make_default_app_wrapper('post') +put = make_default_app_wrapper('put') +delete = make_default_app_wrapper('delete') +error = make_default_app_wrapper('error') +mount = make_default_app_wrapper('mount') +hook = make_default_app_wrapper('hook') +install = make_default_app_wrapper('install') +uninstall = make_default_app_wrapper('uninstall') +url = make_default_app_wrapper('get_url') -for name in '''route get post put delete error mount - hook install uninstall'''.split(): - globals()[name] = make_default_app_wrapper(name) -url = make_default_app_wrapper('get_url') -del name @@ -2040,8 +2714,8 @@ del name class ServerAdapter(object): quiet = False - def __init__(self, host='127.0.0.1', port=8080, **config): - self.options = config + def __init__(self, host='127.0.0.1', port=8080, **options): + self.options = options self.host = host self.port = int(port) @@ -2071,32 +2745,66 @@ class FlupFCGIServer(ServerAdapter): class WSGIRefServer(ServerAdapter): - def run(self, handler): # pragma: no cover - from wsgiref.simple_server import make_server, WSGIRequestHandler - if self.quiet: - class QuietHandler(WSGIRequestHandler): - def log_request(*args, **kw): pass - self.options['handler_class'] = QuietHandler - srv = make_server(self.host, self.port, handler, **self.options) + def run(self, app): # pragma: no cover + from wsgiref.simple_server import WSGIRequestHandler, WSGIServer + from wsgiref.simple_server import make_server + import socket + + class FixedHandler(WSGIRequestHandler): + def address_string(self): # Prevent reverse DNS lookups please. + return self.client_address[0] + def log_request(*args, **kw): + if not self.quiet: + return WSGIRequestHandler.log_request(*args, **kw) + + handler_cls = self.options.get('handler_class', FixedHandler) + server_cls = self.options.get('server_class', WSGIServer) + + if ':' in self.host: # Fix wsgiref for IPv6 addresses. + if getattr(server_cls, 'address_family') == socket.AF_INET: + class server_cls(server_cls): + address_family = socket.AF_INET6 + + srv = make_server(self.host, self.port, app, server_cls, handler_cls) srv.serve_forever() class CherryPyServer(ServerAdapter): def run(self, handler): # pragma: no cover from cherrypy import wsgiserver - server = wsgiserver.CherryPyWSGIServer((self.host, self.port), handler) + self.options['bind_addr'] = (self.host, self.port) + self.options['wsgi_app'] = handler + + certfile = self.options.get('certfile') + if certfile: + del self.options['certfile'] + keyfile = self.options.get('keyfile') + if keyfile: + del self.options['keyfile'] + + server = wsgiserver.CherryPyWSGIServer(**self.options) + if certfile: + server.ssl_certificate = certfile + if keyfile: + server.ssl_private_key = keyfile + try: server.start() finally: server.stop() +class WaitressServer(ServerAdapter): + def run(self, handler): + from waitress import serve + serve(handler, host=self.host, port=self.port) + + class PasteServer(ServerAdapter): def run(self, handler): # pragma: no cover from paste import httpserver - if not self.quiet: - from paste.translogger import TransLogger - handler = TransLogger(handler) + from paste.translogger import TransLogger + handler = TransLogger(handler, setup_console_handler=(not self.quiet)) httpserver.serve(handler, host=self.host, port=str(self.port), **self.options) @@ -2120,8 +2828,8 @@ class FapwsServer(ServerAdapter): evwsgi.start(self.host, port) # fapws3 never releases the GIL. Complain upstream. I tried. No luck. if 'BOTTLE_CHILD' in os.environ and not self.quiet: - print "WARNING: Auto-reloading does not work with Fapws3." - print " (Fapws3 breaks python thread support)" + _stderr("WARNING: Auto-reloading does not work with Fapws3.\n") + _stderr(" (Fapws3 breaks python thread support)\n") evwsgi.set_base_module(base) def app(environ, start_response): environ['wsgi.multiprocess'] = False @@ -2136,7 +2844,7 @@ class TornadoServer(ServerAdapter): import tornado.wsgi, tornado.httpserver, tornado.ioloop container = tornado.wsgi.WSGIContainer(handler) server = tornado.httpserver.HTTPServer(container) - server.listen(port=self.port) + server.listen(port=self.port,address=self.host) tornado.ioloop.IOLoop.instance().start() @@ -2178,16 +2886,30 @@ class DieselServer(ServerAdapter): class GeventServer(ServerAdapter): """ Untested. Options: - * `monkey` (default: True) fixes the stdlib to use greenthreads. * `fast` (default: False) uses libevent's http server, but has some issues: No streaming, no pipelining, no SSL. + * See gevent.wsgi.WSGIServer() documentation for more options. """ def run(self, handler): - from gevent import wsgi as wsgi_fast, pywsgi, monkey, local - if self.options.get('monkey', True): - if not threading.local is local.local: monkey.patch_all() - wsgi = wsgi_fast if self.options.get('fast') else pywsgi - wsgi.WSGIServer((self.host, self.port), handler).serve_forever() + from gevent import wsgi, pywsgi, local + if not isinstance(threading.local(), local.local): + msg = "Bottle requires gevent.monkey.patch_all() (before import)" + raise RuntimeError(msg) + if not self.options.pop('fast', None): wsgi = pywsgi + self.options['log'] = None if self.quiet else 'default' + address = (self.host, self.port) + server = wsgi.WSGIServer(address, handler, **self.options) + if 'BOTTLE_CHILD' in os.environ: + import signal + signal.signal(signal.SIGINT, lambda s, f: server.stop()) + server.serve_forever() + + +class GeventSocketIOServer(ServerAdapter): + def run(self,handler): + from socketio import server + address = (self.host, self.port) + server.SocketIOServer(address, handler, **self.options).serve_forever() class GunicornServer(ServerAdapter): @@ -2212,7 +2934,12 @@ class EventletServer(ServerAdapter): """ Untested """ def run(self, handler): from eventlet import wsgi, listen - wsgi.server(listen((self.host, self.port)), handler) + try: + wsgi.server(listen((self.host, self.port)), handler, + log_output=(not self.quiet)) + except TypeError: + # Fallback, if we have old version of eventlet + wsgi.server(listen((self.host, self.port)), handler) class RocketServer(ServerAdapter): @@ -2232,7 +2959,7 @@ class BjoernServer(ServerAdapter): class AutoServer(ServerAdapter): """ Untested. """ - adapters = [PasteServer, CherryPyServer, TwistedServer, WSGIRefServer] + adapters = [WaitressServer, PasteServer, TwistedServer, CherryPyServer, WSGIRefServer] def run(self, handler): for sa in self.adapters: try: @@ -2244,6 +2971,7 @@ server_names = { 'cgi': CGIServer, 'flup': FlupFCGIServer, 'wsgiref': WSGIRefServer, + 'waitress': WaitressServer, 'cherrypy': CherryPyServer, 'paste': PasteServer, 'fapws3': FapwsServer, @@ -2255,6 +2983,7 @@ server_names = { 'gunicorn': GunicornServer, 'eventlet': EventletServer, 'gevent': GeventServer, + 'geventSocketIO':GeventSocketIOServer, 'rocket': RocketServer, 'bjoern' : BjoernServer, 'auto': AutoServer, @@ -2303,8 +3032,10 @@ def load_app(target): default_app.remove(tmp) # Remove the temporary added default application NORUN = nr_old +_debug = debug def run(app=None, server='wsgiref', host='127.0.0.1', port=8080, - interval=1, reloader=False, quiet=False, plugins=None, **kargs): + interval=1, reloader=False, quiet=False, plugins=None, + debug=None, **kargs): """ Start a server instance. This method blocks until the server terminates. :param app: WSGI application or target string supported by @@ -2324,6 +3055,7 @@ def run(app=None, server='wsgiref', host='127.0.0.1', port=8080, if NORUN: return if reloader and not os.environ.get('BOTTLE_CHILD'): try: + lockfile = None fd, lockfile = tempfile.mkstemp(prefix='bottle.', suffix='.lock') os.close(fd) # We only need this file to exist. We never write to it while os.path.exists(lockfile): @@ -2345,9 +3077,8 @@ def run(app=None, server='wsgiref', host='127.0.0.1', port=8080, os.unlink(lockfile) return - stderr = sys.stderr.write - try: + if debug is not None: _debug(debug) app = app or default_app() if isinstance(app, basestring): app = load_app(app) @@ -2368,9 +3099,9 @@ def run(app=None, server='wsgiref', host='127.0.0.1', port=8080, server.quiet = server.quiet or quiet if not server.quiet: - stderr("Bottle server starting up (using %s)...\n" % repr(server)) - stderr("Listening on http://%s:%d/\n" % (server.host, server.port)) - stderr("Hit Ctrl-C to quit.\n\n") + _stderr("Bottle v%s server starting up (using %s)...\n" % (__version__, repr(server))) + _stderr("Listening on http://%s:%d/\n" % (server.host, server.port)) + _stderr("Hit Ctrl-C to quit.\n\n") if reloader: lockfile = os.environ.get('BOTTLE_LOCKFILE') @@ -2383,12 +3114,15 @@ def run(app=None, server='wsgiref', host='127.0.0.1', port=8080, server.run(app) except KeyboardInterrupt: pass - except (SyntaxError, ImportError): + except (SystemExit, MemoryError): + raise + except: if not reloader: raise - if not getattr(server, 'quiet', False): print_exc() + if not getattr(server, 'quiet', quiet): + print_exc() + time.sleep(interval) sys.exit(3) - finally: - if not getattr(server, 'quiet', False): stderr('Shutdown...\n') + class FileCheckerThread(threading.Thread): @@ -2406,7 +3140,7 @@ class FileCheckerThread(threading.Thread): mtime = lambda path: os.stat(path).st_mtime files = dict() - for module in sys.modules.values(): + for module in list(sys.modules.values()): path = getattr(module, '__file__', '') if path[-4:] in ('.pyo', '.pyc'): path = path[:-1] if path and exists(path): files[path] = mtime(path) @@ -2416,20 +3150,20 @@ class FileCheckerThread(threading.Thread): or mtime(self.lockfile) < time.time() - self.interval - 5: self.status = 'error' thread.interrupt_main() - for path, lmtime in files.iteritems(): + for path, lmtime in list(files.items()): if not exists(path) or mtime(path) > lmtime: self.status = 'reload' thread.interrupt_main() break time.sleep(self.interval) - + def __enter__(self): self.start() - + def __exit__(self, exc_type, exc_val, exc_tb): if not self.status: self.status = 'exit' # silent exit self.join() - return issubclass(exc_type, KeyboardInterrupt) + return exc_type is not None and issubclass(exc_type, KeyboardInterrupt) @@ -2465,7 +3199,7 @@ class BaseTemplate(object): self.name = name self.source = source.read() if hasattr(source, 'read') else source self.filename = source.filename if hasattr(source, 'filename') else None - self.lookup = map(os.path.abspath, lookup) + self.lookup = [os.path.abspath(x) for x in lookup] self.encoding = encoding self.settings = self.settings.copy() # Copy from class variable self.settings.update(settings) # Apply @@ -2481,11 +3215,19 @@ class BaseTemplate(object): def search(cls, name, lookup=[]): """ Search name in all directories specified in lookup. First without, then with common extensions. Return first hit. """ - if os.path.isfile(name): return name + if not lookup: + depr('The template lookup path list should not be empty.') #0.12 + lookup = ['.'] + + if os.path.isabs(name) and os.path.isfile(name): + depr('Absolute template path names are deprecated.') #0.12 + return os.path.abspath(name) + for spath in lookup: - fname = os.path.join(spath, name) - if os.path.isfile(fname): - return fname + spath = os.path.abspath(spath) + os.sep + fname = os.path.abspath(os.path.join(spath, name)) + if not fname.startswith(spath): continue + if os.path.isfile(fname): return fname for ext in cls.extensions: if os.path.isfile('%s.%s' % (fname, ext)): return '%s.%s' % (fname, ext) @@ -2510,8 +3252,8 @@ class BaseTemplate(object): """ Render the template with the specified local variables and return a single byte or unicode string. If it is a byte string, the encoding must match self.encoding. This method must be thread-safe! - Local variables may be provided in dictionaries (*args) - or directly, as keywords (**kwargs). + Local variables may be provided in dictionaries (args) + or directly, as keywords (kwargs). """ raise NotImplementedError @@ -2556,7 +3298,7 @@ class CheetahTemplate(BaseTemplate): class Jinja2Template(BaseTemplate): - def prepare(self, filters=None, tests=None, **kwargs): + def prepare(self, filters=None, tests=None, globals={}, **kwargs): from jinja2 import Environment, FunctionLoader if 'prefix' in kwargs: # TODO: to be removed after a while raise RuntimeError('The keyword argument `prefix` has been removed. ' @@ -2564,6 +3306,7 @@ class Jinja2Template(BaseTemplate): self.env = Environment(loader=FunctionLoader(self.loader), **kwargs) if filters: self.env.filters.update(filters) if tests: self.env.tests.update(tests) + if globals: self.env.globals.update(globals) if self.source: self.tpl = self.env.from_string(self.source) else: @@ -2577,189 +3320,252 @@ class Jinja2Template(BaseTemplate): def loader(self, name): fname = self.search(name, self.lookup) - if fname: - with open(fname, "rb") as f: - return f.read().decode(self.encoding) - - -class SimpleTALTemplate(BaseTemplate): - ''' Untested! ''' - def prepare(self, **options): - from simpletal import simpleTAL - # TODO: add option to load METAL files during render - if self.source: - self.tpl = simpleTAL.compileHTMLTemplate(self.source) - else: - with open(self.filename, 'rb') as fp: - self.tpl = simpleTAL.compileHTMLTemplate(tonat(fp.read())) - - def render(self, *args, **kwargs): - from simpletal import simpleTALES - for dictarg in args: kwargs.update(dictarg) - # TODO: maybe reuse a context instead of always creating one - context = simpleTALES.Context() - for k,v in self.defaults.items(): - context.addGlobal(k, v) - for k,v in kwargs.items(): - context.addGlobal(k, v) - output = StringIO() - self.tpl.expand(context, output) - return output.getvalue() + if not fname: return + with open(fname, "rb") as f: + return f.read().decode(self.encoding) class SimpleTemplate(BaseTemplate): - blocks = ('if', 'elif', 'else', 'try', 'except', 'finally', 'for', 'while', - 'with', 'def', 'class') - dedent_blocks = ('elif', 'else', 'except', 'finally') - - @lazy_attribute - def re_pytokens(cls): - ''' This matches comments and all kinds of quoted strings but does - NOT match comments (#...) within quoted strings. (trust me) ''' - return re.compile(r''' - (''(?!')|""(?!")|'{6}|"{6} # Empty strings (all 4 types) - |'(?:[^\\']|\\.)+?' # Single quotes (') - |"(?:[^\\"]|\\.)+?" # Double quotes (") - |'{3}(?:[^\\]|\\.|\n)+?'{3} # Triple-quoted strings (') - |"{3}(?:[^\\]|\\.|\n)+?"{3} # Triple-quoted strings (") - |\#.* # Comments - )''', re.VERBOSE) - - def prepare(self, escape_func=html_escape, noescape=False, **kwargs): + + def prepare(self, escape_func=html_escape, noescape=False, syntax=None, **ka): self.cache = {} enc = self.encoding self._str = lambda x: touni(x, enc) self._escape = lambda x: escape_func(touni(x, enc)) + self.syntax = syntax if noescape: self._str, self._escape = self._escape, self._str - @classmethod - def split_comment(cls, code): - """ Removes comments (#...) from python code. """ - if '#' not in code: return code - #: Remove comments only (leave quoted strings as they are) - subf = lambda m: '' if m.group(0)[0]=='#' else m.group(0) - return re.sub(cls.re_pytokens, subf, code) - @cached_property def co(self): return compile(self.code, self.filename or '<string>', 'exec') @cached_property def code(self): - stack = [] # Current Code indentation - lineno = 0 # Current line of code - ptrbuffer = [] # Buffer for printable strings and token tuple instances - codebuffer = [] # Buffer for generated python code - multiline = dedent = oneline = False - template = self.source or open(self.filename, 'rb').read() - - def yield_tokens(line): - for i, part in enumerate(re.split(r'\{\{(.*?)\}\}', line)): - if i % 2: - if part.startswith('!'): yield 'RAW', part[1:] - else: yield 'CMD', part - else: yield 'TXT', part - - def flush(): # Flush the ptrbuffer - if not ptrbuffer: return - cline = '' - for line in ptrbuffer: - for token, value in line: - if token == 'TXT': cline += repr(value) - elif token == 'RAW': cline += '_str(%s)' % value - elif token == 'CMD': cline += '_escape(%s)' % value - cline += ', ' - cline = cline[:-2] + '\\\n' - cline = cline[:-2] - if cline[:-1].endswith('\\\\\\\\\\n'): - cline = cline[:-7] + cline[-1] # 'nobr\\\\\n' --> 'nobr' - cline = '_printlist([' + cline + '])' - del ptrbuffer[:] # Do this before calling code() again - code(cline) - - def code(stmt): - for line in stmt.splitlines(): - codebuffer.append(' ' * len(stack) + line.strip()) - - for line in template.splitlines(True): - lineno += 1 - line = line if isinstance(line, unicode)\ - else unicode(line, encoding=self.encoding) - if lineno <= 2: - m = re.search(r"%.*coding[:=]\s*([-\w\.]+)", line) - if m: self.encoding = m.group(1) - if m: line = line.replace('coding','coding (removed)') - if line.strip()[:2].count('%') == 1: - line = line.split('%',1)[1].lstrip() # Full line following the % - cline = self.split_comment(line).strip() - cmd = re.split(r'[^a-zA-Z0-9_]', cline)[0] - flush() # You are actually reading this? Good luck, it's a mess :) - if cmd in self.blocks or multiline: - cmd = multiline or cmd - dedent = cmd in self.dedent_blocks # "else:" - if dedent and not oneline and not multiline: - cmd = stack.pop() - code(line) - oneline = not cline.endswith(':') # "if 1: pass" - multiline = cmd if cline.endswith('\\') else False - if not oneline and not multiline: - stack.append(cmd) - elif cmd == 'end' and stack: - code('#end(%s) %s' % (stack.pop(), line.strip()[3:])) - elif cmd == 'include': - p = cline.split(None, 2)[1:] - if len(p) == 2: - code("_=_include(%s, _stdout, %s)" % (repr(p[0]), p[1])) - elif p: - code("_=_include(%s, _stdout)" % repr(p[0])) - else: # Empty %include -> reverse of %rebase - code("_printlist(_base)") - elif cmd == 'rebase': - p = cline.split(None, 2)[1:] - if len(p) == 2: - code("globals()['_rebase']=(%s, dict(%s))" % (repr(p[0]), p[1])) - elif p: - code("globals()['_rebase']=(%s, {})" % repr(p[0])) - else: - code(line) - else: # Line starting with text (not '%') or '%%' (escaped) - if line.strip().startswith('%%'): - line = line.replace('%%', '%', 1) - ptrbuffer.append(yield_tokens(line)) - flush() - return '\n'.join(codebuffer) + '\n' - - def subtemplate(self, _name, _stdout, *args, **kwargs): - for dictarg in args: kwargs.update(dictarg) + source = self.source + if not source: + with open(self.filename, 'rb') as f: + source = f.read() + try: + source, encoding = touni(source), 'utf8' + except UnicodeError: + depr('Template encodings other than utf8 are no longer supported.') #0.11 + source, encoding = touni(source, 'latin1'), 'latin1' + parser = StplParser(source, encoding=encoding, syntax=self.syntax) + code = parser.translate() + self.encoding = parser.encoding + return code + + def _rebase(self, _env, _name=None, **kwargs): + if _name is None: + depr('Rebase function called without arguments.' + ' You were probably looking for {{base}}?', True) #0.12 + _env['_rebase'] = (_name, kwargs) + + def _include(self, _env, _name=None, **kwargs): + if _name is None: + depr('Rebase function called without arguments.' + ' You were probably looking for {{base}}?', True) #0.12 + env = _env.copy() + env.update(kwargs) if _name not in self.cache: self.cache[_name] = self.__class__(name=_name, lookup=self.lookup) - return self.cache[_name].execute(_stdout, kwargs) + return self.cache[_name].execute(env['_stdout'], env) - def execute(self, _stdout, *args, **kwargs): - for dictarg in args: kwargs.update(dictarg) + def execute(self, _stdout, kwargs): env = self.defaults.copy() - env.update({'_stdout': _stdout, '_printlist': _stdout.extend, - '_include': self.subtemplate, '_str': self._str, - '_escape': self._escape, 'get': env.get, - 'setdefault': env.setdefault, 'defined': env.__contains__}) env.update(kwargs) + env.update({'_stdout': _stdout, '_printlist': _stdout.extend, + 'include': functools.partial(self._include, env), + 'rebase': functools.partial(self._rebase, env), '_rebase': None, + '_str': self._str, '_escape': self._escape, 'get': env.get, + 'setdefault': env.setdefault, 'defined': env.__contains__ }) eval(self.co, env) - if '_rebase' in env: - subtpl, rargs = env['_rebase'] - rargs['_base'] = _stdout[:] #copy stdout + if env.get('_rebase'): + subtpl, rargs = env.pop('_rebase') + rargs['base'] = ''.join(_stdout) #copy stdout del _stdout[:] # clear stdout - return self.subtemplate(subtpl,_stdout,rargs) + return self._include(env, subtpl, **rargs) return env def render(self, *args, **kwargs): """ Render the template using keyword arguments as local variables. """ - for dictarg in args: kwargs.update(dictarg) - stdout = [] - self.execute(stdout, kwargs) + env = {}; stdout = [] + for dictarg in args: env.update(dictarg) + env.update(kwargs) + self.execute(stdout, env) return ''.join(stdout) +class StplSyntaxError(TemplateError): pass + + +class StplParser(object): + ''' Parser for stpl templates. ''' + _re_cache = {} #: Cache for compiled re patterns + # This huge pile of voodoo magic splits python code into 8 different tokens. + # 1: All kinds of python strings (trust me, it works) + _re_tok = '((?m)[urbURB]?(?:\'\'(?!\')|""(?!")|\'{6}|"{6}' \ + '|\'(?:[^\\\\\']|\\\\.)+?\'|"(?:[^\\\\"]|\\\\.)+?"' \ + '|\'{3}(?:[^\\\\]|\\\\.|\\n)+?\'{3}' \ + '|"{3}(?:[^\\\\]|\\\\.|\\n)+?"{3}))' + _re_inl = _re_tok.replace('|\\n','') # We re-use this string pattern later + # 2: Comments (until end of line, but not the newline itself) + _re_tok += '|(#.*)' + # 3,4: Keywords that start or continue a python block (only start of line) + _re_tok += '|^([ \\t]*(?:if|for|while|with|try|def|class)\\b)' \ + '|^([ \\t]*(?:elif|else|except|finally)\\b)' + # 5: Our special 'end' keyword (but only if it stands alone) + _re_tok += '|((?:^|;)[ \\t]*end[ \\t]*(?=(?:%(block_close)s[ \\t]*)?\\r?$|;|#))' + # 6: A customizable end-of-code-block template token (only end of line) + _re_tok += '|(%(block_close)s[ \\t]*(?=$))' + # 7: And finally, a single newline. The 8th token is 'everything else' + _re_tok += '|(\\r?\\n)' + # Match the start tokens of code areas in a template + _re_split = '(?m)^[ \t]*(\\\\?)((%(line_start)s)|(%(block_start)s))(%%?)' + # Match inline statements (may contain python strings) + _re_inl = '%%(inline_start)s((?:%s|[^\'"\n]*?)+)%%(inline_end)s' % _re_inl + + default_syntax = '<% %> % {{ }}' + + def __init__(self, source, syntax=None, encoding='utf8'): + self.source, self.encoding = touni(source, encoding), encoding + self.set_syntax(syntax or self.default_syntax) + self.code_buffer, self.text_buffer = [], [] + self.lineno, self.offset = 1, 0 + self.indent, self.indent_mod = 0, 0 + + def get_syntax(self): + ''' Tokens as a space separated string (default: <% %> % {{ }}) ''' + return self._syntax + + def set_syntax(self, syntax): + self._syntax = syntax + self._tokens = syntax.split() + if not syntax in self._re_cache: + names = 'block_start block_close line_start inline_start inline_end' + etokens = map(re.escape, self._tokens) + pattern_vars = dict(zip(names.split(), etokens)) + patterns = (self._re_split, self._re_tok, self._re_inl) + patterns = [re.compile(p%pattern_vars) for p in patterns] + self._re_cache[syntax] = patterns + self.re_split, self.re_tok, self.re_inl = self._re_cache[syntax] + + syntax = property(get_syntax, set_syntax) + + def translate(self): + if self.offset: raise RuntimeError('Parser is a one time instance.') + while True: + m = self.re_split.search(self.source[self.offset:]) + if m: + text = self.source[self.offset:self.offset+m.start()] + self.text_buffer.append(text) + self.offset += m.end() + if m.group(1): # New escape syntax + line, sep, _ = self.source[self.offset:].partition('\n') + self.text_buffer.append(m.group(2)+m.group(5)+line+sep) + self.offset += len(line+sep)+1 + continue + elif m.group(5): # Old escape syntax + depr('Escape code lines with a backslash.') #0.12 + line, sep, _ = self.source[self.offset:].partition('\n') + self.text_buffer.append(m.group(2)+line+sep) + self.offset += len(line+sep)+1 + continue + self.flush_text() + self.read_code(multiline=bool(m.group(4))) + else: break + self.text_buffer.append(self.source[self.offset:]) + self.flush_text() + return ''.join(self.code_buffer) + + def read_code(self, multiline): + code_line, comment = '', '' + while True: + m = self.re_tok.search(self.source[self.offset:]) + if not m: + code_line += self.source[self.offset:] + self.offset = len(self.source) + self.write_code(code_line.strip(), comment) + return + code_line += self.source[self.offset:self.offset+m.start()] + self.offset += m.end() + _str, _com, _blk1, _blk2, _end, _cend, _nl = m.groups() + if code_line and (_blk1 or _blk2): # a if b else c + code_line += _blk1 or _blk2 + continue + if _str: # Python string + code_line += _str + elif _com: # Python comment (up to EOL) + comment = _com + if multiline and _com.strip().endswith(self._tokens[1]): + multiline = False # Allow end-of-block in comments + elif _blk1: # Start-block keyword (if/for/while/def/try/...) + code_line, self.indent_mod = _blk1, -1 + self.indent += 1 + elif _blk2: # Continue-block keyword (else/elif/except/...) + code_line, self.indent_mod = _blk2, -1 + elif _end: # The non-standard 'end'-keyword (ends a block) + self.indent -= 1 + elif _cend: # The end-code-block template token (usually '%>') + if multiline: multiline = False + else: code_line += _cend + else: # \n + self.write_code(code_line.strip(), comment) + self.lineno += 1 + code_line, comment, self.indent_mod = '', '', 0 + if not multiline: + break + + def flush_text(self): + text = ''.join(self.text_buffer) + del self.text_buffer[:] + if not text: return + parts, pos, nl = [], 0, '\\\n'+' '*self.indent + for m in self.re_inl.finditer(text): + prefix, pos = text[pos:m.start()], m.end() + if prefix: + parts.append(nl.join(map(repr, prefix.splitlines(True)))) + if prefix.endswith('\n'): parts[-1] += nl + parts.append(self.process_inline(m.group(1).strip())) + if pos < len(text): + prefix = text[pos:] + lines = prefix.splitlines(True) + if lines[-1].endswith('\\\\\n'): lines[-1] = lines[-1][:-3] + elif lines[-1].endswith('\\\\\r\n'): lines[-1] = lines[-1][:-4] + parts.append(nl.join(map(repr, lines))) + code = '_printlist((%s,))' % ', '.join(parts) + self.lineno += code.count('\n')+1 + self.write_code(code) + + def process_inline(self, chunk): + if chunk[0] == '!': return '_str(%s)' % chunk[1:] + return '_escape(%s)' % chunk + + def write_code(self, line, comment=''): + line, comment = self.fix_backward_compatibility(line, comment) + code = ' ' * (self.indent+self.indent_mod) + code += line.lstrip() + comment + '\n' + self.code_buffer.append(code) + + def fix_backward_compatibility(self, line, comment): + parts = line.strip().split(None, 2) + if parts and parts[0] in ('include', 'rebase'): + depr('The include and rebase keywords are functions now.') #0.12 + if len(parts) == 1: return "_printlist([base])", comment + elif len(parts) == 2: return "_=%s(%r)" % tuple(parts), comment + else: return "_=%s(%r, %s)" % tuple(parts), comment + if self.lineno <= 2 and not line.strip() and 'coding' in comment: + m = re.match(r"#.*coding[:=]\s*([-\w.]+)", comment) + if m: + depr('PEP263 encoding strings in templates are deprecated.') #0.12 + enc = m.group(1) + self.source = self.source.encode(self.encoding).decode(enc) + self.encoding = enc + return line, comment.replace('coding','coding*') + return line, comment + + def template(*args, **kwargs): ''' Get a rendered template as a string iterator. @@ -2768,26 +3574,26 @@ def template(*args, **kwargs): or directly (as keyword arguments). ''' tpl = args[0] if args else None - template_adapter = kwargs.pop('template_adapter', SimpleTemplate) - if tpl not in TEMPLATES or DEBUG: + adapter = kwargs.pop('template_adapter', SimpleTemplate) + lookup = kwargs.pop('template_lookup', TEMPLATE_PATH) + tplid = (id(lookup), tpl) + if tplid not in TEMPLATES or DEBUG: settings = kwargs.pop('template_settings', {}) - lookup = kwargs.pop('template_lookup', TEMPLATE_PATH) - if isinstance(tpl, template_adapter): - TEMPLATES[tpl] = tpl - if settings: TEMPLATES[tpl].prepare(**settings) + if isinstance(tpl, adapter): + TEMPLATES[tplid] = tpl + if settings: TEMPLATES[tplid].prepare(**settings) elif "\n" in tpl or "{" in tpl or "%" in tpl or '$' in tpl: - TEMPLATES[tpl] = template_adapter(source=tpl, lookup=lookup, **settings) + TEMPLATES[tplid] = adapter(source=tpl, lookup=lookup, **settings) else: - TEMPLATES[tpl] = template_adapter(name=tpl, lookup=lookup, **settings) - if not TEMPLATES[tpl]: + TEMPLATES[tplid] = adapter(name=tpl, lookup=lookup, **settings) + if not TEMPLATES[tplid]: abort(500, 'Template (%s) not found' % tpl) for dictarg in args[1:]: kwargs.update(dictarg) - return TEMPLATES[tpl].render(kwargs) + return TEMPLATES[tplid].render(kwargs) mako_template = functools.partial(template, template_adapter=MakoTemplate) cheetah_template = functools.partial(template, template_adapter=CheetahTemplate) jinja2_template = functools.partial(template, template_adapter=Jinja2Template) -simpletal_template = functools.partial(template, template_adapter=SimpleTALTemplate) def view(tpl_name, **defaults): @@ -2808,6 +3614,8 @@ def view(tpl_name, **defaults): tplvars = defaults.copy() tplvars.update(result) return template(tpl_name, **tplvars) + elif result is None: + return template(tpl_name, defaults) return result return wrapper return decorator @@ -2815,7 +3623,6 @@ def view(tpl_name, **defaults): mako_view = functools.partial(view, template_adapter=MakoTemplate) cheetah_view = functools.partial(view, template_adapter=CheetahTemplate) jinja2_view = functools.partial(view, template_adapter=Jinja2Template) -simpletal_view = functools.partial(view, template_adapter=SimpleTALTemplate) @@ -2839,17 +3646,16 @@ HTTP_CODES[428] = "Precondition Required" HTTP_CODES[429] = "Too Many Requests" HTTP_CODES[431] = "Request Header Fields Too Large" HTTP_CODES[511] = "Network Authentication Required" -_HTTP_STATUS_LINES = dict((k, '%d %s'%(k,v)) for (k,v) in HTTP_CODES.iteritems()) +_HTTP_STATUS_LINES = dict((k, '%d %s'%(k,v)) for (k,v) in HTTP_CODES.items()) #: The default template used for error pages. Override with @error() ERROR_PAGE_TEMPLATE = """ -%try: - %from bottle import DEBUG, HTTP_CODES, request, touni - %status_name = HTTP_CODES.get(e.status, 'Unknown').title() +%%try: + %%from %s import DEBUG, HTTP_CODES, request, touni <!DOCTYPE HTML PUBLIC "-//IETF//DTD HTML 2.0//EN"> <html> <head> - <title>Error {{e.status}}: {{status_name}}</title> + <title>Error: {{e.status}}</title> <style type="text/css"> html {background-color: #eee; font-family: sans;} body {background-color: #fff; border: 1px solid #ddd; @@ -2858,31 +3664,34 @@ ERROR_PAGE_TEMPLATE = """ </style> </head> <body> - <h1>Error {{e.status}}: {{status_name}}</h1> + <h1>Error: {{e.status}}</h1> <p>Sorry, the requested URL <tt>{{repr(request.url)}}</tt> caused an error:</p> - <pre>{{e.output}}</pre> - %if DEBUG and e.exception: + <pre>{{e.body}}</pre> + %%if DEBUG and e.exception: <h2>Exception:</h2> <pre>{{repr(e.exception)}}</pre> - %end - %if DEBUG and e.traceback: + %%end + %%if DEBUG and e.traceback: <h2>Traceback:</h2> <pre>{{e.traceback}}</pre> - %end + %%end </body> </html> -%except ImportError: +%%except ImportError: <b>ImportError:</b> Could not generate the error page. Please add bottle to the import path. -%end -""" +%%end +""" % __name__ -#: A thread-safe instance of :class:`Request` representing the `current` request. -request = Request() +#: A thread-safe instance of :class:`LocalRequest`. If accessed from within a +#: request callback, this instance always refers to the *current* request +#: (even on a multithreaded server). +request = LocalRequest() -#: A thread-safe instance of :class:`Response` used to build the HTTP response. -response = Response() +#: A thread-safe instance of :class:`LocalResponse`. It is used to change the +#: HTTP response for the *current* request. +response = LocalResponse() #: A thread-safe namespace. Not used by Bottle. local = threading.local() @@ -2894,29 +3703,30 @@ app.push() #: A virtual package that redirects import statements. #: Example: ``import bottle.ext.sqlite`` actually imports `bottle_sqlite`. -ext = _ImportRedirect(__name__+'.ext', 'bottle_%s').module +ext = _ImportRedirect('bottle.ext' if __name__ == '__main__' else __name__+".ext", 'bottle_%s').module if __name__ == '__main__': opt, args, parser = _cmd_options, _cmd_args, _cmd_parser if opt.version: - print 'Bottle', __version__; sys.exit(0) + _stdout('Bottle %s\n'%__version__) + sys.exit(0) if not args: parser.print_help() - print '\nError: No application specified.\n' + _stderr('\nError: No application specified.\n') sys.exit(1) - try: - sys.path.insert(0, '.') - sys.modules.setdefault('bottle', sys.modules['__main__']) - except (AttributeError, ImportError), e: - parser.error(e.args[0]) + sys.path.insert(0, '.') + sys.modules.setdefault('bottle', sys.modules['__main__']) + + host, port = (opt.bind or 'localhost'), 8080 + if ':' in host and host.rfind(']') < host.rfind(':'): + host, port = host.rsplit(':', 1) + host = host.strip('[]') + + run(args[0], host=host, port=int(port), server=opt.server, + reloader=opt.reload, plugins=opt.plugin, debug=opt.debug) + - if opt.bind and ':' in opt.bind: - host, port = opt.bind.rsplit(':', 1) - else: - host, port = (opt.bind or 'localhost'), 8080 - debug(opt.debug) - run(args[0], host=host, port=port, server=opt.server, reloader=opt.reload, plugins=opt.plugin) # THE END diff --git a/module/lib/feedparser.py b/module/lib/feedparser.py index a746ed8f5..c78e6a39b 100644 --- a/module/lib/feedparser.py +++ b/module/lib/feedparser.py @@ -1,17 +1,19 @@ -#!/usr/bin/env python """Universal feed parser Handles RSS 0.9x, RSS 1.0, RSS 2.0, CDF, Atom 0.3, and Atom 1.0 feeds -Visit http://feedparser.org/ for the latest version -Visit http://feedparser.org/docs/ for the latest documentation +Visit https://code.google.com/p/feedparser/ for the latest version +Visit http://packages.python.org/feedparser/ for the latest documentation Required: Python 2.4 or later -Recommended: CJKCodecs and iconv_codec <http://cjkpython.i18n.org/> +Recommended: iconv_codec <http://cjkpython.i18n.org/> """ -__version__ = "5.0" -__license__ = """Copyright (c) 2002-2008, Mark Pilgrim, All rights reserved. +__version__ = "5.1.3" +__license__ = """ +Copyright (c) 2010-2012 Kurt McKee <contactme@kurtmckee.org> +Copyright (c) 2002-2008 Mark Pilgrim +All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -42,13 +44,13 @@ __contributors__ = ["Jason Diamond <http://injektilo.org/>", "Sam Ruby <http://intertwingly.net/>", "Ade Oshineye <http://blog.oshineye.com/>", "Martin Pool <http://sourcefrog.net/>", - "Kurt McKee <http://kurtmckee.org/>"] -_debug = 0 + "Kurt McKee <http://kurtmckee.org/>", + "Bernd Schlapsi <https://github.com/brot>",] # HTTP "User-Agent" header to send to servers when downloading feeds. # If you are embedding feedparser in a larger application, you should # change this to your application name and URL. -USER_AGENT = "UniversalFeedParser/%s +http://feedparser.org/" % __version__ +USER_AGENT = "UniversalFeedParser/%s +https://code.google.com/p/feedparser/" % __version__ # HTTP "Accept" header to send to servers when downloading feeds. If you don't # want to send an Accept header, set this to None. @@ -76,6 +78,10 @@ RESOLVE_RELATIVE_URIS = 1 # HTML content, set this to 1. SANITIZE_HTML = 1 +# If you want feedparser to automatically parse microformat content embedded +# in entry contents, set this to 1 +PARSE_MICROFORMATS = 1 + # ---------- Python 3 modules (make it work if possible) ---------- try: import rfc822 @@ -89,34 +95,35 @@ try: except (NameError, AttributeError): import string _maketrans = string.maketrans - + # base64 support for Atom feeds that contain embedded binary data try: import base64, binascii +except ImportError: + base64 = binascii = None +else: # Python 3.1 deprecates decodestring in favor of decodebytes _base64decode = getattr(base64, 'decodebytes', base64.decodestring) -except: - base64 = binascii = None -def _s2bytes(s): - # Convert a UTF-8 str to bytes if the interpreter is Python 3 - try: - return bytes(s, 'utf8') - except (NameError, TypeError): - # In Python 2.5 and below, bytes doesn't exist (NameError) - # In Python 2.6 and above, bytes and str are the same (TypeError) - return s - -def _l2bytes(l): - # Convert a list of ints to bytes if the interpreter is Python 3 - try: - if bytes is not str: - # In Python 2.6 and above, this call won't raise an exception - # but it will return bytes([65]) as '[65]' instead of 'A' - return bytes(l) - raise NameError - except NameError: - return ''.join(map(chr, l)) +# _s2bytes: convert a UTF-8 str to bytes if the interpreter is Python 3 +# _l2bytes: convert a list of ints to bytes if the interpreter is Python 3 +try: + if bytes is str: + # In Python 2.5 and below, bytes doesn't exist (NameError) + # In Python 2.6 and above, bytes and str are the same type + raise NameError +except NameError: + # Python 2 + def _s2bytes(s): + return s + def _l2bytes(l): + return ''.join(map(chr, l)) +else: + # Python 3 + def _s2bytes(s): + return bytes(s, 'utf8') + def _l2bytes(l): + return bytes(l) # If you want feedparser to allow all URL schemes, set this to () # List culled from Python's urlparse documentation at: @@ -125,9 +132,10 @@ def _l2bytes(l): # https://secure.wikimedia.org/wikipedia/en/wiki/URI_scheme # Many more will likely need to be added! ACCEPTABLE_URI_SCHEMES = ( - 'file', 'ftp', 'gopher', 'h323', 'hdl', 'http', 'https', 'imap', 'mailto', - 'mms', 'news', 'nntp', 'prospero', 'rsync', 'rtsp', 'rtspu', 'sftp', - 'shttp', 'sip', 'sips', 'snews', 'svn', 'svn+ssh', 'telnet', 'wais', + 'file', 'ftp', 'gopher', 'h323', 'hdl', 'http', 'https', 'imap', 'magnet', + 'mailto', 'mms', 'news', 'nntp', 'prospero', 'rsync', 'rtsp', 'rtspu', + 'sftp', 'shttp', 'sip', 'sips', 'snews', 'svn', 'svn+ssh', 'telnet', + 'wais', # Additional common-but-unofficial schemes 'aim', 'callto', 'cvs', 'facetime', 'feed', 'git', 'gtalk', 'irc', 'ircs', 'irc6', 'itms', 'mms', 'msnim', 'skype', 'ssh', 'smb', 'svn', 'ymsg', @@ -135,13 +143,27 @@ ACCEPTABLE_URI_SCHEMES = ( #ACCEPTABLE_URI_SCHEMES = () # ---------- required modules (should come with any Python distribution) ---------- -import sgmllib, re, sys, copy, urlparse, time, types, cgi, urllib, urllib2, datetime +import cgi +import codecs +import copy +import datetime +import re +import struct +import time +import types +import urllib +import urllib2 +import urlparse +import warnings + +from htmlentitydefs import name2codepoint, codepoint2name, entitydefs + try: from io import BytesIO as _StringIO except ImportError: try: from cStringIO import StringIO as _StringIO - except: + except ImportError: from StringIO import StringIO as _StringIO # ---------- optional modules (feedparser will work without these, but with reduced functionality) ---------- @@ -149,23 +171,21 @@ except ImportError: # gzip is included with most Python distributions, but may not be available if you compiled your own try: import gzip -except: +except ImportError: gzip = None try: import zlib -except: +except ImportError: zlib = None # If a real XML parser is available, feedparser will attempt to use it. feedparser has -# been tested with the built-in SAX parser, PyXML, and libxml2. On platforms where the +# been tested with the built-in SAX parser and libxml2. On platforms where the # Python distribution does not come with an XML parser (such as Mac OS X 10.2 and some # versions of FreeBSD), feedparser will quietly fall back on regex-based parsing. try: import xml.sax - xml.sax.make_parser(PREFERRED_XML_PARSERS) # test for valid parsers from xml.sax.saxutils import escape as _xmlescape - _XML_AVAILABLE = 1 -except: +except ImportError: _XML_AVAILABLE = 0 def _xmlescape(data,entities={}): data = data.replace('&', '&') @@ -174,49 +194,96 @@ except: for char, entity in entities: data = data.replace(char, entity) return data +else: + try: + xml.sax.make_parser(PREFERRED_XML_PARSERS) # test for valid parsers + except xml.sax.SAXReaderNotAvailable: + _XML_AVAILABLE = 0 + else: + _XML_AVAILABLE = 1 -# cjkcodecs and iconv_codec provide support for more character encodings. -# Both are available from http://cjkpython.i18n.org/ +# sgmllib is not available by default in Python 3; if the end user doesn't have +# it available then we'll lose illformed XML parsing, content santizing, and +# microformat support (at least while feedparser depends on BeautifulSoup). try: - import cjkcodecs.aliases -except: - pass + import sgmllib +except ImportError: + # This is probably Python 3, which doesn't include sgmllib anymore + _SGML_AVAILABLE = 0 + + # Mock sgmllib enough to allow subclassing later on + class sgmllib(object): + class SGMLParser(object): + def goahead(self, i): + pass + def parse_starttag(self, i): + pass +else: + _SGML_AVAILABLE = 1 + + # sgmllib defines a number of module-level regular expressions that are + # insufficient for the XML parsing feedparser needs. Rather than modify + # the variables directly in sgmllib, they're defined here using the same + # names, and the compiled code objects of several sgmllib.SGMLParser + # methods are copied into _BaseHTMLProcessor so that they execute in + # feedparser's scope instead of sgmllib's scope. + charref = re.compile('&#(\d+|[xX][0-9a-fA-F]+);') + tagfind = re.compile('[a-zA-Z][-_.:a-zA-Z0-9]*') + attrfind = re.compile( + r'\s*([a-zA-Z_][-:.a-zA-Z_0-9]*)[$]?(\s*=\s*' + r'(\'[^\']*\'|"[^"]*"|[][\-a-zA-Z0-9./,:;+*%?!&$\(\)_#=~\'"@]*))?' + ) + + # Unfortunately, these must be copied over to prevent NameError exceptions + entityref = sgmllib.entityref + incomplete = sgmllib.incomplete + interesting = sgmllib.interesting + shorttag = sgmllib.shorttag + shorttagopen = sgmllib.shorttagopen + starttagopen = sgmllib.starttagopen + + class _EndBracketRegEx: + def __init__(self): + # Overriding the built-in sgmllib.endbracket regex allows the + # parser to find angle brackets embedded in element attributes. + self.endbracket = re.compile('''([^'"<>]|"[^"]*"(?=>|/|\s|\w+=)|'[^']*'(?=>|/|\s|\w+=))*(?=[<>])|.*?(?=[<>])''') + def search(self, target, index=0): + match = self.endbracket.match(target, index) + if match is not None: + # Returning a new object in the calling thread's context + # resolves a thread-safety. + return EndBracketMatch(match) + return None + class EndBracketMatch: + def __init__(self, match): + self.match = match + def start(self, n): + return self.match.end(n) + endbracket = _EndBracketRegEx() + + +# iconv_codec provides support for more character encodings. +# It's available from http://cjkpython.i18n.org/ try: import iconv_codec -except: +except ImportError: pass # chardet library auto-detects character encodings # Download from http://chardet.feedparser.org/ try: import chardet - if _debug: - import chardet.constants - chardet.constants._debug = 1 -except: +except ImportError: chardet = None -# reversable htmlentitydefs mappings for Python 2.2 -try: - from htmlentitydefs import name2codepoint, codepoint2name -except: - import htmlentitydefs - name2codepoint={} - codepoint2name={} - for (name,codepoint) in htmlentitydefs.entitydefs.iteritems(): - if codepoint.startswith('&#'): codepoint=unichr(int(codepoint[2:-1])) - name2codepoint[name]=ord(codepoint) - codepoint2name[ord(codepoint)]=name - -# BeautifulSoup parser used for parsing microformats from embedded HTML content +# BeautifulSoup is used to extract microformat content from HTML +# feedparser is tested using BeautifulSoup 3.2.0 # http://www.crummy.com/software/BeautifulSoup/ -# feedparser is tested with BeautifulSoup 3.0.x, but it might work with the -# older 2.x series. If it doesn't, and you can figure out why, I'll accept a -# patch and modify the compatibility statement accordingly. try: import BeautifulSoup -except: +except ImportError: BeautifulSoup = None + PARSE_MICROFORMATS = False # ---------- don't touch these ---------- class ThingsNobodyCaresAboutButMe(Exception): pass @@ -225,67 +292,32 @@ class CharacterEncodingUnknown(ThingsNobodyCaresAboutButMe): pass class NonXMLContentType(ThingsNobodyCaresAboutButMe): pass class UndeclaredNamespace(Exception): pass -sgmllib.tagfind = re.compile('[a-zA-Z][-_.:a-zA-Z0-9]*') -sgmllib.special = re.compile('<!') -sgmllib.charref = re.compile('&#(\d+|[xX][0-9a-fA-F]+);') - -if sgmllib.endbracket.search(' <').start(0): - class EndBracketRegEx: - def __init__(self): - # Overriding the built-in sgmllib.endbracket regex allows the - # parser to find angle brackets embedded in element attributes. - self.endbracket = re.compile('''([^'"<>]|"[^"]*"(?=>|/|\s|\w+=)|'[^']*'(?=>|/|\s|\w+=))*(?=[<>])|.*?(?=[<>])''') - def search(self,string,index=0): - match = self.endbracket.match(string,index) - if match is not None: - # Returning a new object in the calling thread's context - # resolves a thread-safety. - return EndBracketMatch(match) - return None - class EndBracketMatch: - def __init__(self, match): - self.match = match - def start(self, n): - return self.match.end(n) - sgmllib.endbracket = EndBracketRegEx() - -SUPPORTED_VERSIONS = {'': 'unknown', - 'rss090': 'RSS 0.90', - 'rss091n': 'RSS 0.91 (Netscape)', - 'rss091u': 'RSS 0.91 (Userland)', - 'rss092': 'RSS 0.92', - 'rss093': 'RSS 0.93', - 'rss094': 'RSS 0.94', - 'rss20': 'RSS 2.0', - 'rss10': 'RSS 1.0', - 'rss': 'RSS (unknown version)', - 'atom01': 'Atom 0.1', - 'atom02': 'Atom 0.2', - 'atom03': 'Atom 0.3', - 'atom10': 'Atom 1.0', - 'atom': 'Atom (unknown version)', - 'cdf': 'CDF', - 'hotrss': 'Hot RSS' +SUPPORTED_VERSIONS = {'': u'unknown', + 'rss090': u'RSS 0.90', + 'rss091n': u'RSS 0.91 (Netscape)', + 'rss091u': u'RSS 0.91 (Userland)', + 'rss092': u'RSS 0.92', + 'rss093': u'RSS 0.93', + 'rss094': u'RSS 0.94', + 'rss20': u'RSS 2.0', + 'rss10': u'RSS 1.0', + 'rss': u'RSS (unknown version)', + 'atom01': u'Atom 0.1', + 'atom02': u'Atom 0.2', + 'atom03': u'Atom 0.3', + 'atom10': u'Atom 1.0', + 'atom': u'Atom (unknown version)', + 'cdf': u'CDF', } -try: - UserDict = dict -except NameError: - # Python 2.1 does not have dict - from UserDict import UserDict - def dict(aList): - rc = {} - for k, v in aList: - rc[k] = v - return rc - -class FeedParserDict(UserDict): +class FeedParserDict(dict): keymap = {'channel': 'feed', 'items': 'entries', 'guid': 'id', 'date': 'updated', 'date_parsed': 'updated_parsed', 'description': ['summary', 'subtitle'], + 'description_detail': ['summary_detail', 'subtitle_detail'], 'url': ['href'], 'modified': 'updated', 'modified_parsed': 'updated_parsed', @@ -297,227 +329,220 @@ class FeedParserDict(UserDict): 'tagline_detail': 'subtitle_detail'} def __getitem__(self, key): if key == 'category': - return UserDict.__getitem__(self, 'tags')[0]['term'] - if key == 'enclosures': + try: + return dict.__getitem__(self, 'tags')[0]['term'] + except IndexError: + raise KeyError, "object doesn't have key 'category'" + elif key == 'enclosures': norel = lambda link: FeedParserDict([(name,value) for (name,value) in link.items() if name!='rel']) - return [norel(link) for link in UserDict.__getitem__(self, 'links') if link['rel']=='enclosure'] - if key == 'license': - for link in UserDict.__getitem__(self, 'links'): - if link['rel']=='license' and link.has_key('href'): + return [norel(link) for link in dict.__getitem__(self, 'links') if link['rel']==u'enclosure'] + elif key == 'license': + for link in dict.__getitem__(self, 'links'): + if link['rel']==u'license' and 'href' in link: return link['href'] - if key == 'categories': - return [(tag['scheme'], tag['term']) for tag in UserDict.__getitem__(self, 'tags')] - realkey = self.keymap.get(key, key) - if type(realkey) == types.ListType: - for k in realkey: - if UserDict.__contains__(self, k): - return UserDict.__getitem__(self, k) - if UserDict.__contains__(self, key): - return UserDict.__getitem__(self, key) - return UserDict.__getitem__(self, realkey) + elif key == 'updated': + # Temporarily help developers out by keeping the old + # broken behavior that was reported in issue 310. + # This fix was proposed in issue 328. + if not dict.__contains__(self, 'updated') and \ + dict.__contains__(self, 'published'): + warnings.warn("To avoid breaking existing software while " + "fixing issue 310, a temporary mapping has been created " + "from `updated` to `published` if `updated` doesn't " + "exist. This fallback will be removed in a future version " + "of feedparser.", DeprecationWarning) + return dict.__getitem__(self, 'published') + return dict.__getitem__(self, 'updated') + elif key == 'updated_parsed': + if not dict.__contains__(self, 'updated_parsed') and \ + dict.__contains__(self, 'published_parsed'): + warnings.warn("To avoid breaking existing software while " + "fixing issue 310, a temporary mapping has been created " + "from `updated_parsed` to `published_parsed` if " + "`updated_parsed` doesn't exist. This fallback will be " + "removed in a future version of feedparser.", + DeprecationWarning) + return dict.__getitem__(self, 'published_parsed') + return dict.__getitem__(self, 'updated_parsed') + else: + realkey = self.keymap.get(key, key) + if isinstance(realkey, list): + for k in realkey: + if dict.__contains__(self, k): + return dict.__getitem__(self, k) + elif dict.__contains__(self, realkey): + return dict.__getitem__(self, realkey) + return dict.__getitem__(self, key) - def __setitem__(self, key, value): - for k in self.keymap.keys(): - if key == k: - key = self.keymap[k] - if type(key) == types.ListType: - key = key[0] - return UserDict.__setitem__(self, key, value) + def __contains__(self, key): + if key in ('updated', 'updated_parsed'): + # Temporarily help developers out by keeping the old + # broken behavior that was reported in issue 310. + # This fix was proposed in issue 328. + return dict.__contains__(self, key) + try: + self.__getitem__(key) + except KeyError: + return False + else: + return True + + has_key = __contains__ def get(self, key, default=None): - if self.has_key(key): - return self[key] - else: + try: + return self.__getitem__(key) + except KeyError: return default + def __setitem__(self, key, value): + key = self.keymap.get(key, key) + if isinstance(key, list): + key = key[0] + return dict.__setitem__(self, key, value) + def setdefault(self, key, value): - if not self.has_key(key): + if key not in self: self[key] = value + return value return self[key] - - def has_key(self, key): - try: - return hasattr(self, key) or UserDict.__contains__(self, key) - except AttributeError: - return False - # This alias prevents the 2to3 tool from changing the semantics of the - # __contains__ function below and exhausting the maximum recursion depth - __has_key = has_key - + def __getattr__(self, key): + # __getattribute__() is called first; this will be called + # only if an attribute was not already found try: - return self.__dict__[key] - except KeyError: - pass - try: - assert not key.startswith('_') return self.__getitem__(key) - except: + except KeyError: raise AttributeError, "object has no attribute '%s'" % key - def __setattr__(self, key, value): - if key.startswith('_') or key == 'data': - self.__dict__[key] = value - else: - return self.__setitem__(key, value) + def __hash__(self): + return id(self) - def __contains__(self, key): - return self.__has_key(key) - -def zopeCompatibilityHack(): - global FeedParserDict - del FeedParserDict - def FeedParserDict(aDict=None): - rc = {} - if aDict: - rc.update(aDict) - return rc - -_ebcdic_to_ascii_map = None -def _ebcdic_to_ascii(s): - global _ebcdic_to_ascii_map - if not _ebcdic_to_ascii_map: - emap = ( - 0,1,2,3,156,9,134,127,151,141,142,11,12,13,14,15, - 16,17,18,19,157,133,8,135,24,25,146,143,28,29,30,31, - 128,129,130,131,132,10,23,27,136,137,138,139,140,5,6,7, - 144,145,22,147,148,149,150,4,152,153,154,155,20,21,158,26, - 32,160,161,162,163,164,165,166,167,168,91,46,60,40,43,33, - 38,169,170,171,172,173,174,175,176,177,93,36,42,41,59,94, - 45,47,178,179,180,181,182,183,184,185,124,44,37,95,62,63, - 186,187,188,189,190,191,192,193,194,96,58,35,64,39,61,34, - 195,97,98,99,100,101,102,103,104,105,196,197,198,199,200,201, - 202,106,107,108,109,110,111,112,113,114,203,204,205,206,207,208, - 209,126,115,116,117,118,119,120,121,122,210,211,212,213,214,215, - 216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231, - 123,65,66,67,68,69,70,71,72,73,232,233,234,235,236,237, - 125,74,75,76,77,78,79,80,81,82,238,239,240,241,242,243, - 92,159,83,84,85,86,87,88,89,90,244,245,246,247,248,249, - 48,49,50,51,52,53,54,55,56,57,250,251,252,253,254,255 - ) - _ebcdic_to_ascii_map = _maketrans( \ - _l2bytes(range(256)), _l2bytes(emap)) - return s.translate(_ebcdic_to_ascii_map) - _cp1252 = { - unichr(128): unichr(8364), # euro sign - unichr(130): unichr(8218), # single low-9 quotation mark - unichr(131): unichr( 402), # latin small letter f with hook - unichr(132): unichr(8222), # double low-9 quotation mark - unichr(133): unichr(8230), # horizontal ellipsis - unichr(134): unichr(8224), # dagger - unichr(135): unichr(8225), # double dagger - unichr(136): unichr( 710), # modifier letter circumflex accent - unichr(137): unichr(8240), # per mille sign - unichr(138): unichr( 352), # latin capital letter s with caron - unichr(139): unichr(8249), # single left-pointing angle quotation mark - unichr(140): unichr( 338), # latin capital ligature oe - unichr(142): unichr( 381), # latin capital letter z with caron - unichr(145): unichr(8216), # left single quotation mark - unichr(146): unichr(8217), # right single quotation mark - unichr(147): unichr(8220), # left double quotation mark - unichr(148): unichr(8221), # right double quotation mark - unichr(149): unichr(8226), # bullet - unichr(150): unichr(8211), # en dash - unichr(151): unichr(8212), # em dash - unichr(152): unichr( 732), # small tilde - unichr(153): unichr(8482), # trade mark sign - unichr(154): unichr( 353), # latin small letter s with caron - unichr(155): unichr(8250), # single right-pointing angle quotation mark - unichr(156): unichr( 339), # latin small ligature oe - unichr(158): unichr( 382), # latin small letter z with caron - unichr(159): unichr( 376)} # latin capital letter y with diaeresis + 128: unichr(8364), # euro sign + 130: unichr(8218), # single low-9 quotation mark + 131: unichr( 402), # latin small letter f with hook + 132: unichr(8222), # double low-9 quotation mark + 133: unichr(8230), # horizontal ellipsis + 134: unichr(8224), # dagger + 135: unichr(8225), # double dagger + 136: unichr( 710), # modifier letter circumflex accent + 137: unichr(8240), # per mille sign + 138: unichr( 352), # latin capital letter s with caron + 139: unichr(8249), # single left-pointing angle quotation mark + 140: unichr( 338), # latin capital ligature oe + 142: unichr( 381), # latin capital letter z with caron + 145: unichr(8216), # left single quotation mark + 146: unichr(8217), # right single quotation mark + 147: unichr(8220), # left double quotation mark + 148: unichr(8221), # right double quotation mark + 149: unichr(8226), # bullet + 150: unichr(8211), # en dash + 151: unichr(8212), # em dash + 152: unichr( 732), # small tilde + 153: unichr(8482), # trade mark sign + 154: unichr( 353), # latin small letter s with caron + 155: unichr(8250), # single right-pointing angle quotation mark + 156: unichr( 339), # latin small ligature oe + 158: unichr( 382), # latin small letter z with caron + 159: unichr( 376), # latin capital letter y with diaeresis +} _urifixer = re.compile('^([A-Za-z][A-Za-z0-9+-.]*://)(/*)(.*?)') def _urljoin(base, uri): uri = _urifixer.sub(r'\1\3', uri) - try: - return urlparse.urljoin(base, uri) - except: - uri = urlparse.urlunparse([urllib.quote(part) for part in urlparse.urlparse(uri)]) - return urlparse.urljoin(base, uri) + #try: + if not isinstance(uri, unicode): + uri = uri.decode('utf-8', 'ignore') + uri = urlparse.urljoin(base, uri) + if not isinstance(uri, unicode): + return uri.decode('utf-8', 'ignore') + return uri + #except: + # uri = urlparse.urlunparse([urllib.quote(part) for part in urlparse.urlparse(uri)]) + # return urlparse.urljoin(base, uri) class _FeedParserMixin: - namespaces = {'': '', - 'http://backend.userland.com/rss': '', - 'http://blogs.law.harvard.edu/tech/rss': '', - 'http://purl.org/rss/1.0/': '', - 'http://my.netscape.com/rdf/simple/0.9/': '', - 'http://example.com/newformat#': '', - 'http://example.com/necho': '', - 'http://purl.org/echo/': '', - 'uri/of/echo/namespace#': '', - 'http://purl.org/pie/': '', - 'http://purl.org/atom/ns#': '', - 'http://www.w3.org/2005/Atom': '', - 'http://purl.org/rss/1.0/modules/rss091#': '', - - 'http://webns.net/mvcb/': 'admin', - 'http://purl.org/rss/1.0/modules/aggregation/': 'ag', - 'http://purl.org/rss/1.0/modules/annotate/': 'annotate', - 'http://media.tangent.org/rss/1.0/': 'audio', - 'http://backend.userland.com/blogChannelModule': 'blogChannel', - 'http://web.resource.org/cc/': 'cc', - 'http://backend.userland.com/creativeCommonsRssModule': 'creativeCommons', - 'http://purl.org/rss/1.0/modules/company': 'co', - 'http://purl.org/rss/1.0/modules/content/': 'content', - 'http://my.theinfo.org/changed/1.0/rss/': 'cp', - 'http://purl.org/dc/elements/1.1/': 'dc', - 'http://purl.org/dc/terms/': 'dcterms', - 'http://purl.org/rss/1.0/modules/email/': 'email', - 'http://purl.org/rss/1.0/modules/event/': 'ev', - 'http://rssnamespace.org/feedburner/ext/1.0': 'feedburner', - 'http://freshmeat.net/rss/fm/': 'fm', - 'http://xmlns.com/foaf/0.1/': 'foaf', - 'http://www.w3.org/2003/01/geo/wgs84_pos#': 'geo', - 'http://postneo.com/icbm/': 'icbm', - 'http://purl.org/rss/1.0/modules/image/': 'image', - 'http://www.itunes.com/DTDs/PodCast-1.0.dtd': 'itunes', - 'http://example.com/DTDs/PodCast-1.0.dtd': 'itunes', - 'http://purl.org/rss/1.0/modules/link/': 'l', - 'http://search.yahoo.com/mrss': 'media', - #Version 1.1.2 of the Media RSS spec added the trailing slash on the namespace - 'http://search.yahoo.com/mrss/': 'media', - 'http://madskills.com/public/xml/rss/module/pingback/': 'pingback', - 'http://prismstandard.org/namespaces/1.2/basic/': 'prism', - 'http://www.w3.org/1999/02/22-rdf-syntax-ns#': 'rdf', - 'http://www.w3.org/2000/01/rdf-schema#': 'rdfs', - 'http://purl.org/rss/1.0/modules/reference/': 'ref', - 'http://purl.org/rss/1.0/modules/richequiv/': 'reqv', - 'http://purl.org/rss/1.0/modules/search/': 'search', - 'http://purl.org/rss/1.0/modules/slash/': 'slash', - 'http://schemas.xmlsoap.org/soap/envelope/': 'soap', - 'http://purl.org/rss/1.0/modules/servicestatus/': 'ss', - 'http://hacks.benhammersley.com/rss/streaming/': 'str', - 'http://purl.org/rss/1.0/modules/subscription/': 'sub', - 'http://purl.org/rss/1.0/modules/syndication/': 'sy', - 'http://schemas.pocketsoap.com/rss/myDescModule/': 'szf', - 'http://purl.org/rss/1.0/modules/taxonomy/': 'taxo', - 'http://purl.org/rss/1.0/modules/threading/': 'thr', - 'http://purl.org/rss/1.0/modules/textinput/': 'ti', - 'http://madskills.com/public/xml/rss/module/trackback/':'trackback', - 'http://wellformedweb.org/commentAPI/': 'wfw', - 'http://purl.org/rss/1.0/modules/wiki/': 'wiki', - 'http://www.w3.org/1999/xhtml': 'xhtml', - 'http://www.w3.org/1999/xlink': 'xlink', - 'http://www.w3.org/XML/1998/namespace': 'xml' -} + namespaces = { + '': '', + 'http://backend.userland.com/rss': '', + 'http://blogs.law.harvard.edu/tech/rss': '', + 'http://purl.org/rss/1.0/': '', + 'http://my.netscape.com/rdf/simple/0.9/': '', + 'http://example.com/newformat#': '', + 'http://example.com/necho': '', + 'http://purl.org/echo/': '', + 'uri/of/echo/namespace#': '', + 'http://purl.org/pie/': '', + 'http://purl.org/atom/ns#': '', + 'http://www.w3.org/2005/Atom': '', + 'http://purl.org/rss/1.0/modules/rss091#': '', + + 'http://webns.net/mvcb/': 'admin', + 'http://purl.org/rss/1.0/modules/aggregation/': 'ag', + 'http://purl.org/rss/1.0/modules/annotate/': 'annotate', + 'http://media.tangent.org/rss/1.0/': 'audio', + 'http://backend.userland.com/blogChannelModule': 'blogChannel', + 'http://web.resource.org/cc/': 'cc', + 'http://backend.userland.com/creativeCommonsRssModule': 'creativeCommons', + 'http://purl.org/rss/1.0/modules/company': 'co', + 'http://purl.org/rss/1.0/modules/content/': 'content', + 'http://my.theinfo.org/changed/1.0/rss/': 'cp', + 'http://purl.org/dc/elements/1.1/': 'dc', + 'http://purl.org/dc/terms/': 'dcterms', + 'http://purl.org/rss/1.0/modules/email/': 'email', + 'http://purl.org/rss/1.0/modules/event/': 'ev', + 'http://rssnamespace.org/feedburner/ext/1.0': 'feedburner', + 'http://freshmeat.net/rss/fm/': 'fm', + 'http://xmlns.com/foaf/0.1/': 'foaf', + 'http://www.w3.org/2003/01/geo/wgs84_pos#': 'geo', + 'http://postneo.com/icbm/': 'icbm', + 'http://purl.org/rss/1.0/modules/image/': 'image', + 'http://www.itunes.com/DTDs/PodCast-1.0.dtd': 'itunes', + 'http://example.com/DTDs/PodCast-1.0.dtd': 'itunes', + 'http://purl.org/rss/1.0/modules/link/': 'l', + 'http://search.yahoo.com/mrss': 'media', + # Version 1.1.2 of the Media RSS spec added the trailing slash on the namespace + 'http://search.yahoo.com/mrss/': 'media', + 'http://madskills.com/public/xml/rss/module/pingback/': 'pingback', + 'http://prismstandard.org/namespaces/1.2/basic/': 'prism', + 'http://www.w3.org/1999/02/22-rdf-syntax-ns#': 'rdf', + 'http://www.w3.org/2000/01/rdf-schema#': 'rdfs', + 'http://purl.org/rss/1.0/modules/reference/': 'ref', + 'http://purl.org/rss/1.0/modules/richequiv/': 'reqv', + 'http://purl.org/rss/1.0/modules/search/': 'search', + 'http://purl.org/rss/1.0/modules/slash/': 'slash', + 'http://schemas.xmlsoap.org/soap/envelope/': 'soap', + 'http://purl.org/rss/1.0/modules/servicestatus/': 'ss', + 'http://hacks.benhammersley.com/rss/streaming/': 'str', + 'http://purl.org/rss/1.0/modules/subscription/': 'sub', + 'http://purl.org/rss/1.0/modules/syndication/': 'sy', + 'http://schemas.pocketsoap.com/rss/myDescModule/': 'szf', + 'http://purl.org/rss/1.0/modules/taxonomy/': 'taxo', + 'http://purl.org/rss/1.0/modules/threading/': 'thr', + 'http://purl.org/rss/1.0/modules/textinput/': 'ti', + 'http://madskills.com/public/xml/rss/module/trackback/': 'trackback', + 'http://wellformedweb.org/commentAPI/': 'wfw', + 'http://purl.org/rss/1.0/modules/wiki/': 'wiki', + 'http://www.w3.org/1999/xhtml': 'xhtml', + 'http://www.w3.org/1999/xlink': 'xlink', + 'http://www.w3.org/XML/1998/namespace': 'xml', + } _matchnamespaces = {} - can_be_relative_uri = ['link', 'id', 'wfw_comment', 'wfw_commentrss', 'docs', 'url', 'href', 'comments', 'icon', 'logo'] - can_contain_relative_uris = ['content', 'title', 'summary', 'info', 'tagline', 'subtitle', 'copyright', 'rights', 'description'] - can_contain_dangerous_markup = ['content', 'title', 'summary', 'info', 'tagline', 'subtitle', 'copyright', 'rights', 'description'] - html_types = ['text/html', 'application/xhtml+xml'] - - def __init__(self, baseuri=None, baselang=None, encoding='utf-8'): - if _debug: sys.stderr.write('initializing FeedParser\n') + can_be_relative_uri = set(['link', 'id', 'wfw_comment', 'wfw_commentrss', 'docs', 'url', 'href', 'comments', 'icon', 'logo']) + can_contain_relative_uris = set(['content', 'title', 'summary', 'info', 'tagline', 'subtitle', 'copyright', 'rights', 'description']) + can_contain_dangerous_markup = set(['content', 'title', 'summary', 'info', 'tagline', 'subtitle', 'copyright', 'rights', 'description']) + html_types = [u'text/html', u'application/xhtml+xml'] + + def __init__(self, baseuri=None, baselang=None, encoding=u'utf-8'): if not self._matchnamespaces: for k, v in self.namespaces.items(): self._matchnamespaces[k.lower()] = v self.feeddata = FeedParserDict() # feed-level data self.encoding = encoding # character encoding self.entries = [] # list of entry-level data - self.version = '' # feed type/version, see SUPPORTED_VERSIONS + self.version = u'' # feed type/version, see SUPPORTED_VERSIONS self.namespacesInUse = {} # dictionary of namespaces defined by the feed # the following are used internally to track state; @@ -538,31 +563,47 @@ class _FeedParserMixin: self.elementstack = [] self.basestack = [] self.langstack = [] - self.baseuri = baseuri or '' + self.baseuri = baseuri or u'' self.lang = baselang or None self.svgOK = 0 - self.hasTitle = 0 + self.title_depth = -1 + self.depth = 0 if baselang: self.feeddata['language'] = baselang.replace('_','-') - def unknown_starttag(self, tag, attrs): - if _debug: sys.stderr.write('start %s with %s\n' % (tag, attrs)) - # normalize attrs - attrs = [(k.lower(), v) for k, v in attrs] - attrs = [(k, k in ('rel', 'type') and v.lower() or v) for k, v in attrs] - # the sgml parser doesn't handle entities in attributes, but + # A map of the following form: + # { + # object_that_value_is_set_on: { + # property_name: depth_of_node_property_was_extracted_from, + # other_property: depth_of_node_property_was_extracted_from, + # }, + # } + self.property_depth_map = {} + + def _normalize_attributes(self, kv): + k = kv[0].lower() + v = k in ('rel', 'type') and kv[1].lower() or kv[1] + # the sgml parser doesn't handle entities in attributes, nor + # does it pass the attribute values through as unicode, while # strict xml parsers do -- account for this difference if isinstance(self, _LooseFeedParser): - attrs = [(k, v.replace('&', '&')) for k, v in attrs] - + v = v.replace('&', '&') + if not isinstance(v, unicode): + v = v.decode('utf-8') + return (k, v) + + def unknown_starttag(self, tag, attrs): + # increment depth counter + self.depth += 1 + + # normalize attrs + attrs = map(self._normalize_attributes, attrs) + # track xml:base and xml:lang attrsD = dict(attrs) baseuri = attrsD.get('xml:base', attrsD.get('base')) or self.baseuri - if type(baseuri) != type(u''): - try: - baseuri = unicode(baseuri, self.encoding) - except: - baseuri = unicode(baseuri, 'iso-8859-1') + if not isinstance(baseuri, unicode): + baseuri = baseuri.decode(self.encoding, 'ignore') # ensure that self.baseuri is always an absolute URI that # uses a whitelisted URI scheme (e.g. not `javscript:`) if self.baseuri: @@ -582,7 +623,7 @@ class _FeedParserMixin: self.lang = lang self.basestack.append(self.baseuri) self.langstack.append(lang) - + # track namespaces for prefix, uri in attrs: if prefix.startswith('xmlns:'): @@ -591,11 +632,12 @@ class _FeedParserMixin: self.trackNamespace(None, uri) # track inline content - if self.incontent and self.contentparams.has_key('type') and not self.contentparams.get('type', 'xml').endswith('xml'): - if tag in ['xhtml:div', 'div']: return # typepad does this 10/2007 + if self.incontent and not self.contentparams.get('type', u'xml').endswith(u'xml'): + if tag in ('xhtml:div', 'div'): + return # typepad does this 10/2007 # element declared itself as escaped markup, but it isn't really - self.contentparams['type'] = 'application/xhtml+xml' - if self.incontent and self.contentparams.get('type') == 'application/xhtml+xml': + self.contentparams['type'] = u'application/xhtml+xml' + if self.incontent and self.contentparams.get('type') == u'application/xhtml+xml': if tag.find(':') <> -1: prefix, tag = tag.split(':', 1) namespace = self.namespacesInUse.get(prefix, '') @@ -603,7 +645,8 @@ class _FeedParserMixin: attrs.append(('xmlns',namespace)) if tag=='svg' and namespace=='http://www.w3.org/2000/svg': attrs.append(('xmlns',namespace)) - if tag == 'svg': self.svgOK += 1 + if tag == 'svg': + self.svgOK += 1 return self.handle_data('<%s%s>' % (tag, self.strattrs(attrs)), escape=0) # match namespaces @@ -620,7 +663,7 @@ class _FeedParserMixin: self.intextinput = 0 if (not prefix) and tag not in ('title', 'link', 'description', 'url', 'href', 'width', 'height'): self.inimage = 0 - + # call special handler (if defined) or default handler methodname = '_start_' + prefix + suffix try: @@ -638,7 +681,6 @@ class _FeedParserMixin: context[unknown_tag] = attrsD def unknown_endtag(self, tag): - if _debug: sys.stderr.write('end %s\n' % tag) # match namespaces if tag.find(':') <> -1: prefix, suffix = tag.split(':', 1) @@ -647,23 +689,26 @@ class _FeedParserMixin: prefix = self.namespacemap.get(prefix, prefix) if prefix: prefix = prefix + '_' - if suffix == 'svg' and self.svgOK: self.svgOK -= 1 + if suffix == 'svg' and self.svgOK: + self.svgOK -= 1 # call special handler (if defined) or default handler methodname = '_end_' + prefix + suffix try: - if self.svgOK: raise AttributeError() + if self.svgOK: + raise AttributeError() method = getattr(self, methodname) method() except AttributeError: self.pop(prefix + suffix) # track inline content - if self.incontent and self.contentparams.has_key('type') and not self.contentparams.get('type', 'xml').endswith('xml'): + if self.incontent and not self.contentparams.get('type', u'xml').endswith(u'xml'): # element declared itself as escaped markup, but it isn't really - if tag in ['xhtml:div', 'div']: return # typepad does this 10/2007 - self.contentparams['type'] = 'application/xhtml+xml' - if self.incontent and self.contentparams.get('type') == 'application/xhtml+xml': + if tag in ('xhtml:div', 'div'): + return # typepad does this 10/2007 + self.contentparams['type'] = u'application/xhtml+xml' + if self.incontent and self.contentparams.get('type') == u'application/xhtml+xml': tag = tag.split(':')[-1] self.handle_data('</%s>' % tag, escape=0) @@ -677,9 +722,12 @@ class _FeedParserMixin: if self.langstack: # and (self.langstack[-1] is not None): self.lang = self.langstack[-1] + self.depth -= 1 + def handle_charref(self, ref): # called for each character reference, e.g. for ' ', ref will be '160' - if not self.elementstack: return + if not self.elementstack: + return ref = ref.lower() if ref in ('34', '38', '39', '60', '62', 'x22', 'x26', 'x27', 'x3c', 'x3e'): text = '&#%s;' % ref @@ -693,25 +741,29 @@ class _FeedParserMixin: def handle_entityref(self, ref): # called for each entity reference, e.g. for '©', ref will be 'copy' - if not self.elementstack: return - if _debug: sys.stderr.write('entering handle_entityref with %s\n' % ref) + if not self.elementstack: + return if ref in ('lt', 'gt', 'quot', 'amp', 'apos'): text = '&%s;' % ref - elif ref in self.entities.keys(): + elif ref in self.entities: text = self.entities[ref] if text.startswith('&#') and text.endswith(';'): return self.handle_entityref(text) else: - try: name2codepoint[ref] - except KeyError: text = '&%s;' % ref - else: text = unichr(name2codepoint[ref]).encode('utf-8') + try: + name2codepoint[ref] + except KeyError: + text = '&%s;' % ref + else: + text = unichr(name2codepoint[ref]).encode('utf-8') self.elementstack[-1][2].append(text) def handle_data(self, text, escape=1): # called for each block of plain text, i.e. outside of any tag and # not containing any character or entity references - if not self.elementstack: return - if escape and self.contentparams.get('type') == 'application/xhtml+xml': + if not self.elementstack: + return + if escape and self.contentparams.get('type') == u'application/xhtml+xml': text = _xmlescape(text) self.elementstack[-1][2].append(text) @@ -728,7 +780,6 @@ class _FeedParserMixin: def parse_declaration(self, i): # override internal declaration handler to handle CDATA blocks - if _debug: sys.stderr.write('entering parse_declaration\n') if self.rawdata[i:i+9] == '<![CDATA[': k = self.rawdata.find(']]>', i) if k == -1: @@ -747,35 +798,36 @@ class _FeedParserMixin: def mapContentType(self, contentType): contentType = contentType.lower() - if contentType == 'text': - contentType = 'text/plain' + if contentType == 'text' or contentType == 'plain': + contentType = u'text/plain' elif contentType == 'html': - contentType = 'text/html' + contentType = u'text/html' elif contentType == 'xhtml': - contentType = 'application/xhtml+xml' + contentType = u'application/xhtml+xml' return contentType - + def trackNamespace(self, prefix, uri): loweruri = uri.lower() - if (prefix, loweruri) == (None, 'http://my.netscape.com/rdf/simple/0.9/') and not self.version: - self.version = 'rss090' - if loweruri == 'http://purl.org/rss/1.0/' and not self.version: - self.version = 'rss10' - if loweruri == 'http://www.w3.org/2005/atom' and not self.version: - self.version = 'atom10' - if loweruri.find('backend.userland.com/rss') <> -1: + if not self.version: + if (prefix, loweruri) == (None, 'http://my.netscape.com/rdf/simple/0.9/'): + self.version = u'rss090' + elif loweruri == 'http://purl.org/rss/1.0/': + self.version = u'rss10' + elif loweruri == 'http://www.w3.org/2005/atom': + self.version = u'atom10' + if loweruri.find(u'backend.userland.com/rss') <> -1: # match any backend.userland.com namespace - uri = 'http://backend.userland.com/rss' + uri = u'http://backend.userland.com/rss' loweruri = uri - if self._matchnamespaces.has_key(loweruri): + if loweruri in self._matchnamespaces: self.namespacemap[prefix] = self._matchnamespaces[loweruri] self.namespacesInUse[self._matchnamespaces[loweruri]] = uri else: self.namespacesInUse[prefix or ''] = uri def resolveURI(self, uri): - return _urljoin(self.baseuri or '', uri) - + return _urljoin(self.baseuri or u'', uri) + def decodeEntities(self, element, data): return data @@ -786,12 +838,14 @@ class _FeedParserMixin: self.elementstack.append([element, expectingText, []]) def pop(self, element, stripWhitespace=1): - if not self.elementstack: return - if self.elementstack[-1][0] != element: return - + if not self.elementstack: + return + if self.elementstack[-1][0] != element: + return + element, expectingText, pieces = self.elementstack.pop() - if self.version == 'atom10' and self.contentparams.get('type','text') == 'application/xhtml+xml': + if self.version == u'atom10' and self.contentparams.get('type', u'text') == u'application/xhtml+xml': # remove enclosing child element, but only if it is a <div> and # only if all the remaining content is nested underneath it. # This means that the divs would be retained in the following: @@ -805,7 +859,8 @@ class _FeedParserMixin: for piece in pieces[:-1]: if piece.startswith('</'): depth -= 1 - if depth == 0: break + if depth == 0: + break elif piece.startswith('<') and not piece.endswith('/>'): depth += 1 else: @@ -813,13 +868,14 @@ class _FeedParserMixin: # Ensure each piece is a str for Python 3 for (i, v) in enumerate(pieces): - if not isinstance(v, basestring): + if not isinstance(v, unicode): pieces[i] = v.decode('utf-8') - output = ''.join(pieces) + output = u''.join(pieces) if stripWhitespace: output = output.strip() - if not expectingText: return output + if not expectingText: + return output # decode base64 content if base64 and self.contentparams.get('base64', 0): @@ -833,17 +889,20 @@ class _FeedParserMixin: # In Python 3, base64 takes and outputs bytes, not str # This may not be the most correct way to accomplish this output = _base64decode(output.encode('utf-8')).decode('utf-8') - + # resolve relative URIs if (element in self.can_be_relative_uri) and output: output = self.resolveURI(output) - + # decode entities within embedded markup if not self.contentparams.get('base64', 0): output = self.decodeEntities(element, output) - if self.lookslikehtml(output): - self.contentparams['type']='text/html' + # some feed formats require consumers to guess + # whether the content is html or plain text + if not self.version.startswith(u'atom') and self.contentparams.get('type') == u'text/plain': + if self.lookslikehtml(output): + self.contentparams['type'] = u'text/html' # remove temporary cruft from contentparams try: @@ -855,16 +914,16 @@ class _FeedParserMixin: except KeyError: pass - is_htmlish = self.mapContentType(self.contentparams.get('type', 'text/html')) in self.html_types + is_htmlish = self.mapContentType(self.contentparams.get('type', u'text/html')) in self.html_types # resolve relative URIs within embedded markup if is_htmlish and RESOLVE_RELATIVE_URIS: if element in self.can_contain_relative_uris: - output = _resolveRelativeURIs(output, self.baseuri, self.encoding, self.contentparams.get('type', 'text/html')) - + output = _resolveRelativeURIs(output, self.baseuri, self.encoding, self.contentparams.get('type', u'text/html')) + # parse microformats # (must do this before sanitizing because some microformats # rely on elements that we sanitize) - if is_htmlish and element in ['content', 'description', 'summary']: + if PARSE_MICROFORMATS and is_htmlish and element in ['content', 'description', 'summary']: mfresults = _parseMicroformats(output, self.baseuri, self.encoding) if mfresults: for tag in mfresults.get('tags', []): @@ -876,37 +935,34 @@ class _FeedParserMixin: vcard = mfresults.get('vcard') if vcard: self._getContext()['vcard'] = vcard - + # sanitize embedded markup if is_htmlish and SANITIZE_HTML: if element in self.can_contain_dangerous_markup: - output = _sanitizeHTML(output, self.encoding, self.contentparams.get('type', 'text/html')) + output = _sanitizeHTML(output, self.encoding, self.contentparams.get('type', u'text/html')) - if self.encoding and type(output) != type(u''): - try: - output = unicode(output, self.encoding) - except: - pass + if self.encoding and not isinstance(output, unicode): + output = output.decode(self.encoding, 'ignore') # address common error where people take data that is already # utf-8, presume that it is iso-8859-1, and re-encode it. - if self.encoding in ('utf-8', 'utf-8_INVALID_PYTHON_3') and type(output) == type(u''): + if self.encoding in (u'utf-8', u'utf-8_INVALID_PYTHON_3') and isinstance(output, unicode): try: - output = unicode(output.encode('iso-8859-1'), 'utf-8') - except: + output = output.encode('iso-8859-1').decode('utf-8') + except (UnicodeEncodeError, UnicodeDecodeError): pass # map win-1252 extensions to the proper code points - if type(output) == type(u''): - output = u''.join([c in _cp1252.keys() and _cp1252[c] or c for c in output]) + if isinstance(output, unicode): + output = output.translate(_cp1252) # categories/tags/keywords/whatever are handled in _end_category if element == 'category': return output - if element == 'title' and self.hasTitle: + if element == 'title' and -1 < self.title_depth <= self.depth: return output - + # store output in appropriate place(s) if self.inentry and not self.insource: if element == 'content': @@ -926,7 +982,10 @@ class _FeedParserMixin: else: if element == 'description': element = 'summary' - self.entries[-1][element] = output + old_value_depth = self.property_depth_map.setdefault(self.entries[-1], {}).get(element) + if old_value_depth is None or self.depth <= old_value_depth: + self.property_depth_map[self.entries[-1]][element] = self.depth + self.entries[-1][element] = output if self.incontent: contentparams = copy.deepcopy(self.contentparams) contentparams['value'] = output @@ -949,7 +1008,8 @@ class _FeedParserMixin: def pushContent(self, tag, attrsD, defaultContentType, expectingText): self.incontent += 1 - if self.lang: self.lang=self.lang.replace('_','-') + if self.lang: + self.lang=self.lang.replace('_','-') self.contentparams = FeedParserDict({ 'type': self.mapContentType(attrsD.get('type', defaultContentType)), 'language': self.lang, @@ -962,26 +1022,25 @@ class _FeedParserMixin: self.incontent -= 1 self.contentparams.clear() return value - + # a number of elements in a number of RSS variants are nominally plain # text, but this is routinely ignored. This is an attempt to detect # the most common cases. As false positives often result in silent # data loss, this function errs on the conservative side. - def lookslikehtml(self, s): - if self.version.startswith('atom'): return - if self.contentparams.get('type','text/html') != 'text/plain': return - - # must have a close tag or a entity reference to qualify - if not (re.search(r'</(\w+)>',s) or re.search("&#?\w+;",s)): return + @staticmethod + def lookslikehtml(s): + # must have a close tag or an entity reference to qualify + if not (re.search(r'</(\w+)>',s) or re.search("&#?\w+;",s)): + return # all tags must be in a restricted subset of valid HTML tags if filter(lambda t: t.lower() not in _HTMLSanitizer.acceptable_elements, - re.findall(r'</?(\w+)',s)): return + re.findall(r'</?(\w+)',s)): + return # all entities must have been defined as valid HTML entities - from htmlentitydefs import entitydefs - if filter(lambda e: e not in entitydefs.keys(), - re.findall(r'&(\w+);',s)): return + if filter(lambda e: e not in entitydefs.keys(), re.findall(r'&(\w+);', s)): + return return 1 @@ -993,18 +1052,18 @@ class _FeedParserMixin: prefix = self.namespacemap.get(prefix, prefix) name = prefix + ':' + suffix return name - + def _getAttribute(self, attrsD, name): return attrsD.get(self._mapToStandardPrefix(name)) def _isBase64(self, attrsD, contentparams): if attrsD.get('mode', '') == 'base64': return 1 - if self.contentparams['type'].startswith('text/'): + if self.contentparams['type'].startswith(u'text/'): return 0 - if self.contentparams['type'].endswith('+xml'): + if self.contentparams['type'].endswith(u'+xml'): return 0 - if self.contentparams['type'].endswith('/xml'): + if self.contentparams['type'].endswith(u'/xml'): return 0 return 1 @@ -1021,7 +1080,7 @@ class _FeedParserMixin: pass attrsD['href'] = href return attrsD - + def _save(self, key, value, overwrite=False): context = self._getContext() if overwrite: @@ -1030,66 +1089,62 @@ class _FeedParserMixin: context.setdefault(key, value) def _start_rss(self, attrsD): - versionmap = {'0.91': 'rss091u', - '0.92': 'rss092', - '0.93': 'rss093', - '0.94': 'rss094'} + versionmap = {'0.91': u'rss091u', + '0.92': u'rss092', + '0.93': u'rss093', + '0.94': u'rss094'} #If we're here then this is an RSS feed. #If we don't have a version or have a version that starts with something #other than RSS then there's been a mistake. Correct it. - if not self.version or not self.version.startswith('rss'): + if not self.version or not self.version.startswith(u'rss'): attr_version = attrsD.get('version', '') version = versionmap.get(attr_version) if version: self.version = version elif attr_version.startswith('2.'): - self.version = 'rss20' + self.version = u'rss20' else: - self.version = 'rss' - - def _start_dlhottitles(self, attrsD): - self.version = 'hotrss' + self.version = u'rss' def _start_channel(self, attrsD): self.infeed = 1 self._cdf_common(attrsD) - _start_feedinfo = _start_channel def _cdf_common(self, attrsD): - if attrsD.has_key('lastmod'): + if 'lastmod' in attrsD: self._start_modified({}) self.elementstack[-1][-1] = attrsD['lastmod'] self._end_modified() - if attrsD.has_key('href'): + if 'href' in attrsD: self._start_link({}) self.elementstack[-1][-1] = attrsD['href'] self._end_link() - + def _start_feed(self, attrsD): self.infeed = 1 - versionmap = {'0.1': 'atom01', - '0.2': 'atom02', - '0.3': 'atom03'} + versionmap = {'0.1': u'atom01', + '0.2': u'atom02', + '0.3': u'atom03'} if not self.version: attr_version = attrsD.get('version') version = versionmap.get(attr_version) if version: self.version = version else: - self.version = 'atom' + self.version = u'atom' def _end_channel(self): self.infeed = 0 _end_feed = _end_channel - + def _start_image(self, attrsD): context = self._getContext() if not self.inentry: context.setdefault('image', FeedParserDict()) self.inimage = 1 - self.hasTitle = 0 + self.title_depth = -1 self.push('image', 0) - + def _end_image(self): self.pop('image') self.inimage = 0 @@ -1098,10 +1153,10 @@ class _FeedParserMixin: context = self._getContext() context.setdefault('textinput', FeedParserDict()) self.intextinput = 1 - self.hasTitle = 0 + self.title_depth = -1 self.push('textinput', 0) _start_textInput = _start_textinput - + def _end_textinput(self): self.pop('textinput') self.intextinput = 0 @@ -1183,7 +1238,7 @@ class _FeedParserMixin: value = self.pop('width') try: value = int(value) - except: + except ValueError: value = 0 if self.inimage: context = self._getContext() @@ -1196,7 +1251,7 @@ class _FeedParserMixin: value = self.pop('height') try: value = int(value) - except: + except ValueError: value = 0 if self.inimage: context = self._getContext() @@ -1233,7 +1288,7 @@ class _FeedParserMixin: def _getContext(self): if self.insource: context = self.sourcedata - elif self.inimage and self.feeddata.has_key('image'): + elif self.inimage and 'image' in self.feeddata: context = self.feeddata['image'] elif self.intextinput: context = self.feeddata['textinput'] @@ -1263,26 +1318,27 @@ class _FeedParserMixin: name = detail.get('name') email = detail.get('email') if name and email: - context[key] = '%s (%s)' % (name, email) + context[key] = u'%s (%s)' % (name, email) elif name: context[key] = name elif email: context[key] = email else: author, email = context.get(key), None - if not author: return - emailmatch = re.search(r'''(([a-zA-Z0-9\_\-\.\+]+)@((\[[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.)|(([a-zA-Z0-9\-]+\.)+))([a-zA-Z]{2,4}|[0-9]{1,3})(\]?))(\?subject=\S+)?''', author) + if not author: + return + emailmatch = re.search(ur'''(([a-zA-Z0-9\_\-\.\+]+)@((\[[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.)|(([a-zA-Z0-9\-]+\.)+))([a-zA-Z]{2,4}|[0-9]{1,3})(\]?))(\?subject=\S+)?''', author) if emailmatch: email = emailmatch.group(0) # probably a better way to do the following, but it passes all the tests - author = author.replace(email, '') - author = author.replace('()', '') - author = author.replace('<>', '') - author = author.replace('<>', '') + author = author.replace(email, u'') + author = author.replace(u'()', u'') + author = author.replace(u'<>', u'') + author = author.replace(u'<>', u'') author = author.strip() - if author and (author[0] == '('): + if author and (author[0] == u'('): author = author[1:] - if author and (author[-1] == ')'): + if author and (author[-1] == u')'): author = author[:-1] author = author.strip() if author or email: @@ -1293,7 +1349,7 @@ class _FeedParserMixin: context['%s_detail' % key]['email'] = email def _start_subtitle(self, attrsD): - self.pushContent('subtitle', attrsD, 'text/plain', 1) + self.pushContent('subtitle', attrsD, u'text/plain', 1) _start_tagline = _start_subtitle _start_itunes_subtitle = _start_subtitle @@ -1301,9 +1357,9 @@ class _FeedParserMixin: self.popContent('subtitle') _end_tagline = _end_subtitle _end_itunes_subtitle = _end_subtitle - + def _start_rights(self, attrsD): - self.pushContent('rights', attrsD, 'text/plain', 1) + self.pushContent('rights', attrsD, u'text/plain', 1) _start_dc_rights = _start_rights _start_copyright = _start_rights @@ -1317,14 +1373,13 @@ class _FeedParserMixin: self.push('item', 0) self.inentry = 1 self.guidislink = 0 - self.hasTitle = 0 + self.title_depth = -1 id = self._getAttribute(attrsD, 'rdf:about') if id: context = self._getContext() context['id'] = id self._cdf_common(attrsD) _start_entry = _start_item - _start_product = _start_item def _end_item(self): self.pop('item') @@ -1352,18 +1407,19 @@ class _FeedParserMixin: self.push('published', 1) _start_dcterms_issued = _start_published _start_issued = _start_published + _start_pubdate = _start_published def _end_published(self): value = self.pop('published') self._save('published_parsed', _parse_date(value), overwrite=True) _end_dcterms_issued = _end_published _end_issued = _end_published + _end_pubdate = _end_published def _start_updated(self, attrsD): self.push('updated', 1) _start_modified = _start_updated _start_dcterms_modified = _start_updated - _start_pubdate = _start_updated _start_dc_date = _start_updated _start_lastbuilddate = _start_updated @@ -1373,7 +1429,6 @@ class _FeedParserMixin: self._save('updated_parsed', parsed_value, overwrite=True) _end_modified = _end_updated _end_dcterms_modified = _end_updated - _end_pubdate = _end_updated _end_dc_date = _end_updated _end_lastbuilddate = _end_updated @@ -1396,10 +1451,11 @@ class _FeedParserMixin: context = self._getContext() value = self._getAttribute(attrsD, 'rdf:resource') attrsD = FeedParserDict() - attrsD['rel']='license' - if value: attrsD['href']=value + attrsD['rel'] = u'license' + if value: + attrsD['href']=value context.setdefault('links', []).append(attrsD) - + def _start_creativecommons_license(self, attrsD): self.push('license', 1) _start_creativeCommons_license = _start_creativecommons_license @@ -1408,8 +1464,9 @@ class _FeedParserMixin: value = self.pop('license') context = self._getContext() attrsD = FeedParserDict() - attrsD['rel']='license' - if value: attrsD['href']=value + attrsD['rel'] = u'license' + if value: + attrsD['href'] = value context.setdefault('links', []).append(attrsD) del context['license'] _end_creativeCommons_license = _end_creativecommons_license @@ -1420,17 +1477,17 @@ class _FeedParserMixin: value = FeedParserDict({'relationships': relationships, 'href': href, 'name': name}) if value not in xfn: xfn.append(value) - + def _addTag(self, term, scheme, label): context = self._getContext() tags = context.setdefault('tags', []) - if (not term) and (not scheme) and (not label): return + if (not term) and (not scheme) and (not label): + return value = FeedParserDict({'term': term, 'scheme': scheme, 'label': label}) if value not in tags: tags.append(value) def _start_category(self, attrsD): - if _debug: sys.stderr.write('entering _start_category with %s\n' % repr(attrsD)) term = attrsD.get('term') scheme = attrsD.get('scheme', attrsD.get('domain')) label = attrsD.get('label') @@ -1438,22 +1495,24 @@ class _FeedParserMixin: self.push('category', 1) _start_dc_subject = _start_category _start_keywords = _start_category - + def _start_media_category(self, attrsD): - attrsD.setdefault('scheme', 'http://search.yahoo.com/mrss/category_schema') + attrsD.setdefault('scheme', u'http://search.yahoo.com/mrss/category_schema') self._start_category(attrsD) def _end_itunes_keywords(self): - for term in self.pop('itunes_keywords').split(): - self._addTag(term, 'http://www.itunes.com/', None) - + for term in self.pop('itunes_keywords').split(','): + if term.strip(): + self._addTag(term.strip(), u'http://www.itunes.com/', None) + def _start_itunes_category(self, attrsD): - self._addTag(attrsD.get('text'), 'http://www.itunes.com/', None) + self._addTag(attrsD.get('text'), u'http://www.itunes.com/', None) self.push('category', 1) - + def _end_category(self): value = self.pop('category') - if not value: return + if not value: + return context = self._getContext() tags = context['tags'] if value and len(tags) and not tags[-1]['term']: @@ -1467,76 +1526,77 @@ class _FeedParserMixin: def _start_cloud(self, attrsD): self._getContext()['cloud'] = FeedParserDict(attrsD) - + def _start_link(self, attrsD): - attrsD.setdefault('rel', 'alternate') - if attrsD['rel'] == 'self': - attrsD.setdefault('type', 'application/atom+xml') + attrsD.setdefault('rel', u'alternate') + if attrsD['rel'] == u'self': + attrsD.setdefault('type', u'application/atom+xml') else: - attrsD.setdefault('type', 'text/html') + attrsD.setdefault('type', u'text/html') context = self._getContext() attrsD = self._itsAnHrefDamnIt(attrsD) - if attrsD.has_key('href'): + if 'href' in attrsD: attrsD['href'] = self.resolveURI(attrsD['href']) expectingText = self.infeed or self.inentry or self.insource context.setdefault('links', []) if not (self.inentry and self.inimage): context['links'].append(FeedParserDict(attrsD)) - if attrsD.has_key('href'): + if 'href' in attrsD: expectingText = 0 - if (attrsD.get('rel') == 'alternate') and (self.mapContentType(attrsD.get('type')) in self.html_types): + if (attrsD.get('rel') == u'alternate') and (self.mapContentType(attrsD.get('type')) in self.html_types): context['link'] = attrsD['href'] else: self.push('link', expectingText) - _start_producturl = _start_link def _end_link(self): value = self.pop('link') - context = self._getContext() - _end_producturl = _end_link def _start_guid(self, attrsD): self.guidislink = (attrsD.get('ispermalink', 'true') == 'true') self.push('id', 1) + _start_id = _start_guid def _end_guid(self): value = self.pop('id') - self._save('guidislink', self.guidislink and not self._getContext().has_key('link')) + self._save('guidislink', self.guidislink and 'link' not in self._getContext()) if self.guidislink: # guid acts as link, but only if 'ispermalink' is not present or is 'true', # and only if the item doesn't already have a link element self._save('link', value) + _end_id = _end_guid def _start_title(self, attrsD): - if self.svgOK: return self.unknown_starttag('title', attrsD.items()) - self.pushContent('title', attrsD, 'text/plain', self.infeed or self.inentry or self.insource) + if self.svgOK: + return self.unknown_starttag('title', attrsD.items()) + self.pushContent('title', attrsD, u'text/plain', self.infeed or self.inentry or self.insource) _start_dc_title = _start_title _start_media_title = _start_title def _end_title(self): - if self.svgOK: return + if self.svgOK: + return value = self.popContent('title') - if not value: return - context = self._getContext() - self.hasTitle = 1 + if not value: + return + self.title_depth = self.depth _end_dc_title = _end_title def _end_media_title(self): - hasTitle = self.hasTitle + title_depth = self.title_depth self._end_title() - self.hasTitle = hasTitle + self.title_depth = title_depth def _start_description(self, attrsD): context = self._getContext() - if context.has_key('summary'): + if 'summary' in context: self._summaryKey = 'content' self._start_content(attrsD) else: - self.pushContent('description', attrsD, 'text/html', self.infeed or self.inentry or self.insource) + self.pushContent('description', attrsD, u'text/html', self.infeed or self.inentry or self.insource) _start_dc_description = _start_description def _start_abstract(self, attrsD): - self.pushContent('description', attrsD, 'text/plain', self.infeed or self.inentry or self.insource) + self.pushContent('description', attrsD, u'text/plain', self.infeed or self.inentry or self.insource) def _end_description(self): if self._summaryKey == 'content': @@ -1548,7 +1608,7 @@ class _FeedParserMixin: _end_dc_description = _end_description def _start_info(self, attrsD): - self.pushContent('info', attrsD, 'text/plain', 1) + self.pushContent('info', attrsD, u'text/plain', 1) _start_feedburner_browserfriendly = _start_info def _end_info(self): @@ -1558,7 +1618,7 @@ class _FeedParserMixin: def _start_generator(self, attrsD): if attrsD: attrsD = self._itsAnHrefDamnIt(attrsD) - if attrsD.has_key('href'): + if 'href' in attrsD: attrsD['href'] = self.resolveURI(attrsD['href']) self._getContext()['generator_detail'] = FeedParserDict(attrsD) self.push('generator', 1) @@ -1566,9 +1626,9 @@ class _FeedParserMixin: def _end_generator(self): value = self.pop('generator') context = self._getContext() - if context.has_key('generator_detail'): + if 'generator_detail' in context: context['generator_detail']['name'] = value - + def _start_admin_generatoragent(self, attrsD): self.push('generator', 1) value = self._getAttribute(attrsD, 'rdf:resource') @@ -1583,15 +1643,15 @@ class _FeedParserMixin: if value: self.elementstack[-1][2].append(value) self.pop('errorreportsto') - + def _start_summary(self, attrsD): context = self._getContext() - if context.has_key('summary'): + if 'summary' in context: self._summaryKey = 'content' self._start_content(attrsD) else: self._summaryKey = 'summary' - self.pushContent(self._summaryKey, attrsD, 'text/plain', 1) + self.pushContent(self._summaryKey, attrsD, u'text/plain', 1) _start_itunes_summary = _start_summary def _end_summary(self): @@ -1601,49 +1661,46 @@ class _FeedParserMixin: self.popContent(self._summaryKey or 'summary') self._summaryKey = None _end_itunes_summary = _end_summary - + def _start_enclosure(self, attrsD): attrsD = self._itsAnHrefDamnIt(attrsD) context = self._getContext() - attrsD['rel']='enclosure' + attrsD['rel'] = u'enclosure' context.setdefault('links', []).append(FeedParserDict(attrsD)) - + def _start_source(self, attrsD): if 'url' in attrsD: - # This means that we're processing a source element from an RSS 2.0 feed - self.sourcedata['href'] = attrsD[u'url'] + # This means that we're processing a source element from an RSS 2.0 feed + self.sourcedata['href'] = attrsD[u'url'] self.push('source', 1) self.insource = 1 - self.hasTitle = 0 + self.title_depth = -1 def _end_source(self): self.insource = 0 value = self.pop('source') if value: - self.sourcedata['title'] = value + self.sourcedata['title'] = value self._getContext()['source'] = copy.deepcopy(self.sourcedata) self.sourcedata.clear() def _start_content(self, attrsD): - self.pushContent('content', attrsD, 'text/plain', 1) + self.pushContent('content', attrsD, u'text/plain', 1) src = attrsD.get('src') if src: self.contentparams['src'] = src self.push('content', 1) - def _start_prodlink(self, attrsD): - self.pushContent('content', attrsD, 'text/html', 1) - def _start_body(self, attrsD): - self.pushContent('content', attrsD, 'application/xhtml+xml', 1) + self.pushContent('content', attrsD, u'application/xhtml+xml', 1) _start_xhtml_body = _start_body def _start_content_encoded(self, attrsD): - self.pushContent('content', attrsD, 'text/html', 1) + self.pushContent('content', attrsD, u'text/html', 1) _start_fullitem = _start_content_encoded def _end_content(self): - copyToSummary = self.mapContentType(self.contentparams.get('type')) in (['text/plain'] + self.html_types) + copyToSummary = self.mapContentType(self.contentparams.get('type')) in ([u'text/plain'] + self.html_types) value = self.popContent('content') if copyToSummary: self._save('summary', value) @@ -1652,14 +1709,15 @@ class _FeedParserMixin: _end_xhtml_body = _end_content _end_content_encoded = _end_content _end_fullitem = _end_content - _end_prodlink = _end_content def _start_itunes_image(self, attrsD): self.push('itunes_image', 0) if attrsD.get('href'): self._getContext()['image'] = FeedParserDict({'href': attrsD.get('href')}) + elif attrsD.get('url'): + self._getContext()['image'] = FeedParserDict({'href': attrsD.get('url')}) _start_itunes_link = _start_itunes_image - + def _end_itunes_block(self): value = self.pop('itunes_block', 0) self._getContext()['itunes_block'] = (value == 'yes') and 1 or 0 @@ -1685,8 +1743,8 @@ class _FeedParserMixin: def _end_media_thumbnail(self): url = self.pop('url') context = self._getContext() - if url is not None and len(url.strip()) != 0: - if not context['media_thumbnail'][-1].has_key('url'): + if url != None and len(url.strip()) != 0: + if 'url' not in context['media_thumbnail'][-1]: context['media_thumbnail'][-1]['url'] = url def _start_media_player(self, attrsD): @@ -1712,32 +1770,35 @@ class _FeedParserMixin: if _XML_AVAILABLE: class _StrictFeedParser(_FeedParserMixin, xml.sax.handler.ContentHandler): def __init__(self, baseuri, baselang, encoding): - if _debug: sys.stderr.write('trying StrictFeedParser\n') xml.sax.handler.ContentHandler.__init__(self) _FeedParserMixin.__init__(self, baseuri, baselang, encoding) self.bozo = 0 self.exc = None self.decls = {} - + def startPrefixMapping(self, prefix, uri): + if not uri: + return + # Jython uses '' instead of None; standardize on None + prefix = prefix or None self.trackNamespace(prefix, uri) - if uri == 'http://www.w3.org/1999/xlink': - self.decls['xmlns:'+prefix] = uri - + if prefix and uri == 'http://www.w3.org/1999/xlink': + self.decls['xmlns:' + prefix] = uri + def startElementNS(self, name, qname, attrs): namespace, localname = name lowernamespace = str(namespace or '').lower() - if lowernamespace.find('backend.userland.com/rss') <> -1: + if lowernamespace.find(u'backend.userland.com/rss') <> -1: # match any backend.userland.com namespace - namespace = 'http://backend.userland.com/rss' + namespace = u'http://backend.userland.com/rss' lowernamespace = namespace if qname and qname.find(':') > 0: givenprefix = qname.split(':')[0] else: givenprefix = None prefix = self._matchnamespaces.get(lowernamespace, givenprefix) - if givenprefix and (prefix is None or (prefix == '' and lowernamespace == '')) and not self.namespacesInUse.has_key(givenprefix): - raise UndeclaredNamespace, "'%s' is not associated with a namespace" % givenprefix + if givenprefix and (prefix == None or (prefix == '' and lowernamespace == '')) and givenprefix not in self.namespacesInUse: + raise UndeclaredNamespace, "'%s' is not associated with a namespace" % givenprefix localname = str(localname).lower() # qname implementation is horribly broken in Python 2.1 (it @@ -1757,12 +1818,11 @@ if _XML_AVAILABLE: localname = prefix.lower() + ':' + localname elif namespace and not qname: #Expat for name,value in self.namespacesInUse.items(): - if name and value == namespace: - localname = name + ':' + localname - break - if _debug: sys.stderr.write('startElementNS: qname = %s, namespace = %s, givenprefix = %s, prefix = %s, attrs = %s, localname = %s\n' % (qname, namespace, givenprefix, prefix, attrs.items(), localname)) + if name and value == namespace: + localname = name + ':' + localname + break - for (namespace, attrlocalname), attrvalue in attrs._attrs.items(): + for (namespace, attrlocalname), attrvalue in attrs.items(): lowernamespace = (namespace or '').lower() prefix = self._matchnamespaces.get(lowernamespace, '') if prefix: @@ -1787,9 +1847,9 @@ if _XML_AVAILABLE: localname = prefix + ':' + localname elif namespace and not qname: #Expat for name,value in self.namespacesInUse.items(): - if name and value == namespace: - localname = name + ':' + localname - break + if name and value == namespace: + localname = name + ':' + localname + break localname = str(localname).lower() self.unknown_endtag(localname) @@ -1797,6 +1857,9 @@ if _XML_AVAILABLE: self.bozo = 1 self.exc = exc + # drv_libxml2 calls warning() in some cases + warning = error + def fatalError(self, exc): self.error(exc) raise exc @@ -1804,16 +1867,15 @@ if _XML_AVAILABLE: class _BaseHTMLProcessor(sgmllib.SGMLParser): special = re.compile('''[<>'"]''') bare_ampersand = re.compile("&(?!#\d+;|#x[0-9a-fA-F]+;|\w+;)") - elements_no_end_tag = [ - 'area', 'base', 'basefont', 'br', 'col', 'command', 'embed', 'frame', + elements_no_end_tag = set([ + 'area', 'base', 'basefont', 'br', 'col', 'command', 'embed', 'frame', 'hr', 'img', 'input', 'isindex', 'keygen', 'link', 'meta', 'param', 'source', 'track', 'wbr' - ] + ]) def __init__(self, encoding, _type): self.encoding = encoding self._type = _type - if _debug: sys.stderr.write('entering BaseHTMLProcessor, encoding=%s\n' % self.encoding) sgmllib.SGMLParser.__init__(self) def reset(self): @@ -1827,8 +1889,21 @@ class _BaseHTMLProcessor(sgmllib.SGMLParser): else: return '<' + tag + '></' + tag + '>' + # By declaring these methods and overriding their compiled code + # with the code from sgmllib, the original code will execute in + # feedparser's scope instead of sgmllib's. This means that the + # `tagfind` and `charref` regular expressions will be found as + # they're declared above, not as they're declared in sgmllib. + def goahead(self, i): + pass + goahead.func_code = sgmllib.SGMLParser.goahead.func_code + + def __parse_starttag(self, i): + pass + __parse_starttag.func_code = sgmllib.SGMLParser.parse_starttag.func_code + def parse_starttag(self,i): - j=sgmllib.SGMLParser.parse_starttag(self, i) + j = self.__parse_starttag(i) if self._type == 'application/xhtml+xml': if j>2 and self.rawdata[j-2:j]=='/>': self.unknown_endtag(self.lasttag) @@ -1836,23 +1911,23 @@ class _BaseHTMLProcessor(sgmllib.SGMLParser): def feed(self, data): data = re.compile(r'<!((?!DOCTYPE|--|\[))', re.IGNORECASE).sub(r'<!\1', data) - #data = re.sub(r'<(\S+?)\s*?/>', self._shorttag_replace, data) # bug [ 1399464 ] Bad regexp for _shorttag_replace - data = re.sub(r'<([^<>\s]+?)\s*/>', self._shorttag_replace, data) + data = re.sub(r'<([^<>\s]+?)\s*/>', self._shorttag_replace, data) data = data.replace(''', "'") data = data.replace('"', '"') try: bytes if bytes is str: raise NameError - self.encoding = self.encoding + '_INVALID_PYTHON_3' + self.encoding = self.encoding + u'_INVALID_PYTHON_3' except NameError: - if self.encoding and type(data) == type(u''): + if self.encoding and isinstance(data, unicode): data = data.encode(self.encoding) sgmllib.SGMLParser.feed(self, data) sgmllib.SGMLParser.close(self) def normalize_attrs(self, attrs): - if not attrs: return attrs + if not attrs: + return attrs # utility method to be called by descendants attrs = dict([(k.lower(), v) for k, v in attrs]).items() attrs = [(k, k in ('rel', 'type') and v.lower() or v) for k, v in attrs] @@ -1863,7 +1938,6 @@ class _BaseHTMLProcessor(sgmllib.SGMLParser): # called for each start tag # attrs is a list of (attr, value) tuples # e.g. for <pre class='screen'>, tag='pre', attrs=[('class', 'screen')] - if _debug: sys.stderr.write('_BaseHTMLProcessor, unknown_starttag, tag=%s\n' % tag) uattrs = [] strattrs='' if attrs: @@ -1871,11 +1945,8 @@ class _BaseHTMLProcessor(sgmllib.SGMLParser): value=value.replace('>','>').replace('<','<').replace('"','"') value = self.bare_ampersand.sub("&", value) # thanks to Kevin Marks for this breathtaking hack to deal with (valid) high-bit attribute values in UTF-8 feeds - if type(value) != type(u''): - try: - value = unicode(value, self.encoding) - except: - value = unicode(value, 'iso-8859-1') + if not isinstance(value, unicode): + value = value.decode(self.encoding, 'ignore') try: # Currently, in Python 3 the key is already a str, and cannot be decoded again uattrs.append((unicode(key, self.encoding), value)) @@ -1884,65 +1955,65 @@ class _BaseHTMLProcessor(sgmllib.SGMLParser): strattrs = u''.join([u' %s="%s"' % (key, value) for key, value in uattrs]) if self.encoding: try: - strattrs=strattrs.encode(self.encoding) - except: + strattrs = strattrs.encode(self.encoding) + except (UnicodeEncodeError, LookupError): pass if tag in self.elements_no_end_tag: - self.pieces.append('<%(tag)s%(strattrs)s />' % locals()) + self.pieces.append('<%s%s />' % (tag, strattrs)) else: - self.pieces.append('<%(tag)s%(strattrs)s>' % locals()) + self.pieces.append('<%s%s>' % (tag, strattrs)) def unknown_endtag(self, tag): # called for each end tag, e.g. for </pre>, tag will be 'pre' # Reconstruct the original end tag. if tag not in self.elements_no_end_tag: - self.pieces.append("</%(tag)s>" % locals()) + self.pieces.append("</%s>" % tag) def handle_charref(self, ref): # called for each character reference, e.g. for ' ', ref will be '160' # Reconstruct the original character reference. + ref = ref.lower() if ref.startswith('x'): - value = unichr(int(ref[1:],16)) + value = int(ref[1:], 16) else: - value = unichr(int(ref)) + value = int(ref) - if value in _cp1252.keys(): + if value in _cp1252: self.pieces.append('&#%s;' % hex(ord(_cp1252[value]))[1:]) else: - self.pieces.append('&#%(ref)s;' % locals()) - + self.pieces.append('&#%s;' % ref) + def handle_entityref(self, ref): # called for each entity reference, e.g. for '©', ref will be 'copy' # Reconstruct the original entity reference. - if name2codepoint.has_key(ref): - self.pieces.append('&%(ref)s;' % locals()) + if ref in name2codepoint or ref == 'apos': + self.pieces.append('&%s;' % ref) else: - self.pieces.append('&%(ref)s' % locals()) + self.pieces.append('&%s' % ref) def handle_data(self, text): # called for each block of plain text, i.e. outside of any tag and # not containing any character or entity references # Store the original text verbatim. - if _debug: sys.stderr.write('_BaseHTMLProcessor, handle_data, text=%s\n' % text) self.pieces.append(text) - + def handle_comment(self, text): # called for each HTML comment, e.g. <!-- insert Javascript code here --> # Reconstruct the original comment. - self.pieces.append('<!--%(text)s-->' % locals()) - + self.pieces.append('<!--%s-->' % text) + def handle_pi(self, text): # called for each processing instruction, e.g. <?instruction> # Reconstruct original processing instruction. - self.pieces.append('<?%(text)s>' % locals()) + self.pieces.append('<?%s>' % text) def handle_decl(self, text): # called for the DOCTYPE, if present, e.g. # <!DOCTYPE html PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN" # "http://www.w3.org/TR/html4/loose.dtd"> # Reconstruct original DOCTYPE - self.pieces.append('<!%(text)s>' % locals()) - + self.pieces.append('<!%s>' % text) + _new_declname_match = re.compile(r'[a-zA-Z][-_.a-zA-Z0-9:]*\s*').match def _scan_name(self, i, declstartpos): rawdata = self.rawdata @@ -1971,6 +2042,14 @@ class _BaseHTMLProcessor(sgmllib.SGMLParser): '''Return processed HTML as a single string''' return ''.join([str(p) for p in self.pieces]) + def parse_declaration(self, i): + try: + return sgmllib.SGMLParser.parse_declaration(self, i) + except sgmllib.SGMLParseError: + # escape the doctype declaration and continue parsing + self.handle_data('<') + return i+1 + class _LooseFeedParser(_FeedParserMixin, _BaseHTMLProcessor): def __init__(self, baseuri, baselang, encoding, entities): sgmllib.SGMLParser.__init__(self) @@ -1991,14 +2070,14 @@ class _LooseFeedParser(_FeedParserMixin, _BaseHTMLProcessor): data = data.replace('"', '"') data = data.replace(''', ''') data = data.replace(''', ''') - if self.contentparams.has_key('type') and not self.contentparams.get('type', 'xml').endswith('xml'): + if not self.contentparams.get('type', u'xml').endswith(u'xml'): data = data.replace('<', '<') data = data.replace('>', '>') data = data.replace('&', '&') data = data.replace('"', '"') data = data.replace(''', "'") return data - + def strattrs(self, attrs): return ''.join([' %s="%s"' % (n,v.replace('"','"')) for n,v in attrs]) @@ -2009,25 +2088,25 @@ class _MicroformatsParser: NODE = 4 EMAIL = 5 - known_xfn_relationships = ['contact', 'acquaintance', 'friend', 'met', 'co-worker', 'coworker', 'colleague', 'co-resident', 'coresident', 'neighbor', 'child', 'parent', 'sibling', 'brother', 'sister', 'spouse', 'wife', 'husband', 'kin', 'relative', 'muse', 'crush', 'date', 'sweetheart', 'me'] - known_binary_extensions = ['zip','rar','exe','gz','tar','tgz','tbz2','bz2','z','7z','dmg','img','sit','sitx','hqx','deb','rpm','bz2','jar','rar','iso','bin','msi','mp2','mp3','ogg','ogm','mp4','m4v','m4a','avi','wma','wmv'] + known_xfn_relationships = set(['contact', 'acquaintance', 'friend', 'met', 'co-worker', 'coworker', 'colleague', 'co-resident', 'coresident', 'neighbor', 'child', 'parent', 'sibling', 'brother', 'sister', 'spouse', 'wife', 'husband', 'kin', 'relative', 'muse', 'crush', 'date', 'sweetheart', 'me']) + known_binary_extensions = set(['zip','rar','exe','gz','tar','tgz','tbz2','bz2','z','7z','dmg','img','sit','sitx','hqx','deb','rpm','bz2','jar','rar','iso','bin','msi','mp2','mp3','ogg','ogm','mp4','m4v','m4a','avi','wma','wmv']) def __init__(self, data, baseuri, encoding): self.document = BeautifulSoup.BeautifulSoup(data) self.baseuri = baseuri self.encoding = encoding - if type(data) == type(u''): + if isinstance(data, unicode): data = data.encode(encoding) self.tags = [] self.enclosures = [] self.xfn = [] self.vcard = None - + def vcardEscape(self, s): - if type(s) in (type(''), type(u'')): + if isinstance(s, basestring): s = s.replace(',', '\\,').replace(';', '\\;').replace('\n', '\\n') return s - + def vcardFold(self, s): s = re.sub(';+$', '', s) sFolded = '' @@ -2043,14 +2122,14 @@ class _MicroformatsParser: def normalize(self, s): return re.sub(r'\s+', ' ', s).strip() - + def unique(self, aList): results = [] for element in aList: if element not in results: results.append(element) return results - + def toISO8601(self, dt): return time.strftime('%Y-%m-%dT%H:%M:%SZ', dt) @@ -2088,12 +2167,18 @@ class _MicroformatsParser: arResults.append(node) bFound = (len(arResults) != 0) if not bFound: - if bAllowMultiple: return [] - elif iPropertyType == self.STRING: return '' - elif iPropertyType == self.DATE: return None - elif iPropertyType == self.URI: return '' - elif iPropertyType == self.NODE: return None - else: return None + if bAllowMultiple: + return [] + elif iPropertyType == self.STRING: + return '' + elif iPropertyType == self.DATE: + return None + elif iPropertyType == self.URI: + return '' + elif iPropertyType == self.NODE: + return None + else: + return None arValues = [] for elmResult in arResults: sValue = None @@ -2113,9 +2198,12 @@ class _MicroformatsParser: if sValue: sValue = bNormalize and self.normalize(sValue) or sValue.strip() if (not sValue) and (iPropertyType == self.URI): - if sNodeName == 'a': sValue = elmResult.get('href') - elif sNodeName == 'img': sValue = elmResult.get('src') - elif sNodeName == 'object': sValue = elmResult.get('data') + if sNodeName == 'a': + sValue = elmResult.get('href') + elif sNodeName == 'img': + sValue = elmResult.get('src') + elif sNodeName == 'object': + sValue = elmResult.get('data') if sValue: sValue = bNormalize and self.normalize(sValue) or sValue.strip() if (not sValue) and (sNodeName == 'img'): @@ -2129,7 +2217,8 @@ class _MicroformatsParser: sValue = sValue.replace('\r', '\n') if sValue: sValue = bNormalize and self.normalize(sValue) or sValue.strip() - if not sValue: continue + if not sValue: + continue if iPropertyType == self.DATE: sValue = _parse_date_iso8601(sValue) if bAllowMultiple: @@ -2140,21 +2229,21 @@ class _MicroformatsParser: def findVCards(self, elmRoot, bAgentParsing=0): sVCards = '' - + if not bAgentParsing: arCards = self.getPropertyValue(elmRoot, 'vcard', bAllowMultiple=1) else: arCards = [elmRoot] - + for elmCard in arCards: arLines = [] - + def processSingleString(sProperty): sValue = self.getPropertyValue(elmCard, sProperty, self.STRING, bAutoEscape=1).decode(self.encoding) if sValue: arLines.append(self.vcardFold(sProperty.upper() + ':' + sValue)) return sValue or u'' - + def processSingleURI(sProperty): sValue = self.getPropertyValue(elmCard, sProperty, self.URI) if sValue: @@ -2177,7 +2266,7 @@ class _MicroformatsParser: if sContentType: sContentType = ';TYPE=' + sContentType.upper() arLines.append(self.vcardFold(sProperty.upper() + sEncoding + sContentType + sValueKey + ':' + sValue)) - + def processTypeValue(sProperty, arDefaultType, arForceType=None): arResults = self.getPropertyValue(elmCard, sProperty, bAllowMultiple=1) for elmResult in arResults: @@ -2189,7 +2278,7 @@ class _MicroformatsParser: sValue = self.getPropertyValue(elmResult, 'value', self.EMAIL, 0) if sValue: arLines.append(self.vcardFold(sProperty.upper() + ';TYPE=' + ','.join(arType) + ':' + sValue)) - + # AGENT # must do this before all other properties because it is destructive # (removes nested class="vcard" nodes so they don't interfere with @@ -2208,10 +2297,10 @@ class _MicroformatsParser: sAgentValue = self.getPropertyValue(elmAgent, 'value', self.URI, bAutoEscape=1); if sAgentValue: arLines.append(self.vcardFold('AGENT;VALUE=uri:' + sAgentValue)) - + # FN (full name) sFN = processSingleString('fn') - + # N (name) elmName = self.getPropertyValue(elmCard, 'n') if elmName: @@ -2220,7 +2309,7 @@ class _MicroformatsParser: arAdditionalNames = self.getPropertyValue(elmName, 'additional-name', self.STRING, 1, 1) + self.getPropertyValue(elmName, 'additional-names', self.STRING, 1, 1) arHonorificPrefixes = self.getPropertyValue(elmName, 'honorific-prefix', self.STRING, 1, 1) + self.getPropertyValue(elmName, 'honorific-prefixes', self.STRING, 1, 1) arHonorificSuffixes = self.getPropertyValue(elmName, 'honorific-suffix', self.STRING, 1, 1) + self.getPropertyValue(elmName, 'honorific-suffixes', self.STRING, 1, 1) - arLines.append(self.vcardFold('N:' + sFamilyName + ';' + + arLines.append(self.vcardFold('N:' + sFamilyName + ';' + sGivenName + ';' + ','.join(arAdditionalNames) + ';' + ','.join(arHonorificPrefixes) + ';' + @@ -2237,25 +2326,25 @@ class _MicroformatsParser: arLines.append(self.vcardFold('N:' + arNames[0] + ';' + arNames[1])) else: arLines.append(self.vcardFold('N:' + arNames[1] + ';' + arNames[0])) - + # SORT-STRING sSortString = self.getPropertyValue(elmCard, 'sort-string', self.STRING, bAutoEscape=1) if sSortString: arLines.append(self.vcardFold('SORT-STRING:' + sSortString)) - + # NICKNAME arNickname = self.getPropertyValue(elmCard, 'nickname', self.STRING, 1, 1) if arNickname: arLines.append(self.vcardFold('NICKNAME:' + ','.join(arNickname))) - + # PHOTO processSingleURI('photo') - + # BDAY dtBday = self.getPropertyValue(elmCard, 'bday', self.DATE) if dtBday: arLines.append(self.vcardFold('BDAY:' + self.toISO8601(dtBday))) - + # ADR (address) arAdr = self.getPropertyValue(elmCard, 'adr', bAllowMultiple=1) for elmAdr in arAdr: @@ -2277,38 +2366,38 @@ class _MicroformatsParser: sRegion + ';' + sPostalCode + ';' + sCountryName)) - + # LABEL processTypeValue('label', ['intl','postal','parcel','work']) - + # TEL (phone number) processTypeValue('tel', ['voice']) - + # EMAIL processTypeValue('email', ['internet'], ['internet']) - + # MAILER processSingleString('mailer') - + # TZ (timezone) processSingleString('tz') - + # GEO (geographical information) elmGeo = self.getPropertyValue(elmCard, 'geo') if elmGeo: sLatitude = self.getPropertyValue(elmGeo, 'latitude', self.STRING, 0, 1) sLongitude = self.getPropertyValue(elmGeo, 'longitude', self.STRING, 0, 1) arLines.append(self.vcardFold('GEO:' + sLatitude + ';' + sLongitude)) - + # TITLE processSingleString('title') - + # ROLE processSingleString('role') # LOGO processSingleURI('logo') - + # ORG (organization) elmOrg = self.getPropertyValue(elmCard, 'org') if elmOrg: @@ -2322,49 +2411,58 @@ class _MicroformatsParser: else: arOrganizationUnit = self.getPropertyValue(elmOrg, 'organization-unit', self.STRING, 1, 1) arLines.append(self.vcardFold('ORG:' + sOrganizationName + ';' + ';'.join(arOrganizationUnit))) - + # CATEGORY arCategory = self.getPropertyValue(elmCard, 'category', self.STRING, 1, 1) + self.getPropertyValue(elmCard, 'categories', self.STRING, 1, 1) if arCategory: arLines.append(self.vcardFold('CATEGORIES:' + ','.join(arCategory))) - + # NOTE processSingleString('note') - + # REV processSingleString('rev') - + # SOUND processSingleURI('sound') - + # UID processSingleString('uid') - + # URL processSingleURI('url') - + # CLASS processSingleString('class') - + # KEY processSingleURI('key') - + if arLines: arLines = [u'BEGIN:vCard',u'VERSION:3.0'] + arLines + [u'END:vCard'] + # XXX - this is super ugly; properly fix this with issue 148 + for i, s in enumerate(arLines): + if not isinstance(s, unicode): + arLines[i] = s.decode('utf-8', 'ignore') sVCards += u'\n'.join(arLines) + u'\n' - + return sVCards.strip() - + def isProbablyDownloadable(self, elm): attrsD = elm.attrMap - if not attrsD.has_key('href'): return 0 + if 'href' not in attrsD: + return 0 linktype = attrsD.get('type', '').strip() if linktype.startswith('audio/') or \ linktype.startswith('video/') or \ (linktype.startswith('application/') and not linktype.endswith('xml')): return 1 - path = urlparse.urlparse(attrsD['href'])[2] - if path.find('.') == -1: return 0 + try: + path = urlparse.urlparse(attrsD['href'])[2] + except ValueError: + return 0 + if path.find('.') == -1: + return 0 fileext = path.split('.').pop().lower() return fileext in self.known_binary_extensions @@ -2372,13 +2470,18 @@ class _MicroformatsParser: all = lambda x: 1 for elm in self.document(all, {'rel': re.compile(r'\btag\b')}): href = elm.get('href') - if not href: continue + if not href: + continue urlscheme, domain, path, params, query, fragment = \ urlparse.urlparse(_urljoin(self.baseuri, href)) segments = path.split('/') tag = segments.pop() if not tag: - tag = segments.pop() + if segments: + tag = segments.pop() + else: + # there are no tags + continue tagscheme = urlparse.urlunparse((urlscheme, domain, '/'.join(segments), '', '', '')) if not tagscheme.endswith('/'): tagscheme += '/' @@ -2388,7 +2491,8 @@ class _MicroformatsParser: all = lambda x: 1 enclosure_match = re.compile(r'\benclosure\b') for elm in self.document(all, {'href': re.compile(r'.+')}): - if not enclosure_match.search(elm.get('rel', '')) and not self.isProbablyDownloadable(elm): continue + if not enclosure_match.search(elm.get('rel', u'')) and not self.isProbablyDownloadable(elm): + continue if elm.attrMap not in self.enclosures: self.enclosures.append(elm.attrMap) if elm.string and not elm.get('title'): @@ -2397,17 +2501,14 @@ class _MicroformatsParser: def findXFN(self): all = lambda x: 1 for elm in self.document(all, {'rel': re.compile('.+'), 'href': re.compile('.+')}): - rels = elm.get('rel', '').split() - xfn_rels = [] - for rel in rels: - if rel in self.known_xfn_relationships: - xfn_rels.append(rel) + rels = elm.get('rel', u'').split() + xfn_rels = [r for r in rels if r in self.known_xfn_relationships] if xfn_rels: self.xfn.append({"relationships": xfn_rels, "href": elm.get('href', ''), "name": elm.string}) def _parseMicroformats(htmlSource, baseURI, encoding): - if not BeautifulSoup: return - if _debug: sys.stderr.write('entering _parseMicroformats\n') + if not BeautifulSoup: + return try: p = _MicroformatsParser(htmlSource, baseURI, encoding) except UnicodeEncodeError: @@ -2421,7 +2522,7 @@ def _parseMicroformats(htmlSource, baseURI, encoding): return {"tags": p.tags, "enclosures": p.enclosures, "xfn": p.xfn, "vcard": p.vcard} class _RelativeURIResolver(_BaseHTMLProcessor): - relative_uris = [('a', 'href'), + relative_uris = set([('a', 'href'), ('applet', 'codebase'), ('area', 'href'), ('blockquote', 'cite'), @@ -2445,25 +2546,24 @@ class _RelativeURIResolver(_BaseHTMLProcessor): ('object', 'data'), ('object', 'usemap'), ('q', 'cite'), - ('script', 'src')] + ('script', 'src'), + ('video', 'poster')]) def __init__(self, baseuri, encoding, _type): _BaseHTMLProcessor.__init__(self, encoding, _type) self.baseuri = baseuri def resolveURI(self, uri): - return _makeSafeAbsoluteURI(_urljoin(self.baseuri, uri.strip())) - + return _makeSafeAbsoluteURI(self.baseuri, uri.strip()) + def unknown_starttag(self, tag, attrs): - if _debug: - sys.stderr.write('tag: [%s] with attributes: [%s]\n' % (tag, str(attrs))) attrs = self.normalize_attrs(attrs) attrs = [(key, ((tag, key) in self.relative_uris) and self.resolveURI(value) or value) for key, value in attrs] _BaseHTMLProcessor.unknown_starttag(self, tag, attrs) def _resolveRelativeURIs(htmlSource, baseURI, encoding, _type): - if _debug: - sys.stderr.write('entering _resolveRelativeURIs\n') + if not _SGML_AVAILABLE: + return htmlSource p = _RelativeURIResolver(baseURI, encoding, _type) p.feed(htmlSource) @@ -2472,20 +2572,30 @@ def _resolveRelativeURIs(htmlSource, baseURI, encoding, _type): def _makeSafeAbsoluteURI(base, rel=None): # bail if ACCEPTABLE_URI_SCHEMES is empty if not ACCEPTABLE_URI_SCHEMES: - return _urljoin(base, rel or u'') + try: + return _urljoin(base, rel or u'') + except ValueError: + return u'' if not base: return rel or u'' if not rel: - if base.strip().split(':', 1)[0] not in ACCEPTABLE_URI_SCHEMES: + try: + scheme = urlparse.urlparse(base)[0] + except ValueError: return u'' - return base - uri = _urljoin(base, rel) + if not scheme or scheme in ACCEPTABLE_URI_SCHEMES: + return base + return u'' + try: + uri = _urljoin(base, rel) + except ValueError: + return u'' if uri.strip().split(':', 1)[0] not in ACCEPTABLE_URI_SCHEMES: return u'' return uri class _HTMLSanitizer(_BaseHTMLProcessor): - acceptable_elements = ['a', 'abbr', 'acronym', 'address', 'area', + acceptable_elements = set(['a', 'abbr', 'acronym', 'address', 'area', 'article', 'aside', 'audio', 'b', 'big', 'blockquote', 'br', 'button', 'canvas', 'caption', 'center', 'cite', 'code', 'col', 'colgroup', 'command', 'datagrid', 'datalist', 'dd', 'del', 'details', 'dfn', @@ -2497,9 +2607,9 @@ class _HTMLSanitizer(_BaseHTMLProcessor): 'p', 'pre', 'progress', 'q', 's', 'samp', 'section', 'select', 'small', 'sound', 'source', 'spacer', 'span', 'strike', 'strong', 'sub', 'sup', 'table', 'tbody', 'td', 'textarea', 'time', 'tfoot', - 'th', 'thead', 'tr', 'tt', 'u', 'ul', 'var', 'video', 'noscript'] + 'th', 'thead', 'tr', 'tt', 'u', 'ul', 'var', 'video', 'noscript']) - acceptable_attributes = ['abbr', 'accept', 'accept-charset', 'accesskey', + acceptable_attributes = set(['abbr', 'accept', 'accept-charset', 'accesskey', 'action', 'align', 'alt', 'autocomplete', 'autofocus', 'axis', 'background', 'balance', 'bgcolor', 'bgproperties', 'border', 'bordercolor', 'bordercolordark', 'bordercolorlight', 'bottompadding', @@ -2514,17 +2624,17 @@ class _HTMLSanitizer(_BaseHTMLProcessor): 'loop', 'loopcount', 'loopend', 'loopstart', 'low', 'lowsrc', 'max', 'maxlength', 'media', 'method', 'min', 'multiple', 'name', 'nohref', 'noshade', 'nowrap', 'open', 'optimum', 'pattern', 'ping', 'point-size', - 'prompt', 'pqg', 'radiogroup', 'readonly', 'rel', 'repeat-max', - 'repeat-min', 'replace', 'required', 'rev', 'rightspacing', 'rows', - 'rowspan', 'rules', 'scope', 'selected', 'shape', 'size', 'span', 'src', - 'start', 'step', 'summary', 'suppress', 'tabindex', 'target', 'template', - 'title', 'toppadding', 'type', 'unselectable', 'usemap', 'urn', 'valign', - 'value', 'variable', 'volume', 'vspace', 'vrml', 'width', 'wrap', - 'xml:lang'] + 'poster', 'pqg', 'preload', 'prompt', 'radiogroup', 'readonly', 'rel', + 'repeat-max', 'repeat-min', 'replace', 'required', 'rev', 'rightspacing', + 'rows', 'rowspan', 'rules', 'scope', 'selected', 'shape', 'size', 'span', + 'src', 'start', 'step', 'summary', 'suppress', 'tabindex', 'target', + 'template', 'title', 'toppadding', 'type', 'unselectable', 'usemap', + 'urn', 'valign', 'value', 'variable', 'volume', 'vspace', 'vrml', + 'width', 'wrap', 'xml:lang']) - unacceptable_elements_with_end_tag = ['script', 'applet', 'style'] + unacceptable_elements_with_end_tag = set(['script', 'applet', 'style']) - acceptable_css_properties = ['azimuth', 'background-color', + acceptable_css_properties = set(['azimuth', 'background-color', 'border-bottom-color', 'border-collapse', 'border-color', 'border-left-color', 'border-right-color', 'border-top-color', 'clear', 'color', 'cursor', 'direction', 'display', 'elevation', 'float', 'font', @@ -2534,26 +2644,26 @@ class _HTMLSanitizer(_BaseHTMLProcessor): 'speak', 'speak-header', 'speak-numeral', 'speak-punctuation', 'speech-rate', 'stress', 'text-align', 'text-decoration', 'text-indent', 'unicode-bidi', 'vertical-align', 'voice-family', 'volume', - 'white-space', 'width'] + 'white-space', 'width']) # survey of common keywords found in feeds - acceptable_css_keywords = ['auto', 'aqua', 'black', 'block', 'blue', + acceptable_css_keywords = set(['auto', 'aqua', 'black', 'block', 'blue', 'bold', 'both', 'bottom', 'brown', 'center', 'collapse', 'dashed', 'dotted', 'fuchsia', 'gray', 'green', '!important', 'italic', 'left', 'lime', 'maroon', 'medium', 'none', 'navy', 'normal', 'nowrap', 'olive', 'pointer', 'purple', 'red', 'right', 'solid', 'silver', 'teal', 'top', - 'transparent', 'underline', 'white', 'yellow'] + 'transparent', 'underline', 'white', 'yellow']) valid_css_values = re.compile('^(#[0-9a-f]+|rgb\(\d+%?,\d*%?,?\d*%?\)?|' + '\d{0,2}\.?\d{0,2}(cm|em|ex|in|mm|pc|pt|px|%|,|\))?)$') - mathml_elements = ['annotation', 'annotation-xml', 'maction', 'math', + mathml_elements = set(['annotation', 'annotation-xml', 'maction', 'math', 'merror', 'mfenced', 'mfrac', 'mi', 'mmultiscripts', 'mn', 'mo', 'mover', 'mpadded', 'mphantom', 'mprescripts', 'mroot', 'mrow', 'mspace', 'msqrt', 'mstyle', 'msub', 'msubsup', 'msup', 'mtable', 'mtd', 'mtext', 'mtr', 'munder', - 'munderover', 'none', 'semantics'] + 'munderover', 'none', 'semantics']) - mathml_attributes = ['actiontype', 'align', 'columnalign', 'columnalign', + mathml_attributes = set(['actiontype', 'align', 'columnalign', 'columnalign', 'columnalign', 'close', 'columnlines', 'columnspacing', 'columnspan', 'depth', 'display', 'displaystyle', 'encoding', 'equalcolumns', 'equalrows', 'fence', 'fontstyle', 'fontweight', 'frame', 'height', 'linethickness', @@ -2561,18 +2671,18 @@ class _HTMLSanitizer(_BaseHTMLProcessor): 'maxsize', 'minsize', 'open', 'other', 'rowalign', 'rowalign', 'rowalign', 'rowlines', 'rowspacing', 'rowspan', 'rspace', 'scriptlevel', 'selection', 'separator', 'separators', 'stretchy', 'width', 'width', 'xlink:href', - 'xlink:show', 'xlink:type', 'xmlns', 'xmlns:xlink'] + 'xlink:show', 'xlink:type', 'xmlns', 'xmlns:xlink']) # svgtiny - foreignObject + linearGradient + radialGradient + stop - svg_elements = ['a', 'animate', 'animateColor', 'animateMotion', + svg_elements = set(['a', 'animate', 'animateColor', 'animateMotion', 'animateTransform', 'circle', 'defs', 'desc', 'ellipse', 'foreignObject', - 'font-face', 'font-face-name', 'font-face-src', 'g', 'glyph', 'hkern', + 'font-face', 'font-face-name', 'font-face-src', 'g', 'glyph', 'hkern', 'linearGradient', 'line', 'marker', 'metadata', 'missing-glyph', 'mpath', 'path', 'polygon', 'polyline', 'radialGradient', 'rect', 'set', 'stop', - 'svg', 'switch', 'text', 'title', 'tspan', 'use'] + 'svg', 'switch', 'text', 'title', 'tspan', 'use']) # svgtiny + class + opacity + offset + xmlns + xmlns:xlink - svg_attributes = ['accent-height', 'accumulate', 'additive', 'alphabetic', + svg_attributes = set(['accent-height', 'accumulate', 'additive', 'alphabetic', 'arabic-form', 'ascent', 'attributeName', 'attributeType', 'baseProfile', 'bbox', 'begin', 'by', 'calcMode', 'cap-height', 'class', 'color', 'color-rendering', 'content', 'cx', 'cy', 'd', 'dx', @@ -2598,21 +2708,21 @@ class _HTMLSanitizer(_BaseHTMLProcessor): 'widths', 'x', 'x-height', 'x1', 'x2', 'xlink:actuate', 'xlink:arcrole', 'xlink:href', 'xlink:role', 'xlink:show', 'xlink:title', 'xlink:type', 'xml:base', 'xml:lang', 'xml:space', 'xmlns', 'xmlns:xlink', 'y', 'y1', - 'y2', 'zoomAndPan'] + 'y2', 'zoomAndPan']) svg_attr_map = None svg_elem_map = None - acceptable_svg_properties = [ 'fill', 'fill-opacity', 'fill-rule', + acceptable_svg_properties = set([ 'fill', 'fill-opacity', 'fill-rule', 'stroke', 'stroke-width', 'stroke-linecap', 'stroke-linejoin', - 'stroke-opacity'] + 'stroke-opacity']) def reset(self): _BaseHTMLProcessor.reset(self) self.unacceptablestack = 0 self.mathmlOK = 0 self.svgOK = 0 - + def unknown_starttag(self, tag, attrs): acceptable_attributes = self.acceptable_attributes keymap = {} @@ -2666,21 +2776,27 @@ class _HTMLSanitizer(_BaseHTMLProcessor): for key, value in self.normalize_attrs(attrs): if key in acceptable_attributes: key=keymap.get(key,key) + # make sure the uri uses an acceptable uri scheme + if key == u'href': + value = _makeSafeAbsoluteURI(value) clean_attrs.append((key,value)) elif key=='style': clean_value = self.sanitize_style(value) - if clean_value: clean_attrs.append((key,clean_value)) + if clean_value: + clean_attrs.append((key,clean_value)) _BaseHTMLProcessor.unknown_starttag(self, tag, clean_attrs) - + def unknown_endtag(self, tag): if not tag in self.acceptable_elements: if tag in self.unacceptable_elements_with_end_tag: self.unacceptablestack -= 1 if self.mathmlOK and tag in self.mathml_elements: - if tag == 'math' and self.mathmlOK: self.mathmlOK -= 1 + if tag == 'math' and self.mathmlOK: + self.mathmlOK -= 1 elif self.svgOK and tag in self.svg_elements: tag = self.svg_elem_map.get(tag,tag) - if tag == 'svg' and self.svgOK: self.svgOK -= 1 + if tag == 'svg' and self.svgOK: + self.svgOK -= 1 else: return _BaseHTMLProcessor.unknown_endtag(self, tag) @@ -2700,29 +2816,46 @@ class _HTMLSanitizer(_BaseHTMLProcessor): style=re.compile('url\s*\(\s*[^\s)]+?\s*\)\s*').sub(' ',style) # gauntlet - if not re.match("""^([:,;#%.\sa-zA-Z0-9!]|\w-\w|'[\s\w]+'|"[\s\w]+"|\([\d,\s]+\))*$""", style): return '' + if not re.match("""^([:,;#%.\sa-zA-Z0-9!]|\w-\w|'[\s\w]+'|"[\s\w]+"|\([\d,\s]+\))*$""", style): + return '' # This replaced a regexp that used re.match and was prone to pathological back-tracking. - if re.sub("\s*[-\w]+\s*:\s*[^:;]*;?", '', style).strip(): return '' + if re.sub("\s*[-\w]+\s*:\s*[^:;]*;?", '', style).strip(): + return '' clean = [] for prop,value in re.findall("([-\w]+)\s*:\s*([^:;]*)",style): - if not value: continue - if prop.lower() in self.acceptable_css_properties: - clean.append(prop + ': ' + value + ';') - elif prop.split('-')[0].lower() in ['background','border','margin','padding']: - for keyword in value.split(): - if not keyword in self.acceptable_css_keywords and \ - not self.valid_css_values.match(keyword): - break - else: - clean.append(prop + ': ' + value + ';') - elif self.svgOK and prop.lower() in self.acceptable_svg_properties: - clean.append(prop + ': ' + value + ';') + if not value: + continue + if prop.lower() in self.acceptable_css_properties: + clean.append(prop + ': ' + value + ';') + elif prop.split('-')[0].lower() in ['background','border','margin','padding']: + for keyword in value.split(): + if not keyword in self.acceptable_css_keywords and \ + not self.valid_css_values.match(keyword): + break + else: + clean.append(prop + ': ' + value + ';') + elif self.svgOK and prop.lower() in self.acceptable_svg_properties: + clean.append(prop + ': ' + value + ';') return ' '.join(clean) + def parse_comment(self, i, report=1): + ret = _BaseHTMLProcessor.parse_comment(self, i, report) + if ret >= 0: + return ret + # if ret == -1, this may be a malicious attempt to circumvent + # sanitization, or a page-destroying unclosed comment + match = re.compile(r'--[^>]*>').search(self.rawdata, i+4) + if match: + return match.end() + # unclosed comment; deliberately fail to handle_data() + return len(self.rawdata) + def _sanitizeHTML(htmlSource, encoding, _type): + if not _SGML_AVAILABLE: + return htmlSource p = _HTMLSanitizer(encoding, _type) htmlSource = htmlSource.replace('<![CDATA[', '<![CDATA[') p.feed(htmlSource) @@ -2747,7 +2880,7 @@ def _sanitizeHTML(htmlSource, encoding, _type): except: pass if _tidy: - utf8 = type(data) == type(u'') + utf8 = isinstance(data, unicode) if utf8: data = data.encode('utf-8') data = _tidy(data, output_xhtml=1, numeric_entities=1, wrap=0, char_encoding="utf8") @@ -2764,39 +2897,29 @@ def _sanitizeHTML(htmlSource, encoding, _type): class _FeedURLHandler(urllib2.HTTPDigestAuthHandler, urllib2.HTTPRedirectHandler, urllib2.HTTPDefaultErrorHandler): def http_error_default(self, req, fp, code, msg, headers): - if ((code / 100) == 3) and (code != 304): - return self.http_error_302(req, fp, code, msg, headers) - infourl = urllib.addinfourl(fp, headers, req.get_full_url()) - infourl.status = code - return infourl - - def http_error_302(self, req, fp, code, msg, headers): - if headers.dict.has_key('location'): - infourl = urllib2.HTTPRedirectHandler.http_error_302(self, req, fp, code, msg, headers) - else: - infourl = urllib.addinfourl(fp, headers, req.get_full_url()) - if not hasattr(infourl, 'status'): - infourl.status = code - return infourl - - def http_error_301(self, req, fp, code, msg, headers): - if headers.dict.has_key('location'): - infourl = urllib2.HTTPRedirectHandler.http_error_301(self, req, fp, code, msg, headers) - else: - infourl = urllib.addinfourl(fp, headers, req.get_full_url()) - if not hasattr(infourl, 'status'): - infourl.status = code - return infourl - - http_error_300 = http_error_302 - http_error_303 = http_error_302 - http_error_307 = http_error_302 - + # The default implementation just raises HTTPError. + # Forget that. + fp.status = code + return fp + + def http_error_301(self, req, fp, code, msg, hdrs): + result = urllib2.HTTPRedirectHandler.http_error_301(self, req, fp, + code, msg, hdrs) + result.status = code + result.newurl = result.geturl() + return result + # The default implementations in urllib2.HTTPRedirectHandler + # are identical, so hardcoding a http_error_301 call above + # won't affect anything + http_error_300 = http_error_301 + http_error_302 = http_error_301 + http_error_303 = http_error_301 + http_error_307 = http_error_301 + def http_error_401(self, req, fp, code, msg, headers): # Check if # - server requires digest auth, AND # - we tried (unsuccessfully) with basic auth, AND - # - we're using Python 2.3.3 or later (digest auth is irreparably broken in earlier versions) # If all conditions hold, parse authentication information # out of the Authorization header we sent the first time # (for the username and password) and the WWW-Authenticate @@ -2804,17 +2927,16 @@ class _FeedURLHandler(urllib2.HTTPDigestAuthHandler, urllib2.HTTPRedirectHandler # the request with the appropriate digest auth headers instead. # This evil genius hack has been brought to you by Aaron Swartz. host = urlparse.urlparse(req.get_full_url())[1] - try: - assert sys.version.split()[0] >= '2.3.3' - assert base64 is not None - user, passw = _base64decode(req.headers['Authorization'].split(' ')[1]).split(':') - realm = re.findall('realm="([^"]*)"', headers['WWW-Authenticate'])[0] - self.add_password(realm, host, user, passw) - retry = self.http_error_auth_reqed('www-authenticate', host, req, headers) - self.reset_retry_count() - return retry - except: + if base64 is None or 'Authorization' not in req.headers \ + or 'WWW-Authenticate' not in headers: return self.http_error_default(req, fp, code, msg, headers) + auth = _base64decode(req.headers['Authorization'].split(' ')[1]) + user, passw = auth.split(':') + realm = re.findall('realm="([^"]*)"', headers['WWW-Authenticate'])[0] + self.add_password(realm, host, user, passw) + retry = self.http_error_auth_reqed('www-authenticate', host, req, headers) + self.reset_retry_count() + return retry def _open_resource(url_file_stream_or_string, etag, modified, agent, referrer, handlers, request_headers): """URL, filename, or string --> stream @@ -2851,10 +2973,8 @@ def _open_resource(url_file_stream_or_string, etag, modified, agent, referrer, h if hasattr(url_file_stream_or_string, 'read'): return url_file_stream_or_string - if url_file_stream_or_string == '-': - return sys.stdin - - if urlparse.urlparse(url_file_stream_or_string)[0] in ('http', 'https', 'ftp', 'file', 'feed'): + if isinstance(url_file_stream_or_string, basestring) \ + and urlparse.urlparse(url_file_stream_or_string)[0] in ('http', 'https', 'ftp', 'file', 'feed'): # Deal with the feed URI scheme if url_file_stream_or_string.startswith('feed:http'): url_file_stream_or_string = url_file_stream_or_string[5:] @@ -2862,9 +2982,9 @@ def _open_resource(url_file_stream_or_string, etag, modified, agent, referrer, h url_file_stream_or_string = 'http:' + url_file_stream_or_string[5:] if not agent: agent = USER_AGENT - # test for inline user:password for basic auth + # Test for inline user:password credentials for HTTP basic auth auth = None - if base64: + if base64 and not url_file_stream_or_string.startswith('ftp:'): urltype, rest = urllib.splittype(url_file_stream_or_string) realhost, rest = urllib.splithost(rest) if realhost: @@ -2874,38 +2994,66 @@ def _open_resource(url_file_stream_or_string, etag, modified, agent, referrer, h auth = base64.standard_b64encode(user_passwd).strip() # iri support - try: - if isinstance(url_file_stream_or_string,unicode): - url_file_stream_or_string = url_file_stream_or_string.encode('idna').decode('utf-8') - else: - url_file_stream_or_string = url_file_stream_or_string.decode('utf-8').encode('idna').decode('utf-8') - except: - pass + if isinstance(url_file_stream_or_string, unicode): + url_file_stream_or_string = _convert_to_idn(url_file_stream_or_string) # try to open with urllib2 (to use optional headers) request = _build_urllib2_request(url_file_stream_or_string, agent, etag, modified, referrer, auth, request_headers) - opener = apply(urllib2.build_opener, tuple(handlers + [_FeedURLHandler()])) + opener = urllib2.build_opener(*tuple(handlers + [_FeedURLHandler()])) opener.addheaders = [] # RMK - must clear so we only send our custom User-Agent try: return opener.open(request) finally: opener.close() # JohnD - + # try to open with native open function (if url_file_stream_or_string is a filename) try: return open(url_file_stream_or_string, 'rb') - except: + except (IOError, UnicodeEncodeError, TypeError): + # if url_file_stream_or_string is a unicode object that + # cannot be converted to the encoding returned by + # sys.getfilesystemencoding(), a UnicodeEncodeError + # will be thrown + # If url_file_stream_or_string is a string that contains NULL + # (such as an XML document encoded in UTF-32), TypeError will + # be thrown. pass # treat url_file_stream_or_string as string - return _StringIO(str(url_file_stream_or_string)) + if isinstance(url_file_stream_or_string, unicode): + return _StringIO(url_file_stream_or_string.encode('utf-8')) + return _StringIO(url_file_stream_or_string) + +def _convert_to_idn(url): + """Convert a URL to IDN notation""" + # this function should only be called with a unicode string + # strategy: if the host cannot be encoded in ascii, then + # it'll be necessary to encode it in idn form + parts = list(urlparse.urlsplit(url)) + try: + parts[1].encode('ascii') + except UnicodeEncodeError: + # the url needs to be converted to idn notation + host = parts[1].rsplit(':', 1) + newhost = [] + port = u'' + if len(host) == 2: + port = host.pop() + for h in host[0].split('.'): + newhost.append(h.encode('idna').decode('utf-8')) + parts[1] = '.'.join(newhost) + if port: + parts[1] += ':' + port + return urlparse.urlunsplit(parts) + else: + return url def _build_urllib2_request(url, agent, etag, modified, referrer, auth, request_headers): request = urllib2.Request(url) request.add_header('User-Agent', agent) if etag: request.add_header('If-None-Match', etag) - if type(modified) == type(''): + if isinstance(modified, basestring): modified = _parse_date(modified) elif isinstance(modified, datetime.datetime): modified = modified.utctimetuple() @@ -2942,7 +3090,7 @@ _date_handlers = [] def registerDateHandler(func): '''Register a date handler function (takes string, returns 9-tuple date in GMT)''' _date_handlers.insert(0, func) - + # ISO-8601 date parsing routines written by Fazal Majid. # The ISO 8601 standard is very convoluted and irregular - a full ISO 8601 # parser is beyond the scope of feedparser and would be a worthwhile addition @@ -2953,7 +3101,7 @@ def registerDateHandler(func): # Please note the order in templates is significant because we need a # greedy match. _iso8601_tmpl = ['YYYY-?MM-?DD', 'YYYY-0MM?-?DD', 'YYYY-MM', 'YYYY-?OOO', - 'YY-?MM-?DD', 'YY-?OOO', 'YYYY', + 'YY-?MM-?DD', 'YY-?OOO', 'YYYY', '-YY-?MM', '-OOO', '-YY', '--MM-?DD', '--MM', '---DD', @@ -2985,9 +3133,12 @@ def _parse_date_iso8601(dateString): m = None for _iso8601_match in _iso8601_matches: m = _iso8601_match(dateString) - if m: break - if not m: return - if m.span() == (0, 0): return + if m: + break + if not m: + return + if m.span() == (0, 0): + return params = m.groupdict() ordinal = params.get('ordinal', 0) if ordinal: @@ -3025,7 +3176,7 @@ def _parse_date_iso8601(dateString): day = int(day) # special case of the century - is the first year of the 21st century # 2000 or 2001 ? The debate goes on... - if 'century' in params.keys(): + if 'century' in params: year = (int(params['century']) - 1) * 100 + 1 # in ISO 8601 most fields are optional for field in ['hour', 'minute', 'second', 'tzhour', 'tzmin']: @@ -3055,7 +3206,7 @@ def _parse_date_iso8601(dateString): # Many implementations have bugs, but we'll pretend they don't. return time.localtime(time.mktime(tuple(tm))) registerDateHandler(_parse_date_iso8601) - + # 8-bit date handling routines written by ytrewq1. _korean_year = u'\ub144' # b3e2 in euc-kr _korean_month = u'\uc6d4' # bff9 in euc-kr @@ -3072,19 +3223,20 @@ _korean_nate_date_re = \ def _parse_date_onblog(dateString): '''Parse a string according to the OnBlog 8-bit date format''' m = _korean_onblog_date_re.match(dateString) - if not m: return + if not m: + return w3dtfdate = '%(year)s-%(month)s-%(day)sT%(hour)s:%(minute)s:%(second)s%(zonediff)s' % \ {'year': m.group(1), 'month': m.group(2), 'day': m.group(3),\ 'hour': m.group(4), 'minute': m.group(5), 'second': m.group(6),\ 'zonediff': '+09:00'} - if _debug: sys.stderr.write('OnBlog date parsed as: %s\n' % w3dtfdate) return _parse_date_w3dtf(w3dtfdate) registerDateHandler(_parse_date_onblog) def _parse_date_nate(dateString): '''Parse a string according to the Nate 8-bit date format''' m = _korean_nate_date_re.match(dateString) - if not m: return + if not m: + return hour = int(m.group(5)) ampm = m.group(4) if (ampm == _korean_pm): @@ -3096,24 +3248,9 @@ def _parse_date_nate(dateString): {'year': m.group(1), 'month': m.group(2), 'day': m.group(3),\ 'hour': hour, 'minute': m.group(6), 'second': m.group(7),\ 'zonediff': '+09:00'} - if _debug: sys.stderr.write('Nate date parsed as: %s\n' % w3dtfdate) return _parse_date_w3dtf(w3dtfdate) registerDateHandler(_parse_date_nate) -_mssql_date_re = \ - re.compile('(\d{4})-(\d{2})-(\d{2})\s+(\d{2}):(\d{2}):(\d{2})(\.\d+)?') -def _parse_date_mssql(dateString): - '''Parse a string according to the MS SQL date format''' - m = _mssql_date_re.match(dateString) - if not m: return - w3dtfdate = '%(year)s-%(month)s-%(day)sT%(hour)s:%(minute)s:%(second)s%(zonediff)s' % \ - {'year': m.group(1), 'month': m.group(2), 'day': m.group(3),\ - 'hour': m.group(4), 'minute': m.group(5), 'second': m.group(6),\ - 'zonediff': '+09:00'} - if _debug: sys.stderr.write('MS SQL date parsed as: %s\n' % w3dtfdate) - return _parse_date_w3dtf(w3dtfdate) -registerDateHandler(_parse_date_mssql) - # Unicode strings for Greek date strings _greek_months = \ { \ @@ -3146,7 +3283,7 @@ _greek_wdays = \ u'\u03a4\u03b5\u03c4': u'Wed', # d4e5f4 in iso-8859-7 u'\u03a0\u03b5\u03bc': u'Thu', # d0e5ec in iso-8859-7 u'\u03a0\u03b1\u03c1': u'Fri', # d0e1f1 in iso-8859-7 - u'\u03a3\u03b1\u03b2': u'Sat', # d3e1e2 in iso-8859-7 + u'\u03a3\u03b1\u03b2': u'Sat', # d3e1e2 in iso-8859-7 } _greek_date_format_re = \ @@ -3155,17 +3292,14 @@ _greek_date_format_re = \ def _parse_date_greek(dateString): '''Parse a string according to a Greek 8-bit date format.''' m = _greek_date_format_re.match(dateString) - if not m: return - try: - wday = _greek_wdays[m.group(1)] - month = _greek_months[m.group(3)] - except: + if not m: return + wday = _greek_wdays[m.group(1)] + month = _greek_months[m.group(3)] rfc822date = '%(wday)s, %(day)s %(month)s %(year)s %(hour)s:%(minute)s:%(second)s %(zonediff)s' % \ {'wday': wday, 'day': m.group(2), 'month': month, 'year': m.group(4),\ 'hour': m.group(5), 'minute': m.group(6), 'second': m.group(7),\ 'zonediff': m.group(8)} - if _debug: sys.stderr.write('Greek date parsed as: %s\n' % rfc822date) return _parse_date_rfc822(rfc822date) registerDateHandler(_parse_date_greek) @@ -3192,22 +3326,19 @@ _hungarian_date_format_re = \ def _parse_date_hungarian(dateString): '''Parse a string according to a Hungarian 8-bit date format.''' m = _hungarian_date_format_re.match(dateString) - if not m: return - try: - month = _hungarian_months[m.group(2)] - day = m.group(3) - if len(day) == 1: - day = '0' + day - hour = m.group(4) - if len(hour) == 1: - hour = '0' + hour - except: - return + if not m or m.group(2) not in _hungarian_months: + return None + month = _hungarian_months[m.group(2)] + day = m.group(3) + if len(day) == 1: + day = '0' + day + hour = m.group(4) + if len(hour) == 1: + hour = '0' + hour w3dtfdate = '%(year)s-%(month)s-%(day)sT%(hour)s:%(minute)s%(zonediff)s' % \ {'year': m.group(1), 'month': month, 'day': day,\ 'hour': hour, 'minute': m.group(5),\ 'zonediff': m.group(6)} - if _debug: sys.stderr.write('Hungarian date parsed as: %s\n' % w3dtfdate) return _parse_date_w3dtf(w3dtfdate) registerDateHandler(_parse_date_hungarian) @@ -3215,6 +3346,9 @@ registerDateHandler(_parse_date_hungarian) # Drake and licensed under the Python license. Removed all range checking # for month, day, hour, minute, and second, since mktime will normalize # these later +# Modified to also support MSSQL-style datetimes as defined at: +# http://msdn.microsoft.com/en-us/library/ms186724.aspx +# (which basically means allowing a space as a date/time/timezone separator) def _parse_date_w3dtf(dateString): def __extract_date(m): year = int(m.group('year')) @@ -3240,7 +3374,7 @@ def _parse_date_w3dtf(dateString): day = 31 elif jday < julian: if day + diff < 28: - day = day + diff + day = day + diff else: month = month + 1 return year, month, day @@ -3296,315 +3430,455 @@ def _parse_date_w3dtf(dateString): '(?:(?P<dsep>-|)' '(?:(?P<month>\d\d)(?:(?P=dsep)(?P<day>\d\d))?' '|(?P<julian>\d\d\d)))?') - __tzd_re = '(?P<tzd>[-+](?P<tzdhours>\d\d)(?::?(?P<tzdminutes>\d\d))|Z)' - __tzd_rx = re.compile(__tzd_re) + __tzd_re = ' ?(?P<tzd>[-+](?P<tzdhours>\d\d)(?::?(?P<tzdminutes>\d\d))|Z)?' __time_re = ('(?P<hours>\d\d)(?P<tsep>:|)(?P<minutes>\d\d)' '(?:(?P=tsep)(?P<seconds>\d\d)(?:[.,]\d+)?)?' + __tzd_re) - __datetime_re = '%s(?:T%s)?' % (__date_re, __time_re) + __datetime_re = '%s(?:[T ]%s)?' % (__date_re, __time_re) __datetime_rx = re.compile(__datetime_re) m = __datetime_rx.match(dateString) - if (m is None) or (m.group() != dateString): return + if (m is None) or (m.group() != dateString): + return gmt = __extract_date(m) + __extract_time(m) + (0, 0, 0) - if gmt[0] == 0: return + if gmt[0] == 0: + return return time.gmtime(time.mktime(gmt) + __extract_tzd(m) - time.timezone) registerDateHandler(_parse_date_w3dtf) -def _parse_date_rfc822(dateString): - '''Parse an RFC822, RFC1123, RFC2822, or asctime-style date''' - data = dateString.split() - if data[0][-1] in (',', '.') or data[0].lower() in rfc822._daynames: - del data[0] - if len(data) == 4: - s = data[3] - i = s.find('+') - if i > 0: - data[3:] = [s[:i], s[i+1:]] - else: - data.append('') - dateString = " ".join(data) - # Account for the Etc/GMT timezone by stripping 'Etc/' - elif len(data) == 5 and data[4].lower().startswith('etc/'): - data[4] = data[4][4:] - dateString = " ".join(data) - if len(data) < 5: - dateString += ' 00:00:00 GMT' +# Define the strings used by the RFC822 datetime parser +_rfc822_months = ['jan', 'feb', 'mar', 'apr', 'may', 'jun', + 'jul', 'aug', 'sep', 'oct', 'nov', 'dec'] +_rfc822_daynames = ['mon', 'tue', 'wed', 'thu', 'fri', 'sat', 'sun'] + +# Only the first three letters of the month name matter +_rfc822_month = "(?P<month>%s)(?:[a-z]*,?)" % ('|'.join(_rfc822_months)) +# The year may be 2 or 4 digits; capture the century if it exists +_rfc822_year = "(?P<year>(?:\d{2})?\d{2})" +_rfc822_day = "(?P<day> *\d{1,2})" +_rfc822_date = "%s %s %s" % (_rfc822_day, _rfc822_month, _rfc822_year) + +_rfc822_hour = "(?P<hour>\d{2}):(?P<minute>\d{2})(?::(?P<second>\d{2}))?" +_rfc822_tz = "(?P<tz>ut|gmt(?:[+-]\d{2}:\d{2})?|[aecmp][sd]?t|[zamny]|[+-]\d{4})" +_rfc822_tznames = { + 'ut': 0, 'gmt': 0, 'z': 0, + 'adt': -3, 'ast': -4, 'at': -4, + 'edt': -4, 'est': -5, 'et': -5, + 'cdt': -5, 'cst': -6, 'ct': -6, + 'mdt': -6, 'mst': -7, 'mt': -7, + 'pdt': -7, 'pst': -8, 'pt': -8, + 'a': -1, 'n': 1, + 'm': -12, 'y': 12, + } +# The timezone may be prefixed by 'Etc/' +_rfc822_time = "%s (?:etc/)?%s" % (_rfc822_hour, _rfc822_tz) + +_rfc822_dayname = "(?P<dayname>%s)" % ('|'.join(_rfc822_daynames)) +_rfc822_match = re.compile( + "(?:%s, )?%s(?: %s)?" % (_rfc822_dayname, _rfc822_date, _rfc822_time) +).match + +def _parse_date_group_rfc822(m): + # Calculate a date and timestamp + for k in ('year', 'day', 'hour', 'minute', 'second'): + m[k] = int(m[k]) + m['month'] = _rfc822_months.index(m['month']) + 1 + # If the year is 2 digits, assume everything in the 90's is the 1990's + if m['year'] < 100: + m['year'] += (1900, 2000)[m['year'] < 90] + stamp = datetime.datetime(*[m[i] for i in + ('year', 'month', 'day', 'hour', 'minute', 'second')]) + + # Use the timezone information to calculate the difference between + # the given date and timestamp and Universal Coordinated Time + tzhour = 0 + tzmin = 0 + if m['tz'] and m['tz'].startswith('gmt'): + # Handle GMT and GMT+hh:mm timezone syntax (the trailing + # timezone info will be handled by the next `if` block) + m['tz'] = ''.join(m['tz'][3:].split(':')) or 'gmt' + if not m['tz']: + pass + elif m['tz'].startswith('+'): + tzhour = int(m['tz'][1:3]) + tzmin = int(m['tz'][3:]) + elif m['tz'].startswith('-'): + tzhour = int(m['tz'][1:3]) * -1 + tzmin = int(m['tz'][3:]) * -1 + else: + tzhour = _rfc822_tznames[m['tz']] + delta = datetime.timedelta(0, 0, 0, 0, tzmin, tzhour) + + # Return the date and timestamp in UTC + return (stamp - delta).utctimetuple() + +def _parse_date_rfc822(dt): + """Parse RFC 822 dates and times, with one minor + difference: years may be 4DIGIT or 2DIGIT. + http://tools.ietf.org/html/rfc822#section-5""" + try: + m = _rfc822_match(dt.lower()).groupdict(0) + except AttributeError: + return None + + return _parse_date_group_rfc822(m) +registerDateHandler(_parse_date_rfc822) + +def _parse_date_rfc822_grubby(dt): + """Parse date format similar to RFC 822, but + the comma after the dayname is optional and + month/day are inverted""" + _rfc822_date_grubby = "%s %s %s" % (_rfc822_month, _rfc822_day, _rfc822_year) + _rfc822_match_grubby = re.compile( + "(?:%s[,]? )?%s(?: %s)?" % (_rfc822_dayname, _rfc822_date_grubby, _rfc822_time) + ).match + + try: + m = _rfc822_match_grubby(dt.lower()).groupdict(0) + except AttributeError: + return None + + return _parse_date_group_rfc822(m) +registerDateHandler(_parse_date_rfc822_grubby) + +def _parse_date_asctime(dt): + """Parse asctime-style dates""" + dayname, month, day, remainder = dt.split(None, 3) + # Convert month and day into zero-padded integers + month = '%02i ' % (_rfc822_months.index(month.lower()) + 1) + day = '%02i ' % (int(day),) + dt = month + day + remainder + return time.strptime(dt, '%m %d %H:%M:%S %Y')[:-1] + (0, ) +registerDateHandler(_parse_date_asctime) + +def _parse_date_perforce(aDateString): + """parse a date in yyyy/mm/dd hh:mm:ss TTT format""" + # Fri, 2006/09/15 08:19:53 EDT + _my_date_pattern = re.compile( \ + r'(\w{,3}), (\d{,4})/(\d{,2})/(\d{2}) (\d{,2}):(\d{2}):(\d{2}) (\w{,3})') + + m = _my_date_pattern.search(aDateString) + if m is None: + return None + dow, year, month, day, hour, minute, second, tz = m.groups() + months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] + dateString = "%s, %s %s %s %s:%s:%s %s" % (dow, day, months[int(month) - 1], year, hour, minute, second, tz) tm = rfc822.parsedate_tz(dateString) if tm: return time.gmtime(rfc822.mktime_tz(tm)) -# rfc822.py defines several time zones, but we define some extra ones. -# 'ET' is equivalent to 'EST', etc. -_additional_timezones = {'AT': -400, 'ET': -500, 'CT': -600, 'MT': -700, 'PT': -800} -rfc822._timezones.update(_additional_timezones) -registerDateHandler(_parse_date_rfc822) - -def _parse_date_perforce(aDateString): - """parse a date in yyyy/mm/dd hh:mm:ss TTT format""" - # Fri, 2006/09/15 08:19:53 EDT - _my_date_pattern = re.compile( \ - r'(\w{,3}), (\d{,4})/(\d{,2})/(\d{2}) (\d{,2}):(\d{2}):(\d{2}) (\w{,3})') - - dow, year, month, day, hour, minute, second, tz = \ - _my_date_pattern.search(aDateString).groups() - months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] - dateString = "%s, %s %s %s %s:%s:%s %s" % (dow, day, months[int(month) - 1], year, hour, minute, second, tz) - tm = rfc822.parsedate_tz(dateString) - if tm: - return time.gmtime(rfc822.mktime_tz(tm)) registerDateHandler(_parse_date_perforce) def _parse_date(dateString): '''Parses a variety of date formats into a 9-tuple in GMT''' + if not dateString: + return None for handler in _date_handlers: try: date9tuple = handler(dateString) - if not date9tuple: continue - if len(date9tuple) != 9: - if _debug: sys.stderr.write('date handler function must return 9-tuple\n') - raise ValueError - map(int, date9tuple) - return date9tuple - except Exception, e: - if _debug: sys.stderr.write('%s raised %s\n' % (handler.__name__, repr(e))) - pass + except (KeyError, OverflowError, ValueError): + continue + if not date9tuple: + continue + if len(date9tuple) != 9: + continue + return date9tuple return None -def _getCharacterEncoding(http_headers, xml_data): - '''Get the character encoding of the XML document +# Each marker represents some of the characters of the opening XML +# processing instruction ('<?xm') in the specified encoding. +EBCDIC_MARKER = _l2bytes([0x4C, 0x6F, 0xA7, 0x94]) +UTF16BE_MARKER = _l2bytes([0x00, 0x3C, 0x00, 0x3F]) +UTF16LE_MARKER = _l2bytes([0x3C, 0x00, 0x3F, 0x00]) +UTF32BE_MARKER = _l2bytes([0x00, 0x00, 0x00, 0x3C]) +UTF32LE_MARKER = _l2bytes([0x3C, 0x00, 0x00, 0x00]) + +ZERO_BYTES = _l2bytes([0x00, 0x00]) + +# Match the opening XML declaration. +# Example: <?xml version="1.0" encoding="utf-8"?> +RE_XML_DECLARATION = re.compile('^<\?xml[^>]*?>') + +# Capture the value of the XML processing instruction's encoding attribute. +# Example: <?xml version="1.0" encoding="utf-8"?> +RE_XML_PI_ENCODING = re.compile(_s2bytes('^<\?.*encoding=[\'"](.*?)[\'"].*\?>')) + +def convert_to_utf8(http_headers, data): + '''Detect and convert the character encoding to UTF-8. http_headers is a dictionary - xml_data is a raw string (not Unicode) - - This is so much trickier than it sounds, it's not even funny. - According to RFC 3023 ('XML Media Types'), if the HTTP Content-Type - is application/xml, application/*+xml, - application/xml-external-parsed-entity, or application/xml-dtd, - the encoding given in the charset parameter of the HTTP Content-Type - takes precedence over the encoding given in the XML prefix within the - document, and defaults to 'utf-8' if neither are specified. But, if - the HTTP Content-Type is text/xml, text/*+xml, or - text/xml-external-parsed-entity, the encoding given in the XML prefix - within the document is ALWAYS IGNORED and only the encoding given in - the charset parameter of the HTTP Content-Type header should be - respected, and it defaults to 'us-ascii' if not specified. - - Furthermore, discussion on the atom-syntax mailing list with the - author of RFC 3023 leads me to the conclusion that any document - served with a Content-Type of text/* and no charset parameter - must be treated as us-ascii. (We now do this.) And also that it - must always be flagged as non-well-formed. (We now do this too.) - - If Content-Type is unspecified (input was local file or non-HTTP source) - or unrecognized (server just got it totally wrong), then go by the - encoding given in the XML prefix of the document and default to - 'iso-8859-1' as per the HTTP specification (RFC 2616). - - Then, assuming we didn't find a character encoding in the HTTP headers - (and the HTTP Content-type allowed us to look in the body), we need - to sniff the first few bytes of the XML data and try to determine - whether the encoding is ASCII-compatible. Section F of the XML - specification shows the way here: - http://www.w3.org/TR/REC-xml/#sec-guessing-no-ext-info - - If the sniffed encoding is not ASCII-compatible, we need to make it - ASCII compatible so that we can sniff further into the XML declaration - to find the encoding attribute, which will tell us the true encoding. - - Of course, none of this guarantees that we will be able to parse the - feed in the declared character encoding (assuming it was declared - correctly, which many are not). CJKCodecs and iconv_codec help a lot; - you should definitely install them if you can. - http://cjkpython.i18n.org/ - ''' + data is a raw string (not Unicode)''' + + # This is so much trickier than it sounds, it's not even funny. + # According to RFC 3023 ('XML Media Types'), if the HTTP Content-Type + # is application/xml, application/*+xml, + # application/xml-external-parsed-entity, or application/xml-dtd, + # the encoding given in the charset parameter of the HTTP Content-Type + # takes precedence over the encoding given in the XML prefix within the + # document, and defaults to 'utf-8' if neither are specified. But, if + # the HTTP Content-Type is text/xml, text/*+xml, or + # text/xml-external-parsed-entity, the encoding given in the XML prefix + # within the document is ALWAYS IGNORED and only the encoding given in + # the charset parameter of the HTTP Content-Type header should be + # respected, and it defaults to 'us-ascii' if not specified. + + # Furthermore, discussion on the atom-syntax mailing list with the + # author of RFC 3023 leads me to the conclusion that any document + # served with a Content-Type of text/* and no charset parameter + # must be treated as us-ascii. (We now do this.) And also that it + # must always be flagged as non-well-formed. (We now do this too.) + + # If Content-Type is unspecified (input was local file or non-HTTP source) + # or unrecognized (server just got it totally wrong), then go by the + # encoding given in the XML prefix of the document and default to + # 'iso-8859-1' as per the HTTP specification (RFC 2616). + + # Then, assuming we didn't find a character encoding in the HTTP headers + # (and the HTTP Content-type allowed us to look in the body), we need + # to sniff the first few bytes of the XML data and try to determine + # whether the encoding is ASCII-compatible. Section F of the XML + # specification shows the way here: + # http://www.w3.org/TR/REC-xml/#sec-guessing-no-ext-info - def _parseHTTPContentType(content_type): - '''takes HTTP Content-Type header and returns (content type, charset) - - If no charset is specified, returns (content type, '') - If no content type is specified, returns ('', '') - Both return parameters are guaranteed to be lowercase strings - ''' - content_type = content_type or '' - content_type, params = cgi.parse_header(content_type) - return content_type, params.get('charset', '').replace("'", '') - - sniffed_xml_encoding = '' - xml_encoding = '' - true_encoding = '' - http_content_type, http_encoding = _parseHTTPContentType(http_headers.get('content-type', http_headers.get('Content-type'))) - # Must sniff for non-ASCII-compatible character encodings before - # searching for XML declaration. This heuristic is defined in - # section F of the XML specification: + # If the sniffed encoding is not ASCII-compatible, we need to make it + # ASCII compatible so that we can sniff further into the XML declaration + # to find the encoding attribute, which will tell us the true encoding. + + # Of course, none of this guarantees that we will be able to parse the + # feed in the declared character encoding (assuming it was declared + # correctly, which many are not). iconv_codec can help a lot; + # you should definitely install it if you can. + # http://cjkpython.i18n.org/ + + bom_encoding = u'' + xml_encoding = u'' + rfc3023_encoding = u'' + + # Look at the first few bytes of the document to guess what + # its encoding may be. We only need to decode enough of the + # document that we can use an ASCII-compatible regular + # expression to search for an XML encoding declaration. + # The heuristic follows the XML specification, section F: # http://www.w3.org/TR/REC-xml/#sec-guessing-no-ext-info + # Check for BOMs first. + if data[:4] == codecs.BOM_UTF32_BE: + bom_encoding = u'utf-32be' + data = data[4:] + elif data[:4] == codecs.BOM_UTF32_LE: + bom_encoding = u'utf-32le' + data = data[4:] + elif data[:2] == codecs.BOM_UTF16_BE and data[2:4] != ZERO_BYTES: + bom_encoding = u'utf-16be' + data = data[2:] + elif data[:2] == codecs.BOM_UTF16_LE and data[2:4] != ZERO_BYTES: + bom_encoding = u'utf-16le' + data = data[2:] + elif data[:3] == codecs.BOM_UTF8: + bom_encoding = u'utf-8' + data = data[3:] + # Check for the characters '<?xm' in several encodings. + elif data[:4] == EBCDIC_MARKER: + bom_encoding = u'cp037' + elif data[:4] == UTF16BE_MARKER: + bom_encoding = u'utf-16be' + elif data[:4] == UTF16LE_MARKER: + bom_encoding = u'utf-16le' + elif data[:4] == UTF32BE_MARKER: + bom_encoding = u'utf-32be' + elif data[:4] == UTF32LE_MARKER: + bom_encoding = u'utf-32le' + + tempdata = data try: - if xml_data[:4] == _l2bytes([0x4c, 0x6f, 0xa7, 0x94]): - # EBCDIC - xml_data = _ebcdic_to_ascii(xml_data) - elif xml_data[:4] == _l2bytes([0x00, 0x3c, 0x00, 0x3f]): - # UTF-16BE - sniffed_xml_encoding = 'utf-16be' - xml_data = unicode(xml_data, 'utf-16be').encode('utf-8') - elif (len(xml_data) >= 4) and (xml_data[:2] == _l2bytes([0xfe, 0xff])) and (xml_data[2:4] != _l2bytes([0x00, 0x00])): - # UTF-16BE with BOM - sniffed_xml_encoding = 'utf-16be' - xml_data = unicode(xml_data[2:], 'utf-16be').encode('utf-8') - elif xml_data[:4] == _l2bytes([0x3c, 0x00, 0x3f, 0x00]): - # UTF-16LE - sniffed_xml_encoding = 'utf-16le' - xml_data = unicode(xml_data, 'utf-16le').encode('utf-8') - elif (len(xml_data) >= 4) and (xml_data[:2] == _l2bytes([0xff, 0xfe])) and (xml_data[2:4] != _l2bytes([0x00, 0x00])): - # UTF-16LE with BOM - sniffed_xml_encoding = 'utf-16le' - xml_data = unicode(xml_data[2:], 'utf-16le').encode('utf-8') - elif xml_data[:4] == _l2bytes([0x00, 0x00, 0x00, 0x3c]): - # UTF-32BE - sniffed_xml_encoding = 'utf-32be' - xml_data = unicode(xml_data, 'utf-32be').encode('utf-8') - elif xml_data[:4] == _l2bytes([0x3c, 0x00, 0x00, 0x00]): - # UTF-32LE - sniffed_xml_encoding = 'utf-32le' - xml_data = unicode(xml_data, 'utf-32le').encode('utf-8') - elif xml_data[:4] == _l2bytes([0x00, 0x00, 0xfe, 0xff]): - # UTF-32BE with BOM - sniffed_xml_encoding = 'utf-32be' - xml_data = unicode(xml_data[4:], 'utf-32be').encode('utf-8') - elif xml_data[:4] == _l2bytes([0xff, 0xfe, 0x00, 0x00]): - # UTF-32LE with BOM - sniffed_xml_encoding = 'utf-32le' - xml_data = unicode(xml_data[4:], 'utf-32le').encode('utf-8') - elif xml_data[:3] == _l2bytes([0xef, 0xbb, 0xbf]): - # UTF-8 with BOM - sniffed_xml_encoding = 'utf-8' - xml_data = unicode(xml_data[3:], 'utf-8').encode('utf-8') - else: - # ASCII-compatible - pass - xml_encoding_match = re.compile(_s2bytes('^<\?.*encoding=[\'"](.*?)[\'"].*\?>')).match(xml_data) - except: + if bom_encoding: + tempdata = data.decode(bom_encoding).encode('utf-8') + except (UnicodeDecodeError, LookupError): + # feedparser recognizes UTF-32 encodings that aren't + # available in Python 2.4 and 2.5, so it's possible to + # encounter a LookupError during decoding. xml_encoding_match = None + else: + xml_encoding_match = RE_XML_PI_ENCODING.match(tempdata) + if xml_encoding_match: xml_encoding = xml_encoding_match.groups()[0].decode('utf-8').lower() - if sniffed_xml_encoding and (xml_encoding in ('iso-10646-ucs-2', 'ucs-2', 'csunicode', 'iso-10646-ucs-4', 'ucs-4', 'csucs4', 'utf-16', 'utf-32', 'utf_16', 'utf_32', 'utf16', 'u16')): - xml_encoding = sniffed_xml_encoding + # Normalize the xml_encoding if necessary. + if bom_encoding and (xml_encoding in ( + u'u16', u'utf-16', u'utf16', u'utf_16', + u'u32', u'utf-32', u'utf32', u'utf_32', + u'iso-10646-ucs-2', u'iso-10646-ucs-4', + u'csucs4', u'csunicode', u'ucs-2', u'ucs-4' + )): + xml_encoding = bom_encoding + + # Find the HTTP Content-Type and, hopefully, a character + # encoding provided by the server. The Content-Type is used + # to choose the "correct" encoding among the BOM encoding, + # XML declaration encoding, and HTTP encoding, following the + # heuristic defined in RFC 3023. + http_content_type = http_headers.get('content-type') or '' + http_content_type, params = cgi.parse_header(http_content_type) + http_encoding = params.get('charset', '').replace("'", "") + if not isinstance(http_encoding, unicode): + http_encoding = http_encoding.decode('utf-8', 'ignore') + acceptable_content_type = 0 - application_content_types = ('application/xml', 'application/xml-dtd', 'application/xml-external-parsed-entity') - text_content_types = ('text/xml', 'text/xml-external-parsed-entity') + application_content_types = (u'application/xml', u'application/xml-dtd', + u'application/xml-external-parsed-entity') + text_content_types = (u'text/xml', u'text/xml-external-parsed-entity') if (http_content_type in application_content_types) or \ - (http_content_type.startswith('application/') and http_content_type.endswith('+xml')): + (http_content_type.startswith(u'application/') and + http_content_type.endswith(u'+xml')): acceptable_content_type = 1 - true_encoding = http_encoding or xml_encoding or 'utf-8' + rfc3023_encoding = http_encoding or xml_encoding or u'utf-8' elif (http_content_type in text_content_types) or \ - (http_content_type.startswith('text/')) and http_content_type.endswith('+xml'): + (http_content_type.startswith(u'text/') and + http_content_type.endswith(u'+xml')): acceptable_content_type = 1 - true_encoding = http_encoding or 'us-ascii' - elif http_content_type.startswith('text/'): - true_encoding = http_encoding or 'us-ascii' - elif http_headers and (not (http_headers.has_key('content-type') or http_headers.has_key('Content-type'))): - true_encoding = xml_encoding or 'iso-8859-1' - else: - true_encoding = xml_encoding or 'utf-8' - # some feeds claim to be gb2312 but are actually gb18030. - # apparently MSIE and Firefox both do the following switch: - if true_encoding.lower() == 'gb2312': - true_encoding = 'gb18030' - return true_encoding, http_encoding, xml_encoding, sniffed_xml_encoding, acceptable_content_type - -def _toUTF8(data, encoding): - '''Changes an XML data stream on the fly to specify a new encoding - - data is a raw sequence of bytes (not Unicode) that is presumed to be in %encoding already - encoding is a string recognized by encodings.aliases - ''' - if _debug: sys.stderr.write('entering _toUTF8, trying encoding %s\n' % encoding) - # strip Byte Order Mark (if present) - if (len(data) >= 4) and (data[:2] == _l2bytes([0xfe, 0xff])) and (data[2:4] != _l2bytes([0x00, 0x00])): - if _debug: - sys.stderr.write('stripping BOM\n') - if encoding != 'utf-16be': - sys.stderr.write('trying utf-16be instead\n') - encoding = 'utf-16be' - data = data[2:] - elif (len(data) >= 4) and (data[:2] == _l2bytes([0xff, 0xfe])) and (data[2:4] != _l2bytes([0x00, 0x00])): - if _debug: - sys.stderr.write('stripping BOM\n') - if encoding != 'utf-16le': - sys.stderr.write('trying utf-16le instead\n') - encoding = 'utf-16le' - data = data[2:] - elif data[:3] == _l2bytes([0xef, 0xbb, 0xbf]): - if _debug: - sys.stderr.write('stripping BOM\n') - if encoding != 'utf-8': - sys.stderr.write('trying utf-8 instead\n') - encoding = 'utf-8' - data = data[3:] - elif data[:4] == _l2bytes([0x00, 0x00, 0xfe, 0xff]): - if _debug: - sys.stderr.write('stripping BOM\n') - if encoding != 'utf-32be': - sys.stderr.write('trying utf-32be instead\n') - encoding = 'utf-32be' - data = data[4:] - elif data[:4] == _l2bytes([0xff, 0xfe, 0x00, 0x00]): - if _debug: - sys.stderr.write('stripping BOM\n') - if encoding != 'utf-32le': - sys.stderr.write('trying utf-32le instead\n') - encoding = 'utf-32le' - data = data[4:] - newdata = unicode(data, encoding) - if _debug: sys.stderr.write('successfully converted %s data to unicode\n' % encoding) - declmatch = re.compile('^<\?xml[^>]*?>') - newdecl = '''<?xml version='1.0' encoding='utf-8'?>''' - if declmatch.search(newdata): - newdata = declmatch.sub(newdecl, newdata) + rfc3023_encoding = http_encoding or u'us-ascii' + elif http_content_type.startswith(u'text/'): + rfc3023_encoding = http_encoding or u'us-ascii' + elif http_headers and 'content-type' not in http_headers: + rfc3023_encoding = xml_encoding or u'iso-8859-1' else: - newdata = newdecl + u'\n' + newdata - return newdata.encode('utf-8') + rfc3023_encoding = xml_encoding or u'utf-8' + # gb18030 is a superset of gb2312, so always replace gb2312 + # with gb18030 for greater compatibility. + if rfc3023_encoding.lower() == u'gb2312': + rfc3023_encoding = u'gb18030' + if xml_encoding.lower() == u'gb2312': + xml_encoding = u'gb18030' + + # there are four encodings to keep track of: + # - http_encoding is the encoding declared in the Content-Type HTTP header + # - xml_encoding is the encoding declared in the <?xml declaration + # - bom_encoding is the encoding sniffed from the first 4 bytes of the XML data + # - rfc3023_encoding is the actual encoding, as per RFC 3023 and a variety of other conflicting specifications + error = None -def _stripDoctype(data): - '''Strips DOCTYPE from XML document, returns (rss_version, stripped_data) + if http_headers and (not acceptable_content_type): + if 'content-type' in http_headers: + msg = '%s is not an XML media type' % http_headers['content-type'] + else: + msg = 'no Content-type specified' + error = NonXMLContentType(msg) + + # determine character encoding + known_encoding = 0 + chardet_encoding = None + tried_encodings = [] + if chardet: + chardet_encoding = unicode(chardet.detect(data)['encoding'] or '', 'ascii', 'ignore') + # try: HTTP encoding, declared XML encoding, encoding sniffed from BOM + for proposed_encoding in (rfc3023_encoding, xml_encoding, bom_encoding, + chardet_encoding, u'utf-8', u'windows-1252', u'iso-8859-2'): + if not proposed_encoding: + continue + if proposed_encoding in tried_encodings: + continue + tried_encodings.append(proposed_encoding) + try: + data = data.decode(proposed_encoding) + except (UnicodeDecodeError, LookupError): + pass + else: + known_encoding = 1 + # Update the encoding in the opening XML processing instruction. + new_declaration = '''<?xml version='1.0' encoding='utf-8'?>''' + if RE_XML_DECLARATION.search(data): + data = RE_XML_DECLARATION.sub(new_declaration, data) + else: + data = new_declaration + u'\n' + data + data = data.encode('utf-8') + break + # if still no luck, give up + if not known_encoding: + error = CharacterEncodingUnknown( + 'document encoding unknown, I tried ' + + '%s, %s, utf-8, windows-1252, and iso-8859-2 but nothing worked' % + (rfc3023_encoding, xml_encoding)) + rfc3023_encoding = u'' + elif proposed_encoding != rfc3023_encoding: + error = CharacterEncodingOverride( + 'document declared as %s, but parsed as %s' % + (rfc3023_encoding, proposed_encoding)) + rfc3023_encoding = proposed_encoding + + return data, rfc3023_encoding, error + +# Match XML entity declarations. +# Example: <!ENTITY copyright "(C)"> +RE_ENTITY_PATTERN = re.compile(_s2bytes(r'^\s*<!ENTITY([^>]*?)>'), re.MULTILINE) + +# Match XML DOCTYPE declarations. +# Example: <!DOCTYPE feed [ ]> +RE_DOCTYPE_PATTERN = re.compile(_s2bytes(r'^\s*<!DOCTYPE([^>]*?)>'), re.MULTILINE) + +# Match safe entity declarations. +# This will allow hexadecimal character references through, +# as well as text, but not arbitrary nested entities. +# Example: cubed "³" +# Example: copyright "(C)" +# Forbidden: explode1 "&explode2;&explode2;" +RE_SAFE_ENTITY_PATTERN = re.compile(_s2bytes('\s+(\w+)\s+"(&#\w+;|[^&"]*)"')) + +def replace_doctype(data): + '''Strips and replaces the DOCTYPE, returns (rss_version, stripped_data) rss_version may be 'rss091n' or None - stripped_data is the same XML document, minus the DOCTYPE + stripped_data is the same XML document with a replaced DOCTYPE ''' + + # Divide the document into two groups by finding the location + # of the first element that doesn't begin with '<?' or '<!'. start = re.search(_s2bytes('<\w'), data) start = start and start.start() or -1 - head,data = data[:start+1], data[start+1:] - - entity_pattern = re.compile(_s2bytes(r'^\s*<!ENTITY([^>]*?)>'), re.MULTILINE) - entity_results=entity_pattern.findall(head) - head = entity_pattern.sub(_s2bytes(''), head) - doctype_pattern = re.compile(_s2bytes(r'^\s*<!DOCTYPE([^>]*?)>'), re.MULTILINE) - doctype_results = doctype_pattern.findall(head) + head, data = data[:start+1], data[start+1:] + + # Save and then remove all of the ENTITY declarations. + entity_results = RE_ENTITY_PATTERN.findall(head) + head = RE_ENTITY_PATTERN.sub(_s2bytes(''), head) + + # Find the DOCTYPE declaration and check the feed type. + doctype_results = RE_DOCTYPE_PATTERN.findall(head) doctype = doctype_results and doctype_results[0] or _s2bytes('') - if doctype.lower().count(_s2bytes('netscape')): - version = 'rss091n' + if _s2bytes('netscape') in doctype.lower(): + version = u'rss091n' else: version = None - # only allow in 'safe' inline entity definitions - replacement=_s2bytes('') - if len(doctype_results)==1 and entity_results: - safe_pattern=re.compile(_s2bytes('\s+(\w+)\s+"(&#\w+;|[^&"]*)"')) - safe_entities=filter(lambda e: safe_pattern.match(e),entity_results) - if safe_entities: - replacement=_s2bytes('<!DOCTYPE feed [\n <!ENTITY') + _s2bytes('>\n <!ENTITY ').join(safe_entities) + _s2bytes('>\n]>') - data = doctype_pattern.sub(replacement, head) + data - - return version, data, dict(replacement and [(k.decode('utf-8'), v.decode('utf-8')) for k, v in safe_pattern.findall(replacement)]) - -def parse(url_file_stream_or_string, etag=None, modified=None, agent=None, referrer=None, handlers=[], request_headers={}, response_headers={}): + # Re-insert the safe ENTITY declarations if a DOCTYPE was found. + replacement = _s2bytes('') + if len(doctype_results) == 1 and entity_results: + match_safe_entities = lambda e: RE_SAFE_ENTITY_PATTERN.match(e) + safe_entities = filter(match_safe_entities, entity_results) + if safe_entities: + replacement = _s2bytes('<!DOCTYPE feed [\n<!ENTITY') \ + + _s2bytes('>\n<!ENTITY ').join(safe_entities) \ + + _s2bytes('>\n]>') + data = RE_DOCTYPE_PATTERN.sub(replacement, head) + data + + # Precompute the safe entities for the loose parser. + safe_entities = dict((k.decode('utf-8'), v.decode('utf-8')) + for k, v in RE_SAFE_ENTITY_PATTERN.findall(replacement)) + return version, data, safe_entities + +def parse(url_file_stream_or_string, etag=None, modified=None, agent=None, referrer=None, handlers=None, request_headers=None, response_headers=None): '''Parse a feed from a URL, file, stream, or string. - + request_headers, if given, is a dict from http header name to value to add to the request; this overrides internally generated values. ''' + + if handlers is None: + handlers = [] + if request_headers is None: + request_headers = {} + if response_headers is None: + response_headers = {} + result = FeedParserDict() result['feed'] = FeedParserDict() result['entries'] = [] - if _XML_AVAILABLE: - result['bozo'] = 0 + result['bozo'] = 0 if not isinstance(handlers, list): handlers = [handlers] try: @@ -3624,148 +3898,88 @@ def parse(url_file_stream_or_string, etag=None, modified=None, agent=None, refer elif response_headers: result['headers'] = copy.deepcopy(response_headers) + # lowercase all of the HTTP headers for comparisons per RFC 2616 + if 'headers' in result: + http_headers = dict((k.lower(), v) for k, v in result['headers'].items()) + else: + http_headers = {} + # if feed is gzip-compressed, decompress it - if f and data and 'headers' in result: - if gzip and result['headers'].get('content-encoding') == 'gzip': + if f and data and http_headers: + if gzip and 'gzip' in http_headers.get('content-encoding', ''): try: data = gzip.GzipFile(fileobj=_StringIO(data)).read() - except Exception, e: - # Some feeds claim to be gzipped but they're not, so - # we get garbage. Ideally, we should re-request the - # feed without the 'Accept-encoding: gzip' header, - # but we don't. + except (IOError, struct.error), e: + # IOError can occur if the gzip header is bad. + # struct.error can occur if the data is damaged. result['bozo'] = 1 result['bozo_exception'] = e - data = '' - elif zlib and result['headers'].get('content-encoding') == 'deflate': + if isinstance(e, struct.error): + # A gzip header was found but the data is corrupt. + # Ideally, we should re-request the feed without the + # 'Accept-encoding: gzip' header, but we don't. + data = None + elif zlib and 'deflate' in http_headers.get('content-encoding', ''): try: - data = zlib.decompress(data, -zlib.MAX_WBITS) - except Exception, e: - result['bozo'] = 1 - result['bozo_exception'] = e - data = '' + data = zlib.decompress(data) + except zlib.error, e: + try: + # The data may have no headers and no checksum. + data = zlib.decompress(data, -15) + except zlib.error, e: + result['bozo'] = 1 + result['bozo_exception'] = e # save HTTP headers - if 'headers' in result: - if 'etag' in result['headers'] or 'ETag' in result['headers']: - etag = result['headers'].get('etag', result['headers'].get('ETag')) + if http_headers: + if 'etag' in http_headers: + etag = http_headers.get('etag', u'') + if not isinstance(etag, unicode): + etag = etag.decode('utf-8', 'ignore') if etag: result['etag'] = etag - if 'last-modified' in result['headers'] or 'Last-Modified' in result['headers']: - modified = result['headers'].get('last-modified', result['headers'].get('Last-Modified')) + if 'last-modified' in http_headers: + modified = http_headers.get('last-modified', u'') if modified: - result['modified'] = _parse_date(modified) + result['modified'] = modified + result['modified_parsed'] = _parse_date(modified) if hasattr(f, 'url'): - result['href'] = f.url + if not isinstance(f.url, unicode): + result['href'] = f.url.decode('utf-8', 'ignore') + else: + result['href'] = f.url result['status'] = 200 if hasattr(f, 'status'): result['status'] = f.status if hasattr(f, 'close'): f.close() - # there are four encodings to keep track of: - # - http_encoding is the encoding declared in the Content-Type HTTP header - # - xml_encoding is the encoding declared in the <?xml declaration - # - sniffed_encoding is the encoding sniffed from the first 4 bytes of the XML data - # - result['encoding'] is the actual encoding, as per RFC 3023 and a variety of other conflicting specifications - http_headers = result.get('headers', {}) - result['encoding'], http_encoding, xml_encoding, sniffed_xml_encoding, acceptable_content_type = \ - _getCharacterEncoding(http_headers, data) - if http_headers and (not acceptable_content_type): - if http_headers.has_key('content-type') or http_headers.has_key('Content-type'): - bozo_message = '%s is not an XML media type' % http_headers.get('content-type', http_headers.get('Content-type')) - else: - bozo_message = 'no Content-type specified' - result['bozo'] = 1 - result['bozo_exception'] = NonXMLContentType(bozo_message) - - if data is not None: - result['version'], data, entities = _stripDoctype(data) - - # ensure that baseuri is an absolute uri using an acceptable URI scheme - contentloc = http_headers.get('content-location', http_headers.get('Content-Location', '')) - href = result.get('href', '') - baseuri = _makeSafeAbsoluteURI(href, contentloc) or _makeSafeAbsoluteURI(contentloc) or href - - baselang = http_headers.get('content-language', http_headers.get('Content-Language', None)) + if data is None: + return result - # if server sent 304, we're done - if result.get('status', 0) == 304: - result['version'] = '' + # Stop processing if the server sent HTTP 304 Not Modified. + if getattr(f, 'code', 0) == 304: + result['version'] = u'' result['debug_message'] = 'The feed has not changed since you last checked, ' + \ 'so the server sent no data. This is a feature, not a bug!' return result - # if there was a problem downloading, we're done - if data is None: - return result - - # determine character encoding - use_strict_parser = 0 - known_encoding = 0 - tried_encodings = [] - # try: HTTP encoding, declared XML encoding, encoding sniffed from BOM - for proposed_encoding in (result['encoding'], xml_encoding, sniffed_xml_encoding): - if not proposed_encoding: continue - if proposed_encoding in tried_encodings: continue - tried_encodings.append(proposed_encoding) - try: - data = _toUTF8(data, proposed_encoding) - known_encoding = use_strict_parser = 1 - break - except: - pass - # if no luck and we have auto-detection library, try that - if (not known_encoding) and chardet: - try: - proposed_encoding = chardet.detect(data)['encoding'] - if proposed_encoding and (proposed_encoding not in tried_encodings): - tried_encodings.append(proposed_encoding) - data = _toUTF8(data, proposed_encoding) - known_encoding = use_strict_parser = 1 - except: - pass - # if still no luck and we haven't tried utf-8 yet, try that - if (not known_encoding) and ('utf-8' not in tried_encodings): - try: - proposed_encoding = 'utf-8' - tried_encodings.append(proposed_encoding) - data = _toUTF8(data, proposed_encoding) - known_encoding = use_strict_parser = 1 - except: - pass - # if still no luck and we haven't tried windows-1252 yet, try that - if (not known_encoding) and ('windows-1252' not in tried_encodings): - try: - proposed_encoding = 'windows-1252' - tried_encodings.append(proposed_encoding) - data = _toUTF8(data, proposed_encoding) - known_encoding = use_strict_parser = 1 - except: - pass - # if still no luck and we haven't tried iso-8859-2 yet, try that. - if (not known_encoding) and ('iso-8859-2' not in tried_encodings): - try: - proposed_encoding = 'iso-8859-2' - tried_encodings.append(proposed_encoding) - data = _toUTF8(data, proposed_encoding) - known_encoding = use_strict_parser = 1 - except: - pass - # if still no luck, give up - if not known_encoding: + data, result['encoding'], error = convert_to_utf8(http_headers, data) + use_strict_parser = result['encoding'] and True or False + if error is not None: result['bozo'] = 1 - result['bozo_exception'] = CharacterEncodingUnknown( \ - 'document encoding unknown, I tried ' + \ - '%s, %s, utf-8, windows-1252, and iso-8859-2 but nothing worked' % \ - (result['encoding'], xml_encoding)) - result['encoding'] = '' - elif proposed_encoding != result['encoding']: - result['bozo'] = 1 - result['bozo_exception'] = CharacterEncodingOverride( \ - 'document declared as %s, but parsed as %s' % \ - (result['encoding'], proposed_encoding)) - result['encoding'] = proposed_encoding + result['bozo_exception'] = error + + result['version'], data, entities = replace_doctype(data) + + # Ensure that baseuri is an absolute URI using an acceptable URI scheme. + contentloc = http_headers.get('content-location', u'') + href = result.get('href', u'') + baseuri = _makeSafeAbsoluteURI(href, contentloc) or _makeSafeAbsoluteURI(contentloc) or href + + baselang = http_headers.get('content-language', None) + if not isinstance(baselang, unicode) and baselang is not None: + baselang = baselang.decode('utf-8', 'ignore') if not _XML_AVAILABLE: use_strict_parser = 0 @@ -3774,26 +3988,22 @@ def parse(url_file_stream_or_string, etag=None, modified=None, agent=None, refer feedparser = _StrictFeedParser(baseuri, baselang, 'utf-8') saxparser = xml.sax.make_parser(PREFERRED_XML_PARSERS) saxparser.setFeature(xml.sax.handler.feature_namespaces, 1) + try: + # disable downloading external doctype references, if possible + saxparser.setFeature(xml.sax.handler.feature_external_ges, 0) + except xml.sax.SAXNotSupportedException: + pass saxparser.setContentHandler(feedparser) saxparser.setErrorHandler(feedparser) source = xml.sax.xmlreader.InputSource() source.setByteStream(_StringIO(data)) - if hasattr(saxparser, '_ns_stack'): - # work around bug in built-in SAX parser (doesn't recognize xml: namespace) - # PyXML doesn't have this problem, and it doesn't have _ns_stack either - saxparser._ns_stack.append({'http://www.w3.org/XML/1998/namespace':'xml'}) try: saxparser.parse(source) - except Exception, e: - if _debug: - import traceback - traceback.print_stack() - traceback.print_exc() - sys.stderr.write('xml parsing failed\n') + except xml.sax.SAXException, e: result['bozo'] = 1 result['bozo_exception'] = feedparser.exc or e use_strict_parser = 0 - if not use_strict_parser: + if not use_strict_parser and _SGML_AVAILABLE: feedparser = _LooseFeedParser(baseuri, baselang, 'utf-8', entities) feedparser.feed(data.decode('utf-8', 'replace')) result['feed'] = feedparser.feeddata @@ -3801,85 +4011,3 @@ def parse(url_file_stream_or_string, etag=None, modified=None, agent=None, refer result['version'] = result['version'] or feedparser.version result['namespaces'] = feedparser.namespacesInUse return result - -class Serializer: - def __init__(self, results): - self.results = results - -class TextSerializer(Serializer): - def write(self, stream=sys.stdout): - self._writer(stream, self.results, '') - - def _writer(self, stream, node, prefix): - if not node: return - if hasattr(node, 'keys'): - keys = node.keys() - keys.sort() - for k in keys: - if k in ('description', 'link'): continue - if node.has_key(k + '_detail'): continue - if node.has_key(k + '_parsed'): continue - self._writer(stream, node[k], prefix + k + '.') - elif type(node) == types.ListType: - index = 0 - for n in node: - self._writer(stream, n, prefix[:-1] + '[' + str(index) + '].') - index += 1 - else: - try: - s = str(node).encode('utf-8') - s = s.replace('\\', '\\\\') - s = s.replace('\r', '') - s = s.replace('\n', r'\n') - stream.write(prefix[:-1]) - stream.write('=') - stream.write(s) - stream.write('\n') - except: - pass - -class PprintSerializer(Serializer): - def write(self, stream=sys.stdout): - if self.results.has_key('href'): - stream.write(self.results['href'] + '\n\n') - from pprint import pprint - pprint(self.results, stream) - stream.write('\n') - -if __name__ == '__main__': - try: - from optparse import OptionParser - except: - OptionParser = None - - if OptionParser: - optionParser = OptionParser(version=__version__, usage="%prog [options] url_or_filename_or_-") - optionParser.set_defaults(format="pprint") - optionParser.add_option("-A", "--user-agent", dest="agent", metavar="AGENT", help="User-Agent for HTTP URLs") - optionParser.add_option("-e", "--referer", "--referrer", dest="referrer", metavar="URL", help="Referrer for HTTP URLs") - optionParser.add_option("-t", "--etag", dest="etag", metavar="TAG", help="ETag/If-None-Match for HTTP URLs") - optionParser.add_option("-m", "--last-modified", dest="modified", metavar="DATE", help="Last-modified/If-Modified-Since for HTTP URLs (any supported date format)") - optionParser.add_option("-f", "--format", dest="format", metavar="FORMAT", help="output results in FORMAT (text, pprint)") - optionParser.add_option("-v", "--verbose", action="store_true", dest="verbose", default=False, help="write debugging information to stderr") - (options, urls) = optionParser.parse_args() - if options.verbose: - _debug = 1 - if not urls: - optionParser.print_help() - sys.exit(0) - else: - if not sys.argv[1:]: - print __doc__ - sys.exit(0) - class _Options: - etag = modified = agent = referrer = None - format = 'pprint' - options = _Options() - urls = sys.argv[1:] - - zopeCompatibilityHack() - - serializer = globals().get(options.format.capitalize() + 'Serializer', Serializer) - for url in urls: - results = parse(url, etag=options.etag, modified=options.modified, agent=options.agent, referrer=options.referrer) - serializer(results).write(sys.stdout) diff --git a/module/lib/jinja2/__init__.py b/module/lib/jinja2/__init__.py index f944e11b6..a4f7e9c4e 100644 --- a/module/lib/jinja2/__init__.py +++ b/module/lib/jinja2/__init__.py @@ -27,11 +27,7 @@ :license: BSD, see LICENSE for more details. """ __docformat__ = 'restructuredtext en' -try: - __version__ = __import__('pkg_resources') \ - .get_distribution('Jinja2').version -except: - __version__ = 'unknown' +__version__ = '2.7.3' # high level interface from jinja2.environment import Environment, Template diff --git a/module/lib/jinja2/_compat.py b/module/lib/jinja2/_compat.py new file mode 100644 index 000000000..8fa8a49a0 --- /dev/null +++ b/module/lib/jinja2/_compat.py @@ -0,0 +1,150 @@ +# -*- coding: utf-8 -*- +""" + jinja2._compat + ~~~~~~~~~~~~~~ + + Some py2/py3 compatibility support based on a stripped down + version of six so we don't have to depend on a specific version + of it. + + :copyright: Copyright 2013 by the Jinja team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" +import sys + +PY2 = sys.version_info[0] == 2 +PYPY = hasattr(sys, 'pypy_translation_info') +_identity = lambda x: x + + +if not PY2: + unichr = chr + range_type = range + text_type = str + string_types = (str,) + + iterkeys = lambda d: iter(d.keys()) + itervalues = lambda d: iter(d.values()) + iteritems = lambda d: iter(d.items()) + + import pickle + from io import BytesIO, StringIO + NativeStringIO = StringIO + + def reraise(tp, value, tb=None): + if value.__traceback__ is not tb: + raise value.with_traceback(tb) + raise value + + ifilter = filter + imap = map + izip = zip + intern = sys.intern + + implements_iterator = _identity + implements_to_string = _identity + encode_filename = _identity + get_next = lambda x: x.__next__ + +else: + unichr = unichr + text_type = unicode + range_type = xrange + string_types = (str, unicode) + + iterkeys = lambda d: d.iterkeys() + itervalues = lambda d: d.itervalues() + iteritems = lambda d: d.iteritems() + + import cPickle as pickle + from cStringIO import StringIO as BytesIO, StringIO + NativeStringIO = BytesIO + + exec('def reraise(tp, value, tb=None):\n raise tp, value, tb') + + from itertools import imap, izip, ifilter + intern = intern + + def implements_iterator(cls): + cls.next = cls.__next__ + del cls.__next__ + return cls + + def implements_to_string(cls): + cls.__unicode__ = cls.__str__ + cls.__str__ = lambda x: x.__unicode__().encode('utf-8') + return cls + + get_next = lambda x: x.next + + def encode_filename(filename): + if isinstance(filename, unicode): + return filename.encode('utf-8') + return filename + +try: + next = next +except NameError: + def next(it): + return it.next() + + +def with_metaclass(meta, *bases): + # This requires a bit of explanation: the basic idea is to make a + # dummy metaclass for one level of class instanciation that replaces + # itself with the actual metaclass. Because of internal type checks + # we also need to make sure that we downgrade the custom metaclass + # for one level to something closer to type (that's why __call__ and + # __init__ comes back from type etc.). + # + # This has the advantage over six.with_metaclass in that it does not + # introduce dummy classes into the final MRO. + class metaclass(meta): + __call__ = type.__call__ + __init__ = type.__init__ + def __new__(cls, name, this_bases, d): + if this_bases is None: + return type.__new__(cls, name, (), d) + return meta(name, bases, d) + return metaclass('temporary_class', None, {}) + + +try: + from collections import Mapping as mapping_types +except ImportError: + import UserDict + mapping_types = (UserDict.UserDict, UserDict.DictMixin, dict) + + +# common types. These do exist in the special types module too which however +# does not exist in IronPython out of the box. Also that way we don't have +# to deal with implementation specific stuff here +class _C(object): + def method(self): pass +def _func(): + yield None +function_type = type(_func) +generator_type = type(_func()) +method_type = type(_C().method) +code_type = type(_C.method.__code__) +try: + raise TypeError() +except TypeError: + _tb = sys.exc_info()[2] + traceback_type = type(_tb) + frame_type = type(_tb.tb_frame) + + +try: + from urllib.parse import quote_from_bytes as url_quote +except ImportError: + from urllib import quote as url_quote + + +try: + from thread import allocate_lock +except ImportError: + try: + from threading import Lock as allocate_lock + except ImportError: + from dummy_thread import allocate_lock diff --git a/module/lib/jinja2/_markupsafe/_bundle.py b/module/lib/jinja2/_markupsafe/_bundle.py deleted file mode 100644 index e694faf23..000000000 --- a/module/lib/jinja2/_markupsafe/_bundle.py +++ /dev/null @@ -1,49 +0,0 @@ -# -*- coding: utf-8 -*- -""" - jinja2._markupsafe._bundle - ~~~~~~~~~~~~~~~~~~~~~~~~~~ - - This script pulls in markupsafe from a source folder and - bundles it with Jinja2. It does not pull in the speedups - module though. - - :copyright: Copyright 2010 by the Jinja team, see AUTHORS. - :license: BSD, see LICENSE for details. -""" -import sys -import os -import re - - -def rewrite_imports(lines): - for idx, line in enumerate(lines): - new_line = re.sub(r'(import|from)\s+markupsafe\b', - r'\1 jinja2._markupsafe', line) - if new_line != line: - lines[idx] = new_line - - -def main(): - if len(sys.argv) != 2: - print 'error: only argument is path to markupsafe' - sys.exit(1) - basedir = os.path.dirname(__file__) - markupdir = sys.argv[1] - for filename in os.listdir(markupdir): - if filename.endswith('.py'): - f = open(os.path.join(markupdir, filename)) - try: - lines = list(f) - finally: - f.close() - rewrite_imports(lines) - f = open(os.path.join(basedir, filename), 'w') - try: - for line in lines: - f.write(line) - finally: - f.close() - - -if __name__ == '__main__': - main() diff --git a/module/lib/jinja2/_markupsafe/tests.py b/module/lib/jinja2/_markupsafe/tests.py deleted file mode 100644 index c1ce3943a..000000000 --- a/module/lib/jinja2/_markupsafe/tests.py +++ /dev/null @@ -1,80 +0,0 @@ -import gc -import unittest -from jinja2._markupsafe import Markup, escape, escape_silent - - -class MarkupTestCase(unittest.TestCase): - - def test_markup_operations(self): - # adding two strings should escape the unsafe one - unsafe = '<script type="application/x-some-script">alert("foo");</script>' - safe = Markup('<em>username</em>') - assert unsafe + safe == unicode(escape(unsafe)) + unicode(safe) - - # string interpolations are safe to use too - assert Markup('<em>%s</em>') % '<bad user>' == \ - '<em><bad user></em>' - assert Markup('<em>%(username)s</em>') % { - 'username': '<bad user>' - } == '<em><bad user></em>' - - # an escaped object is markup too - assert type(Markup('foo') + 'bar') is Markup - - # and it implements __html__ by returning itself - x = Markup("foo") - assert x.__html__() is x - - # it also knows how to treat __html__ objects - class Foo(object): - def __html__(self): - return '<em>awesome</em>' - def __unicode__(self): - return 'awesome' - assert Markup(Foo()) == '<em>awesome</em>' - assert Markup('<strong>%s</strong>') % Foo() == \ - '<strong><em>awesome</em></strong>' - - # escaping and unescaping - assert escape('"<>&\'') == '"<>&'' - assert Markup("<em>Foo & Bar</em>").striptags() == "Foo & Bar" - assert Markup("<test>").unescape() == "<test>" - - def test_all_set(self): - import jinja2._markupsafe as markup - for item in markup.__all__: - getattr(markup, item) - - def test_escape_silent(self): - assert escape_silent(None) == Markup() - assert escape(None) == Markup(None) - assert escape_silent('<foo>') == Markup(u'<foo>') - - -class MarkupLeakTestCase(unittest.TestCase): - - def test_markup_leaks(self): - counts = set() - for count in xrange(20): - for item in xrange(1000): - escape("foo") - escape("<foo>") - escape(u"foo") - escape(u"<foo>") - counts.add(len(gc.get_objects())) - assert len(counts) == 1, 'ouch, c extension seems to leak objects' - - -def suite(): - suite = unittest.TestSuite() - suite.addTest(unittest.makeSuite(MarkupTestCase)) - - # this test only tests the c extension - if not hasattr(escape, 'func_code'): - suite.addTest(unittest.makeSuite(MarkupLeakTestCase)) - - return suite - - -if __name__ == '__main__': - unittest.main(defaultTest='suite') diff --git a/module/lib/jinja2/_stringdefs.py b/module/lib/jinja2/_stringdefs.py index 1161b7f4a..da5830e9f 100644 --- a/module/lib/jinja2/_stringdefs.py +++ b/module/lib/jinja2/_stringdefs.py @@ -13,6 +13,8 @@ :license: BSD, see LICENSE for details. """ +from jinja2._compat import unichr + Cc = u'\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x7f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f' Cf = u'\xad\u0600\u0601\u0602\u0603\u06dd\u070f\u17b4\u17b5\u200b\u200c\u200d\u200e\u200f\u202a\u202b\u202c\u202d\u202e\u2060\u2061\u2062\u2063\u206a\u206b\u206c\u206d\u206e\u206f\ufeff\ufff9\ufffa\ufffb' diff --git a/module/lib/jinja2/bccache.py b/module/lib/jinja2/bccache.py index 1e2236c3a..2d28ab8b2 100644 --- a/module/lib/jinja2/bccache.py +++ b/module/lib/jinja2/bccache.py @@ -15,20 +15,46 @@ :license: BSD. """ from os import path, listdir +import os +import stat +import sys +import errno import marshal import tempfile -import cPickle as pickle import fnmatch -from cStringIO import StringIO -try: - from hashlib import sha1 -except ImportError: - from sha import new as sha1 +from hashlib import sha1 from jinja2.utils import open_if_exists +from jinja2._compat import BytesIO, pickle, PY2, text_type -bc_version = 1 -bc_magic = 'j2'.encode('ascii') + pickle.dumps(bc_version, 2) +# marshal works better on 3.x, one hack less required +if not PY2: + marshal_dump = marshal.dump + marshal_load = marshal.load +else: + + def marshal_dump(code, f): + if isinstance(f, file): + marshal.dump(code, f) + else: + f.write(marshal.dumps(code)) + + def marshal_load(f): + if isinstance(f, file): + return marshal.load(f) + return marshal.loads(f.read()) + + +bc_version = 2 + +# magic version used to only change with new jinja versions. With 2.6 +# we change this to also take Python version changes into account. The +# reason for this is that Python tends to segfault if fed earlier bytecode +# versions because someone thought it would be a good idea to reuse opcodes +# or make Python incompatible with earlier versions. +bc_magic = 'j2'.encode('ascii') + \ + pickle.dumps(bc_version, 2) + \ + pickle.dumps((sys.version_info[0] << 24) | sys.version_info[1]) class Bucket(object): @@ -62,12 +88,7 @@ class Bucket(object): if self.checksum != checksum: self.reset() return - # now load the code. Because marshal is not able to load - # from arbitrary streams we have to work around that - if isinstance(f, file): - self.code = marshal.load(f) - else: - self.code = marshal.loads(f.read()) + self.code = marshal_load(f) def write_bytecode(self, f): """Dump the bytecode into the file or file like object passed.""" @@ -75,18 +96,15 @@ class Bucket(object): raise TypeError('can\'t write empty bucket') f.write(bc_magic) pickle.dump(self.checksum, f, 2) - if isinstance(f, file): - marshal.dump(self.code, f) - else: - f.write(marshal.dumps(self.code)) + marshal_dump(self.code, f) def bytecode_from_string(self, string): """Load bytecode from a string.""" - self.load_bytecode(StringIO(string)) + self.load_bytecode(BytesIO(string)) def bytecode_to_string(self): """Return the bytecode as string.""" - out = StringIO() + out = BytesIO() self.write_bytecode(out) return out.getvalue() @@ -144,9 +162,10 @@ class BytecodeCache(object): """Returns the unique hash key for this template name.""" hash = sha1(name.encode('utf-8')) if filename is not None: - if isinstance(filename, unicode): + filename = '|' + filename + if isinstance(filename, text_type): filename = filename.encode('utf-8') - hash.update('|' + filename) + hash.update(filename) return hash.hexdigest() def get_source_checksum(self, source): @@ -173,7 +192,9 @@ class FileSystemBytecodeCache(BytecodeCache): two arguments: The directory where the cache items are stored and a pattern string that is used to build the filename. - If no directory is specified the system temporary items folder is used. + If no directory is specified a default cache directory is selected. On + Windows the user's temp directory is used, on UNIX systems a directory + is created for the user in the system temp directory. The pattern can be used to have multiple separate caches operate on the same directory. The default pattern is ``'__jinja2_%s.cache'``. ``%s`` @@ -186,10 +207,38 @@ class FileSystemBytecodeCache(BytecodeCache): def __init__(self, directory=None, pattern='__jinja2_%s.cache'): if directory is None: - directory = tempfile.gettempdir() + directory = self._get_default_cache_dir() self.directory = directory self.pattern = pattern + def _get_default_cache_dir(self): + tmpdir = tempfile.gettempdir() + + # On windows the temporary directory is used specific unless + # explicitly forced otherwise. We can just use that. + if os.name == 'nt': + return tmpdir + if not hasattr(os, 'getuid'): + raise RuntimeError('Cannot determine safe temp directory. You ' + 'need to explicitly provide one.') + + dirname = '_jinja2-cache-%d' % os.getuid() + actual_dir = os.path.join(tmpdir, dirname) + try: + os.mkdir(actual_dir, stat.S_IRWXU) # 0o700 + except OSError as e: + if e.errno != errno.EEXIST: + raise + + actual_dir_stat = os.lstat(actual_dir) + if actual_dir_stat.st_uid != os.getuid() \ + or not stat.S_ISDIR(actual_dir_stat.st_mode) \ + or stat.S_IMODE(actual_dir_stat.st_mode) != stat.S_IRWXU: + raise RuntimeError('Temporary directory \'%s\' has an incorrect ' + 'owner, permissions, or type.' % actual_dir) + + return actual_dir + def _get_cache_filename(self, bucket): return path.join(self.directory, self.pattern % bucket.key) @@ -261,15 +310,26 @@ class MemcachedBytecodeCache(BytecodeCache): This bytecode cache does not support clearing of used items in the cache. The clear method is a no-operation function. + + .. versionadded:: 2.7 + Added support for ignoring memcache errors through the + `ignore_memcache_errors` parameter. """ - def __init__(self, client, prefix='jinja2/bytecode/', timeout=None): + def __init__(self, client, prefix='jinja2/bytecode/', timeout=None, + ignore_memcache_errors=True): self.client = client self.prefix = prefix self.timeout = timeout + self.ignore_memcache_errors = ignore_memcache_errors def load_bytecode(self, bucket): - code = self.client.get(self.prefix + bucket.key) + try: + code = self.client.get(self.prefix + bucket.key) + except Exception: + if not self.ignore_memcache_errors: + raise + code = None if code is not None: bucket.bytecode_from_string(code) @@ -277,4 +337,8 @@ class MemcachedBytecodeCache(BytecodeCache): args = (self.prefix + bucket.key, bucket.bytecode_to_string()) if self.timeout is not None: args += (self.timeout,) - self.client.set(*args) + try: + self.client.set(*args) + except Exception: + if not self.ignore_memcache_errors: + raise diff --git a/module/lib/jinja2/compiler.py b/module/lib/jinja2/compiler.py index 57641596a..75a60b8d2 100644 --- a/module/lib/jinja2/compiler.py +++ b/module/lib/jinja2/compiler.py @@ -8,14 +8,16 @@ :copyright: (c) 2010 by the Jinja Team. :license: BSD, see LICENSE for more details. """ -from cStringIO import StringIO from itertools import chain from copy import deepcopy +from keyword import iskeyword as is_python_keyword from jinja2 import nodes from jinja2.nodes import EvalContext -from jinja2.visitor import NodeVisitor, NodeTransformer +from jinja2.visitor import NodeVisitor from jinja2.exceptions import TemplateAssertionError -from jinja2.utils import Markup, concat, escape, is_python_keyword, next +from jinja2.utils import Markup, concat, escape +from jinja2._compat import range_type, next, text_type, string_types, \ + iteritems, NativeStringIO, imap operators = { @@ -29,14 +31,6 @@ operators = { 'notin': 'not in' } -try: - exec '(0 if 0 else 0)' -except SyntaxError: - have_condexpr = False -else: - have_condexpr = True - - # what method to iterate over items do we want to use for dict iteration # in generated code? on 2.x let's go with iteritems, on 3.x with items if hasattr(dict, 'iteritems'): @@ -51,7 +45,11 @@ def unoptimize_before_dead_code(): def f(): if 0: dummy(x) return f -unoptimize_before_dead_code = bool(unoptimize_before_dead_code().func_closure) + +# The getattr is necessary for pypy which does not set this attribute if +# no closure is on the function +unoptimize_before_dead_code = bool( + getattr(unoptimize_before_dead_code(), '__closure__', None)) def generate(node, environment, name, filename, stream=None, @@ -69,8 +67,8 @@ def has_safe_repr(value): """Does the node have a safe representation?""" if value is None or value is NotImplemented or value is Ellipsis: return True - if isinstance(value, (bool, int, long, float, complex, basestring, - xrange, Markup)): + if isinstance(value, (bool, int, float, complex, range_type, + Markup) + string_types): return True if isinstance(value, (tuple, list, set, frozenset)): for item in value: @@ -78,7 +76,7 @@ def has_safe_repr(value): return False return True elif isinstance(value, dict): - for key, value in value.iteritems(): + for key, value in iteritems(value): if not has_safe_repr(key): return False if not has_safe_repr(value): @@ -127,12 +125,10 @@ class Identifiers(object): self.undeclared.discard(name) self.declared.add(name) - def is_declared(self, name, local_only=False): + def is_declared(self, name): """Check if a name is declared in this or an outer scope.""" if name in self.declared_locally or name in self.declared_parameter: return True - if local_only: - return False return name in self.declared def copy(self): @@ -193,12 +189,12 @@ class Frame(object): rv.identifiers.__dict__.update(self.identifiers.__dict__) return rv - def inspect(self, nodes, hard_scope=False): + def inspect(self, nodes): """Walk the node and check for identifiers. If the scope is hard (eg: enforce on a python level) overrides from outer scopes are tracked differently. """ - visitor = FrameIdentifierVisitor(self.identifiers, hard_scope) + visitor = FrameIdentifierVisitor(self.identifiers) for node in nodes: visitor.visit(node) @@ -275,9 +271,8 @@ class UndeclaredNameVisitor(NodeVisitor): class FrameIdentifierVisitor(NodeVisitor): """A visitor for `Frame.inspect`.""" - def __init__(self, identifiers, hard_scope): + def __init__(self, identifiers): self.identifiers = identifiers - self.hard_scope = hard_scope def visit_Name(self, node): """All assignments to names go through this function.""" @@ -286,7 +281,7 @@ class FrameIdentifierVisitor(NodeVisitor): elif node.ctx == 'param': self.identifiers.declared_parameter.add(node.name) elif node.ctx == 'load' and not \ - self.identifiers.is_declared(node.name, self.hard_scope): + self.identifiers.is_declared(node.name): self.identifiers.undeclared.add(node.name) def visit_If(self, node): @@ -371,7 +366,7 @@ class CodeGenerator(NodeVisitor): def __init__(self, environment, name, filename, stream=None, defer_init=False): if stream is None: - stream = StringIO() + stream = NativeStringIO() self.environment = environment self.name = name self.filename = filename @@ -545,7 +540,7 @@ class CodeGenerator(NodeVisitor): self.write(', ') self.visit(kwarg, frame) if extra_kwargs is not None: - for key, value in extra_kwargs.iteritems(): + for key, value in iteritems(extra_kwargs): self.write(', %s=%s' % (key, value)) if node.dyn_args: self.write(', *') @@ -561,7 +556,7 @@ class CodeGenerator(NodeVisitor): self.visit(kwarg.value, frame) self.write(', ') if extra_kwargs is not None: - for key, value in extra_kwargs.iteritems(): + for key, value in iteritems(extra_kwargs): self.write('%r: %s, ' % (key, value)) if node.dyn_kwargs is not None: self.write('}, **') @@ -628,7 +623,7 @@ class CodeGenerator(NodeVisitor): def pop_scope(self, aliases, frame): """Restore all aliases and delete unused variables.""" - for name, alias in aliases.iteritems(): + for name, alias in iteritems(aliases): self.writeline('l_%s = %s' % (name, alias)) to_delete = set() for name in frame.identifiers.declared_locally: @@ -658,7 +653,7 @@ class CodeGenerator(NodeVisitor): children = node.iter_child_nodes() children = list(children) func_frame = frame.inner() - func_frame.inspect(children, hard_scope=True) + func_frame.inspect(children) # variables that are undeclared (accessed before declaration) and # declared locally *and* part of an outside scope raise a template @@ -666,16 +661,16 @@ class CodeGenerator(NodeVisitor): # it without aliasing all the variables. # this could be fixed in Python 3 where we have the nonlocal # keyword or if we switch to bytecode generation - overriden_closure_vars = ( + overridden_closure_vars = ( func_frame.identifiers.undeclared & func_frame.identifiers.declared & (func_frame.identifiers.declared_locally | func_frame.identifiers.declared_parameter) ) - if overriden_closure_vars: + if overridden_closure_vars: self.fail('It\'s not possible to set and access variables ' 'derived from an outer scope! (affects: %s)' % - ', '.join(sorted(overriden_closure_vars)), node.lineno) + ', '.join(sorted(overridden_closure_vars)), node.lineno) # remove variables from a closure from the frame's undeclared # identifiers. @@ -830,7 +825,7 @@ class CodeGenerator(NodeVisitor): self.outdent(2 + (not self.has_known_extends)) # at this point we now have the blocks collected and can visit them too. - for name, block in self.blocks.iteritems(): + for name, block in iteritems(self.blocks): block_frame = Frame(eval_ctx) block_frame.inspect(block.body) block_frame.block = name @@ -897,12 +892,13 @@ class CodeGenerator(NodeVisitor): self.indent() self.writeline('raise TemplateRuntimeError(%r)' % 'extended multiple times') - self.outdent() # if we have a known extends already we don't need that code here # as we know that the template execution will end here. if self.has_known_extends: raise CompilerExit() + else: + self.outdent() self.writeline('parent_template = environment.get_template(', node) self.visit(node.template, frame) @@ -933,7 +929,7 @@ class CodeGenerator(NodeVisitor): func_name = 'get_or_select_template' if isinstance(node.template, nodes.Const): - if isinstance(node.template.value, basestring): + if isinstance(node.template.value, string_types): func_name = 'get_template' elif isinstance(node.template.value, (tuple, list)): func_name = 'select_template' @@ -1035,7 +1031,7 @@ class CodeGenerator(NodeVisitor): discarded_names[0]) else: self.writeline('context.exported_vars.difference_' - 'update((%s))' % ', '.join(map(repr, discarded_names))) + 'update((%s))' % ', '.join(imap(repr, discarded_names))) def visit_For(self, node, frame): # when calculating the nodes for the inner frame we have to exclude @@ -1063,7 +1059,7 @@ class CodeGenerator(NodeVisitor): # otherwise we set up a buffer and add a function def else: - self.writeline('def loop(reciter, loop_render_func):', node) + self.writeline('def loop(reciter, loop_render_func, depth=0):', node) self.indent() self.buffer(loop_frame) aliases = {} @@ -1071,6 +1067,7 @@ class CodeGenerator(NodeVisitor): # make sure the loop variable is a special one and raise a template # assertion error if a loop tries to write to loop if extended_loop: + self.writeline('l_loop = missing') loop_frame.identifiers.add_special('loop') for name in node.find_all(nodes.Name): if name.ctx == 'store' and name.name == 'loop': @@ -1089,7 +1086,7 @@ class CodeGenerator(NodeVisitor): node.iter_child_nodes(only=('else_', 'test')), ('loop',)): self.writeline("l_loop = environment.undefined(%r, name='loop')" % ("'loop' is undefined. the filter section of a loop as well " - "as the else block doesn't have access to the special 'loop'" + "as the else block don't have access to the special 'loop'" " variable of the current loop. Because there is no parent " "loop it's undefined. Happened in loop on %s" % self.position(node))) @@ -1121,7 +1118,7 @@ class CodeGenerator(NodeVisitor): self.visit(node.iter, loop_frame) if node.recursive: - self.write(', recurse=loop_render_func):') + self.write(', loop_render_func, depth):') else: self.write(extended_loop and '):' or ':') @@ -1219,9 +1216,9 @@ class CodeGenerator(NodeVisitor): return if self.environment.finalize: - finalize = lambda x: unicode(self.environment.finalize(x)) + finalize = lambda x: text_type(self.environment.finalize(x)) else: - finalize = unicode + finalize = text_type # if we are inside a frame that requires output checking, we do so outdent_later = False @@ -1249,7 +1246,7 @@ class CodeGenerator(NodeVisitor): else: const = escape(const) const = finalize(const) - except: + except Exception: # if something goes wrong here we evaluate the node # at runtime for easier debugging body.append(child) @@ -1370,7 +1367,7 @@ class CodeGenerator(NodeVisitor): public_names[0]) else: self.writeline('context.exported_vars.update((%s))' % - ', '.join(map(repr, public_names))) + ', '.join(imap(repr, public_names))) # -- Expression Visitors @@ -1421,19 +1418,31 @@ class CodeGenerator(NodeVisitor): self.visit(item.value, frame) self.write('}') - def binop(operator): + def binop(operator, interceptable=True): def visitor(self, node, frame): - self.write('(') - self.visit(node.left, frame) - self.write(' %s ' % operator) - self.visit(node.right, frame) + if self.environment.sandboxed and \ + operator in self.environment.intercepted_binops: + self.write('environment.call_binop(context, %r, ' % operator) + self.visit(node.left, frame) + self.write(', ') + self.visit(node.right, frame) + else: + self.write('(') + self.visit(node.left, frame) + self.write(' %s ' % operator) + self.visit(node.right, frame) self.write(')') return visitor - def uaop(operator): + def uaop(operator, interceptable=True): def visitor(self, node, frame): - self.write('(' + operator) - self.visit(node.node, frame) + if self.environment.sandboxed and \ + operator in self.environment.intercepted_unops: + self.write('environment.call_unop(context, %r, ' % operator) + self.visit(node.node, frame) + else: + self.write('(' + operator) + self.visit(node.node, frame) self.write(')') return visitor @@ -1444,11 +1453,11 @@ class CodeGenerator(NodeVisitor): visit_FloorDiv = binop('//') visit_Pow = binop('**') visit_Mod = binop('%') - visit_And = binop('and') - visit_Or = binop('or') + visit_And = binop('and', interceptable=False) + visit_Or = binop('or', interceptable=False) visit_Pos = uaop('+') visit_Neg = uaop('-') - visit_Not = uaop('not ') + visit_Not = uaop('not ', interceptable=False) del binop, uaop def visit_Concat(self, node, frame): @@ -1546,22 +1555,13 @@ class CodeGenerator(NodeVisitor): 'expression on %s evaluated to false and ' 'no else section was defined.' % self.position(node))) - if not have_condexpr: - self.write('((') - self.visit(node.test, frame) - self.write(') and (') - self.visit(node.expr1, frame) - self.write(',) or (') - write_expr2() - self.write(',))[0]') - else: - self.write('(') - self.visit(node.expr1, frame) - self.write(' if ') - self.visit(node.test, frame) - self.write(' else ') - write_expr2() - self.write(')') + self.write('(') + self.visit(node.expr1, frame) + self.write(' if ') + self.visit(node.test, frame) + self.write(' else ') + write_expr2() + self.write(')') def visit_Call(self, node, frame, forward_caller=False): if self.environment.sandboxed: diff --git a/module/lib/jinja2/debug.py b/module/lib/jinja2/debug.py index eb15456d1..815cc18a4 100644 --- a/module/lib/jinja2/debug.py +++ b/module/lib/jinja2/debug.py @@ -12,13 +12,21 @@ """ import sys import traceback -from jinja2.utils import CodeType, missing, internal_code +from types import TracebackType +from jinja2.utils import missing, internal_code from jinja2.exceptions import TemplateSyntaxError +from jinja2._compat import iteritems, reraise, code_type + +# on pypy we can take advantage of transparent proxies +try: + from __pypy__ import tproxy +except ImportError: + tproxy = None # how does the raise helper look like? try: - exec "raise TypeError, 'foo'" + exec("raise TypeError, 'foo'") except SyntaxError: raise_helper = 'raise __jinja_exception__[1]' except TypeError: @@ -30,17 +38,22 @@ class TracebackFrameProxy(object): def __init__(self, tb): self.tb = tb + self._tb_next = None - def _set_tb_next(self, next): - if tb_set_next is not None: - tb_set_next(self.tb, next and next.tb or None) - self._tb_next = next - - def _get_tb_next(self): + @property + def tb_next(self): return self._tb_next - tb_next = property(_get_tb_next, _set_tb_next) - del _get_tb_next, _set_tb_next + def set_next(self, next): + if tb_set_next is not None: + try: + tb_set_next(self.tb, next and next.tb or None) + except Exception: + # this function can fail due to all the hackery it does + # on various python implementations. We just catch errors + # down and ignore them if necessary. + pass + self._tb_next = next @property def is_jinja_frame(self): @@ -50,8 +63,22 @@ class TracebackFrameProxy(object): return getattr(self.tb, name) +def make_frame_proxy(frame): + proxy = TracebackFrameProxy(frame) + if tproxy is None: + return proxy + def operation_handler(operation, *args, **kwargs): + if operation in ('__getattribute__', '__getattr__'): + return getattr(proxy, args[0]) + elif operation == '__setattr__': + proxy.__setattr__(*args, **kwargs) + else: + return getattr(proxy, operation)(*args, **kwargs) + return tproxy(TracebackType, operation_handler) + + class ProcessedTraceback(object): - """Holds a Jinja preprocessed traceback for priting or reraising.""" + """Holds a Jinja preprocessed traceback for printing or reraising.""" def __init__(self, exc_type, exc_value, frames): assert frames, 'no frames for this traceback?' @@ -59,14 +86,13 @@ class ProcessedTraceback(object): self.exc_value = exc_value self.frames = frames - def chain_frames(self): - """Chains the frames. Requires ctypes or the debugsupport extension.""" + # newly concatenate the frames (which are proxies) prev_tb = None for tb in self.frames: if prev_tb is not None: - prev_tb.tb_next = tb + prev_tb.set_next(tb) prev_tb = tb - prev_tb.tb_next = None + prev_tb.set_next(None) def render_as_text(self, limit=None): """Return a string with the traceback.""" @@ -95,7 +121,12 @@ class ProcessedTraceback(object): @property def standard_exc_info(self): """Standard python exc_info for re-raising""" - return self.exc_type, self.exc_value, self.frames[0].tb + tb = self.frames[0] + # the frame will be an actual traceback (or transparent proxy) if + # we are on pypy or a python implementation with support for tproxy + if type(tb) is not TracebackType: + tb = tb.tb + return self.exc_type, self.exc_value, tb def make_traceback(exc_info, source_hint=None): @@ -128,7 +159,7 @@ def translate_exception(exc_info, initial_skip=0): frames = [] # skip some internal frames if wanted - for x in xrange(initial_skip): + for x in range(initial_skip): if tb is not None: tb = tb.tb_next initial_tb = tb @@ -152,19 +183,16 @@ def translate_exception(exc_info, initial_skip=0): tb = fake_exc_info(exc_info[:2] + (tb,), template.filename, lineno)[2] - frames.append(TracebackFrameProxy(tb)) + frames.append(make_frame_proxy(tb)) tb = next # if we don't have any exceptions in the frames left, we have to # reraise it unchanged. # XXX: can we backup here? when could this happen? if not frames: - raise exc_info[0], exc_info[1], exc_info[2] + reraise(exc_info[0], exc_info[1], exc_info[2]) - traceback = ProcessedTraceback(exc_info[0], exc_info[1], frames) - if tb_set_next is not None: - traceback.chain_frames() - return traceback + return ProcessedTraceback(exc_info[0], exc_info[1], frames) def fake_exc_info(exc_info, filename, lineno): @@ -179,7 +207,7 @@ def fake_exc_info(exc_info, filename, lineno): locals = ctx.get_all() else: locals = {} - for name, value in real_locals.iteritems(): + for name, value in iteritems(real_locals): if name.startswith('l_') and value is not missing: locals[name[2:]] = value @@ -217,17 +245,17 @@ def fake_exc_info(exc_info, filename, lineno): location = 'block "%s"' % function[6:] else: location = 'template' - code = CodeType(0, code.co_nlocals, code.co_stacksize, - code.co_flags, code.co_code, code.co_consts, - code.co_names, code.co_varnames, filename, - location, code.co_firstlineno, - code.co_lnotab, (), ()) + code = code_type(0, code.co_nlocals, code.co_stacksize, + code.co_flags, code.co_code, code.co_consts, + code.co_names, code.co_varnames, filename, + location, code.co_firstlineno, + code.co_lnotab, (), ()) except: pass # execute the code and catch the new traceback try: - exec code in globals, locals + exec(code, globals, locals) except: exc_info = sys.exc_info() new_tb = exc_info[2].tb_next @@ -239,7 +267,8 @@ def fake_exc_info(exc_info, filename, lineno): def _init_ugly_crap(): """This function implements a few ugly things so that we can patch the traceback objects. The function returned allows resetting `tb_next` on - any python traceback object. + any python traceback object. Do not attempt to use this on non cpython + interpreters """ import ctypes from types import TracebackType @@ -297,12 +326,12 @@ def _init_ugly_crap(): return tb_set_next -# try to get a tb_set_next implementation -try: - from jinja2._debugsupport import tb_set_next -except ImportError: +# try to get a tb_set_next implementation if we don't have transparent +# proxies. +tb_set_next = None +if tproxy is None: try: tb_set_next = _init_ugly_crap() except: - tb_set_next = None -del _init_ugly_crap + pass + del _init_ugly_crap diff --git a/module/lib/jinja2/defaults.py b/module/lib/jinja2/defaults.py index d2d45443a..a27cb80cb 100644 --- a/module/lib/jinja2/defaults.py +++ b/module/lib/jinja2/defaults.py @@ -8,6 +8,7 @@ :copyright: (c) 2010 by the Jinja Team. :license: BSD, see LICENSE for more details. """ +from jinja2._compat import range_type from jinja2.utils import generate_lorem_ipsum, Cycler, Joiner @@ -21,14 +22,16 @@ COMMENT_END_STRING = '#}' LINE_STATEMENT_PREFIX = None LINE_COMMENT_PREFIX = None TRIM_BLOCKS = False +LSTRIP_BLOCKS = False NEWLINE_SEQUENCE = '\n' +KEEP_TRAILING_NEWLINE = False # default filters, tests and namespace from jinja2.filters import FILTERS as DEFAULT_FILTERS from jinja2.tests import TESTS as DEFAULT_TESTS DEFAULT_NAMESPACE = { - 'range': xrange, + 'range': range_type, 'dict': lambda **kw: kw, 'lipsum': generate_lorem_ipsum, 'cycler': Cycler, diff --git a/module/lib/jinja2/environment.py b/module/lib/jinja2/environment.py index ac74a5c68..45fabada2 100644 --- a/module/lib/jinja2/environment.py +++ b/module/lib/jinja2/environment.py @@ -11,16 +11,26 @@ import os import sys from jinja2 import nodes -from jinja2.defaults import * +from jinja2.defaults import BLOCK_START_STRING, \ + BLOCK_END_STRING, VARIABLE_START_STRING, VARIABLE_END_STRING, \ + COMMENT_START_STRING, COMMENT_END_STRING, LINE_STATEMENT_PREFIX, \ + LINE_COMMENT_PREFIX, TRIM_BLOCKS, NEWLINE_SEQUENCE, \ + DEFAULT_FILTERS, DEFAULT_TESTS, DEFAULT_NAMESPACE, \ + KEEP_TRAILING_NEWLINE, LSTRIP_BLOCKS from jinja2.lexer import get_lexer, TokenStream from jinja2.parser import Parser +from jinja2.nodes import EvalContext from jinja2.optimizer import optimize from jinja2.compiler import generate from jinja2.runtime import Undefined, new_context from jinja2.exceptions import TemplateSyntaxError, TemplateNotFound, \ - TemplatesNotFound + TemplatesNotFound, TemplateRuntimeError from jinja2.utils import import_string, LRUCache, Markup, missing, \ - concat, consume, internalcode, _encode_filename + concat, consume, internalcode +from jinja2._compat import imap, ifilter, string_types, iteritems, \ + text_type, reraise, implements_iterator, implements_to_string, \ + get_next, encode_filename, PY2, PYPY +from functools import reduce # for direct template usage we have up to ten living environments @@ -67,11 +77,11 @@ def copy_cache(cache): def load_extensions(environment, extensions): """Load the extensions from the list and bind it to the environment. - Returns a dict of instanciated environments. + Returns a dict of instantiated environments. """ result = {} for extension in extensions: - if isinstance(extension, basestring): + if isinstance(extension, string_types): extension = import_string(extension) result[extension.identifier] = extension(environment) return result @@ -134,12 +144,23 @@ class Environment(object): If this is set to ``True`` the first newline after a block is removed (block, not variable tag!). Defaults to `False`. + `lstrip_blocks` + If this is set to ``True`` leading spaces and tabs are stripped + from the start of a line to a block. Defaults to `False`. + `newline_sequence` The sequence that starts a newline. Must be one of ``'\r'``, ``'\n'`` or ``'\r\n'``. The default is ``'\n'`` which is a useful default for Linux and OS X systems as well as web applications. + `keep_trailing_newline` + Preserve the trailing newline when rendering templates. + The default is ``False``, which causes a single newline, + if present, to be stripped from the end of the template. + + .. versionadded:: 2.7 + `extensions` List of Jinja extensions to use. This can either be import paths as strings or extension classes. For more information have a @@ -196,7 +217,8 @@ class Environment(object): #: if this environment is sandboxed. Modifying this variable won't make #: the environment sandboxed though. For a real sandboxed environment - #: have a look at jinja2.sandbox + #: have a look at jinja2.sandbox. This flag alone controls the code + #: generation by the compiler. sandboxed = False #: True if the environment is just an overlay @@ -223,7 +245,9 @@ class Environment(object): line_statement_prefix=LINE_STATEMENT_PREFIX, line_comment_prefix=LINE_COMMENT_PREFIX, trim_blocks=TRIM_BLOCKS, + lstrip_blocks=LSTRIP_BLOCKS, newline_sequence=NEWLINE_SEQUENCE, + keep_trailing_newline=KEEP_TRAILING_NEWLINE, extensions=(), optimized=True, undefined=Undefined, @@ -238,7 +262,7 @@ class Environment(object): # passed by keyword rather than position. However it's important to # not change the order of arguments because it's used at least # internally in those cases: - # - spontaneus environments (i18n extension and Template) + # - spontaneous environments (i18n extension and Template) # - unittests # If parameter changes are required only add parameters at the end # and don't change the arguments (or the defaults!) of the arguments @@ -254,7 +278,9 @@ class Environment(object): self.line_statement_prefix = line_statement_prefix self.line_comment_prefix = line_comment_prefix self.trim_blocks = trim_blocks + self.lstrip_blocks = lstrip_blocks self.newline_sequence = newline_sequence + self.keep_trailing_newline = keep_trailing_newline # runtime information self.undefined = undefined @@ -269,7 +295,6 @@ class Environment(object): # set the loader provided self.loader = loader - self.bytecode_cache = None self.cache = create_cache(cache_size) self.bytecode_cache = bytecode_cache self.auto_reload = auto_reload @@ -291,7 +316,7 @@ class Environment(object): yet. This is used by :ref:`extensions <writing-extensions>` to register callbacks and configuration values without breaking inheritance. """ - for key, value in attributes.iteritems(): + for key, value in iteritems(attributes): if not hasattr(self, key): setattr(self, key, value) @@ -299,7 +324,8 @@ class Environment(object): variable_start_string=missing, variable_end_string=missing, comment_start_string=missing, comment_end_string=missing, line_statement_prefix=missing, line_comment_prefix=missing, - trim_blocks=missing, extensions=missing, optimized=missing, + trim_blocks=missing, lstrip_blocks=missing, + extensions=missing, optimized=missing, undefined=missing, finalize=missing, autoescape=missing, loader=missing, cache_size=missing, auto_reload=missing, bytecode_cache=missing): @@ -322,7 +348,7 @@ class Environment(object): rv.overlayed = True rv.linked_to = self - for key, value in args.iteritems(): + for key, value in iteritems(args): if value is not missing: setattr(rv, key, value) @@ -332,7 +358,7 @@ class Environment(object): rv.cache = copy_cache(self.cache) rv.extensions = {} - for key, value in self.extensions.iteritems(): + for key, value in iteritems(self.extensions): rv.extensions[key] = value.bind(rv) if extensions is not missing: rv.extensions.update(load_extensions(rv, extensions)) @@ -351,10 +377,10 @@ class Environment(object): try: return obj[argument] except (TypeError, LookupError): - if isinstance(argument, basestring): + if isinstance(argument, string_types): try: attr = str(argument) - except: + except Exception: pass else: try: @@ -376,6 +402,42 @@ class Environment(object): except (TypeError, LookupError, AttributeError): return self.undefined(obj=obj, name=attribute) + def call_filter(self, name, value, args=None, kwargs=None, + context=None, eval_ctx=None): + """Invokes a filter on a value the same way the compiler does it. + + .. versionadded:: 2.7 + """ + func = self.filters.get(name) + if func is None: + raise TemplateRuntimeError('no filter named %r' % name) + args = [value] + list(args or ()) + if getattr(func, 'contextfilter', False): + if context is None: + raise TemplateRuntimeError('Attempted to invoke context ' + 'filter without context') + args.insert(0, context) + elif getattr(func, 'evalcontextfilter', False): + if eval_ctx is None: + if context is not None: + eval_ctx = context.eval_ctx + else: + eval_ctx = EvalContext(self) + args.insert(0, eval_ctx) + elif getattr(func, 'environmentfilter', False): + args.insert(0, self) + return func(*args, **(kwargs or {})) + + def call_test(self, name, value, args=None, kwargs=None): + """Invokes a test on a value the same way the compiler does it. + + .. versionadded:: 2.7 + """ + func = self.tests.get(name) + if func is None: + raise TemplateRuntimeError('no test named %r' % name) + return func(value, *(args or ()), **(kwargs or {})) + @internalcode def parse(self, source, name=None, filename=None): """Parse the sourcecode and return the abstract syntax tree. This @@ -394,7 +456,7 @@ class Environment(object): def _parse(self, source, name, filename): """Internal parsing function used by `parse` and `compile`.""" - return Parser(self, source, name, _encode_filename(filename)).parse() + return Parser(self, source, name, encode_filename(filename)).parse() def lex(self, source, name=None, filename=None): """Lex the given sourcecode and return a generator that yields @@ -406,7 +468,7 @@ class Environment(object): of the extensions to be applied you have to filter source through the :meth:`preprocess` method. """ - source = unicode(source) + source = text_type(source) try: return self.lexer.tokeniter(source, name, filename) except TemplateSyntaxError: @@ -419,7 +481,7 @@ class Environment(object): because there you usually only want the actual source tokenized. """ return reduce(lambda s, e: e.preprocess(s, name, filename), - self.iter_extensions(), unicode(source)) + self.iter_extensions(), text_type(source)) def _tokenize(self, source, name, filename=None, state=None): """Called by the parser to do the preprocessing and filtering @@ -434,7 +496,7 @@ class Environment(object): return stream def _generate(self, source, name, filename, defer_init=False): - """Internal hook that can be overriden to hook a different generate + """Internal hook that can be overridden to hook a different generate method in. .. versionadded:: 2.5 @@ -442,7 +504,7 @@ class Environment(object): return generate(source, self, name, filename, defer_init=defer_init) def _compile(self, source, filename): - """Internal hook that can be overriden to hook a different compile + """Internal hook that can be overridden to hook a different compile method in. .. versionadded:: 2.5 @@ -473,7 +535,7 @@ class Environment(object): """ source_hint = None try: - if isinstance(source, basestring): + if isinstance(source, string_types): source_hint = source source = self._parse(source, name, filename) if self.optimized: @@ -485,7 +547,7 @@ class Environment(object): if filename is None: filename = '<template>' else: - filename = _encode_filename(filename) + filename = encode_filename(filename) return self._compile(source, filename) except TemplateSyntaxError: exc_info = sys.exc_info() @@ -539,7 +601,7 @@ class Environment(object): def compile_templates(self, target, extensions=None, filter_func=None, zip='deflated', log_function=None, ignore_errors=True, py_compile=False): - """Compiles all the templates the loader can find, compiles them + """Finds all the templates the loader can find, compiles them and stores them in `target`. If `zip` is `None`, instead of in a zipfile, the templates will be will be stored in a directory. By default a deflate zip algorithm is used, to switch to @@ -555,7 +617,9 @@ class Environment(object): to `False` and you will get an exception on syntax errors. If `py_compile` is set to `True` .pyc files will be written to the - target instead of standard .py files. + target instead of standard .py files. This flag does not do anything + on pypy and Python 3 where pyc files are not picked up by itself and + don't give much benefit. .. versionadded:: 2.4 """ @@ -565,14 +629,23 @@ class Environment(object): log_function = lambda x: None if py_compile: - import imp, struct, marshal - py_header = imp.get_magic() + \ - u'\xff\xff\xff\xff'.encode('iso-8859-15') + if not PY2 or PYPY: + from warnings import warn + warn(Warning('py_compile has no effect on pypy or Python 3')) + py_compile = False + else: + import imp, marshal + py_header = imp.get_magic() + \ + u'\xff\xff\xff\xff'.encode('iso-8859-15') + + # Python 3.3 added a source filesize to the header + if sys.version_info >= (3, 3): + py_header += u'\x00\x00\x00\x00'.encode('iso-8859-15') def write_file(filename, data, mode): if zip: info = ZipInfo(filename) - info.external_attr = 0755 << 16L + info.external_attr = 0o755 << 16 zip_file.writestr(info, data) else: f = open(os.path.join(target, filename), mode) @@ -596,7 +669,7 @@ class Environment(object): source, filename, _ = self.loader.get_source(self, name) try: code = self.compile(source, name, filename, True, True) - except TemplateSyntaxError, e: + except TemplateSyntaxError as e: if not ignore_errors: raise log_function('Could not compile "%s": %s' % (name, e)) @@ -605,7 +678,7 @@ class Environment(object): filename = ModuleLoader.get_module_filename(name) if py_compile: - c = self._compile(code, _encode_filename(filename)) + c = self._compile(code, encode_filename(filename)) write_file(filename + 'c', py_header + marshal.dumps(c), 'wb') log_function('Byte-compiled "%s" as %s' % @@ -632,6 +705,8 @@ class Environment(object): in the result list. If the loader does not support that, a :exc:`TypeError` is raised. + + .. versionadded:: 2.4 """ x = self.loader.list_templates() if extensions is not None: @@ -641,7 +716,7 @@ class Environment(object): filter_func = lambda x: '.' in x and \ x.rsplit('.', 1)[1] in extensions if filter_func is not None: - x = filter(filter_func, x) + x = ifilter(filter_func, x) return x def handle_exception(self, exc_info=None, rendered=False, source_hint=None): @@ -664,7 +739,7 @@ class Environment(object): if self.exception_handler is not None: self.exception_handler(traceback) exc_type, exc_value, tb = traceback.standard_exc_info - raise exc_type, exc_value, tb + reraise(exc_type, exc_value, tb) def join_path(self, template, parent): """Join a template with the parent. By default all the lookups are @@ -751,7 +826,7 @@ class Environment(object): .. versionadded:: 2.3 """ - if isinstance(template_name_or_list, basestring): + if isinstance(template_name_or_list, string_types): return self.get_template(template_name_or_list, parent, globals) elif isinstance(template_name_or_list, Template): return template_name_or_list @@ -813,7 +888,9 @@ class Template(object): line_statement_prefix=LINE_STATEMENT_PREFIX, line_comment_prefix=LINE_COMMENT_PREFIX, trim_blocks=TRIM_BLOCKS, + lstrip_blocks=LSTRIP_BLOCKS, newline_sequence=NEWLINE_SEQUENCE, + keep_trailing_newline=KEEP_TRAILING_NEWLINE, extensions=(), optimized=True, undefined=Undefined, @@ -823,8 +900,9 @@ class Template(object): block_start_string, block_end_string, variable_start_string, variable_end_string, comment_start_string, comment_end_string, line_statement_prefix, line_comment_prefix, trim_blocks, - newline_sequence, frozenset(extensions), optimized, undefined, - finalize, autoescape, None, 0, False, None) + lstrip_blocks, newline_sequence, keep_trailing_newline, + frozenset(extensions), optimized, undefined, finalize, autoescape, + None, 0, False, None) return env.from_string(source, template_class=cls) @classmethod @@ -836,7 +914,7 @@ class Template(object): 'environment': environment, '__file__': code.co_filename } - exec code in namespace + exec(code, namespace) rv = cls._from_namespace(environment, namespace, globals) rv._uptodate = uptodate return rv @@ -886,7 +964,7 @@ class Template(object): vars = dict(*args, **kwargs) try: return concat(self.root_render_func(self.new_context(vars))) - except: + except Exception: exc_info = sys.exc_info() return self.environment.handle_exception(exc_info, True) @@ -908,7 +986,7 @@ class Template(object): try: for event in self.root_render_func(self.new_context(vars)): yield event - except: + except Exception: exc_info = sys.exc_info() else: return @@ -970,7 +1048,7 @@ class Template(object): @property def debug_info(self): """The debug info mapping.""" - return [tuple(map(int, x.split('='))) for x in + return [tuple(imap(int, x.split('='))) for x in self._debug_info.split('&')] def __repr__(self): @@ -981,6 +1059,7 @@ class Template(object): return '<%s %s>' % (self.__class__.__name__, name) +@implements_to_string class TemplateModule(object): """Represents an imported template. All the exported names of the template are available as attributes on this object. Additionally @@ -996,13 +1075,6 @@ class TemplateModule(object): return Markup(concat(self._body_stream)) def __str__(self): - return unicode(self).encode('utf-8') - - # unicode goes after __str__ because we configured 2to3 to rename - # __unicode__ to __str__. because the 2to3 tree is not designed to - # remove nodes from it, we leave the above __str__ around and let - # it override at runtime. - def __unicode__(self): return concat(self._body_stream) def __repr__(self): @@ -1032,6 +1104,7 @@ class TemplateExpression(object): return rv +@implements_iterator class TemplateStream(object): """A template stream works pretty much like an ordinary python generator but it can buffer multiple items to reduce the number of total iterations. @@ -1050,15 +1123,15 @@ class TemplateStream(object): def dump(self, fp, encoding=None, errors='strict'): """Dump the complete stream into a file or file-like object. Per default unicode strings are written, if you want to encode - before writing specifiy an `encoding`. + before writing specify an `encoding`. Example usage:: Template('Hello {{ name }}!').stream(name='foo').dump('hello.html') """ close = False - if isinstance(fp, basestring): - fp = file(fp, 'w') + if isinstance(fp, string_types): + fp = open(fp, encoding is None and 'w' or 'wb') close = True try: if encoding is not None: @@ -1076,7 +1149,7 @@ class TemplateStream(object): def disable_buffering(self): """Disable the output buffering.""" - self._next = self._gen.next + self._next = get_next(self._gen) self.buffered = False def enable_buffering(self, size=5): @@ -1104,12 +1177,12 @@ class TemplateStream(object): c_size = 0 self.buffered = True - self._next = generator(self._gen.next).next + self._next = get_next(generator(get_next(self._gen))) def __iter__(self): return self - def next(self): + def __next__(self): return self._next() diff --git a/module/lib/jinja2/exceptions.py b/module/lib/jinja2/exceptions.py index 771f6a8d7..c9df6dc7c 100644 --- a/module/lib/jinja2/exceptions.py +++ b/module/lib/jinja2/exceptions.py @@ -8,24 +8,40 @@ :copyright: (c) 2010 by the Jinja Team. :license: BSD, see LICENSE for more details. """ +from jinja2._compat import imap, text_type, PY2, implements_to_string class TemplateError(Exception): """Baseclass for all template errors.""" - def __init__(self, message=None): - if message is not None: - message = unicode(message).encode('utf-8') - Exception.__init__(self, message) - - @property - def message(self): - if self.args: - message = self.args[0] + if PY2: + def __init__(self, message=None): if message is not None: - return message.decode('utf-8', 'replace') - - + message = text_type(message).encode('utf-8') + Exception.__init__(self, message) + + @property + def message(self): + if self.args: + message = self.args[0] + if message is not None: + return message.decode('utf-8', 'replace') + + def __unicode__(self): + return self.message or u'' + else: + def __init__(self, message=None): + Exception.__init__(self, message) + + @property + def message(self): + if self.args: + message = self.args[0] + if message is not None: + return message + + +@implements_to_string class TemplateNotFound(IOError, LookupError, TemplateError): """Raised if a template does not exist.""" @@ -42,13 +58,6 @@ class TemplateNotFound(IOError, LookupError, TemplateError): self.templates = [name] def __str__(self): - return self.message.encode('utf-8') - - # unicode goes after __str__ because we configured 2to3 to rename - # __unicode__ to __str__. because the 2to3 tree is not designed to - # remove nodes from it, we leave the above __str__ around and let - # it override at runtime. - def __unicode__(self): return self.message @@ -62,12 +71,13 @@ class TemplatesNotFound(TemplateNotFound): def __init__(self, names=(), message=None): if message is None: - message = u'non of the templates given were found: ' + \ - u', '.join(map(unicode, names)) + message = u'none of the templates given were found: ' + \ + u', '.join(imap(text_type, names)) TemplateNotFound.__init__(self, names and names[-1] or None, message) self.templates = list(names) +@implements_to_string class TemplateSyntaxError(TemplateError): """Raised to tell the user that there is a problem with the template.""" @@ -83,13 +93,6 @@ class TemplateSyntaxError(TemplateError): self.translated = False def __str__(self): - return unicode(self).encode('utf-8') - - # unicode goes after __str__ because we configured 2to3 to rename - # __unicode__ to __str__. because the 2to3 tree is not designed to - # remove nodes from it, we leave the above __str__ around and let - # it override at runtime. - def __unicode__(self): # for translated errors we only return the message if self.translated: return self.message diff --git a/module/lib/jinja2/ext.py b/module/lib/jinja2/ext.py index ceb38953a..c2df12d55 100644 --- a/module/lib/jinja2/ext.py +++ b/module/lib/jinja2/ext.py @@ -10,13 +10,17 @@ :copyright: (c) 2010 by the Jinja Team. :license: BSD. """ -from collections import deque from jinja2 import nodes -from jinja2.defaults import * +from jinja2.defaults import BLOCK_START_STRING, \ + BLOCK_END_STRING, VARIABLE_START_STRING, VARIABLE_END_STRING, \ + COMMENT_START_STRING, COMMENT_END_STRING, LINE_STATEMENT_PREFIX, \ + LINE_COMMENT_PREFIX, TRIM_BLOCKS, NEWLINE_SEQUENCE, \ + KEEP_TRAILING_NEWLINE, LSTRIP_BLOCKS from jinja2.environment import Environment -from jinja2.runtime import Undefined, concat +from jinja2.runtime import concat from jinja2.exceptions import TemplateAssertionError, TemplateSyntaxError -from jinja2.utils import contextfunction, import_string, Markup, next +from jinja2.utils import contextfunction, import_string, Markup +from jinja2._compat import next, with_metaclass, string_types, iteritems # the only real useful gettext functions for a Jinja template. Note @@ -34,7 +38,7 @@ class ExtensionRegistry(type): return rv -class Extension(object): +class Extension(with_metaclass(ExtensionRegistry, object)): """Extensions can be used to add extra functionality to the Jinja template system at the parser level. Custom extensions are bound to an environment but may not store environment specific data on `self`. The reason for @@ -52,7 +56,6 @@ class Extension(object): is a terrible name, ``fragment_cache_prefix`` on the other hand is a good name as includes the name of the extension (fragment cache). """ - __metaclass__ = ExtensionRegistry #: if this extension parses this is the list of tags it's listening to. tags = set() @@ -103,7 +106,9 @@ class Extension(object): def attr(self, name, lineno=None): """Return an attribute node for the current extension. This is useful - to pass constants on extensions to generated template code:: + to pass constants on extensions to generated template code. + + :: self.attr('_my_attribute', lineno=lineno) """ @@ -203,7 +208,7 @@ class InternationalizationExtension(Extension): self.environment.globals.pop(key, None) def _extract(self, source, gettext_functions=GETTEXT_FUNCTIONS): - if isinstance(source, basestring): + if isinstance(source, string_types): source = self.environment.parse(source) return extract_from_ast(source, gettext_functions) @@ -216,6 +221,7 @@ class InternationalizationExtension(Extension): # defined in the body of the trans block too, but this is checked at # a later state. plural_expr = None + plural_expr_assignment = None variables = {} while parser.stream.current.type != 'block_end': if variables: @@ -239,7 +245,13 @@ class InternationalizationExtension(Extension): variables[name.value] = var = nodes.Name(name.value, 'load') if plural_expr is None: - plural_expr = var + if isinstance(var, nodes.Call): + plural_expr = nodes.Name('_trans', 'load') + variables[name.value] = plural_expr + plural_expr_assignment = nodes.Assign( + nodes.Name('_trans', 'store'), var) + else: + plural_expr = var num_called_num = name.value == 'num' parser.stream.expect('block_end') @@ -289,7 +301,10 @@ class InternationalizationExtension(Extension): bool(referenced), num_called_num and have_plural) node.set_lineno(lineno) - return node + if plural_expr_assignment is not None: + return [plural_expr_assignment, node] + else: + return node def _parse_block(self, parser, allow_pluralize): """Parse until the next block tag with a given name.""" @@ -352,7 +367,7 @@ class InternationalizationExtension(Extension): # enough to handle the variable expansion and autoescape # handling itself if self.environment.newstyle_gettext: - for key, value in variables.iteritems(): + for key, value in iteritems(variables): # the function adds that later anyways in case num was # called num, so just skip it. if num_called_num and key == 'num': @@ -474,7 +489,7 @@ def extract_from_ast(node, gettext_functions=GETTEXT_FUNCTIONS, strings = [] for arg in node.args: if isinstance(arg, nodes.Const) and \ - isinstance(arg.value, basestring): + isinstance(arg.value, string_types): strings.append(arg.value) else: strings.append(None) @@ -550,6 +565,10 @@ def babel_extract(fileobj, keywords, comment_tags, options): The `newstyle_gettext` flag can be set to `True` to enable newstyle gettext calls. + .. versionchanged:: 2.7 + A `silent` option can now be provided. If set to `False` template + syntax errors are propagated instead of being ignored. + :param fileobj: the file-like object the messages should be extracted from :param keywords: a list of keywords (i.e. function names) that should be recognized as translation functions @@ -569,8 +588,10 @@ def babel_extract(fileobj, keywords, comment_tags, options): extensions.add(InternationalizationExtension) def getbool(options, key, default=False): - options.get(key, str(default)).lower() in ('1', 'on', 'yes', 'true') + return options.get(key, str(default)).lower() in \ + ('1', 'on', 'yes', 'true') + silent = getbool(options, 'silent', True) environment = Environment( options.get('block_start_string', BLOCK_START_STRING), options.get('block_end_string', BLOCK_END_STRING), @@ -581,7 +602,10 @@ def babel_extract(fileobj, keywords, comment_tags, options): options.get('line_statement_prefix') or LINE_STATEMENT_PREFIX, options.get('line_comment_prefix') or LINE_COMMENT_PREFIX, getbool(options, 'trim_blocks', TRIM_BLOCKS), - NEWLINE_SEQUENCE, frozenset(extensions), + getbool(options, 'lstrip_blocks', LSTRIP_BLOCKS), + NEWLINE_SEQUENCE, + getbool(options, 'keep_trailing_newline', KEEP_TRAILING_NEWLINE), + frozenset(extensions), cache_size=0, auto_reload=False ) @@ -593,7 +617,9 @@ def babel_extract(fileobj, keywords, comment_tags, options): try: node = environment.parse(source) tokens = list(environment.lex(environment.preprocess(source))) - except TemplateSyntaxError, e: + except TemplateSyntaxError as e: + if not silent: + raise # skip templates with syntax errors return diff --git a/module/lib/jinja2/filters.py b/module/lib/jinja2/filters.py index d1848e434..fd0db04aa 100644 --- a/module/lib/jinja2/filters.py +++ b/module/lib/jinja2/filters.py @@ -10,12 +10,15 @@ """ import re import math + from random import choice from operator import itemgetter -from itertools import imap, groupby -from jinja2.utils import Markup, escape, pformat, urlize, soft_unicode +from itertools import groupby +from jinja2.utils import Markup, escape, pformat, urlize, soft_unicode, \ + unicode_urlencode from jinja2.runtime import Undefined -from jinja2.exceptions import FilterArgumentError, SecurityError +from jinja2.exceptions import FilterArgumentError +from jinja2._compat import next, imap, string_types, text_type, iteritems _word_re = re.compile(r'\w+(?u)') @@ -48,11 +51,50 @@ def environmentfilter(f): return f +def make_attrgetter(environment, attribute): + """Returns a callable that looks up the given attribute from a + passed object with the rules of the environment. Dots are allowed + to access attributes of attributes. Integer parts in paths are + looked up as integers. + """ + if not isinstance(attribute, string_types) \ + or ('.' not in attribute and not attribute.isdigit()): + return lambda x: environment.getitem(x, attribute) + attribute = attribute.split('.') + def attrgetter(item): + for part in attribute: + if part.isdigit(): + part = int(part) + item = environment.getitem(item, part) + return item + return attrgetter + + def do_forceescape(value): """Enforce HTML escaping. This will probably double escape variables.""" if hasattr(value, '__html__'): value = value.__html__() - return escape(unicode(value)) + return escape(text_type(value)) + + +def do_urlencode(value): + """Escape strings for use in URLs (uses UTF-8 encoding). It accepts both + dictionaries and regular strings as well as pairwise iterables. + + .. versionadded:: 2.7 + """ + itemiter = None + if isinstance(value, dict): + itemiter = iteritems(value) + elif not isinstance(value, string_types): + try: + itemiter = iter(value) + except TypeError: + pass + if itemiter is None: + return unicode_urlencode(value) + return u'&'.join(unicode_urlencode(k) + '=' + + unicode_urlencode(v) for k, v in itemiter) @evalcontextfilter @@ -74,7 +116,7 @@ def do_replace(eval_ctx, s, old, new, count=None): if count is None: count = -1 if not eval_ctx.autoescape: - return unicode(s).replace(unicode(old), unicode(new), count) + return text_type(s).replace(text_type(old), text_type(new), count) if hasattr(old, '__html__') or hasattr(new, '__html__') and \ not hasattr(s, '__html__'): s = escape(s) @@ -119,7 +161,7 @@ def do_xmlattr(_eval_ctx, d, autospace=True): """ rv = u' '.join( u'%s="%s"' % (escape(key), escape(value)) - for key, value in d.iteritems() + for key, value in iteritems(d) if value is not None and not isinstance(value, Undefined) ) if autospace and rv: @@ -140,7 +182,12 @@ def do_title(s): """Return a titlecased version of the value. I.e. words will start with uppercase letters, all remaining characters are lowercase. """ - return soft_unicode(s).title() + rv = [] + for item in re.compile(r'([-\s]+)(?u)').split(s): + if not item: + continue + rv.append(item[0].upper() + item[1:].lower()) + return ''.join(rv) def do_dictsort(value, case_sensitive=False, by='key'): @@ -153,7 +200,7 @@ def do_dictsort(value, case_sensitive=False, by='key'): {% for item in mydict|dictsort %} sort the dict by key, case insensitive - {% for item in mydict|dicsort(true) %} + {% for item in mydict|dictsort(true) %} sort the dict by key, case sensitive {% for item in mydict|dictsort(false, 'value') %} @@ -169,14 +216,16 @@ def do_dictsort(value, case_sensitive=False, by='key'): '"key" or "value"') def sort_func(item): value = item[pos] - if isinstance(value, basestring) and not case_sensitive: + if isinstance(value, string_types) and not case_sensitive: value = value.lower() return value return sorted(value.items(), key=sort_func) -def do_sort(value, reverse=False, case_sensitive=False): +@environmentfilter +def do_sort(environment, value, reverse=False, case_sensitive=False, + attribute=None): """Sort an iterable. Per default it sorts ascending, if you pass it true as first argument it will reverse the sorting. @@ -189,14 +238,30 @@ def do_sort(value, reverse=False, case_sensitive=False): {% for item in iterable|sort %} ... {% endfor %} + + It is also possible to sort by an attribute (for example to sort + by the date of an object) by specifying the `attribute` parameter: + + .. sourcecode:: jinja + + {% for item in iterable|sort(attribute='date') %} + ... + {% endfor %} + + .. versionchanged:: 2.6 + The `attribute` parameter was added. """ if not case_sensitive: def sort_func(item): - if isinstance(item, basestring): + if isinstance(item, string_types): item = item.lower() return item else: sort_func = None + if attribute is not None: + getter = make_attrgetter(environment, attribute) + def sort_func(item, processor=sort_func or (lambda x: x)): + return processor(getter(item)) return sorted(value, key=sort_func, reverse=reverse) @@ -217,13 +282,13 @@ def do_default(value, default_value=u'', boolean=False): {{ ''|default('the string was empty', true) }} """ - if (boolean and not value) or isinstance(value, Undefined): + if isinstance(value, Undefined) or (boolean and not value): return default_value return value @evalcontextfilter -def do_join(eval_ctx, value, d=u''): +def do_join(eval_ctx, value, d=u'', attribute=None): """Return a string which is the concatenation of the strings in the sequence. The separator between elements is an empty string per default, you can define it with the optional parameter: @@ -235,10 +300,22 @@ def do_join(eval_ctx, value, d=u''): {{ [1, 2, 3]|join }} -> 123 + + It is also possible to join certain attributes of an object: + + .. sourcecode:: jinja + + {{ users|join(', ', attribute='username') }} + + .. versionadded:: 2.6 + The `attribute` parameter was added. """ + if attribute is not None: + value = imap(make_attrgetter(eval_ctx.environment, attribute), value) + # no automatic escaping? joining is a lot eaiser then if not eval_ctx.autoescape: - return unicode(d).join(imap(unicode, value)) + return text_type(d).join(imap(text_type, value)) # if the delimiter doesn't have an html representation we check # if any of the items has. If yes we do a coercion to Markup @@ -249,11 +326,11 @@ def do_join(eval_ctx, value, d=u''): if hasattr(item, '__html__'): do_escape = True else: - value[idx] = unicode(item) + value[idx] = text_type(item) if do_escape: d = escape(d) else: - d = unicode(d) + d = text_type(d) return d.join(value) # no html involved, to normal joining @@ -262,14 +339,14 @@ def do_join(eval_ctx, value, d=u''): def do_center(value, width=80): """Centers the value in a field of a given width.""" - return unicode(value).center(width) + return text_type(value).center(width) @environmentfilter def do_first(environment, seq): """Return the first item of a sequence.""" try: - return iter(seq).next() + return next(iter(seq)) except StopIteration: return environment.undefined('No first item, sequence was empty.') @@ -278,7 +355,7 @@ def do_first(environment, seq): def do_last(environment, seq): """Return the last item of a sequence.""" try: - return iter(reversed(seq)).next() + return next(iter(reversed(seq))) except StopIteration: return environment.undefined('No last item, sequence was empty.') @@ -293,21 +370,33 @@ def do_random(environment, seq): def do_filesizeformat(value, binary=False): - """Format the value like a 'human-readable' file size (i.e. 13 KB, - 4.1 MB, 102 bytes, etc). Per default decimal prefixes are used (mega, - giga, etc.), if the second parameter is set to `True` the binary - prefixes are used (mebi, gibi). + """Format the value like a 'human-readable' file size (i.e. 13 kB, + 4.1 MB, 102 Bytes, etc). Per default decimal prefixes are used (Mega, + Giga, etc.), if the second parameter is set to `True` the binary + prefixes are used (Mebi, Gibi). """ bytes = float(value) base = binary and 1024 or 1000 - middle = binary and 'i' or '' - if bytes < base: - return "%d Byte%s" % (bytes, bytes != 1 and 's' or '') - elif bytes < base * base: - return "%.1f K%sB" % (bytes / base, middle) - elif bytes < base * base * base: - return "%.1f M%sB" % (bytes / (base * base), middle) - return "%.1f G%sB" % (bytes / (base * base * base), middle) + prefixes = [ + (binary and 'KiB' or 'kB'), + (binary and 'MiB' or 'MB'), + (binary and 'GiB' or 'GB'), + (binary and 'TiB' or 'TB'), + (binary and 'PiB' or 'PB'), + (binary and 'EiB' or 'EB'), + (binary and 'ZiB' or 'ZB'), + (binary and 'YiB' or 'YB') + ] + if bytes == 1: + return '1 Byte' + elif bytes < base: + return '%d Bytes' % bytes + else: + for i, prefix in enumerate(prefixes): + unit = base ** (i + 2) + if bytes < unit: + return '%.1f %s' % ((base * bytes / unit), prefix) + return '%.1f %s' % ((base * bytes / unit), prefix) def do_pprint(value, verbose=False): @@ -360,16 +449,17 @@ def do_truncate(s, length=255, killwords=False, end='...'): """Return a truncated copy of the string. The length is specified with the first parameter which defaults to ``255``. If the second parameter is ``true`` the filter will cut the text at length. Otherwise - it will try to save the last word. If the text was in fact + it will discard the last word. If the text was in fact truncated it will append an ellipsis sign (``"..."``). If you want a different ellipsis sign than ``"..."`` you can specify it using the third parameter. - .. sourcecode jinja:: + .. sourcecode:: jinja - {{ mytext|truncate(300, false, '»') }} - truncate mytext to 300 chars, don't split up words, use a - right pointing double arrow as ellipsis sign. + {{ "foo bar"|truncate(5) }} + -> "foo ..." + {{ "foo bar"|truncate(5, True) }} + -> "foo b..." """ if len(s) <= length: return s @@ -386,16 +476,24 @@ def do_truncate(s, length=255, killwords=False, end='...'): result.append(end) return u' '.join(result) - -def do_wordwrap(s, width=79, break_long_words=True): +@environmentfilter +def do_wordwrap(environment, s, width=79, break_long_words=True, + wrapstring=None): """ Return a copy of the string passed to the filter wrapped after ``79`` characters. You can override this default using the first parameter. If you set the second parameter to `false` Jinja will not - split words apart if they are longer than `width`. + split words apart if they are longer than `width`. By default, the newlines + will be the default newlines for the environment, but this can be changed + using the wrapstring keyword argument. + + .. versionadded:: 2.7 + Added support for the `wrapstring` parameter. """ + if not wrapstring: + wrapstring = environment.newline_sequence import textwrap - return u'\n'.join(textwrap.wrap(s, width=width, expand_tabs=False, + return wrapstring.join(textwrap.wrap(s, width=width, expand_tabs=False, replace_whitespace=False, break_long_words=break_long_words)) @@ -456,7 +554,7 @@ def do_striptags(value): """ if hasattr(value, '__html__'): value = value.__html__() - return Markup(unicode(value)).striptags() + return Markup(text_type(value)).striptags() def do_slice(value, slices, fill_with=None): @@ -484,7 +582,7 @@ def do_slice(value, slices, fill_with=None): items_per_slice = length // slices slices_with_extra = length % slices offset = 0 - for slice_number in xrange(slices): + for slice_number in range(slices): start = offset + slice_number * items_per_slice if slice_number < slices_with_extra: offset += 1 @@ -500,7 +598,7 @@ def do_batch(value, linecount, fill_with=None): A filter that batches items. It works pretty much like `slice` just the other way round. It returns a list of lists with the given number of items. If you provide a second parameter this - is used to fill missing items. See this example: + is used to fill up missing items. See this example: .. sourcecode:: html+jinja @@ -595,8 +693,12 @@ def do_groupby(environment, value, attribute): As you can see the item we're grouping by is stored in the `grouper` attribute and the `list` contains all the objects that have this grouper in common. + + .. versionchanged:: 2.6 + It's now possible to use dotted notation to group by the child + attribute of another attribute. """ - expr = lambda x: environment.getitem(x, attribute) + expr = make_attrgetter(environment, attribute) return sorted(map(_GroupTuple, groupby(sorted(value, key=expr), expr))) @@ -605,10 +707,32 @@ class _GroupTuple(tuple): grouper = property(itemgetter(0)) list = property(itemgetter(1)) - def __new__(cls, (key, value)): + def __new__(cls, xxx_todo_changeme): + (key, value) = xxx_todo_changeme return tuple.__new__(cls, (key, list(value))) +@environmentfilter +def do_sum(environment, iterable, attribute=None, start=0): + """Returns the sum of a sequence of numbers plus the value of parameter + 'start' (which defaults to 0). When the sequence is empty it returns + start. + + It is also possible to sum up only certain attributes: + + .. sourcecode:: jinja + + Total: {{ items|sum(attribute='price') }} + + .. versionchanged:: 2.6 + The `attribute` parameter was added to allow suming up over + attributes. Also the `start` parameter was moved on to the right. + """ + if attribute is not None: + iterable = imap(make_attrgetter(environment, attribute), iterable) + return sum(iterable, start) + + def do_list(value): """Convert the value into a list. If it was a string the returned list will be a list of characters. @@ -625,14 +749,14 @@ def do_mark_safe(value): def do_mark_unsafe(value): """Mark a value as unsafe. This is the reverse operation for :func:`safe`.""" - return unicode(value) + return text_type(value) def do_reverse(value): """Reverse the object or return an iterator the iterates over it the other way round. """ - if isinstance(value, basestring): + if isinstance(value, string_types): return value[::-1] try: return reversed(value) @@ -670,6 +794,144 @@ def do_attr(environment, obj, name): return environment.undefined(obj=obj, name=name) +@contextfilter +def do_map(*args, **kwargs): + """Applies a filter on a sequence of objects or looks up an attribute. + This is useful when dealing with lists of objects but you are really + only interested in a certain value of it. + + The basic usage is mapping on an attribute. Imagine you have a list + of users but you are only interested in a list of usernames: + + .. sourcecode:: jinja + + Users on this page: {{ users|map(attribute='username')|join(', ') }} + + Alternatively you can let it invoke a filter by passing the name of the + filter and the arguments afterwards. A good example would be applying a + text conversion filter on a sequence: + + .. sourcecode:: jinja + + Users on this page: {{ titles|map('lower')|join(', ') }} + + .. versionadded:: 2.7 + """ + context = args[0] + seq = args[1] + + if len(args) == 2 and 'attribute' in kwargs: + attribute = kwargs.pop('attribute') + if kwargs: + raise FilterArgumentError('Unexpected keyword argument %r' % + next(iter(kwargs))) + func = make_attrgetter(context.environment, attribute) + else: + try: + name = args[2] + args = args[3:] + except LookupError: + raise FilterArgumentError('map requires a filter argument') + func = lambda item: context.environment.call_filter( + name, item, args, kwargs, context=context) + + if seq: + for item in seq: + yield func(item) + + +@contextfilter +def do_select(*args, **kwargs): + """Filters a sequence of objects by appying a test to either the object + or the attribute and only selecting the ones with the test succeeding. + + Example usage: + + .. sourcecode:: jinja + + {{ numbers|select("odd") }} + + .. versionadded:: 2.7 + """ + return _select_or_reject(args, kwargs, lambda x: x, False) + + +@contextfilter +def do_reject(*args, **kwargs): + """Filters a sequence of objects by appying a test to either the object + or the attribute and rejecting the ones with the test succeeding. + + Example usage: + + .. sourcecode:: jinja + + {{ numbers|reject("odd") }} + + .. versionadded:: 2.7 + """ + return _select_or_reject(args, kwargs, lambda x: not x, False) + + +@contextfilter +def do_selectattr(*args, **kwargs): + """Filters a sequence of objects by appying a test to either the object + or the attribute and only selecting the ones with the test succeeding. + + Example usage: + + .. sourcecode:: jinja + + {{ users|selectattr("is_active") }} + {{ users|selectattr("email", "none") }} + + .. versionadded:: 2.7 + """ + return _select_or_reject(args, kwargs, lambda x: x, True) + + +@contextfilter +def do_rejectattr(*args, **kwargs): + """Filters a sequence of objects by appying a test to either the object + or the attribute and rejecting the ones with the test succeeding. + + .. sourcecode:: jinja + + {{ users|rejectattr("is_active") }} + {{ users|rejectattr("email", "none") }} + + .. versionadded:: 2.7 + """ + return _select_or_reject(args, kwargs, lambda x: not x, True) + + +def _select_or_reject(args, kwargs, modfunc, lookup_attr): + context = args[0] + seq = args[1] + if lookup_attr: + try: + attr = args[2] + except LookupError: + raise FilterArgumentError('Missing parameter for attribute name') + transfunc = make_attrgetter(context.environment, attr) + off = 1 + else: + off = 0 + transfunc = lambda x: x + + try: + name = args[2 + off] + args = args[3 + off:] + func = lambda item: context.environment.call_test( + name, item, args, kwargs) + except LookupError: + func = bool + + if seq: + for item in seq: + if modfunc(func(transfunc(item))): + yield item + + FILTERS = { 'attr': do_attr, 'replace': do_replace, @@ -694,7 +956,10 @@ FILTERS = { 'capitalize': do_capitalize, 'first': do_first, 'last': do_last, + 'map': do_map, 'random': do_random, + 'reject': do_reject, + 'rejectattr': do_rejectattr, 'filesizeformat': do_filesizeformat, 'pprint': do_pprint, 'truncate': do_truncate, @@ -708,12 +973,15 @@ FILTERS = { 'format': do_format, 'trim': do_trim, 'striptags': do_striptags, + 'select': do_select, + 'selectattr': do_selectattr, 'slice': do_slice, 'batch': do_batch, - 'sum': sum, + 'sum': do_sum, 'abs': abs, 'round': do_round, 'groupby': do_groupby, 'safe': do_mark_safe, - 'xmlattr': do_xmlattr + 'xmlattr': do_xmlattr, + 'urlencode': do_urlencode } diff --git a/module/lib/jinja2/lexer.py b/module/lib/jinja2/lexer.py index 0d3f69617..a50128507 100644 --- a/module/lib/jinja2/lexer.py +++ b/module/lib/jinja2/lexer.py @@ -15,10 +15,13 @@ :license: BSD, see LICENSE for more details. """ import re + from operator import itemgetter from collections import deque from jinja2.exceptions import TemplateSyntaxError -from jinja2.utils import LRUCache, next +from jinja2.utils import LRUCache +from jinja2._compat import next, iteritems, implements_iterator, text_type, \ + intern # cache for the lexers. Exists in order to be able to have multiple @@ -126,7 +129,7 @@ operators = { ';': TOKEN_SEMICOLON } -reverse_operators = dict([(v, k) for k, v in operators.iteritems()]) +reverse_operators = dict([(v, k) for k, v in iteritems(operators)]) assert len(operators) == len(reverse_operators), 'operators dropped' operator_re = re.compile('(%s)' % '|'.join(re.escape(x) for x in sorted(operators, key=lambda x: -len(x)))) @@ -197,7 +200,7 @@ def compile_rules(environment): if environment.line_statement_prefix is not None: rules.append((len(environment.line_statement_prefix), 'linestatement', - r'^\s*' + e(environment.line_statement_prefix))) + r'^[ \t\v]*' + e(environment.line_statement_prefix))) if environment.line_comment_prefix is not None: rules.append((len(environment.line_comment_prefix), 'linecomment', r'(?:^|(?<=\S))[^\S\r\n]*' + @@ -262,6 +265,7 @@ class Token(tuple): ) +@implements_iterator class TokenStreamIterator(object): """The iterator for tokenstreams. Iterate over the stream until the eof token is reached. @@ -273,7 +277,7 @@ class TokenStreamIterator(object): def __iter__(self): return self - def next(self): + def __next__(self): token = self.stream.current if token.type is TOKEN_EOF: self.stream.close() @@ -282,6 +286,7 @@ class TokenStreamIterator(object): return token +@implements_iterator class TokenStream(object): """A token stream is an iterable that yields :class:`Token`\s. The parser however does not iterate over it but calls :meth:`next` to go @@ -289,7 +294,7 @@ class TokenStream(object): """ def __init__(self, generator, name, filename): - self._next = iter(generator).next + self._iter = iter(generator) self._pushed = deque() self.name = name self.filename = filename @@ -300,8 +305,9 @@ class TokenStream(object): def __iter__(self): return TokenStreamIterator(self) - def __nonzero__(self): + def __bool__(self): return bool(self._pushed) or self.current.type is not TOKEN_EOF + __nonzero__ = __bool__ # py2 eos = property(lambda x: not x, doc="Are we at the end of the stream?") @@ -319,7 +325,7 @@ class TokenStream(object): def skip(self, n=1): """Got n tokens ahead.""" - for x in xrange(n): + for x in range(n): next(self) def next_if(self, expr): @@ -333,14 +339,14 @@ class TokenStream(object): """Like :meth:`next_if` but only returns `True` or `False`.""" return self.next_if(expr) is not None - def next(self): + def __next__(self): """Go one token ahead and return the old one""" rv = self.current if self._pushed: self.current = self._pushed.popleft() elif self.current.type is not TOKEN_EOF: try: - self.current = self._next() + self.current = next(self._iter) except StopIteration: self.close() return rv @@ -348,7 +354,7 @@ class TokenStream(object): def close(self): """Close the stream.""" self.current = Token(self.current.lineno, TOKEN_EOF, '') - self._next = None + self._iter = None self.closed = True def expect(self, expr): @@ -383,7 +389,9 @@ def get_lexer(environment): environment.line_statement_prefix, environment.line_comment_prefix, environment.trim_blocks, - environment.newline_sequence) + environment.lstrip_blocks, + environment.newline_sequence, + environment.keep_trailing_newline) lexer = _lexer_cache.get(key) if lexer is None: lexer = Lexer(environment) @@ -414,7 +422,7 @@ class Lexer(object): (operator_re, TOKEN_OPERATOR, None) ] - # assamble the root lexing rule. because "|" is ungreedy + # assemble the root lexing rule. because "|" is ungreedy # we have to sort by length so that the lexer continues working # as expected when we have parsing rules like <% for block and # <%= for variables. (if someone wants asp like syntax) @@ -425,7 +433,44 @@ class Lexer(object): # block suffix if trimming is enabled block_suffix_re = environment.trim_blocks and '\\n?' or '' + # strip leading spaces if lstrip_blocks is enabled + prefix_re = {} + if environment.lstrip_blocks: + # use '{%+' to manually disable lstrip_blocks behavior + no_lstrip_re = e('+') + # detect overlap between block and variable or comment strings + block_diff = c(r'^%s(.*)' % e(environment.block_start_string)) + # make sure we don't mistake a block for a variable or a comment + m = block_diff.match(environment.comment_start_string) + no_lstrip_re += m and r'|%s' % e(m.group(1)) or '' + m = block_diff.match(environment.variable_start_string) + no_lstrip_re += m and r'|%s' % e(m.group(1)) or '' + + # detect overlap between comment and variable strings + comment_diff = c(r'^%s(.*)' % e(environment.comment_start_string)) + m = comment_diff.match(environment.variable_start_string) + no_variable_re = m and r'(?!%s)' % e(m.group(1)) or '' + + lstrip_re = r'^[ \t]*' + block_prefix_re = r'%s%s(?!%s)|%s\+?' % ( + lstrip_re, + e(environment.block_start_string), + no_lstrip_re, + e(environment.block_start_string), + ) + comment_prefix_re = r'%s%s%s|%s\+?' % ( + lstrip_re, + e(environment.comment_start_string), + no_variable_re, + e(environment.comment_start_string), + ) + prefix_re['block'] = block_prefix_re + prefix_re['comment'] = comment_prefix_re + else: + block_prefix_re = '%s' % e(environment.block_start_string) + self.newline_sequence = environment.newline_sequence + self.keep_trailing_newline = environment.keep_trailing_newline # global lexing rules self.rules = { @@ -434,11 +479,11 @@ class Lexer(object): (c('(.*?)(?:%s)' % '|'.join( [r'(?P<raw_begin>(?:\s*%s\-|%s)\s*raw\s*(?:\-%s\s*|%s))' % ( e(environment.block_start_string), - e(environment.block_start_string), + block_prefix_re, e(environment.block_end_string), e(environment.block_end_string) )] + [ - r'(?P<%s_begin>\s*%s\-|%s)' % (n, r, r) + r'(?P<%s_begin>\s*%s\-|%s)' % (n, r, prefix_re.get(n,r)) for n, r in root_tag_rules ])), (TOKEN_DATA, '#bygroup'), '#bygroup'), # data @@ -472,7 +517,7 @@ class Lexer(object): TOKEN_RAW_BEGIN: [ (c('(.*?)((?:\s*%s\-|%s)\s*endraw\s*(?:\-%s\s*|%s%s))' % ( e(environment.block_start_string), - e(environment.block_start_string), + block_prefix_re, e(environment.block_end_string), e(environment.block_end_string), block_suffix_re @@ -491,7 +536,7 @@ class Lexer(object): } def _normalize_newlines(self, value): - """Called for strings and template data to normlize it to unicode.""" + """Called for strings and template data to normalize it to unicode.""" return newline_re.sub(self.newline_sequence, value) def tokenize(self, source, name=None, filename=None, state=None): @@ -526,7 +571,7 @@ class Lexer(object): value = self._normalize_newlines(value[1:-1]) \ .encode('ascii', 'backslashreplace') \ .decode('unicode-escape') - except Exception, e: + except Exception as e: msg = str(e).split(':')[-1].strip() raise TemplateSyntaxError(msg, lineno, name, filename) # if we can express it as bytestring (ascii only) @@ -549,7 +594,14 @@ class Lexer(object): """This method tokenizes the text and returns the tokens in a generator. Use this method if you just want to tokenize a template. """ - source = '\n'.join(unicode(source).splitlines()) + source = text_type(source) + lines = source.splitlines() + if self.keep_trailing_newline and source: + for newline in ('\r\n', '\r', '\n'): + if source.endswith(newline): + lines.append('') + break + source = '\n'.join(lines) pos = 0 lineno = 1 stack = ['root'] @@ -571,7 +623,7 @@ class Lexer(object): if m is None: continue - # we only match blocks and variables if brances / parentheses + # we only match blocks and variables if braces / parentheses # are balanced. continue parsing with the lower rule which # is the operator rule. do this only if the end tags look # like operators @@ -590,7 +642,7 @@ class Lexer(object): # yield for the current token the first named # group that matched elif token == '#bygroup': - for key, value in m.groupdict().iteritems(): + for key, value in iteritems(m.groupdict()): if value is not None: yield lineno, key, value lineno += value.count('\n') @@ -647,7 +699,7 @@ class Lexer(object): stack.pop() # resolve the new state by group checking elif new_state == '#bygroup': - for key, value in m.groupdict().iteritems(): + for key, value in iteritems(m.groupdict()): if value is not None: stack.append(key) break @@ -669,7 +721,7 @@ class Lexer(object): # publish new function and start again pos = pos2 break - # if loop terminated without break we havn't found a single match + # if loop terminated without break we haven't found a single match # either we are at the end of the file or we have a problem else: # end of text diff --git a/module/lib/jinja2/loaders.py b/module/lib/jinja2/loaders.py index bd435e8b0..cc9c6836e 100644 --- a/module/lib/jinja2/loaders.py +++ b/module/lib/jinja2/loaders.py @@ -13,12 +13,10 @@ import sys import weakref from types import ModuleType from os import path -try: - from hashlib import sha1 -except ImportError: - from sha import new as sha1 +from hashlib import sha1 from jinja2.exceptions import TemplateNotFound -from jinja2.utils import LRUCache, open_if_exists, internalcode +from jinja2.utils import open_if_exists, internalcode +from jinja2._compat import string_types, iteritems def split_template_path(template): @@ -153,7 +151,7 @@ class FileSystemLoader(BaseLoader): """ def __init__(self, searchpath, encoding='utf-8'): - if isinstance(searchpath, basestring): + if isinstance(searchpath, string_types): searchpath = [searchpath] self.searchpath = list(searchpath) self.encoding = encoding @@ -251,8 +249,7 @@ class PackageLoader(BaseLoader): for filename in self.provider.resource_listdir(path): fullname = path + '/' + filename if self.provider.resource_isdir(fullname): - for item in _walk(fullname): - results.append(item) + _walk(fullname) else: results.append(fullname[offset:].lstrip('/')) _walk(path) @@ -275,7 +272,7 @@ class DictLoader(BaseLoader): def get_source(self, environment, template): if template in self.mapping: source = self.mapping[template] - return source, None, lambda: source != self.mapping.get(template) + return source, None, lambda: source == self.mapping.get(template) raise TemplateNotFound(template) def list_templates(self): @@ -307,7 +304,7 @@ class FunctionLoader(BaseLoader): rv = self.load_func(template) if rv is None: raise TemplateNotFound(template) - elif isinstance(rv, basestring): + elif isinstance(rv, string_types): return rv, None, None return rv @@ -331,12 +328,16 @@ class PrefixLoader(BaseLoader): self.mapping = mapping self.delimiter = delimiter - def get_source(self, environment, template): + def get_loader(self, template): try: prefix, name = template.split(self.delimiter, 1) loader = self.mapping[prefix] except (ValueError, KeyError): raise TemplateNotFound(template) + return loader, name + + def get_source(self, environment, template): + loader, name = self.get_loader(template) try: return loader.get_source(environment, name) except TemplateNotFound: @@ -344,9 +345,19 @@ class PrefixLoader(BaseLoader): # (the one that includes the prefix) raise TemplateNotFound(template) + @internalcode + def load(self, environment, name, globals=None): + loader, local_name = self.get_loader(name) + try: + return loader.load(environment, local_name, globals) + except TemplateNotFound: + # re-raise the exception with the correct fileame here. + # (the one that includes the prefix) + raise TemplateNotFound(name) + def list_templates(self): result = [] - for prefix, loader in self.mapping.iteritems(): + for prefix, loader in iteritems(self.mapping): for template in loader.list_templates(): result.append(prefix + self.delimiter + template) return result @@ -377,6 +388,15 @@ class ChoiceLoader(BaseLoader): pass raise TemplateNotFound(template) + @internalcode + def load(self, environment, name, globals=None): + for loader in self.loaders: + try: + return loader.load(environment, name, globals) + except TemplateNotFound: + pass + raise TemplateNotFound(name) + def list_templates(self): found = set() for loader in self.loaders: @@ -397,6 +417,8 @@ class ModuleLoader(BaseLoader): ... ModuleLoader('/path/to/compiled/templates'), ... FileSystemLoader('/path/to/templates') ... ]) + + Templates can be precompiled with :meth:`Environment.compile_templates`. """ has_source_access = False @@ -407,7 +429,7 @@ class ModuleLoader(BaseLoader): # create a fake module that looks for the templates in the # path given. mod = _TemplateModule(package_name) - if isinstance(path, basestring): + if isinstance(path, string_types): path = [path] else: path = list(path) diff --git a/module/lib/jinja2/meta.py b/module/lib/jinja2/meta.py index 3a779a5e9..3110cff60 100644 --- a/module/lib/jinja2/meta.py +++ b/module/lib/jinja2/meta.py @@ -11,6 +11,7 @@ """ from jinja2 import nodes from jinja2.compiler import CodeGenerator +from jinja2._compat import string_types class TrackingCodeGenerator(CodeGenerator): @@ -77,7 +78,7 @@ def find_referenced_templates(ast): # something const, only yield the strings and ignore # non-string consts that really just make no sense if isinstance(template_name, nodes.Const): - if isinstance(template_name.value, basestring): + if isinstance(template_name.value, string_types): yield template_name.value # something dynamic in there else: @@ -87,7 +88,7 @@ def find_referenced_templates(ast): yield None continue # constant is a basestring, direct template name - if isinstance(node.template.value, basestring): + if isinstance(node.template.value, string_types): yield node.template.value # a tuple or list (latter *should* not happen) made of consts, # yield the consts that are strings. We could warn here for @@ -95,7 +96,7 @@ def find_referenced_templates(ast): elif isinstance(node, nodes.Include) and \ isinstance(node.template.value, (tuple, list)): for template_name in node.template.value: - if isinstance(template_name, basestring): + if isinstance(template_name, string_types): yield template_name # something else we don't care about, we could warn here else: diff --git a/module/lib/jinja2/nodes.py b/module/lib/jinja2/nodes.py index 6446c70ea..c5697e6b5 100644 --- a/module/lib/jinja2/nodes.py +++ b/module/lib/jinja2/nodes.py @@ -13,13 +13,15 @@ :license: BSD, see LICENSE for more details. """ import operator -from itertools import chain, izip + from collections import deque -from jinja2.utils import Markup, MethodType, FunctionType +from jinja2.utils import Markup +from jinja2._compat import next, izip, with_metaclass, text_type, \ + method_type, function_type #: the types we support for context functions -_context_function_types = (FunctionType, MethodType) +_context_function_types = (function_type, method_type) _binop_to_func = { @@ -77,6 +79,7 @@ class EvalContext(object): """ def __init__(self, environment, template_name=None): + self.environment = environment if callable(environment.autoescape): self.autoescape = environment.autoescape(template_name) else: @@ -101,9 +104,9 @@ def get_eval_context(node, ctx): return ctx -class Node(object): +class Node(with_metaclass(NodeType, object)): """Baseclass for all Jinja2 nodes. There are a number of nodes available - of different types. There are three major types: + of different types. There are four major types: - :class:`Stmt`: statements - :class:`Expr`: expressions @@ -117,7 +120,6 @@ class Node(object): The `environment` attribute is set at the end of the parsing process for all nodes automatically. """ - __metaclass__ = NodeType fields = () attributes = ('lineno', 'environment') abstract = True @@ -141,7 +143,7 @@ class Node(object): setattr(self, attr, attributes.pop(attr, None)) if attributes: raise TypeError('unknown attribute %r' % - iter(attributes).next()) + next(iter(attributes))) def iter_fields(self, exclude=None, only=None): """This method iterates over all fields that are defined and yields @@ -230,6 +232,9 @@ class Node(object): def __ne__(self, other): return not self.__eq__(other) + # Restore Python 2 hashing behavior on Python 3 + __hash__ = object.__hash__ + def __repr__(self): return '%s(%s)' % ( self.__class__.__name__, @@ -372,10 +377,14 @@ class BinExpr(Expr): def as_const(self, eval_ctx=None): eval_ctx = get_eval_context(self, eval_ctx) + # intercepted operators cannot be folded at compile time + if self.environment.sandboxed and \ + self.operator in self.environment.intercepted_binops: + raise Impossible() f = _binop_to_func[self.operator] try: return f(self.left.as_const(eval_ctx), self.right.as_const(eval_ctx)) - except: + except Exception: raise Impossible() @@ -387,10 +396,14 @@ class UnaryExpr(Expr): def as_const(self, eval_ctx=None): eval_ctx = get_eval_context(self, eval_ctx) + # intercepted operators cannot be folded at compile time + if self.environment.sandboxed and \ + self.operator in self.environment.intercepted_unops: + raise Impossible() f = _uaop_to_func[self.operator] try: return f(self.node.as_const(eval_ctx)) - except: + except Exception: raise Impossible() @@ -431,7 +444,7 @@ class Const(Literal): constant value in the generated code, otherwise it will raise an `Impossible` exception. """ - from compiler import has_safe_repr + from .compiler import has_safe_repr if not has_safe_repr(value): raise Impossible() return cls(value, lineno=lineno, environment=environment) @@ -555,16 +568,16 @@ class Filter(Expr): if self.dyn_args is not None: try: args.extend(self.dyn_args.as_const(eval_ctx)) - except: + except Exception: raise Impossible() if self.dyn_kwargs is not None: try: kwargs.update(self.dyn_kwargs.as_const(eval_ctx)) - except: + except Exception: raise Impossible() try: return filter_(obj, *args, **kwargs) - except: + except Exception: raise Impossible() @@ -604,16 +617,16 @@ class Call(Expr): if self.dyn_args is not None: try: args.extend(self.dyn_args.as_const(eval_ctx)) - except: + except Exception: raise Impossible() if self.dyn_kwargs is not None: try: kwargs.update(self.dyn_kwargs.as_const(eval_ctx)) - except: + except Exception: raise Impossible() try: return obj(*args, **kwargs) - except: + except Exception: raise Impossible() @@ -628,7 +641,7 @@ class Getitem(Expr): try: return self.environment.getitem(self.node.as_const(eval_ctx), self.arg.as_const(eval_ctx)) - except: + except Exception: raise Impossible() def can_assign(self): @@ -648,7 +661,7 @@ class Getattr(Expr): eval_ctx = get_eval_context(self, eval_ctx) return self.environment.getattr(self.node.as_const(eval_ctx), self.attr) - except: + except Exception: raise Impossible() def can_assign(self): @@ -678,7 +691,7 @@ class Concat(Expr): def as_const(self, eval_ctx=None): eval_ctx = get_eval_context(self, eval_ctx) - return ''.join(unicode(x.as_const(eval_ctx)) for x in self.nodes) + return ''.join(text_type(x.as_const(eval_ctx)) for x in self.nodes) class Compare(Expr): @@ -695,7 +708,7 @@ class Compare(Expr): new_value = op.expr.as_const(eval_ctx) result = _cmpop_to_func[op.op](value, new_value) value = new_value - except: + except Exception: raise Impossible() return result diff --git a/module/lib/jinja2/parser.py b/module/lib/jinja2/parser.py index d44229ad0..f60cd018c 100644 --- a/module/lib/jinja2/parser.py +++ b/module/lib/jinja2/parser.py @@ -10,8 +10,8 @@ """ from jinja2 import nodes from jinja2.exceptions import TemplateSyntaxError, TemplateAssertionError -from jinja2.utils import next from jinja2.lexer import describe_token, describe_token_expr +from jinja2._compat import next, imap #: statements that callinto @@ -53,7 +53,7 @@ class Parser(object): def _fail_ut_eof(self, name, end_token_stack, lineno): expected = [] for exprs in end_token_stack: - expected.extend(map(describe_token_expr, exprs)) + expected.extend(imap(describe_token_expr, exprs)) if end_token_stack: currently_looking = ' or '.join( "'%s'" % describe_token_expr(expr) @@ -223,7 +223,7 @@ class Parser(object): # raise a nicer error message in that case. if self.stream.current.type == 'sub': self.fail('Block names in Jinja have to be valid Python ' - 'identifiers and may not contain hypens, use an ' + 'identifiers and may not contain hyphens, use an ' 'underscore instead.') node.body = self.parse_statements(('name:endblock',), drop_needle=True) @@ -698,7 +698,6 @@ class Parser(object): arg = nodes.Const(attr_token.value, lineno=attr_token.lineno) return nodes.Getitem(node, arg, 'load', lineno=token.lineno) if token.type == 'lbracket': - priority_on_attribute = False args = [] while self.stream.current.type != 'rbracket': if args: diff --git a/module/lib/jinja2/runtime.py b/module/lib/jinja2/runtime.py index 6fea3aa4f..7791c645a 100644 --- a/module/lib/jinja2/runtime.py +++ b/module/lib/jinja2/runtime.py @@ -8,13 +8,14 @@ :copyright: (c) 2010 by the Jinja Team. :license: BSD. """ -import sys -from itertools import chain, imap +from itertools import chain from jinja2.nodes import EvalContext, _context_function_types -from jinja2.utils import Markup, partial, soft_unicode, escape, missing, \ - concat, internalcode, next, object_type_repr +from jinja2.utils import Markup, soft_unicode, escape, missing, concat, \ + internalcode, object_type_repr from jinja2.exceptions import UndefinedError, TemplateRuntimeError, \ TemplateNotFound +from jinja2._compat import next, imap, text_type, iteritems, \ + implements_iterator, implements_to_string, string_types, PY2 # these variables are exported to the template runtime @@ -24,13 +25,14 @@ __all__ = ['LoopContext', 'TemplateReference', 'Macro', 'Markup', 'TemplateNotFound'] #: the name of the function that is used to convert something into -#: a string. 2to3 will adopt that automatically and the generated -#: code can take advantage of it. -to_string = unicode +#: a string. We can just use the text type here. +to_string = text_type #: the identity function. Useful for certain things in the environment identity = lambda x: x +_last_iteration = object() + def markup_join(seq): """Concatenation that escapes if necessary and converts to unicode.""" @@ -45,7 +47,7 @@ def markup_join(seq): def unicode_join(seq): """Simple args to unicode conversion and concatenation.""" - return concat(imap(unicode, seq)) + return concat(imap(text_type, seq)) def new_context(environment, template_name, blocks, vars=None, @@ -62,7 +64,7 @@ def new_context(environment, template_name, blocks, vars=None, # we don't want to modify the dict passed if shared: parent = dict(parent) - for key, value in locals.iteritems(): + for key, value in iteritems(locals): if key[:2] == 'l_' and value is not missing: parent[key[2:]] = value return Context(environment, parent, template_name, blocks) @@ -76,8 +78,6 @@ class TemplateReference(object): def __getitem__(self, name): blocks = self.__context.blocks[name] - wrap = self.__context.eval_ctx.autoescape and \ - Markup or (lambda x: x) return BlockReference(name, self.__context, blocks, 0) def __repr__(self): @@ -120,7 +120,7 @@ class Context(object): # create the initial mapping of blocks. Whenever template inheritance # takes place the runtime will update this mapping with the new blocks # from the template. - self.blocks = dict((k, [v]) for k, v in blocks.iteritems()) + self.blocks = dict((k, [v]) for k, v in iteritems(blocks)) def super(self, name, current): """Render a parent block.""" @@ -172,6 +172,16 @@ class Context(object): """ if __debug__: __traceback_hide__ = True + + # Allow callable classes to take a context + fn = __obj.__call__ + for fn_type in ('contextfunction', + 'evalcontextfunction', + 'environmentfunction'): + if hasattr(fn, fn_type): + __obj = fn + break + if isinstance(__obj, _context_function_types): if getattr(__obj, 'contextfunction', 0): args = (__self,) + args @@ -190,8 +200,9 @@ class Context(object): """Internal helper function to create a derived context.""" context = new_context(self.environment, self.name, {}, self.parent, True, None, locals) + context.vars.update(self.vars) context.eval_ctx = self.eval_ctx - context.blocks.update((k, list(v)) for k, v in self.blocks.iteritems()) + context.blocks.update((k, list(v)) for k, v in iteritems(self.blocks)) return context def _all(meth): @@ -205,7 +216,7 @@ class Context(object): items = _all('items') # not available on python 3 - if hasattr(dict, 'iterkeys'): + if PY2: iterkeys = _all('iterkeys') itervalues = _all('itervalues') iteritems = _all('iteritems') @@ -269,10 +280,12 @@ class BlockReference(object): class LoopContext(object): """A loop context for dynamic iteration.""" - def __init__(self, iterable, recurse=None): + def __init__(self, iterable, recurse=None, depth0=0): self._iterator = iter(iterable) self._recurse = recurse + self._after = self._safe_next() self.index0 = -1 + self.depth0 = depth0 # try to get the length of the iterable early. This must be done # here because there are some broken iterators around where there @@ -290,10 +303,11 @@ class LoopContext(object): return args[self.index0 % len(args)] first = property(lambda x: x.index0 == 0) - last = property(lambda x: x.index0 + 1 == x.length) + last = property(lambda x: x._after is _last_iteration) index = property(lambda x: x.index0 + 1) revindex = property(lambda x: x.length - x.index0) revindex0 = property(lambda x: x.length - x.index) + depth = property(lambda x: x.depth0 + 1) def __len__(self): return self.length @@ -301,12 +315,18 @@ class LoopContext(object): def __iter__(self): return LoopContextIterator(self) + def _safe_next(self): + try: + return next(self._iterator) + except StopIteration: + return _last_iteration + @internalcode def loop(self, iterable): if self._recurse is None: raise TypeError('Tried to call non recursive loop. Maybe you ' "forgot the 'recursive' modifier.") - return self._recurse(iterable, self._recurse) + return self._recurse(iterable, self._recurse, self.depth0 + 1) # a nifty trick to enhance the error message if someone tried to call # the the loop without or with too many arguments. @@ -333,6 +353,7 @@ class LoopContext(object): ) +@implements_iterator class LoopContextIterator(object): """The iterator for a loop context.""" __slots__ = ('context',) @@ -343,10 +364,14 @@ class LoopContextIterator(object): def __iter__(self): return self - def next(self): + def __next__(self): ctx = self.context ctx.index0 += 1 - return next(ctx._iterator), ctx + if ctx._after is _last_iteration: + raise StopIteration() + next_elem = ctx._after + ctx._after = ctx._safe_next() + return next_elem, ctx class Macro(object): @@ -413,6 +438,7 @@ class Macro(object): ) +@implements_to_string class Undefined(object): """The default undefined type. This undefined type can be printed and iterated over, but every other access will raise an :exc:`UndefinedError`: @@ -444,7 +470,7 @@ class Undefined(object): if self._undefined_hint is None: if self._undefined_obj is missing: hint = '%r is undefined' % self._undefined_name - elif not isinstance(self._undefined_name, basestring): + elif not isinstance(self._undefined_name, string_types): hint = '%s has no element %r' % ( object_type_repr(self._undefined_obj), self._undefined_name @@ -458,21 +484,29 @@ class Undefined(object): hint = self._undefined_hint raise self._undefined_exception(hint) + @internalcode + def __getattr__(self, name): + if name[:2] == '__': + raise AttributeError(name) + return self._fail_with_undefined_error() + __add__ = __radd__ = __mul__ = __rmul__ = __div__ = __rdiv__ = \ __truediv__ = __rtruediv__ = __floordiv__ = __rfloordiv__ = \ __mod__ = __rmod__ = __pos__ = __neg__ = __call__ = \ - __getattr__ = __getitem__ = __lt__ = __le__ = __gt__ = __ge__ = \ - __int__ = __float__ = __complex__ = __pow__ = __rpow__ = \ + __getitem__ = __lt__ = __le__ = __gt__ = __ge__ = __int__ = \ + __float__ = __complex__ = __pow__ = __rpow__ = \ _fail_with_undefined_error - def __str__(self): - return unicode(self).encode('utf-8') + def __eq__(self, other): + return type(self) is type(other) + + def __ne__(self, other): + return not self.__eq__(other) - # unicode goes after __str__ because we configured 2to3 to rename - # __unicode__ to __str__. because the 2to3 tree is not designed to - # remove nodes from it, we leave the above __str__ around and let - # it override at runtime. - def __unicode__(self): + def __hash__(self): + return id(type(self)) + + def __str__(self): return u'' def __len__(self): @@ -489,6 +523,7 @@ class Undefined(object): return 'Undefined' +@implements_to_string class DebugUndefined(Undefined): """An undefined that returns the debug info when printed. @@ -504,7 +539,7 @@ class DebugUndefined(Undefined): """ __slots__ = () - def __unicode__(self): + def __str__(self): if self._undefined_hint is None: if self._undefined_obj is missing: return u'{{ %s }}' % self._undefined_name @@ -515,6 +550,7 @@ class DebugUndefined(Undefined): return u'{{ undefined value printed: %s }}' % self._undefined_hint +@implements_to_string class StrictUndefined(Undefined): """An undefined that barks on print and iteration as well as boolean tests and all kinds of comparisons. In other words: you can do nothing @@ -535,8 +571,9 @@ class StrictUndefined(Undefined): UndefinedError: 'foo' is undefined """ __slots__ = () - __iter__ = __unicode__ = __str__ = __len__ = __nonzero__ = __eq__ = \ - __ne__ = __bool__ = Undefined._fail_with_undefined_error + __iter__ = __str__ = __len__ = __nonzero__ = __eq__ = \ + __ne__ = __bool__ = __hash__ = \ + Undefined._fail_with_undefined_error # remove remaining slots attributes, after the metaclass did the magic they diff --git a/module/lib/jinja2/sandbox.py b/module/lib/jinja2/sandbox.py index 749719548..da479c1ba 100644 --- a/module/lib/jinja2/sandbox.py +++ b/module/lib/jinja2/sandbox.py @@ -13,11 +13,10 @@ :license: BSD. """ import operator -from jinja2.runtime import Undefined from jinja2.environment import Environment from jinja2.exceptions import SecurityError -from jinja2.utils import FunctionType, MethodType, TracebackType, CodeType, \ - FrameType, GeneratorType +from jinja2._compat import string_types, function_type, method_type, \ + traceback_type, code_type, frame_type, generator_type, PY2 #: maximum number of items a range may produce @@ -30,6 +29,13 @@ UNSAFE_FUNCTION_ATTRIBUTES = set(['func_closure', 'func_code', 'func_dict', #: unsafe method attributes. function attributes are unsafe for methods too UNSAFE_METHOD_ATTRIBUTES = set(['im_class', 'im_func', 'im_self']) +#: unsafe generator attirbutes. +UNSAFE_GENERATOR_ATTRIBUTES = set(['gi_frame', 'gi_code']) + +# On versions > python 2 the special attributes on functions are gone, +# but they remain on methods and generators for whatever reason. +if not PY2: + UNSAFE_FUNCTION_ATTRIBUTES = set() import warnings @@ -91,7 +97,7 @@ def safe_range(*args): """A range that can't generate ranges with a length of more than MAX_RANGE items. """ - rng = xrange(*args) + rng = range(*args) if len(rng) > MAX_RANGE: raise OverflowError('range too big, maximum size for range is %d' % MAX_RANGE) @@ -99,8 +105,9 @@ def safe_range(*args): def unsafe(f): - """ - Mark a function or method as unsafe:: + """Marks a function or method as unsafe. + + :: @unsafe def delete(self): @@ -114,7 +121,7 @@ def is_internal_attribute(obj, attr): """Test if the attribute given is an internal python attribute. For example this function returns `True` for the `func_code` attribute of python objects. This is useful if the environment method - :meth:`~SandboxedEnvironment.is_safe_attribute` is overriden. + :meth:`~SandboxedEnvironment.is_safe_attribute` is overridden. >>> from jinja2.sandbox import is_internal_attribute >>> is_internal_attribute(lambda: None, "func_code") @@ -124,20 +131,20 @@ def is_internal_attribute(obj, attr): >>> is_internal_attribute(str, "upper") False """ - if isinstance(obj, FunctionType): + if isinstance(obj, function_type): if attr in UNSAFE_FUNCTION_ATTRIBUTES: return True - elif isinstance(obj, MethodType): + elif isinstance(obj, method_type): if attr in UNSAFE_FUNCTION_ATTRIBUTES or \ attr in UNSAFE_METHOD_ATTRIBUTES: return True elif isinstance(obj, type): if attr == 'mro': return True - elif isinstance(obj, (CodeType, TracebackType, FrameType)): + elif isinstance(obj, (code_type, traceback_type, frame_type)): return True - elif isinstance(obj, GeneratorType): - if attr == 'gi_frame': + elif isinstance(obj, generator_type): + if attr in UNSAFE_GENERATOR_ATTRIBUTES: return True return attr.startswith('__') @@ -182,9 +189,81 @@ class SandboxedEnvironment(Environment): """ sandboxed = True + #: default callback table for the binary operators. A copy of this is + #: available on each instance of a sandboxed environment as + #: :attr:`binop_table` + default_binop_table = { + '+': operator.add, + '-': operator.sub, + '*': operator.mul, + '/': operator.truediv, + '//': operator.floordiv, + '**': operator.pow, + '%': operator.mod + } + + #: default callback table for the unary operators. A copy of this is + #: available on each instance of a sandboxed environment as + #: :attr:`unop_table` + default_unop_table = { + '+': operator.pos, + '-': operator.neg + } + + #: a set of binary operators that should be intercepted. Each operator + #: that is added to this set (empty by default) is delegated to the + #: :meth:`call_binop` method that will perform the operator. The default + #: operator callback is specified by :attr:`binop_table`. + #: + #: The following binary operators are interceptable: + #: ``//``, ``%``, ``+``, ``*``, ``-``, ``/``, and ``**`` + #: + #: The default operation form the operator table corresponds to the + #: builtin function. Intercepted calls are always slower than the native + #: operator call, so make sure only to intercept the ones you are + #: interested in. + #: + #: .. versionadded:: 2.6 + intercepted_binops = frozenset() + + #: a set of unary operators that should be intercepted. Each operator + #: that is added to this set (empty by default) is delegated to the + #: :meth:`call_unop` method that will perform the operator. The default + #: operator callback is specified by :attr:`unop_table`. + #: + #: The following unary operators are interceptable: ``+``, ``-`` + #: + #: The default operation form the operator table corresponds to the + #: builtin function. Intercepted calls are always slower than the native + #: operator call, so make sure only to intercept the ones you are + #: interested in. + #: + #: .. versionadded:: 2.6 + intercepted_unops = frozenset() + + def intercept_unop(self, operator): + """Called during template compilation with the name of a unary + operator to check if it should be intercepted at runtime. If this + method returns `True`, :meth:`call_unop` is excuted for this unary + operator. The default implementation of :meth:`call_unop` will use + the :attr:`unop_table` dictionary to perform the operator with the + same logic as the builtin one. + + The following unary operators are interceptable: ``+`` and ``-`` + + Intercepted calls are always slower than the native operator call, + so make sure only to intercept the ones you are interested in. + + .. versionadded:: 2.6 + """ + return False + + def __init__(self, *args, **kwargs): Environment.__init__(self, *args, **kwargs) self.globals['range'] = safe_range + self.binop_table = self.default_binop_table.copy() + self.unop_table = self.default_unop_table.copy() def is_safe_attribute(self, obj, attr, value): """The sandboxed environment will call this method to check if the @@ -201,18 +280,36 @@ class SandboxedEnvironment(Environment): True. Override this method to alter the behavior, but this won't affect the `unsafe` decorator from this module. """ - return not (getattr(obj, 'unsafe_callable', False) or \ + return not (getattr(obj, 'unsafe_callable', False) or getattr(obj, 'alters_data', False)) + def call_binop(self, context, operator, left, right): + """For intercepted binary operator calls (:meth:`intercepted_binops`) + this function is executed instead of the builtin operator. This can + be used to fine tune the behavior of certain operators. + + .. versionadded:: 2.6 + """ + return self.binop_table[operator](left, right) + + def call_unop(self, context, operator, arg): + """For intercepted unary operator calls (:meth:`intercepted_unops`) + this function is executed instead of the builtin operator. This can + be used to fine tune the behavior of certain operators. + + .. versionadded:: 2.6 + """ + return self.unop_table[operator](arg) + def getitem(self, obj, argument): """Subscribe an object from sandboxed code.""" try: return obj[argument] except (TypeError, LookupError): - if isinstance(argument, basestring): + if isinstance(argument, string_types): try: attr = str(argument) - except: + except Exception: pass else: try: diff --git a/module/lib/jinja2/tests.py b/module/lib/jinja2/tests.py index d257eca0a..48a3e0618 100644 --- a/module/lib/jinja2/tests.py +++ b/module/lib/jinja2/tests.py @@ -10,20 +10,14 @@ """ import re from jinja2.runtime import Undefined - -# nose, nothing here to test -__test__ = False +from jinja2._compat import text_type, string_types, mapping_types number_re = re.compile(r'^-?\d+(\.\d+)?$') regex_type = type(number_re) -try: - test_callable = callable -except NameError: - def test_callable(x): - return hasattr(x, '__call__') +test_callable = callable def test_odd(value): @@ -70,22 +64,30 @@ def test_none(value): def test_lower(value): """Return true if the variable is lowercased.""" - return unicode(value).islower() + return text_type(value).islower() def test_upper(value): """Return true if the variable is uppercased.""" - return unicode(value).isupper() + return text_type(value).isupper() def test_string(value): """Return true if the object is a string.""" - return isinstance(value, basestring) + return isinstance(value, string_types) + + +def test_mapping(value): + """Return true if the object is a mapping (dict etc.). + + .. versionadded:: 2.6 + """ + return isinstance(value, mapping_types) def test_number(value): """Return true if the variable is a number.""" - return isinstance(value, (int, long, float, complex)) + return isinstance(value, (int, float, complex)) def test_sequence(value): @@ -137,6 +139,7 @@ TESTS = { 'lower': test_lower, 'upper': test_upper, 'string': test_string, + 'mapping': test_mapping, 'number': test_number, 'sequence': test_sequence, 'iterable': test_iterable, diff --git a/module/lib/jinja2/testsuite/__init__.py b/module/lib/jinja2/testsuite/__init__.py new file mode 100644 index 000000000..635c83e5d --- /dev/null +++ b/module/lib/jinja2/testsuite/__init__.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite + ~~~~~~~~~~~~~~~~ + + All the unittests of Jinja2. These tests can be executed by + either running run-tests.py using multiple Python versions at + the same time. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import os +import re +import sys +import unittest +from traceback import format_exception +from jinja2 import loaders +from jinja2._compat import PY2 + + +here = os.path.dirname(os.path.abspath(__file__)) + +dict_loader = loaders.DictLoader({ + 'justdict.html': 'FOO' +}) +package_loader = loaders.PackageLoader('jinja2.testsuite.res', 'templates') +filesystem_loader = loaders.FileSystemLoader(here + '/res/templates') +function_loader = loaders.FunctionLoader({'justfunction.html': 'FOO'}.get) +choice_loader = loaders.ChoiceLoader([dict_loader, package_loader]) +prefix_loader = loaders.PrefixLoader({ + 'a': filesystem_loader, + 'b': dict_loader +}) + + +class JinjaTestCase(unittest.TestCase): + + ### use only these methods for testing. If you need standard + ### unittest method, wrap them! + + def setup(self): + pass + + def teardown(self): + pass + + def setUp(self): + self.setup() + + def tearDown(self): + self.teardown() + + def assert_equal(self, a, b): + return self.assertEqual(a, b) + + def assert_raises(self, *args, **kwargs): + return self.assertRaises(*args, **kwargs) + + def assert_traceback_matches(self, callback, expected_tb): + try: + callback() + except Exception as e: + tb = format_exception(*sys.exc_info()) + if re.search(expected_tb.strip(), ''.join(tb)) is None: + raise self.fail('Traceback did not match:\n\n%s\nexpected:\n%s' + % (''.join(tb), expected_tb)) + else: + self.fail('Expected exception') + + +def find_all_tests(suite): + """Yields all the tests and their names from a given suite.""" + suites = [suite] + while suites: + s = suites.pop() + try: + suites.extend(s) + except TypeError: + yield s, '%s.%s.%s' % ( + s.__class__.__module__, + s.__class__.__name__, + s._testMethodName + ) + + +class BetterLoader(unittest.TestLoader): + """A nicer loader that solves two problems. First of all we are setting + up tests from different sources and we're doing this programmatically + which breaks the default loading logic so this is required anyways. + Secondly this loader has a nicer interpolation for test names than the + default one so you can just do ``run-tests.py ViewTestCase`` and it + will work. + """ + + def getRootSuite(self): + return suite() + + def loadTestsFromName(self, name, module=None): + root = self.getRootSuite() + if name == 'suite': + return root + + all_tests = [] + for testcase, testname in find_all_tests(root): + if testname == name or \ + testname.endswith('.' + name) or \ + ('.' + name + '.') in testname or \ + testname.startswith(name + '.'): + all_tests.append(testcase) + + if not all_tests: + raise LookupError('could not find test case for "%s"' % name) + + if len(all_tests) == 1: + return all_tests[0] + rv = unittest.TestSuite() + for test in all_tests: + rv.addTest(test) + return rv + + +def suite(): + from jinja2.testsuite import ext, filters, tests, core_tags, \ + loader, inheritance, imports, lexnparse, security, api, \ + regression, debug, utils, bytecode_cache, doctests + suite = unittest.TestSuite() + suite.addTest(ext.suite()) + suite.addTest(filters.suite()) + suite.addTest(tests.suite()) + suite.addTest(core_tags.suite()) + suite.addTest(loader.suite()) + suite.addTest(inheritance.suite()) + suite.addTest(imports.suite()) + suite.addTest(lexnparse.suite()) + suite.addTest(security.suite()) + suite.addTest(api.suite()) + suite.addTest(regression.suite()) + suite.addTest(debug.suite()) + suite.addTest(utils.suite()) + suite.addTest(bytecode_cache.suite()) + + # doctests will not run on python 3 currently. Too many issues + # with that, do not test that on that platform. + if PY2: + suite.addTest(doctests.suite()) + + return suite + + +def main(): + """Runs the testsuite as command line application.""" + try: + unittest.main(testLoader=BetterLoader(), defaultTest='suite') + except Exception as e: + print('Error: %s' % e) diff --git a/module/lib/jinja2/testsuite/api.py b/module/lib/jinja2/testsuite/api.py new file mode 100644 index 000000000..1b68bf8b3 --- /dev/null +++ b/module/lib/jinja2/testsuite/api.py @@ -0,0 +1,261 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.api + ~~~~~~~~~~~~~~~~~~~~ + + Tests the public API and related stuff. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import unittest +import os +import tempfile +import shutil + +from jinja2.testsuite import JinjaTestCase +from jinja2._compat import next + +from jinja2 import Environment, Undefined, DebugUndefined, \ + StrictUndefined, UndefinedError, meta, \ + is_undefined, Template, DictLoader +from jinja2.utils import Cycler + +env = Environment() + + +class ExtendedAPITestCase(JinjaTestCase): + + def test_item_and_attribute(self): + from jinja2.sandbox import SandboxedEnvironment + + for env in Environment(), SandboxedEnvironment(): + # the |list is necessary for python3 + tmpl = env.from_string('{{ foo.items()|list }}') + assert tmpl.render(foo={'items': 42}) == "[('items', 42)]" + tmpl = env.from_string('{{ foo|attr("items")()|list }}') + assert tmpl.render(foo={'items': 42}) == "[('items', 42)]" + tmpl = env.from_string('{{ foo["items"] }}') + assert tmpl.render(foo={'items': 42}) == '42' + + def test_finalizer(self): + def finalize_none_empty(value): + if value is None: + value = u'' + return value + env = Environment(finalize=finalize_none_empty) + tmpl = env.from_string('{% for item in seq %}|{{ item }}{% endfor %}') + assert tmpl.render(seq=(None, 1, "foo")) == '||1|foo' + tmpl = env.from_string('<{{ none }}>') + assert tmpl.render() == '<>' + + def test_cycler(self): + items = 1, 2, 3 + c = Cycler(*items) + for item in items + items: + assert c.current == item + assert next(c) == item + next(c) + assert c.current == 2 + c.reset() + assert c.current == 1 + + def test_expressions(self): + expr = env.compile_expression("foo") + assert expr() is None + assert expr(foo=42) == 42 + expr2 = env.compile_expression("foo", undefined_to_none=False) + assert is_undefined(expr2()) + + expr = env.compile_expression("42 + foo") + assert expr(foo=42) == 84 + + def test_template_passthrough(self): + t = Template('Content') + assert env.get_template(t) is t + assert env.select_template([t]) is t + assert env.get_or_select_template([t]) is t + assert env.get_or_select_template(t) is t + + def test_autoescape_autoselect(self): + def select_autoescape(name): + if name is None or '.' not in name: + return False + return name.endswith('.html') + env = Environment(autoescape=select_autoescape, + loader=DictLoader({ + 'test.txt': '{{ foo }}', + 'test.html': '{{ foo }}' + })) + t = env.get_template('test.txt') + assert t.render(foo='<foo>') == '<foo>' + t = env.get_template('test.html') + assert t.render(foo='<foo>') == '<foo>' + t = env.from_string('{{ foo }}') + assert t.render(foo='<foo>') == '<foo>' + + +class MetaTestCase(JinjaTestCase): + + def test_find_undeclared_variables(self): + ast = env.parse('{% set foo = 42 %}{{ bar + foo }}') + x = meta.find_undeclared_variables(ast) + assert x == set(['bar']) + + ast = env.parse('{% set foo = 42 %}{{ bar + foo }}' + '{% macro meh(x) %}{{ x }}{% endmacro %}' + '{% for item in seq %}{{ muh(item) + meh(seq) }}{% endfor %}') + x = meta.find_undeclared_variables(ast) + assert x == set(['bar', 'seq', 'muh']) + + def test_find_refererenced_templates(self): + ast = env.parse('{% extends "layout.html" %}{% include helper %}') + i = meta.find_referenced_templates(ast) + assert next(i) == 'layout.html' + assert next(i) is None + assert list(i) == [] + + ast = env.parse('{% extends "layout.html" %}' + '{% from "test.html" import a, b as c %}' + '{% import "meh.html" as meh %}' + '{% include "muh.html" %}') + i = meta.find_referenced_templates(ast) + assert list(i) == ['layout.html', 'test.html', 'meh.html', 'muh.html'] + + def test_find_included_templates(self): + ast = env.parse('{% include ["foo.html", "bar.html"] %}') + i = meta.find_referenced_templates(ast) + assert list(i) == ['foo.html', 'bar.html'] + + ast = env.parse('{% include ("foo.html", "bar.html") %}') + i = meta.find_referenced_templates(ast) + assert list(i) == ['foo.html', 'bar.html'] + + ast = env.parse('{% include ["foo.html", "bar.html", foo] %}') + i = meta.find_referenced_templates(ast) + assert list(i) == ['foo.html', 'bar.html', None] + + ast = env.parse('{% include ("foo.html", "bar.html", foo) %}') + i = meta.find_referenced_templates(ast) + assert list(i) == ['foo.html', 'bar.html', None] + + +class StreamingTestCase(JinjaTestCase): + + def test_basic_streaming(self): + tmpl = env.from_string("<ul>{% for item in seq %}<li>{{ loop.index " + "}} - {{ item }}</li>{%- endfor %}</ul>") + stream = tmpl.stream(seq=list(range(4))) + self.assert_equal(next(stream), '<ul>') + self.assert_equal(next(stream), '<li>1 - 0</li>') + self.assert_equal(next(stream), '<li>2 - 1</li>') + self.assert_equal(next(stream), '<li>3 - 2</li>') + self.assert_equal(next(stream), '<li>4 - 3</li>') + self.assert_equal(next(stream), '</ul>') + + def test_buffered_streaming(self): + tmpl = env.from_string("<ul>{% for item in seq %}<li>{{ loop.index " + "}} - {{ item }}</li>{%- endfor %}</ul>") + stream = tmpl.stream(seq=list(range(4))) + stream.enable_buffering(size=3) + self.assert_equal(next(stream), u'<ul><li>1 - 0</li><li>2 - 1</li>') + self.assert_equal(next(stream), u'<li>3 - 2</li><li>4 - 3</li></ul>') + + def test_streaming_behavior(self): + tmpl = env.from_string("") + stream = tmpl.stream() + assert not stream.buffered + stream.enable_buffering(20) + assert stream.buffered + stream.disable_buffering() + assert not stream.buffered + + def test_dump_stream(self): + tmp = tempfile.mkdtemp() + try: + tmpl = env.from_string(u"\u2713") + stream = tmpl.stream() + stream.dump(os.path.join(tmp, 'dump.txt'), 'utf-8') + with open(os.path.join(tmp, 'dump.txt'), 'rb') as f: + self.assertEqual(f.read(), b'\xe2\x9c\x93') + finally: + shutil.rmtree(tmp) + + +class UndefinedTestCase(JinjaTestCase): + + def test_stopiteration_is_undefined(self): + def test(): + raise StopIteration() + t = Template('A{{ test() }}B') + assert t.render(test=test) == 'AB' + t = Template('A{{ test().missingattribute }}B') + self.assert_raises(UndefinedError, t.render, test=test) + + def test_undefined_and_special_attributes(self): + try: + Undefined('Foo').__dict__ + except AttributeError: + pass + else: + assert False, "Expected actual attribute error" + + def test_default_undefined(self): + env = Environment(undefined=Undefined) + self.assert_equal(env.from_string('{{ missing }}').render(), u'') + self.assert_raises(UndefinedError, + env.from_string('{{ missing.attribute }}').render) + self.assert_equal(env.from_string('{{ missing|list }}').render(), '[]') + self.assert_equal(env.from_string('{{ missing is not defined }}').render(), 'True') + self.assert_equal(env.from_string('{{ foo.missing }}').render(foo=42), '') + self.assert_equal(env.from_string('{{ not missing }}').render(), 'True') + + def test_debug_undefined(self): + env = Environment(undefined=DebugUndefined) + self.assert_equal(env.from_string('{{ missing }}').render(), '{{ missing }}') + self.assert_raises(UndefinedError, + env.from_string('{{ missing.attribute }}').render) + self.assert_equal(env.from_string('{{ missing|list }}').render(), '[]') + self.assert_equal(env.from_string('{{ missing is not defined }}').render(), 'True') + self.assert_equal(env.from_string('{{ foo.missing }}').render(foo=42), + u"{{ no such element: int object['missing'] }}") + self.assert_equal(env.from_string('{{ not missing }}').render(), 'True') + + def test_strict_undefined(self): + env = Environment(undefined=StrictUndefined) + self.assert_raises(UndefinedError, env.from_string('{{ missing }}').render) + self.assert_raises(UndefinedError, env.from_string('{{ missing.attribute }}').render) + self.assert_raises(UndefinedError, env.from_string('{{ missing|list }}').render) + self.assert_equal(env.from_string('{{ missing is not defined }}').render(), 'True') + self.assert_raises(UndefinedError, env.from_string('{{ foo.missing }}').render, foo=42) + self.assert_raises(UndefinedError, env.from_string('{{ not missing }}').render) + self.assert_equal(env.from_string('{{ missing|default("default", true) }}').render(), 'default') + + def test_indexing_gives_undefined(self): + t = Template("{{ var[42].foo }}") + self.assert_raises(UndefinedError, t.render, var=0) + + def test_none_gives_proper_error(self): + try: + Environment().getattr(None, 'split')() + except UndefinedError as e: + assert e.message == "'None' has no attribute 'split'" + else: + assert False, 'expected exception' + + def test_object_repr(self): + try: + Undefined(obj=42, name='upper')() + except UndefinedError as e: + assert e.message == "'int object' has no attribute 'upper'" + else: + assert False, 'expected exception' + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(ExtendedAPITestCase)) + suite.addTest(unittest.makeSuite(MetaTestCase)) + suite.addTest(unittest.makeSuite(StreamingTestCase)) + suite.addTest(unittest.makeSuite(UndefinedTestCase)) + return suite diff --git a/module/lib/jinja2/testsuite/bytecode_cache.py b/module/lib/jinja2/testsuite/bytecode_cache.py new file mode 100644 index 000000000..9f5c635b8 --- /dev/null +++ b/module/lib/jinja2/testsuite/bytecode_cache.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.bytecode_cache + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Test bytecode caching + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import unittest + +from jinja2.testsuite import JinjaTestCase, package_loader + +from jinja2 import Environment +from jinja2.bccache import FileSystemBytecodeCache +from jinja2.exceptions import TemplateNotFound + +bytecode_cache = FileSystemBytecodeCache() +env = Environment( + loader=package_loader, + bytecode_cache=bytecode_cache, +) + + +class ByteCodeCacheTestCase(JinjaTestCase): + + def test_simple(self): + tmpl = env.get_template('test.html') + assert tmpl.render().strip() == 'BAR' + self.assert_raises(TemplateNotFound, env.get_template, 'missing.html') + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(ByteCodeCacheTestCase)) + return suite diff --git a/module/lib/jinja2/testsuite/core_tags.py b/module/lib/jinja2/testsuite/core_tags.py new file mode 100644 index 000000000..f1a20fd44 --- /dev/null +++ b/module/lib/jinja2/testsuite/core_tags.py @@ -0,0 +1,305 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.core_tags + ~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Test the core tags like for and if. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import unittest + +from jinja2.testsuite import JinjaTestCase + +from jinja2 import Environment, TemplateSyntaxError, UndefinedError, \ + DictLoader + +env = Environment() + + +class ForLoopTestCase(JinjaTestCase): + + def test_simple(self): + tmpl = env.from_string('{% for item in seq %}{{ item }}{% endfor %}') + assert tmpl.render(seq=list(range(10))) == '0123456789' + + def test_else(self): + tmpl = env.from_string('{% for item in seq %}XXX{% else %}...{% endfor %}') + assert tmpl.render() == '...' + + def test_empty_blocks(self): + tmpl = env.from_string('<{% for item in seq %}{% else %}{% endfor %}>') + assert tmpl.render() == '<>' + + def test_context_vars(self): + tmpl = env.from_string('''{% for item in seq -%} + {{ loop.index }}|{{ loop.index0 }}|{{ loop.revindex }}|{{ + loop.revindex0 }}|{{ loop.first }}|{{ loop.last }}|{{ + loop.length }}###{% endfor %}''') + one, two, _ = tmpl.render(seq=[0, 1]).split('###') + (one_index, one_index0, one_revindex, one_revindex0, one_first, + one_last, one_length) = one.split('|') + (two_index, two_index0, two_revindex, two_revindex0, two_first, + two_last, two_length) = two.split('|') + + assert int(one_index) == 1 and int(two_index) == 2 + assert int(one_index0) == 0 and int(two_index0) == 1 + assert int(one_revindex) == 2 and int(two_revindex) == 1 + assert int(one_revindex0) == 1 and int(two_revindex0) == 0 + assert one_first == 'True' and two_first == 'False' + assert one_last == 'False' and two_last == 'True' + assert one_length == two_length == '2' + + def test_cycling(self): + tmpl = env.from_string('''{% for item in seq %}{{ + loop.cycle('<1>', '<2>') }}{% endfor %}{% + for item in seq %}{{ loop.cycle(*through) }}{% endfor %}''') + output = tmpl.render(seq=list(range(4)), through=('<1>', '<2>')) + assert output == '<1><2>' * 4 + + def test_scope(self): + tmpl = env.from_string('{% for item in seq %}{% endfor %}{{ item }}') + output = tmpl.render(seq=list(range(10))) + assert not output + + def test_varlen(self): + def inner(): + for item in range(5): + yield item + tmpl = env.from_string('{% for item in iter %}{{ item }}{% endfor %}') + output = tmpl.render(iter=inner()) + assert output == '01234' + + def test_noniter(self): + tmpl = env.from_string('{% for item in none %}...{% endfor %}') + self.assert_raises(TypeError, tmpl.render) + + def test_recursive(self): + tmpl = env.from_string('''{% for item in seq recursive -%} + [{{ item.a }}{% if item.b %}<{{ loop(item.b) }}>{% endif %}] + {%- endfor %}''') + assert tmpl.render(seq=[ + dict(a=1, b=[dict(a=1), dict(a=2)]), + dict(a=2, b=[dict(a=1), dict(a=2)]), + dict(a=3, b=[dict(a='a')]) + ]) == '[1<[1][2]>][2<[1][2]>][3<[a]>]' + + def test_recursive_depth0(self): + tmpl = env.from_string('''{% for item in seq recursive -%} + [{{ loop.depth0 }}:{{ item.a }}{% if item.b %}<{{ loop(item.b) }}>{% endif %}] + {%- endfor %}''') + self.assertEqual(tmpl.render(seq=[ + dict(a=1, b=[dict(a=1), dict(a=2)]), + dict(a=2, b=[dict(a=1), dict(a=2)]), + dict(a=3, b=[dict(a='a')]) + ]), '[0:1<[1:1][1:2]>][0:2<[1:1][1:2]>][0:3<[1:a]>]') + + def test_recursive_depth(self): + tmpl = env.from_string('''{% for item in seq recursive -%} + [{{ loop.depth }}:{{ item.a }}{% if item.b %}<{{ loop(item.b) }}>{% endif %}] + {%- endfor %}''') + self.assertEqual(tmpl.render(seq=[ + dict(a=1, b=[dict(a=1), dict(a=2)]), + dict(a=2, b=[dict(a=1), dict(a=2)]), + dict(a=3, b=[dict(a='a')]) + ]), '[1:1<[2:1][2:2]>][1:2<[2:1][2:2]>][1:3<[2:a]>]') + + def test_looploop(self): + tmpl = env.from_string('''{% for row in table %} + {%- set rowloop = loop -%} + {% for cell in row -%} + [{{ rowloop.index }}|{{ loop.index }}] + {%- endfor %} + {%- endfor %}''') + assert tmpl.render(table=['ab', 'cd']) == '[1|1][1|2][2|1][2|2]' + + def test_reversed_bug(self): + tmpl = env.from_string('{% for i in items %}{{ i }}' + '{% if not loop.last %}' + ',{% endif %}{% endfor %}') + assert tmpl.render(items=reversed([3, 2, 1])) == '1,2,3' + + def test_loop_errors(self): + tmpl = env.from_string('''{% for item in [1] if loop.index + == 0 %}...{% endfor %}''') + self.assert_raises(UndefinedError, tmpl.render) + tmpl = env.from_string('''{% for item in [] %}...{% else + %}{{ loop }}{% endfor %}''') + assert tmpl.render() == '' + + def test_loop_filter(self): + tmpl = env.from_string('{% for item in range(10) if item ' + 'is even %}[{{ item }}]{% endfor %}') + assert tmpl.render() == '[0][2][4][6][8]' + tmpl = env.from_string(''' + {%- for item in range(10) if item is even %}[{{ + loop.index }}:{{ item }}]{% endfor %}''') + assert tmpl.render() == '[1:0][2:2][3:4][4:6][5:8]' + + def test_loop_unassignable(self): + self.assert_raises(TemplateSyntaxError, env.from_string, + '{% for loop in seq %}...{% endfor %}') + + def test_scoped_special_var(self): + t = env.from_string('{% for s in seq %}[{{ loop.first }}{% for c in s %}' + '|{{ loop.first }}{% endfor %}]{% endfor %}') + assert t.render(seq=('ab', 'cd')) == '[True|True|False][False|True|False]' + + def test_scoped_loop_var(self): + t = env.from_string('{% for x in seq %}{{ loop.first }}' + '{% for y in seq %}{% endfor %}{% endfor %}') + assert t.render(seq='ab') == 'TrueFalse' + t = env.from_string('{% for x in seq %}{% for y in seq %}' + '{{ loop.first }}{% endfor %}{% endfor %}') + assert t.render(seq='ab') == 'TrueFalseTrueFalse' + + def test_recursive_empty_loop_iter(self): + t = env.from_string(''' + {%- for item in foo recursive -%}{%- endfor -%} + ''') + assert t.render(dict(foo=[])) == '' + + def test_call_in_loop(self): + t = env.from_string(''' + {%- macro do_something() -%} + [{{ caller() }}] + {%- endmacro %} + + {%- for i in [1, 2, 3] %} + {%- call do_something() -%} + {{ i }} + {%- endcall %} + {%- endfor -%} + ''') + assert t.render() == '[1][2][3]' + + def test_scoping_bug(self): + t = env.from_string(''' + {%- for item in foo %}...{{ item }}...{% endfor %} + {%- macro item(a) %}...{{ a }}...{% endmacro %} + {{- item(2) -}} + ''') + assert t.render(foo=(1,)) == '...1......2...' + + def test_unpacking(self): + tmpl = env.from_string('{% for a, b, c in [[1, 2, 3]] %}' + '{{ a }}|{{ b }}|{{ c }}{% endfor %}') + assert tmpl.render() == '1|2|3' + + +class IfConditionTestCase(JinjaTestCase): + + def test_simple(self): + tmpl = env.from_string('''{% if true %}...{% endif %}''') + assert tmpl.render() == '...' + + def test_elif(self): + tmpl = env.from_string('''{% if false %}XXX{% elif true + %}...{% else %}XXX{% endif %}''') + assert tmpl.render() == '...' + + def test_else(self): + tmpl = env.from_string('{% if false %}XXX{% else %}...{% endif %}') + assert tmpl.render() == '...' + + def test_empty(self): + tmpl = env.from_string('[{% if true %}{% else %}{% endif %}]') + assert tmpl.render() == '[]' + + def test_complete(self): + tmpl = env.from_string('{% if a %}A{% elif b %}B{% elif c == d %}' + 'C{% else %}D{% endif %}') + assert tmpl.render(a=0, b=False, c=42, d=42.0) == 'C' + + def test_no_scope(self): + tmpl = env.from_string('{% if a %}{% set foo = 1 %}{% endif %}{{ foo }}') + assert tmpl.render(a=True) == '1' + tmpl = env.from_string('{% if true %}{% set foo = 1 %}{% endif %}{{ foo }}') + assert tmpl.render() == '1' + + +class MacrosTestCase(JinjaTestCase): + env = Environment(trim_blocks=True) + + def test_simple(self): + tmpl = self.env.from_string('''\ +{% macro say_hello(name) %}Hello {{ name }}!{% endmacro %} +{{ say_hello('Peter') }}''') + assert tmpl.render() == 'Hello Peter!' + + def test_scoping(self): + tmpl = self.env.from_string('''\ +{% macro level1(data1) %} +{% macro level2(data2) %}{{ data1 }}|{{ data2 }}{% endmacro %} +{{ level2('bar') }}{% endmacro %} +{{ level1('foo') }}''') + assert tmpl.render() == 'foo|bar' + + def test_arguments(self): + tmpl = self.env.from_string('''\ +{% macro m(a, b, c='c', d='d') %}{{ a }}|{{ b }}|{{ c }}|{{ d }}{% endmacro %} +{{ m() }}|{{ m('a') }}|{{ m('a', 'b') }}|{{ m(1, 2, 3) }}''') + assert tmpl.render() == '||c|d|a||c|d|a|b|c|d|1|2|3|d' + + def test_varargs(self): + tmpl = self.env.from_string('''\ +{% macro test() %}{{ varargs|join('|') }}{% endmacro %}\ +{{ test(1, 2, 3) }}''') + assert tmpl.render() == '1|2|3' + + def test_simple_call(self): + tmpl = self.env.from_string('''\ +{% macro test() %}[[{{ caller() }}]]{% endmacro %}\ +{% call test() %}data{% endcall %}''') + assert tmpl.render() == '[[data]]' + + def test_complex_call(self): + tmpl = self.env.from_string('''\ +{% macro test() %}[[{{ caller('data') }}]]{% endmacro %}\ +{% call(data) test() %}{{ data }}{% endcall %}''') + assert tmpl.render() == '[[data]]' + + def test_caller_undefined(self): + tmpl = self.env.from_string('''\ +{% set caller = 42 %}\ +{% macro test() %}{{ caller is not defined }}{% endmacro %}\ +{{ test() }}''') + assert tmpl.render() == 'True' + + def test_include(self): + self.env = Environment(loader=DictLoader({'include': + '{% macro test(foo) %}[{{ foo }}]{% endmacro %}'})) + tmpl = self.env.from_string('{% from "include" import test %}{{ test("foo") }}') + assert tmpl.render() == '[foo]' + + def test_macro_api(self): + tmpl = self.env.from_string('{% macro foo(a, b) %}{% endmacro %}' + '{% macro bar() %}{{ varargs }}{{ kwargs }}{% endmacro %}' + '{% macro baz() %}{{ caller() }}{% endmacro %}') + assert tmpl.module.foo.arguments == ('a', 'b') + assert tmpl.module.foo.defaults == () + assert tmpl.module.foo.name == 'foo' + assert not tmpl.module.foo.caller + assert not tmpl.module.foo.catch_kwargs + assert not tmpl.module.foo.catch_varargs + assert tmpl.module.bar.arguments == () + assert tmpl.module.bar.defaults == () + assert not tmpl.module.bar.caller + assert tmpl.module.bar.catch_kwargs + assert tmpl.module.bar.catch_varargs + assert tmpl.module.baz.caller + + def test_callself(self): + tmpl = self.env.from_string('{% macro foo(x) %}{{ x }}{% if x > 1 %}|' + '{{ foo(x - 1) }}{% endif %}{% endmacro %}' + '{{ foo(5) }}') + assert tmpl.render() == '5|4|3|2|1' + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(ForLoopTestCase)) + suite.addTest(unittest.makeSuite(IfConditionTestCase)) + suite.addTest(unittest.makeSuite(MacrosTestCase)) + return suite diff --git a/module/lib/jinja2/testsuite/debug.py b/module/lib/jinja2/testsuite/debug.py new file mode 100644 index 000000000..2588a83ea --- /dev/null +++ b/module/lib/jinja2/testsuite/debug.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.debug + ~~~~~~~~~~~~~~~~~~~~~~ + + Tests the debug system. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import unittest + +from jinja2.testsuite import JinjaTestCase, filesystem_loader + +from jinja2 import Environment, TemplateSyntaxError + +env = Environment(loader=filesystem_loader) + + +class DebugTestCase(JinjaTestCase): + + def test_runtime_error(self): + def test(): + tmpl.render(fail=lambda: 1 / 0) + tmpl = env.get_template('broken.html') + self.assert_traceback_matches(test, r''' + File ".*?broken.html", line 2, in (top-level template code|<module>) + \{\{ fail\(\) \}\} + File ".*?debug.pyc?", line \d+, in <lambda> + tmpl\.render\(fail=lambda: 1 / 0\) +ZeroDivisionError: (int(eger)? )?division (or modulo )?by zero +''') + + def test_syntax_error(self): + # XXX: the .*? is necessary for python3 which does not hide + # some of the stack frames we don't want to show. Not sure + # what's up with that, but that is not that critical. Should + # be fixed though. + self.assert_traceback_matches(lambda: env.get_template('syntaxerror.html'), r'''(?sm) + File ".*?syntaxerror.html", line 4, in (template|<module>) + \{% endif %\}.*? +(jinja2\.exceptions\.)?TemplateSyntaxError: Encountered unknown tag 'endif'. Jinja was looking for the following tags: 'endfor' or 'else'. The innermost block that needs to be closed is 'for'. + ''') + + def test_regular_syntax_error(self): + def test(): + raise TemplateSyntaxError('wtf', 42) + self.assert_traceback_matches(test, r''' + File ".*debug.pyc?", line \d+, in test + raise TemplateSyntaxError\('wtf', 42\) +(jinja2\.exceptions\.)?TemplateSyntaxError: wtf + line 42''') + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(DebugTestCase)) + return suite diff --git a/module/lib/jinja2/testsuite/doctests.py b/module/lib/jinja2/testsuite/doctests.py new file mode 100644 index 000000000..616d3b6ee --- /dev/null +++ b/module/lib/jinja2/testsuite/doctests.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.doctests + ~~~~~~~~~~~~~~~~~~~~~~~~~ + + The doctests. Collects all tests we want to test from + the Jinja modules. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import unittest +import doctest + + +def suite(): + from jinja2 import utils, sandbox, runtime, meta, loaders, \ + ext, environment, bccache, nodes + suite = unittest.TestSuite() + suite.addTest(doctest.DocTestSuite(utils)) + suite.addTest(doctest.DocTestSuite(sandbox)) + suite.addTest(doctest.DocTestSuite(runtime)) + suite.addTest(doctest.DocTestSuite(meta)) + suite.addTest(doctest.DocTestSuite(loaders)) + suite.addTest(doctest.DocTestSuite(ext)) + suite.addTest(doctest.DocTestSuite(environment)) + suite.addTest(doctest.DocTestSuite(bccache)) + suite.addTest(doctest.DocTestSuite(nodes)) + return suite diff --git a/module/lib/jinja2/testsuite/ext.py b/module/lib/jinja2/testsuite/ext.py new file mode 100644 index 000000000..0f93be945 --- /dev/null +++ b/module/lib/jinja2/testsuite/ext.py @@ -0,0 +1,459 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.ext + ~~~~~~~~~~~~~~~~~~~~ + + Tests for the extensions. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import re +import unittest + +from jinja2.testsuite import JinjaTestCase + +from jinja2 import Environment, DictLoader, contextfunction, nodes +from jinja2.exceptions import TemplateAssertionError +from jinja2.ext import Extension +from jinja2.lexer import Token, count_newlines +from jinja2._compat import next, BytesIO, itervalues, text_type + +importable_object = 23 + +_gettext_re = re.compile(r'_\((.*?)\)(?s)') + + +i18n_templates = { + 'master.html': '<title>{{ page_title|default(_("missing")) }}</title>' + '{% block body %}{% endblock %}', + 'child.html': '{% extends "master.html" %}{% block body %}' + '{% trans %}watch out{% endtrans %}{% endblock %}', + 'plural.html': '{% trans user_count %}One user online{% pluralize %}' + '{{ user_count }} users online{% endtrans %}', + 'plural2.html': '{% trans user_count=get_user_count() %}{{ user_count }}s' + '{% pluralize %}{{ user_count }}p{% endtrans %}', + 'stringformat.html': '{{ _("User: %(num)s")|format(num=user_count) }}' +} + +newstyle_i18n_templates = { + 'master.html': '<title>{{ page_title|default(_("missing")) }}</title>' + '{% block body %}{% endblock %}', + 'child.html': '{% extends "master.html" %}{% block body %}' + '{% trans %}watch out{% endtrans %}{% endblock %}', + 'plural.html': '{% trans user_count %}One user online{% pluralize %}' + '{{ user_count }} users online{% endtrans %}', + 'stringformat.html': '{{ _("User: %(num)s", num=user_count) }}', + 'ngettext.html': '{{ ngettext("%(num)s apple", "%(num)s apples", apples) }}', + 'ngettext_long.html': '{% trans num=apples %}{{ num }} apple{% pluralize %}' + '{{ num }} apples{% endtrans %}', + 'transvars1.html': '{% trans %}User: {{ num }}{% endtrans %}', + 'transvars2.html': '{% trans num=count %}User: {{ num }}{% endtrans %}', + 'transvars3.html': '{% trans count=num %}User: {{ count }}{% endtrans %}', + 'novars.html': '{% trans %}%(hello)s{% endtrans %}', + 'vars.html': '{% trans %}{{ foo }}%(foo)s{% endtrans %}', + 'explicitvars.html': '{% trans foo="42" %}%(foo)s{% endtrans %}' +} + + +languages = { + 'de': { + 'missing': u'fehlend', + 'watch out': u'pass auf', + 'One user online': u'Ein Benutzer online', + '%(user_count)s users online': u'%(user_count)s Benutzer online', + 'User: %(num)s': u'Benutzer: %(num)s', + 'User: %(count)s': u'Benutzer: %(count)s', + '%(num)s apple': u'%(num)s Apfel', + '%(num)s apples': u'%(num)s Ãpfel' + } +} + + +@contextfunction +def gettext(context, string): + language = context.get('LANGUAGE', 'en') + return languages.get(language, {}).get(string, string) + + +@contextfunction +def ngettext(context, s, p, n): + language = context.get('LANGUAGE', 'en') + if n != 1: + return languages.get(language, {}).get(p, p) + return languages.get(language, {}).get(s, s) + + +i18n_env = Environment( + loader=DictLoader(i18n_templates), + extensions=['jinja2.ext.i18n'] +) +i18n_env.globals.update({ + '_': gettext, + 'gettext': gettext, + 'ngettext': ngettext +}) + +newstyle_i18n_env = Environment( + loader=DictLoader(newstyle_i18n_templates), + extensions=['jinja2.ext.i18n'] +) +newstyle_i18n_env.install_gettext_callables(gettext, ngettext, newstyle=True) + +class TestExtension(Extension): + tags = set(['test']) + ext_attr = 42 + + def parse(self, parser): + return nodes.Output([self.call_method('_dump', [ + nodes.EnvironmentAttribute('sandboxed'), + self.attr('ext_attr'), + nodes.ImportedName(__name__ + '.importable_object'), + nodes.ContextReference() + ])]).set_lineno(next(parser.stream).lineno) + + def _dump(self, sandboxed, ext_attr, imported_object, context): + return '%s|%s|%s|%s' % ( + sandboxed, + ext_attr, + imported_object, + context.blocks + ) + + +class PreprocessorExtension(Extension): + + def preprocess(self, source, name, filename=None): + return source.replace('[[TEST]]', '({{ foo }})') + + +class StreamFilterExtension(Extension): + + def filter_stream(self, stream): + for token in stream: + if token.type == 'data': + for t in self.interpolate(token): + yield t + else: + yield token + + def interpolate(self, token): + pos = 0 + end = len(token.value) + lineno = token.lineno + while 1: + match = _gettext_re.search(token.value, pos) + if match is None: + break + value = token.value[pos:match.start()] + if value: + yield Token(lineno, 'data', value) + lineno += count_newlines(token.value) + yield Token(lineno, 'variable_begin', None) + yield Token(lineno, 'name', 'gettext') + yield Token(lineno, 'lparen', None) + yield Token(lineno, 'string', match.group(1)) + yield Token(lineno, 'rparen', None) + yield Token(lineno, 'variable_end', None) + pos = match.end() + if pos < end: + yield Token(lineno, 'data', token.value[pos:]) + + +class ExtensionsTestCase(JinjaTestCase): + + def test_extend_late(self): + env = Environment() + env.add_extension('jinja2.ext.autoescape') + t = env.from_string('{% autoescape true %}{{ "<test>" }}{% endautoescape %}') + assert t.render() == '<test>' + + def test_loop_controls(self): + env = Environment(extensions=['jinja2.ext.loopcontrols']) + + tmpl = env.from_string(''' + {%- for item in [1, 2, 3, 4] %} + {%- if item % 2 == 0 %}{% continue %}{% endif -%} + {{ item }} + {%- endfor %}''') + assert tmpl.render() == '13' + + tmpl = env.from_string(''' + {%- for item in [1, 2, 3, 4] %} + {%- if item > 2 %}{% break %}{% endif -%} + {{ item }} + {%- endfor %}''') + assert tmpl.render() == '12' + + def test_do(self): + env = Environment(extensions=['jinja2.ext.do']) + tmpl = env.from_string(''' + {%- set items = [] %} + {%- for char in "foo" %} + {%- do items.append(loop.index0 ~ char) %} + {%- endfor %}{{ items|join(', ') }}''') + assert tmpl.render() == '0f, 1o, 2o' + + def test_with(self): + env = Environment(extensions=['jinja2.ext.with_']) + tmpl = env.from_string('''\ + {% with a=42, b=23 -%} + {{ a }} = {{ b }} + {% endwith -%} + {{ a }} = {{ b }}\ + ''') + assert [x.strip() for x in tmpl.render(a=1, b=2).splitlines()] \ + == ['42 = 23', '1 = 2'] + + def test_extension_nodes(self): + env = Environment(extensions=[TestExtension]) + tmpl = env.from_string('{% test %}') + assert tmpl.render() == 'False|42|23|{}' + + def test_identifier(self): + assert TestExtension.identifier == __name__ + '.TestExtension' + + def test_rebinding(self): + original = Environment(extensions=[TestExtension]) + overlay = original.overlay() + for env in original, overlay: + for ext in itervalues(env.extensions): + assert ext.environment is env + + def test_preprocessor_extension(self): + env = Environment(extensions=[PreprocessorExtension]) + tmpl = env.from_string('{[[TEST]]}') + assert tmpl.render(foo=42) == '{(42)}' + + def test_streamfilter_extension(self): + env = Environment(extensions=[StreamFilterExtension]) + env.globals['gettext'] = lambda x: x.upper() + tmpl = env.from_string('Foo _(bar) Baz') + out = tmpl.render() + assert out == 'Foo BAR Baz' + + def test_extension_ordering(self): + class T1(Extension): + priority = 1 + class T2(Extension): + priority = 2 + env = Environment(extensions=[T1, T2]) + ext = list(env.iter_extensions()) + assert ext[0].__class__ is T1 + assert ext[1].__class__ is T2 + + +class InternationalizationTestCase(JinjaTestCase): + + def test_trans(self): + tmpl = i18n_env.get_template('child.html') + assert tmpl.render(LANGUAGE='de') == '<title>fehlend</title>pass auf' + + def test_trans_plural(self): + tmpl = i18n_env.get_template('plural.html') + assert tmpl.render(LANGUAGE='de', user_count=1) == 'Ein Benutzer online' + assert tmpl.render(LANGUAGE='de', user_count=2) == '2 Benutzer online' + + def test_trans_plural_with_functions(self): + tmpl = i18n_env.get_template('plural2.html') + def get_user_count(): + get_user_count.called += 1 + return 1 + get_user_count.called = 0 + assert tmpl.render(LANGUAGE='de', get_user_count=get_user_count) == '1s' + assert get_user_count.called == 1 + + def test_complex_plural(self): + tmpl = i18n_env.from_string('{% trans foo=42, count=2 %}{{ count }} item{% ' + 'pluralize count %}{{ count }} items{% endtrans %}') + assert tmpl.render() == '2 items' + self.assert_raises(TemplateAssertionError, i18n_env.from_string, + '{% trans foo %}...{% pluralize bar %}...{% endtrans %}') + + def test_trans_stringformatting(self): + tmpl = i18n_env.get_template('stringformat.html') + assert tmpl.render(LANGUAGE='de', user_count=5) == 'Benutzer: 5' + + def test_extract(self): + from jinja2.ext import babel_extract + source = BytesIO(''' + {{ gettext('Hello World') }} + {% trans %}Hello World{% endtrans %} + {% trans %}{{ users }} user{% pluralize %}{{ users }} users{% endtrans %} + '''.encode('ascii')) # make python 3 happy + assert list(babel_extract(source, ('gettext', 'ngettext', '_'), [], {})) == [ + (2, 'gettext', u'Hello World', []), + (3, 'gettext', u'Hello World', []), + (4, 'ngettext', (u'%(users)s user', u'%(users)s users', None), []) + ] + + def test_comment_extract(self): + from jinja2.ext import babel_extract + source = BytesIO(''' + {# trans first #} + {{ gettext('Hello World') }} + {% trans %}Hello World{% endtrans %}{# trans second #} + {#: third #} + {% trans %}{{ users }} user{% pluralize %}{{ users }} users{% endtrans %} + '''.encode('utf-8')) # make python 3 happy + assert list(babel_extract(source, ('gettext', 'ngettext', '_'), ['trans', ':'], {})) == [ + (3, 'gettext', u'Hello World', ['first']), + (4, 'gettext', u'Hello World', ['second']), + (6, 'ngettext', (u'%(users)s user', u'%(users)s users', None), ['third']) + ] + + +class NewstyleInternationalizationTestCase(JinjaTestCase): + + def test_trans(self): + tmpl = newstyle_i18n_env.get_template('child.html') + assert tmpl.render(LANGUAGE='de') == '<title>fehlend</title>pass auf' + + def test_trans_plural(self): + tmpl = newstyle_i18n_env.get_template('plural.html') + assert tmpl.render(LANGUAGE='de', user_count=1) == 'Ein Benutzer online' + assert tmpl.render(LANGUAGE='de', user_count=2) == '2 Benutzer online' + + def test_complex_plural(self): + tmpl = newstyle_i18n_env.from_string('{% trans foo=42, count=2 %}{{ count }} item{% ' + 'pluralize count %}{{ count }} items{% endtrans %}') + assert tmpl.render() == '2 items' + self.assert_raises(TemplateAssertionError, i18n_env.from_string, + '{% trans foo %}...{% pluralize bar %}...{% endtrans %}') + + def test_trans_stringformatting(self): + tmpl = newstyle_i18n_env.get_template('stringformat.html') + assert tmpl.render(LANGUAGE='de', user_count=5) == 'Benutzer: 5' + + def test_newstyle_plural(self): + tmpl = newstyle_i18n_env.get_template('ngettext.html') + assert tmpl.render(LANGUAGE='de', apples=1) == '1 Apfel' + assert tmpl.render(LANGUAGE='de', apples=5) == u'5 Ãpfel' + + def test_autoescape_support(self): + env = Environment(extensions=['jinja2.ext.autoescape', + 'jinja2.ext.i18n']) + env.install_gettext_callables(lambda x: u'<strong>Wert: %(name)s</strong>', + lambda s, p, n: s, newstyle=True) + t = env.from_string('{% autoescape ae %}{{ gettext("foo", name=' + '"<test>") }}{% endautoescape %}') + assert t.render(ae=True) == '<strong>Wert: <test></strong>' + assert t.render(ae=False) == '<strong>Wert: <test></strong>' + + def test_num_used_twice(self): + tmpl = newstyle_i18n_env.get_template('ngettext_long.html') + assert tmpl.render(apples=5, LANGUAGE='de') == u'5 Ãpfel' + + def test_num_called_num(self): + source = newstyle_i18n_env.compile(''' + {% trans num=3 %}{{ num }} apple{% pluralize + %}{{ num }} apples{% endtrans %} + ''', raw=True) + # quite hacky, but the only way to properly test that. The idea is + # that the generated code does not pass num twice (although that + # would work) for better performance. This only works on the + # newstyle gettext of course + assert re.search(r"l_ngettext, u?'\%\(num\)s apple', u?'\%\(num\)s " + r"apples', 3", source) is not None + + def test_trans_vars(self): + t1 = newstyle_i18n_env.get_template('transvars1.html') + t2 = newstyle_i18n_env.get_template('transvars2.html') + t3 = newstyle_i18n_env.get_template('transvars3.html') + assert t1.render(num=1, LANGUAGE='de') == 'Benutzer: 1' + assert t2.render(count=23, LANGUAGE='de') == 'Benutzer: 23' + assert t3.render(num=42, LANGUAGE='de') == 'Benutzer: 42' + + def test_novars_vars_escaping(self): + t = newstyle_i18n_env.get_template('novars.html') + assert t.render() == '%(hello)s' + t = newstyle_i18n_env.get_template('vars.html') + assert t.render(foo='42') == '42%(foo)s' + t = newstyle_i18n_env.get_template('explicitvars.html') + assert t.render() == '%(foo)s' + + +class AutoEscapeTestCase(JinjaTestCase): + + def test_scoped_setting(self): + env = Environment(extensions=['jinja2.ext.autoescape'], + autoescape=True) + tmpl = env.from_string(''' + {{ "<HelloWorld>" }} + {% autoescape false %} + {{ "<HelloWorld>" }} + {% endautoescape %} + {{ "<HelloWorld>" }} + ''') + assert tmpl.render().split() == \ + [u'<HelloWorld>', u'<HelloWorld>', u'<HelloWorld>'] + + env = Environment(extensions=['jinja2.ext.autoescape'], + autoescape=False) + tmpl = env.from_string(''' + {{ "<HelloWorld>" }} + {% autoescape true %} + {{ "<HelloWorld>" }} + {% endautoescape %} + {{ "<HelloWorld>" }} + ''') + assert tmpl.render().split() == \ + [u'<HelloWorld>', u'<HelloWorld>', u'<HelloWorld>'] + + def test_nonvolatile(self): + env = Environment(extensions=['jinja2.ext.autoescape'], + autoescape=True) + tmpl = env.from_string('{{ {"foo": "<test>"}|xmlattr|escape }}') + assert tmpl.render() == ' foo="<test>"' + tmpl = env.from_string('{% autoescape false %}{{ {"foo": "<test>"}' + '|xmlattr|escape }}{% endautoescape %}') + assert tmpl.render() == ' foo="&lt;test&gt;"' + + def test_volatile(self): + env = Environment(extensions=['jinja2.ext.autoescape'], + autoescape=True) + tmpl = env.from_string('{% autoescape foo %}{{ {"foo": "<test>"}' + '|xmlattr|escape }}{% endautoescape %}') + assert tmpl.render(foo=False) == ' foo="&lt;test&gt;"' + assert tmpl.render(foo=True) == ' foo="<test>"' + + def test_scoping(self): + env = Environment(extensions=['jinja2.ext.autoescape']) + tmpl = env.from_string('{% autoescape true %}{% set x = "<x>" %}{{ x }}' + '{% endautoescape %}{{ x }}{{ "<y>" }}') + assert tmpl.render(x=1) == '<x>1<y>' + + def test_volatile_scoping(self): + env = Environment(extensions=['jinja2.ext.autoescape']) + tmplsource = ''' + {% autoescape val %} + {% macro foo(x) %} + [{{ x }}] + {% endmacro %} + {{ foo().__class__.__name__ }} + {% endautoescape %} + {{ '<testing>' }} + ''' + tmpl = env.from_string(tmplsource) + assert tmpl.render(val=True).split()[0] == 'Markup' + assert tmpl.render(val=False).split()[0] == text_type.__name__ + + # looking at the source we should see <testing> there in raw + # (and then escaped as well) + env = Environment(extensions=['jinja2.ext.autoescape']) + pysource = env.compile(tmplsource, raw=True) + assert '<testing>\\n' in pysource + + env = Environment(extensions=['jinja2.ext.autoescape'], + autoescape=True) + pysource = env.compile(tmplsource, raw=True) + assert '<testing>\\n' in pysource + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(ExtensionsTestCase)) + suite.addTest(unittest.makeSuite(InternationalizationTestCase)) + suite.addTest(unittest.makeSuite(NewstyleInternationalizationTestCase)) + suite.addTest(unittest.makeSuite(AutoEscapeTestCase)) + return suite diff --git a/module/lib/jinja2/testsuite/filters.py b/module/lib/jinja2/testsuite/filters.py new file mode 100644 index 000000000..282dd2d85 --- /dev/null +++ b/module/lib/jinja2/testsuite/filters.py @@ -0,0 +1,515 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.filters + ~~~~~~~~~~~~~~~~~~~~~~~~ + + Tests for the jinja filters. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import unittest +from jinja2.testsuite import JinjaTestCase + +from jinja2 import Markup, Environment +from jinja2._compat import text_type, implements_to_string + +env = Environment() + + +class FilterTestCase(JinjaTestCase): + + def test_filter_calling(self): + rv = env.call_filter('sum', [1, 2, 3]) + self.assert_equal(rv, 6) + + def test_capitalize(self): + tmpl = env.from_string('{{ "foo bar"|capitalize }}') + assert tmpl.render() == 'Foo bar' + + def test_center(self): + tmpl = env.from_string('{{ "foo"|center(9) }}') + assert tmpl.render() == ' foo ' + + def test_default(self): + tmpl = env.from_string( + "{{ missing|default('no') }}|{{ false|default('no') }}|" + "{{ false|default('no', true) }}|{{ given|default('no') }}" + ) + assert tmpl.render(given='yes') == 'no|False|no|yes' + + def test_dictsort(self): + tmpl = env.from_string( + '{{ foo|dictsort }}|' + '{{ foo|dictsort(true) }}|' + '{{ foo|dictsort(false, "value") }}' + ) + out = tmpl.render(foo={"aa": 0, "b": 1, "c": 2, "AB": 3}) + assert out == ("[('aa', 0), ('AB', 3), ('b', 1), ('c', 2)]|" + "[('AB', 3), ('aa', 0), ('b', 1), ('c', 2)]|" + "[('aa', 0), ('b', 1), ('c', 2), ('AB', 3)]") + + def test_batch(self): + tmpl = env.from_string("{{ foo|batch(3)|list }}|" + "{{ foo|batch(3, 'X')|list }}") + out = tmpl.render(foo=list(range(10))) + assert out == ("[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]|" + "[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 'X', 'X']]") + + def test_slice(self): + tmpl = env.from_string('{{ foo|slice(3)|list }}|' + '{{ foo|slice(3, "X")|list }}') + out = tmpl.render(foo=list(range(10))) + assert out == ("[[0, 1, 2, 3], [4, 5, 6], [7, 8, 9]]|" + "[[0, 1, 2, 3], [4, 5, 6, 'X'], [7, 8, 9, 'X']]") + + def test_escape(self): + tmpl = env.from_string('''{{ '<">&'|escape }}''') + out = tmpl.render() + assert out == '<">&' + + def test_striptags(self): + tmpl = env.from_string('''{{ foo|striptags }}''') + out = tmpl.render(foo=' <p>just a small \n <a href="#">' + 'example</a> link</p>\n<p>to a webpage</p> ' + '<!-- <p>and some commented stuff</p> -->') + assert out == 'just a small example link to a webpage' + + def test_filesizeformat(self): + tmpl = env.from_string( + '{{ 100|filesizeformat }}|' + '{{ 1000|filesizeformat }}|' + '{{ 1000000|filesizeformat }}|' + '{{ 1000000000|filesizeformat }}|' + '{{ 1000000000000|filesizeformat }}|' + '{{ 100|filesizeformat(true) }}|' + '{{ 1000|filesizeformat(true) }}|' + '{{ 1000000|filesizeformat(true) }}|' + '{{ 1000000000|filesizeformat(true) }}|' + '{{ 1000000000000|filesizeformat(true) }}' + ) + out = tmpl.render() + self.assert_equal(out, ( + '100 Bytes|1.0 kB|1.0 MB|1.0 GB|1.0 TB|100 Bytes|' + '1000 Bytes|976.6 KiB|953.7 MiB|931.3 GiB' + )) + + def test_filesizeformat_issue59(self): + tmpl = env.from_string( + '{{ 300|filesizeformat }}|' + '{{ 3000|filesizeformat }}|' + '{{ 3000000|filesizeformat }}|' + '{{ 3000000000|filesizeformat }}|' + '{{ 3000000000000|filesizeformat }}|' + '{{ 300|filesizeformat(true) }}|' + '{{ 3000|filesizeformat(true) }}|' + '{{ 3000000|filesizeformat(true) }}' + ) + out = tmpl.render() + self.assert_equal(out, ( + '300 Bytes|3.0 kB|3.0 MB|3.0 GB|3.0 TB|300 Bytes|' + '2.9 KiB|2.9 MiB' + )) + + + def test_first(self): + tmpl = env.from_string('{{ foo|first }}') + out = tmpl.render(foo=list(range(10))) + assert out == '0' + + def test_float(self): + tmpl = env.from_string('{{ "42"|float }}|' + '{{ "ajsghasjgd"|float }}|' + '{{ "32.32"|float }}') + out = tmpl.render() + assert out == '42.0|0.0|32.32' + + def test_format(self): + tmpl = env.from_string('''{{ "%s|%s"|format("a", "b") }}''') + out = tmpl.render() + assert out == 'a|b' + + def test_indent(self): + tmpl = env.from_string('{{ foo|indent(2) }}|{{ foo|indent(2, true) }}') + text = '\n'.join([' '.join(['foo', 'bar'] * 2)] * 2) + out = tmpl.render(foo=text) + assert out == ('foo bar foo bar\n foo bar foo bar| ' + 'foo bar foo bar\n foo bar foo bar') + + def test_int(self): + tmpl = env.from_string('{{ "42"|int }}|{{ "ajsghasjgd"|int }}|' + '{{ "32.32"|int }}') + out = tmpl.render() + assert out == '42|0|32' + + def test_join(self): + tmpl = env.from_string('{{ [1, 2, 3]|join("|") }}') + out = tmpl.render() + assert out == '1|2|3' + + env2 = Environment(autoescape=True) + tmpl = env2.from_string('{{ ["<foo>", "<span>foo</span>"|safe]|join }}') + assert tmpl.render() == '<foo><span>foo</span>' + + def test_join_attribute(self): + class User(object): + def __init__(self, username): + self.username = username + tmpl = env.from_string('''{{ users|join(', ', 'username') }}''') + assert tmpl.render(users=map(User, ['foo', 'bar'])) == 'foo, bar' + + def test_last(self): + tmpl = env.from_string('''{{ foo|last }}''') + out = tmpl.render(foo=list(range(10))) + assert out == '9' + + def test_length(self): + tmpl = env.from_string('''{{ "hello world"|length }}''') + out = tmpl.render() + assert out == '11' + + def test_lower(self): + tmpl = env.from_string('''{{ "FOO"|lower }}''') + out = tmpl.render() + assert out == 'foo' + + def test_pprint(self): + from pprint import pformat + tmpl = env.from_string('''{{ data|pprint }}''') + data = list(range(1000)) + assert tmpl.render(data=data) == pformat(data) + + def test_random(self): + tmpl = env.from_string('''{{ seq|random }}''') + seq = list(range(100)) + for _ in range(10): + assert int(tmpl.render(seq=seq)) in seq + + def test_reverse(self): + tmpl = env.from_string('{{ "foobar"|reverse|join }}|' + '{{ [1, 2, 3]|reverse|list }}') + assert tmpl.render() == 'raboof|[3, 2, 1]' + + def test_string(self): + x = [1, 2, 3, 4, 5] + tmpl = env.from_string('''{{ obj|string }}''') + assert tmpl.render(obj=x) == text_type(x) + + def test_title(self): + tmpl = env.from_string('''{{ "foo bar"|title }}''') + assert tmpl.render() == "Foo Bar" + tmpl = env.from_string('''{{ "foo's bar"|title }}''') + assert tmpl.render() == "Foo's Bar" + tmpl = env.from_string('''{{ "foo bar"|title }}''') + assert tmpl.render() == "Foo Bar" + tmpl = env.from_string('''{{ "f bar f"|title }}''') + assert tmpl.render() == "F Bar F" + tmpl = env.from_string('''{{ "foo-bar"|title }}''') + assert tmpl.render() == "Foo-Bar" + tmpl = env.from_string('''{{ "foo\tbar"|title }}''') + assert tmpl.render() == "Foo\tBar" + tmpl = env.from_string('''{{ "FOO\tBAR"|title }}''') + assert tmpl.render() == "Foo\tBar" + + def test_truncate(self): + tmpl = env.from_string( + '{{ data|truncate(15, true, ">>>") }}|' + '{{ data|truncate(15, false, ">>>") }}|' + '{{ smalldata|truncate(15) }}' + ) + out = tmpl.render(data='foobar baz bar' * 1000, + smalldata='foobar baz bar') + assert out == 'foobar baz barf>>>|foobar baz >>>|foobar baz bar' + + def test_upper(self): + tmpl = env.from_string('{{ "foo"|upper }}') + assert tmpl.render() == 'FOO' + + def test_urlize(self): + tmpl = env.from_string('{{ "foo http://www.example.com/ bar"|urlize }}') + assert tmpl.render() == 'foo <a href="http://www.example.com/">'\ + 'http://www.example.com/</a> bar' + + def test_wordcount(self): + tmpl = env.from_string('{{ "foo bar baz"|wordcount }}') + assert tmpl.render() == '3' + + def test_block(self): + tmpl = env.from_string('{% filter lower|escape %}<HEHE>{% endfilter %}') + assert tmpl.render() == '<hehe>' + + def test_chaining(self): + tmpl = env.from_string('''{{ ['<foo>', '<bar>']|first|upper|escape }}''') + assert tmpl.render() == '<FOO>' + + def test_sum(self): + tmpl = env.from_string('''{{ [1, 2, 3, 4, 5, 6]|sum }}''') + assert tmpl.render() == '21' + + def test_sum_attributes(self): + tmpl = env.from_string('''{{ values|sum('value') }}''') + assert tmpl.render(values=[ + {'value': 23}, + {'value': 1}, + {'value': 18}, + ]) == '42' + + def test_sum_attributes_nested(self): + tmpl = env.from_string('''{{ values|sum('real.value') }}''') + assert tmpl.render(values=[ + {'real': {'value': 23}}, + {'real': {'value': 1}}, + {'real': {'value': 18}}, + ]) == '42' + + def test_sum_attributes_tuple(self): + tmpl = env.from_string('''{{ values.items()|sum('1') }}''') + assert tmpl.render(values={ + 'foo': 23, + 'bar': 1, + 'baz': 18, + }) == '42' + + def test_abs(self): + tmpl = env.from_string('''{{ -1|abs }}|{{ 1|abs }}''') + assert tmpl.render() == '1|1', tmpl.render() + + def test_round_positive(self): + tmpl = env.from_string('{{ 2.7|round }}|{{ 2.1|round }}|' + "{{ 2.1234|round(3, 'floor') }}|" + "{{ 2.1|round(0, 'ceil') }}") + assert tmpl.render() == '3.0|2.0|2.123|3.0', tmpl.render() + + def test_round_negative(self): + tmpl = env.from_string('{{ 21.3|round(-1)}}|' + "{{ 21.3|round(-1, 'ceil')}}|" + "{{ 21.3|round(-1, 'floor')}}") + assert tmpl.render() == '20.0|30.0|20.0',tmpl.render() + + def test_xmlattr(self): + tmpl = env.from_string("{{ {'foo': 42, 'bar': 23, 'fish': none, " + "'spam': missing, 'blub:blub': '<?>'}|xmlattr }}") + out = tmpl.render().split() + assert len(out) == 3 + assert 'foo="42"' in out + assert 'bar="23"' in out + assert 'blub:blub="<?>"' in out + + def test_sort1(self): + tmpl = env.from_string('{{ [2, 3, 1]|sort }}|{{ [2, 3, 1]|sort(true) }}') + assert tmpl.render() == '[1, 2, 3]|[3, 2, 1]' + + def test_sort2(self): + tmpl = env.from_string('{{ "".join(["c", "A", "b", "D"]|sort) }}') + assert tmpl.render() == 'AbcD' + + def test_sort3(self): + tmpl = env.from_string('''{{ ['foo', 'Bar', 'blah']|sort }}''') + assert tmpl.render() == "['Bar', 'blah', 'foo']" + + def test_sort4(self): + @implements_to_string + class Magic(object): + def __init__(self, value): + self.value = value + def __str__(self): + return text_type(self.value) + tmpl = env.from_string('''{{ items|sort(attribute='value')|join }}''') + assert tmpl.render(items=map(Magic, [3, 2, 4, 1])) == '1234' + + def test_groupby(self): + tmpl = env.from_string(''' + {%- for grouper, list in [{'foo': 1, 'bar': 2}, + {'foo': 2, 'bar': 3}, + {'foo': 1, 'bar': 1}, + {'foo': 3, 'bar': 4}]|groupby('foo') -%} + {{ grouper }}{% for x in list %}: {{ x.foo }}, {{ x.bar }}{% endfor %}| + {%- endfor %}''') + assert tmpl.render().split('|') == [ + "1: 1, 2: 1, 1", + "2: 2, 3", + "3: 3, 4", + "" + ] + + def test_groupby_tuple_index(self): + tmpl = env.from_string(''' + {%- for grouper, list in [('a', 1), ('a', 2), ('b', 1)]|groupby(0) -%} + {{ grouper }}{% for x in list %}:{{ x.1 }}{% endfor %}| + {%- endfor %}''') + assert tmpl.render() == 'a:1:2|b:1|' + + def test_groupby_multidot(self): + class Date(object): + def __init__(self, day, month, year): + self.day = day + self.month = month + self.year = year + class Article(object): + def __init__(self, title, *date): + self.date = Date(*date) + self.title = title + articles = [ + Article('aha', 1, 1, 1970), + Article('interesting', 2, 1, 1970), + Article('really?', 3, 1, 1970), + Article('totally not', 1, 1, 1971) + ] + tmpl = env.from_string(''' + {%- for year, list in articles|groupby('date.year') -%} + {{ year }}{% for x in list %}[{{ x.title }}]{% endfor %}| + {%- endfor %}''') + assert tmpl.render(articles=articles).split('|') == [ + '1970[aha][interesting][really?]', + '1971[totally not]', + '' + ] + + def test_filtertag(self): + tmpl = env.from_string("{% filter upper|replace('FOO', 'foo') %}" + "foobar{% endfilter %}") + assert tmpl.render() == 'fooBAR' + + def test_replace(self): + env = Environment() + tmpl = env.from_string('{{ string|replace("o", 42) }}') + assert tmpl.render(string='<foo>') == '<f4242>' + env = Environment(autoescape=True) + tmpl = env.from_string('{{ string|replace("o", 42) }}') + assert tmpl.render(string='<foo>') == '<f4242>' + tmpl = env.from_string('{{ string|replace("<", 42) }}') + assert tmpl.render(string='<foo>') == '42foo>' + tmpl = env.from_string('{{ string|replace("o", ">x<") }}') + assert tmpl.render(string=Markup('foo')) == 'f>x<>x<' + + def test_forceescape(self): + tmpl = env.from_string('{{ x|forceescape }}') + assert tmpl.render(x=Markup('<div />')) == u'<div />' + + def test_safe(self): + env = Environment(autoescape=True) + tmpl = env.from_string('{{ "<div>foo</div>"|safe }}') + assert tmpl.render() == '<div>foo</div>' + tmpl = env.from_string('{{ "<div>foo</div>" }}') + assert tmpl.render() == '<div>foo</div>' + + def test_urlencode(self): + env = Environment(autoescape=True) + tmpl = env.from_string('{{ "Hello, world!"|urlencode }}') + assert tmpl.render() == 'Hello%2C%20world%21' + tmpl = env.from_string('{{ o|urlencode }}') + assert tmpl.render(o=u"Hello, world\u203d") == "Hello%2C%20world%E2%80%BD" + assert tmpl.render(o=(("f", 1),)) == "f=1" + assert tmpl.render(o=(('f', 1), ("z", 2))) == "f=1&z=2" + assert tmpl.render(o=((u"\u203d", 1),)) == "%E2%80%BD=1" + assert tmpl.render(o={u"\u203d": 1}) == "%E2%80%BD=1" + assert tmpl.render(o={0: 1}) == "0=1" + + def test_simple_map(self): + env = Environment() + tmpl = env.from_string('{{ ["1", "2", "3"]|map("int")|sum }}') + self.assertEqual(tmpl.render(), '6') + + def test_attribute_map(self): + class User(object): + def __init__(self, name): + self.name = name + env = Environment() + users = [ + User('john'), + User('jane'), + User('mike'), + ] + tmpl = env.from_string('{{ users|map(attribute="name")|join("|") }}') + self.assertEqual(tmpl.render(users=users), 'john|jane|mike') + + def test_empty_map(self): + env = Environment() + tmpl = env.from_string('{{ none|map("upper")|list }}') + self.assertEqual(tmpl.render(), '[]') + + def test_simple_select(self): + env = Environment() + tmpl = env.from_string('{{ [1, 2, 3, 4, 5]|select("odd")|join("|") }}') + self.assertEqual(tmpl.render(), '1|3|5') + + def test_bool_select(self): + env = Environment() + tmpl = env.from_string('{{ [none, false, 0, 1, 2, 3, 4, 5]|select|join("|") }}') + self.assertEqual(tmpl.render(), '1|2|3|4|5') + + def test_simple_reject(self): + env = Environment() + tmpl = env.from_string('{{ [1, 2, 3, 4, 5]|reject("odd")|join("|") }}') + self.assertEqual(tmpl.render(), '2|4') + + def test_bool_reject(self): + env = Environment() + tmpl = env.from_string('{{ [none, false, 0, 1, 2, 3, 4, 5]|reject|join("|") }}') + self.assertEqual(tmpl.render(), 'None|False|0') + + def test_simple_select_attr(self): + class User(object): + def __init__(self, name, is_active): + self.name = name + self.is_active = is_active + env = Environment() + users = [ + User('john', True), + User('jane', True), + User('mike', False), + ] + tmpl = env.from_string('{{ users|selectattr("is_active")|' + 'map(attribute="name")|join("|") }}') + self.assertEqual(tmpl.render(users=users), 'john|jane') + + def test_simple_reject_attr(self): + class User(object): + def __init__(self, name, is_active): + self.name = name + self.is_active = is_active + env = Environment() + users = [ + User('john', True), + User('jane', True), + User('mike', False), + ] + tmpl = env.from_string('{{ users|rejectattr("is_active")|' + 'map(attribute="name")|join("|") }}') + self.assertEqual(tmpl.render(users=users), 'mike') + + def test_func_select_attr(self): + class User(object): + def __init__(self, id, name): + self.id = id + self.name = name + env = Environment() + users = [ + User(1, 'john'), + User(2, 'jane'), + User(3, 'mike'), + ] + tmpl = env.from_string('{{ users|selectattr("id", "odd")|' + 'map(attribute="name")|join("|") }}') + self.assertEqual(tmpl.render(users=users), 'john|mike') + + def test_func_reject_attr(self): + class User(object): + def __init__(self, id, name): + self.id = id + self.name = name + env = Environment() + users = [ + User(1, 'john'), + User(2, 'jane'), + User(3, 'mike'), + ] + tmpl = env.from_string('{{ users|rejectattr("id", "odd")|' + 'map(attribute="name")|join("|") }}') + self.assertEqual(tmpl.render(users=users), 'jane') + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(FilterTestCase)) + return suite diff --git a/module/lib/jinja2/testsuite/imports.py b/module/lib/jinja2/testsuite/imports.py new file mode 100644 index 000000000..3db9008de --- /dev/null +++ b/module/lib/jinja2/testsuite/imports.py @@ -0,0 +1,141 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.imports + ~~~~~~~~~~~~~~~~~~~~~~~~ + + Tests the import features (with includes). + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import unittest + +from jinja2.testsuite import JinjaTestCase + +from jinja2 import Environment, DictLoader +from jinja2.exceptions import TemplateNotFound, TemplatesNotFound + + +test_env = Environment(loader=DictLoader(dict( + module='{% macro test() %}[{{ foo }}|{{ bar }}]{% endmacro %}', + header='[{{ foo }}|{{ 23 }}]', + o_printer='({{ o }})' +))) +test_env.globals['bar'] = 23 + + +class ImportsTestCase(JinjaTestCase): + + def test_context_imports(self): + t = test_env.from_string('{% import "module" as m %}{{ m.test() }}') + assert t.render(foo=42) == '[|23]' + t = test_env.from_string('{% import "module" as m without context %}{{ m.test() }}') + assert t.render(foo=42) == '[|23]' + t = test_env.from_string('{% import "module" as m with context %}{{ m.test() }}') + assert t.render(foo=42) == '[42|23]' + t = test_env.from_string('{% from "module" import test %}{{ test() }}') + assert t.render(foo=42) == '[|23]' + t = test_env.from_string('{% from "module" import test without context %}{{ test() }}') + assert t.render(foo=42) == '[|23]' + t = test_env.from_string('{% from "module" import test with context %}{{ test() }}') + assert t.render(foo=42) == '[42|23]' + + def test_trailing_comma(self): + test_env.from_string('{% from "foo" import bar, baz with context %}') + test_env.from_string('{% from "foo" import bar, baz, with context %}') + test_env.from_string('{% from "foo" import bar, with context %}') + test_env.from_string('{% from "foo" import bar, with, context %}') + test_env.from_string('{% from "foo" import bar, with with context %}') + + def test_exports(self): + m = test_env.from_string(''' + {% macro toplevel() %}...{% endmacro %} + {% macro __private() %}...{% endmacro %} + {% set variable = 42 %} + {% for item in [1] %} + {% macro notthere() %}{% endmacro %} + {% endfor %} + ''').module + assert m.toplevel() == '...' + assert not hasattr(m, '__missing') + assert m.variable == 42 + assert not hasattr(m, 'notthere') + + +class IncludesTestCase(JinjaTestCase): + + def test_context_include(self): + t = test_env.from_string('{% include "header" %}') + assert t.render(foo=42) == '[42|23]' + t = test_env.from_string('{% include "header" with context %}') + assert t.render(foo=42) == '[42|23]' + t = test_env.from_string('{% include "header" without context %}') + assert t.render(foo=42) == '[|23]' + + def test_choice_includes(self): + t = test_env.from_string('{% include ["missing", "header"] %}') + assert t.render(foo=42) == '[42|23]' + + t = test_env.from_string('{% include ["missing", "missing2"] ignore missing %}') + assert t.render(foo=42) == '' + + t = test_env.from_string('{% include ["missing", "missing2"] %}') + self.assert_raises(TemplateNotFound, t.render) + try: + t.render() + except TemplatesNotFound as e: + assert e.templates == ['missing', 'missing2'] + assert e.name == 'missing2' + else: + assert False, 'thou shalt raise' + + def test_includes(t, **ctx): + ctx['foo'] = 42 + assert t.render(ctx) == '[42|23]' + + t = test_env.from_string('{% include ["missing", "header"] %}') + test_includes(t) + t = test_env.from_string('{% include x %}') + test_includes(t, x=['missing', 'header']) + t = test_env.from_string('{% include [x, "header"] %}') + test_includes(t, x='missing') + t = test_env.from_string('{% include x %}') + test_includes(t, x='header') + t = test_env.from_string('{% include x %}') + test_includes(t, x='header') + t = test_env.from_string('{% include [x] %}') + test_includes(t, x='header') + + def test_include_ignoring_missing(self): + t = test_env.from_string('{% include "missing" %}') + self.assert_raises(TemplateNotFound, t.render) + for extra in '', 'with context', 'without context': + t = test_env.from_string('{% include "missing" ignore missing ' + + extra + ' %}') + assert t.render() == '' + + def test_context_include_with_overrides(self): + env = Environment(loader=DictLoader(dict( + main="{% for item in [1, 2, 3] %}{% include 'item' %}{% endfor %}", + item="{{ item }}" + ))) + assert env.get_template("main").render() == "123" + + def test_unoptimized_scopes(self): + t = test_env.from_string(""" + {% macro outer(o) %} + {% macro inner() %} + {% include "o_printer" %} + {% endmacro %} + {{ inner() }} + {% endmacro %} + {{ outer("FOO") }} + """) + assert t.render().strip() == '(FOO)' + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(ImportsTestCase)) + suite.addTest(unittest.makeSuite(IncludesTestCase)) + return suite diff --git a/module/lib/jinja2/testsuite/inheritance.py b/module/lib/jinja2/testsuite/inheritance.py new file mode 100644 index 000000000..e0f51cda9 --- /dev/null +++ b/module/lib/jinja2/testsuite/inheritance.py @@ -0,0 +1,250 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.inheritance + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Tests the template inheritance feature. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import unittest + +from jinja2.testsuite import JinjaTestCase + +from jinja2 import Environment, DictLoader, TemplateError + + +LAYOUTTEMPLATE = '''\ +|{% block block1 %}block 1 from layout{% endblock %} +|{% block block2 %}block 2 from layout{% endblock %} +|{% block block3 %} +{% block block4 %}nested block 4 from layout{% endblock %} +{% endblock %}|''' + +LEVEL1TEMPLATE = '''\ +{% extends "layout" %} +{% block block1 %}block 1 from level1{% endblock %}''' + +LEVEL2TEMPLATE = '''\ +{% extends "level1" %} +{% block block2 %}{% block block5 %}nested block 5 from level2{% +endblock %}{% endblock %}''' + +LEVEL3TEMPLATE = '''\ +{% extends "level2" %} +{% block block5 %}block 5 from level3{% endblock %} +{% block block4 %}block 4 from level3{% endblock %} +''' + +LEVEL4TEMPLATE = '''\ +{% extends "level3" %} +{% block block3 %}block 3 from level4{% endblock %} +''' + +WORKINGTEMPLATE = '''\ +{% extends "layout" %} +{% block block1 %} + {% if false %} + {% block block2 %} + this should workd + {% endblock %} + {% endif %} +{% endblock %} +''' + +DOUBLEEXTENDS = '''\ +{% extends "layout" %} +{% extends "layout" %} +{% block block1 %} + {% if false %} + {% block block2 %} + this should workd + {% endblock %} + {% endif %} +{% endblock %} +''' + + +env = Environment(loader=DictLoader({ + 'layout': LAYOUTTEMPLATE, + 'level1': LEVEL1TEMPLATE, + 'level2': LEVEL2TEMPLATE, + 'level3': LEVEL3TEMPLATE, + 'level4': LEVEL4TEMPLATE, + 'working': WORKINGTEMPLATE, + 'doublee': DOUBLEEXTENDS, +}), trim_blocks=True) + + +class InheritanceTestCase(JinjaTestCase): + + def test_layout(self): + tmpl = env.get_template('layout') + assert tmpl.render() == ('|block 1 from layout|block 2 from ' + 'layout|nested block 4 from layout|') + + def test_level1(self): + tmpl = env.get_template('level1') + assert tmpl.render() == ('|block 1 from level1|block 2 from ' + 'layout|nested block 4 from layout|') + + def test_level2(self): + tmpl = env.get_template('level2') + assert tmpl.render() == ('|block 1 from level1|nested block 5 from ' + 'level2|nested block 4 from layout|') + + def test_level3(self): + tmpl = env.get_template('level3') + assert tmpl.render() == ('|block 1 from level1|block 5 from level3|' + 'block 4 from level3|') + + def test_level4(sel): + tmpl = env.get_template('level4') + assert tmpl.render() == ('|block 1 from level1|block 5 from ' + 'level3|block 3 from level4|') + + def test_super(self): + env = Environment(loader=DictLoader({ + 'a': '{% block intro %}INTRO{% endblock %}|' + 'BEFORE|{% block data %}INNER{% endblock %}|AFTER', + 'b': '{% extends "a" %}{% block data %}({{ ' + 'super() }}){% endblock %}', + 'c': '{% extends "b" %}{% block intro %}--{{ ' + 'super() }}--{% endblock %}\n{% block data ' + '%}[{{ super() }}]{% endblock %}' + })) + tmpl = env.get_template('c') + assert tmpl.render() == '--INTRO--|BEFORE|[(INNER)]|AFTER' + + def test_working(self): + tmpl = env.get_template('working') + + def test_reuse_blocks(self): + tmpl = env.from_string('{{ self.foo() }}|{% block foo %}42' + '{% endblock %}|{{ self.foo() }}') + assert tmpl.render() == '42|42|42' + + def test_preserve_blocks(self): + env = Environment(loader=DictLoader({ + 'a': '{% if false %}{% block x %}A{% endblock %}{% endif %}{{ self.x() }}', + 'b': '{% extends "a" %}{% block x %}B{{ super() }}{% endblock %}' + })) + tmpl = env.get_template('b') + assert tmpl.render() == 'BA' + + def test_dynamic_inheritance(self): + env = Environment(loader=DictLoader({ + 'master1': 'MASTER1{% block x %}{% endblock %}', + 'master2': 'MASTER2{% block x %}{% endblock %}', + 'child': '{% extends master %}{% block x %}CHILD{% endblock %}' + })) + tmpl = env.get_template('child') + for m in range(1, 3): + assert tmpl.render(master='master%d' % m) == 'MASTER%dCHILD' % m + + def test_multi_inheritance(self): + env = Environment(loader=DictLoader({ + 'master1': 'MASTER1{% block x %}{% endblock %}', + 'master2': 'MASTER2{% block x %}{% endblock %}', + 'child': '''{% if master %}{% extends master %}{% else %}{% extends + 'master1' %}{% endif %}{% block x %}CHILD{% endblock %}''' + })) + tmpl = env.get_template('child') + assert tmpl.render(master='master2') == 'MASTER2CHILD' + assert tmpl.render(master='master1') == 'MASTER1CHILD' + assert tmpl.render() == 'MASTER1CHILD' + + def test_scoped_block(self): + env = Environment(loader=DictLoader({ + 'master.html': '{% for item in seq %}[{% block item scoped %}' + '{% endblock %}]{% endfor %}' + })) + t = env.from_string('{% extends "master.html" %}{% block item %}' + '{{ item }}{% endblock %}') + assert t.render(seq=list(range(5))) == '[0][1][2][3][4]' + + def test_super_in_scoped_block(self): + env = Environment(loader=DictLoader({ + 'master.html': '{% for item in seq %}[{% block item scoped %}' + '{{ item }}{% endblock %}]{% endfor %}' + })) + t = env.from_string('{% extends "master.html" %}{% block item %}' + '{{ super() }}|{{ item * 2 }}{% endblock %}') + assert t.render(seq=list(range(5))) == '[0|0][1|2][2|4][3|6][4|8]' + + def test_scoped_block_after_inheritance(self): + env = Environment(loader=DictLoader({ + 'layout.html': ''' + {% block useless %}{% endblock %} + ''', + 'index.html': ''' + {%- extends 'layout.html' %} + {% from 'helpers.html' import foo with context %} + {% block useless %} + {% for x in [1, 2, 3] %} + {% block testing scoped %} + {{ foo(x) }} + {% endblock %} + {% endfor %} + {% endblock %} + ''', + 'helpers.html': ''' + {% macro foo(x) %}{{ the_foo + x }}{% endmacro %} + ''' + })) + rv = env.get_template('index.html').render(the_foo=42).split() + assert rv == ['43', '44', '45'] + + +class BugFixTestCase(JinjaTestCase): + + def test_fixed_macro_scoping_bug(self): + assert Environment(loader=DictLoader({ + 'test.html': '''\ + {% extends 'details.html' %} + + {% macro my_macro() %} + my_macro + {% endmacro %} + + {% block inner_box %} + {{ my_macro() }} + {% endblock %} + ''', + 'details.html': '''\ + {% extends 'standard.html' %} + + {% macro my_macro() %} + my_macro + {% endmacro %} + + {% block content %} + {% block outer_box %} + outer_box + {% block inner_box %} + inner_box + {% endblock %} + {% endblock %} + {% endblock %} + ''', + 'standard.html': ''' + {% block content %} {% endblock %} + ''' + })).get_template("test.html").render().split() == [u'outer_box', u'my_macro'] + + def test_double_extends(self): + """Ensures that a template with more than 1 {% extends ... %} usage + raises a ``TemplateError``. + """ + try: + tmpl = env.get_template('doublee') + except Exception as e: + assert isinstance(e, TemplateError) + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(InheritanceTestCase)) + suite.addTest(unittest.makeSuite(BugFixTestCase)) + return suite diff --git a/module/lib/jinja2/testsuite/lexnparse.py b/module/lib/jinja2/testsuite/lexnparse.py new file mode 100644 index 000000000..bd1c94cd3 --- /dev/null +++ b/module/lib/jinja2/testsuite/lexnparse.py @@ -0,0 +1,593 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.lexnparse + ~~~~~~~~~~~~~~~~~~~~~~~~~~ + + All the unittests regarding lexing, parsing and syntax. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import unittest + +from jinja2.testsuite import JinjaTestCase + +from jinja2 import Environment, Template, TemplateSyntaxError, \ + UndefinedError, nodes +from jinja2._compat import next, iteritems, text_type, PY2 +from jinja2.lexer import Token, TokenStream, TOKEN_EOF, \ + TOKEN_BLOCK_BEGIN, TOKEN_BLOCK_END + +env = Environment() + + +# how does a string look like in jinja syntax? +if PY2: + def jinja_string_repr(string): + return repr(string)[1:] +else: + jinja_string_repr = repr + + +class TokenStreamTestCase(JinjaTestCase): + test_tokens = [Token(1, TOKEN_BLOCK_BEGIN, ''), + Token(2, TOKEN_BLOCK_END, ''), + ] + + def test_simple(self): + ts = TokenStream(self.test_tokens, "foo", "bar") + assert ts.current.type is TOKEN_BLOCK_BEGIN + assert bool(ts) + assert not bool(ts.eos) + next(ts) + assert ts.current.type is TOKEN_BLOCK_END + assert bool(ts) + assert not bool(ts.eos) + next(ts) + assert ts.current.type is TOKEN_EOF + assert not bool(ts) + assert bool(ts.eos) + + def test_iter(self): + token_types = [t.type for t in TokenStream(self.test_tokens, "foo", "bar")] + assert token_types == ['block_begin', 'block_end', ] + + +class LexerTestCase(JinjaTestCase): + + def test_raw1(self): + tmpl = env.from_string('{% raw %}foo{% endraw %}|' + '{%raw%}{{ bar }}|{% baz %}{% endraw %}') + assert tmpl.render() == 'foo|{{ bar }}|{% baz %}' + + def test_raw2(self): + tmpl = env.from_string('1 {%- raw -%} 2 {%- endraw -%} 3') + assert tmpl.render() == '123' + + def test_balancing(self): + env = Environment('{%', '%}', '${', '}') + tmpl = env.from_string('''{% for item in seq + %}${{'foo': item}|upper}{% endfor %}''') + assert tmpl.render(seq=list(range(3))) == "{'FOO': 0}{'FOO': 1}{'FOO': 2}" + + def test_comments(self): + env = Environment('<!--', '-->', '{', '}') + tmpl = env.from_string('''\ +<ul> +<!--- for item in seq --> + <li>{item}</li> +<!--- endfor --> +</ul>''') + assert tmpl.render(seq=list(range(3))) == ("<ul>\n <li>0</li>\n " + "<li>1</li>\n <li>2</li>\n</ul>") + + def test_string_escapes(self): + for char in u'\0', u'\u2668', u'\xe4', u'\t', u'\r', u'\n': + tmpl = env.from_string('{{ %s }}' % jinja_string_repr(char)) + assert tmpl.render() == char + assert env.from_string('{{ "\N{HOT SPRINGS}" }}').render() == u'\u2668' + + def test_bytefallback(self): + from pprint import pformat + tmpl = env.from_string(u'''{{ 'foo'|pprint }}|{{ 'bÀr'|pprint }}''') + assert tmpl.render() == pformat('foo') + '|' + pformat(u'bÀr') + + def test_operators(self): + from jinja2.lexer import operators + for test, expect in iteritems(operators): + if test in '([{}])': + continue + stream = env.lexer.tokenize('{{ %s }}' % test) + next(stream) + assert stream.current.type == expect + + def test_normalizing(self): + for seq in '\r', '\r\n', '\n': + env = Environment(newline_sequence=seq) + tmpl = env.from_string('1\n2\r\n3\n4\n') + result = tmpl.render() + assert result.replace(seq, 'X') == '1X2X3X4' + + def test_trailing_newline(self): + for keep in [True, False]: + env = Environment(keep_trailing_newline=keep) + for template,expected in [ + ('', {}), + ('no\nnewline', {}), + ('with\nnewline\n', {False: 'with\nnewline'}), + ('with\nseveral\n\n\n', {False: 'with\nseveral\n\n'}), + ]: + tmpl = env.from_string(template) + expect = expected.get(keep, template) + result = tmpl.render() + assert result == expect, (keep, template, result, expect) + +class ParserTestCase(JinjaTestCase): + + def test_php_syntax(self): + env = Environment('<?', '?>', '<?=', '?>', '<!--', '-->') + tmpl = env.from_string('''\ +<!-- I'm a comment, I'm not interesting -->\ +<? for item in seq -?> + <?= item ?> +<?- endfor ?>''') + assert tmpl.render(seq=list(range(5))) == '01234' + + def test_erb_syntax(self): + env = Environment('<%', '%>', '<%=', '%>', '<%#', '%>') + tmpl = env.from_string('''\ +<%# I'm a comment, I'm not interesting %>\ +<% for item in seq -%> + <%= item %> +<%- endfor %>''') + assert tmpl.render(seq=list(range(5))) == '01234' + + def test_comment_syntax(self): + env = Environment('<!--', '-->', '${', '}', '<!--#', '-->') + tmpl = env.from_string('''\ +<!--# I'm a comment, I'm not interesting -->\ +<!-- for item in seq ---> + ${item} +<!--- endfor -->''') + assert tmpl.render(seq=list(range(5))) == '01234' + + def test_balancing(self): + tmpl = env.from_string('''{{{'foo':'bar'}.foo}}''') + assert tmpl.render() == 'bar' + + def test_start_comment(self): + tmpl = env.from_string('''{# foo comment +and bar comment #} +{% macro blub() %}foo{% endmacro %} +{{ blub() }}''') + assert tmpl.render().strip() == 'foo' + + def test_line_syntax(self): + env = Environment('<%', '%>', '${', '}', '<%#', '%>', '%') + tmpl = env.from_string('''\ +<%# regular comment %> +% for item in seq: + ${item} +% endfor''') + assert [int(x.strip()) for x in tmpl.render(seq=list(range(5))).split()] == \ + list(range(5)) + + env = Environment('<%', '%>', '${', '}', '<%#', '%>', '%', '##') + tmpl = env.from_string('''\ +<%# regular comment %> +% for item in seq: + ${item} ## the rest of the stuff +% endfor''') + assert [int(x.strip()) for x in tmpl.render(seq=list(range(5))).split()] == \ + list(range(5)) + + def test_line_syntax_priority(self): + # XXX: why is the whitespace there in front of the newline? + env = Environment('{%', '%}', '${', '}', '/*', '*/', '##', '#') + tmpl = env.from_string('''\ +/* ignore me. + I'm a multiline comment */ +## for item in seq: +* ${item} # this is just extra stuff +## endfor''') + assert tmpl.render(seq=[1, 2]).strip() == '* 1\n* 2' + env = Environment('{%', '%}', '${', '}', '/*', '*/', '#', '##') + tmpl = env.from_string('''\ +/* ignore me. + I'm a multiline comment */ +# for item in seq: +* ${item} ## this is just extra stuff + ## extra stuff i just want to ignore +# endfor''') + assert tmpl.render(seq=[1, 2]).strip() == '* 1\n\n* 2' + + def test_error_messages(self): + def assert_error(code, expected): + try: + Template(code) + except TemplateSyntaxError as e: + assert str(e) == expected, 'unexpected error message' + else: + assert False, 'that was supposed to be an error' + + assert_error('{% for item in seq %}...{% endif %}', + "Encountered unknown tag 'endif'. Jinja was looking " + "for the following tags: 'endfor' or 'else'. The " + "innermost block that needs to be closed is 'for'.") + assert_error('{% if foo %}{% for item in seq %}...{% endfor %}{% endfor %}', + "Encountered unknown tag 'endfor'. Jinja was looking for " + "the following tags: 'elif' or 'else' or 'endif'. The " + "innermost block that needs to be closed is 'if'.") + assert_error('{% if foo %}', + "Unexpected end of template. Jinja was looking for the " + "following tags: 'elif' or 'else' or 'endif'. The " + "innermost block that needs to be closed is 'if'.") + assert_error('{% for item in seq %}', + "Unexpected end of template. Jinja was looking for the " + "following tags: 'endfor' or 'else'. The innermost block " + "that needs to be closed is 'for'.") + assert_error('{% block foo-bar-baz %}', + "Block names in Jinja have to be valid Python identifiers " + "and may not contain hyphens, use an underscore instead.") + assert_error('{% unknown_tag %}', + "Encountered unknown tag 'unknown_tag'.") + + +class SyntaxTestCase(JinjaTestCase): + + def test_call(self): + env = Environment() + env.globals['foo'] = lambda a, b, c, e, g: a + b + c + e + g + tmpl = env.from_string("{{ foo('a', c='d', e='f', *['b'], **{'g': 'h'}) }}") + assert tmpl.render() == 'abdfh' + + def test_slicing(self): + tmpl = env.from_string('{{ [1, 2, 3][:] }}|{{ [1, 2, 3][::-1] }}') + assert tmpl.render() == '[1, 2, 3]|[3, 2, 1]' + + def test_attr(self): + tmpl = env.from_string("{{ foo.bar }}|{{ foo['bar'] }}") + assert tmpl.render(foo={'bar': 42}) == '42|42' + + def test_subscript(self): + tmpl = env.from_string("{{ foo[0] }}|{{ foo[-1] }}") + assert tmpl.render(foo=[0, 1, 2]) == '0|2' + + def test_tuple(self): + tmpl = env.from_string('{{ () }}|{{ (1,) }}|{{ (1, 2) }}') + assert tmpl.render() == '()|(1,)|(1, 2)' + + def test_math(self): + tmpl = env.from_string('{{ (1 + 1 * 2) - 3 / 2 }}|{{ 2**3 }}') + assert tmpl.render() == '1.5|8' + + def test_div(self): + tmpl = env.from_string('{{ 3 // 2 }}|{{ 3 / 2 }}|{{ 3 % 2 }}') + assert tmpl.render() == '1|1.5|1' + + def test_unary(self): + tmpl = env.from_string('{{ +3 }}|{{ -3 }}') + assert tmpl.render() == '3|-3' + + def test_concat(self): + tmpl = env.from_string("{{ [1, 2] ~ 'foo' }}") + assert tmpl.render() == '[1, 2]foo' + + def test_compare(self): + tmpl = env.from_string('{{ 1 > 0 }}|{{ 1 >= 1 }}|{{ 2 < 3 }}|' + '{{ 2 == 2 }}|{{ 1 <= 1 }}') + assert tmpl.render() == 'True|True|True|True|True' + + def test_inop(self): + tmpl = env.from_string('{{ 1 in [1, 2, 3] }}|{{ 1 not in [1, 2, 3] }}') + assert tmpl.render() == 'True|False' + + def test_literals(self): + tmpl = env.from_string('{{ [] }}|{{ {} }}|{{ () }}') + assert tmpl.render().lower() == '[]|{}|()' + + def test_bool(self): + tmpl = env.from_string('{{ true and false }}|{{ false ' + 'or true }}|{{ not false }}') + assert tmpl.render() == 'False|True|True' + + def test_grouping(self): + tmpl = env.from_string('{{ (true and false) or (false and true) and not false }}') + assert tmpl.render() == 'False' + + def test_django_attr(self): + tmpl = env.from_string('{{ [1, 2, 3].0 }}|{{ [[1]].0.0 }}') + assert tmpl.render() == '1|1' + + def test_conditional_expression(self): + tmpl = env.from_string('''{{ 0 if true else 1 }}''') + assert tmpl.render() == '0' + + def test_short_conditional_expression(self): + tmpl = env.from_string('<{{ 1 if false }}>') + assert tmpl.render() == '<>' + + tmpl = env.from_string('<{{ (1 if false).bar }}>') + self.assert_raises(UndefinedError, tmpl.render) + + def test_filter_priority(self): + tmpl = env.from_string('{{ "foo"|upper + "bar"|upper }}') + assert tmpl.render() == 'FOOBAR' + + def test_function_calls(self): + tests = [ + (True, '*foo, bar'), + (True, '*foo, *bar'), + (True, '*foo, bar=42'), + (True, '**foo, *bar'), + (True, '**foo, bar'), + (False, 'foo, bar'), + (False, 'foo, bar=42'), + (False, 'foo, bar=23, *args'), + (False, 'a, b=c, *d, **e'), + (False, '*foo, **bar') + ] + for should_fail, sig in tests: + if should_fail: + self.assert_raises(TemplateSyntaxError, + env.from_string, '{{ foo(%s) }}' % sig) + else: + env.from_string('foo(%s)' % sig) + + def test_tuple_expr(self): + for tmpl in [ + '{{ () }}', + '{{ (1, 2) }}', + '{{ (1, 2,) }}', + '{{ 1, }}', + '{{ 1, 2 }}', + '{% for foo, bar in seq %}...{% endfor %}', + '{% for x in foo, bar %}...{% endfor %}', + '{% for x in foo, %}...{% endfor %}' + ]: + assert env.from_string(tmpl) + + def test_trailing_comma(self): + tmpl = env.from_string('{{ (1, 2,) }}|{{ [1, 2,] }}|{{ {1: 2,} }}') + assert tmpl.render().lower() == '(1, 2)|[1, 2]|{1: 2}' + + def test_block_end_name(self): + env.from_string('{% block foo %}...{% endblock foo %}') + self.assert_raises(TemplateSyntaxError, env.from_string, + '{% block x %}{% endblock y %}') + + def test_constant_casing(self): + for const in True, False, None: + tmpl = env.from_string('{{ %s }}|{{ %s }}|{{ %s }}' % ( + str(const), str(const).lower(), str(const).upper() + )) + assert tmpl.render() == '%s|%s|' % (const, const) + + def test_test_chaining(self): + self.assert_raises(TemplateSyntaxError, env.from_string, + '{{ foo is string is sequence }}') + assert env.from_string('{{ 42 is string or 42 is number }}' + ).render() == 'True' + + def test_string_concatenation(self): + tmpl = env.from_string('{{ "foo" "bar" "baz" }}') + assert tmpl.render() == 'foobarbaz' + + def test_notin(self): + bar = range(100) + tmpl = env.from_string('''{{ not 42 in bar }}''') + assert tmpl.render(bar=bar) == text_type(not 42 in bar) + + def test_implicit_subscribed_tuple(self): + class Foo(object): + def __getitem__(self, x): + return x + t = env.from_string('{{ foo[1, 2] }}') + assert t.render(foo=Foo()) == u'(1, 2)' + + def test_raw2(self): + tmpl = env.from_string('{% raw %}{{ FOO }} and {% BAR %}{% endraw %}') + assert tmpl.render() == '{{ FOO }} and {% BAR %}' + + def test_const(self): + tmpl = env.from_string('{{ true }}|{{ false }}|{{ none }}|' + '{{ none is defined }}|{{ missing is defined }}') + assert tmpl.render() == 'True|False|None|True|False' + + def test_neg_filter_priority(self): + node = env.parse('{{ -1|foo }}') + assert isinstance(node.body[0].nodes[0], nodes.Filter) + assert isinstance(node.body[0].nodes[0].node, nodes.Neg) + + def test_const_assign(self): + constass1 = '''{% set true = 42 %}''' + constass2 = '''{% for none in seq %}{% endfor %}''' + for tmpl in constass1, constass2: + self.assert_raises(TemplateSyntaxError, env.from_string, tmpl) + + def test_localset(self): + tmpl = env.from_string('''{% set foo = 0 %}\ +{% for item in [1, 2] %}{% set foo = 1 %}{% endfor %}\ +{{ foo }}''') + assert tmpl.render() == '0' + + def test_parse_unary(self): + tmpl = env.from_string('{{ -foo["bar"] }}') + assert tmpl.render(foo={'bar': 42}) == '-42' + tmpl = env.from_string('{{ -foo["bar"]|abs }}') + assert tmpl.render(foo={'bar': 42}) == '42' + + +class LstripBlocksTestCase(JinjaTestCase): + + def test_lstrip(self): + env = Environment(lstrip_blocks=True, trim_blocks=False) + tmpl = env.from_string(''' {% if True %}\n {% endif %}''') + assert tmpl.render() == "\n" + + def test_lstrip_trim(self): + env = Environment(lstrip_blocks=True, trim_blocks=True) + tmpl = env.from_string(''' {% if True %}\n {% endif %}''') + assert tmpl.render() == "" + + def test_no_lstrip(self): + env = Environment(lstrip_blocks=True, trim_blocks=False) + tmpl = env.from_string(''' {%+ if True %}\n {%+ endif %}''') + assert tmpl.render() == " \n " + + def test_lstrip_endline(self): + env = Environment(lstrip_blocks=True, trim_blocks=False) + tmpl = env.from_string(''' hello{% if True %}\n goodbye{% endif %}''') + assert tmpl.render() == " hello\n goodbye" + + def test_lstrip_inline(self): + env = Environment(lstrip_blocks=True, trim_blocks=False) + tmpl = env.from_string(''' {% if True %}hello {% endif %}''') + assert tmpl.render() == 'hello ' + + def test_lstrip_nested(self): + env = Environment(lstrip_blocks=True, trim_blocks=False) + tmpl = env.from_string(''' {% if True %}a {% if True %}b {% endif %}c {% endif %}''') + assert tmpl.render() == 'a b c ' + + def test_lstrip_left_chars(self): + env = Environment(lstrip_blocks=True, trim_blocks=False) + tmpl = env.from_string(''' abc {% if True %} + hello{% endif %}''') + assert tmpl.render() == ' abc \n hello' + + def test_lstrip_embeded_strings(self): + env = Environment(lstrip_blocks=True, trim_blocks=False) + tmpl = env.from_string(''' {% set x = " {% str %} " %}{{ x }}''') + assert tmpl.render() == ' {% str %} ' + + def test_lstrip_preserve_leading_newlines(self): + env = Environment(lstrip_blocks=True, trim_blocks=False) + tmpl = env.from_string('''\n\n\n{% set hello = 1 %}''') + assert tmpl.render() == '\n\n\n' + + def test_lstrip_comment(self): + env = Environment(lstrip_blocks=True, trim_blocks=False) + tmpl = env.from_string(''' {# if True #} +hello + {#endif#}''') + assert tmpl.render() == '\nhello\n' + + def test_lstrip_angle_bracket_simple(self): + env = Environment('<%', '%>', '${', '}', '<%#', '%>', '%', '##', + lstrip_blocks=True, trim_blocks=True) + tmpl = env.from_string(''' <% if True %>hello <% endif %>''') + assert tmpl.render() == 'hello ' + + def test_lstrip_angle_bracket_comment(self): + env = Environment('<%', '%>', '${', '}', '<%#', '%>', '%', '##', + lstrip_blocks=True, trim_blocks=True) + tmpl = env.from_string(''' <%# if True %>hello <%# endif %>''') + assert tmpl.render() == 'hello ' + + def test_lstrip_angle_bracket(self): + env = Environment('<%', '%>', '${', '}', '<%#', '%>', '%', '##', + lstrip_blocks=True, trim_blocks=True) + tmpl = env.from_string('''\ + <%# regular comment %> + <% for item in seq %> +${item} ## the rest of the stuff + <% endfor %>''') + assert tmpl.render(seq=range(5)) == \ + ''.join('%s\n' % x for x in range(5)) + + def test_lstrip_angle_bracket_compact(self): + env = Environment('<%', '%>', '${', '}', '<%#', '%>', '%', '##', + lstrip_blocks=True, trim_blocks=True) + tmpl = env.from_string('''\ + <%#regular comment%> + <%for item in seq%> +${item} ## the rest of the stuff + <%endfor%>''') + assert tmpl.render(seq=range(5)) == \ + ''.join('%s\n' % x for x in range(5)) + + def test_php_syntax_with_manual(self): + env = Environment('<?', '?>', '<?=', '?>', '<!--', '-->', + lstrip_blocks=True, trim_blocks=True) + tmpl = env.from_string('''\ + <!-- I'm a comment, I'm not interesting --> + <? for item in seq -?> + <?= item ?> + <?- endfor ?>''') + assert tmpl.render(seq=range(5)) == '01234' + + def test_php_syntax(self): + env = Environment('<?', '?>', '<?=', '?>', '<!--', '-->', + lstrip_blocks=True, trim_blocks=True) + tmpl = env.from_string('''\ + <!-- I'm a comment, I'm not interesting --> + <? for item in seq ?> + <?= item ?> + <? endfor ?>''') + assert tmpl.render(seq=range(5)) == ''.join(' %s\n' % x for x in range(5)) + + def test_php_syntax_compact(self): + env = Environment('<?', '?>', '<?=', '?>', '<!--', '-->', + lstrip_blocks=True, trim_blocks=True) + tmpl = env.from_string('''\ + <!-- I'm a comment, I'm not interesting --> + <?for item in seq?> + <?=item?> + <?endfor?>''') + assert tmpl.render(seq=range(5)) == ''.join(' %s\n' % x for x in range(5)) + + def test_erb_syntax(self): + env = Environment('<%', '%>', '<%=', '%>', '<%#', '%>', + lstrip_blocks=True, trim_blocks=True) + #env.from_string('') + #for n,r in env.lexer.rules.iteritems(): + # print n + #print env.lexer.rules['root'][0][0].pattern + #print "'%s'" % tmpl.render(seq=range(5)) + tmpl = env.from_string('''\ +<%# I'm a comment, I'm not interesting %> + <% for item in seq %> + <%= item %> + <% endfor %> +''') + assert tmpl.render(seq=range(5)) == ''.join(' %s\n' % x for x in range(5)) + + def test_erb_syntax_with_manual(self): + env = Environment('<%', '%>', '<%=', '%>', '<%#', '%>', + lstrip_blocks=True, trim_blocks=True) + tmpl = env.from_string('''\ +<%# I'm a comment, I'm not interesting %> + <% for item in seq -%> + <%= item %> + <%- endfor %>''') + assert tmpl.render(seq=range(5)) == '01234' + + def test_erb_syntax_no_lstrip(self): + env = Environment('<%', '%>', '<%=', '%>', '<%#', '%>', + lstrip_blocks=True, trim_blocks=True) + tmpl = env.from_string('''\ +<%# I'm a comment, I'm not interesting %> + <%+ for item in seq -%> + <%= item %> + <%- endfor %>''') + assert tmpl.render(seq=range(5)) == ' 01234' + + def test_comment_syntax(self): + env = Environment('<!--', '-->', '${', '}', '<!--#', '-->', + lstrip_blocks=True, trim_blocks=True) + tmpl = env.from_string('''\ +<!--# I'm a comment, I'm not interesting -->\ +<!-- for item in seq ---> + ${item} +<!--- endfor -->''') + assert tmpl.render(seq=range(5)) == '01234' + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(TokenStreamTestCase)) + suite.addTest(unittest.makeSuite(LexerTestCase)) + suite.addTest(unittest.makeSuite(ParserTestCase)) + suite.addTest(unittest.makeSuite(SyntaxTestCase)) + suite.addTest(unittest.makeSuite(LstripBlocksTestCase)) + return suite diff --git a/module/lib/jinja2/testsuite/loader.py b/module/lib/jinja2/testsuite/loader.py new file mode 100644 index 000000000..a7350aab9 --- /dev/null +++ b/module/lib/jinja2/testsuite/loader.py @@ -0,0 +1,226 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.loader + ~~~~~~~~~~~~~~~~~~~~~~~ + + Test the loaders. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import os +import sys +import tempfile +import shutil +import unittest + +from jinja2.testsuite import JinjaTestCase, dict_loader, \ + package_loader, filesystem_loader, function_loader, \ + choice_loader, prefix_loader + +from jinja2 import Environment, loaders +from jinja2._compat import PYPY, PY2 +from jinja2.loaders import split_template_path +from jinja2.exceptions import TemplateNotFound + + +class LoaderTestCase(JinjaTestCase): + + def test_dict_loader(self): + env = Environment(loader=dict_loader) + tmpl = env.get_template('justdict.html') + assert tmpl.render().strip() == 'FOO' + self.assert_raises(TemplateNotFound, env.get_template, 'missing.html') + + def test_package_loader(self): + env = Environment(loader=package_loader) + tmpl = env.get_template('test.html') + assert tmpl.render().strip() == 'BAR' + self.assert_raises(TemplateNotFound, env.get_template, 'missing.html') + + def test_filesystem_loader(self): + env = Environment(loader=filesystem_loader) + tmpl = env.get_template('test.html') + assert tmpl.render().strip() == 'BAR' + tmpl = env.get_template('foo/test.html') + assert tmpl.render().strip() == 'FOO' + self.assert_raises(TemplateNotFound, env.get_template, 'missing.html') + + def test_choice_loader(self): + env = Environment(loader=choice_loader) + tmpl = env.get_template('justdict.html') + assert tmpl.render().strip() == 'FOO' + tmpl = env.get_template('test.html') + assert tmpl.render().strip() == 'BAR' + self.assert_raises(TemplateNotFound, env.get_template, 'missing.html') + + def test_function_loader(self): + env = Environment(loader=function_loader) + tmpl = env.get_template('justfunction.html') + assert tmpl.render().strip() == 'FOO' + self.assert_raises(TemplateNotFound, env.get_template, 'missing.html') + + def test_prefix_loader(self): + env = Environment(loader=prefix_loader) + tmpl = env.get_template('a/test.html') + assert tmpl.render().strip() == 'BAR' + tmpl = env.get_template('b/justdict.html') + assert tmpl.render().strip() == 'FOO' + self.assert_raises(TemplateNotFound, env.get_template, 'missing') + + def test_caching(self): + changed = False + class TestLoader(loaders.BaseLoader): + def get_source(self, environment, template): + return u'foo', None, lambda: not changed + env = Environment(loader=TestLoader(), cache_size=-1) + tmpl = env.get_template('template') + assert tmpl is env.get_template('template') + changed = True + assert tmpl is not env.get_template('template') + changed = False + + env = Environment(loader=TestLoader(), cache_size=0) + assert env.get_template('template') \ + is not env.get_template('template') + + env = Environment(loader=TestLoader(), cache_size=2) + t1 = env.get_template('one') + t2 = env.get_template('two') + assert t2 is env.get_template('two') + assert t1 is env.get_template('one') + t3 = env.get_template('three') + assert 'one' in env.cache + assert 'two' not in env.cache + assert 'three' in env.cache + + def test_dict_loader_cache_invalidates(self): + mapping = {'foo': "one"} + env = Environment(loader=loaders.DictLoader(mapping)) + assert env.get_template('foo').render() == "one" + mapping['foo'] = "two" + assert env.get_template('foo').render() == "two" + + def test_split_template_path(self): + assert split_template_path('foo/bar') == ['foo', 'bar'] + assert split_template_path('./foo/bar') == ['foo', 'bar'] + self.assert_raises(TemplateNotFound, split_template_path, '../foo') + + +class ModuleLoaderTestCase(JinjaTestCase): + archive = None + + def compile_down(self, zip='deflated', py_compile=False): + super(ModuleLoaderTestCase, self).setup() + log = [] + self.reg_env = Environment(loader=prefix_loader) + if zip is not None: + self.archive = tempfile.mkstemp(suffix='.zip')[1] + else: + self.archive = tempfile.mkdtemp() + self.reg_env.compile_templates(self.archive, zip=zip, + log_function=log.append, + py_compile=py_compile) + self.mod_env = Environment(loader=loaders.ModuleLoader(self.archive)) + return ''.join(log) + + def teardown(self): + super(ModuleLoaderTestCase, self).teardown() + if hasattr(self, 'mod_env'): + if os.path.isfile(self.archive): + os.remove(self.archive) + else: + shutil.rmtree(self.archive) + self.archive = None + + def test_log(self): + log = self.compile_down() + assert 'Compiled "a/foo/test.html" as ' \ + 'tmpl_a790caf9d669e39ea4d280d597ec891c4ef0404a' in log + assert 'Finished compiling templates' in log + assert 'Could not compile "a/syntaxerror.html": ' \ + 'Encountered unknown tag \'endif\'' in log + + def _test_common(self): + tmpl1 = self.reg_env.get_template('a/test.html') + tmpl2 = self.mod_env.get_template('a/test.html') + assert tmpl1.render() == tmpl2.render() + + tmpl1 = self.reg_env.get_template('b/justdict.html') + tmpl2 = self.mod_env.get_template('b/justdict.html') + assert tmpl1.render() == tmpl2.render() + + def test_deflated_zip_compile(self): + self.compile_down(zip='deflated') + self._test_common() + + def test_stored_zip_compile(self): + self.compile_down(zip='stored') + self._test_common() + + def test_filesystem_compile(self): + self.compile_down(zip=None) + self._test_common() + + def test_weak_references(self): + self.compile_down() + tmpl = self.mod_env.get_template('a/test.html') + key = loaders.ModuleLoader.get_template_key('a/test.html') + name = self.mod_env.loader.module.__name__ + + assert hasattr(self.mod_env.loader.module, key) + assert name in sys.modules + + # unset all, ensure the module is gone from sys.modules + self.mod_env = tmpl = None + + try: + import gc + gc.collect() + except: + pass + + assert name not in sys.modules + + # This test only makes sense on non-pypy python 2 + if PY2 and not PYPY: + def test_byte_compilation(self): + log = self.compile_down(py_compile=True) + assert 'Byte-compiled "a/test.html"' in log + tmpl1 = self.mod_env.get_template('a/test.html') + mod = self.mod_env.loader.module. \ + tmpl_3c4ddf650c1a73df961a6d3d2ce2752f1b8fd490 + assert mod.__file__.endswith('.pyc') + + def test_choice_loader(self): + log = self.compile_down() + + self.mod_env.loader = loaders.ChoiceLoader([ + self.mod_env.loader, + loaders.DictLoader({'DICT_SOURCE': 'DICT_TEMPLATE'}) + ]) + + tmpl1 = self.mod_env.get_template('a/test.html') + self.assert_equal(tmpl1.render(), 'BAR') + tmpl2 = self.mod_env.get_template('DICT_SOURCE') + self.assert_equal(tmpl2.render(), 'DICT_TEMPLATE') + + def test_prefix_loader(self): + log = self.compile_down() + + self.mod_env.loader = loaders.PrefixLoader({ + 'MOD': self.mod_env.loader, + 'DICT': loaders.DictLoader({'test.html': 'DICT_TEMPLATE'}) + }) + + tmpl1 = self.mod_env.get_template('MOD/a/test.html') + self.assert_equal(tmpl1.render(), 'BAR') + tmpl2 = self.mod_env.get_template('DICT/test.html') + self.assert_equal(tmpl2.render(), 'DICT_TEMPLATE') + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(LoaderTestCase)) + suite.addTest(unittest.makeSuite(ModuleLoaderTestCase)) + return suite diff --git a/module/lib/jinja2/testsuite/regression.py b/module/lib/jinja2/testsuite/regression.py new file mode 100644 index 000000000..c5f7d5c65 --- /dev/null +++ b/module/lib/jinja2/testsuite/regression.py @@ -0,0 +1,279 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.regression + ~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Tests corner cases and bugs. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import unittest + +from jinja2.testsuite import JinjaTestCase + +from jinja2 import Template, Environment, DictLoader, TemplateSyntaxError, \ + TemplateNotFound, PrefixLoader +from jinja2._compat import text_type + +env = Environment() + + +class CornerTestCase(JinjaTestCase): + + def test_assigned_scoping(self): + t = env.from_string(''' + {%- for item in (1, 2, 3, 4) -%} + [{{ item }}] + {%- endfor %} + {{- item -}} + ''') + assert t.render(item=42) == '[1][2][3][4]42' + + t = env.from_string(''' + {%- for item in (1, 2, 3, 4) -%} + [{{ item }}] + {%- endfor %} + {%- set item = 42 %} + {{- item -}} + ''') + assert t.render() == '[1][2][3][4]42' + + t = env.from_string(''' + {%- set item = 42 %} + {%- for item in (1, 2, 3, 4) -%} + [{{ item }}] + {%- endfor %} + {{- item -}} + ''') + assert t.render() == '[1][2][3][4]42' + + def test_closure_scoping(self): + t = env.from_string(''' + {%- set wrapper = "<FOO>" %} + {%- for item in (1, 2, 3, 4) %} + {%- macro wrapper() %}[{{ item }}]{% endmacro %} + {{- wrapper() }} + {%- endfor %} + {{- wrapper -}} + ''') + assert t.render() == '[1][2][3][4]<FOO>' + + t = env.from_string(''' + {%- for item in (1, 2, 3, 4) %} + {%- macro wrapper() %}[{{ item }}]{% endmacro %} + {{- wrapper() }} + {%- endfor %} + {%- set wrapper = "<FOO>" %} + {{- wrapper -}} + ''') + assert t.render() == '[1][2][3][4]<FOO>' + + t = env.from_string(''' + {%- for item in (1, 2, 3, 4) %} + {%- macro wrapper() %}[{{ item }}]{% endmacro %} + {{- wrapper() }} + {%- endfor %} + {{- wrapper -}} + ''') + assert t.render(wrapper=23) == '[1][2][3][4]23' + + +class BugTestCase(JinjaTestCase): + + def test_keyword_folding(self): + env = Environment() + env.filters['testing'] = lambda value, some: value + some + assert env.from_string("{{ 'test'|testing(some='stuff') }}") \ + .render() == 'teststuff' + + def test_extends_output_bugs(self): + env = Environment(loader=DictLoader({ + 'parent.html': '(({% block title %}{% endblock %}))' + })) + + t = env.from_string('{% if expr %}{% extends "parent.html" %}{% endif %}' + '[[{% block title %}title{% endblock %}]]' + '{% for item in [1, 2, 3] %}({{ item }}){% endfor %}') + assert t.render(expr=False) == '[[title]](1)(2)(3)' + assert t.render(expr=True) == '((title))' + + def test_urlize_filter_escaping(self): + tmpl = env.from_string('{{ "http://www.example.org/<foo"|urlize }}') + assert tmpl.render() == '<a href="http://www.example.org/<foo">http://www.example.org/<foo</a>' + + def test_loop_call_loop(self): + tmpl = env.from_string(''' + + {% macro test() %} + {{ caller() }} + {% endmacro %} + + {% for num1 in range(5) %} + {% call test() %} + {% for num2 in range(10) %} + {{ loop.index }} + {% endfor %} + {% endcall %} + {% endfor %} + + ''') + + assert tmpl.render().split() == [text_type(x) for x in range(1, 11)] * 5 + + def test_weird_inline_comment(self): + env = Environment(line_statement_prefix='%') + self.assert_raises(TemplateSyntaxError, env.from_string, + '% for item in seq {# missing #}\n...% endfor') + + def test_old_macro_loop_scoping_bug(self): + tmpl = env.from_string('{% for i in (1, 2) %}{{ i }}{% endfor %}' + '{% macro i() %}3{% endmacro %}{{ i() }}') + assert tmpl.render() == '123' + + def test_partial_conditional_assignments(self): + tmpl = env.from_string('{% if b %}{% set a = 42 %}{% endif %}{{ a }}') + assert tmpl.render(a=23) == '23' + assert tmpl.render(b=True) == '42' + + def test_stacked_locals_scoping_bug(self): + env = Environment(line_statement_prefix='#') + t = env.from_string('''\ +# for j in [1, 2]: +# set x = 1 +# for i in [1, 2]: +# print x +# if i % 2 == 0: +# set x = x + 1 +# endif +# endfor +# endfor +# if a +# print 'A' +# elif b +# print 'B' +# elif c == d +# print 'C' +# else +# print 'D' +# endif + ''') + assert t.render(a=0, b=False, c=42, d=42.0) == '1111C' + + def test_stacked_locals_scoping_bug_twoframe(self): + t = Template(''' + {% set x = 1 %} + {% for item in foo %} + {% if item == 1 %} + {% set x = 2 %} + {% endif %} + {% endfor %} + {{ x }} + ''') + rv = t.render(foo=[1]).strip() + assert rv == u'1' + + def test_call_with_args(self): + t = Template("""{% macro dump_users(users) -%} + <ul> + {%- for user in users -%} + <li><p>{{ user.username|e }}</p>{{ caller(user) }}</li> + {%- endfor -%} + </ul> + {%- endmacro -%} + + {% call(user) dump_users(list_of_user) -%} + <dl> + <dl>Realname</dl> + <dd>{{ user.realname|e }}</dd> + <dl>Description</dl> + <dd>{{ user.description }}</dd> + </dl> + {% endcall %}""") + + assert [x.strip() for x in t.render(list_of_user=[{ + 'username':'apo', + 'realname':'something else', + 'description':'test' + }]).splitlines()] == [ + u'<ul><li><p>apo</p><dl>', + u'<dl>Realname</dl>', + u'<dd>something else</dd>', + u'<dl>Description</dl>', + u'<dd>test</dd>', + u'</dl>', + u'</li></ul>' + ] + + def test_empty_if_condition_fails(self): + self.assert_raises(TemplateSyntaxError, Template, '{% if %}....{% endif %}') + self.assert_raises(TemplateSyntaxError, Template, '{% if foo %}...{% elif %}...{% endif %}') + self.assert_raises(TemplateSyntaxError, Template, '{% for x in %}..{% endfor %}') + + def test_recursive_loop_bug(self): + tpl1 = Template(""" + {% for p in foo recursive%} + {{p.bar}} + {% for f in p.fields recursive%} + {{f.baz}} + {{p.bar}} + {% if f.rec %} + {{ loop(f.sub) }} + {% endif %} + {% endfor %} + {% endfor %} + """) + + tpl2 = Template(""" + {% for p in foo%} + {{p.bar}} + {% for f in p.fields recursive%} + {{f.baz}} + {{p.bar}} + {% if f.rec %} + {{ loop(f.sub) }} + {% endif %} + {% endfor %} + {% endfor %} + """) + + def test_else_loop_bug(self): + t = Template(''' + {% for x in y %} + {{ loop.index0 }} + {% else %} + {% for i in range(3) %}{{ i }}{% endfor %} + {% endfor %} + ''') + self.assertEqual(t.render(y=[]).strip(), '012') + + def test_correct_prefix_loader_name(self): + env = Environment(loader=PrefixLoader({ + 'foo': DictLoader({}) + })) + try: + env.get_template('foo/bar.html') + except TemplateNotFound as e: + assert e.name == 'foo/bar.html' + else: + assert False, 'expected error here' + + def test_contextfunction_callable_classes(self): + from jinja2.utils import contextfunction + class CallableClass(object): + @contextfunction + def __call__(self, ctx): + return ctx.resolve('hello') + + tpl = Template("""{{ callableclass() }}""") + output = tpl.render(callableclass = CallableClass(), hello = 'TEST') + expected = 'TEST' + + self.assert_equal(output, expected) + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(CornerTestCase)) + suite.addTest(unittest.makeSuite(BugTestCase)) + return suite diff --git a/module/lib/jinja2/testsuite/res/__init__.py b/module/lib/jinja2/testsuite/res/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/module/lib/jinja2/testsuite/res/__init__.py diff --git a/module/lib/jinja2/testsuite/res/templates/broken.html b/module/lib/jinja2/testsuite/res/templates/broken.html new file mode 100644 index 000000000..77669fae5 --- /dev/null +++ b/module/lib/jinja2/testsuite/res/templates/broken.html @@ -0,0 +1,3 @@ +Before +{{ fail() }} +After diff --git a/module/lib/jinja2/testsuite/res/templates/foo/test.html b/module/lib/jinja2/testsuite/res/templates/foo/test.html new file mode 100644 index 000000000..b7d6715e2 --- /dev/null +++ b/module/lib/jinja2/testsuite/res/templates/foo/test.html @@ -0,0 +1 @@ +FOO diff --git a/module/lib/jinja2/testsuite/res/templates/syntaxerror.html b/module/lib/jinja2/testsuite/res/templates/syntaxerror.html new file mode 100644 index 000000000..f21b81793 --- /dev/null +++ b/module/lib/jinja2/testsuite/res/templates/syntaxerror.html @@ -0,0 +1,4 @@ +Foo +{% for item in broken %} + ... +{% endif %} diff --git a/module/lib/jinja2/testsuite/res/templates/test.html b/module/lib/jinja2/testsuite/res/templates/test.html new file mode 100644 index 000000000..ba578e48b --- /dev/null +++ b/module/lib/jinja2/testsuite/res/templates/test.html @@ -0,0 +1 @@ +BAR diff --git a/module/lib/jinja2/testsuite/security.py b/module/lib/jinja2/testsuite/security.py new file mode 100644 index 000000000..246d0f073 --- /dev/null +++ b/module/lib/jinja2/testsuite/security.py @@ -0,0 +1,166 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.security + ~~~~~~~~~~~~~~~~~~~~~~~~~ + + Checks the sandbox and other security features. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import unittest + +from jinja2.testsuite import JinjaTestCase + +from jinja2 import Environment +from jinja2.sandbox import SandboxedEnvironment, \ + ImmutableSandboxedEnvironment, unsafe +from jinja2 import Markup, escape +from jinja2.exceptions import SecurityError, TemplateSyntaxError, \ + TemplateRuntimeError +from jinja2._compat import text_type + + +class PrivateStuff(object): + + def bar(self): + return 23 + + @unsafe + def foo(self): + return 42 + + def __repr__(self): + return 'PrivateStuff' + + +class PublicStuff(object): + bar = lambda self: 23 + _foo = lambda self: 42 + + def __repr__(self): + return 'PublicStuff' + + +class SandboxTestCase(JinjaTestCase): + + def test_unsafe(self): + env = SandboxedEnvironment() + self.assert_raises(SecurityError, env.from_string("{{ foo.foo() }}").render, + foo=PrivateStuff()) + self.assert_equal(env.from_string("{{ foo.bar() }}").render(foo=PrivateStuff()), '23') + + self.assert_raises(SecurityError, env.from_string("{{ foo._foo() }}").render, + foo=PublicStuff()) + self.assert_equal(env.from_string("{{ foo.bar() }}").render(foo=PublicStuff()), '23') + self.assert_equal(env.from_string("{{ foo.__class__ }}").render(foo=42), '') + self.assert_equal(env.from_string("{{ foo.func_code }}").render(foo=lambda:None), '') + # security error comes from __class__ already. + self.assert_raises(SecurityError, env.from_string( + "{{ foo.__class__.__subclasses__() }}").render, foo=42) + + def test_immutable_environment(self): + env = ImmutableSandboxedEnvironment() + self.assert_raises(SecurityError, env.from_string( + '{{ [].append(23) }}').render) + self.assert_raises(SecurityError, env.from_string( + '{{ {1:2}.clear() }}').render) + + def test_restricted(self): + env = SandboxedEnvironment() + self.assert_raises(TemplateSyntaxError, env.from_string, + "{% for item.attribute in seq %}...{% endfor %}") + self.assert_raises(TemplateSyntaxError, env.from_string, + "{% for foo, bar.baz in seq %}...{% endfor %}") + + def test_markup_operations(self): + # adding two strings should escape the unsafe one + unsafe = '<script type="application/x-some-script">alert("foo");</script>' + safe = Markup('<em>username</em>') + assert unsafe + safe == text_type(escape(unsafe)) + text_type(safe) + + # string interpolations are safe to use too + assert Markup('<em>%s</em>') % '<bad user>' == \ + '<em><bad user></em>' + assert Markup('<em>%(username)s</em>') % { + 'username': '<bad user>' + } == '<em><bad user></em>' + + # an escaped object is markup too + assert type(Markup('foo') + 'bar') is Markup + + # and it implements __html__ by returning itself + x = Markup("foo") + assert x.__html__() is x + + # it also knows how to treat __html__ objects + class Foo(object): + def __html__(self): + return '<em>awesome</em>' + def __unicode__(self): + return 'awesome' + assert Markup(Foo()) == '<em>awesome</em>' + assert Markup('<strong>%s</strong>') % Foo() == \ + '<strong><em>awesome</em></strong>' + + # escaping and unescaping + assert escape('"<>&\'') == '"<>&'' + assert Markup("<em>Foo & Bar</em>").striptags() == "Foo & Bar" + assert Markup("<test>").unescape() == "<test>" + + def test_template_data(self): + env = Environment(autoescape=True) + t = env.from_string('{% macro say_hello(name) %}' + '<p>Hello {{ name }}!</p>{% endmacro %}' + '{{ say_hello("<blink>foo</blink>") }}') + escaped_out = '<p>Hello <blink>foo</blink>!</p>' + assert t.render() == escaped_out + assert text_type(t.module) == escaped_out + assert escape(t.module) == escaped_out + assert t.module.say_hello('<blink>foo</blink>') == escaped_out + assert escape(t.module.say_hello('<blink>foo</blink>')) == escaped_out + + def test_attr_filter(self): + env = SandboxedEnvironment() + tmpl = env.from_string('{{ cls|attr("__subclasses__")() }}') + self.assert_raises(SecurityError, tmpl.render, cls=int) + + def test_binary_operator_intercepting(self): + def disable_op(left, right): + raise TemplateRuntimeError('that operator so does not work') + for expr, ctx, rv in ('1 + 2', {}, '3'), ('a + 2', {'a': 2}, '4'): + env = SandboxedEnvironment() + env.binop_table['+'] = disable_op + t = env.from_string('{{ %s }}' % expr) + assert t.render(ctx) == rv + env.intercepted_binops = frozenset(['+']) + t = env.from_string('{{ %s }}' % expr) + try: + t.render(ctx) + except TemplateRuntimeError as e: + pass + else: + self.fail('expected runtime error') + + def test_unary_operator_intercepting(self): + def disable_op(arg): + raise TemplateRuntimeError('that operator so does not work') + for expr, ctx, rv in ('-1', {}, '-1'), ('-a', {'a': 2}, '-2'): + env = SandboxedEnvironment() + env.unop_table['-'] = disable_op + t = env.from_string('{{ %s }}' % expr) + assert t.render(ctx) == rv + env.intercepted_unops = frozenset(['-']) + t = env.from_string('{{ %s }}' % expr) + try: + t.render(ctx) + except TemplateRuntimeError as e: + pass + else: + self.fail('expected runtime error') + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(SandboxTestCase)) + return suite diff --git a/module/lib/jinja2/testsuite/tests.py b/module/lib/jinja2/testsuite/tests.py new file mode 100644 index 000000000..3ece7a8ff --- /dev/null +++ b/module/lib/jinja2/testsuite/tests.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.tests + ~~~~~~~~~~~~~~~~~~~~~~ + + Who tests the tests? + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import unittest +from jinja2.testsuite import JinjaTestCase + +from jinja2 import Markup, Environment + +env = Environment() + + +class TestsTestCase(JinjaTestCase): + + def test_defined(self): + tmpl = env.from_string('{{ missing is defined }}|{{ true is defined }}') + assert tmpl.render() == 'False|True' + + def test_even(self): + tmpl = env.from_string('''{{ 1 is even }}|{{ 2 is even }}''') + assert tmpl.render() == 'False|True' + + def test_odd(self): + tmpl = env.from_string('''{{ 1 is odd }}|{{ 2 is odd }}''') + assert tmpl.render() == 'True|False' + + def test_lower(self): + tmpl = env.from_string('''{{ "foo" is lower }}|{{ "FOO" is lower }}''') + assert tmpl.render() == 'True|False' + + def test_typechecks(self): + tmpl = env.from_string(''' + {{ 42 is undefined }} + {{ 42 is defined }} + {{ 42 is none }} + {{ none is none }} + {{ 42 is number }} + {{ 42 is string }} + {{ "foo" is string }} + {{ "foo" is sequence }} + {{ [1] is sequence }} + {{ range is callable }} + {{ 42 is callable }} + {{ range(5) is iterable }} + {{ {} is mapping }} + {{ mydict is mapping }} + {{ [] is mapping }} + ''') + class MyDict(dict): + pass + assert tmpl.render(mydict=MyDict()).split() == [ + 'False', 'True', 'False', 'True', 'True', 'False', + 'True', 'True', 'True', 'True', 'False', 'True', + 'True', 'True', 'False' + ] + + def test_sequence(self): + tmpl = env.from_string( + '{{ [1, 2, 3] is sequence }}|' + '{{ "foo" is sequence }}|' + '{{ 42 is sequence }}' + ) + assert tmpl.render() == 'True|True|False' + + def test_upper(self): + tmpl = env.from_string('{{ "FOO" is upper }}|{{ "foo" is upper }}') + assert tmpl.render() == 'True|False' + + def test_sameas(self): + tmpl = env.from_string('{{ foo is sameas false }}|' + '{{ 0 is sameas false }}') + assert tmpl.render(foo=False) == 'True|False' + + def test_no_paren_for_arg1(self): + tmpl = env.from_string('{{ foo is sameas none }}') + assert tmpl.render(foo=None) == 'True' + + def test_escaped(self): + env = Environment(autoescape=True) + tmpl = env.from_string('{{ x is escaped }}|{{ y is escaped }}') + assert tmpl.render(x='foo', y=Markup('foo')) == 'False|True' + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(TestsTestCase)) + return suite diff --git a/module/lib/jinja2/testsuite/utils.py b/module/lib/jinja2/testsuite/utils.py new file mode 100644 index 000000000..cab9b09a9 --- /dev/null +++ b/module/lib/jinja2/testsuite/utils.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +""" + jinja2.testsuite.utils + ~~~~~~~~~~~~~~~~~~~~~~ + + Tests utilities jinja uses. + + :copyright: (c) 2010 by the Jinja Team. + :license: BSD, see LICENSE for more details. +""" +import gc +import unittest + +import pickle + +from jinja2.testsuite import JinjaTestCase + +from jinja2.utils import LRUCache, escape, object_type_repr + + +class LRUCacheTestCase(JinjaTestCase): + + def test_simple(self): + d = LRUCache(3) + d["a"] = 1 + d["b"] = 2 + d["c"] = 3 + d["a"] + d["d"] = 4 + assert len(d) == 3 + assert 'a' in d and 'c' in d and 'd' in d and 'b' not in d + + def test_pickleable(self): + cache = LRUCache(2) + cache["foo"] = 42 + cache["bar"] = 23 + cache["foo"] + + for protocol in range(3): + copy = pickle.loads(pickle.dumps(cache, protocol)) + assert copy.capacity == cache.capacity + assert copy._mapping == cache._mapping + assert copy._queue == cache._queue + + +class HelpersTestCase(JinjaTestCase): + + def test_object_type_repr(self): + class X(object): + pass + self.assert_equal(object_type_repr(42), 'int object') + self.assert_equal(object_type_repr([]), 'list object') + self.assert_equal(object_type_repr(X()), + 'jinja2.testsuite.utils.X object') + self.assert_equal(object_type_repr(None), 'None') + self.assert_equal(object_type_repr(Ellipsis), 'Ellipsis') + + +class MarkupLeakTestCase(JinjaTestCase): + + def test_markup_leaks(self): + counts = set() + for count in range(20): + for item in range(1000): + escape("foo") + escape("<foo>") + escape(u"foo") + escape(u"<foo>") + counts.add(len(gc.get_objects())) + assert len(counts) == 1, 'ouch, c extension seems to leak objects' + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(LRUCacheTestCase)) + suite.addTest(unittest.makeSuite(HelpersTestCase)) + + # this test only tests the c extension + if not hasattr(escape, 'func_code'): + suite.addTest(unittest.makeSuite(MarkupLeakTestCase)) + + return suite diff --git a/module/lib/jinja2/utils.py b/module/lib/jinja2/utils.py index 7b77b8eb7..ddc47da0a 100644 --- a/module/lib/jinja2/utils.py +++ b/module/lib/jinja2/utils.py @@ -9,21 +9,17 @@ :license: BSD, see LICENSE for more details. """ import re -import sys import errno -try: - from thread import allocate_lock -except ImportError: - from dummy_thread import allocate_lock from collections import deque -from itertools import imap +from jinja2._compat import text_type, string_types, implements_iterator, \ + allocate_lock, url_quote _word_split_re = re.compile(r'(\s+)') _punctuation_re = re.compile( '^(?P<lead>(?:%s)*)(?P<middle>.*?)(?P<trail>(?:%s)*)$' % ( - '|'.join(imap(re.escape, ('(', '<', '<'))), - '|'.join(imap(re.escape, ('.', ',', ')', '>', '\n', '>'))) + '|'.join(map(re.escape, ('(', '<', '<'))), + '|'.join(map(re.escape, ('.', ',', ')', '>', '\n', '>'))) ) ) _simple_email_re = re.compile(r'^\S+@[a-zA-Z0-9._-]+\.[a-zA-Z0-9._-]+$') @@ -38,77 +34,7 @@ missing = type('MissingType', (), {'__repr__': lambda x: 'missing'})() # internal code internal_code = set() - -# concatenate a list of strings and convert them to unicode. -# unfortunately there is a bug in python 2.4 and lower that causes -# unicode.join trash the traceback. -_concat = u''.join -try: - def _test_gen_bug(): - raise TypeError(_test_gen_bug) - yield None - _concat(_test_gen_bug()) -except TypeError, _error: - if not _error.args or _error.args[0] is not _test_gen_bug: - def concat(gen): - try: - return _concat(list(gen)) - except: - # this hack is needed so that the current frame - # does not show up in the traceback. - exc_type, exc_value, tb = sys.exc_info() - raise exc_type, exc_value, tb.tb_next - else: - concat = _concat - del _test_gen_bug, _error - - -# for python 2.x we create outselves a next() function that does the -# basics without exception catching. -try: - next = next -except NameError: - def next(x): - return x.next() - - -# if this python version is unable to deal with unicode filenames -# when passed to encode we let this function encode it properly. -# This is used in a couple of places. As far as Jinja is concerned -# filenames are unicode *or* bytestrings in 2.x and unicode only in -# 3.x because compile cannot handle bytes -if sys.version_info < (3, 0): - def _encode_filename(filename): - if isinstance(filename, unicode): - return filename.encode('utf-8') - return filename -else: - def _encode_filename(filename): - assert filename is None or isinstance(filename, str), \ - 'filenames must be strings' - return filename - -from keyword import iskeyword as is_python_keyword - - -# common types. These do exist in the special types module too which however -# does not exist in IronPython out of the box. Also that way we don't have -# to deal with implementation specific stuff here -class _C(object): - def method(self): pass -def _func(): - yield None -FunctionType = type(_func) -GeneratorType = type(_func()) -MethodType = type(_C.method) -CodeType = type(_C.method.func_code) -try: - raise TypeError() -except TypeError: - _tb = sys.exc_info()[2] - TracebackType = type(_tb) - FrameType = type(_tb.tb_frame) -del _C, _tb, _func +concat = u''.join def contextfunction(f): @@ -128,7 +54,7 @@ def contextfunction(f): def evalcontextfunction(f): - """This decoraotr can be used to mark a function or method as an eval + """This decorator can be used to mark a function or method as an eval context callable. This is similar to the :func:`contextfunction` but instead of passing the context, an evaluation context object is passed. For more information about the eval context, see @@ -152,7 +78,7 @@ def environmentfunction(f): def internalcode(f): """Marks the function as internally used""" - internal_code.add(f.func_code) + internal_code.add(f.__code__) return f @@ -191,7 +117,7 @@ def clear_caches(): def import_string(import_name, silent=False): - """Imports an object based on a string. This use useful if you want to + """Imports an object based on a string. This is useful if you want to use import paths as endpoints or something similar. An import path can be specified either in dotted notation (``xml.sax.saxutils.escape``) or with a colon as object delimiter (``xml.sax.saxutils:escape``). @@ -222,7 +148,7 @@ def open_if_exists(filename, mode='rb'): """ try: return open(filename, mode) - except IOError, e: + except IOError as e: if e.errno not in (errno.ENOENT, errno.EISDIR): raise @@ -271,7 +197,7 @@ def urlize(text, trim_url_limit=None, nofollow=False): trim_url = lambda x, limit=trim_url_limit: limit is not None \ and (x[:limit] + (len(x) >=limit and '...' or '')) or x - words = _word_split_re.split(unicode(escape(text))) + words = _word_split_re.split(text_type(escape(text))) nofollow_attr = nofollow and ' rel="nofollow"' or '' for i, word in enumerate(words): match = _punctuation_re.match(word) @@ -280,6 +206,7 @@ def urlize(text, trim_url_limit=None, nofollow=False): if middle.startswith('www.') or ( '@' not in middle and not middle.startswith('http://') and + not middle.startswith('https://') and len(middle) > 0 and middle[0] in _letters + _digits and ( middle.endswith('.org') or @@ -307,7 +234,7 @@ def generate_lorem_ipsum(n=5, html=True, min=20, max=100): words = LOREM_IPSUM_WORDS.split() result = [] - for _ in xrange(n): + for _ in range(n): next_capitalized = True last_comma = last_fullstop = 0 word = None @@ -315,7 +242,7 @@ def generate_lorem_ipsum(n=5, html=True, min=20, max=100): p = [] # each paragraph contains out of 20 to 100 words. - for idx, _ in enumerate(xrange(randrange(min, max))): + for idx, _ in enumerate(range(randrange(min, max))): while True: word = choice(words) if word != last: @@ -349,6 +276,21 @@ def generate_lorem_ipsum(n=5, html=True, min=20, max=100): return Markup(u'\n'.join(u'<p>%s</p>' % escape(x) for x in result)) +def unicode_urlencode(obj, charset='utf-8'): + """URL escapes a single bytestring or unicode string with the + given charset if applicable to URL safe quoting under all rules + that need to be considered under all supported Python versions. + + If non strings are provided they are converted to their unicode + representation first. + """ + if not isinstance(obj, string_types): + obj = text_type(obj) + if isinstance(obj, text_type): + obj = obj.encode(charset) + return text_type(url_quote(obj)) + + class LRUCache(object): """A simple LRU Cache implementation.""" @@ -366,18 +308,10 @@ class LRUCache(object): # alias all queue methods for faster lookup self._popleft = self._queue.popleft self._pop = self._queue.pop - if hasattr(self._queue, 'remove'): - self._remove = self._queue.remove + self._remove = self._queue.remove self._wlock = allocate_lock() self._append = self._queue.append - def _remove(self, obj): - """Python 2.4 compatibility.""" - for idx, item in enumerate(self._queue): - if item == obj: - del self._queue[idx] - break - def __getstate__(self): return { 'capacity': self.capacity, @@ -393,7 +327,7 @@ class LRUCache(object): return (self.capacity,) def copy(self): - """Return an shallow copy of the instance.""" + """Return a shallow copy of the instance.""" rv = self.__class__(self.capacity) rv._mapping.update(self._mapping) rv._queue = deque(self._queue) @@ -410,11 +344,15 @@ class LRUCache(object): """Set `default` if the key is not in the cache otherwise leave unchanged. Return the value of this key. """ + self._wlock.acquire() try: - return self[key] - except KeyError: - self[key] = default - return default + try: + return self[key] + except KeyError: + self[key] = default + return default + finally: + self._wlock.release() def clear(self): """Clear the cache.""" @@ -443,19 +381,23 @@ class LRUCache(object): """Get an item from the cache. Moves the item up so that it has the highest priority then. - Raise an `KeyError` if it does not exist. + Raise a `KeyError` if it does not exist. """ - rv = self._mapping[key] - if self._queue[-1] != key: - try: - self._remove(key) - except ValueError: - # if something removed the key from the container - # when we read, ignore the ValueError that we would - # get otherwise. - pass - self._append(key) - return rv + self._wlock.acquire() + try: + rv = self._mapping[key] + if self._queue[-1] != key: + try: + self._remove(key) + except ValueError: + # if something removed the key from the container + # when we read, ignore the ValueError that we would + # get otherwise. + pass + self._append(key) + return rv + finally: + self._wlock.release() def __setitem__(self, key, value): """Sets the value for an item. Moves the item up so that it @@ -464,11 +406,7 @@ class LRUCache(object): self._wlock.acquire() try: if key in self._mapping: - try: - self._remove(key) - except ValueError: - # __getitem__ is not locked, it might happen - pass + self._remove(key) elif len(self._mapping) == self.capacity: del self._mapping[self._popleft()] self._append(key) @@ -478,7 +416,7 @@ class LRUCache(object): def __delitem__(self, key): """Remove an item from the cache dict. - Raise an `KeyError` if it does not exist. + Raise a `KeyError` if it does not exist. """ self._wlock.acquire() try: @@ -538,6 +476,7 @@ except ImportError: pass +@implements_iterator class Cycler(object): """A cycle helper for templates.""" @@ -556,7 +495,7 @@ class Cycler(object): """Returns the current item.""" return self.items[self.pos] - def next(self): + def __next__(self): """Goes one item ahead and returns it.""" rv = self.current self.pos = (self.pos + 1) % len(self.items) @@ -577,25 +516,5 @@ class Joiner(object): return self.sep -# try markupsafe first, if that fails go with Jinja2's bundled version -# of markupsafe. Markupsafe was previously Jinja2's implementation of -# the Markup object but was moved into a separate package in a patchleve -# release -try: - from markupsafe import Markup, escape, soft_unicode -except ImportError: - from jinja2._markupsafe import Markup, escape, soft_unicode - - -# partials -try: - from functools import partial -except ImportError: - class partial(object): - def __init__(self, _func, *args, **kwargs): - self._func = _func - self._args = args - self._kwargs = kwargs - def __call__(self, *args, **kwargs): - kwargs.update(self._kwargs) - return self._func(*(self._args + args), **kwargs) +# Imported here because that's where it was in the past +from markupsafe import Markup, escape, soft_unicode diff --git a/module/lib/jinja2/_markupsafe/__init__.py b/module/lib/markupsafe/__init__.py index ec7bd572d..275540154 100644 --- a/module/lib/jinja2/_markupsafe/__init__.py +++ b/module/lib/markupsafe/__init__.py @@ -9,7 +9,10 @@ :license: BSD, see LICENSE for more details. """ import re -from itertools import imap +import string +from collections import Mapping +from markupsafe._compat import text_type, string_types, int_types, \ + unichr, iteritems, PY2 __all__ = ['Markup', 'soft_unicode', 'escape', 'escape_silent'] @@ -19,7 +22,7 @@ _striptags_re = re.compile(r'(<!--.*?-->|<[^>]*>)') _entity_re = re.compile(r'&([^;]+);') -class Markup(unicode): +class Markup(text_type): r"""Marks a string as being safe for inclusion in HTML/XML output without needing to be escaped. This implements the `__html__` interface a couple of frameworks and web applications use. :class:`Markup` is a direct @@ -40,7 +43,7 @@ class Markup(unicode): >>> class Foo(object): ... def __html__(self): ... return '<a href="#">foo</a>' - ... + ... >>> Markup(Foo()) Markup(u'<a href="#">foo</a>') @@ -68,65 +71,66 @@ class Markup(unicode): if hasattr(base, '__html__'): base = base.__html__() if encoding is None: - return unicode.__new__(cls, base) - return unicode.__new__(cls, base, encoding, errors) + return text_type.__new__(cls, base) + return text_type.__new__(cls, base, encoding, errors) def __html__(self): return self def __add__(self, other): - if hasattr(other, '__html__') or isinstance(other, basestring): - return self.__class__(unicode(self) + unicode(escape(other))) + if isinstance(other, string_types) or hasattr(other, '__html__'): + return self.__class__(super(Markup, self).__add__(self.escape(other))) return NotImplemented def __radd__(self, other): - if hasattr(other, '__html__') or isinstance(other, basestring): - return self.__class__(unicode(escape(other)) + unicode(self)) + if hasattr(other, '__html__') or isinstance(other, string_types): + return self.escape(other).__add__(self) return NotImplemented def __mul__(self, num): - if isinstance(num, (int, long)): - return self.__class__(unicode.__mul__(self, num)) + if isinstance(num, int_types): + return self.__class__(text_type.__mul__(self, num)) return NotImplemented __rmul__ = __mul__ def __mod__(self, arg): if isinstance(arg, tuple): - arg = tuple(imap(_MarkupEscapeHelper, arg)) + arg = tuple(_MarkupEscapeHelper(x, self.escape) for x in arg) else: - arg = _MarkupEscapeHelper(arg) - return self.__class__(unicode.__mod__(self, arg)) + arg = _MarkupEscapeHelper(arg, self.escape) + return self.__class__(text_type.__mod__(self, arg)) def __repr__(self): return '%s(%s)' % ( self.__class__.__name__, - unicode.__repr__(self) + text_type.__repr__(self) ) def join(self, seq): - return self.__class__(unicode.join(self, imap(escape, seq))) - join.__doc__ = unicode.join.__doc__ + return self.__class__(text_type.join(self, map(self.escape, seq))) + join.__doc__ = text_type.join.__doc__ def split(self, *args, **kwargs): - return map(self.__class__, unicode.split(self, *args, **kwargs)) - split.__doc__ = unicode.split.__doc__ + return list(map(self.__class__, text_type.split(self, *args, **kwargs))) + split.__doc__ = text_type.split.__doc__ def rsplit(self, *args, **kwargs): - return map(self.__class__, unicode.rsplit(self, *args, **kwargs)) - rsplit.__doc__ = unicode.rsplit.__doc__ + return list(map(self.__class__, text_type.rsplit(self, *args, **kwargs))) + rsplit.__doc__ = text_type.rsplit.__doc__ def splitlines(self, *args, **kwargs): - return map(self.__class__, unicode.splitlines(self, *args, **kwargs)) - splitlines.__doc__ = unicode.splitlines.__doc__ + return list(map(self.__class__, text_type.splitlines( + self, *args, **kwargs))) + splitlines.__doc__ = text_type.splitlines.__doc__ def unescape(self): - r"""Unescape markup again into an unicode string. This also resolves + r"""Unescape markup again into an text_type string. This also resolves known HTML4 and XHTML entities: >>> Markup("Main » <em>About</em>").unescape() u'Main \xbb <em>About</em>' """ - from jinja2._markupsafe._constants import HTML_ENTITIES + from markupsafe._constants import HTML_ENTITIES def handle_match(m): name = m.group(1) if name in HTML_ENTITIES: @@ -139,10 +143,10 @@ class Markup(unicode): except ValueError: pass return u'' - return _entity_re.sub(handle_match, unicode(self)) + return _entity_re.sub(handle_match, text_type(self)) def striptags(self): - r"""Unescape markup into an unicode string and strip all tags. This + r"""Unescape markup into an text_type string and strip all tags. This also resolves known HTML4 and XHTML entities. Whitespace is normalized to one: @@ -163,11 +167,11 @@ class Markup(unicode): return cls(rv) return rv - def make_wrapper(name): - orig = getattr(unicode, name) + def make_simple_escaping_wrapper(name): + orig = getattr(text_type, name) def func(self, *args, **kwargs): - args = _escape_argspec(list(args), enumerate(args)) - _escape_argspec(kwargs, kwargs.iteritems()) + args = _escape_argspec(list(args), enumerate(args), self.escape) + _escape_argspec(kwargs, iteritems(kwargs), self.escape) return self.__class__(orig(self, *args, **kwargs)) func.__name__ = orig.__name__ func.__doc__ = orig.__doc__ @@ -177,28 +181,93 @@ class Markup(unicode): 'title', 'lower', 'upper', 'replace', 'ljust', \ 'rjust', 'lstrip', 'rstrip', 'center', 'strip', \ 'translate', 'expandtabs', 'swapcase', 'zfill': - locals()[method] = make_wrapper(method) + locals()[method] = make_simple_escaping_wrapper(method) # new in python 2.5 - if hasattr(unicode, 'partition'): - partition = make_wrapper('partition'), - rpartition = make_wrapper('rpartition') + if hasattr(text_type, 'partition'): + def partition(self, sep): + return tuple(map(self.__class__, + text_type.partition(self, self.escape(sep)))) + def rpartition(self, sep): + return tuple(map(self.__class__, + text_type.rpartition(self, self.escape(sep)))) # new in python 2.6 - if hasattr(unicode, 'format'): - format = make_wrapper('format') + if hasattr(text_type, 'format'): + def format(*args, **kwargs): + self, args = args[0], args[1:] + formatter = EscapeFormatter(self.escape) + kwargs = _MagicFormatMapping(args, kwargs) + return self.__class__(formatter.vformat(self, args, kwargs)) + + def __html_format__(self, format_spec): + if format_spec: + raise ValueError('Unsupported format specification ' + 'for Markup.') + return self # not in python 3 - if hasattr(unicode, '__getslice__'): - __getslice__ = make_wrapper('__getslice__') + if hasattr(text_type, '__getslice__'): + __getslice__ = make_simple_escaping_wrapper('__getslice__') - del method, make_wrapper + del method, make_simple_escaping_wrapper -def _escape_argspec(obj, iterable): +class _MagicFormatMapping(Mapping): + """This class implements a dummy wrapper to fix a bug in the Python + standard library for string formatting. + + See http://bugs.python.org/issue13598 for information about why + this is necessary. + """ + + def __init__(self, args, kwargs): + self._args = args + self._kwargs = kwargs + self._last_index = 0 + + def __getitem__(self, key): + if key == '': + idx = self._last_index + self._last_index += 1 + try: + return self._args[idx] + except LookupError: + pass + key = str(idx) + return self._kwargs[key] + + def __iter__(self): + return iter(self._kwargs) + + def __len__(self): + return len(self._kwargs) + + +if hasattr(text_type, 'format'): + class EscapeFormatter(string.Formatter): + + def __init__(self, escape): + self.escape = escape + + def format_field(self, value, format_spec): + if hasattr(value, '__html_format__'): + rv = value.__html_format__(format_spec) + elif hasattr(value, '__html__'): + if format_spec: + raise ValueError('No format specification allowed ' + 'when formatting an object with ' + 'its __html__ method.') + rv = value.__html__() + else: + rv = string.Formatter.format_field(self, value, format_spec) + return text_type(self.escape(rv)) + + +def _escape_argspec(obj, iterable, escape): """Helper for various string-wrapped functions.""" for key, value in iterable: - if hasattr(value, '__html__') or isinstance(value, basestring): + if hasattr(value, '__html__') or isinstance(value, string_types): obj[key] = escape(value) return obj @@ -206,13 +275,13 @@ def _escape_argspec(obj, iterable): class _MarkupEscapeHelper(object): """Helper for Markup.__mod__""" - def __init__(self, obj): + def __init__(self, obj, escape): self.obj = obj + self.escape = escape - __getitem__ = lambda s, x: _MarkupEscapeHelper(s.obj[x]) - __str__ = lambda s: str(escape(s.obj)) - __unicode__ = lambda s: unicode(escape(s.obj)) - __repr__ = lambda s: str(escape(repr(s.obj))) + __getitem__ = lambda s, x: _MarkupEscapeHelper(s.obj[x], s.escape) + __unicode__ = __str__ = lambda s: text_type(s.escape(s.obj)) + __repr__ = lambda s: str(s.escape(repr(s.obj))) __int__ = lambda s: int(s.obj) __float__ = lambda s: float(s.obj) @@ -220,6 +289,10 @@ class _MarkupEscapeHelper(object): # we have to import it down here as the speedups and native # modules imports the markup type which is define above. try: - from jinja2._markupsafe._speedups import escape, escape_silent, soft_unicode + from markupsafe._speedups import escape, escape_silent, soft_unicode except ImportError: - from jinja2._markupsafe._native import escape, escape_silent, soft_unicode + from markupsafe._native import escape, escape_silent, soft_unicode + +if not PY2: + soft_str = soft_unicode + __all__.append('soft_str') diff --git a/module/lib/markupsafe/_compat.py b/module/lib/markupsafe/_compat.py new file mode 100644 index 000000000..62e5632ad --- /dev/null +++ b/module/lib/markupsafe/_compat.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +""" + markupsafe._compat + ~~~~~~~~~~~~~~~~~~ + + Compatibility module for different Python versions. + + :copyright: (c) 2013 by Armin Ronacher. + :license: BSD, see LICENSE for more details. +""" +import sys + +PY2 = sys.version_info[0] == 2 + +if not PY2: + text_type = str + string_types = (str,) + unichr = chr + int_types = (int,) + iteritems = lambda x: iter(x.items()) +else: + text_type = unicode + string_types = (str, unicode) + unichr = unichr + int_types = (int, long) + iteritems = lambda x: x.iteritems() diff --git a/module/lib/jinja2/_markupsafe/_constants.py b/module/lib/markupsafe/_constants.py index 919bf03c5..919bf03c5 100644 --- a/module/lib/jinja2/_markupsafe/_constants.py +++ b/module/lib/markupsafe/_constants.py diff --git a/module/lib/jinja2/_markupsafe/_native.py b/module/lib/markupsafe/_native.py index 7b95828ec..5e83f10a1 100644 --- a/module/lib/jinja2/_markupsafe/_native.py +++ b/module/lib/markupsafe/_native.py @@ -8,7 +8,8 @@ :copyright: (c) 2010 by Armin Ronacher. :license: BSD, see LICENSE for more details. """ -from jinja2._markupsafe import Markup +from markupsafe import Markup +from markupsafe._compat import text_type def escape(s): @@ -18,7 +19,7 @@ def escape(s): """ if hasattr(s, '__html__'): return s.__html__() - return Markup(unicode(s) + return Markup(text_type(s) .replace('&', '&') .replace('>', '>') .replace('<', '<') @@ -40,6 +41,6 @@ def soft_unicode(s): """Make a string unicode if it isn't already. That way a markup string is not converted back to unicode. """ - if not isinstance(s, unicode): - s = unicode(s) + if not isinstance(s, text_type): + s = text_type(s) return s diff --git a/module/lib/markupsafe/_speedups.c b/module/lib/markupsafe/_speedups.c new file mode 100644 index 000000000..f349febf2 --- /dev/null +++ b/module/lib/markupsafe/_speedups.c @@ -0,0 +1,239 @@ +/** + * markupsafe._speedups + * ~~~~~~~~~~~~~~~~~~~~ + * + * This module implements functions for automatic escaping in C for better + * performance. + * + * :copyright: (c) 2010 by Armin Ronacher. + * :license: BSD. + */ + +#include <Python.h> + +#define ESCAPED_CHARS_TABLE_SIZE 63 +#define UNICHR(x) (PyUnicode_AS_UNICODE((PyUnicodeObject*)PyUnicode_DecodeASCII(x, strlen(x), NULL))); + +#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN) +typedef int Py_ssize_t; +#define PY_SSIZE_T_MAX INT_MAX +#define PY_SSIZE_T_MIN INT_MIN +#endif + + +static PyObject* markup; +static Py_ssize_t escaped_chars_delta_len[ESCAPED_CHARS_TABLE_SIZE]; +static Py_UNICODE *escaped_chars_repl[ESCAPED_CHARS_TABLE_SIZE]; + +static int +init_constants(void) +{ + PyObject *module; + /* happing of characters to replace */ + escaped_chars_repl['"'] = UNICHR("""); + escaped_chars_repl['\''] = UNICHR("'"); + escaped_chars_repl['&'] = UNICHR("&"); + escaped_chars_repl['<'] = UNICHR("<"); + escaped_chars_repl['>'] = UNICHR(">"); + + /* lengths of those characters when replaced - 1 */ + memset(escaped_chars_delta_len, 0, sizeof (escaped_chars_delta_len)); + escaped_chars_delta_len['"'] = escaped_chars_delta_len['\''] = \ + escaped_chars_delta_len['&'] = 4; + escaped_chars_delta_len['<'] = escaped_chars_delta_len['>'] = 3; + + /* import markup type so that we can mark the return value */ + module = PyImport_ImportModule("markupsafe"); + if (!module) + return 0; + markup = PyObject_GetAttrString(module, "Markup"); + Py_DECREF(module); + + return 1; +} + +static PyObject* +escape_unicode(PyUnicodeObject *in) +{ + PyUnicodeObject *out; + Py_UNICODE *inp = PyUnicode_AS_UNICODE(in); + const Py_UNICODE *inp_end = PyUnicode_AS_UNICODE(in) + PyUnicode_GET_SIZE(in); + Py_UNICODE *next_escp; + Py_UNICODE *outp; + Py_ssize_t delta=0, erepl=0, delta_len=0; + + /* First we need to figure out how long the escaped string will be */ + while (*(inp) || inp < inp_end) { + if (*inp < ESCAPED_CHARS_TABLE_SIZE) { + delta += escaped_chars_delta_len[*inp]; + erepl += !!escaped_chars_delta_len[*inp]; + } + ++inp; + } + + /* Do we need to escape anything at all? */ + if (!erepl) { + Py_INCREF(in); + return (PyObject*)in; + } + + out = (PyUnicodeObject*)PyUnicode_FromUnicode(NULL, PyUnicode_GET_SIZE(in) + delta); + if (!out) + return NULL; + + outp = PyUnicode_AS_UNICODE(out); + inp = PyUnicode_AS_UNICODE(in); + while (erepl-- > 0) { + /* look for the next substitution */ + next_escp = inp; + while (next_escp < inp_end) { + if (*next_escp < ESCAPED_CHARS_TABLE_SIZE && + (delta_len = escaped_chars_delta_len[*next_escp])) { + ++delta_len; + break; + } + ++next_escp; + } + + if (next_escp > inp) { + /* copy unescaped chars between inp and next_escp */ + Py_UNICODE_COPY(outp, inp, next_escp-inp); + outp += next_escp - inp; + } + + /* escape 'next_escp' */ + Py_UNICODE_COPY(outp, escaped_chars_repl[*next_escp], delta_len); + outp += delta_len; + + inp = next_escp + 1; + } + if (inp < inp_end) + Py_UNICODE_COPY(outp, inp, PyUnicode_GET_SIZE(in) - (inp - PyUnicode_AS_UNICODE(in))); + + return (PyObject*)out; +} + + +static PyObject* +escape(PyObject *self, PyObject *text) +{ + PyObject *s = NULL, *rv = NULL, *html; + + /* we don't have to escape integers, bools or floats */ + if (PyLong_CheckExact(text) || +#if PY_MAJOR_VERSION < 3 + PyInt_CheckExact(text) || +#endif + PyFloat_CheckExact(text) || PyBool_Check(text) || + text == Py_None) + return PyObject_CallFunctionObjArgs(markup, text, NULL); + + /* if the object has an __html__ method that performs the escaping */ + html = PyObject_GetAttrString(text, "__html__"); + if (html) { + rv = PyObject_CallObject(html, NULL); + Py_DECREF(html); + return rv; + } + + /* otherwise make the object unicode if it isn't, then escape */ + PyErr_Clear(); + if (!PyUnicode_Check(text)) { +#if PY_MAJOR_VERSION < 3 + PyObject *unicode = PyObject_Unicode(text); +#else + PyObject *unicode = PyObject_Str(text); +#endif + if (!unicode) + return NULL; + s = escape_unicode((PyUnicodeObject*)unicode); + Py_DECREF(unicode); + } + else + s = escape_unicode((PyUnicodeObject*)text); + + /* convert the unicode string into a markup object. */ + rv = PyObject_CallFunctionObjArgs(markup, (PyObject*)s, NULL); + Py_DECREF(s); + return rv; +} + + +static PyObject* +escape_silent(PyObject *self, PyObject *text) +{ + if (text != Py_None) + return escape(self, text); + return PyObject_CallFunctionObjArgs(markup, NULL); +} + + +static PyObject* +soft_unicode(PyObject *self, PyObject *s) +{ + if (!PyUnicode_Check(s)) +#if PY_MAJOR_VERSION < 3 + return PyObject_Unicode(s); +#else + return PyObject_Str(s); +#endif + Py_INCREF(s); + return s; +} + + +static PyMethodDef module_methods[] = { + {"escape", (PyCFunction)escape, METH_O, + "escape(s) -> markup\n\n" + "Convert the characters &, <, >, ', and \" in string s to HTML-safe\n" + "sequences. Use this if you need to display text that might contain\n" + "such characters in HTML. Marks return value as markup string."}, + {"escape_silent", (PyCFunction)escape_silent, METH_O, + "escape_silent(s) -> markup\n\n" + "Like escape but converts None to an empty string."}, + {"soft_unicode", (PyCFunction)soft_unicode, METH_O, + "soft_unicode(object) -> string\n\n" + "Make a string unicode if it isn't already. That way a markup\n" + "string is not converted back to unicode."}, + {NULL, NULL, 0, NULL} /* Sentinel */ +}; + + +#if PY_MAJOR_VERSION < 3 + +#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */ +#define PyMODINIT_FUNC void +#endif +PyMODINIT_FUNC +init_speedups(void) +{ + if (!init_constants()) + return; + + Py_InitModule3("markupsafe._speedups", module_methods, ""); +} + +#else /* Python 3.x module initialization */ + +static struct PyModuleDef module_definition = { + PyModuleDef_HEAD_INIT, + "markupsafe._speedups", + NULL, + -1, + module_methods, + NULL, + NULL, + NULL, + NULL +}; + +PyMODINIT_FUNC +PyInit__speedups(void) +{ + if (!init_constants()) + return NULL; + + return PyModule_Create(&module_definition); +} + +#endif diff --git a/module/lib/markupsafe/tests.py b/module/lib/markupsafe/tests.py new file mode 100644 index 000000000..636993629 --- /dev/null +++ b/module/lib/markupsafe/tests.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +import gc +import sys +import unittest +from markupsafe import Markup, escape, escape_silent +from markupsafe._compat import text_type + + +class MarkupTestCase(unittest.TestCase): + + def test_adding(self): + # adding two strings should escape the unsafe one + unsafe = '<script type="application/x-some-script">alert("foo");</script>' + safe = Markup('<em>username</em>') + assert unsafe + safe == text_type(escape(unsafe)) + text_type(safe) + + def test_string_interpolation(self): + # string interpolations are safe to use too + assert Markup('<em>%s</em>') % '<bad user>' == \ + '<em><bad user></em>' + assert Markup('<em>%(username)s</em>') % { + 'username': '<bad user>' + } == '<em><bad user></em>' + + assert Markup('%i') % 3.14 == '3' + assert Markup('%.2f') % 3.14 == '3.14' + + def test_type_behavior(self): + # an escaped object is markup too + assert type(Markup('foo') + 'bar') is Markup + + # and it implements __html__ by returning itself + x = Markup("foo") + assert x.__html__() is x + + def test_html_interop(self): + # it also knows how to treat __html__ objects + class Foo(object): + def __html__(self): + return '<em>awesome</em>' + def __unicode__(self): + return 'awesome' + __str__ = __unicode__ + assert Markup(Foo()) == '<em>awesome</em>' + assert Markup('<strong>%s</strong>') % Foo() == \ + '<strong><em>awesome</em></strong>' + + def test_tuple_interpol(self): + self.assertEqual(Markup('<em>%s:%s</em>') % ( + '<foo>', + '<bar>', + ), Markup(u'<em><foo>:<bar></em>')) + + def test_dict_interpol(self): + self.assertEqual(Markup('<em>%(foo)s</em>') % { + 'foo': '<foo>', + }, Markup(u'<em><foo></em>')) + self.assertEqual(Markup('<em>%(foo)s:%(bar)s</em>') % { + 'foo': '<foo>', + 'bar': '<bar>', + }, Markup(u'<em><foo>:<bar></em>')) + + def test_escaping(self): + # escaping and unescaping + assert escape('"<>&\'') == '"<>&'' + assert Markup("<em>Foo & Bar</em>").striptags() == "Foo & Bar" + assert Markup("<test>").unescape() == "<test>" + + def test_formatting(self): + for actual, expected in ( + (Markup('%i') % 3.14, '3'), + (Markup('%.2f') % 3.14159, '3.14'), + (Markup('%s %s %s') % ('<', 123, '>'), '< 123 >'), + (Markup('<em>{awesome}</em>').format(awesome='<awesome>'), + '<em><awesome></em>'), + (Markup('{0[1][bar]}').format([0, {'bar': '<bar/>'}]), + '<bar/>'), + (Markup('{0[1][bar]}').format([0, {'bar': Markup('<bar/>')}]), + '<bar/>')): + assert actual == expected, "%r should be %r!" % (actual, expected) + + # This is new in 2.7 + if sys.version_info >= (2, 7): + def test_formatting_empty(self): + formatted = Markup('{}').format(0) + assert formatted == Markup('0') + + def test_custom_formatting(self): + class HasHTMLOnly(object): + def __html__(self): + return Markup('<foo>') + + class HasHTMLAndFormat(object): + def __html__(self): + return Markup('<foo>') + def __html_format__(self, spec): + return Markup('<FORMAT>') + + assert Markup('{0}').format(HasHTMLOnly()) == Markup('<foo>') + assert Markup('{0}').format(HasHTMLAndFormat()) == Markup('<FORMAT>') + + def test_complex_custom_formatting(self): + class User(object): + def __init__(self, id, username): + self.id = id + self.username = username + def __html_format__(self, format_spec): + if format_spec == 'link': + return Markup('<a href="/user/{0}">{1}</a>').format( + self.id, + self.__html__(), + ) + elif format_spec: + raise ValueError('Invalid format spec') + return self.__html__() + def __html__(self): + return Markup('<span class=user>{0}</span>').format(self.username) + + user = User(1, 'foo') + assert Markup('<p>User: {0:link}').format(user) == \ + Markup('<p>User: <a href="/user/1"><span class=user>foo</span></a>') + + def test_all_set(self): + import markupsafe as markup + for item in markup.__all__: + getattr(markup, item) + + def test_escape_silent(self): + assert escape_silent(None) == Markup() + assert escape(None) == Markup(None) + assert escape_silent('<foo>') == Markup(u'<foo>') + + def test_splitting(self): + self.assertEqual(Markup('a b').split(), [ + Markup('a'), + Markup('b') + ]) + self.assertEqual(Markup('a b').rsplit(), [ + Markup('a'), + Markup('b') + ]) + self.assertEqual(Markup('a\nb').splitlines(), [ + Markup('a'), + Markup('b') + ]) + + def test_mul(self): + self.assertEqual(Markup('a') * 3, Markup('aaa')) + + +class MarkupLeakTestCase(unittest.TestCase): + + def test_markup_leaks(self): + counts = set() + for count in range(20): + for item in range(1000): + escape("foo") + escape("<foo>") + escape(u"foo") + escape(u"<foo>") + counts.add(len(gc.get_objects())) + assert len(counts) == 1, 'ouch, c extension seems to leak objects' + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(MarkupTestCase)) + + # this test only tests the c extension + if not hasattr(escape, 'func_code'): + suite.addTest(unittest.makeSuite(MarkupLeakTestCase)) + + return suite + + +if __name__ == '__main__': + unittest.main(defaultTest='suite') + +# vim:sts=4:sw=4:et: diff --git a/module/lib/simplejson/__init__.py b/module/lib/simplejson/__init__.py index ef5c0db48..a5c01379a 100644 --- a/module/lib/simplejson/__init__.py +++ b/module/lib/simplejson/__init__.py @@ -14,15 +14,15 @@ Encoding basic Python object hierarchies:: >>> import simplejson as json >>> json.dumps(['foo', {'bar': ('baz', None, 1.0, 2)}]) '["foo", {"bar": ["baz", null, 1.0, 2]}]' - >>> print json.dumps("\"foo\bar") + >>> print(json.dumps("\"foo\bar")) "\"foo\bar" - >>> print json.dumps(u'\u1234') + >>> print(json.dumps(u'\u1234')) "\u1234" - >>> print json.dumps('\\') + >>> print(json.dumps('\\')) "\\" - >>> print json.dumps({"c": 0, "b": 0, "a": 0}, sort_keys=True) + >>> print(json.dumps({"c": 0, "b": 0, "a": 0}, sort_keys=True)) {"a": 0, "b": 0, "c": 0} - >>> from StringIO import StringIO + >>> from simplejson.compat import StringIO >>> io = StringIO() >>> json.dump(['streaming API'], io) >>> io.getvalue() @@ -31,14 +31,14 @@ Encoding basic Python object hierarchies:: Compact encoding:: >>> import simplejson as json - >>> json.dumps([1,2,3,{'4': 5, '6': 7}], separators=(',',':')) + >>> obj = [1,2,3,{'4': 5, '6': 7}] + >>> json.dumps(obj, separators=(',',':'), sort_keys=True) '[1,2,3,{"4":5,"6":7}]' Pretty printing:: >>> import simplejson as json - >>> s = json.dumps({'4': 5, '6': 7}, sort_keys=True, indent=' ') - >>> print '\n'.join([l.rstrip() for l in s.splitlines()]) + >>> print(json.dumps({'4': 5, '6': 7}, sort_keys=True, indent=' ')) { "4": 5, "6": 7 @@ -52,7 +52,7 @@ Decoding JSON:: True >>> json.loads('"\\"foo\\bar"') == u'"foo\x08ar' True - >>> from StringIO import StringIO + >>> from simplejson.compat import StringIO >>> io = StringIO('["streaming API"]') >>> json.load(io)[0] == 'streaming API' True @@ -95,33 +95,35 @@ Using simplejson.tool from the shell to validate and pretty-print:: "json": "obj" } $ echo '{ 1.2:3.4}' | python -m simplejson.tool - Expecting property name: line 1 column 2 (char 2) + Expecting property name: line 1 column 3 (char 2) """ -__version__ = '2.2.1' +from __future__ import absolute_import +__version__ = '3.5.3' __all__ = [ 'dump', 'dumps', 'load', 'loads', 'JSONDecoder', 'JSONDecodeError', 'JSONEncoder', - 'OrderedDict', + 'OrderedDict', 'simple_first', ] __author__ = 'Bob Ippolito <bob@redivi.com>' from decimal import Decimal -from decoder import JSONDecoder, JSONDecodeError -from encoder import JSONEncoder +from .scanner import JSONDecodeError +from .decoder import JSONDecoder +from .encoder import JSONEncoder, JSONEncoderForHTML def _import_OrderedDict(): import collections try: return collections.OrderedDict except AttributeError: - import ordered_dict + from . import ordered_dict return ordered_dict.OrderedDict OrderedDict = _import_OrderedDict() def _import_c_make_encoder(): try: - from simplejson._speedups import make_encoder + from ._speedups import make_encoder return make_encoder except ImportError: return None @@ -138,34 +140,41 @@ _default_encoder = JSONEncoder( use_decimal=True, namedtuple_as_object=True, tuple_as_array=True, + bigint_as_string=False, + item_sort_key=None, + for_json=False, + ignore_nan=False, + int_as_string_bitcount=None, ) def dump(obj, fp, skipkeys=False, ensure_ascii=True, check_circular=True, - allow_nan=True, cls=None, indent=None, separators=None, - encoding='utf-8', default=None, use_decimal=True, - namedtuple_as_object=True, tuple_as_array=True, - **kw): + allow_nan=True, cls=None, indent=None, separators=None, + encoding='utf-8', default=None, use_decimal=True, + namedtuple_as_object=True, tuple_as_array=True, + bigint_as_string=False, sort_keys=False, item_sort_key=None, + for_json=False, ignore_nan=False, int_as_string_bitcount=None, **kw): """Serialize ``obj`` as a JSON formatted stream to ``fp`` (a ``.write()``-supporting file-like object). - If ``skipkeys`` is true then ``dict`` keys that are not basic types + If *skipkeys* is true then ``dict`` keys that are not basic types (``str``, ``unicode``, ``int``, ``long``, ``float``, ``bool``, ``None``) will be skipped instead of raising a ``TypeError``. - If ``ensure_ascii`` is false, then the some chunks written to ``fp`` + If *ensure_ascii* is false, then the some chunks written to ``fp`` may be ``unicode`` instances, subject to normal Python ``str`` to ``unicode`` coercion rules. Unless ``fp.write()`` explicitly understands ``unicode`` (as in ``codecs.getwriter()``) this is likely to cause an error. - If ``check_circular`` is false, then the circular reference check + If *check_circular* is false, then the circular reference check for container types will be skipped and a circular reference will result in an ``OverflowError`` (or worse). - If ``allow_nan`` is false, then it will be a ``ValueError`` to + If *allow_nan* is false, then it will be a ``ValueError`` to serialize out of range ``float`` values (``nan``, ``inf``, ``-inf``) - in strict compliance of the JSON specification, instead of using the - JavaScript equivalents (``NaN``, ``Infinity``, ``-Infinity``). + in strict compliance of the original JSON specification, instead of using + the JavaScript equivalents (``NaN``, ``Infinity``, ``-Infinity``). See + *ignore_nan* for ECMA-262 compliant behavior. If *indent* is a string, then JSON array elements and object members will be pretty-printed with a newline followed by that string repeated @@ -174,14 +183,16 @@ def dump(obj, fp, skipkeys=False, ensure_ascii=True, check_circular=True, versions of simplejson earlier than 2.1.0, an integer is also accepted and is converted to a string with that many spaces. - If ``separators`` is an ``(item_separator, dict_separator)`` tuple - then it will be used instead of the default ``(', ', ': ')`` separators. - ``(',', ':')`` is the most compact JSON representation. + If specified, *separators* should be an + ``(item_separator, key_separator)`` tuple. The default is ``(', ', ': ')`` + if *indent* is ``None`` and ``(',', ': ')`` otherwise. To get the most + compact JSON representation, you should specify ``(',', ':')`` to eliminate + whitespace. - ``encoding`` is the character encoding for str instances, default is UTF-8. + *encoding* is the character encoding for str instances, default is UTF-8. - ``default(obj)`` is a function that should return a serializable version - of obj or raise TypeError. The default simply raises TypeError. + *default(obj)* is a function that should return a serializable version + of obj or raise ``TypeError``. The default simply raises ``TypeError``. If *use_decimal* is true (default: ``True``) then decimal.Decimal will be natively serialized to JSON with full precision. @@ -189,13 +200,41 @@ def dump(obj, fp, skipkeys=False, ensure_ascii=True, check_circular=True, If *namedtuple_as_object* is true (default: ``True``), :class:`tuple` subclasses with ``_asdict()`` methods will be encoded as JSON objects. - + If *tuple_as_array* is true (default: ``True``), :class:`tuple` (and subclasses) will be encoded as JSON arrays. + If *bigint_as_string* is true (default: ``False``), ints 2**53 and higher + or lower than -2**53 will be encoded as strings. This is to avoid the + rounding that happens in Javascript otherwise. Note that this is still a + lossy operation that will not round-trip correctly and should be used + sparingly. + + If *int_as_string_bitcount* is a positive number (n), then int of size + greater than or equal to 2**n or lower than or equal to -2**n will be + encoded as strings. + + If specified, *item_sort_key* is a callable used to sort the items in + each dictionary. This is useful if you want to sort items other than + in alphabetical order by key. This option takes precedence over + *sort_keys*. + + If *sort_keys* is true (default: ``False``), the output of dictionaries + will be sorted by item. + + If *for_json* is true (default: ``False``), objects with a ``for_json()`` + method will use the return value of that method for encoding as JSON + instead of the object. + + If *ignore_nan* is true (default: ``False``), then out of range + :class:`float` values (``nan``, ``inf``, ``-inf``) will be serialized as + ``null`` in compliance with the ECMA-262 specification. If true, this will + override *allow_nan*. + To use a custom ``JSONEncoder`` subclass (e.g. one that overrides the ``.default()`` method to serialize additional types), specify it with - the ``cls`` kwarg. + the ``cls`` kwarg. NOTE: You should use *default* or *for_json* instead + of subclassing whenever possible. """ # cached encoder @@ -203,7 +242,9 @@ def dump(obj, fp, skipkeys=False, ensure_ascii=True, check_circular=True, check_circular and allow_nan and cls is None and indent is None and separators is None and encoding == 'utf-8' and default is None and use_decimal - and namedtuple_as_object and tuple_as_array and not kw): + and namedtuple_as_object and tuple_as_array + and not bigint_as_string and int_as_string_bitcount is None + and not item_sort_key and not for_json and not ignore_nan and not kw): iterable = _default_encoder.iterencode(obj) else: if cls is None: @@ -214,6 +255,12 @@ def dump(obj, fp, skipkeys=False, ensure_ascii=True, check_circular=True, default=default, use_decimal=use_decimal, namedtuple_as_object=namedtuple_as_object, tuple_as_array=tuple_as_array, + bigint_as_string=bigint_as_string, + sort_keys=sort_keys, + item_sort_key=item_sort_key, + for_json=for_json, + ignore_nan=ignore_nan, + int_as_string_bitcount=int_as_string_bitcount, **kw).iterencode(obj) # could accelerate with writelines in some versions of Python, at # a debuggability cost @@ -222,11 +269,11 @@ def dump(obj, fp, skipkeys=False, ensure_ascii=True, check_circular=True, def dumps(obj, skipkeys=False, ensure_ascii=True, check_circular=True, - allow_nan=True, cls=None, indent=None, separators=None, - encoding='utf-8', default=None, use_decimal=True, - namedtuple_as_object=True, - tuple_as_array=True, - **kw): + allow_nan=True, cls=None, indent=None, separators=None, + encoding='utf-8', default=None, use_decimal=True, + namedtuple_as_object=True, tuple_as_array=True, + bigint_as_string=False, sort_keys=False, item_sort_key=None, + for_json=False, ignore_nan=False, int_as_string_bitcount=None, **kw): """Serialize ``obj`` to a JSON formatted ``str``. If ``skipkeys`` is false then ``dict`` keys that are not basic types @@ -253,9 +300,11 @@ def dumps(obj, skipkeys=False, ensure_ascii=True, check_circular=True, versions of simplejson earlier than 2.1.0, an integer is also accepted and is converted to a string with that many spaces. - If ``separators`` is an ``(item_separator, dict_separator)`` tuple - then it will be used instead of the default ``(', ', ': ')`` separators. - ``(',', ':')`` is the most compact JSON representation. + If specified, ``separators`` should be an + ``(item_separator, key_separator)`` tuple. The default is ``(', ', ': ')`` + if *indent* is ``None`` and ``(',', ': ')`` otherwise. To get the most + compact JSON representation, you should specify ``(',', ':')`` to eliminate + whitespace. ``encoding`` is the character encoding for str instances, default is UTF-8. @@ -268,21 +317,52 @@ def dumps(obj, skipkeys=False, ensure_ascii=True, check_circular=True, If *namedtuple_as_object* is true (default: ``True``), :class:`tuple` subclasses with ``_asdict()`` methods will be encoded as JSON objects. - + If *tuple_as_array* is true (default: ``True``), :class:`tuple` (and subclasses) will be encoded as JSON arrays. + If *bigint_as_string* is true (not the default), ints 2**53 and higher + or lower than -2**53 will be encoded as strings. This is to avoid the + rounding that happens in Javascript otherwise. + + If *int_as_string_bitcount* is a positive number (n), then int of size + greater than or equal to 2**n or lower than or equal to -2**n will be + encoded as strings. + + If specified, *item_sort_key* is a callable used to sort the items in + each dictionary. This is useful if you want to sort items other than + in alphabetical order by key. This option takes precendence over + *sort_keys*. + + If *sort_keys* is true (default: ``False``), the output of dictionaries + will be sorted by item. + + If *for_json* is true (default: ``False``), objects with a ``for_json()`` + method will use the return value of that method for encoding as JSON + instead of the object. + + If *ignore_nan* is true (default: ``False``), then out of range + :class:`float` values (``nan``, ``inf``, ``-inf``) will be serialized as + ``null`` in compliance with the ECMA-262 specification. If true, this will + override *allow_nan*. + To use a custom ``JSONEncoder`` subclass (e.g. one that overrides the ``.default()`` method to serialize additional types), specify it with - the ``cls`` kwarg. + the ``cls`` kwarg. NOTE: You should use *default* instead of subclassing + whenever possible. """ # cached encoder - if (not skipkeys and ensure_ascii and + if ( + not skipkeys and ensure_ascii and check_circular and allow_nan and cls is None and indent is None and separators is None and encoding == 'utf-8' and default is None and use_decimal - and namedtuple_as_object and tuple_as_array and not kw): + and namedtuple_as_object and tuple_as_array + and not bigint_as_string and int_as_string_bitcount is None + and not sort_keys and not item_sort_key and not for_json + and not ignore_nan and not kw + ): return _default_encoder.encode(obj) if cls is None: cls = JSONEncoder @@ -293,6 +373,12 @@ def dumps(obj, skipkeys=False, ensure_ascii=True, check_circular=True, use_decimal=use_decimal, namedtuple_as_object=namedtuple_as_object, tuple_as_array=tuple_as_array, + bigint_as_string=bigint_as_string, + sort_keys=sort_keys, + item_sort_key=item_sort_key, + for_json=for_json, + ignore_nan=ignore_nan, + int_as_string_bitcount=int_as_string_bitcount, **kw).encode(obj) @@ -347,7 +433,8 @@ def load(fp, encoding=None, cls=None, object_hook=None, parse_float=None, parse_float=decimal.Decimal for parity with ``dump``. To use a custom ``JSONDecoder`` subclass, specify it with the ``cls`` - kwarg. + kwarg. NOTE: You should use *object_hook* or *object_pairs_hook* instead + of subclassing whenever possible. """ return loads(fp.read(), @@ -403,7 +490,8 @@ def loads(s, encoding=None, cls=None, object_hook=None, parse_float=None, parse_float=decimal.Decimal for parity with ``dump``. To use a custom ``JSONDecoder`` subclass, specify it with the ``cls`` - kwarg. + kwarg. NOTE: You should use *object_hook* or *object_pairs_hook* instead + of subclassing whenever possible. """ if (cls is None and encoding is None and object_hook is None and @@ -431,14 +519,14 @@ def loads(s, encoding=None, cls=None, object_hook=None, parse_float=None, def _toggle_speedups(enabled): - import simplejson.decoder as dec - import simplejson.encoder as enc - import simplejson.scanner as scan + from . import decoder as dec + from . import encoder as enc + from . import scanner as scan c_make_encoder = _import_c_make_encoder() if enabled: dec.scanstring = dec.c_scanstring or dec.py_scanstring enc.c_make_encoder = c_make_encoder - enc.encode_basestring_ascii = (enc.c_encode_basestring_ascii or + enc.encode_basestring_ascii = (enc.c_encode_basestring_ascii or enc.py_encode_basestring_ascii) scan.make_scanner = scan.c_make_scanner or scan.py_make_scanner else: @@ -464,3 +552,9 @@ def _toggle_speedups(enabled): encoding='utf-8', default=None, ) + +def simple_first(kv): + """Helper function to pass to item_sort_key to sort simple + elements to the top, then container elements. + """ + return (isinstance(kv[1], (list, dict, tuple)), kv[0]) diff --git a/module/lib/simplejson/compat.py b/module/lib/simplejson/compat.py new file mode 100644 index 000000000..a0af4a1cb --- /dev/null +++ b/module/lib/simplejson/compat.py @@ -0,0 +1,46 @@ +"""Python 3 compatibility shims +""" +import sys +if sys.version_info[0] < 3: + PY3 = False + def b(s): + return s + def u(s): + return unicode(s, 'unicode_escape') + import cStringIO as StringIO + StringIO = BytesIO = StringIO.StringIO + text_type = unicode + binary_type = str + string_types = (basestring,) + integer_types = (int, long) + unichr = unichr + reload_module = reload + def fromhex(s): + return s.decode('hex') + +else: + PY3 = True + if sys.version_info[:2] >= (3, 4): + from importlib import reload as reload_module + else: + from imp import reload as reload_module + import codecs + def b(s): + return codecs.latin_1_encode(s)[0] + def u(s): + return s + import io + StringIO = io.StringIO + BytesIO = io.BytesIO + text_type = str + binary_type = bytes + string_types = (str,) + integer_types = (int,) + + def unichr(s): + return u(chr(s)) + + def fromhex(s): + return bytes.fromhex(s) + +long_type = integer_types[-1] diff --git a/module/lib/simplejson/decoder.py b/module/lib/simplejson/decoder.py index e5496d6e7..1a6c5d938 100644 --- a/module/lib/simplejson/decoder.py +++ b/module/lib/simplejson/decoder.py @@ -1,24 +1,28 @@ """Implementation of JSONDecoder """ +from __future__ import absolute_import import re import sys import struct +from .compat import fromhex, b, u, text_type, binary_type, PY3, unichr +from .scanner import make_scanner, JSONDecodeError -from simplejson.scanner import make_scanner def _import_c_scanstring(): try: - from simplejson._speedups import scanstring + from ._speedups import scanstring return scanstring except ImportError: return None c_scanstring = _import_c_scanstring() +# NOTE (3.1.0): JSONDecodeError may still be imported from this module for +# compatibility, but it was never in the __all__ __all__ = ['JSONDecoder'] FLAGS = re.VERBOSE | re.MULTILINE | re.DOTALL def _floatconstants(): - _BYTES = '7FF80000000000007FF0000000000000'.decode('hex') + _BYTES = fromhex('7FF80000000000007FF0000000000000') # The struct module in Python 2.4 would get frexp() out of range here # when an endian is specified in the format string. Fixed in Python 2.5+ if sys.byteorder != 'big': @@ -28,57 +32,6 @@ def _floatconstants(): NaN, PosInf, NegInf = _floatconstants() - -class JSONDecodeError(ValueError): - """Subclass of ValueError with the following additional properties: - - msg: The unformatted error message - doc: The JSON document being parsed - pos: The start index of doc where parsing failed - end: The end index of doc where parsing failed (may be None) - lineno: The line corresponding to pos - colno: The column corresponding to pos - endlineno: The line corresponding to end (may be None) - endcolno: The column corresponding to end (may be None) - - """ - def __init__(self, msg, doc, pos, end=None): - ValueError.__init__(self, errmsg(msg, doc, pos, end=end)) - self.msg = msg - self.doc = doc - self.pos = pos - self.end = end - self.lineno, self.colno = linecol(doc, pos) - if end is not None: - self.endlineno, self.endcolno = linecol(doc, end) - else: - self.endlineno, self.endcolno = None, None - - -def linecol(doc, pos): - lineno = doc.count('\n', 0, pos) + 1 - if lineno == 1: - colno = pos - else: - colno = pos - doc.rindex('\n', 0, pos) - return lineno, colno - - -def errmsg(msg, doc, pos, end=None): - # Note that this function is called from _speedups - lineno, colno = linecol(doc, pos) - if end is None: - #fmt = '{0}: line {1} column {2} (char {3})' - #return fmt.format(msg, lineno, colno, pos) - fmt = '%s: line %d column %d (char %d)' - return fmt % (msg, lineno, colno, pos) - endlineno, endcolno = linecol(doc, end) - #fmt = '{0}: line {1} column {2} - line {3} column {4} (char {5} - {6})' - #return fmt.format(msg, lineno, colno, endlineno, endcolno, pos, end) - fmt = '%s: line %d column %d - line %d column %d (char %d - %d)' - return fmt % (msg, lineno, colno, endlineno, endcolno, pos, end) - - _CONSTANTS = { '-Infinity': NegInf, 'Infinity': PosInf, @@ -87,14 +40,15 @@ _CONSTANTS = { STRINGCHUNK = re.compile(r'(.*?)(["\\\x00-\x1f])', FLAGS) BACKSLASH = { - '"': u'"', '\\': u'\\', '/': u'/', - 'b': u'\b', 'f': u'\f', 'n': u'\n', 'r': u'\r', 't': u'\t', + '"': u('"'), '\\': u('\u005c'), '/': u('/'), + 'b': u('\b'), 'f': u('\f'), 'n': u('\n'), 'r': u('\r'), 't': u('\t'), } DEFAULT_ENCODING = "utf-8" def py_scanstring(s, end, encoding=None, strict=True, - _b=BACKSLASH, _m=STRINGCHUNK.match): + _b=BACKSLASH, _m=STRINGCHUNK.match, _join=u('').join, + _PY3=PY3, _maxunicode=sys.maxunicode): """Scan the string s for a JSON string. End is the index of the character in s after the quote that started the JSON string. Unescapes all valid JSON string escape sequences and raises ValueError @@ -117,8 +71,8 @@ def py_scanstring(s, end, encoding=None, strict=True, content, terminator = chunk.groups() # Content is contains zero or more unescaped string characters if content: - if not isinstance(content, unicode): - content = unicode(content, encoding) + if not _PY3 and not isinstance(content, text_type): + content = text_type(content, encoding) _append(content) # Terminator is the end of string, a literal control character, # or a backslash denoting that an escape sequence follows @@ -126,8 +80,7 @@ def py_scanstring(s, end, encoding=None, strict=True, break elif terminator != '\\': if strict: - msg = "Invalid control character %r at" % (terminator,) - #msg = "Invalid control character {0!r} at".format(terminator) + msg = "Invalid control character %r at" raise JSONDecodeError(msg, s, end) else: _append(terminator) @@ -142,33 +95,42 @@ def py_scanstring(s, end, encoding=None, strict=True, try: char = _b[esc] except KeyError: - msg = "Invalid \\escape: " + repr(esc) + msg = "Invalid \\X escape sequence %r" raise JSONDecodeError(msg, s, end) end += 1 else: # Unicode escape sequence + msg = "Invalid \\uXXXX escape sequence" esc = s[end + 1:end + 5] - next_end = end + 5 - if len(esc) != 4: - msg = "Invalid \\uXXXX escape" - raise JSONDecodeError(msg, s, end) - uni = int(esc, 16) + escX = esc[1:2] + if len(esc) != 4 or escX == 'x' or escX == 'X': + raise JSONDecodeError(msg, s, end - 1) + try: + uni = int(esc, 16) + except ValueError: + raise JSONDecodeError(msg, s, end - 1) + end += 5 # Check for surrogate pair on UCS-4 systems - if 0xd800 <= uni <= 0xdbff and sys.maxunicode > 65535: - msg = "Invalid \\uXXXX\\uXXXX surrogate pair" - if not s[end + 5:end + 7] == '\\u': - raise JSONDecodeError(msg, s, end) - esc2 = s[end + 7:end + 11] - if len(esc2) != 4: - raise JSONDecodeError(msg, s, end) - uni2 = int(esc2, 16) - uni = 0x10000 + (((uni - 0xd800) << 10) | (uni2 - 0xdc00)) - next_end += 6 + # Note that this will join high/low surrogate pairs + # but will also pass unpaired surrogates through + if (_maxunicode > 65535 and + uni & 0xfc00 == 0xd800 and + s[end:end + 2] == '\\u'): + esc2 = s[end + 2:end + 6] + escX = esc2[1:2] + if len(esc2) == 4 and not (escX == 'x' or escX == 'X'): + try: + uni2 = int(esc2, 16) + except ValueError: + raise JSONDecodeError(msg, s, end) + if uni2 & 0xfc00 == 0xdc00: + uni = 0x10000 + (((uni - 0xd800) << 10) | + (uni2 - 0xdc00)) + end += 6 char = unichr(uni) - end = next_end # Append the unescaped character _append(char) - return u''.join(chunks), end + return _join(chunks), end # Use speedup if available @@ -177,9 +139,10 @@ scanstring = c_scanstring or py_scanstring WHITESPACE = re.compile(r'[ \t\n\r]*', FLAGS) WHITESPACE_STR = ' \t\n\r' -def JSONObject((s, end), encoding, strict, scan_once, object_hook, +def JSONObject(state, encoding, strict, scan_once, object_hook, object_pairs_hook, memo=None, _w=WHITESPACE.match, _ws=WHITESPACE_STR): + (s, end) = state # Backwards compatibility if memo is None: memo = {} @@ -203,7 +166,9 @@ def JSONObject((s, end), encoding, strict, scan_once, object_hook, pairs = object_hook(pairs) return pairs, end + 1 elif nextchar != '"': - raise JSONDecodeError("Expecting property name", s, end) + raise JSONDecodeError( + "Expecting property name enclosed in double quotes", + s, end) end += 1 while True: key, end = scanstring(s, end, encoding, strict) @@ -214,7 +179,7 @@ def JSONObject((s, end), encoding, strict, scan_once, object_hook, if s[end:end + 1] != ':': end = _w(s, end).end() if s[end:end + 1] != ':': - raise JSONDecodeError("Expecting : delimiter", s, end) + raise JSONDecodeError("Expecting ':' delimiter", s, end) end += 1 @@ -226,10 +191,7 @@ def JSONObject((s, end), encoding, strict, scan_once, object_hook, except IndexError: pass - try: - value, end = scan_once(s, end) - except StopIteration: - raise JSONDecodeError("Expecting object", s, end) + value, end = scan_once(s, end) pairs.append((key, value)) try: @@ -244,7 +206,7 @@ def JSONObject((s, end), encoding, strict, scan_once, object_hook, if nextchar == '}': break elif nextchar != ',': - raise JSONDecodeError("Expecting , delimiter", s, end - 1) + raise JSONDecodeError("Expecting ',' delimiter or '}'", s, end - 1) try: nextchar = s[end] @@ -259,7 +221,9 @@ def JSONObject((s, end), encoding, strict, scan_once, object_hook, end += 1 if nextchar != '"': - raise JSONDecodeError("Expecting property name", s, end - 1) + raise JSONDecodeError( + "Expecting property name enclosed in double quotes", + s, end - 1) if object_pairs_hook is not None: result = object_pairs_hook(pairs) @@ -269,7 +233,8 @@ def JSONObject((s, end), encoding, strict, scan_once, object_hook, pairs = object_hook(pairs) return pairs, end -def JSONArray((s, end), scan_once, _w=WHITESPACE.match, _ws=WHITESPACE_STR): +def JSONArray(state, scan_once, _w=WHITESPACE.match, _ws=WHITESPACE_STR): + (s, end) = state values = [] nextchar = s[end:end + 1] if nextchar in _ws: @@ -278,12 +243,11 @@ def JSONArray((s, end), scan_once, _w=WHITESPACE.match, _ws=WHITESPACE_STR): # Look-ahead for trivial empty array if nextchar == ']': return values, end + 1 + elif nextchar == '': + raise JSONDecodeError("Expecting value or ']'", s, end) _append = values.append while True: - try: - value, end = scan_once(s, end) - except StopIteration: - raise JSONDecodeError("Expecting object", s, end) + value, end = scan_once(s, end) _append(value) nextchar = s[end:end + 1] if nextchar in _ws: @@ -293,7 +257,7 @@ def JSONArray((s, end), scan_once, _w=WHITESPACE.match, _ws=WHITESPACE_STR): if nextchar == ']': break elif nextchar != ',': - raise JSONDecodeError("Expecting , delimiter", s, end) + raise JSONDecodeError("Expecting ',' delimiter or ']'", s, end - 1) try: if s[end] in _ws: @@ -317,7 +281,7 @@ class JSONDecoder(object): +---------------+-------------------+ | array | list | +---------------+-------------------+ - | string | unicode | + | string | str, unicode | +---------------+-------------------+ | number (int) | int, long | +---------------+-------------------+ @@ -381,6 +345,8 @@ class JSONDecoder(object): ``False`` then control characters will be allowed in strings. """ + if encoding is None: + encoding = DEFAULT_ENCODING self.encoding = encoding self.object_hook = object_hook self.object_pairs_hook = object_pairs_hook @@ -394,28 +360,34 @@ class JSONDecoder(object): self.memo = {} self.scan_once = make_scanner(self) - def decode(self, s, _w=WHITESPACE.match): + def decode(self, s, _w=WHITESPACE.match, _PY3=PY3): """Return the Python representation of ``s`` (a ``str`` or ``unicode`` instance containing a JSON document) """ - obj, end = self.raw_decode(s, idx=_w(s, 0).end()) + if _PY3 and isinstance(s, binary_type): + s = s.decode(self.encoding) + obj, end = self.raw_decode(s) end = _w(s, end).end() if end != len(s): raise JSONDecodeError("Extra data", s, end, len(s)) return obj - def raw_decode(self, s, idx=0): + def raw_decode(self, s, idx=0, _w=WHITESPACE.match, _PY3=PY3): """Decode a JSON document from ``s`` (a ``str`` or ``unicode`` beginning with a JSON document) and return a 2-tuple of the Python representation and the index in ``s`` where the document ended. + Optionally, ``idx`` can be used to specify an offset in ``s`` where + the JSON document begins. This can be used to decode a JSON document from a string that may have extraneous data at the end. """ - try: - obj, end = self.scan_once(s, idx) - except StopIteration: - raise JSONDecodeError("No JSON object could be decoded", s, idx) - return obj, end + if idx < 0: + # Ensure that raw_decode bails on negative indexes, the regex + # would otherwise mask this behavior. #98 + raise JSONDecodeError('Expecting value', s, idx) + if _PY3 and not isinstance(s, text_type): + raise TypeError("Input string must be text, not bytes") + return self.scan_once(s, idx=_w(s, idx).end()) diff --git a/module/lib/simplejson/encoder.py b/module/lib/simplejson/encoder.py index 5ec7440f1..db18244ec 100644 --- a/module/lib/simplejson/encoder.py +++ b/module/lib/simplejson/encoder.py @@ -1,11 +1,13 @@ """Implementation of JSONEncoder """ +from __future__ import absolute_import import re +from operator import itemgetter from decimal import Decimal - +from .compat import u, unichr, binary_type, string_types, integer_types, PY3 def _import_speedups(): try: - from simplejson import _speedups + from . import _speedups return _speedups.encode_basestring_ascii, _speedups.make_encoder except ImportError: return None, None @@ -13,7 +15,10 @@ c_encode_basestring_ascii, c_make_encoder = _import_speedups() from simplejson.decoder import PosInf -ESCAPE = re.compile(ur'[\x00-\x1f\\"\b\f\n\r\t\u2028\u2029]') +#ESCAPE = re.compile(ur'[\x00-\x1f\\"\b\f\n\r\t\u2028\u2029]') +# This is required because u() will mangle the string and ur'' isn't valid +# python3 syntax +ESCAPE = re.compile(u'[\\x00-\\x1f\\\\"\\b\\f\\n\\r\\t\u2028\u2029]') ESCAPE_ASCII = re.compile(r'([\\"]|[^\ -~])') HAS_UTF8 = re.compile(r'[\x80-\xff]') ESCAPE_DCT = { @@ -24,32 +29,40 @@ ESCAPE_DCT = { '\n': '\\n', '\r': '\\r', '\t': '\\t', - u'\u2028': '\\u2028', - u'\u2029': '\\u2029', } for i in range(0x20): #ESCAPE_DCT.setdefault(chr(i), '\\u{0:04x}'.format(i)) ESCAPE_DCT.setdefault(chr(i), '\\u%04x' % (i,)) +for i in [0x2028, 0x2029]: + ESCAPE_DCT.setdefault(unichr(i), '\\u%04x' % (i,)) FLOAT_REPR = repr -def encode_basestring(s): +def encode_basestring(s, _PY3=PY3, _q=u('"')): """Return a JSON representation of a Python string """ - if isinstance(s, str) and HAS_UTF8.search(s) is not None: - s = s.decode('utf-8') + if _PY3: + if isinstance(s, binary_type): + s = s.decode('utf-8') + else: + if isinstance(s, str) and HAS_UTF8.search(s) is not None: + s = s.decode('utf-8') def replace(match): return ESCAPE_DCT[match.group(0)] - return u'"' + ESCAPE.sub(replace, s) + u'"' + return _q + ESCAPE.sub(replace, s) + _q -def py_encode_basestring_ascii(s): +def py_encode_basestring_ascii(s, _PY3=PY3): """Return an ASCII-only JSON representation of a Python string """ - if isinstance(s, str) and HAS_UTF8.search(s) is not None: - s = s.decode('utf-8') + if _PY3: + if isinstance(s, binary_type): + s = s.decode('utf-8') + else: + if isinstance(s, str) and HAS_UTF8.search(s) is not None: + s = s.decode('utf-8') def replace(match): s = match.group(0) try: @@ -103,11 +116,14 @@ class JSONEncoder(object): """ item_separator = ', ' key_separator = ': ' + def __init__(self, skipkeys=False, ensure_ascii=True, - check_circular=True, allow_nan=True, sort_keys=False, - indent=None, separators=None, encoding='utf-8', default=None, - use_decimal=True, namedtuple_as_object=True, - tuple_as_array=True): + check_circular=True, allow_nan=True, sort_keys=False, + indent=None, separators=None, encoding='utf-8', default=None, + use_decimal=True, namedtuple_as_object=True, + tuple_as_array=True, bigint_as_string=False, + item_sort_key=None, for_json=False, ignore_nan=False, + int_as_string_bitcount=None): """Constructor for JSONEncoder, with sensible defaults. If skipkeys is false, then it is a TypeError to attempt @@ -139,9 +155,10 @@ class JSONEncoder(object): versions of simplejson earlier than 2.1.0, an integer is also accepted and is converted to a string with that many spaces. - If specified, separators should be a (item_separator, key_separator) - tuple. The default is (', ', ': '). To get the most compact JSON - representation you should specify (',', ':') to eliminate whitespace. + If specified, separators should be an (item_separator, key_separator) + tuple. The default is (', ', ': ') if *indent* is ``None`` and + (',', ': ') otherwise. To get the most compact JSON representation, + you should specify (',', ':') to eliminate whitespace. If specified, default is a function that gets called for objects that can't otherwise be serialized. It should return a JSON encodable @@ -155,11 +172,33 @@ class JSONEncoder(object): be supported directly by the encoder. For the inverse, decode JSON with ``parse_float=decimal.Decimal``. - If namedtuple_as_object is true (the default), tuple subclasses with + If namedtuple_as_object is true (the default), objects with ``_asdict()`` methods will be encoded as JSON objects. - + If tuple_as_array is true (the default), tuple (and subclasses) will be encoded as JSON arrays. + + If bigint_as_string is true (not the default), ints 2**53 and higher + or lower than -2**53 will be encoded as strings. This is to avoid the + rounding that happens in Javascript otherwise. + + If int_as_string_bitcount is a positive number (n), then int of size + greater than or equal to 2**n or lower than or equal to -2**n will be + encoded as strings. + + If specified, item_sort_key is a callable used to sort the items in + each dictionary. This is useful if you want to sort items other than + in alphabetical order by key. + + If for_json is true (not the default), objects with a ``for_json()`` + method will use the return value of that method for encoding as JSON + instead of the object. + + If *ignore_nan* is true (default: ``False``), then out of range + :class:`float` values (``nan``, ``inf``, ``-inf``) will be serialized + as ``null`` in compliance with the ECMA-262 specification. If true, + this will override *allow_nan*. + """ self.skipkeys = skipkeys @@ -170,8 +209,13 @@ class JSONEncoder(object): self.use_decimal = use_decimal self.namedtuple_as_object = namedtuple_as_object self.tuple_as_array = tuple_as_array - if isinstance(indent, (int, long)): - indent = ' ' * indent + self.bigint_as_string = bigint_as_string + self.item_sort_key = item_sort_key + self.for_json = for_json + self.ignore_nan = ignore_nan + self.int_as_string_bitcount = int_as_string_bitcount + if indent is not None and not isinstance(indent, string_types): + indent = indent * ' ' self.indent = indent if separators is not None: self.item_separator, self.key_separator = separators @@ -210,12 +254,11 @@ class JSONEncoder(object): """ # This is for extremely simple cases and benchmarks. - if isinstance(o, basestring): - if isinstance(o, str): - _encoding = self.encoding - if (_encoding is not None - and not (_encoding == 'utf-8')): - o = o.decode(_encoding) + if isinstance(o, binary_type): + _encoding = self.encoding + if (_encoding is not None and not (_encoding == 'utf-8')): + o = o.decode(_encoding) + if isinstance(o, string_types): if self.ensure_ascii: return encode_basestring_ascii(o) else: @@ -251,11 +294,11 @@ class JSONEncoder(object): _encoder = encode_basestring if self.encoding != 'utf-8': def _encoder(o, _orig_encoder=_encoder, _encoding=self.encoding): - if isinstance(o, str): + if isinstance(o, binary_type): o = o.decode(_encoding) return _orig_encoder(o) - def floatstr(o, allow_nan=self.allow_nan, + def floatstr(o, allow_nan=self.allow_nan, ignore_nan=self.ignore_nan, _repr=FLOAT_REPR, _inf=PosInf, _neginf=-PosInf): # Check for specials. Note that this type of test is processor # and/or platform-specific, so do tests which don't depend on @@ -270,28 +313,37 @@ class JSONEncoder(object): else: return _repr(o) - if not allow_nan: + if ignore_nan: + text = 'null' + elif not allow_nan: raise ValueError( "Out of range float values are not JSON compliant: " + repr(o)) return text - key_memo = {} + int_as_string_bitcount = ( + 53 if self.bigint_as_string else self.int_as_string_bitcount) if (_one_shot and c_make_encoder is not None and self.indent is None): _iterencode = c_make_encoder( markers, self.default, _encoder, self.indent, self.key_separator, self.item_separator, self.sort_keys, self.skipkeys, self.allow_nan, key_memo, self.use_decimal, - self.namedtuple_as_object, self.tuple_as_array) + self.namedtuple_as_object, self.tuple_as_array, + int_as_string_bitcount, + self.item_sort_key, self.encoding, self.for_json, + self.ignore_nan, Decimal) else: _iterencode = _make_iterencode( markers, self.default, _encoder, self.indent, floatstr, self.key_separator, self.item_separator, self.sort_keys, self.skipkeys, _one_shot, self.use_decimal, - self.namedtuple_as_object, self.tuple_as_array) + self.namedtuple_as_object, self.tuple_as_array, + int_as_string_bitcount, + self.item_sort_key, self.encoding, self.for_json, + Decimal=Decimal) try: return _iterencode(o, 0) finally: @@ -328,22 +380,46 @@ class JSONEncoderForHTML(JSONEncoder): def _make_iterencode(markers, _default, _encoder, _indent, _floatstr, _key_separator, _item_separator, _sort_keys, _skipkeys, _one_shot, _use_decimal, _namedtuple_as_object, _tuple_as_array, + _int_as_string_bitcount, _item_sort_key, + _encoding,_for_json, ## HACK: hand-optimized bytecode; turn globals into locals - False=False, - True=True, + _PY3=PY3, ValueError=ValueError, - basestring=basestring, + string_types=string_types, Decimal=Decimal, dict=dict, float=float, id=id, - int=int, + integer_types=integer_types, isinstance=isinstance, list=list, - long=long, str=str, tuple=tuple, ): + if _item_sort_key and not callable(_item_sort_key): + raise TypeError("item_sort_key must be None or callable") + elif _sort_keys and not _item_sort_key: + _item_sort_key = itemgetter(0) + + if (_int_as_string_bitcount is not None and + (_int_as_string_bitcount <= 0 or + not isinstance(_int_as_string_bitcount, integer_types))): + raise TypeError("int_as_string_bitcount must be a positive integer") + + def _encode_int(value): + skip_quoting = ( + _int_as_string_bitcount is None + or + _int_as_string_bitcount < 1 + ) + if ( + skip_quoting or + (-1 << _int_as_string_bitcount) + < value < + (1 << _int_as_string_bitcount) + ): + return str(value) + return '"' + str(value) + '"' def _iterencode_list(lst, _current_indent_level): if not lst: @@ -369,7 +445,8 @@ def _make_iterencode(markers, _default, _encoder, _indent, _floatstr, first = False else: buf = separator - if isinstance(value, basestring): + if (isinstance(value, string_types) or + (_PY3 and isinstance(value, binary_type))): yield buf + _encoder(value) elif value is None: yield buf + 'null' @@ -377,26 +454,30 @@ def _make_iterencode(markers, _default, _encoder, _indent, _floatstr, yield buf + 'true' elif value is False: yield buf + 'false' - elif isinstance(value, (int, long)): - yield buf + str(value) + elif isinstance(value, integer_types): + yield buf + _encode_int(value) elif isinstance(value, float): yield buf + _floatstr(value) elif _use_decimal and isinstance(value, Decimal): yield buf + str(value) else: yield buf - if isinstance(value, list): + for_json = _for_json and getattr(value, 'for_json', None) + if for_json and callable(for_json): + chunks = _iterencode(for_json(), _current_indent_level) + elif isinstance(value, list): chunks = _iterencode_list(value, _current_indent_level) - elif (_namedtuple_as_object and isinstance(value, tuple) and - hasattr(value, '_asdict')): - chunks = _iterencode_dict(value._asdict(), - _current_indent_level) - elif _tuple_as_array and isinstance(value, tuple): - chunks = _iterencode_list(value, _current_indent_level) - elif isinstance(value, dict): - chunks = _iterencode_dict(value, _current_indent_level) else: - chunks = _iterencode(value, _current_indent_level) + _asdict = _namedtuple_as_object and getattr(value, '_asdict', None) + if _asdict and callable(_asdict): + chunks = _iterencode_dict(_asdict(), + _current_indent_level) + elif _tuple_as_array and isinstance(value, tuple): + chunks = _iterencode_list(value, _current_indent_level) + elif isinstance(value, dict): + chunks = _iterencode_dict(value, _current_indent_level) + else: + chunks = _iterencode(value, _current_indent_level) for chunk in chunks: yield chunk if newline_indent is not None: @@ -406,6 +487,29 @@ def _make_iterencode(markers, _default, _encoder, _indent, _floatstr, if markers is not None: del markers[markerid] + def _stringify_key(key): + if isinstance(key, string_types): # pragma: no cover + pass + elif isinstance(key, binary_type): + key = key.decode(_encoding) + elif isinstance(key, float): + key = _floatstr(key) + elif key is True: + key = 'true' + elif key is False: + key = 'false' + elif key is None: + key = 'null' + elif isinstance(key, integer_types): + key = str(key) + elif _use_decimal and isinstance(key, Decimal): + key = str(key) + elif _skipkeys: + key = None + else: + raise TypeError("key " + repr(key) + " is not a string") + return key + def _iterencode_dict(dct, _current_indent_level): if not dct: yield '{}' @@ -425,37 +529,35 @@ def _make_iterencode(markers, _default, _encoder, _indent, _floatstr, newline_indent = None item_separator = _item_separator first = True - if _sort_keys: - items = dct.items() - items.sort(key=lambda kv: kv[0]) + if _PY3: + iteritems = dct.items() else: - items = dct.iteritems() + iteritems = dct.iteritems() + if _item_sort_key: + items = [] + for k, v in dct.items(): + if not isinstance(k, string_types): + k = _stringify_key(k) + if k is None: + continue + items.append((k, v)) + items.sort(key=_item_sort_key) + else: + items = iteritems for key, value in items: - if isinstance(key, basestring): - pass - # JavaScript is weakly typed for these, so it makes sense to - # also allow them. Many encoders seem to do something like this. - elif isinstance(key, float): - key = _floatstr(key) - elif key is True: - key = 'true' - elif key is False: - key = 'false' - elif key is None: - key = 'null' - elif isinstance(key, (int, long)): - key = str(key) - elif _skipkeys: - continue - else: - raise TypeError("key " + repr(key) + " is not a string") + if not (_item_sort_key or isinstance(key, string_types)): + key = _stringify_key(key) + if key is None: + # _skipkeys must be True + continue if first: first = False else: yield item_separator yield _encoder(key) yield _key_separator - if isinstance(value, basestring): + if (isinstance(value, string_types) or + (_PY3 and isinstance(value, binary_type))): yield _encoder(value) elif value is None: yield 'null' @@ -463,25 +565,29 @@ def _make_iterencode(markers, _default, _encoder, _indent, _floatstr, yield 'true' elif value is False: yield 'false' - elif isinstance(value, (int, long)): - yield str(value) + elif isinstance(value, integer_types): + yield _encode_int(value) elif isinstance(value, float): yield _floatstr(value) elif _use_decimal and isinstance(value, Decimal): yield str(value) else: - if isinstance(value, list): + for_json = _for_json and getattr(value, 'for_json', None) + if for_json and callable(for_json): + chunks = _iterencode(for_json(), _current_indent_level) + elif isinstance(value, list): chunks = _iterencode_list(value, _current_indent_level) - elif (_namedtuple_as_object and isinstance(value, tuple) and - hasattr(value, '_asdict')): - chunks = _iterencode_dict(value._asdict(), - _current_indent_level) - elif _tuple_as_array and isinstance(value, tuple): - chunks = _iterencode_list(value, _current_indent_level) - elif isinstance(value, dict): - chunks = _iterencode_dict(value, _current_indent_level) else: - chunks = _iterencode(value, _current_indent_level) + _asdict = _namedtuple_as_object and getattr(value, '_asdict', None) + if _asdict and callable(_asdict): + chunks = _iterencode_dict(_asdict(), + _current_indent_level) + elif _tuple_as_array and isinstance(value, tuple): + chunks = _iterencode_list(value, _current_indent_level) + elif isinstance(value, dict): + chunks = _iterencode_dict(value, _current_indent_level) + else: + chunks = _iterencode(value, _current_indent_level) for chunk in chunks: yield chunk if newline_indent is not None: @@ -492,7 +598,8 @@ def _make_iterencode(markers, _default, _encoder, _indent, _floatstr, del markers[markerid] def _iterencode(o, _current_indent_level): - if isinstance(o, basestring): + if (isinstance(o, string_types) or + (_PY3 and isinstance(o, binary_type))): yield _encoder(o) elif o is None: yield 'null' @@ -500,35 +607,42 @@ def _make_iterencode(markers, _default, _encoder, _indent, _floatstr, yield 'true' elif o is False: yield 'false' - elif isinstance(o, (int, long)): - yield str(o) + elif isinstance(o, integer_types): + yield _encode_int(o) elif isinstance(o, float): yield _floatstr(o) - elif isinstance(o, list): - for chunk in _iterencode_list(o, _current_indent_level): - yield chunk - elif (_namedtuple_as_object and isinstance(o, tuple) and - hasattr(o, '_asdict')): - for chunk in _iterencode_dict(o._asdict(), _current_indent_level): - yield chunk - elif (_tuple_as_array and isinstance(o, tuple)): - for chunk in _iterencode_list(o, _current_indent_level): - yield chunk - elif isinstance(o, dict): - for chunk in _iterencode_dict(o, _current_indent_level): - yield chunk - elif _use_decimal and isinstance(o, Decimal): - yield str(o) else: - if markers is not None: - markerid = id(o) - if markerid in markers: - raise ValueError("Circular reference detected") - markers[markerid] = o - o = _default(o) - for chunk in _iterencode(o, _current_indent_level): - yield chunk - if markers is not None: - del markers[markerid] + for_json = _for_json and getattr(o, 'for_json', None) + if for_json and callable(for_json): + for chunk in _iterencode(for_json(), _current_indent_level): + yield chunk + elif isinstance(o, list): + for chunk in _iterencode_list(o, _current_indent_level): + yield chunk + else: + _asdict = _namedtuple_as_object and getattr(o, '_asdict', None) + if _asdict and callable(_asdict): + for chunk in _iterencode_dict(_asdict(), + _current_indent_level): + yield chunk + elif (_tuple_as_array and isinstance(o, tuple)): + for chunk in _iterencode_list(o, _current_indent_level): + yield chunk + elif isinstance(o, dict): + for chunk in _iterencode_dict(o, _current_indent_level): + yield chunk + elif _use_decimal and isinstance(o, Decimal): + yield str(o) + else: + if markers is not None: + markerid = id(o) + if markerid in markers: + raise ValueError("Circular reference detected") + markers[markerid] = o + o = _default(o) + for chunk in _iterencode(o, _current_indent_level): + yield chunk + if markers is not None: + del markers[markerid] return _iterencode diff --git a/module/lib/simplejson/scanner.py b/module/lib/simplejson/scanner.py index 54593a371..5abed357b 100644 --- a/module/lib/simplejson/scanner.py +++ b/module/lib/simplejson/scanner.py @@ -9,12 +9,62 @@ def _import_c_make_scanner(): return None c_make_scanner = _import_c_make_scanner() -__all__ = ['make_scanner'] +__all__ = ['make_scanner', 'JSONDecodeError'] NUMBER_RE = re.compile( r'(-?(?:0|[1-9]\d*))(\.\d+)?([eE][-+]?\d+)?', (re.VERBOSE | re.MULTILINE | re.DOTALL)) +class JSONDecodeError(ValueError): + """Subclass of ValueError with the following additional properties: + + msg: The unformatted error message + doc: The JSON document being parsed + pos: The start index of doc where parsing failed + end: The end index of doc where parsing failed (may be None) + lineno: The line corresponding to pos + colno: The column corresponding to pos + endlineno: The line corresponding to end (may be None) + endcolno: The column corresponding to end (may be None) + + """ + # Note that this exception is used from _speedups + def __init__(self, msg, doc, pos, end=None): + ValueError.__init__(self, errmsg(msg, doc, pos, end=end)) + self.msg = msg + self.doc = doc + self.pos = pos + self.end = end + self.lineno, self.colno = linecol(doc, pos) + if end is not None: + self.endlineno, self.endcolno = linecol(doc, end) + else: + self.endlineno, self.endcolno = None, None + + def __reduce__(self): + return self.__class__, (self.msg, self.doc, self.pos, self.end) + + +def linecol(doc, pos): + lineno = doc.count('\n', 0, pos) + 1 + if lineno == 1: + colno = pos + 1 + else: + colno = pos - doc.rindex('\n', 0, pos) + return lineno, colno + + +def errmsg(msg, doc, pos, end=None): + lineno, colno = linecol(doc, pos) + msg = msg.replace('%r', repr(doc[pos:pos + 1])) + if end is None: + fmt = '%s: line %d column %d (char %d)' + return fmt % (msg, lineno, colno, pos) + endlineno, endcolno = linecol(doc, end) + fmt = '%s: line %d column %d - line %d column %d (char %d - %d)' + return fmt % (msg, lineno, colno, endlineno, endcolno, pos, end) + + def py_make_scanner(context): parse_object = context.parse_object parse_array = context.parse_array @@ -30,10 +80,11 @@ def py_make_scanner(context): memo = context.memo def _scan_once(string, idx): + errmsg = 'Expecting value' try: nextchar = string[idx] except IndexError: - raise StopIteration + raise JSONDecodeError(errmsg, string, idx) if nextchar == '"': return parse_string(string, idx + 1, encoding, strict) @@ -64,9 +115,14 @@ def py_make_scanner(context): elif nextchar == '-' and string[idx:idx + 9] == '-Infinity': return parse_constant('-Infinity'), idx + 9 else: - raise StopIteration + raise JSONDecodeError(errmsg, string, idx) def scan_once(string, idx): + if idx < 0: + # Ensure the same behavior as the C speedup, otherwise + # this would work for *some* negative string indices due + # to the behavior of __getitem__ for strings. #98 + raise JSONDecodeError('Expecting value', string, idx) try: return _scan_once(string, idx) finally: diff --git a/module/lib/simplejson/tool.py b/module/lib/simplejson/tool.py index 73370db55..062e8e2c1 100644 --- a/module/lib/simplejson/tool.py +++ b/module/lib/simplejson/tool.py @@ -10,6 +10,7 @@ Usage:: Expecting property name: line 1 column 2 (char 2) """ +from __future__ import with_statement import sys import simplejson as json @@ -18,21 +19,23 @@ def main(): infile = sys.stdin outfile = sys.stdout elif len(sys.argv) == 2: - infile = open(sys.argv[1], 'rb') + infile = open(sys.argv[1], 'r') outfile = sys.stdout elif len(sys.argv) == 3: - infile = open(sys.argv[1], 'rb') - outfile = open(sys.argv[2], 'wb') + infile = open(sys.argv[1], 'r') + outfile = open(sys.argv[2], 'w') else: raise SystemExit(sys.argv[0] + " [infile [outfile]]") - try: - obj = json.load(infile, - object_pairs_hook=json.OrderedDict, - use_decimal=True) - except ValueError, e: - raise SystemExit(e) - json.dump(obj, outfile, sort_keys=True, indent=' ', use_decimal=True) - outfile.write('\n') + with infile: + try: + obj = json.load(infile, + object_pairs_hook=json.OrderedDict, + use_decimal=True) + except ValueError: + raise SystemExit(sys.exc_info()[1]) + with outfile: + json.dump(obj, outfile, sort_keys=True, indent=' ', use_decimal=True) + outfile.write('\n') if __name__ == '__main__': diff --git a/module/lib/thrift/TSCons.py b/module/lib/thrift/TSCons.py index 24046256c..da8d2833b 100644 --- a/module/lib/thrift/TSCons.py +++ b/module/lib/thrift/TSCons.py @@ -20,14 +20,16 @@ from os import path from SCons.Builder import Builder + def scons_env(env, add=''): opath = path.dirname(path.abspath('$TARGET')) lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE' - cppbuild = Builder(action = lstr) - env.Append(BUILDERS = {'ThriftCpp' : cppbuild}) + cppbuild = Builder(action=lstr) + env.Append(BUILDERS={'ThriftCpp': cppbuild}) + def gen_cpp(env, dir, file): scons_env(env) suffixes = ['_types.h', '_types.cpp'] targets = map(lambda s: 'gen-cpp/' + file + s, suffixes) - return env.ThriftCpp(targets, dir+file+'.thrift') + return env.ThriftCpp(targets, dir + file + '.thrift') diff --git a/module/lib/thrift/TSerialization.py b/module/lib/thrift/TSerialization.py index b19f98aa8..8a58d89df 100644 --- a/module/lib/thrift/TSerialization.py +++ b/module/lib/thrift/TSerialization.py @@ -20,15 +20,19 @@ from protocol import TBinaryProtocol from transport import TTransport -def serialize(thrift_object, protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()): + +def serialize(thrift_object, + protocol_factory=TBinaryProtocol.TBinaryProtocolFactory()): transport = TTransport.TMemoryBuffer() protocol = protocol_factory.getProtocol(transport) thrift_object.write(protocol) return transport.getvalue() -def deserialize(base, buf, protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()): + +def deserialize(base, + buf, + protocol_factory=TBinaryProtocol.TBinaryProtocolFactory()): transport = TTransport.TMemoryBuffer(buf) protocol = protocol_factory.getProtocol(transport) base.read(protocol) return base - diff --git a/module/lib/thrift/TTornado.py b/module/lib/thrift/TTornado.py new file mode 100644 index 000000000..af309c3d9 --- /dev/null +++ b/module/lib/thrift/TTornado.py @@ -0,0 +1,153 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from cStringIO import StringIO +import logging +import socket +import struct + +from thrift.transport import TTransport +from thrift.transport.TTransport import TTransportException + +from tornado import gen +from tornado import iostream +from tornado import netutil + + +class TTornadoStreamTransport(TTransport.TTransportBase): + """a framed, buffered transport over a Tornado stream""" + def __init__(self, host, port, stream=None): + self.host = host + self.port = port + self.is_queuing_reads = False + self.read_queue = [] + self.__wbuf = StringIO() + + # servers provide a ready-to-go stream + self.stream = stream + if self.stream is not None: + self._set_close_callback() + + # not the same number of parameters as TTransportBase.open + def open(self, callback): + logging.debug('socket connecting') + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) + self.stream = iostream.IOStream(sock) + + def on_close_in_connect(*_): + message = 'could not connect to {}:{}'.format(self.host, self.port) + raise TTransportException( + type=TTransportException.NOT_OPEN, + message=message) + self.stream.set_close_callback(on_close_in_connect) + + def finish(*_): + self._set_close_callback() + callback() + + self.stream.connect((self.host, self.port), callback=finish) + + def _set_close_callback(self): + def on_close(): + raise TTransportException( + type=TTransportException.END_OF_FILE, + message='socket closed') + self.stream.set_close_callback(self.close) + + def close(self): + # don't raise if we intend to close + self.stream.set_close_callback(None) + self.stream.close() + + def read(self, _): + # The generated code for Tornado shouldn't do individual reads -- only + # frames at a time + assert "you're doing it wrong" is True + + @gen.engine + def readFrame(self, callback): + self.read_queue.append(callback) + logging.debug('read queue: %s', self.read_queue) + + if self.is_queuing_reads: + # If a read is already in flight, then the while loop below should + # pull it from self.read_queue + return + + self.is_queuing_reads = True + while self.read_queue: + next_callback = self.read_queue.pop() + result = yield gen.Task(self._readFrameFromStream) + next_callback(result) + self.is_queuing_reads = False + + @gen.engine + def _readFrameFromStream(self, callback): + logging.debug('_readFrameFromStream') + frame_header = yield gen.Task(self.stream.read_bytes, 4) + frame_length, = struct.unpack('!i', frame_header) + logging.debug('received frame header, frame length = %i', frame_length) + frame = yield gen.Task(self.stream.read_bytes, frame_length) + logging.debug('received frame payload') + callback(frame) + + def write(self, buf): + self.__wbuf.write(buf) + + def flush(self, callback=None): + wout = self.__wbuf.getvalue() + wsz = len(wout) + # reset wbuf before write/flush to preserve state on underlying failure + self.__wbuf = StringIO() + # N.B.: Doing this string concatenation is WAY cheaper than making + # two separate calls to the underlying socket object. Socket writes in + # Python turn out to be REALLY expensive, but it seems to do a pretty + # good job of managing string buffer operations without excessive copies + buf = struct.pack("!i", wsz) + wout + + logging.debug('writing frame length = %i', wsz) + self.stream.write(buf, callback) + + +class TTornadoServer(netutil.TCPServer): + def __init__(self, processor, iprot_factory, oprot_factory=None, + *args, **kwargs): + super(TTornadoServer, self).__init__(*args, **kwargs) + + self._processor = processor + self._iprot_factory = iprot_factory + self._oprot_factory = (oprot_factory if oprot_factory is not None + else iprot_factory) + + def handle_stream(self, stream, address): + try: + host, port = address + trans = TTornadoStreamTransport(host=host, port=port, stream=stream) + oprot = self._oprot_factory.getProtocol(trans) + + def next_pass(): + if not trans.stream.closed(): + self._processor.process(trans, self._iprot_factory, oprot, + callback=next_pass) + + next_pass() + + except Exception: + logging.exception('thrift exception in handle_stream') + trans.close() diff --git a/module/lib/thrift/Thrift.py b/module/lib/thrift/Thrift.py index 1d271fcff..9890af7e1 100644 --- a/module/lib/thrift/Thrift.py +++ b/module/lib/thrift/Thrift.py @@ -19,6 +19,7 @@ import sys + class TType: STOP = 0 VOID = 1 @@ -38,7 +39,7 @@ class TType: UTF8 = 16 UTF16 = 17 - _VALUES_TO_NAMES = ( 'STOP', + _VALUES_TO_NAMES = ('STOP', 'VOID', 'BOOL', 'BYTE', @@ -48,46 +49,48 @@ class TType: None, 'I32', None, - 'I64', - 'STRING', - 'STRUCT', - 'MAP', - 'SET', - 'LIST', - 'UTF8', - 'UTF16' ) + 'I64', + 'STRING', + 'STRUCT', + 'MAP', + 'SET', + 'LIST', + 'UTF8', + 'UTF16') + class TMessageType: - CALL = 1 + CALL = 1 REPLY = 2 EXCEPTION = 3 ONEWAY = 4 -class TProcessor: +class TProcessor: """Base class for procsessor, which works on two streams.""" def process(iprot, oprot): pass -class TException(Exception): +class TException(Exception): """Base class for all thrift exceptions.""" # BaseException.message is deprecated in Python v[2.6,3.0) - if (2,6,0) <= sys.version_info < (3,0): + if (2, 6, 0) <= sys.version_info < (3, 0): def _get_message(self): - return self._message + return self._message + def _set_message(self, message): - self._message = message + self._message = message message = property(_get_message, _set_message) def __init__(self, message=None): Exception.__init__(self, message) self.message = message -class TApplicationException(TException): +class TApplicationException(TException): """Application level thrift exceptions.""" UNKNOWN = 0 @@ -98,6 +101,9 @@ class TApplicationException(TException): MISSING_RESULT = 5 INTERNAL_ERROR = 6 PROTOCOL_ERROR = 7 + INVALID_TRANSFORM = 8 + INVALID_PROTOCOL = 9 + UNSUPPORTED_CLIENT_TYPE = 10 def __init__(self, type=UNKNOWN, message=None): TException.__init__(self, message) @@ -116,6 +122,16 @@ class TApplicationException(TException): return 'Bad sequence ID' elif self.type == self.MISSING_RESULT: return 'Missing result' + elif self.type == self.INTERNAL_ERROR: + return 'Internal error' + elif self.type == self.PROTOCOL_ERROR: + return 'Protocol error' + elif self.type == self.INVALID_TRANSFORM: + return 'Invalid transform' + elif self.type == self.INVALID_PROTOCOL: + return 'Invalid protocol' + elif self.type == self.UNSUPPORTED_CLIENT_TYPE: + return 'Unsupported client type' else: return 'Default (unknown) TApplicationException' @@ -127,12 +143,12 @@ class TApplicationException(TException): break if fid == 1: if ftype == TType.STRING: - self.message = iprot.readString(); + self.message = iprot.readString() else: iprot.skip(ftype) elif fid == 2: if ftype == TType.I32: - self.type = iprot.readI32(); + self.type = iprot.readI32() else: iprot.skip(ftype) else: @@ -142,11 +158,11 @@ class TApplicationException(TException): def write(self, oprot): oprot.writeStructBegin('TApplicationException') - if self.message != None: + if self.message is not None: oprot.writeFieldBegin('message', TType.STRING, 1) oprot.writeString(self.message) oprot.writeFieldEnd() - if self.type != None: + if self.type is not None: oprot.writeFieldBegin('type', TType.I32, 2) oprot.writeI32(self.type) oprot.writeFieldEnd() diff --git a/module/lib/thrift/protocol/TBase.py b/module/lib/thrift/protocol/TBase.py index e675c7dc0..6cbd5f39a 100644 --- a/module/lib/thrift/protocol/TBase.py +++ b/module/lib/thrift/protocol/TBase.py @@ -26,12 +26,13 @@ try: except: fastbinary = None + class TBase(object): __slots__ = [] def __repr__(self): L = ['%s=%r' % (key, getattr(self, key)) - for key in self.__slots__ ] + for key in self.__slots__] return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) def __eq__(self, other): @@ -43,30 +44,38 @@ class TBase(object): if my_val != other_val: return False return True - + def __ne__(self, other): return not (self == other) - + def read(self, iprot): - if iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None and fastbinary is not None: - fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec)) + if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and + isinstance(iprot.trans, TTransport.CReadableTransport) and + self.thrift_spec is not None and + fastbinary is not None): + fastbinary.decode_binary(self, + iprot.trans, + (self.__class__, self.thrift_spec)) return iprot.readStruct(self, self.thrift_spec) def write(self, oprot): - if oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and self.thrift_spec is not None and fastbinary is not None: - oprot.trans.write(fastbinary.encode_binary(self, (self.__class__, self.thrift_spec))) + if (oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and + self.thrift_spec is not None and + fastbinary is not None): + oprot.trans.write( + fastbinary.encode_binary(self, (self.__class__, self.thrift_spec))) return oprot.writeStruct(self, self.thrift_spec) + class TExceptionBase(Exception): # old style class so python2.4 can raise exceptions derived from this # This can't inherit from TBase because of that limitation. __slots__ = [] - + __repr__ = TBase.__repr__.im_func __eq__ = TBase.__eq__.im_func __ne__ = TBase.__ne__.im_func read = TBase.read.im_func write = TBase.write.im_func - diff --git a/module/lib/thrift/protocol/TBinaryProtocol.py b/module/lib/thrift/protocol/TBinaryProtocol.py index 50c6aa896..6fdd08c26 100644 --- a/module/lib/thrift/protocol/TBinaryProtocol.py +++ b/module/lib/thrift/protocol/TBinaryProtocol.py @@ -20,8 +20,8 @@ from TProtocol import * from struct import pack, unpack -class TBinaryProtocol(TProtocolBase): +class TBinaryProtocol(TProtocolBase): """Binary implementation of the Thrift protocol driver.""" # NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be @@ -68,7 +68,7 @@ class TBinaryProtocol(TProtocolBase): pass def writeFieldStop(self): - self.writeByte(TType.STOP); + self.writeByte(TType.STOP) def writeMapBegin(self, ktype, vtype, size): self.writeByte(ktype) @@ -127,13 +127,16 @@ class TBinaryProtocol(TProtocolBase): if sz < 0: version = sz & TBinaryProtocol.VERSION_MASK if version != TBinaryProtocol.VERSION_1: - raise TProtocolException(type=TProtocolException.BAD_VERSION, message='Bad version in readMessageBegin: %d' % (sz)) + raise TProtocolException( + type=TProtocolException.BAD_VERSION, + message='Bad version in readMessageBegin: %d' % (sz)) type = sz & TBinaryProtocol.TYPE_MASK name = self.readString() seqid = self.readI32() else: if self.strictRead: - raise TProtocolException(type=TProtocolException.BAD_VERSION, message='No protocol version header') + raise TProtocolException(type=TProtocolException.BAD_VERSION, + message='No protocol version header') name = self.trans.readAll(sz) type = self.readByte() seqid = self.readI32() @@ -231,7 +234,6 @@ class TBinaryProtocolFactory: class TBinaryProtocolAccelerated(TBinaryProtocol): - """C-Accelerated version of TBinaryProtocol. This class does not override any of TBinaryProtocol's methods, @@ -250,7 +252,6 @@ class TBinaryProtocolAccelerated(TBinaryProtocol): Please feel free to report bugs and/or success stories to the public mailing list. """ - pass diff --git a/module/lib/thrift/protocol/TCompactProtocol.py b/module/lib/thrift/protocol/TCompactProtocol.py index 016a33171..cdec60773 100644 --- a/module/lib/thrift/protocol/TCompactProtocol.py +++ b/module/lib/thrift/protocol/TCompactProtocol.py @@ -32,6 +32,7 @@ CONTAINER_READ = 6 VALUE_READ = 7 BOOL_READ = 8 + def make_helper(v_from, container): def helper(func): def nested(self, *args, **kwargs): @@ -42,12 +43,15 @@ def make_helper(v_from, container): writer = make_helper(VALUE_WRITE, CONTAINER_WRITE) reader = make_helper(VALUE_READ, CONTAINER_READ) + def makeZigZag(n, bits): return (n << 1) ^ (n >> (bits - 1)) + def fromZigZag(n): return (n >> 1) ^ -(n & 1) + def writeVarint(trans, n): out = [] while True: @@ -59,6 +63,7 @@ def writeVarint(trans, n): n = n >> 7 trans.write(''.join(map(chr, out))) + def readVarint(trans): result = 0 shift = 0 @@ -70,6 +75,7 @@ def readVarint(trans): return result shift += 7 + class CompactType: STOP = 0x00 TRUE = 0x01 @@ -86,7 +92,7 @@ class CompactType: STRUCT = 0x0C CTYPES = {TType.STOP: CompactType.STOP, - TType.BOOL: CompactType.TRUE, # used for collection + TType.BOOL: CompactType.TRUE, # used for collection TType.BYTE: CompactType.BYTE, TType.I16: CompactType.I16, TType.I32: CompactType.I32, @@ -106,8 +112,9 @@ TTYPES[CompactType.FALSE] = TType.BOOL del k del v + class TCompactProtocol(TProtocolBase): - "Compact implementation of the Thrift protocol driver." + """Compact implementation of the Thrift protocol driver.""" PROTOCOL_ID = 0x82 VERSION = 1 @@ -217,18 +224,18 @@ class TCompactProtocol(TProtocolBase): def writeBool(self, bool): if self.state == BOOL_WRITE: - if bool: - ctype = CompactType.TRUE - else: - ctype = CompactType.FALSE - self.__writeFieldHeader(ctype, self.__bool_fid) + if bool: + ctype = CompactType.TRUE + else: + ctype = CompactType.FALSE + self.__writeFieldHeader(ctype, self.__bool_fid) elif self.state == CONTAINER_WRITE: - if bool: - self.__writeByte(CompactType.TRUE) - else: - self.__writeByte(CompactType.FALSE) + if bool: + self.__writeByte(CompactType.TRUE) + else: + self.__writeByte(CompactType.FALSE) else: - raise AssertionError, "Invalid state in compact protocol" + raise AssertionError("Invalid state in compact protocol") writeByte = writer(__writeByte) writeI16 = writer(__writeI16) @@ -364,7 +371,8 @@ class TCompactProtocol(TProtocolBase): elif self.state == CONTAINER_READ: return self.__readByte() == CompactType.TRUE else: - raise AssertionError, "Invalid state in compact protocol: %d" % self.state + raise AssertionError("Invalid state in compact protocol: %d" % + self.state) readByte = reader(__readByte) __readI16 = __readZigZag diff --git a/module/lib/thrift/protocol/TJSONProtocol.py b/module/lib/thrift/protocol/TJSONProtocol.py new file mode 100644 index 000000000..3048197d4 --- /dev/null +++ b/module/lib/thrift/protocol/TJSONProtocol.py @@ -0,0 +1,550 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from TProtocol import TType, TProtocolBase, TProtocolException +import base64 +import json +import math + +__all__ = ['TJSONProtocol', + 'TJSONProtocolFactory', + 'TSimpleJSONProtocol', + 'TSimpleJSONProtocolFactory'] + +VERSION = 1 + +COMMA = ',' +COLON = ':' +LBRACE = '{' +RBRACE = '}' +LBRACKET = '[' +RBRACKET = ']' +QUOTE = '"' +BACKSLASH = '\\' +ZERO = '0' + +ESCSEQ = '\\u00' +ESCAPE_CHAR = '"\\bfnrt' +ESCAPE_CHAR_VALS = ['"', '\\', '\b', '\f', '\n', '\r', '\t'] +NUMERIC_CHAR = '+-.0123456789Ee' + +CTYPES = {TType.BOOL: 'tf', + TType.BYTE: 'i8', + TType.I16: 'i16', + TType.I32: 'i32', + TType.I64: 'i64', + TType.DOUBLE: 'dbl', + TType.STRING: 'str', + TType.STRUCT: 'rec', + TType.LIST: 'lst', + TType.SET: 'set', + TType.MAP: 'map'} + +JTYPES = {} +for key in CTYPES.keys(): + JTYPES[CTYPES[key]] = key + + +class JSONBaseContext(object): + + def __init__(self, protocol): + self.protocol = protocol + self.first = True + + def doIO(self, function): + pass + + def write(self): + pass + + def read(self): + pass + + def escapeNum(self): + return False + + def __str__(self): + return self.__class__.__name__ + + +class JSONListContext(JSONBaseContext): + + def doIO(self, function): + if self.first is True: + self.first = False + else: + function(COMMA) + + def write(self): + self.doIO(self.protocol.trans.write) + + def read(self): + self.doIO(self.protocol.readJSONSyntaxChar) + + +class JSONPairContext(JSONBaseContext): + + def __init__(self, protocol): + super(JSONPairContext, self).__init__(protocol) + self.colon = True + + def doIO(self, function): + if self.first: + self.first = False + self.colon = True + else: + function(COLON if self.colon else COMMA) + self.colon = not self.colon + + def write(self): + self.doIO(self.protocol.trans.write) + + def read(self): + self.doIO(self.protocol.readJSONSyntaxChar) + + def escapeNum(self): + return self.colon + + def __str__(self): + return '%s, colon=%s' % (self.__class__.__name__, self.colon) + + +class LookaheadReader(): + hasData = False + data = '' + + def __init__(self, protocol): + self.protocol = protocol + + def read(self): + if self.hasData is True: + self.hasData = False + else: + self.data = self.protocol.trans.read(1) + return self.data + + def peek(self): + if self.hasData is False: + self.data = self.protocol.trans.read(1) + self.hasData = True + return self.data + +class TJSONProtocolBase(TProtocolBase): + + def __init__(self, trans): + TProtocolBase.__init__(self, trans) + self.resetWriteContext() + self.resetReadContext() + + def resetWriteContext(self): + self.context = JSONBaseContext(self) + self.contextStack = [self.context] + + def resetReadContext(self): + self.resetWriteContext() + self.reader = LookaheadReader(self) + + def pushContext(self, ctx): + self.contextStack.append(ctx) + self.context = ctx + + def popContext(self): + self.contextStack.pop() + if self.contextStack: + self.context = self.contextStack[-1] + else: + self.context = JSONBaseContext(self) + + def writeJSONString(self, string): + self.context.write() + self.trans.write(json.dumps(string)) + + def writeJSONNumber(self, number): + self.context.write() + jsNumber = str(number) + if self.context.escapeNum(): + jsNumber = "%s%s%s" % (QUOTE, jsNumber, QUOTE) + self.trans.write(jsNumber) + + def writeJSONBase64(self, binary): + self.context.write() + self.trans.write(QUOTE) + self.trans.write(base64.b64encode(binary)) + self.trans.write(QUOTE) + + def writeJSONObjectStart(self): + self.context.write() + self.trans.write(LBRACE) + self.pushContext(JSONPairContext(self)) + + def writeJSONObjectEnd(self): + self.popContext() + self.trans.write(RBRACE) + + def writeJSONArrayStart(self): + self.context.write() + self.trans.write(LBRACKET) + self.pushContext(JSONListContext(self)) + + def writeJSONArrayEnd(self): + self.popContext() + self.trans.write(RBRACKET) + + def readJSONSyntaxChar(self, character): + current = self.reader.read() + if character != current: + raise TProtocolException(TProtocolException.INVALID_DATA, + "Unexpected character: %s" % current) + + def readJSONString(self, skipContext): + string = [] + if skipContext is False: + self.context.read() + self.readJSONSyntaxChar(QUOTE) + while True: + character = self.reader.read() + if character == QUOTE: + break + if character == ESCSEQ[0]: + character = self.reader.read() + if character == ESCSEQ[1]: + self.readJSONSyntaxChar(ZERO) + self.readJSONSyntaxChar(ZERO) + character = json.JSONDecoder().decode('"\u00%s"' % self.trans.read(2)) + else: + off = ESCAPE_CHAR.find(character) + if off == -1: + raise TProtocolException(TProtocolException.INVALID_DATA, + "Expected control char") + character = ESCAPE_CHAR_VALS[off] + string.append(character) + return ''.join(string) + + def isJSONNumeric(self, character): + return (True if NUMERIC_CHAR.find(character) != - 1 else False) + + def readJSONQuotes(self): + if (self.context.escapeNum()): + self.readJSONSyntaxChar(QUOTE) + + def readJSONNumericChars(self): + numeric = [] + while True: + character = self.reader.peek() + if self.isJSONNumeric(character) is False: + break + numeric.append(self.reader.read()) + return ''.join(numeric) + + def readJSONInteger(self): + self.context.read() + self.readJSONQuotes() + numeric = self.readJSONNumericChars() + self.readJSONQuotes() + try: + return int(numeric) + except ValueError: + raise TProtocolException(TProtocolException.INVALID_DATA, + "Bad data encounted in numeric data") + + def readJSONDouble(self): + self.context.read() + if self.reader.peek() == QUOTE: + string = self.readJSONString(True) + try: + double = float(string) + if (self.context.escapeNum is False and + not math.isinf(double) and + not math.isnan(double)): + raise TProtocolException(TProtocolException.INVALID_DATA, + "Numeric data unexpectedly quoted") + return double + except ValueError: + raise TProtocolException(TProtocolException.INVALID_DATA, + "Bad data encounted in numeric data") + else: + if self.context.escapeNum() is True: + self.readJSONSyntaxChar(QUOTE) + try: + return float(self.readJSONNumericChars()) + except ValueError: + raise TProtocolException(TProtocolException.INVALID_DATA, + "Bad data encounted in numeric data") + + def readJSONBase64(self): + string = self.readJSONString(False) + return base64.b64decode(string) + + def readJSONObjectStart(self): + self.context.read() + self.readJSONSyntaxChar(LBRACE) + self.pushContext(JSONPairContext(self)) + + def readJSONObjectEnd(self): + self.readJSONSyntaxChar(RBRACE) + self.popContext() + + def readJSONArrayStart(self): + self.context.read() + self.readJSONSyntaxChar(LBRACKET) + self.pushContext(JSONListContext(self)) + + def readJSONArrayEnd(self): + self.readJSONSyntaxChar(RBRACKET) + self.popContext() + + +class TJSONProtocol(TJSONProtocolBase): + + def readMessageBegin(self): + self.resetReadContext() + self.readJSONArrayStart() + if self.readJSONInteger() != VERSION: + raise TProtocolException(TProtocolException.BAD_VERSION, + "Message contained bad version.") + name = self.readJSONString(False) + typen = self.readJSONInteger() + seqid = self.readJSONInteger() + return (name, typen, seqid) + + def readMessageEnd(self): + self.readJSONArrayEnd() + + def readStructBegin(self): + self.readJSONObjectStart() + + def readStructEnd(self): + self.readJSONObjectEnd() + + def readFieldBegin(self): + character = self.reader.peek() + ttype = 0 + id = 0 + if character == RBRACE: + ttype = TType.STOP + else: + id = self.readJSONInteger() + self.readJSONObjectStart() + ttype = JTYPES[self.readJSONString(False)] + return (None, ttype, id) + + def readFieldEnd(self): + self.readJSONObjectEnd() + + def readMapBegin(self): + self.readJSONArrayStart() + keyType = JTYPES[self.readJSONString(False)] + valueType = JTYPES[self.readJSONString(False)] + size = self.readJSONInteger() + self.readJSONObjectStart() + return (keyType, valueType, size) + + def readMapEnd(self): + self.readJSONObjectEnd() + self.readJSONArrayEnd() + + def readCollectionBegin(self): + self.readJSONArrayStart() + elemType = JTYPES[self.readJSONString(False)] + size = self.readJSONInteger() + return (elemType, size) + readListBegin = readCollectionBegin + readSetBegin = readCollectionBegin + + def readCollectionEnd(self): + self.readJSONArrayEnd() + readSetEnd = readCollectionEnd + readListEnd = readCollectionEnd + + def readBool(self): + return (False if self.readJSONInteger() == 0 else True) + + def readNumber(self): + return self.readJSONInteger() + readByte = readNumber + readI16 = readNumber + readI32 = readNumber + readI64 = readNumber + + def readDouble(self): + return self.readJSONDouble() + + def readString(self): + return self.readJSONString(False) + + def readBinary(self): + return self.readJSONBase64() + + def writeMessageBegin(self, name, request_type, seqid): + self.resetWriteContext() + self.writeJSONArrayStart() + self.writeJSONNumber(VERSION) + self.writeJSONString(name) + self.writeJSONNumber(request_type) + self.writeJSONNumber(seqid) + + def writeMessageEnd(self): + self.writeJSONArrayEnd() + + def writeStructBegin(self, name): + self.writeJSONObjectStart() + + def writeStructEnd(self): + self.writeJSONObjectEnd() + + def writeFieldBegin(self, name, ttype, id): + self.writeJSONNumber(id) + self.writeJSONObjectStart() + self.writeJSONString(CTYPES[ttype]) + + def writeFieldEnd(self): + self.writeJSONObjectEnd() + + def writeFieldStop(self): + pass + + def writeMapBegin(self, ktype, vtype, size): + self.writeJSONArrayStart() + self.writeJSONString(CTYPES[ktype]) + self.writeJSONString(CTYPES[vtype]) + self.writeJSONNumber(size) + self.writeJSONObjectStart() + + def writeMapEnd(self): + self.writeJSONObjectEnd() + self.writeJSONArrayEnd() + + def writeListBegin(self, etype, size): + self.writeJSONArrayStart() + self.writeJSONString(CTYPES[etype]) + self.writeJSONNumber(size) + + def writeListEnd(self): + self.writeJSONArrayEnd() + + def writeSetBegin(self, etype, size): + self.writeJSONArrayStart() + self.writeJSONString(CTYPES[etype]) + self.writeJSONNumber(size) + + def writeSetEnd(self): + self.writeJSONArrayEnd() + + def writeBool(self, boolean): + self.writeJSONNumber(1 if boolean is True else 0) + + def writeInteger(self, integer): + self.writeJSONNumber(integer) + writeByte = writeInteger + writeI16 = writeInteger + writeI32 = writeInteger + writeI64 = writeInteger + + def writeDouble(self, dbl): + self.writeJSONNumber(dbl) + + def writeString(self, string): + self.writeJSONString(string) + + def writeBinary(self, binary): + self.writeJSONBase64(binary) + + +class TJSONProtocolFactory: + + def getProtocol(self, trans): + return TJSONProtocol(trans) + + +class TSimpleJSONProtocol(TJSONProtocolBase): + """Simple, readable, write-only JSON protocol. + + Useful for interacting with scripting languages. + """ + + def readMessageBegin(self): + raise NotImplementedError() + + def readMessageEnd(self): + raise NotImplementedError() + + def readStructBegin(self): + raise NotImplementedError() + + def readStructEnd(self): + raise NotImplementedError() + + def writeMessageBegin(self, name, request_type, seqid): + self.resetWriteContext() + + def writeMessageEnd(self): + pass + + def writeStructBegin(self, name): + self.writeJSONObjectStart() + + def writeStructEnd(self): + self.writeJSONObjectEnd() + + def writeFieldBegin(self, name, ttype, fid): + self.writeJSONString(name) + + def writeFieldEnd(self): + pass + + def writeMapBegin(self, ktype, vtype, size): + self.writeJSONObjectStart() + + def writeMapEnd(self): + self.writeJSONObjectEnd() + + def _writeCollectionBegin(self, etype, size): + self.writeJSONArrayStart() + + def _writeCollectionEnd(self): + self.writeJSONArrayEnd() + writeListBegin = _writeCollectionBegin + writeListEnd = _writeCollectionEnd + writeSetBegin = _writeCollectionBegin + writeSetEnd = _writeCollectionEnd + + def writeInteger(self, integer): + self.writeJSONNumber(integer) + writeByte = writeInteger + writeI16 = writeInteger + writeI32 = writeInteger + writeI64 = writeInteger + + def writeBool(self, boolean): + self.writeJSONNumber(1 if boolean is True else 0) + + def writeDouble(self, dbl): + self.writeJSONNumber(dbl) + + def writeString(self, string): + self.writeJSONString(string) + + def writeBinary(self, binary): + self.writeJSONBase64(binary) + + +class TSimpleJSONProtocolFactory(object): + + def getProtocol(self, trans): + return TSimpleJSONProtocol(trans) diff --git a/module/lib/thrift/protocol/TProtocol.py b/module/lib/thrift/protocol/TProtocol.py index 7338ff68a..dc2b095de 100644 --- a/module/lib/thrift/protocol/TProtocol.py +++ b/module/lib/thrift/protocol/TProtocol.py @@ -19,8 +19,8 @@ from thrift.Thrift import * -class TProtocolException(TException): +class TProtocolException(TException): """Custom Protocol Exception class""" UNKNOWN = 0 @@ -33,14 +33,14 @@ class TProtocolException(TException): TException.__init__(self, message) self.type = type -class TProtocolBase: +class TProtocolBase: """Base class for Thrift protocol driver.""" def __init__(self, trans): self.trans = trans - def writeMessageBegin(self, name, type, seqid): + def writeMessageBegin(self, name, ttype, seqid): pass def writeMessageEnd(self): @@ -52,7 +52,7 @@ class TProtocolBase: def writeStructEnd(self): pass - def writeFieldBegin(self, name, type, id): + def writeFieldBegin(self, name, ttype, fid): pass def writeFieldEnd(self): @@ -79,7 +79,7 @@ class TProtocolBase: def writeSetEnd(self): pass - def writeBool(self, bool): + def writeBool(self, bool_val): pass def writeByte(self, byte): @@ -97,7 +97,7 @@ class TProtocolBase: def writeDouble(self, dub): pass - def writeString(self, str): + def writeString(self, str_val): pass def readMessageBegin(self): @@ -157,69 +157,69 @@ class TProtocolBase: def readString(self): pass - def skip(self, type): - if type == TType.STOP: + def skip(self, ttype): + if ttype == TType.STOP: return - elif type == TType.BOOL: + elif ttype == TType.BOOL: self.readBool() - elif type == TType.BYTE: + elif ttype == TType.BYTE: self.readByte() - elif type == TType.I16: + elif ttype == TType.I16: self.readI16() - elif type == TType.I32: + elif ttype == TType.I32: self.readI32() - elif type == TType.I64: + elif ttype == TType.I64: self.readI64() - elif type == TType.DOUBLE: + elif ttype == TType.DOUBLE: self.readDouble() - elif type == TType.STRING: + elif ttype == TType.STRING: self.readString() - elif type == TType.STRUCT: + elif ttype == TType.STRUCT: name = self.readStructBegin() while True: - (name, type, id) = self.readFieldBegin() - if type == TType.STOP: + (name, ttype, id) = self.readFieldBegin() + if ttype == TType.STOP: break - self.skip(type) + self.skip(ttype) self.readFieldEnd() self.readStructEnd() - elif type == TType.MAP: + elif ttype == TType.MAP: (ktype, vtype, size) = self.readMapBegin() - for i in range(size): + for i in xrange(size): self.skip(ktype) self.skip(vtype) self.readMapEnd() - elif type == TType.SET: + elif ttype == TType.SET: (etype, size) = self.readSetBegin() - for i in range(size): + for i in xrange(size): self.skip(etype) self.readSetEnd() - elif type == TType.LIST: + elif ttype == TType.LIST: (etype, size) = self.readListBegin() - for i in range(size): + for i in xrange(size): self.skip(etype) self.readListEnd() - # tuple of: ( 'reader method' name, is_container boolean, 'writer_method' name ) + # tuple of: ( 'reader method' name, is_container bool, 'writer_method' name ) _TTYPE_HANDLERS = ( - (None, None, False), # 0 == TType,STOP - (None, None, False), # 1 == TType.VOID # TODO: handle void? - ('readBool', 'writeBool', False), # 2 == TType.BOOL - ('readByte', 'writeByte', False), # 3 == TType.BYTE and I08 - ('readDouble', 'writeDouble', False), # 4 == TType.DOUBLE - (None, None, False), # 5, undefined - ('readI16', 'writeI16', False), # 6 == TType.I16 - (None, None, False), # 7, undefined - ('readI32', 'writeI32', False), # 8 == TType.I32 - (None, None, False), # 9, undefined - ('readI64', 'writeI64', False), # 10 == TType.I64 - ('readString', 'writeString', False), # 11 == TType.STRING and UTF7 - ('readContainerStruct', 'writeContainerStruct', True), # 12 == TType.STRUCT - ('readContainerMap', 'writeContainerMap', True), # 13 == TType.MAP - ('readContainerSet', 'writeContainerSet', True), # 14 == TType.SET - ('readContainerList', 'writeContainerList', True), # 15 == TType.LIST - (None, None, False), # 16 == TType.UTF8 # TODO: handle utf8 types? - (None, None, False)# 17 == TType.UTF16 # TODO: handle utf16 types? + (None, None, False), # 0 TType.STOP + (None, None, False), # 1 TType.VOID # TODO: handle void? + ('readBool', 'writeBool', False), # 2 TType.BOOL + ('readByte', 'writeByte', False), # 3 TType.BYTE and I08 + ('readDouble', 'writeDouble', False), # 4 TType.DOUBLE + (None, None, False), # 5 undefined + ('readI16', 'writeI16', False), # 6 TType.I16 + (None, None, False), # 7 undefined + ('readI32', 'writeI32', False), # 8 TType.I32 + (None, None, False), # 9 undefined + ('readI64', 'writeI64', False), # 10 TType.I64 + ('readString', 'writeString', False), # 11 TType.STRING and UTF7 + ('readContainerStruct', 'writeContainerStruct', True), # 12 *.STRUCT + ('readContainerMap', 'writeContainerMap', True), # 13 TType.MAP + ('readContainerSet', 'writeContainerSet', True), # 14 TType.SET + ('readContainerList', 'writeContainerList', True), # 15 TType.LIST + (None, None, False), # 16 TType.UTF8 # TODO: handle utf8 types? + (None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types? ) def readFieldByTType(self, ttype, spec): @@ -270,7 +270,7 @@ class TProtocolBase: container_reader = self._TTYPE_HANDLERS[set_type][0] val_reader = getattr(self, container_reader) for idx in xrange(set_len): - results.add(val_reader(tspec)) + results.add(val_reader(tspec)) self.readSetEnd() return results @@ -279,13 +279,14 @@ class TProtocolBase: obj = obj_class() obj.read(self) return obj - + def readContainerMap(self, spec): results = dict() key_ttype, key_spec = spec[0], spec[1] val_ttype, val_spec = spec[2], spec[3] (map_ktype, map_vtype, map_len) = self.readMapBegin() - # TODO: compare types we just decoded with thrift_spec and abort/skip if types disagree + # TODO: compare types we just decoded with thrift_spec and + # abort/skip if types disagree key_reader = getattr(self, self._TTYPE_HANDLERS[key_ttype][0]) val_reader = getattr(self, self._TTYPE_HANDLERS[val_ttype][0]) # list values are simple types @@ -298,7 +299,8 @@ class TProtocolBase: v_val = val_reader() else: v_val = self.readFieldByTType(val_ttype, val_spec) - # this raises a TypeError with unhashable keys types. i.e. d=dict(); d[[0,1]] = 2 fails + # this raises a TypeError with unhashable keys types + # i.e. this fails: d=dict(); d[[0,1]] = 2 results[k_val] = v_val self.readMapEnd() return results @@ -329,7 +331,7 @@ class TProtocolBase: def writeContainerList(self, val, spec): self.writeListBegin(spec[0], len(val)) - r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]] + r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]] e_writer = getattr(self, w_handler) if not is_container: for elem in val: @@ -398,7 +400,7 @@ class TProtocolBase: else: writer(val) + class TProtocolFactory: def getProtocol(self, trans): pass - diff --git a/module/lib/thrift/protocol/__init__.py b/module/lib/thrift/protocol/__init__.py index d53359b28..7eefb458a 100644 --- a/module/lib/thrift/protocol/__init__.py +++ b/module/lib/thrift/protocol/__init__.py @@ -17,4 +17,4 @@ # under the License. # -__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary', 'TBase'] +__all__ = ['fastbinary', 'TBase', 'TBinaryProtocol', 'TCompactProtocol', 'TJSONProtocol', 'TProtocol'] diff --git a/module/lib/thrift/protocol/fastbinary.c b/module/lib/thrift/protocol/fastbinary.c new file mode 100644 index 000000000..2ce56603c --- /dev/null +++ b/module/lib/thrift/protocol/fastbinary.c @@ -0,0 +1,1219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <Python.h> +#include "cStringIO.h" +#include <stdint.h> +#ifndef _WIN32 +# include <stdbool.h> +# include <netinet/in.h> +#else +# include <WinSock2.h> +# pragma comment (lib, "ws2_32.lib") +# define BIG_ENDIAN (4321) +# define LITTLE_ENDIAN (1234) +# define BYTE_ORDER LITTLE_ENDIAN +# if defined(_MSC_VER) && _MSC_VER < 1600 + typedef int _Bool; +# define bool _Bool +# define false 0 +# define true 1 +# endif +# define inline __inline +#endif + +/* Fix endianness issues on Solaris */ +#if defined (__SVR4) && defined (__sun) + #if defined(__i386) && !defined(__i386__) + #define __i386__ + #endif + + #ifndef BIG_ENDIAN + #define BIG_ENDIAN (4321) + #endif + #ifndef LITTLE_ENDIAN + #define LITTLE_ENDIAN (1234) + #endif + + /* I386 is LE, even on Solaris */ + #if !defined(BYTE_ORDER) && defined(__i386__) + #define BYTE_ORDER LITTLE_ENDIAN + #endif +#endif + +// TODO(dreiss): defval appears to be unused. Look into removing it. +// TODO(dreiss): Make parse_spec_args recursive, and cache the output +// permanently in the object. (Malloc and orphan.) +// TODO(dreiss): Why do we need cStringIO for reading, why not just char*? +// Can cStringIO let us work with a BufferedTransport? +// TODO(dreiss): Don't ignore the rv from cwrite (maybe). + +/* ====== BEGIN UTILITIES ====== */ + +#define INIT_OUTBUF_SIZE 128 + +// Stolen out of TProtocol.h. +// It would be a huge pain to have both get this from one place. +typedef enum TType { + T_STOP = 0, + T_VOID = 1, + T_BOOL = 2, + T_BYTE = 3, + T_I08 = 3, + T_I16 = 6, + T_I32 = 8, + T_U64 = 9, + T_I64 = 10, + T_DOUBLE = 4, + T_STRING = 11, + T_UTF7 = 11, + T_STRUCT = 12, + T_MAP = 13, + T_SET = 14, + T_LIST = 15, + T_UTF8 = 16, + T_UTF16 = 17 +} TType; + +#ifndef __BYTE_ORDER +# if defined(BYTE_ORDER) && defined(LITTLE_ENDIAN) && defined(BIG_ENDIAN) +# define __BYTE_ORDER BYTE_ORDER +# define __LITTLE_ENDIAN LITTLE_ENDIAN +# define __BIG_ENDIAN BIG_ENDIAN +# else +# error "Cannot determine endianness" +# endif +#endif + +// Same comment as the enum. Sorry. +#if __BYTE_ORDER == __BIG_ENDIAN +# define ntohll(n) (n) +# define htonll(n) (n) +#elif __BYTE_ORDER == __LITTLE_ENDIAN +# if defined(__GNUC__) && defined(__GLIBC__) +# include <byteswap.h> +# define ntohll(n) bswap_64(n) +# define htonll(n) bswap_64(n) +# else /* GNUC & GLIBC */ +# define ntohll(n) ( (((unsigned long long)ntohl(n)) << 32) + ntohl(n >> 32) ) +# define htonll(n) ( (((unsigned long long)htonl(n)) << 32) + htonl(n >> 32) ) +# endif /* GNUC & GLIBC */ +#else /* __BYTE_ORDER */ +# error "Can't define htonll or ntohll!" +#endif + +// Doing a benchmark shows that interning actually makes a difference, amazingly. +#define INTERN_STRING(value) _intern_ ## value + +#define INT_CONV_ERROR_OCCURRED(v) ( ((v) == -1) && PyErr_Occurred() ) +#define CHECK_RANGE(v, min, max) ( ((v) <= (max)) && ((v) >= (min)) ) + +// Py_ssize_t was not defined before Python 2.5 +#if (PY_VERSION_HEX < 0x02050000) +typedef int Py_ssize_t; +#endif + +/** + * A cache of the spec_args for a set or list, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { + TType element_type; + PyObject* typeargs; +} SetListTypeArgs; + +/** + * A cache of the spec_args for a map, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { + TType ktag; + TType vtag; + PyObject* ktypeargs; + PyObject* vtypeargs; +} MapTypeArgs; + +/** + * A cache of the spec_args for a struct, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { + PyObject* klass; + PyObject* spec; +} StructTypeArgs; + +/** + * A cache of the item spec from a struct specification, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { + int tag; + TType type; + PyObject* attrname; + PyObject* typeargs; + PyObject* defval; +} StructItemSpec; + +/** + * A cache of the two key attributes of a CReadableTransport, + * so we don't have to keep calling PyObject_GetAttr. + */ +typedef struct { + PyObject* stringiobuf; + PyObject* refill_callable; +} DecodeBuffer; + +/** Pointer to interned string to speed up attribute lookup. */ +static PyObject* INTERN_STRING(cstringio_buf); +/** Pointer to interned string to speed up attribute lookup. */ +static PyObject* INTERN_STRING(cstringio_refill); + +static inline bool +check_ssize_t_32(Py_ssize_t len) { + // error from getting the int + if (INT_CONV_ERROR_OCCURRED(len)) { + return false; + } + if (!CHECK_RANGE(len, 0, INT32_MAX)) { + PyErr_SetString(PyExc_OverflowError, "string size out of range"); + return false; + } + return true; +} + +static inline bool +parse_pyint(PyObject* o, int32_t* ret, int32_t min, int32_t max) { + long val = PyInt_AsLong(o); + + if (INT_CONV_ERROR_OCCURRED(val)) { + return false; + } + if (!CHECK_RANGE(val, min, max)) { + PyErr_SetString(PyExc_OverflowError, "int out of range"); + return false; + } + + *ret = (int32_t) val; + return true; +} + + +/* --- FUNCTIONS TO PARSE STRUCT SPECIFICATOINS --- */ + +static bool +parse_set_list_args(SetListTypeArgs* dest, PyObject* typeargs) { + if (PyTuple_Size(typeargs) != 2) { + PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for list/set type args"); + return false; + } + + dest->element_type = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0)); + if (INT_CONV_ERROR_OCCURRED(dest->element_type)) { + return false; + } + + dest->typeargs = PyTuple_GET_ITEM(typeargs, 1); + + return true; +} + +static bool +parse_map_args(MapTypeArgs* dest, PyObject* typeargs) { + if (PyTuple_Size(typeargs) != 4) { + PyErr_SetString(PyExc_TypeError, "expecting 4 arguments for typeargs to map"); + return false; + } + + dest->ktag = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0)); + if (INT_CONV_ERROR_OCCURRED(dest->ktag)) { + return false; + } + + dest->vtag = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 2)); + if (INT_CONV_ERROR_OCCURRED(dest->vtag)) { + return false; + } + + dest->ktypeargs = PyTuple_GET_ITEM(typeargs, 1); + dest->vtypeargs = PyTuple_GET_ITEM(typeargs, 3); + + return true; +} + +static bool +parse_struct_args(StructTypeArgs* dest, PyObject* typeargs) { + if (PyTuple_Size(typeargs) != 2) { + PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for struct args"); + return false; + } + + dest->klass = PyTuple_GET_ITEM(typeargs, 0); + dest->spec = PyTuple_GET_ITEM(typeargs, 1); + + return true; +} + +static int +parse_struct_item_spec(StructItemSpec* dest, PyObject* spec_tuple) { + + // i'd like to use ParseArgs here, but it seems to be a bottleneck. + if (PyTuple_Size(spec_tuple) != 5) { + PyErr_SetString(PyExc_TypeError, "expecting 5 arguments for spec tuple"); + return false; + } + + dest->tag = PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 0)); + if (INT_CONV_ERROR_OCCURRED(dest->tag)) { + return false; + } + + dest->type = PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 1)); + if (INT_CONV_ERROR_OCCURRED(dest->type)) { + return false; + } + + dest->attrname = PyTuple_GET_ITEM(spec_tuple, 2); + dest->typeargs = PyTuple_GET_ITEM(spec_tuple, 3); + dest->defval = PyTuple_GET_ITEM(spec_tuple, 4); + return true; +} + +/* ====== END UTILITIES ====== */ + + +/* ====== BEGIN WRITING FUNCTIONS ====== */ + +/* --- LOW-LEVEL WRITING FUNCTIONS --- */ + +static void writeByte(PyObject* outbuf, int8_t val) { + int8_t net = val; + PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int8_t)); +} + +static void writeI16(PyObject* outbuf, int16_t val) { + int16_t net = (int16_t)htons(val); + PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int16_t)); +} + +static void writeI32(PyObject* outbuf, int32_t val) { + int32_t net = (int32_t)htonl(val); + PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int32_t)); +} + +static void writeI64(PyObject* outbuf, int64_t val) { + int64_t net = (int64_t)htonll(val); + PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int64_t)); +} + +static void writeDouble(PyObject* outbuf, double dub) { + // Unfortunately, bitwise_cast doesn't work in C. Bad C! + union { + double f; + int64_t t; + } transfer; + transfer.f = dub; + writeI64(outbuf, transfer.t); +} + + +/* --- MAIN RECURSIVE OUTPUT FUCNTION -- */ + +static int +output_val(PyObject* output, PyObject* value, TType type, PyObject* typeargs) { + /* + * Refcounting Strategy: + * + * We assume that elements of the thrift_spec tuple are not going to be + * mutated, so we don't ref count those at all. Other than that, we try to + * keep a reference to all the user-created objects while we work with them. + * output_val assumes that a reference is already held. The *caller* is + * responsible for handling references + */ + + switch (type) { + + case T_BOOL: { + int v = PyObject_IsTrue(value); + if (v == -1) { + return false; + } + + writeByte(output, (int8_t) v); + break; + } + case T_I08: { + int32_t val; + + if (!parse_pyint(value, &val, INT8_MIN, INT8_MAX)) { + return false; + } + + writeByte(output, (int8_t) val); + break; + } + case T_I16: { + int32_t val; + + if (!parse_pyint(value, &val, INT16_MIN, INT16_MAX)) { + return false; + } + + writeI16(output, (int16_t) val); + break; + } + case T_I32: { + int32_t val; + + if (!parse_pyint(value, &val, INT32_MIN, INT32_MAX)) { + return false; + } + + writeI32(output, val); + break; + } + case T_I64: { + int64_t nval = PyLong_AsLongLong(value); + + if (INT_CONV_ERROR_OCCURRED(nval)) { + return false; + } + + if (!CHECK_RANGE(nval, INT64_MIN, INT64_MAX)) { + PyErr_SetString(PyExc_OverflowError, "int out of range"); + return false; + } + + writeI64(output, nval); + break; + } + + case T_DOUBLE: { + double nval = PyFloat_AsDouble(value); + if (nval == -1.0 && PyErr_Occurred()) { + return false; + } + + writeDouble(output, nval); + break; + } + + case T_STRING: { + Py_ssize_t len = PyString_Size(value); + + if (!check_ssize_t_32(len)) { + return false; + } + + writeI32(output, (int32_t) len); + PycStringIO->cwrite(output, PyString_AsString(value), (int32_t) len); + break; + } + + case T_LIST: + case T_SET: { + Py_ssize_t len; + SetListTypeArgs parsedargs; + PyObject *item; + PyObject *iterator; + + if (!parse_set_list_args(&parsedargs, typeargs)) { + return false; + } + + len = PyObject_Length(value); + + if (!check_ssize_t_32(len)) { + return false; + } + + writeByte(output, parsedargs.element_type); + writeI32(output, (int32_t) len); + + iterator = PyObject_GetIter(value); + if (iterator == NULL) { + return false; + } + + while ((item = PyIter_Next(iterator))) { + if (!output_val(output, item, parsedargs.element_type, parsedargs.typeargs)) { + Py_DECREF(item); + Py_DECREF(iterator); + return false; + } + Py_DECREF(item); + } + + Py_DECREF(iterator); + + if (PyErr_Occurred()) { + return false; + } + + break; + } + + case T_MAP: { + PyObject *k, *v; + Py_ssize_t pos = 0; + Py_ssize_t len; + + MapTypeArgs parsedargs; + + len = PyDict_Size(value); + if (!check_ssize_t_32(len)) { + return false; + } + + if (!parse_map_args(&parsedargs, typeargs)) { + return false; + } + + writeByte(output, parsedargs.ktag); + writeByte(output, parsedargs.vtag); + writeI32(output, len); + + // TODO(bmaurer): should support any mapping, not just dicts + while (PyDict_Next(value, &pos, &k, &v)) { + // TODO(dreiss): Think hard about whether these INCREFs actually + // turn any unsafe scenarios into safe scenarios. + Py_INCREF(k); + Py_INCREF(v); + + if (!output_val(output, k, parsedargs.ktag, parsedargs.ktypeargs) + || !output_val(output, v, parsedargs.vtag, parsedargs.vtypeargs)) { + Py_DECREF(k); + Py_DECREF(v); + return false; + } + Py_DECREF(k); + Py_DECREF(v); + } + break; + } + + // TODO(dreiss): Consider breaking this out as a function + // the way we did for decode_struct. + case T_STRUCT: { + StructTypeArgs parsedargs; + Py_ssize_t nspec; + Py_ssize_t i; + + if (!parse_struct_args(&parsedargs, typeargs)) { + return false; + } + + nspec = PyTuple_Size(parsedargs.spec); + + if (nspec == -1) { + return false; + } + + for (i = 0; i < nspec; i++) { + StructItemSpec parsedspec; + PyObject* spec_tuple; + PyObject* instval = NULL; + + spec_tuple = PyTuple_GET_ITEM(parsedargs.spec, i); + if (spec_tuple == Py_None) { + continue; + } + + if (!parse_struct_item_spec (&parsedspec, spec_tuple)) { + return false; + } + + instval = PyObject_GetAttr(value, parsedspec.attrname); + + if (!instval) { + return false; + } + + if (instval == Py_None) { + Py_DECREF(instval); + continue; + } + + writeByte(output, (int8_t) parsedspec.type); + writeI16(output, parsedspec.tag); + + if (!output_val(output, instval, parsedspec.type, parsedspec.typeargs)) { + Py_DECREF(instval); + return false; + } + + Py_DECREF(instval); + } + + writeByte(output, (int8_t)T_STOP); + break; + } + + case T_STOP: + case T_VOID: + case T_UTF16: + case T_UTF8: + case T_U64: + default: + PyErr_SetString(PyExc_TypeError, "Unexpected TType"); + return false; + + } + + return true; +} + + +/* --- TOP-LEVEL WRAPPER FOR OUTPUT -- */ + +static PyObject * +encode_binary(PyObject *self, PyObject *args) { + PyObject* enc_obj; + PyObject* type_args; + PyObject* buf; + PyObject* ret = NULL; + + if (!PyArg_ParseTuple(args, "OO", &enc_obj, &type_args)) { + return NULL; + } + + buf = PycStringIO->NewOutput(INIT_OUTBUF_SIZE); + if (output_val(buf, enc_obj, T_STRUCT, type_args)) { + ret = PycStringIO->cgetvalue(buf); + } + + Py_DECREF(buf); + return ret; +} + +/* ====== END WRITING FUNCTIONS ====== */ + + +/* ====== BEGIN READING FUNCTIONS ====== */ + +/* --- LOW-LEVEL READING FUNCTIONS --- */ + +static void +free_decodebuf(DecodeBuffer* d) { + Py_XDECREF(d->stringiobuf); + Py_XDECREF(d->refill_callable); +} + +static bool +decode_buffer_from_obj(DecodeBuffer* dest, PyObject* obj) { + dest->stringiobuf = PyObject_GetAttr(obj, INTERN_STRING(cstringio_buf)); + if (!dest->stringiobuf) { + return false; + } + + if (!PycStringIO_InputCheck(dest->stringiobuf)) { + free_decodebuf(dest); + PyErr_SetString(PyExc_TypeError, "expecting stringio input"); + return false; + } + + dest->refill_callable = PyObject_GetAttr(obj, INTERN_STRING(cstringio_refill)); + + if(!dest->refill_callable) { + free_decodebuf(dest); + return false; + } + + if (!PyCallable_Check(dest->refill_callable)) { + free_decodebuf(dest); + PyErr_SetString(PyExc_TypeError, "expecting callable"); + return false; + } + + return true; +} + +static bool readBytes(DecodeBuffer* input, char** output, int len) { + int read; + + // TODO(dreiss): Don't fear the malloc. Think about taking a copy of + // the partial read instead of forcing the transport + // to prepend it to its buffer. + + read = PycStringIO->cread(input->stringiobuf, output, len); + + if (read == len) { + return true; + } else if (read == -1) { + return false; + } else { + PyObject* newiobuf; + + // using building functions as this is a rare codepath + newiobuf = PyObject_CallFunction( + input->refill_callable, "s#i", *output, read, len, NULL); + if (newiobuf == NULL) { + return false; + } + + // must do this *AFTER* the call so that we don't deref the io buffer + Py_CLEAR(input->stringiobuf); + input->stringiobuf = newiobuf; + + read = PycStringIO->cread(input->stringiobuf, output, len); + + if (read == len) { + return true; + } else if (read == -1) { + return false; + } else { + // TODO(dreiss): This could be a valid code path for big binary blobs. + PyErr_SetString(PyExc_TypeError, + "refill claimed to have refilled the buffer, but didn't!!"); + return false; + } + } +} + +static int8_t readByte(DecodeBuffer* input) { + char* buf; + if (!readBytes(input, &buf, sizeof(int8_t))) { + return -1; + } + + return *(int8_t*) buf; +} + +static int16_t readI16(DecodeBuffer* input) { + char* buf; + if (!readBytes(input, &buf, sizeof(int16_t))) { + return -1; + } + + return (int16_t) ntohs(*(int16_t*) buf); +} + +static int32_t readI32(DecodeBuffer* input) { + char* buf; + if (!readBytes(input, &buf, sizeof(int32_t))) { + return -1; + } + return (int32_t) ntohl(*(int32_t*) buf); +} + + +static int64_t readI64(DecodeBuffer* input) { + char* buf; + if (!readBytes(input, &buf, sizeof(int64_t))) { + return -1; + } + + return (int64_t) ntohll(*(int64_t*) buf); +} + +static double readDouble(DecodeBuffer* input) { + union { + int64_t f; + double t; + } transfer; + + transfer.f = readI64(input); + if (transfer.f == -1) { + return -1; + } + return transfer.t; +} + +static bool +checkTypeByte(DecodeBuffer* input, TType expected) { + TType got = readByte(input); + if (INT_CONV_ERROR_OCCURRED(got)) { + return false; + } + + if (expected != got) { + PyErr_SetString(PyExc_TypeError, "got wrong ttype while reading field"); + return false; + } + return true; +} + +static bool +skip(DecodeBuffer* input, TType type) { +#define SKIPBYTES(n) \ + do { \ + if (!readBytes(input, &dummy_buf, (n))) { \ + return false; \ + } \ + } while(0) + + char* dummy_buf; + + switch (type) { + + case T_BOOL: + case T_I08: SKIPBYTES(1); break; + case T_I16: SKIPBYTES(2); break; + case T_I32: SKIPBYTES(4); break; + case T_I64: + case T_DOUBLE: SKIPBYTES(8); break; + + case T_STRING: { + // TODO(dreiss): Find out if these check_ssize_t32s are really necessary. + int len = readI32(input); + if (!check_ssize_t_32(len)) { + return false; + } + SKIPBYTES(len); + break; + } + + case T_LIST: + case T_SET: { + TType etype; + int len, i; + + etype = readByte(input); + if (etype == -1) { + return false; + } + + len = readI32(input); + if (!check_ssize_t_32(len)) { + return false; + } + + for (i = 0; i < len; i++) { + if (!skip(input, etype)) { + return false; + } + } + break; + } + + case T_MAP: { + TType ktype, vtype; + int len, i; + + ktype = readByte(input); + if (ktype == -1) { + return false; + } + + vtype = readByte(input); + if (vtype == -1) { + return false; + } + + len = readI32(input); + if (!check_ssize_t_32(len)) { + return false; + } + + for (i = 0; i < len; i++) { + if (!(skip(input, ktype) && skip(input, vtype))) { + return false; + } + } + break; + } + + case T_STRUCT: { + while (true) { + TType type; + + type = readByte(input); + if (type == -1) { + return false; + } + + if (type == T_STOP) + break; + + SKIPBYTES(2); // tag + if (!skip(input, type)) { + return false; + } + } + break; + } + + case T_STOP: + case T_VOID: + case T_UTF16: + case T_UTF8: + case T_U64: + default: + PyErr_SetString(PyExc_TypeError, "Unexpected TType"); + return false; + + } + + return true; + +#undef SKIPBYTES +} + + +/* --- HELPER FUNCTION FOR DECODE_VAL --- */ + +static PyObject* +decode_val(DecodeBuffer* input, TType type, PyObject* typeargs); + +static bool +decode_struct(DecodeBuffer* input, PyObject* output, PyObject* spec_seq) { + int spec_seq_len = PyTuple_Size(spec_seq); + if (spec_seq_len == -1) { + return false; + } + + while (true) { + TType type; + int16_t tag; + PyObject* item_spec; + PyObject* fieldval = NULL; + StructItemSpec parsedspec; + + type = readByte(input); + if (type == -1) { + return false; + } + if (type == T_STOP) { + break; + } + tag = readI16(input); + if (INT_CONV_ERROR_OCCURRED(tag)) { + return false; + } + if (tag >= 0 && tag < spec_seq_len) { + item_spec = PyTuple_GET_ITEM(spec_seq, tag); + } else { + item_spec = Py_None; + } + + if (item_spec == Py_None) { + if (!skip(input, type)) { + return false; + } else { + continue; + } + } + + if (!parse_struct_item_spec(&parsedspec, item_spec)) { + return false; + } + if (parsedspec.type != type) { + if (!skip(input, type)) { + PyErr_SetString(PyExc_TypeError, "struct field had wrong type while reading and can't be skipped"); + return false; + } else { + continue; + } + } + + fieldval = decode_val(input, parsedspec.type, parsedspec.typeargs); + if (fieldval == NULL) { + return false; + } + + if (PyObject_SetAttr(output, parsedspec.attrname, fieldval) == -1) { + Py_DECREF(fieldval); + return false; + } + Py_DECREF(fieldval); + } + return true; +} + + +/* --- MAIN RECURSIVE INPUT FUCNTION --- */ + +// Returns a new reference. +static PyObject* +decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) { + switch (type) { + + case T_BOOL: { + int8_t v = readByte(input); + if (INT_CONV_ERROR_OCCURRED(v)) { + return NULL; + } + + switch (v) { + case 0: Py_RETURN_FALSE; + case 1: Py_RETURN_TRUE; + // Don't laugh. This is a potentially serious issue. + default: PyErr_SetString(PyExc_TypeError, "boolean out of range"); return NULL; + } + break; + } + case T_I08: { + int8_t v = readByte(input); + if (INT_CONV_ERROR_OCCURRED(v)) { + return NULL; + } + + return PyInt_FromLong(v); + } + case T_I16: { + int16_t v = readI16(input); + if (INT_CONV_ERROR_OCCURRED(v)) { + return NULL; + } + return PyInt_FromLong(v); + } + case T_I32: { + int32_t v = readI32(input); + if (INT_CONV_ERROR_OCCURRED(v)) { + return NULL; + } + return PyInt_FromLong(v); + } + + case T_I64: { + int64_t v = readI64(input); + if (INT_CONV_ERROR_OCCURRED(v)) { + return NULL; + } + // TODO(dreiss): Find out if we can take this fastpath always when + // sizeof(long) == sizeof(long long). + if (CHECK_RANGE(v, LONG_MIN, LONG_MAX)) { + return PyInt_FromLong((long) v); + } + + return PyLong_FromLongLong(v); + } + + case T_DOUBLE: { + double v = readDouble(input); + if (v == -1.0 && PyErr_Occurred()) { + return false; + } + return PyFloat_FromDouble(v); + } + + case T_STRING: { + Py_ssize_t len = readI32(input); + char* buf; + if (!readBytes(input, &buf, len)) { + return NULL; + } + + return PyString_FromStringAndSize(buf, len); + } + + case T_LIST: + case T_SET: { + SetListTypeArgs parsedargs; + int32_t len; + PyObject* ret = NULL; + int i; + + if (!parse_set_list_args(&parsedargs, typeargs)) { + return NULL; + } + + if (!checkTypeByte(input, parsedargs.element_type)) { + return NULL; + } + + len = readI32(input); + if (!check_ssize_t_32(len)) { + return NULL; + } + + ret = PyList_New(len); + if (!ret) { + return NULL; + } + + for (i = 0; i < len; i++) { + PyObject* item = decode_val(input, parsedargs.element_type, parsedargs.typeargs); + if (!item) { + Py_DECREF(ret); + return NULL; + } + PyList_SET_ITEM(ret, i, item); + } + + // TODO(dreiss): Consider biting the bullet and making two separate cases + // for list and set, avoiding this post facto conversion. + if (type == T_SET) { + PyObject* setret; +#if (PY_VERSION_HEX < 0x02050000) + // hack needed for older versions + setret = PyObject_CallFunctionObjArgs((PyObject*)&PySet_Type, ret, NULL); +#else + // official version + setret = PySet_New(ret); +#endif + Py_DECREF(ret); + return setret; + } + return ret; + } + + case T_MAP: { + int32_t len; + int i; + MapTypeArgs parsedargs; + PyObject* ret = NULL; + + if (!parse_map_args(&parsedargs, typeargs)) { + return NULL; + } + + if (!checkTypeByte(input, parsedargs.ktag)) { + return NULL; + } + if (!checkTypeByte(input, parsedargs.vtag)) { + return NULL; + } + + len = readI32(input); + if (!check_ssize_t_32(len)) { + return false; + } + + ret = PyDict_New(); + if (!ret) { + goto error; + } + + for (i = 0; i < len; i++) { + PyObject* k = NULL; + PyObject* v = NULL; + k = decode_val(input, parsedargs.ktag, parsedargs.ktypeargs); + if (k == NULL) { + goto loop_error; + } + v = decode_val(input, parsedargs.vtag, parsedargs.vtypeargs); + if (v == NULL) { + goto loop_error; + } + if (PyDict_SetItem(ret, k, v) == -1) { + goto loop_error; + } + + Py_DECREF(k); + Py_DECREF(v); + continue; + + // Yuck! Destructors, anyone? + loop_error: + Py_XDECREF(k); + Py_XDECREF(v); + goto error; + } + + return ret; + + error: + Py_XDECREF(ret); + return NULL; + } + + case T_STRUCT: { + StructTypeArgs parsedargs; + PyObject* ret; + if (!parse_struct_args(&parsedargs, typeargs)) { + return NULL; + } + + ret = PyObject_CallObject(parsedargs.klass, NULL); + if (!ret) { + return NULL; + } + + if (!decode_struct(input, ret, parsedargs.spec)) { + Py_DECREF(ret); + return NULL; + } + + return ret; + } + + case T_STOP: + case T_VOID: + case T_UTF16: + case T_UTF8: + case T_U64: + default: + PyErr_SetString(PyExc_TypeError, "Unexpected TType"); + return NULL; + } +} + + +/* --- TOP-LEVEL WRAPPER FOR INPUT -- */ + +static PyObject* +decode_binary(PyObject *self, PyObject *args) { + PyObject* output_obj = NULL; + PyObject* transport = NULL; + PyObject* typeargs = NULL; + StructTypeArgs parsedargs; + DecodeBuffer input = {0, 0}; + + if (!PyArg_ParseTuple(args, "OOO", &output_obj, &transport, &typeargs)) { + return NULL; + } + + if (!parse_struct_args(&parsedargs, typeargs)) { + return NULL; + } + + if (!decode_buffer_from_obj(&input, transport)) { + return NULL; + } + + if (!decode_struct(&input, output_obj, parsedargs.spec)) { + free_decodebuf(&input); + return NULL; + } + + free_decodebuf(&input); + + Py_RETURN_NONE; +} + +/* ====== END READING FUNCTIONS ====== */ + + +/* -- PYTHON MODULE SETUP STUFF --- */ + +static PyMethodDef ThriftFastBinaryMethods[] = { + + {"encode_binary", encode_binary, METH_VARARGS, ""}, + {"decode_binary", decode_binary, METH_VARARGS, ""}, + + {NULL, NULL, 0, NULL} /* Sentinel */ +}; + +PyMODINIT_FUNC +initfastbinary(void) { +#define INIT_INTERN_STRING(value) \ + do { \ + INTERN_STRING(value) = PyString_InternFromString(#value); \ + if(!INTERN_STRING(value)) return; \ + } while(0) + + INIT_INTERN_STRING(cstringio_buf); + INIT_INTERN_STRING(cstringio_refill); +#undef INIT_INTERN_STRING + + PycString_IMPORT; + if (PycStringIO == NULL) return; + + (void) Py_InitModule("thrift.protocol.fastbinary", ThriftFastBinaryMethods); +} diff --git a/module/lib/thrift/server/THttpServer.py b/module/lib/thrift/server/THttpServer.py index 3047d9c00..be54bab94 100644 --- a/module/lib/thrift/server/THttpServer.py +++ b/module/lib/thrift/server/THttpServer.py @@ -22,6 +22,7 @@ import BaseHTTPServer from thrift.server import TServer from thrift.transport import TTransport + class ResponseException(Exception): """Allows handlers to override the HTTP response @@ -39,16 +40,19 @@ class THttpServer(TServer.TServer): """A simple HTTP-based Thrift server This class is not very performant, but it is useful (for example) for - acting as a mock version of an Apache-based PHP Thrift endpoint.""" - - def __init__(self, processor, server_address, - inputProtocolFactory, outputProtocolFactory = None, - server_class = BaseHTTPServer.HTTPServer): + acting as a mock version of an Apache-based PHP Thrift endpoint. + """ + def __init__(self, + processor, + server_address, + inputProtocolFactory, + outputProtocolFactory=None, + server_class=BaseHTTPServer.HTTPServer): """Set up protocol factories and HTTP server. See BaseHTTPServer for server_address. - See TServer for protocol factories.""" - + See TServer for protocol factories. + """ if outputProtocolFactory is None: outputProtocolFactory = inputProtocolFactory @@ -62,7 +66,8 @@ class THttpServer(TServer.TServer): # Don't care about the request path. itrans = TTransport.TFileObjectTransport(self.rfile) otrans = TTransport.TFileObjectTransport(self.wfile) - itrans = TTransport.TBufferedTransport(itrans, int(self.headers['Content-Length'])) + itrans = TTransport.TBufferedTransport( + itrans, int(self.headers['Content-Length'])) otrans = TTransport.TMemoryBuffer() iprot = thttpserver.inputProtocolFactory.getProtocol(itrans) oprot = thttpserver.outputProtocolFactory.getProtocol(otrans) diff --git a/module/lib/thrift/server/TNonblockingServer.py b/module/lib/thrift/server/TNonblockingServer.py index ea348a0b6..fa478d01f 100644 --- a/module/lib/thrift/server/TNonblockingServer.py +++ b/module/lib/thrift/server/TNonblockingServer.py @@ -18,10 +18,11 @@ # """Implementation of non-blocking server. -The main idea of the server is reciving and sending requests -only from main thread. +The main idea of the server is to receive and send requests +only from the main thread. -It also makes thread pool server in tasks terms, not connections. +The thread poool should be sized for concurrent tasks, not +maximum connections """ import threading import socket @@ -35,8 +36,10 @@ from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory __all__ = ['TNonblockingServer'] + class Worker(threading.Thread): """Worker is a small helper to process incoming connection.""" + def __init__(self, queue): threading.Thread.__init__(self) self.queue = queue @@ -60,8 +63,9 @@ WAIT_PROCESS = 2 SEND_ANSWER = 3 CLOSED = 4 + def locked(func): - "Decorator which locks self.lock." + """Decorator which locks self.lock.""" def nested(self, *args, **kwargs): self.lock.acquire() try: @@ -70,8 +74,9 @@ def locked(func): self.lock.release() return nested + def socket_exception(func): - "Decorator close object on socket.error." + """Decorator close object on socket.error.""" def read(self, *args, **kwargs): try: return func(self, *args, **kwargs) @@ -79,16 +84,17 @@ def socket_exception(func): self.close() return read + class Connection: """Basic class is represented connection. - + It can be in state: WAIT_LEN --- connection is reading request len. WAIT_MESSAGE --- connection is reading request. - WAIT_PROCESS --- connection has just read whole request and - waits for call ready routine. + WAIT_PROCESS --- connection has just read whole request and + waits for call ready routine. SEND_ANSWER --- connection is sending answer string (including length - of answer). + of answer). CLOSED --- socket was closed and connection should be deleted. """ def __init__(self, new_socket, wake_up): @@ -102,13 +108,13 @@ class Connection: def _read_len(self): """Reads length of request. - - It's really paranoic routine and it may be replaced by - self.socket.recv(4).""" + + It's a safer alternative to self.socket.recv(4) + """ read = self.socket.recv(4 - len(self.message)) if len(read) == 0: - # if we read 0 bytes and self.message is empty, it means client close - # connection + # if we read 0 bytes and self.message is empty, then + # the client closed the connection if len(self.message) != 0: logging.error("can't read frame size from socket") self.close() @@ -117,8 +123,8 @@ class Connection: if len(self.message) == 4: self.len, = struct.unpack('!i', self.message) if self.len < 0: - logging.error("negative frame size, it seems client"\ - " doesn't use FramedTransport") + logging.error("negative frame size, it seems client " + "doesn't use FramedTransport") self.close() elif self.len == 0: logging.error("empty frame, it's really strange") @@ -139,8 +145,8 @@ class Connection: elif self.status == WAIT_MESSAGE: read = self.socket.recv(self.len - len(self.message)) if len(read) == 0: - logging.error("can't read frame from socket (get %d of %d bytes)" % - (len(self.message), self.len)) + logging.error("can't read frame from socket (get %d of " + "%d bytes)" % (len(self.message), self.len)) self.close() return self.message += read @@ -162,14 +168,14 @@ class Connection: @locked def ready(self, all_ok, message): """Callback function for switching state and waking up main thread. - + This function is the only function witch can be called asynchronous. - + The ready can switch Connection to three states: WAIT_LEN if request was oneway. SEND_ANSWER if request was processed in normal way. CLOSED if request throws unexpected exception. - + The one wakes up main thread. """ assert self.status == WAIT_PROCESS @@ -189,33 +195,39 @@ class Connection: @locked def is_writeable(self): - "Returns True if connection should be added to write list of select." + """Return True if connection should be added to write list of select""" return self.status == SEND_ANSWER # it's not necessary, but... @locked def is_readable(self): - "Returns True if connection should be added to read list of select." + """Return True if connection should be added to read list of select""" return self.status in (WAIT_LEN, WAIT_MESSAGE) @locked def is_closed(self): - "Returns True if connection is closed." + """Returns True if connection is closed.""" return self.status == CLOSED def fileno(self): - "Returns the file descriptor of the associated socket." + """Returns the file descriptor of the associated socket.""" return self.socket.fileno() def close(self): - "Closes connection" + """Closes connection""" self.status = CLOSED self.socket.close() + class TNonblockingServer: """Non-blocking server.""" - def __init__(self, processor, lsocket, inputProtocolFactory=None, - outputProtocolFactory=None, threads=10): + + def __init__(self, + processor, + lsocket, + inputProtocolFactory=None, + outputProtocolFactory=None, + threads=10): self.processor = processor self.socket = lsocket self.in_protocol = inputProtocolFactory or TBinaryProtocolFactory() @@ -225,15 +237,18 @@ class TNonblockingServer: self.tasks = Queue.Queue() self._read, self._write = socket.socketpair() self.prepared = False + self._stop = False def setNumThreads(self, num): """Set the number of worker threads that should be created.""" # implement ThreadPool interface - assert not self.prepared, "You can't change number of threads for working server" + assert not self.prepared, "Can't change number of threads after start" self.threads = num def prepare(self): """Prepares server for serve requests.""" + if self.prepared: + return self.socket.listen() for _ in xrange(self.threads): thread = Worker(self.tasks) @@ -243,16 +258,32 @@ class TNonblockingServer: def wake_up(self): """Wake up main thread. - + The server usualy waits in select call in we should terminate one. The simplest way is using socketpair. - + Select always wait to read from the first socket of socketpair. - + In this case, we can just write anything to the second socket from - socketpair.""" + socketpair. + """ self._write.send('1') + def stop(self): + """Stop the server. + + This method causes the serve() method to return. stop() may be invoked + from within your handler, or from another thread. + + After stop() is called, serve() will return but the server will still + be listening on the socket. serve() may then be called again to resume + processing requests. Alternatively, close() may be called after + serve() returns to close the server socket and shutdown all worker + threads. + """ + self._stop = True + self.wake_up() + def _select(self): """Does select on open connections.""" readable = [self.socket.handle.fileno(), self._read.fileno()] @@ -265,21 +296,22 @@ class TNonblockingServer: if connection.is_closed(): del self.clients[i] return select.select(readable, writable, readable) - + def handle(self): """Handle requests. - - WARNING! You must call prepare BEFORE calling handle. + + WARNING! You must call prepare() BEFORE calling handle() """ assert self.prepared, "You have to call prepare before handle" rset, wset, xset = self._select() for readable in rset: if readable == self._read.fileno(): # don't care i just need to clean readable flag - self._read.recv(1024) + self._read.recv(1024) elif readable == self.socket.handle.fileno(): client = self.socket.accept().handle - self.clients[client.fileno()] = Connection(client, self.wake_up) + self.clients[client.fileno()] = Connection(client, + self.wake_up) else: connection = self.clients[readable] connection.read() @@ -288,7 +320,7 @@ class TNonblockingServer: otransport = TTransport.TMemoryBuffer() iprot = self.in_protocol.getProtocol(itransport) oprot = self.out_protocol.getProtocol(otransport) - self.tasks.put([self.processor, iprot, oprot, + self.tasks.put([self.processor, iprot, oprot, otransport, connection.ready]) for writeable in wset: self.clients[writeable].write() @@ -302,9 +334,13 @@ class TNonblockingServer: self.tasks.put([None, None, None, None, None]) self.socket.close() self.prepared = False - + def serve(self): - """Serve forever.""" + """Serve requests. + + Serve requests forever, or until stop() is called. + """ + self._stop = False self.prepare() - while True: + while not self._stop: self.handle() diff --git a/module/lib/thrift/server/TProcessPoolServer.py b/module/lib/thrift/server/TProcessPoolServer.py index 7ed814a88..7a695a883 100644 --- a/module/lib/thrift/server/TProcessPoolServer.py +++ b/module/lib/thrift/server/TProcessPoolServer.py @@ -24,15 +24,14 @@ from multiprocessing import Process, Value, Condition, reduction from TServer import TServer from thrift.transport.TTransport import TTransportException + class TProcessPoolServer(TServer): + """Server with a fixed size pool of worker subprocesses to service requests - """ - Server with a fixed size pool of worker subprocesses which service requests. Note that if you need shared state between the handlers - it's up to you! Written by Dvir Volk, doat.com """ - - def __init__(self, * args): + def __init__(self, *args): TServer.__init__(self, *args) self.numWorkers = 10 self.workers = [] @@ -50,12 +49,11 @@ class TProcessPoolServer(TServer): self.numWorkers = num def workerProcess(self): - """Loop around getting clients from the shared queue and process them.""" - + """Loop getting clients from the shared queue and process them""" if self.postForkCallback: self.postForkCallback() - while self.isRunning.value == True: + while self.isRunning.value: try: client = self.serverTransport.accept() self.serveClient(client) @@ -82,17 +80,15 @@ class TProcessPoolServer(TServer): itrans.close() otrans.close() - def serve(self): - """Start a fixed number of worker threads and put client into a queue""" - - #this is a shared state that can tell the workers to exit when set as false + """Start workers and put into queue""" + # this is a shared state that can tell the workers to exit when False self.isRunning.value = True - #first bind and listen to the port + # first bind and listen to the port self.serverTransport.listen() - #fork the children + # fork the children for i in range(self.numWorkers): try: w = Process(target=self.workerProcess) @@ -102,16 +98,14 @@ class TProcessPoolServer(TServer): except Exception, x: logging.exception(x) - #wait until the condition is set by stop() - + # wait until the condition is set by stop() while True: - self.stopCondition.acquire() try: self.stopCondition.wait() break except (SystemExit, KeyboardInterrupt): - break + break except Exception, x: logging.exception(x) @@ -122,4 +116,3 @@ class TProcessPoolServer(TServer): self.stopCondition.acquire() self.stopCondition.notify() self.stopCondition.release() - diff --git a/module/lib/thrift/server/TServer.py b/module/lib/thrift/server/TServer.py index 8456e2d40..2f24842c4 100644 --- a/module/lib/thrift/server/TServer.py +++ b/module/lib/thrift/server/TServer.py @@ -17,27 +17,28 @@ # under the License. # +import Queue import logging -import sys import os -import traceback +import sys import threading -import Queue +import traceback from thrift.Thrift import TProcessor -from thrift.transport import TTransport from thrift.protocol import TBinaryProtocol +from thrift.transport import TTransport -class TServer: - """Base interface for a server, which must have a serve method.""" +class TServer: + """Base interface for a server, which must have a serve() method. - """ 3 constructors for all servers: + Three constructors for all servers: 1) (processor, serverTransport) 2) (processor, serverTransport, transportFactory, protocolFactory) 3) (processor, serverTransport, inputTransportFactory, outputTransportFactory, - inputProtocolFactory, outputProtocolFactory)""" + inputProtocolFactory, outputProtocolFactory) + """ def __init__(self, *args): if (len(args) == 2): self.__initArgs__(args[0], args[1], @@ -63,8 +64,8 @@ class TServer: def serve(self): pass -class TSimpleServer(TServer): +class TSimpleServer(TServer): """Simple single-threaded server that just pumps around one transport.""" def __init__(self, *args): @@ -89,8 +90,8 @@ class TSimpleServer(TServer): itrans.close() otrans.close() -class TThreadedServer(TServer): +class TThreadedServer(TServer): """Threaded server that spawns a new thread per each connection.""" def __init__(self, *args, **kwargs): @@ -102,7 +103,7 @@ class TThreadedServer(TServer): while True: try: client = self.serverTransport.accept() - t = threading.Thread(target = self.handle, args=(client,)) + t = threading.Thread(target=self.handle, args=(client,)) t.setDaemon(self.daemon) t.start() except KeyboardInterrupt: @@ -126,8 +127,8 @@ class TThreadedServer(TServer): itrans.close() otrans.close() -class TThreadPoolServer(TServer): +class TThreadPoolServer(TServer): """Server with a fixed size pool of threads which service requests.""" def __init__(self, *args, **kwargs): @@ -170,7 +171,7 @@ class TThreadPoolServer(TServer): """Start a fixed number of worker threads and put client into a queue""" for i in range(self.threads): try: - t = threading.Thread(target = self.serveThread) + t = threading.Thread(target=self.serveThread) t.setDaemon(self.daemon) t.start() except Exception, x: @@ -187,9 +188,8 @@ class TThreadPoolServer(TServer): class TForkingServer(TServer): + """A Thrift server that forks a new process for each request - """A Thrift server that forks a new process for each request""" - """ This is more scalable than the threaded server as it does not cause GIL contention. @@ -200,7 +200,6 @@ class TForkingServer(TServer): This code is heavily inspired by SocketServer.ForkingMixIn in the Python stdlib. """ - def __init__(self, *args): TServer.__init__(self, *args) self.children = [] @@ -212,14 +211,13 @@ class TForkingServer(TServer): except IOError, e: logging.warning(e, exc_info=True) - self.serverTransport.listen() while True: client = self.serverTransport.accept() try: pid = os.fork() - if pid: # parent + if pid: # parent # add before collect, otherwise you race w/ waitpid self.children.append(pid) self.collect_children() @@ -258,7 +256,6 @@ class TForkingServer(TServer): except Exception, x: logging.exception(x) - def collect_children(self): while self.children: try: @@ -270,5 +267,3 @@ class TForkingServer(TServer): self.children.remove(pid) else: break - - diff --git a/module/lib/thrift/transport/THttpClient.py b/module/lib/thrift/transport/THttpClient.py index 50269785c..ea80a1ae8 100644 --- a/module/lib/thrift/transport/THttpClient.py +++ b/module/lib/thrift/transport/THttpClient.py @@ -17,16 +17,20 @@ # under the License. # -from TTransport import * -from cStringIO import StringIO - -import urlparse import httplib -import warnings +import os import socket +import sys +import urllib +import urlparse +import warnings -class THttpClient(TTransportBase): +from cStringIO import StringIO +from TTransport import * + + +class THttpClient(TTransportBase): """Http implementation of TTransport base.""" def __init__(self, uri_or_host, port=None, path=None): @@ -35,10 +39,13 @@ class THttpClient(TTransportBase): THttpClient(host, port, path) - deprecated THttpClient(uri) - Only the second supports https.""" - + Only the second supports https. + """ if port is not None: - warnings.warn("Please use the THttpClient('http://host:port/path') syntax", DeprecationWarning, stacklevel=2) + warnings.warn( + "Please use the THttpClient('http://host:port/path') syntax", + DeprecationWarning, + stacklevel=2) self.host = uri_or_host self.port = port assert path @@ -59,6 +66,7 @@ class THttpClient(TTransportBase): self.__wbuf = StringIO() self.__http = None self.__timeout = None + self.__custom_headers = None def open(self): if self.scheme == 'http': @@ -71,7 +79,7 @@ class THttpClient(TTransportBase): self.__http = None def isOpen(self): - return self.__http != None + return self.__http is not None def setTimeout(self, ms): if not hasattr(socket, 'getdefaulttimeout'): @@ -80,7 +88,10 @@ class THttpClient(TTransportBase): if ms is None: self.__timeout = None else: - self.__timeout = ms/1000.0 + self.__timeout = ms / 1000.0 + + def setCustomHeaders(self, headers): + self.__custom_headers = headers def read(self, sz): return self.__http.file.read(sz) @@ -100,7 +111,7 @@ class THttpClient(TTransportBase): def flush(self): if self.isOpen(): self.close() - self.open(); + self.open() # Pull data out of buffer data = self.__wbuf.getvalue() @@ -113,6 +124,18 @@ class THttpClient(TTransportBase): self.__http.putheader('Host', self.host) self.__http.putheader('Content-Type', 'application/x-thrift') self.__http.putheader('Content-Length', str(len(data))) + + if not self.__custom_headers or 'User-Agent' not in self.__custom_headers: + user_agent = 'Python/THttpClient' + script = os.path.basename(sys.argv[0]) + if script: + user_agent = '%s (%s)' % (user_agent, urllib.quote(script)) + self.__http.putheader('User-Agent', user_agent) + + if self.__custom_headers: + for key, val in self.__custom_headers.iteritems(): + self.__http.putheader(key, val) + self.__http.endheaders() # Write payload diff --git a/module/lib/thrift/transport/TSSLSocket.py b/module/lib/thrift/transport/TSSLSocket.py new file mode 100644 index 000000000..81e098426 --- /dev/null +++ b/module/lib/thrift/transport/TSSLSocket.py @@ -0,0 +1,214 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import os +import socket +import ssl + +from thrift.transport import TSocket +from thrift.transport.TTransport import TTransportException + + +class TSSLSocket(TSocket.TSocket): + """ + SSL implementation of client-side TSocket + + This class creates outbound sockets wrapped using the + python standard ssl module for encrypted connections. + + The protocol used is set using the class variable + SSL_VERSION, which must be one of ssl.PROTOCOL_* and + defaults to ssl.PROTOCOL_TLSv1 for greatest security. + """ + SSL_VERSION = ssl.PROTOCOL_TLSv1 + + def __init__(self, + host='localhost', + port=9090, + validate=True, + ca_certs=None, + keyfile=None, + certfile=None, + unix_socket=None): + """Create SSL TSocket + + @param validate: Set to False to disable SSL certificate validation + @type validate: bool + @param ca_certs: Filename to the Certificate Authority pem file, possibly a + file downloaded from: http://curl.haxx.se/ca/cacert.pem This is passed to + the ssl_wrap function as the 'ca_certs' parameter. + @type ca_certs: str + @param keyfile: The private key + @type keyfile: str + @param certfile: The cert file + @type certfile: str + + Raises an IOError exception if validate is True and the ca_certs file is + None, not present or unreadable. + """ + self.validate = validate + self.is_valid = False + self.peercert = None + if not validate: + self.cert_reqs = ssl.CERT_NONE + else: + self.cert_reqs = ssl.CERT_REQUIRED + self.ca_certs = ca_certs + self.keyfile = keyfile + self.certfile = certfile + if validate: + if ca_certs is None or not os.access(ca_certs, os.R_OK): + raise IOError('Certificate Authority ca_certs file "%s" ' + 'is not readable, cannot validate SSL ' + 'certificates.' % (ca_certs)) + TSocket.TSocket.__init__(self, host, port, unix_socket) + + def open(self): + try: + res0 = self._resolveAddr() + for res in res0: + sock_family, sock_type = res[0:2] + ip_port = res[4] + plain_sock = socket.socket(sock_family, sock_type) + self.handle = ssl.wrap_socket(plain_sock, + ssl_version=self.SSL_VERSION, + do_handshake_on_connect=True, + ca_certs=self.ca_certs, + keyfile=self.keyfile, + certfile=self.certfile, + cert_reqs=self.cert_reqs) + self.handle.settimeout(self._timeout) + try: + self.handle.connect(ip_port) + except socket.error, e: + if res is not res0[-1]: + continue + else: + raise e + break + except socket.error, e: + if self._unix_socket: + message = 'Could not connect to secure socket %s: %s' \ + % (self._unix_socket, e) + else: + message = 'Could not connect to %s:%d: %s' % (self.host, self.port, e) + raise TTransportException(type=TTransportException.NOT_OPEN, + message=message) + if self.validate: + self._validate_cert() + + def _validate_cert(self): + """internal method to validate the peer's SSL certificate, and to check the + commonName of the certificate to ensure it matches the hostname we + used to make this connection. Does not support subjectAltName records + in certificates. + + raises TTransportException if the certificate fails validation. + """ + cert = self.handle.getpeercert() + self.peercert = cert + if 'subject' not in cert: + raise TTransportException( + type=TTransportException.NOT_OPEN, + message='No SSL certificate found from %s:%s' % (self.host, self.port)) + fields = cert['subject'] + for field in fields: + # ensure structure we get back is what we expect + if not isinstance(field, tuple): + continue + cert_pair = field[0] + if len(cert_pair) < 2: + continue + cert_key, cert_value = cert_pair[0:2] + if cert_key != 'commonName': + continue + certhost = cert_value + # this check should be performed by some sort of Access Manager + if certhost == self.host: + # success, cert commonName matches desired hostname + self.is_valid = True + return + else: + raise TTransportException( + type=TTransportException.UNKNOWN, + message='Hostname we connected to "%s" doesn\'t match certificate ' + 'provided commonName "%s"' % (self.host, certhost)) + raise TTransportException( + type=TTransportException.UNKNOWN, + message='Could not validate SSL certificate from ' + 'host "%s". Cert=%s' % (self.host, cert)) + + +class TSSLServerSocket(TSocket.TServerSocket): + """SSL implementation of TServerSocket + + This uses the ssl module's wrap_socket() method to provide SSL + negotiated encryption. + """ + SSL_VERSION = ssl.PROTOCOL_TLSv1 + + def __init__(self, + host=None, + port=9090, + certfile='cert.pem', + unix_socket=None): + """Initialize a TSSLServerSocket + + @param certfile: filename of the server certificate, defaults to cert.pem + @type certfile: str + @param host: The hostname or IP to bind the listen socket to, + i.e. 'localhost' for only allowing local network connections. + Pass None to bind to all interfaces. + @type host: str + @param port: The port to listen on for inbound connections. + @type port: int + """ + self.setCertfile(certfile) + TSocket.TServerSocket.__init__(self, host, port) + + def setCertfile(self, certfile): + """Set or change the server certificate file used to wrap new connections. + + @param certfile: The filename of the server certificate, + i.e. '/etc/certs/server.pem' + @type certfile: str + + Raises an IOError exception if the certfile is not present or unreadable. + """ + if not os.access(certfile, os.R_OK): + raise IOError('No such certfile found: %s' % (certfile)) + self.certfile = certfile + + def accept(self): + plain_client, addr = self.handle.accept() + try: + client = ssl.wrap_socket(plain_client, certfile=self.certfile, + server_side=True, ssl_version=self.SSL_VERSION) + except ssl.SSLError, ssl_exc: + # failed handshake/ssl wrap, close socket to client + plain_client.close() + # raise ssl_exc + # We can't raise the exception, because it kills most TServer derived + # serve() methods. + # Instead, return None, and let the TServer instance deal with it in + # other exception handling. (but TSimpleServer dies anyway) + return None + result = TSocket.TSocket() + result.setHandle(client) + return result diff --git a/module/lib/thrift/transport/TSocket.py b/module/lib/thrift/transport/TSocket.py index 4e0e1874f..9e2b3849b 100644 --- a/module/lib/thrift/transport/TSocket.py +++ b/module/lib/thrift/transport/TSocket.py @@ -17,24 +17,33 @@ # under the License. # -from TTransport import * -import os import errno +import os import socket import sys +from TTransport import * + + class TSocketBase(TTransportBase): def _resolveAddr(self): if self._unix_socket is not None: - return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, self._unix_socket)] + return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, + self._unix_socket)] else: - return socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE | socket.AI_ADDRCONFIG) + return socket.getaddrinfo(self.host, + self.port, + socket.AF_UNSPEC, + socket.SOCK_STREAM, + 0, + socket.AI_PASSIVE | socket.AI_ADDRCONFIG) def close(self): if self.handle: self.handle.close() self.handle = None + class TSocket(TSocketBase): """Socket implementation of TTransport base.""" @@ -46,7 +55,6 @@ class TSocket(TSocketBase): @param unix_socket(str) The filename of a unix socket to connect to. (host and port will be ignored.) """ - self.host = host self.port = port self.handle = None @@ -63,7 +71,7 @@ class TSocket(TSocketBase): if ms is None: self._timeout = None else: - self._timeout = ms/1000.0 + self._timeout = ms / 1000.0 if self.handle is not None: self.handle.settimeout(self._timeout) @@ -87,7 +95,8 @@ class TSocket(TSocketBase): message = 'Could not connect to socket %s' % self._unix_socket else: message = 'Could not connect to %s:%d' % (self.host, self.port) - raise TTransportException(type=TTransportException.NOT_OPEN, message=message) + raise TTransportException(type=TTransportException.NOT_OPEN, + message=message) def read(self, sz): try: @@ -105,24 +114,28 @@ class TSocket(TSocketBase): else: raise if len(buff) == 0: - raise TTransportException(type=TTransportException.END_OF_FILE, message='TSocket read 0 bytes') + raise TTransportException(type=TTransportException.END_OF_FILE, + message='TSocket read 0 bytes') return buff def write(self, buff): if not self.handle: - raise TTransportException(type=TTransportException.NOT_OPEN, message='Transport not open') + raise TTransportException(type=TTransportException.NOT_OPEN, + message='Transport not open') sent = 0 have = len(buff) while sent < have: plus = self.handle.send(buff) if plus == 0: - raise TTransportException(type=TTransportException.END_OF_FILE, message='TSocket sent 0 bytes') + raise TTransportException(type=TTransportException.END_OF_FILE, + message='TSocket sent 0 bytes') sent += plus buff = buff[plus:] def flush(self): pass + class TServerSocket(TSocketBase, TServerTransportBase): """Socket implementation of TServerTransport base.""" diff --git a/module/lib/thrift/transport/TTransport.py b/module/lib/thrift/transport/TTransport.py index 12e51a9bf..4481371a6 100644 --- a/module/lib/thrift/transport/TTransport.py +++ b/module/lib/thrift/transport/TTransport.py @@ -18,11 +18,11 @@ # from cStringIO import StringIO -from struct import pack,unpack +from struct import pack, unpack from thrift.Thrift import TException -class TTransportException(TException): +class TTransportException(TException): """Custom Transport Exception class""" UNKNOWN = 0 @@ -35,8 +35,8 @@ class TTransportException(TException): TException.__init__(self, message) self.type = type -class TTransportBase: +class TTransportBase: """Base class for Thrift transport layer.""" def isOpen(self): @@ -55,7 +55,7 @@ class TTransportBase: buff = '' have = 0 while (have < sz): - chunk = self.read(sz-have) + chunk = self.read(sz - have) have += len(chunk) buff += chunk @@ -70,6 +70,7 @@ class TTransportBase: def flush(self): pass + # This class should be thought of as an interface. class CReadableTransport: """base class for transports that are readable from C""" @@ -98,8 +99,8 @@ class CReadableTransport: """ pass -class TServerTransportBase: +class TServerTransportBase: """Base class for Thrift server transports.""" def listen(self): @@ -111,15 +112,15 @@ class TServerTransportBase: def close(self): pass -class TTransportFactoryBase: +class TTransportFactoryBase: """Base class for a Transport Factory""" def getTransport(self, trans): return trans -class TBufferedTransportFactory: +class TBufferedTransportFactory: """Factory transport that builds buffered transports""" def getTransport(self, trans): @@ -127,17 +128,15 @@ class TBufferedTransportFactory: return buffered -class TBufferedTransport(TTransportBase,CReadableTransport): - +class TBufferedTransport(TTransportBase, CReadableTransport): """Class that wraps another transport and buffers its I/O. The implementation uses a (configurable) fixed-size read buffer but buffers all writes until a flush is performed. """ - DEFAULT_BUFFER = 4096 - def __init__(self, trans, rbuf_size = DEFAULT_BUFFER): + def __init__(self, trans, rbuf_size=DEFAULT_BUFFER): self.__trans = trans self.__wbuf = StringIO() self.__rbuf = StringIO("") @@ -188,6 +187,7 @@ class TBufferedTransport(TTransportBase,CReadableTransport): self.__rbuf = StringIO(retstring) return self.__rbuf + class TMemoryBuffer(TTransportBase, CReadableTransport): """Wraps a cStringIO object as a TTransport. @@ -237,8 +237,8 @@ class TMemoryBuffer(TTransportBase, CReadableTransport): # only one shot at reading... raise EOFError() -class TFramedTransportFactory: +class TFramedTransportFactory: """Factory transport that builds framed transports""" def getTransport(self, trans): @@ -247,7 +247,6 @@ class TFramedTransportFactory: class TFramedTransport(TTransportBase, CReadableTransport): - """Class that wraps another transport and frames its I/O when writing.""" def __init__(self, trans,): diff --git a/module/lib/thrift/transport/TTwisted.py b/module/lib/thrift/transport/TTwisted.py index b6dcb4e0b..3ce3eb220 100644 --- a/module/lib/thrift/transport/TTwisted.py +++ b/module/lib/thrift/transport/TTwisted.py @@ -16,6 +16,9 @@ # specific language governing permissions and limitations # under the License. # + +from cStringIO import StringIO + from zope.interface import implements, Interface, Attribute from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \ connectionDone @@ -25,7 +28,6 @@ from twisted.python import log from twisted.web import server, resource, http from thrift.transport import TTransport -from cStringIO import StringIO class TMessageSenderTransport(TTransport.TTransportBase): @@ -79,7 +81,7 @@ class ThriftClientProtocol(basic.Int32StringReceiver): self.started.callback(self.client) def connectionLost(self, reason=connectionDone): - for k,v in self.client._reqs.iteritems(): + for k, v in self.client._reqs.iteritems(): tex = TTransport.TTransportException( type=TTransport.TTransportException.END_OF_FILE, message='Connection closed') diff --git a/module/lib/thrift/transport/TZlibTransport.py b/module/lib/thrift/transport/TZlibTransport.py index 784d4e1e0..a2f42a5d2 100644 --- a/module/lib/thrift/transport/TZlibTransport.py +++ b/module/lib/thrift/transport/TZlibTransport.py @@ -16,50 +16,49 @@ # specific language governing permissions and limitations # under the License. # -''' -TZlibTransport provides a compressed transport and transport factory + +"""TZlibTransport provides a compressed transport and transport factory class, using the python standard library zlib module to implement data compression. -''' +""" from __future__ import division import zlib from cStringIO import StringIO from TTransport import TTransportBase, CReadableTransport + class TZlibTransportFactory(object): - ''' - Factory transport that builds zlib compressed transports. - + """Factory transport that builds zlib compressed transports. + This factory caches the last single client/transport that it was passed and returns the same TZlibTransport object that was created. - + This caching means the TServer class will get the _same_ transport object for both input and output transports from this factory. (For non-threaded scenarios only, since the cache only holds one object) - + The purpose of this caching is to allocate only one TZlibTransport where only one is really needed (since it must have separate read/write buffers), and makes the statistics from getCompSavings() and getCompRatio() easier to understand. - ''' - + """ # class scoped cache of last transport given and zlibtransport returned _last_trans = None _last_z = None def getTransport(self, trans, compresslevel=9): - '''Wrap a transport , trans, with the TZlibTransport + """Wrap a transport, trans, with the TZlibTransport compressed transport class, returning a new transport to the caller. - + @param compresslevel: The zlib compression level, ranging from 0 (no compression) to 9 (best compression). Defaults to 9. @type compresslevel: int - + This method returns a TZlibTransport which wraps the passed C{trans} TTransport derived instance. - ''' + """ if trans == self._last_trans: return self._last_z ztrans = TZlibTransport(trans, compresslevel) @@ -69,27 +68,24 @@ class TZlibTransportFactory(object): class TZlibTransport(TTransportBase, CReadableTransport): - ''' - Class that wraps a transport with zlib, compressing writes + """Class that wraps a transport with zlib, compressing writes and decompresses reads, using the python standard library zlib module. - ''' - + """ # Read buffer size for the python fastbinary C extension, # the TBinaryProtocolAccelerated class. DEFAULT_BUFFSIZE = 4096 def __init__(self, trans, compresslevel=9): - ''' - Create a new TZlibTransport, wrapping C{trans}, another + """Create a new TZlibTransport, wrapping C{trans}, another TTransport derived object. - + @param trans: A thrift transport object, i.e. a TSocket() object. @type trans: TTransport @param compresslevel: The zlib compression level, ranging from 0 (no compression) to 9 (best compression). Default is 9. @type compresslevel: int - ''' + """ self.__trans = trans self.compresslevel = compresslevel self.__rbuf = StringIO() @@ -98,49 +94,45 @@ class TZlibTransport(TTransportBase, CReadableTransport): self._init_stats() def _reinit_buffers(self): - ''' - Internal method to initialize/reset the internal StringIO objects + """Internal method to initialize/reset the internal StringIO objects for read and write buffers. - ''' + """ self.__rbuf = StringIO() self.__wbuf = StringIO() def _init_stats(self): - ''' - Internal method to reset the internal statistics counters + """Internal method to reset the internal statistics counters for compression ratios and bandwidth savings. - ''' + """ self.bytes_in = 0 self.bytes_out = 0 self.bytes_in_comp = 0 self.bytes_out_comp = 0 def _init_zlib(self): - ''' - Internal method for setting up the zlib compression and + """Internal method for setting up the zlib compression and decompression objects. - ''' + """ self._zcomp_read = zlib.decompressobj() self._zcomp_write = zlib.compressobj(self.compresslevel) def getCompRatio(self): - ''' - Get the current measured compression ratios (in,out) from + """Get the current measured compression ratios (in,out) from this transport. - - Returns a tuple of: + + Returns a tuple of: (inbound_compression_ratio, outbound_compression_ratio) - + The compression ratios are computed as: compressed / uncompressed E.g., data that compresses by 10x will have a ratio of: 0.10 and data that compresses to half of ts original size will have a ratio of 0.5 - + None is returned if no bytes have yet been processed in a particular direction. - ''' + """ r_percent, w_percent = (None, None) if self.bytes_in > 0: r_percent = self.bytes_in_comp / self.bytes_in @@ -149,23 +141,22 @@ class TZlibTransport(TTransportBase, CReadableTransport): return (r_percent, w_percent) def getCompSavings(self): - ''' - Get the current count of saved bytes due to data + """Get the current count of saved bytes due to data compression. - + Returns a tuple of: (inbound_saved_bytes, outbound_saved_bytes) - + Note: if compression is actually expanding your data (only likely with very tiny thrift objects), then the values returned will be negative. - ''' + """ r_saved = self.bytes_in - self.bytes_in_comp w_saved = self.bytes_out - self.bytes_out_comp return (r_saved, w_saved) def isOpen(self): - '''Return the underlying transport's open status''' + """Return the underlying transport's open status""" return self.__trans.isOpen() def open(self): @@ -174,25 +165,24 @@ class TZlibTransport(TTransportBase, CReadableTransport): return self.__trans.open() def listen(self): - '''Invoke the underlying transport's listen() method''' + """Invoke the underlying transport's listen() method""" self.__trans.listen() def accept(self): - '''Accept connections on the underlying transport''' + """Accept connections on the underlying transport""" return self.__trans.accept() def close(self): - '''Close the underlying transport,''' + """Close the underlying transport,""" self._reinit_buffers() self._init_zlib() return self.__trans.close() def read(self, sz): - ''' - Read up to sz bytes from the decompressed bytes buffer, and + """Read up to sz bytes from the decompressed bytes buffer, and read from the underlying transport if the decompression buffer is empty. - ''' + """ ret = self.__rbuf.read(sz) if len(ret) > 0: return ret @@ -204,10 +194,9 @@ class TZlibTransport(TTransportBase, CReadableTransport): return ret def readComp(self, sz): - ''' - Read compressed data from the underlying transport, then + """Read compressed data from the underlying transport, then decompress it and append it to the internal StringIO read buffer - ''' + """ zbuf = self.__trans.read(sz) zbuf = self._zcomp_read.unconsumed_tail + zbuf buf = self._zcomp_read.decompress(zbuf) @@ -220,17 +209,15 @@ class TZlibTransport(TTransportBase, CReadableTransport): return True def write(self, buf): - ''' - Write some bytes, putting them into the internal write + """Write some bytes, putting them into the internal write buffer for eventual compression. - ''' + """ self.__wbuf.write(buf) def flush(self): - ''' - Flush any queued up data in the write buffer and ensure the + """Flush any queued up data in the write buffer and ensure the compression buffer is flushed out to the underlying transport - ''' + """ wout = self.__wbuf.getvalue() if len(wout) > 0: zbuf = self._zcomp_write.compress(wout) @@ -247,11 +234,11 @@ class TZlibTransport(TTransportBase, CReadableTransport): @property def cstringio_buf(self): - '''Implement the CReadableTransport interface''' + """Implement the CReadableTransport interface""" return self.__rbuf def cstringio_refill(self, partialread, reqlen): - '''Implement the CReadableTransport interface for refill''' + """Implement the CReadableTransport interface for refill""" retstring = partialread if reqlen < self.DEFAULT_BUFFSIZE: retstring += self.read(self.DEFAULT_BUFFSIZE) diff --git a/module/lib/thrift/transport/__init__.py b/module/lib/thrift/transport/__init__.py index 46e54fe6b..c9596d9a6 100644 --- a/module/lib/thrift/transport/__init__.py +++ b/module/lib/thrift/transport/__init__.py @@ -17,4 +17,4 @@ # under the License. # -__all__ = ['TTransport', 'TSocket', 'THttpClient','TZlibTransport'] +__all__ = ['TTransport', 'TSocket', 'THttpClient', 'TZlibTransport'] diff --git a/module/lib/wsgiserver/__init__.py b/module/lib/wsgiserver/__init__.py index c380e18b0..1058b19ff 100644 --- a/module/lib/wsgiserver/__init__.py +++ b/module/lib/wsgiserver/__init__.py @@ -1,7 +1,7 @@ -"""A high-speed, production ready, thread pooled, generic WSGI server. +"""A high-speed, production ready, thread pooled, generic HTTP server. Simplest example on how to use this module directly -(without using CherryPy's application machinery): +(without using CherryPy's application machinery):: from cherrypy import wsgiserver @@ -9,37 +9,29 @@ Simplest example on how to use this module directly status = '200 OK' response_headers = [('Content-type','text/plain')] start_response(status, response_headers) - return ['Hello world!\n'] + return ['Hello world!'] server = wsgiserver.CherryPyWSGIServer( ('0.0.0.0', 8070), my_crazy_app, server_name='www.cherrypy.example') + server.start() The CherryPy WSGI server can serve as many WSGI applications -as you want in one instance by using a WSGIPathInfoDispatcher: +as you want in one instance by using a WSGIPathInfoDispatcher:: d = WSGIPathInfoDispatcher({'/': my_crazy_app, '/blog': my_blog_app}) server = wsgiserver.CherryPyWSGIServer(('0.0.0.0', 80), d) -Want SSL support? Just set these attributes: - - server.ssl_certificate = <filename> - server.ssl_private_key = <filename> - - if __name__ == '__main__': - try: - server.start() - except KeyboardInterrupt: - server.stop() +Want SSL support? Just set server.ssl_adapter to an SSLAdapter instance. This won't call the CherryPy engine (application side) at all, only the -WSGI server, which is independant from the rest of CherryPy. Don't +HTTP server, which is independent from the rest of CherryPy. Don't let the name "CherryPyWSGIServer" throw you; the name merely reflects its origin, not its coupling. For those of you wanting to understand internals of this module, here's the basic call flow. The server's listening thread runs a very tight loop, -sticking incoming connections onto a Queue: +sticking incoming connections onto a Queue:: server = CherryPyWSGIServer(...) server.start() @@ -52,7 +44,7 @@ sticking incoming connections onto a Queue: Worker threads are kept in a pool and poll the Queue, popping off and then handling each connection in turn. Each connection can consist of an arbitrary -number of requests and their responses, so we run a nested loop: +number of requests and their responses, so we run a nested loop:: while True: conn = server.requests.get() @@ -62,9 +54,9 @@ number of requests and their responses, so we run a nested loop: req.parse_request() -> # Read the Request-Line, e.g. "GET /page HTTP/1.1" req.rfile.readline() - req.read_headers() + read_headers(req.rfile, req.inheaders) req.respond() - -> response = wsgi_app(...) + -> response = app(...) try: for chunk in response: if chunk: @@ -76,34 +68,81 @@ number of requests and their responses, so we run a nested loop: return """ +__all__ = ['HTTPRequest', 'HTTPConnection', 'HTTPServer', + 'SizeCheckWrapper', 'KnownLengthRFile', 'ChunkedRFile', + 'CP_fileobject', + 'MaxSizeExceeded', 'NoSSLError', 'FatalSSLAlert', + 'WorkerThread', 'ThreadPool', 'SSLAdapter', + 'CherryPyWSGIServer', + 'Gateway', 'WSGIGateway', 'WSGIGateway_10', 'WSGIGateway_u0', + 'WSGIPathInfoDispatcher'] -import base64 import os -import Queue +try: + import queue +except: + import Queue as queue import re -quoted_slash = re.compile("(?i)%2F") import rfc822 import socket +import sys +if 'win' in sys.platform and not hasattr(socket, 'IPPROTO_IPV6'): + socket.IPPROTO_IPV6 = 41 try: import cStringIO as StringIO except ImportError: import StringIO +DEFAULT_BUFFER_SIZE = -1 _fileobject_uses_str_type = isinstance(socket._fileobject(None)._rbuf, basestring) -import sys import threading import time import traceback +def format_exc(limit=None): + """Like print_exc() but return a string. Backport for Python 2.3.""" + try: + etype, value, tb = sys.exc_info() + return ''.join(traceback.format_exception(etype, value, tb, limit)) + finally: + etype = value = tb = None + + from urllib import unquote from urlparse import urlparse import warnings -try: - from OpenSSL import SSL - from OpenSSL import crypto -except ImportError: - SSL = None +if sys.version_info >= (3, 0): + bytestr = bytes + unicodestr = str + basestring = (bytes, str) + def ntob(n, encoding='ISO-8859-1'): + """Return the given native string as a byte string in the given encoding.""" + # In Python 3, the native string type is unicode + return n.encode(encoding) +else: + bytestr = str + unicodestr = unicode + basestring = basestring + def ntob(n, encoding='ISO-8859-1'): + """Return the given native string as a byte string in the given encoding.""" + # In Python 2, the native string type is bytes. Assume it's already + # in the given encoding, which for ISO-8859-1 is almost always what + # was intended. + return n + +LF = ntob('\n') +CRLF = ntob('\r\n') +TAB = ntob('\t') +SPACE = ntob(' ') +COLON = ntob(':') +SEMICOLON = ntob(';') +EMPTY = ntob('') +NUMBER_SIGN = ntob('#') +QUESTION_MARK = ntob('?') +ASTERISK = ntob('*') +FORWARD_SLASH = ntob('/') +quoted_slash = re.compile(ntob("(?i)%2F")) import errno @@ -117,7 +156,7 @@ def plat_specific_errors(*errnames): errno_names = dir(errno) nums = [getattr(errno, k) for k in errnames if k in errno_names] # de-dupe the list - return dict.fromkeys(nums).keys() + return list(dict.fromkeys(nums).keys()) socket_error_eintr = plat_specific_errors("EINTR", "WSAEINTR") @@ -133,51 +172,71 @@ socket_errors_to_ignore = plat_specific_errors( "EHOSTDOWN", "EHOSTUNREACH", ) socket_errors_to_ignore.append("timed out") +socket_errors_to_ignore.append("The read operation timed out") socket_errors_nonblocking = plat_specific_errors( 'EAGAIN', 'EWOULDBLOCK', 'WSAEWOULDBLOCK') -comma_separated_headers = ['ACCEPT', 'ACCEPT-CHARSET', 'ACCEPT-ENCODING', - 'ACCEPT-LANGUAGE', 'ACCEPT-RANGES', 'ALLOW', 'CACHE-CONTROL', - 'CONNECTION', 'CONTENT-ENCODING', 'CONTENT-LANGUAGE', 'EXPECT', - 'IF-MATCH', 'IF-NONE-MATCH', 'PRAGMA', 'PROXY-AUTHENTICATE', 'TE', - 'TRAILER', 'TRANSFER-ENCODING', 'UPGRADE', 'VARY', 'VIA', 'WARNING', - 'WWW-AUTHENTICATE'] +comma_separated_headers = [ntob(h) for h in + ['Accept', 'Accept-Charset', 'Accept-Encoding', + 'Accept-Language', 'Accept-Ranges', 'Allow', 'Cache-Control', + 'Connection', 'Content-Encoding', 'Content-Language', 'Expect', + 'If-Match', 'If-None-Match', 'Pragma', 'Proxy-Authenticate', 'TE', + 'Trailer', 'Transfer-Encoding', 'Upgrade', 'Vary', 'Via', 'Warning', + 'WWW-Authenticate']] -class WSGIPathInfoDispatcher(object): - """A WSGI dispatcher for dispatch based on the PATH_INFO. +import logging +if not hasattr(logging, 'statistics'): logging.statistics = {} + + +def read_headers(rfile, hdict=None): + """Read headers from the given stream into the given header dict. - apps: a dict or list of (path_prefix, app) pairs. + If hdict is None, a new header dict is created. Returns the populated + header dict. + + Headers which are repeated are folded together using a comma if their + specification so dictates. + + This function raises ValueError when the read bytes violate the HTTP spec. + You should probably return "400 Bad Request" if this happens. """ + if hdict is None: + hdict = {} - def __init__(self, apps): - try: - apps = apps.items() - except AttributeError: - pass + while True: + line = rfile.readline() + if not line: + # No more data--illegal end of headers + raise ValueError("Illegal end of headers.") - # Sort the apps by len(path), descending - apps.sort() - apps.reverse() + if line == CRLF: + # Normal end of headers + break + if not line.endswith(CRLF): + raise ValueError("HTTP requires CRLF terminators") - # The path_prefix strings must start, but not end, with a slash. - # Use "" instead of "/". - self.apps = [(p.rstrip("/"), a) for p, a in apps] - - def __call__(self, environ, start_response): - path = environ["PATH_INFO"] or "/" - for p, app in self.apps: - # The apps list should be sorted by length, descending. - if path.startswith(p + "/") or path == p: - environ = environ.copy() - environ["SCRIPT_NAME"] = environ["SCRIPT_NAME"] + p - environ["PATH_INFO"] = path[len(p):] - return app(environ, start_response) + if line[0] in (SPACE, TAB): + # It's a continuation line. + v = line.strip() + else: + try: + k, v = line.split(COLON, 1) + except ValueError: + raise ValueError("Illegal header line.") + # TODO: what about TE and WWW-Authenticate? + k = k.strip().title() + v = v.strip() + hname = k - start_response('404 Not Found', [('Content-Type', 'text/plain'), - ('Content-Length', '0')]) - return [''] + if k in comma_separated_headers: + existing = hdict.get(hname) + if existing: + v = ", ".join((existing, v)) + hdict[hname] = v + + return hdict class MaxSizeExceeded(Exception): @@ -218,7 +277,7 @@ class SizeCheckWrapper(object): res.append(data) # See http://www.cherrypy.org/ticket/421 if len(data) < 256 or data[-1:] == "\n": - return ''.join(res) + return EMPTY.join(res) def readlines(self, sizehint=0): # Shamelessly stolen from StringIO @@ -239,6 +298,12 @@ class SizeCheckWrapper(object): def __iter__(self): return self + def __next__(self): + data = next(self.rfile) + self.bytes_read += len(data) + self._check_length() + return data + def next(self): data = self.rfile.next() self.bytes_read += len(data) @@ -246,70 +311,293 @@ class SizeCheckWrapper(object): return data +class KnownLengthRFile(object): + """Wraps a file-like object, returning an empty string when exhausted.""" + + def __init__(self, rfile, content_length): + self.rfile = rfile + self.remaining = content_length + + def read(self, size=None): + if self.remaining == 0: + return '' + if size is None: + size = self.remaining + else: + size = min(size, self.remaining) + + data = self.rfile.read(size) + self.remaining -= len(data) + return data + + def readline(self, size=None): + if self.remaining == 0: + return '' + if size is None: + size = self.remaining + else: + size = min(size, self.remaining) + + data = self.rfile.readline(size) + self.remaining -= len(data) + return data + + def readlines(self, sizehint=0): + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline(sizehint) + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline(sizehint) + return lines + + def close(self): + self.rfile.close() + + def __iter__(self): + return self + + def __next__(self): + data = next(self.rfile) + self.remaining -= len(data) + return data + + +class ChunkedRFile(object): + """Wraps a file-like object, returning an empty string when exhausted. + + This class is intended to provide a conforming wsgi.input value for + request entities that have been encoded with the 'chunked' transfer + encoding. + """ + + def __init__(self, rfile, maxlen, bufsize=8192): + self.rfile = rfile + self.maxlen = maxlen + self.bytes_read = 0 + self.buffer = EMPTY + self.bufsize = bufsize + self.closed = False + + def _fetch(self): + if self.closed: + return + + line = self.rfile.readline() + self.bytes_read += len(line) + + if self.maxlen and self.bytes_read > self.maxlen: + raise MaxSizeExceeded("Request Entity Too Large", self.maxlen) + + line = line.strip().split(SEMICOLON, 1) + + try: + chunk_size = line.pop(0) + chunk_size = int(chunk_size, 16) + except ValueError: + raise ValueError("Bad chunked transfer size: " + repr(chunk_size)) + + if chunk_size <= 0: + self.closed = True + return + +## if line: chunk_extension = line[0] + + if self.maxlen and self.bytes_read + chunk_size > self.maxlen: + raise IOError("Request Entity Too Large") + + chunk = self.rfile.read(chunk_size) + self.bytes_read += len(chunk) + self.buffer += chunk + + crlf = self.rfile.read(2) + if crlf != CRLF: + raise ValueError( + "Bad chunked transfer coding (expected '\\r\\n', " + "got " + repr(crlf) + ")") + + def read(self, size=None): + data = EMPTY + while True: + if size and len(data) >= size: + return data + + if not self.buffer: + self._fetch() + if not self.buffer: + # EOF + return data + + if size: + remaining = size - len(data) + data += self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + else: + data += self.buffer + + def readline(self, size=None): + data = EMPTY + while True: + if size and len(data) >= size: + return data + + if not self.buffer: + self._fetch() + if not self.buffer: + # EOF + return data + + newline_pos = self.buffer.find(LF) + if size: + if newline_pos == -1: + remaining = size - len(data) + data += self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + else: + remaining = min(size - len(data), newline_pos) + data += self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + else: + if newline_pos == -1: + data += self.buffer + else: + data += self.buffer[:newline_pos] + self.buffer = self.buffer[newline_pos:] + + def readlines(self, sizehint=0): + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline(sizehint) + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline(sizehint) + return lines + + def read_trailer_lines(self): + if not self.closed: + raise ValueError( + "Cannot read trailers until the request body has been read.") + + while True: + line = self.rfile.readline() + if not line: + # No more data--illegal end of headers + raise ValueError("Illegal end of headers.") + + self.bytes_read += len(line) + if self.maxlen and self.bytes_read > self.maxlen: + raise IOError("Request Entity Too Large") + + if line == CRLF: + # Normal end of headers + break + if not line.endswith(CRLF): + raise ValueError("HTTP requires CRLF terminators") + + yield line + + def close(self): + self.rfile.close() + + def __iter__(self): + # Shamelessly stolen from StringIO + total = 0 + line = self.readline(sizehint) + while line: + yield line + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline(sizehint) + + class HTTPRequest(object): """An HTTP Request (and response). A single HTTP connection may consist of multiple request/response pairs. - - send: the 'send' method from the connection's socket object. - wsgi_app: the WSGI application to call. - environ: a partial WSGI environ (server and connection entries). - The caller MUST set the following entries: - * All wsgi.* entries, including .input - * SERVER_NAME and SERVER_PORT - * Any SSL_* entries - * Any custom entries like REMOTE_ADDR and REMOTE_PORT - * SERVER_SOFTWARE: the value to write in the "Server" response header. - * ACTUAL_SERVER_PROTOCOL: the value to write in the Status-Line of - the response. From RFC 2145: "An HTTP server SHOULD send a - response version equal to the highest version for which the - server is at least conditionally compliant, and whose major - version is less than or equal to the one received in the - request. An HTTP server MUST NOT send a version for which - it is not at least conditionally compliant." - - outheaders: a list of header tuples to write in the response. - ready: when True, the request has been parsed and is ready to begin - generating the response. When False, signals the calling Connection - that the response should not be generated and the connection should - close. - close_connection: signals the calling Connection that the request - should close. This does not imply an error! The client and/or - server may each request that the connection be closed. - chunked_write: if True, output will be encoded with the "chunked" - transfer-coding. This value is set automatically inside - send_headers. """ - max_request_header_size = 0 - max_request_body_size = 0 + server = None + """The HTTPServer object which is receiving this request.""" - def __init__(self, wfile, environ, wsgi_app): - self.rfile = environ['wsgi.input'] - self.wfile = wfile - self.environ = environ.copy() - self.wsgi_app = wsgi_app + conn = None + """The HTTPConnection object on which this request connected.""" + + inheaders = {} + """A dict of request headers.""" + + outheaders = [] + """A list of header tuples to write in the response.""" + + ready = False + """When True, the request has been parsed and is ready to begin generating + the response. When False, signals the calling Connection that the response + should not be generated and the connection should close.""" + + close_connection = False + """Signals the calling Connection that the request should close. This does + not imply an error! The client and/or server may each request that the + connection be closed.""" + + chunked_write = False + """If True, output will be encoded with the "chunked" transfer-coding. + + This value is set automatically inside send_headers.""" + + def __init__(self, server, conn): + self.server= server + self.conn = conn self.ready = False - self.started_response = False + self.started_request = False + self.scheme = ntob("http") + if self.server.ssl_adapter is not None: + self.scheme = ntob("https") + # Use the lowest-common protocol in case read_request_line errors. + self.response_protocol = 'HTTP/1.0' + self.inheaders = {} + self.status = "" self.outheaders = [] self.sent_headers = False - self.close_connection = False - self.chunked_write = False + self.close_connection = self.__class__.close_connection + self.chunked_read = False + self.chunked_write = self.__class__.chunked_write def parse_request(self): """Parse the next HTTP request start-line and message-headers.""" - self.rfile.maxlen = self.max_request_header_size - self.rfile.bytes_read = 0 + self.rfile = SizeCheckWrapper(self.conn.rfile, + self.server.max_request_header_size) + try: + self.read_request_line() + except MaxSizeExceeded: + self.simple_response("414 Request-URI Too Long", + "The Request-URI sent with the request exceeds the maximum " + "allowed bytes.") + return try: - self._parse_request() + success = self.read_request_headers() except MaxSizeExceeded: - self.simple_response("413 Request Entity Too Large") + self.simple_response("413 Request Entity Too Large", + "The headers sent with the request exceed the maximum " + "allowed bytes.") return + else: + if not success: + return + + self.ready = True - def _parse_request(self): + def read_request_line(self): # HTTP/1.1 connections are persistent by default. If a client # requests a page, then idles (leaves the connection open), # then rfile.readline() will raise socket.error("timed out"). @@ -318,12 +606,16 @@ class HTTPRequest(object): # (although your TCP stack might suffer for it: cf Apache's history # with FIN_WAIT_2). request_line = self.rfile.readline() + + # Set started_request to True so communicate() knows to send 408 + # from here on out. + self.started_request = True if not request_line: # Force self.ready = False so the connection will close. self.ready = False return - if request_line == "\r\n": + if request_line == CRLF: # RFC 2616 sec 4.1: "...if the server is reading the protocol # stream at the beginning of a message and receives a CRLF # first, it should ignore the CRLF." @@ -333,44 +625,53 @@ class HTTPRequest(object): self.ready = False return - environ = self.environ + if not request_line.endswith(CRLF): + self.simple_response("400 Bad Request", "HTTP requires CRLF terminators") + return try: - method, path, req_protocol = request_line.strip().split(" ", 2) - except ValueError: - self.simple_response(400, "Malformed Request-Line") + method, uri, req_protocol = request_line.strip().split(SPACE, 2) + rp = int(req_protocol[5]), int(req_protocol[7]) + except (ValueError, IndexError): + self.simple_response("400 Bad Request", "Malformed Request-Line") return - environ["REQUEST_METHOD"] = method - - # path may be an abs_path (including "http://host.domain.tld"); - scheme, location, path, params, qs, frag = urlparse(path) + self.uri = uri + self.method = method - if frag: + # uri may be an abs_path (including "http://host.domain.tld"); + scheme, authority, path = self.parse_request_uri(uri) + if NUMBER_SIGN in path: self.simple_response("400 Bad Request", "Illegal #fragment in Request-URI.") return if scheme: - environ["wsgi.url_scheme"] = scheme - if params: - path = path + ";" + params + self.scheme = scheme - environ["SCRIPT_NAME"] = "" + qs = EMPTY + if QUESTION_MARK in path: + path, qs = path.split(QUESTION_MARK, 1) - # Unquote the path+params (e.g. "/this%20path" -> "this path"). + # Unquote the path+params (e.g. "/this%20path" -> "/this path"). # http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5.1.2 # # But note that "...a URI must be separated into its components # before the escaped characters within those components can be # safely decoded." http://www.ietf.org/rfc/rfc2396.txt, sec 2.4.2 - atoms = [unquote(x) for x in quoted_slash.split(path)] + # Therefore, "/this%2Fpath" becomes "/this%2Fpath", not "/this/path". + try: + atoms = [unquote(x) for x in quoted_slash.split(path)] + except ValueError: + ex = sys.exc_info()[1] + self.simple_response("400 Bad Request", ex.args[0]) + return path = "%2F".join(atoms) - environ["PATH_INFO"] = path + self.path = path - # Note that, like wsgiref and most other WSGI servers, - # we unquote the path but not the query string. - environ["QUERY_STRING"] = qs + # Note that, like wsgiref and most other HTTP servers, + # we "% HEX HEX"-unquote the path but not the query string. + self.qs = qs # Compare request and server HTTP protocol versions, in case our # server does not support the requested protocol. Limit our output @@ -384,46 +685,46 @@ class HTTPRequest(object): # Notice that, in (b), the response will be "HTTP/1.1" even though # the client only understands 1.0. RFC 2616 10.5.6 says we should # only return 505 if the _major_ version is different. - rp = int(req_protocol[5]), int(req_protocol[7]) - server_protocol = environ["ACTUAL_SERVER_PROTOCOL"] - sp = int(server_protocol[5]), int(server_protocol[7]) + sp = int(self.server.protocol[5]), int(self.server.protocol[7]) + if sp[0] != rp[0]: self.simple_response("505 HTTP Version Not Supported") return - # Bah. "SERVER_PROTOCOL" is actually the REQUEST protocol. - environ["SERVER_PROTOCOL"] = req_protocol + self.request_protocol = req_protocol self.response_protocol = "HTTP/%s.%s" % min(rp, sp) - - # If the Request-URI was an absoluteURI, use its location atom. - if location: - environ["SERVER_NAME"] = location + + def read_request_headers(self): + """Read self.rfile into self.inheaders. Return success.""" # then all the http headers try: - self.read_headers() - except ValueError, ex: - self.simple_response("400 Bad Request", repr(ex.args)) - return + read_headers(self.rfile, self.inheaders) + except ValueError: + ex = sys.exc_info()[1] + self.simple_response("400 Bad Request", ex.args[0]) + return False - mrbs = self.max_request_body_size - if mrbs and int(environ.get("CONTENT_LENGTH", 0)) > mrbs: - self.simple_response("413 Request Entity Too Large") - return + mrbs = self.server.max_request_body_size + if mrbs and int(self.inheaders.get("Content-Length", 0)) > mrbs: + self.simple_response("413 Request Entity Too Large", + "The entity sent with the request exceeds the maximum " + "allowed bytes.") + return False # Persistent connection support if self.response_protocol == "HTTP/1.1": # Both server and client are HTTP/1.1 - if environ.get("HTTP_CONNECTION", "") == "close": + if self.inheaders.get("Connection", "") == "close": self.close_connection = True else: # Either the server or client (or both) are HTTP/1.0 - if environ.get("HTTP_CONNECTION", "") != "Keep-Alive": + if self.inheaders.get("Connection", "") != "Keep-Alive": self.close_connection = True # Transfer-Encoding support te = None if self.response_protocol == "HTTP/1.1": - te = environ.get("HTTP_TRANSFER_ENCODING") + te = self.inheaders.get("Transfer-Encoding") if te: te = [x.strip().lower() for x in te.split(",") if x.strip()] @@ -438,7 +739,7 @@ class HTTPRequest(object): # if there is an extension we don't recognize. self.simple_response("501 Unimplemented") self.close_connection = True - return + return False # From PEP 333: # "Servers and gateways that implement HTTP 1.1 must provide @@ -457,188 +758,128 @@ class HTTPRequest(object): # # We used to do 3, but are now doing 1. Maybe we'll do 2 someday, # but it seems like it would be a big slowdown for such a rare case. - if environ.get("HTTP_EXPECT", "") == "100-continue": - self.simple_response(100) - - self.ready = True + if self.inheaders.get("Expect", "") == "100-continue": + # Don't use simple_response here, because it emits headers + # we don't want. See http://www.cherrypy.org/ticket/951 + msg = self.server.protocol + " 100 Continue\r\n\r\n" + try: + self.conn.wfile.sendall(msg) + except socket.error: + x = sys.exc_info()[1] + if x.args[0] not in socket_errors_to_ignore: + raise + return True - def read_headers(self): - """Read header lines from the incoming stream.""" - environ = self.environ + def parse_request_uri(self, uri): + """Parse a Request-URI into (scheme, authority, path). - while True: - line = self.rfile.readline() - if not line: - # No more data--illegal end of headers - raise ValueError("Illegal end of headers.") - - if line == '\r\n': - # Normal end of headers - break - - if line[0] in ' \t': - # It's a continuation line. - v = line.strip() - else: - k, v = line.split(":", 1) - k, v = k.strip().upper(), v.strip() - envname = "HTTP_" + k.replace("-", "_") + Note that Request-URI's must be one of:: - if k in comma_separated_headers: - existing = environ.get(envname) - if existing: - v = ", ".join((existing, v)) - environ[envname] = v + Request-URI = "*" | absoluteURI | abs_path | authority - ct = environ.pop("HTTP_CONTENT_TYPE", None) - if ct is not None: - environ["CONTENT_TYPE"] = ct - cl = environ.pop("HTTP_CONTENT_LENGTH", None) - if cl is not None: - environ["CONTENT_LENGTH"] = cl - - def decode_chunked(self): - """Decode the 'chunked' transfer coding.""" - cl = 0 - data = StringIO.StringIO() - while True: - line = self.rfile.readline().strip().split(";", 1) - chunk_size = int(line.pop(0), 16) - if chunk_size <= 0: - break -## if line: chunk_extension = line[0] - cl += chunk_size - data.write(self.rfile.read(chunk_size)) - crlf = self.rfile.read(2) - if crlf != "\r\n": - self.simple_response("400 Bad Request", - "Bad chunked transfer coding " - "(expected '\\r\\n', got %r)" % crlf) - return + Therefore, a Request-URI which starts with a double forward-slash + cannot be a "net_path":: - # Grab any trailer headers - self.read_headers() + net_path = "//" authority [ abs_path ] - data.seek(0) - self.environ["wsgi.input"] = data - self.environ["CONTENT_LENGTH"] = str(cl) or "" - return True + Instead, it must be interpreted as an "abs_path" with an empty first + path segment:: + + abs_path = "/" path_segments + path_segments = segment *( "/" segment ) + segment = *pchar *( ";" param ) + param = *pchar + """ + if uri == ASTERISK: + return None, None, uri + + i = uri.find('://') + if i > 0 and QUESTION_MARK not in uri[:i]: + # An absoluteURI. + # If there's a scheme (and it must be http or https), then: + # http_URL = "http:" "//" host [ ":" port ] [ abs_path [ "?" query ]] + scheme, remainder = uri[:i].lower(), uri[i + 3:] + authority, path = remainder.split(FORWARD_SLASH, 1) + path = FORWARD_SLASH + path + return scheme, authority, path + + if uri.startswith(FORWARD_SLASH): + # An abs_path. + return None, None, uri + else: + # An authority. + return None, uri, None def respond(self): - """Call the appropriate WSGI app and write its iterable output.""" - # Set rfile.maxlen to ensure we don't read past Content-Length. - # This will also be used to read the entire request body if errors - # are raised before the app can read the body. + """Call the gateway and write its iterable output.""" + mrbs = self.server.max_request_body_size if self.chunked_read: - # If chunked, Content-Length will be 0. - self.rfile.maxlen = self.max_request_body_size + self.rfile = ChunkedRFile(self.conn.rfile, mrbs) else: - cl = int(self.environ.get("CONTENT_LENGTH", 0)) - if self.max_request_body_size: - self.rfile.maxlen = min(cl, self.max_request_body_size) - else: - self.rfile.maxlen = cl - self.rfile.bytes_read = 0 - - try: - self._respond() - except MaxSizeExceeded: - if not self.sent_headers: - self.simple_response("413 Request Entity Too Large") - return - - def _respond(self): - if self.chunked_read: - if not self.decode_chunked(): - self.close_connection = True + cl = int(self.inheaders.get("Content-Length", 0)) + if mrbs and mrbs < cl: + if not self.sent_headers: + self.simple_response("413 Request Entity Too Large", + "The entity sent with the request exceeds the maximum " + "allowed bytes.") return + self.rfile = KnownLengthRFile(self.conn.rfile, cl) - response = self.wsgi_app(self.environ, self.start_response) - try: - for chunk in response: - # "The start_response callable must not actually transmit - # the response headers. Instead, it must store them for the - # server or gateway to transmit only after the first - # iteration of the application return value that yields - # a NON-EMPTY string, or upon the application's first - # invocation of the write() callable." (PEP 333) - if chunk: - self.write(chunk) - finally: - if hasattr(response, "close"): - response.close() + self.server.gateway(self).respond() if (self.ready and not self.sent_headers): self.sent_headers = True self.send_headers() if self.chunked_write: - self.wfile.sendall("0\r\n\r\n") + self.conn.wfile.sendall("0\r\n\r\n") def simple_response(self, status, msg=""): """Write a simple response back to the client.""" status = str(status) - buf = ["%s %s\r\n" % (self.environ['ACTUAL_SERVER_PROTOCOL'], status), + buf = [self.server.protocol + SPACE + + status + CRLF, "Content-Length: %s\r\n" % len(msg), "Content-Type: text/plain\r\n"] - if status[:3] == "413" and self.response_protocol == 'HTTP/1.1': - # Request Entity Too Large + if status[:3] in ("413", "414"): + # Request Entity Too Large / Request-URI Too Long self.close_connection = True - buf.append("Connection: close\r\n") + if self.response_protocol == 'HTTP/1.1': + # This will not be true for 414, since read_request_line + # usually raises 414 before reading the whole line, and we + # therefore cannot know the proper response_protocol. + buf.append("Connection: close\r\n") + else: + # HTTP/1.0 had no 413/414 status nor Connection header. + # Emit 400 instead and trust the message body is enough. + status = "400 Bad Request" - buf.append("\r\n") + buf.append(CRLF) if msg: + if isinstance(msg, unicodestr): + msg = msg.encode("ISO-8859-1") buf.append(msg) try: - self.wfile.sendall("".join(buf)) - except socket.error, x: + self.conn.wfile.sendall("".join(buf)) + except socket.error: + x = sys.exc_info()[1] if x.args[0] not in socket_errors_to_ignore: raise - def start_response(self, status, headers, exc_info = None): - """WSGI callable to begin the HTTP response.""" - # "The application may call start_response more than once, - # if and only if the exc_info argument is provided." - if self.started_response and not exc_info: - raise AssertionError("WSGI start_response called a second " - "time with no exc_info.") - - # "if exc_info is provided, and the HTTP headers have already been - # sent, start_response must raise an error, and should raise the - # exc_info tuple." - if self.sent_headers: - try: - raise exc_info[0], exc_info[1], exc_info[2] - finally: - exc_info = None - - self.started_response = True - self.status = status - self.outheaders.extend(headers) - return self.write - def write(self, chunk): - """WSGI callable to write unbuffered data to the client. - - This method is also used internally by start_response (to write - data from the iterable returned by the WSGI application). - """ - if not self.started_response: - raise AssertionError("WSGI write called before start_response.") - - if not self.sent_headers: - self.sent_headers = True - self.send_headers() - + """Write unbuffered data to the client.""" if self.chunked_write and chunk: - buf = [hex(len(chunk))[2:], "\r\n", chunk, "\r\n"] - self.wfile.sendall("".join(buf)) + buf = [hex(len(chunk))[2:], CRLF, chunk, CRLF] + self.conn.wfile.sendall(EMPTY.join(buf)) else: - self.wfile.sendall(chunk) + self.conn.wfile.sendall(chunk) def send_headers(self): - """Assert, process, and send the HTTP response message-headers.""" + """Assert, process, and send the HTTP response message-headers. + + You must set self.status, and self.outheaders before calling this. + """ hkeys = [key.lower() for key, value in self.outheaders] status = int(self.status[:3]) @@ -653,7 +894,7 @@ class HTTPRequest(object): pass else: if (self.response_protocol == 'HTTP/1.1' - and self.environ["REQUEST_METHOD"] != 'HEAD'): + and self.method != 'HEAD'): # Use the chunked transfer-coding self.chunked_write = True self.outheaders.append(("Transfer-Encoding", "chunked")) @@ -684,28 +925,21 @@ class HTTPRequest(object): # requirement is not be construed as preventing a server from # defending itself against denial-of-service attacks, or from # badly broken client implementations." - size = self.rfile.maxlen - self.rfile.bytes_read - if size > 0: - self.rfile.read(size) + remaining = getattr(self.rfile, 'remaining', 0) + if remaining > 0: + self.rfile.read(remaining) if "date" not in hkeys: self.outheaders.append(("Date", rfc822.formatdate())) if "server" not in hkeys: - self.outheaders.append(("Server", self.environ['SERVER_SOFTWARE'])) + self.outheaders.append(("Server", self.server.server_name)) - buf = [self.environ['ACTUAL_SERVER_PROTOCOL'], " ", self.status, "\r\n"] - try: - buf += [k + ": " + v + "\r\n" for k, v in self.outheaders] - except TypeError: - if not isinstance(k, str): - raise TypeError("WSGI response header key %r is not a string.") - if not isinstance(v, str): - raise TypeError("WSGI response header value %r is not a string.") - else: - raise - buf.append("\r\n") - self.wfile.sendall("".join(buf)) + buf = [self.server.protocol + SPACE + self.status + CRLF] + for k, v in self.outheaders: + buf.append(k + COLON + SPACE + v + CRLF) + buf.append(CRLF) + self.conn.wfile.sendall(EMPTY.join(buf)) class NoSSLError(Exception): @@ -718,38 +952,47 @@ class FatalSSLAlert(Exception): pass -if not _fileobject_uses_str_type: - class CP_fileobject(socket._fileobject): - """Faux file object attached to a socket object.""" - - def sendall(self, data): - """Sendall for non-blocking sockets.""" - while data: - try: - bytes_sent = self.send(data) - data = data[bytes_sent:] - except socket.error, e: - if e.args[0] not in socket_errors_nonblocking: - raise - - def send(self, data): - return self._sock.send(data) - - def flush(self): - if self._wbuf: - buffer = "".join(self._wbuf) - self._wbuf = [] - self.sendall(buffer) - - def recv(self, size): - while True: - try: - return self._sock.recv(size) - except socket.error, e: - if (e.args[0] not in socket_errors_nonblocking - and e.args[0] not in socket_error_eintr): - raise +class CP_fileobject(socket._fileobject): + """Faux file object attached to a socket object.""" + def __init__(self, *args, **kwargs): + self.bytes_read = 0 + self.bytes_written = 0 + socket._fileobject.__init__(self, *args, **kwargs) + + def sendall(self, data): + """Sendall for non-blocking sockets.""" + while data: + try: + bytes_sent = self.send(data) + data = data[bytes_sent:] + except socket.error, e: + if e.args[0] not in socket_errors_nonblocking: + raise + + def send(self, data): + bytes_sent = self._sock.send(data) + self.bytes_written += bytes_sent + return bytes_sent + + def flush(self): + if self._wbuf: + buffer = "".join(self._wbuf) + self._wbuf = [] + self.sendall(buffer) + + def recv(self, size): + while True: + try: + data = self._sock.recv(size) + self.bytes_read += len(data) + return data + except socket.error, e: + if (e.args[0] not in socket_errors_nonblocking + and e.args[0] not in socket_error_eintr): + raise + + if not _fileobject_uses_str_type: def read(self, size=-1): # Use max, disallow tiny reads in a loop as they are very inefficient. # We never leave read() with any leftover data from a new recv() call @@ -895,39 +1138,7 @@ if not _fileobject_uses_str_type: buf_len += n #assert buf_len == buf.tell() return buf.getvalue() - -else: - class CP_fileobject(socket._fileobject): - """Faux file object attached to a socket object.""" - - def sendall(self, data): - """Sendall for non-blocking sockets.""" - while data: - try: - bytes_sent = self.send(data) - data = data[bytes_sent:] - except socket.error, e: - if e.args[0] not in socket_errors_nonblocking: - raise - - def send(self, data): - return self._sock.send(data) - - def flush(self): - if self._wbuf: - buffer = "".join(self._wbuf) - self._wbuf = [] - self.sendall(buffer) - - def recv(self, size): - while True: - try: - return self._sock.recv(size) - except socket.error, e: - if (e.args[0] not in socket_errors_nonblocking - and e.args[0] not in socket_error_eintr): - raise - + else: def read(self, size=-1): if size < 0: # Read until EOF @@ -1039,168 +1250,102 @@ else: break buf_len += n return "".join(buffers) - - -class SSL_fileobject(CP_fileobject): - """SSL file object attached to a socket object.""" - - ssl_timeout = 3 - ssl_retry = .01 - - def _safe_call(self, is_reader, call, *args, **kwargs): - """Wrap the given call with SSL error-trapping. - - is_reader: if False EOF errors will be raised. If True, EOF errors - will return "" (to emulate normal sockets). - """ - start = time.time() - while True: - try: - return call(*args, **kwargs) - except SSL.WantReadError: - # Sleep and try again. This is dangerous, because it means - # the rest of the stack has no way of differentiating - # between a "new handshake" error and "client dropped". - # Note this isn't an endless loop: there's a timeout below. - time.sleep(self.ssl_retry) - except SSL.WantWriteError: - time.sleep(self.ssl_retry) - except SSL.SysCallError, e: - if is_reader and e.args == (-1, 'Unexpected EOF'): - return "" - - errnum = e.args[0] - if is_reader and errnum in socket_errors_to_ignore: - return "" - raise socket.error(errnum) - except SSL.Error, e: - if is_reader and e.args == (-1, 'Unexpected EOF'): - return "" - - thirdarg = None - try: - thirdarg = e.args[0][0][2] - except IndexError: - pass - - if thirdarg == 'http request': - # The client is talking HTTP to an HTTPS server. - raise NoSSLError() - raise FatalSSLAlert(*e.args) - except: - raise - - if time.time() - start > self.ssl_timeout: - raise socket.timeout("timed out") - - def recv(self, *args, **kwargs): - buf = [] - r = super(SSL_fileobject, self).recv - while True: - data = self._safe_call(True, r, *args, **kwargs) - buf.append(data) - p = self._sock.pending() - if not p: - return "".join(buf) - - def sendall(self, *args, **kwargs): - return self._safe_call(False, super(SSL_fileobject, self).sendall, *args, **kwargs) - - def send(self, *args, **kwargs): - return self._safe_call(False, super(SSL_fileobject, self).send, *args, **kwargs) class HTTPConnection(object): """An HTTP connection (active socket). + server: the Server object which received this connection. socket: the raw socket object (usually TCP) for this connection. - wsgi_app: the WSGI application for this server/connection. - environ: a WSGI environ template. This will be copied for each request. - - rfile: a fileobject for reading from the socket. - send: a function for writing (+ flush) to the socket. + makefile: a fileobject class for reading from the socket. """ - rbufsize = -1 + remote_addr = None + remote_port = None + ssl_env = None + rbufsize = DEFAULT_BUFFER_SIZE + wbufsize = DEFAULT_BUFFER_SIZE RequestHandlerClass = HTTPRequest - environ = {"wsgi.version": (1, 0), - "wsgi.url_scheme": "http", - "wsgi.multithread": True, - "wsgi.multiprocess": False, - "wsgi.run_once": False, - "wsgi.errors": sys.stderr, - } - - def __init__(self, sock, wsgi_app, environ): + + def __init__(self, server, sock, makefile=CP_fileobject): + self.server = server self.socket = sock - self.wsgi_app = wsgi_app - - # Copy the class environ into self. - self.environ = self.environ.copy() - self.environ.update(environ) - - if SSL and isinstance(sock, SSL.ConnectionType): - timeout = sock.gettimeout() - self.rfile = SSL_fileobject(sock, "rb", self.rbufsize) - self.rfile.ssl_timeout = timeout - self.wfile = SSL_fileobject(sock, "wb", -1) - self.wfile.ssl_timeout = timeout - else: - self.rfile = CP_fileobject(sock, "rb", self.rbufsize) - self.wfile = CP_fileobject(sock, "wb", -1) - - # Wrap wsgi.input but not HTTPConnection.rfile itself. - # We're also not setting maxlen yet; we'll do that separately - # for headers and body for each iteration of self.communicate - # (if maxlen is 0 the wrapper doesn't check length). - self.environ["wsgi.input"] = SizeCheckWrapper(self.rfile, 0) + self.rfile = makefile(sock, "rb", self.rbufsize) + self.wfile = makefile(sock, "wb", self.wbufsize) + self.requests_seen = 0 def communicate(self): """Read each request and respond appropriately.""" + request_seen = False try: while True: # (re)set req to None so that if something goes wrong in # the RequestHandlerClass constructor, the error doesn't # get written to the previous request. req = None - req = self.RequestHandlerClass(self.wfile, self.environ, - self.wsgi_app) + req = self.RequestHandlerClass(self.server, self) # This order of operations should guarantee correct pipelining. req.parse_request() + if self.server.stats['Enabled']: + self.requests_seen += 1 if not req.ready: + # Something went wrong in the parsing (and the server has + # probably already made a simple_response). Return and + # let the conn close. return + request_seen = True req.respond() if req.close_connection: return - - except socket.error, e: + except socket.error: + e = sys.exc_info()[1] errnum = e.args[0] - if errnum == 'timed out': - if req and not req.sent_headers: - req.simple_response("408 Request Timeout") + # sadly SSL sockets return a different (longer) time out string + if errnum == 'timed out' or errnum == 'The read operation timed out': + # Don't error if we're between requests; only error + # if 1) no request has been started at all, or 2) we're + # in the middle of a request. + # See http://www.cherrypy.org/ticket/853 + if (not request_seen) or (req and req.started_request): + # Don't bother writing the 408 if the response + # has already started being written. + if req and not req.sent_headers: + try: + req.simple_response("408 Request Timeout") + except FatalSSLAlert: + # Close the connection. + return elif errnum not in socket_errors_to_ignore: if req and not req.sent_headers: - req.simple_response("500 Internal Server Error", - format_exc()) + try: + req.simple_response("500 Internal Server Error", + format_exc()) + except FatalSSLAlert: + # Close the connection. + return return except (KeyboardInterrupt, SystemExit): raise - except FatalSSLAlert, e: + except FatalSSLAlert: # Close the connection. return except NoSSLError: if req and not req.sent_headers: # Unwrap our wfile - req.wfile = CP_fileobject(self.socket._sock, "wb", -1) + self.wfile = CP_fileobject(self.socket._sock, "wb", self.wbufsize) req.simple_response("400 Bad Request", "The client sent a plain HTTP request, but " "this server only speaks HTTPS on this port.") self.linger = True - except Exception, e: + except Exception: if req and not req.sent_headers: - req.simple_response("500 Internal Server Error", format_exc()) + try: + req.simple_response("500 Internal Server Error", format_exc()) + except FatalSSLAlert: + # Close the connection. + return linger = False @@ -1214,7 +1359,8 @@ class HTTPConnection(object): # want this server to send a FIN TCP segment immediately. Note this # must be called *before* calling socket.close(), because the latter # drops its reference to the kernel socket. - self.socket._sock.close() + if hasattr(self.socket, '_sock'): + self.socket._sock.close() self.socket.close() else: # On the other hand, sometimes we want to hang around for a bit @@ -1226,13 +1372,13 @@ class HTTPConnection(object): pass -def format_exc(limit=None): - """Like print_exc() but return a string. Backport for Python 2.3.""" - try: - etype, value, tb = sys.exc_info() - return ''.join(traceback.format_exception(etype, value, tb, limit)) - finally: - etype = value = tb = None +class TrueyZero(object): + """An object which equals and does math like the integer '0' but evals True.""" + def __add__(self, other): + return other + def __radd__(self, other): + return other +trueyzero = TrueyZero() _SHUTDOWNREQUEST = None @@ -1240,11 +1386,6 @@ _SHUTDOWNREQUEST = None class WorkerThread(threading.Thread): """Thread which continuously polls a Queue for Connection objects. - server: the HTTP Server which spawned this thread, and which owns the - Queue and is placing active connections into it. - ready: a simple flag for the calling server to know when this thread - has begun polling the Queue. - Due to the timing issues of polling a Queue, a WorkerThread does not check its own 'ready' flag after it has started. To stop the thread, it is necessary to stick a _SHUTDOWNREQUEST object onto the Queue @@ -1252,13 +1393,38 @@ class WorkerThread(threading.Thread): """ conn = None + """The current connection pulled off the Queue, or None.""" + + server = None + """The HTTP Server which spawned this thread, and which owns the + Queue and is placing active connections into it.""" + + ready = False + """A simple flag for the calling server to know when this thread + has begun polling the Queue.""" + def __init__(self, server): self.ready = False self.server = server + + self.requests_seen = 0 + self.bytes_read = 0 + self.bytes_written = 0 + self.start_time = None + self.work_time = 0 + self.stats = { + 'Requests': lambda s: self.requests_seen + ((self.start_time is None) and trueyzero or self.conn.requests_seen), + 'Bytes Read': lambda s: self.bytes_read + ((self.start_time is None) and trueyzero or self.conn.rfile.bytes_read), + 'Bytes Written': lambda s: self.bytes_written + ((self.start_time is None) and trueyzero or self.conn.wfile.bytes_written), + 'Work Time': lambda s: self.work_time + ((self.start_time is None) and trueyzero or time.time() - self.start_time), + 'Read Throughput': lambda s: s['Bytes Read'](s) / (s['Work Time'](s) or 1e-6), + 'Write Throughput': lambda s: s['Bytes Written'](s) / (s['Work Time'](s) or 1e-6), + } threading.Thread.__init__(self) def run(self): + self.server.stats['Worker Threads'][self.getName()] = self.stats try: self.ready = True while True: @@ -1267,17 +1433,26 @@ class WorkerThread(threading.Thread): return self.conn = conn + if self.server.stats['Enabled']: + self.start_time = time.time() try: conn.communicate() finally: conn.close() + if self.server.stats['Enabled']: + self.requests_seen += self.conn.requests_seen + self.bytes_read += self.conn.rfile.bytes_read + self.bytes_written += self.conn.wfile.bytes_written + self.work_time += time.time() - self.start_time + self.start_time = None self.conn = None - except (KeyboardInterrupt, SystemExit), exc: + except (KeyboardInterrupt, SystemExit): + exc = sys.exc_info()[1] self.server.interrupt = exc class ThreadPool(object): - """A Request Queue for the CherryPyWSGIServer which pools threads. + """A Request Queue for an HTTPServer which pools threads. ThreadPool objects must provide min, get(), put(obj), start() and stop(timeout) attributes. @@ -1288,15 +1463,15 @@ class ThreadPool(object): self.min = min self.max = max self._threads = [] - self._queue = Queue.Queue() + self._queue = queue.Queue() self.get = self._queue.get def start(self): """Start the pool of threads.""" - for i in xrange(self.min): + for i in range(self.min): self._threads.append(WorkerThread(self.server)) for worker in self._threads: - worker.setName("CP WSGIServer " + worker.getName()) + worker.setName("CP Server " + worker.getName()) worker.start() for worker in self._threads: while not worker.ready: @@ -1314,11 +1489,11 @@ class ThreadPool(object): def grow(self, amount): """Spawn new worker threads (not above self.max).""" - for i in xrange(amount): + for i in range(amount): if self.max > 0 and len(self._threads) >= self.max: break worker = WorkerThread(self.server) - worker.setName("CP WSGIServer " + worker.getName()) + worker.setName("CP Server " + worker.getName()) self._threads.append(worker) worker.start() @@ -1332,7 +1507,7 @@ class ThreadPool(object): amount -= 1 if amount > 0: - for i in xrange(min(amount, len(self._threads) - self.min)): + for i in range(min(amount, len(self._threads) - self.min)): # Put a number of shutdown requests on the queue equal # to 'amount'. Once each of those is processed by a worker, # that worker will terminate and be culled from our list @@ -1347,6 +1522,8 @@ class ThreadPool(object): # Don't join currentThread (when stop is called inside a request). current = threading.currentThread() + if timeout and timeout >= 0: + endtime = time.time() + timeout while self._threads: worker = self._threads.pop() if worker is not current and worker.isAlive(): @@ -1354,51 +1531,30 @@ class ThreadPool(object): if timeout is None or timeout < 0: worker.join() else: - worker.join(timeout) + remaining_time = endtime - time.time() + if remaining_time > 0: + worker.join(remaining_time) if worker.isAlive(): # We exhausted the timeout. # Forcibly shut down the socket. c = worker.conn if c and not c.rfile.closed: - if SSL and isinstance(c.socket, SSL.ConnectionType): - # pyOpenSSL.socket.shutdown takes no args - c.socket.shutdown() - else: + try: c.socket.shutdown(socket.SHUT_RD) + except TypeError: + # pyOpenSSL sockets don't take an arg + c.socket.shutdown() worker.join() except (AssertionError, # Ignore repeated Ctrl-C. # See http://www.cherrypy.org/ticket/691. - KeyboardInterrupt), exc1: + KeyboardInterrupt): pass - - - -class SSLConnection: - """A thread-safe wrapper for an SSL.Connection. - - *args: the arguments to create the wrapped SSL.Connection(*args). - """ - def __init__(self, *args): - self._ssl_conn = SSL.Connection(*args) - self._lock = threading.RLock() - - for f in ('get_context', 'pending', 'send', 'write', 'recv', 'read', - 'renegotiate', 'bind', 'listen', 'connect', 'accept', - 'setblocking', 'fileno', 'shutdown', 'close', 'get_cipher_list', - 'getpeername', 'getsockname', 'getsockopt', 'setsockopt', - 'makefile', 'get_app_data', 'set_app_data', 'state_string', - 'sock_shutdown', 'get_peer_certificate', 'want_read', - 'want_write', 'set_connect_state', 'set_accept_state', - 'connect_ex', 'sendall', 'settimeout'): - exec """def %s(self, *args): - self._lock.acquire() - try: - return self._ssl_conn.%s(*args) - finally: - self._lock.release() -""" % (f, f) + def _get_qsize(self): + return self._queue.qsize() + qsize = property(_get_qsize) + try: @@ -1423,96 +1579,136 @@ else: fcntl.fcntl(fd, fcntl.F_SETFD, old_flags | fcntl.FD_CLOEXEC) -class CherryPyWSGIServer(object): - """An HTTP server for WSGI. +class SSLAdapter(object): + """Base class for SSL driver library adapters. - bind_addr: The interface on which to listen for connections. - For TCP sockets, a (host, port) tuple. Host values may be any IPv4 - or IPv6 address, or any valid hostname. The string 'localhost' is a - synonym for '127.0.0.1' (or '::1', if your hosts file prefers IPv6). - The string '0.0.0.0' is a special IPv4 entry meaning "any active - interface" (INADDR_ANY), and '::' is the similar IN6ADDR_ANY for - IPv6. The empty string or None are not allowed. - - For UNIX sockets, supply the filename as a string. - wsgi_app: the WSGI 'application callable'; multiple WSGI applications - may be passed as (path_prefix, app) pairs. - numthreads: the number of worker threads to create (default 10). - server_name: the string to set for WSGI's SERVER_NAME environ entry. - Defaults to socket.gethostname(). - max: the maximum number of queued requests (defaults to -1 = no limit). - request_queue_size: the 'backlog' argument to socket.listen(); - specifies the maximum number of queued connections (default 5). - timeout: the timeout in seconds for accepted connections (default 10). + Required methods: - nodelay: if True (the default since 3.1), sets the TCP_NODELAY socket - option. + * ``wrap(sock) -> (wrapped socket, ssl environ dict)`` + * ``makefile(sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE) -> socket file object`` + """ - protocol: the version string to write in the Status-Line of all - HTTP responses. For example, "HTTP/1.1" (the default). This - also limits the supported features used in the response. + def __init__(self, certificate, private_key, certificate_chain=None): + self.certificate = certificate + self.private_key = private_key + self.certificate_chain = certificate_chain + def wrap(self, sock): + raise NotImplemented - SSL/HTTPS - --------- - The OpenSSL module must be importable for SSL functionality. - You can obtain it from http://pyopenssl.sourceforge.net/ + def makefile(self, sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE): + raise NotImplemented + + +class HTTPServer(object): + """An HTTP server.""" - ssl_certificate: the filename of the server SSL certificate. - ssl_privatekey: the filename of the server's private key file. + _bind_addr = "127.0.0.1" + _interrupt = None - If either of these is None (both are None by default), this server - will not use SSL. If both are given and are valid, they will be read - on server start and used in the SSL context for the listening socket. - """ + gateway = None + """A Gateway instance.""" + + minthreads = None + """The minimum number of worker threads to create (default 10).""" + + maxthreads = None + """The maximum number of worker threads to create (default -1 = no limit).""" + + server_name = None + """The name of the server; defaults to socket.gethostname().""" protocol = "HTTP/1.1" - _bind_addr = "127.0.0.1" - version = "CherryPy/3.1.2" + """The version string to write in the Status-Line of all HTTP responses. + + For example, "HTTP/1.1" is the default. This also limits the supported + features used in the response.""" + + request_queue_size = 5 + """The 'backlog' arg to socket.listen(); max queued connections (default 5).""" + + shutdown_timeout = 5 + """The total time, in seconds, to wait for worker threads to cleanly exit.""" + + timeout = 10 + """The timeout in seconds for accepted connections (default 10).""" + + version = "CherryPy/3.2.1" + """A version string for the HTTPServer.""" + + software = None + """The value to set for the SERVER_SOFTWARE entry in the WSGI environ. + + If None, this defaults to ``'%s Server' % self.version``.""" + ready = False - _interrupt = None + """An internal flag which marks whether the socket is accepting connections.""" + + max_request_header_size = 0 + """The maximum size, in bytes, for request headers, or 0 for no limit.""" + + max_request_body_size = 0 + """The maximum size, in bytes, for request bodies, or 0 for no limit.""" nodelay = True + """If True (the default since 3.1), sets the TCP_NODELAY socket option.""" ConnectionClass = HTTPConnection - environ = {} + """The class to use for handling HTTP connections.""" - # Paths to certificate and private key files - ssl_certificate = None - ssl_private_key = None + ssl_adapter = None + """An instance of SSLAdapter (or a subclass). - def __init__(self, bind_addr, wsgi_app, numthreads=10, server_name=None, - max=-1, request_queue_size=5, timeout=10, shutdown_timeout=5): - self.requests = ThreadPool(self, min=numthreads or 1, max=max) + You must have the corresponding SSL driver library installed.""" + + def __init__(self, bind_addr, gateway, minthreads=10, maxthreads=-1, + server_name=None): + self.bind_addr = bind_addr + self.gateway = gateway - if callable(wsgi_app): - # We've been handed a single wsgi_app, in CP-2.1 style. - # Assume it's mounted at "". - self.wsgi_app = wsgi_app - else: - # We've been handed a list of (path_prefix, wsgi_app) tuples, - # so that the server can call different wsgi_apps, and also - # correctly set SCRIPT_NAME. - warnings.warn("The ability to pass multiple apps is deprecated " - "and will be removed in 3.2. You should explicitly " - "include a WSGIPathInfoDispatcher instead.", - DeprecationWarning) - self.wsgi_app = WSGIPathInfoDispatcher(wsgi_app) + self.requests = ThreadPool(self, min=minthreads or 1, max=maxthreads) - self.bind_addr = bind_addr if not server_name: server_name = socket.gethostname() self.server_name = server_name - self.request_queue_size = request_queue_size - - self.timeout = timeout - self.shutdown_timeout = shutdown_timeout + self.clear_stats() + + def clear_stats(self): + self._start_time = None + self._run_time = 0 + self.stats = { + 'Enabled': False, + 'Bind Address': lambda s: repr(self.bind_addr), + 'Run time': lambda s: (not s['Enabled']) and -1 or self.runtime(), + 'Accepts': 0, + 'Accepts/sec': lambda s: s['Accepts'] / self.runtime(), + 'Queue': lambda s: getattr(self.requests, "qsize", None), + 'Threads': lambda s: len(getattr(self.requests, "_threads", [])), + 'Threads Idle': lambda s: getattr(self.requests, "idle", None), + 'Socket Errors': 0, + 'Requests': lambda s: (not s['Enabled']) and -1 or sum([w['Requests'](w) for w + in s['Worker Threads'].values()], 0), + 'Bytes Read': lambda s: (not s['Enabled']) and -1 or sum([w['Bytes Read'](w) for w + in s['Worker Threads'].values()], 0), + 'Bytes Written': lambda s: (not s['Enabled']) and -1 or sum([w['Bytes Written'](w) for w + in s['Worker Threads'].values()], 0), + 'Work Time': lambda s: (not s['Enabled']) and -1 or sum([w['Work Time'](w) for w + in s['Worker Threads'].values()], 0), + 'Read Throughput': lambda s: (not s['Enabled']) and -1 or sum( + [w['Bytes Read'](w) / (w['Work Time'](w) or 1e-6) + for w in s['Worker Threads'].values()], 0), + 'Write Throughput': lambda s: (not s['Enabled']) and -1 or sum( + [w['Bytes Written'](w) / (w['Work Time'](w) or 1e-6) + for w in s['Worker Threads'].values()], 0), + 'Worker Threads': {}, + } + logging.statistics["CherryPy HTTPServer %d" % id(self)] = self.stats - def _get_numthreads(self): - return self.requests.min - def _set_numthreads(self, value): - self.requests.min = value - numthreads = property(_get_numthreads, _set_numthreads) + def runtime(self): + if self._start_time is None: + return self._run_time + else: + return self._run_time + (time.time() - self._start_time) def __str__(self): return "%s.%s(%r)" % (self.__module__, self.__class__.__name__, @@ -1556,6 +1752,28 @@ class CherryPyWSGIServer(object): # trap those exceptions in whatever code block calls start(). self._interrupt = None + if self.software is None: + self.software = "%s Server" % self.version + + # SSL backward compatibility + if (self.ssl_adapter is None and + getattr(self, 'ssl_certificate', None) and + getattr(self, 'ssl_private_key', None)): + warnings.warn( + "SSL attributes are deprecated in CherryPy 3.2, and will " + "be removed in CherryPy 3.3. Use an ssl_adapter attribute " + "instead.", + DeprecationWarning + ) + try: + from cherrypy.wsgiserver.ssl_pyopenssl import pyOpenSSLAdapter + except ImportError: + pass + else: + self.ssl_adapter = pyOpenSSLAdapter( + self.ssl_certificate, self.ssl_private_key, + getattr(self, 'ssl_certificate_chain', None)) + # Select the appropriate socket if isinstance(self.bind_addr, basestring): # AF_UNIX socket @@ -1565,7 +1783,7 @@ class CherryPyWSGIServer(object): except: pass # So everyone can access the socket... - try: os.chmod(self.bind_addr, 0777) + try: os.chmod(self.bind_addr, 511) # 0777 except: pass info = [(socket.AF_UNIX, socket.SOCK_STREAM, 0, "", self.bind_addr)] @@ -1577,8 +1795,12 @@ class CherryPyWSGIServer(object): info = socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE) except socket.gaierror: - # Probably a DNS issue. Assume IPv4. - info = [(socket.AF_INET, socket.SOCK_STREAM, 0, "", self.bind_addr)] + if ':' in self.bind_addr[0]: + info = [(socket.AF_INET6, socket.SOCK_STREAM, + 0, "", self.bind_addr + (0, 0))] + else: + info = [(socket.AF_INET, socket.SOCK_STREAM, + 0, "", self.bind_addr)] self.socket = None msg = "No socket could be created" @@ -1586,14 +1808,14 @@ class CherryPyWSGIServer(object): af, socktype, proto, canonname, sa = res try: self.bind(af, socktype, proto) - except socket.error, msg: + except socket.error: if self.socket: self.socket.close() self.socket = None continue break if not self.socket: - raise socket.error, msg + raise socket.error(msg) # Timeout so KeyboardInterrupt can be caught on Win32 self.socket.settimeout(1) @@ -1603,6 +1825,7 @@ class CherryPyWSGIServer(object): self.requests.start() self.ready = True + self._start_time = time.time() while self.ready: self.tick() if self.interrupt: @@ -1617,29 +1840,22 @@ class CherryPyWSGIServer(object): self.socket = socket.socket(family, type, proto) prevent_socket_inheritance(self.socket) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - if self.nodelay: + if self.nodelay and not isinstance(self.bind_addr, str): self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - if self.ssl_certificate and self.ssl_private_key: - if SSL is None: - raise ImportError("You must install pyOpenSSL to use HTTPS.") - - # See http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/442473 - ctx = SSL.Context(SSL.SSLv23_METHOD) - ctx.use_privatekey_file(self.ssl_private_key) - ctx.use_certificate_file(self.ssl_certificate) - self.socket = SSLConnection(ctx, self.socket) - self.populate_ssl_environ() - - # If listening on the IPV6 any address ('::' = IN6ADDR_ANY), - # activate dual-stack. See http://www.cherrypy.org/ticket/871. - if (not isinstance(self.bind_addr, basestring) - and self.bind_addr[0] == '::' and family == socket.AF_INET6): - try: - self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) - except (AttributeError, socket.error): - # Apparently, the socket option is not available in - # this machine's TCP stack - pass + + if self.ssl_adapter is not None: + self.socket = self.ssl_adapter.bind(self.socket) + + # If listening on the IPV6 any address ('::' = IN6ADDR_ANY), + # activate dual-stack. See http://www.cherrypy.org/ticket/871. + if (hasattr(socket, 'AF_INET6') and family == socket.AF_INET6 + and self.bind_addr[0] in ('::', '::0', '::0.0.0.0')): + try: + self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) + except (AttributeError, socket.error): + # Apparently, the socket option is not available in + # this machine's TCP stack + pass self.socket.bind(self.bind_addr) @@ -1647,42 +1863,72 @@ class CherryPyWSGIServer(object): """Accept a new connection and put it on the Queue.""" try: s, addr = self.socket.accept() - prevent_socket_inheritance(s) + if self.stats['Enabled']: + self.stats['Accepts'] += 1 if not self.ready: return + + prevent_socket_inheritance(s) if hasattr(s, 'settimeout'): s.settimeout(self.timeout) - environ = self.environ.copy() - # SERVER_SOFTWARE is common for IIS. It's also helpful for - # us to pass a default value for the "Server" response header. - if environ.get("SERVER_SOFTWARE") is None: - environ["SERVER_SOFTWARE"] = "%s WSGI Server" % self.version - # set a non-standard environ entry so the WSGI app can know what - # the *real* server protocol is (and what features to support). - # See http://www.faqs.org/rfcs/rfc2145.html. - environ["ACTUAL_SERVER_PROTOCOL"] = self.protocol - environ["SERVER_NAME"] = self.server_name + makefile = CP_fileobject + ssl_env = {} + # if ssl cert and key are set, we try to be a secure HTTP server + if self.ssl_adapter is not None: + try: + s, ssl_env = self.ssl_adapter.wrap(s) + except NoSSLError: + msg = ("The client sent a plain HTTP request, but " + "this server only speaks HTTPS on this port.") + buf = ["%s 400 Bad Request\r\n" % self.protocol, + "Content-Length: %s\r\n" % len(msg), + "Content-Type: text/plain\r\n\r\n", + msg] + + wfile = makefile(s, "wb", DEFAULT_BUFFER_SIZE) + try: + wfile.sendall("".join(buf)) + except socket.error: + x = sys.exc_info()[1] + if x.args[0] not in socket_errors_to_ignore: + raise + return + if not s: + return + makefile = self.ssl_adapter.makefile + # Re-apply our timeout since we may have a new socket object + if hasattr(s, 'settimeout'): + s.settimeout(self.timeout) - if isinstance(self.bind_addr, basestring): - # AF_UNIX. This isn't really allowed by WSGI, which doesn't - # address unix domain sockets. But it's better than nothing. - environ["SERVER_PORT"] = "" - else: - environ["SERVER_PORT"] = str(self.bind_addr[1]) + conn = self.ConnectionClass(self, s, makefile) + + if not isinstance(self.bind_addr, basestring): # optional values # Until we do DNS lookups, omit REMOTE_HOST - environ["REMOTE_ADDR"] = addr[0] - environ["REMOTE_PORT"] = str(addr[1]) + if addr is None: # sometimes this can happen + # figure out if AF_INET or AF_INET6. + if len(s.getsockname()) == 2: + # AF_INET + addr = ('0.0.0.0', 0) + else: + # AF_INET6 + addr = ('::', 0) + conn.remote_addr = addr[0] + conn.remote_port = addr[1] + + conn.ssl_env = ssl_env - conn = self.ConnectionClass(s, self.wsgi_app, environ) self.requests.put(conn) except socket.timeout: # The only reason for the timeout in start() is so we can # notice keyboard interrupts on Win32, which don't interrupt # accept() by default return - except socket.error, x: + except socket.error: + x = sys.exc_info()[1] + if self.stats['Enabled']: + self.stats['Socket Errors'] += 1 if x.args[0] in socket_error_eintr: # I *think* this is right. EINTR should occur when a signal # is received during the accept() call; all docs say retry @@ -1712,6 +1958,9 @@ class CherryPyWSGIServer(object): def stop(self): """Gracefully shutdown a server that is serving forever.""" self.ready = False + if self._start_time is not None: + self._run_time += (time.time() - self._start_time) + self._start_time = None sock = getattr(self, "socket", None) if sock: @@ -1719,8 +1968,11 @@ class CherryPyWSGIServer(object): # Touch our own socket to make accept() return immediately. try: host, port = sock.getsockname()[:2] - except socket.error, x: + except socket.error: + x = sys.exc_info()[1] if x.args[0] not in socket_errors_to_ignore: + # Changed to use error code and not message + # See http://www.cherrypy.org/ticket/860. raise else: # Note that we're explicitly NOT using AI_PASSIVE, @@ -1746,49 +1998,302 @@ class CherryPyWSGIServer(object): self.socket = None self.requests.stop(self.shutdown_timeout) + + +class Gateway(object): + """A base class to interface HTTPServer with other systems, such as WSGI.""" + + def __init__(self, req): + self.req = req + + def respond(self): + """Process the current request. Must be overridden in a subclass.""" + raise NotImplemented + + +# These may either be wsgiserver.SSLAdapter subclasses or the string names +# of such classes (in which case they will be lazily loaded). +ssl_adapters = { + 'builtin': 'cherrypy.wsgiserver.ssl_builtin.BuiltinSSLAdapter', + 'pyopenssl': 'cherrypy.wsgiserver.ssl_pyopenssl.pyOpenSSLAdapter', + } + +def get_ssl_adapter_class(name='pyopenssl'): + """Return an SSL adapter class for the given name.""" + adapter = ssl_adapters[name.lower()] + if isinstance(adapter, basestring): + last_dot = adapter.rfind(".") + attr_name = adapter[last_dot + 1:] + mod_path = adapter[:last_dot] + + try: + mod = sys.modules[mod_path] + if mod is None: + raise KeyError() + except KeyError: + # The last [''] is important. + mod = __import__(mod_path, globals(), locals(), ['']) + + # Let an AttributeError propagate outward. + try: + adapter = getattr(mod, attr_name) + except AttributeError: + raise AttributeError("'%s' object has no attribute '%s'" + % (mod_path, attr_name)) + + return adapter + +# -------------------------------- WSGI Stuff -------------------------------- # + + +class CherryPyWSGIServer(HTTPServer): + """A subclass of HTTPServer which calls a WSGI application.""" + + wsgi_version = (1, 0) + """The version of WSGI to produce.""" + + def __init__(self, bind_addr, wsgi_app, numthreads=10, server_name=None, + max=-1, request_queue_size=5, timeout=10, shutdown_timeout=5): + self.requests = ThreadPool(self, min=numthreads or 1, max=max) + self.wsgi_app = wsgi_app + self.gateway = wsgi_gateways[self.wsgi_version] + + self.bind_addr = bind_addr + if not server_name: + server_name = socket.gethostname() + self.server_name = server_name + self.request_queue_size = request_queue_size + + self.timeout = timeout + self.shutdown_timeout = shutdown_timeout + self.clear_stats() + + def _get_numthreads(self): + return self.requests.min + def _set_numthreads(self, value): + self.requests.min = value + numthreads = property(_get_numthreads, _set_numthreads) + + +class WSGIGateway(Gateway): + """A base class to interface HTTPServer with WSGI.""" + + def __init__(self, req): + self.req = req + self.started_response = False + self.env = self.get_environ() + self.remaining_bytes_out = None - def populate_ssl_environ(self): - """Create WSGI environ entries to be merged into each request.""" - cert = open(self.ssl_certificate, 'rb').read() - cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert) - ssl_environ = { - "wsgi.url_scheme": "https", - "HTTPS": "on", - # pyOpenSSL doesn't provide access to any of these AFAICT -## 'SSL_PROTOCOL': 'SSLv2', -## SSL_CIPHER string The cipher specification name -## SSL_VERSION_INTERFACE string The mod_ssl program version -## SSL_VERSION_LIBRARY string The OpenSSL program version + def get_environ(self): + """Return a new environ dict targeting the given wsgi.version""" + raise NotImplemented + + def respond(self): + """Process the current request.""" + response = self.req.server.wsgi_app(self.env, self.start_response) + try: + for chunk in response: + # "The start_response callable must not actually transmit + # the response headers. Instead, it must store them for the + # server or gateway to transmit only after the first + # iteration of the application return value that yields + # a NON-EMPTY string, or upon the application's first + # invocation of the write() callable." (PEP 333) + if chunk: + if isinstance(chunk, unicodestr): + chunk = chunk.encode('ISO-8859-1') + self.write(chunk) + finally: + if hasattr(response, "close"): + response.close() + + def start_response(self, status, headers, exc_info = None): + """WSGI callable to begin the HTTP response.""" + # "The application may call start_response more than once, + # if and only if the exc_info argument is provided." + if self.started_response and not exc_info: + raise AssertionError("WSGI start_response called a second " + "time with no exc_info.") + self.started_response = True + + # "if exc_info is provided, and the HTTP headers have already been + # sent, start_response must raise an error, and should raise the + # exc_info tuple." + if self.req.sent_headers: + try: + raise exc_info[0], exc_info[1], exc_info[2] + finally: + exc_info = None + + self.req.status = status + for k, v in headers: + if not isinstance(k, bytestr): + raise TypeError("WSGI response header key %r is not a byte string." % k) + if not isinstance(v, bytestr): + raise TypeError("WSGI response header value %r is not a byte string." % v) + if k.lower() == 'content-length': + self.remaining_bytes_out = int(v) + self.req.outheaders.extend(headers) + + return self.write + + def write(self, chunk): + """WSGI callable to write unbuffered data to the client. + + This method is also used internally by start_response (to write + data from the iterable returned by the WSGI application). + """ + if not self.started_response: + raise AssertionError("WSGI write called before start_response.") + + chunklen = len(chunk) + rbo = self.remaining_bytes_out + if rbo is not None and chunklen > rbo: + if not self.req.sent_headers: + # Whew. We can send a 500 to the client. + self.req.simple_response("500 Internal Server Error", + "The requested resource returned more bytes than the " + "declared Content-Length.") + else: + # Dang. We have probably already sent data. Truncate the chunk + # to fit (so the client doesn't hang) and raise an error later. + chunk = chunk[:rbo] + + if not self.req.sent_headers: + self.req.sent_headers = True + self.req.send_headers() + + self.req.write(chunk) + + if rbo is not None: + rbo -= chunklen + if rbo < 0: + raise ValueError( + "Response body exceeds the declared Content-Length.") + + +class WSGIGateway_10(WSGIGateway): + """A Gateway class to interface HTTPServer with WSGI 1.0.x.""" + + def get_environ(self): + """Return a new environ dict targeting the given wsgi.version""" + req = self.req + env = { + # set a non-standard environ entry so the WSGI app can know what + # the *real* server protocol is (and what features to support). + # See http://www.faqs.org/rfcs/rfc2145.html. + 'ACTUAL_SERVER_PROTOCOL': req.server.protocol, + 'PATH_INFO': req.path, + 'QUERY_STRING': req.qs, + 'REMOTE_ADDR': req.conn.remote_addr or '', + 'REMOTE_PORT': str(req.conn.remote_port or ''), + 'REQUEST_METHOD': req.method, + 'REQUEST_URI': req.uri, + 'SCRIPT_NAME': '', + 'SERVER_NAME': req.server.server_name, + # Bah. "SERVER_PROTOCOL" is actually the REQUEST protocol. + 'SERVER_PROTOCOL': req.request_protocol, + 'SERVER_SOFTWARE': req.server.software, + 'wsgi.errors': sys.stderr, + 'wsgi.input': req.rfile, + 'wsgi.multiprocess': False, + 'wsgi.multithread': True, + 'wsgi.run_once': False, + 'wsgi.url_scheme': req.scheme, + 'wsgi.version': (1, 0), } - # Server certificate attributes - ssl_environ.update({ - 'SSL_SERVER_M_VERSION': cert.get_version(), - 'SSL_SERVER_M_SERIAL': cert.get_serial_number(), -## 'SSL_SERVER_V_START': Validity of server's certificate (start time), -## 'SSL_SERVER_V_END': Validity of server's certificate (end time), - }) - - for prefix, dn in [("I", cert.get_issuer()), - ("S", cert.get_subject())]: - # X509Name objects don't seem to have a way to get the - # complete DN string. Use str() and slice it instead, - # because str(dn) == "<X509Name object '/C=US/ST=...'>" - dnstr = str(dn)[18:-2] - - wsgikey = 'SSL_SERVER_%s_DN' % prefix - ssl_environ[wsgikey] = dnstr - - # The DN should be of the form: /k1=v1/k2=v2, but we must allow - # for any value to contain slashes itself (in a URL). - while dnstr: - pos = dnstr.rfind("=") - dnstr, value = dnstr[:pos], dnstr[pos + 1:] - pos = dnstr.rfind("/") - dnstr, key = dnstr[:pos], dnstr[pos + 1:] - if key and value: - wsgikey = 'SSL_SERVER_%s_DN_%s' % (prefix, key) - ssl_environ[wsgikey] = value - - self.environ.update(ssl_environ) + if isinstance(req.server.bind_addr, basestring): + # AF_UNIX. This isn't really allowed by WSGI, which doesn't + # address unix domain sockets. But it's better than nothing. + env["SERVER_PORT"] = "" + else: + env["SERVER_PORT"] = str(req.server.bind_addr[1]) + + # Request headers + for k, v in req.inheaders.iteritems(): + env["HTTP_" + k.upper().replace("-", "_")] = v + + # CONTENT_TYPE/CONTENT_LENGTH + ct = env.pop("HTTP_CONTENT_TYPE", None) + if ct is not None: + env["CONTENT_TYPE"] = ct + cl = env.pop("HTTP_CONTENT_LENGTH", None) + if cl is not None: + env["CONTENT_LENGTH"] = cl + + if req.conn.ssl_env: + env.update(req.conn.ssl_env) + + return env + +class WSGIGateway_u0(WSGIGateway_10): + """A Gateway class to interface HTTPServer with WSGI u.0. + + WSGI u.0 is an experimental protocol, which uses unicode for keys and values + in both Python 2 and Python 3. + """ + + def get_environ(self): + """Return a new environ dict targeting the given wsgi.version""" + req = self.req + env_10 = WSGIGateway_10.get_environ(self) + env = dict([(k.decode('ISO-8859-1'), v) for k, v in env_10.iteritems()]) + env[u'wsgi.version'] = ('u', 0) + + # Request-URI + env.setdefault(u'wsgi.url_encoding', u'utf-8') + try: + for key in [u"PATH_INFO", u"SCRIPT_NAME", u"QUERY_STRING"]: + env[key] = env_10[str(key)].decode(env[u'wsgi.url_encoding']) + except UnicodeDecodeError: + # Fall back to latin 1 so apps can transcode if needed. + env[u'wsgi.url_encoding'] = u'ISO-8859-1' + for key in [u"PATH_INFO", u"SCRIPT_NAME", u"QUERY_STRING"]: + env[key] = env_10[str(key)].decode(env[u'wsgi.url_encoding']) + + for k, v in sorted(env.items()): + if isinstance(v, str) and k not in ('REQUEST_URI', 'wsgi.input'): + env[k] = v.decode('ISO-8859-1') + + return env + +wsgi_gateways = { + (1, 0): WSGIGateway_10, + ('u', 0): WSGIGateway_u0, +} + +class WSGIPathInfoDispatcher(object): + """A WSGI dispatcher for dispatch based on the PATH_INFO. + + apps: a dict or list of (path_prefix, app) pairs. + """ + + def __init__(self, apps): + try: + apps = list(apps.items()) + except AttributeError: + pass + + # Sort the apps by len(path), descending + apps.sort(cmp=lambda x,y: cmp(len(x[0]), len(y[0]))) + apps.reverse() + + # The path_prefix strings must start, but not end, with a slash. + # Use "" instead of "/". + self.apps = [(p.rstrip("/"), a) for p, a in apps] + + def __call__(self, environ, start_response): + path = environ["PATH_INFO"] or "/" + for p, app in self.apps: + # The apps list should be sorted by length, descending. + if path.startswith(p + "/") or path == p: + environ = environ.copy() + environ["SCRIPT_NAME"] = environ["SCRIPT_NAME"] + p + environ["PATH_INFO"] = path[len(p):] + return app(environ, start_response) + + start_response('404 Not Found', [('Content-Type', 'text/plain'), + ('Content-Length', '0')]) + return [''] |