diff options
| author | 2014-06-28 16:12:44 +0200 | |
|---|---|---|
| committer | 2014-06-28 20:23:59 +0200 | |
| commit | 2c797aba90ec32fa979dce8c89789309f936ddce (patch) | |
| tree | fd08def9e9e9cd9892e8b5743fa91be4ad45b26c /module/lib/thrift/protocol | |
| parent | [Lib] Update simplejson to version 3.5.3 (diff) | |
| download | pyload-2c797aba90ec32fa979dce8c89789309f936ddce.tar.xz | |
[Lib] Update thrift to version 0.9.1
Diffstat (limited to 'module/lib/thrift/protocol')
| -rw-r--r-- | module/lib/thrift/protocol/TBase.py | 27 | ||||
| -rw-r--r-- | module/lib/thrift/protocol/TBinaryProtocol.py | 13 | ||||
| -rw-r--r-- | module/lib/thrift/protocol/TCompactProtocol.py | 34 | ||||
| -rw-r--r-- | module/lib/thrift/protocol/TJSONProtocol.py | 550 | ||||
| -rw-r--r-- | module/lib/thrift/protocol/TProtocol.py | 102 | ||||
| -rw-r--r-- | module/lib/thrift/protocol/__init__.py | 2 | ||||
| -rw-r--r-- | module/lib/thrift/protocol/fastbinary.c | 1219 | 
7 files changed, 1868 insertions, 79 deletions
| diff --git a/module/lib/thrift/protocol/TBase.py b/module/lib/thrift/protocol/TBase.py index e675c7dc0..6cbd5f39a 100644 --- a/module/lib/thrift/protocol/TBase.py +++ b/module/lib/thrift/protocol/TBase.py @@ -26,12 +26,13 @@ try:  except:    fastbinary = None +  class TBase(object):    __slots__ = []    def __repr__(self):      L = ['%s=%r' % (key, getattr(self, key)) -              for key in self.__slots__ ] +              for key in self.__slots__]      return '%s(%s)' % (self.__class__.__name__, ', '.join(L))    def __eq__(self, other): @@ -43,30 +44,38 @@ class TBase(object):        if my_val != other_val:          return False      return True -     +    def __ne__(self, other):      return not (self == other) -   +    def read(self, iprot): -    if iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None and fastbinary is not None: -      fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec)) +    if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and +        isinstance(iprot.trans, TTransport.CReadableTransport) and +        self.thrift_spec is not None and +        fastbinary is not None): +      fastbinary.decode_binary(self, +                               iprot.trans, +                               (self.__class__, self.thrift_spec))        return      iprot.readStruct(self, self.thrift_spec)    def write(self, oprot): -    if oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and self.thrift_spec is not None and fastbinary is not None: -      oprot.trans.write(fastbinary.encode_binary(self, (self.__class__, self.thrift_spec))) +    if (oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and +        self.thrift_spec is not None and +        fastbinary is not None): +      oprot.trans.write( +        fastbinary.encode_binary(self, (self.__class__, self.thrift_spec)))        return      oprot.writeStruct(self, self.thrift_spec) +  class TExceptionBase(Exception):    # old style class so python2.4 can raise exceptions derived from this    #  This can't inherit from TBase because of that limitation.    __slots__ = [] -   +    __repr__ = TBase.__repr__.im_func    __eq__ = TBase.__eq__.im_func    __ne__ = TBase.__ne__.im_func    read = TBase.read.im_func    write = TBase.write.im_func -   diff --git a/module/lib/thrift/protocol/TBinaryProtocol.py b/module/lib/thrift/protocol/TBinaryProtocol.py index 50c6aa896..6fdd08c26 100644 --- a/module/lib/thrift/protocol/TBinaryProtocol.py +++ b/module/lib/thrift/protocol/TBinaryProtocol.py @@ -20,8 +20,8 @@  from TProtocol import *  from struct import pack, unpack -class TBinaryProtocol(TProtocolBase): +class TBinaryProtocol(TProtocolBase):    """Binary implementation of the Thrift protocol driver."""    # NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be @@ -68,7 +68,7 @@ class TBinaryProtocol(TProtocolBase):      pass    def writeFieldStop(self): -    self.writeByte(TType.STOP); +    self.writeByte(TType.STOP)    def writeMapBegin(self, ktype, vtype, size):      self.writeByte(ktype) @@ -127,13 +127,16 @@ class TBinaryProtocol(TProtocolBase):      if sz < 0:        version = sz & TBinaryProtocol.VERSION_MASK        if version != TBinaryProtocol.VERSION_1: -        raise TProtocolException(type=TProtocolException.BAD_VERSION, message='Bad version in readMessageBegin: %d' % (sz)) +        raise TProtocolException( +          type=TProtocolException.BAD_VERSION, +          message='Bad version in readMessageBegin: %d' % (sz))        type = sz & TBinaryProtocol.TYPE_MASK        name = self.readString()        seqid = self.readI32()      else:        if self.strictRead: -        raise TProtocolException(type=TProtocolException.BAD_VERSION, message='No protocol version header') +        raise TProtocolException(type=TProtocolException.BAD_VERSION, +                                 message='No protocol version header')        name = self.trans.readAll(sz)        type = self.readByte()        seqid = self.readI32() @@ -231,7 +234,6 @@ class TBinaryProtocolFactory:  class TBinaryProtocolAccelerated(TBinaryProtocol): -    """C-Accelerated version of TBinaryProtocol.    This class does not override any of TBinaryProtocol's methods, @@ -250,7 +252,6 @@ class TBinaryProtocolAccelerated(TBinaryProtocol):           Please feel free to report bugs and/or success stories           to the public mailing list.    """ -    pass diff --git a/module/lib/thrift/protocol/TCompactProtocol.py b/module/lib/thrift/protocol/TCompactProtocol.py index 016a33171..cdec60773 100644 --- a/module/lib/thrift/protocol/TCompactProtocol.py +++ b/module/lib/thrift/protocol/TCompactProtocol.py @@ -32,6 +32,7 @@ CONTAINER_READ = 6  VALUE_READ = 7  BOOL_READ = 8 +  def make_helper(v_from, container):    def helper(func):      def nested(self, *args, **kwargs): @@ -42,12 +43,15 @@ def make_helper(v_from, container):  writer = make_helper(VALUE_WRITE, CONTAINER_WRITE)  reader = make_helper(VALUE_READ, CONTAINER_READ) +  def makeZigZag(n, bits):    return (n << 1) ^ (n >> (bits - 1)) +  def fromZigZag(n):    return (n >> 1) ^ -(n & 1) +  def writeVarint(trans, n):    out = []    while True: @@ -59,6 +63,7 @@ def writeVarint(trans, n):        n = n >> 7    trans.write(''.join(map(chr, out))) +  def readVarint(trans):    result = 0    shift = 0 @@ -70,6 +75,7 @@ def readVarint(trans):        return result      shift += 7 +  class CompactType:    STOP = 0x00    TRUE = 0x01 @@ -86,7 +92,7 @@ class CompactType:    STRUCT = 0x0C  CTYPES = {TType.STOP: CompactType.STOP, -          TType.BOOL: CompactType.TRUE, # used for collection +          TType.BOOL: CompactType.TRUE,  # used for collection            TType.BYTE: CompactType.BYTE,            TType.I16: CompactType.I16,            TType.I32: CompactType.I32, @@ -106,8 +112,9 @@ TTYPES[CompactType.FALSE] = TType.BOOL  del k  del v +  class TCompactProtocol(TProtocolBase): -  "Compact implementation of the Thrift protocol driver." +  """Compact implementation of the Thrift protocol driver."""    PROTOCOL_ID = 0x82    VERSION = 1 @@ -217,18 +224,18 @@ class TCompactProtocol(TProtocolBase):    def writeBool(self, bool):      if self.state == BOOL_WRITE: -        if bool: -            ctype = CompactType.TRUE -        else: -            ctype = CompactType.FALSE -        self.__writeFieldHeader(ctype, self.__bool_fid) +      if bool: +        ctype = CompactType.TRUE +      else: +        ctype = CompactType.FALSE +      self.__writeFieldHeader(ctype, self.__bool_fid)      elif self.state == CONTAINER_WRITE: -       if bool: -           self.__writeByte(CompactType.TRUE) -       else: -           self.__writeByte(CompactType.FALSE) +      if bool: +        self.__writeByte(CompactType.TRUE) +      else: +        self.__writeByte(CompactType.FALSE)      else: -      raise AssertionError, "Invalid state in compact protocol" +      raise AssertionError("Invalid state in compact protocol")    writeByte = writer(__writeByte)    writeI16 = writer(__writeI16) @@ -364,7 +371,8 @@ class TCompactProtocol(TProtocolBase):      elif self.state == CONTAINER_READ:        return self.__readByte() == CompactType.TRUE      else: -      raise AssertionError, "Invalid state in compact protocol: %d" % self.state +      raise AssertionError("Invalid state in compact protocol: %d" % +                           self.state)    readByte = reader(__readByte)    __readI16 = __readZigZag diff --git a/module/lib/thrift/protocol/TJSONProtocol.py b/module/lib/thrift/protocol/TJSONProtocol.py new file mode 100644 index 000000000..3048197d4 --- /dev/null +++ b/module/lib/thrift/protocol/TJSONProtocol.py @@ -0,0 +1,550 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from TProtocol import TType, TProtocolBase, TProtocolException +import base64 +import json +import math + +__all__ = ['TJSONProtocol', +           'TJSONProtocolFactory', +           'TSimpleJSONProtocol', +           'TSimpleJSONProtocolFactory'] + +VERSION = 1 + +COMMA = ',' +COLON = ':' +LBRACE = '{' +RBRACE = '}' +LBRACKET = '[' +RBRACKET = ']' +QUOTE = '"' +BACKSLASH = '\\' +ZERO = '0' + +ESCSEQ = '\\u00' +ESCAPE_CHAR = '"\\bfnrt' +ESCAPE_CHAR_VALS = ['"', '\\', '\b', '\f', '\n', '\r', '\t'] +NUMERIC_CHAR = '+-.0123456789Ee' + +CTYPES = {TType.BOOL:       'tf', +          TType.BYTE:       'i8', +          TType.I16:        'i16', +          TType.I32:        'i32', +          TType.I64:        'i64', +          TType.DOUBLE:     'dbl', +          TType.STRING:     'str', +          TType.STRUCT:     'rec', +          TType.LIST:       'lst', +          TType.SET:        'set', +          TType.MAP:        'map'} + +JTYPES = {} +for key in CTYPES.keys(): +  JTYPES[CTYPES[key]] = key + + +class JSONBaseContext(object): + +  def __init__(self, protocol): +    self.protocol = protocol +    self.first = True + +  def doIO(self, function): +    pass +   +  def write(self): +    pass + +  def read(self): +    pass + +  def escapeNum(self): +    return False + +  def __str__(self): +    return self.__class__.__name__ + + +class JSONListContext(JSONBaseContext): +     +  def doIO(self, function): +    if self.first is True: +      self.first = False +    else: +      function(COMMA) + +  def write(self): +    self.doIO(self.protocol.trans.write) + +  def read(self): +    self.doIO(self.protocol.readJSONSyntaxChar) + + +class JSONPairContext(JSONBaseContext): +   +  def __init__(self, protocol): +    super(JSONPairContext, self).__init__(protocol) +    self.colon = True + +  def doIO(self, function): +    if self.first: +      self.first = False +      self.colon = True +    else: +      function(COLON if self.colon else COMMA) +      self.colon = not self.colon + +  def write(self): +    self.doIO(self.protocol.trans.write) + +  def read(self): +    self.doIO(self.protocol.readJSONSyntaxChar) + +  def escapeNum(self): +    return self.colon + +  def __str__(self): +    return '%s, colon=%s' % (self.__class__.__name__, self.colon) + + +class LookaheadReader(): +  hasData = False +  data = '' + +  def __init__(self, protocol): +    self.protocol = protocol + +  def read(self): +    if self.hasData is True: +      self.hasData = False +    else: +      self.data = self.protocol.trans.read(1) +    return self.data + +  def peek(self): +    if self.hasData is False: +      self.data = self.protocol.trans.read(1) +    self.hasData = True +    return self.data + +class TJSONProtocolBase(TProtocolBase): + +  def __init__(self, trans): +    TProtocolBase.__init__(self, trans) +    self.resetWriteContext() +    self.resetReadContext() + +  def resetWriteContext(self): +    self.context = JSONBaseContext(self) +    self.contextStack = [self.context] + +  def resetReadContext(self): +    self.resetWriteContext() +    self.reader = LookaheadReader(self) + +  def pushContext(self, ctx): +    self.contextStack.append(ctx) +    self.context = ctx + +  def popContext(self): +    self.contextStack.pop() +    if self.contextStack: +      self.context = self.contextStack[-1] +    else: +      self.context = JSONBaseContext(self) + +  def writeJSONString(self, string): +    self.context.write() +    self.trans.write(json.dumps(string)) + +  def writeJSONNumber(self, number): +    self.context.write() +    jsNumber = str(number) +    if self.context.escapeNum(): +      jsNumber = "%s%s%s" % (QUOTE, jsNumber,  QUOTE) +    self.trans.write(jsNumber) + +  def writeJSONBase64(self, binary): +    self.context.write() +    self.trans.write(QUOTE) +    self.trans.write(base64.b64encode(binary)) +    self.trans.write(QUOTE) + +  def writeJSONObjectStart(self): +    self.context.write() +    self.trans.write(LBRACE) +    self.pushContext(JSONPairContext(self)) + +  def writeJSONObjectEnd(self): +    self.popContext() +    self.trans.write(RBRACE) + +  def writeJSONArrayStart(self): +    self.context.write() +    self.trans.write(LBRACKET) +    self.pushContext(JSONListContext(self)) + +  def writeJSONArrayEnd(self): +    self.popContext() +    self.trans.write(RBRACKET) + +  def readJSONSyntaxChar(self, character): +    current = self.reader.read() +    if character != current: +      raise TProtocolException(TProtocolException.INVALID_DATA, +                               "Unexpected character: %s" % current) + +  def readJSONString(self, skipContext): +    string = [] +    if skipContext is False: +      self.context.read() +    self.readJSONSyntaxChar(QUOTE) +    while True: +      character = self.reader.read() +      if character == QUOTE: +        break +      if character == ESCSEQ[0]: +        character = self.reader.read() +        if character == ESCSEQ[1]: +          self.readJSONSyntaxChar(ZERO) +          self.readJSONSyntaxChar(ZERO) +          character = json.JSONDecoder().decode('"\u00%s"' % self.trans.read(2)) +        else: +          off = ESCAPE_CHAR.find(character) +          if off == -1: +            raise TProtocolException(TProtocolException.INVALID_DATA, +                                     "Expected control char") +          character = ESCAPE_CHAR_VALS[off] +      string.append(character) +    return ''.join(string) + +  def isJSONNumeric(self, character): +    return (True if NUMERIC_CHAR.find(character) != - 1 else False) + +  def readJSONQuotes(self): +    if (self.context.escapeNum()): +      self.readJSONSyntaxChar(QUOTE) + +  def readJSONNumericChars(self): +    numeric = [] +    while True: +      character = self.reader.peek() +      if self.isJSONNumeric(character) is False: +        break +      numeric.append(self.reader.read()) +    return ''.join(numeric) + +  def readJSONInteger(self): +    self.context.read() +    self.readJSONQuotes() +    numeric = self.readJSONNumericChars() +    self.readJSONQuotes() +    try: +      return int(numeric) +    except ValueError: +      raise TProtocolException(TProtocolException.INVALID_DATA, +                               "Bad data encounted in numeric data") + +  def readJSONDouble(self): +    self.context.read() +    if self.reader.peek() == QUOTE: +      string  = self.readJSONString(True) +      try: +        double = float(string) +        if (self.context.escapeNum is False and +            not math.isinf(double) and +            not math.isnan(double)): +          raise TProtocolException(TProtocolException.INVALID_DATA, +                                   "Numeric data unexpectedly quoted") +        return double +      except ValueError: +        raise TProtocolException(TProtocolException.INVALID_DATA, +                                 "Bad data encounted in numeric data") +    else: +      if self.context.escapeNum() is True: +        self.readJSONSyntaxChar(QUOTE) +      try: +        return float(self.readJSONNumericChars()) +      except ValueError: +        raise TProtocolException(TProtocolException.INVALID_DATA, +                                 "Bad data encounted in numeric data") + +  def readJSONBase64(self): +    string = self.readJSONString(False) +    return base64.b64decode(string) + +  def readJSONObjectStart(self): +    self.context.read() +    self.readJSONSyntaxChar(LBRACE) +    self.pushContext(JSONPairContext(self)) + +  def readJSONObjectEnd(self): +    self.readJSONSyntaxChar(RBRACE) +    self.popContext() + +  def readJSONArrayStart(self): +    self.context.read() +    self.readJSONSyntaxChar(LBRACKET) +    self.pushContext(JSONListContext(self)) + +  def readJSONArrayEnd(self): +    self.readJSONSyntaxChar(RBRACKET) +    self.popContext() + + +class TJSONProtocol(TJSONProtocolBase): + +  def readMessageBegin(self): +    self.resetReadContext() +    self.readJSONArrayStart() +    if self.readJSONInteger() != VERSION: +      raise TProtocolException(TProtocolException.BAD_VERSION, +                               "Message contained bad version.") +    name = self.readJSONString(False) +    typen = self.readJSONInteger() +    seqid = self.readJSONInteger() +    return (name, typen, seqid) + +  def readMessageEnd(self): +    self.readJSONArrayEnd() + +  def readStructBegin(self): +    self.readJSONObjectStart() + +  def readStructEnd(self): +    self.readJSONObjectEnd() + +  def readFieldBegin(self): +    character = self.reader.peek() +    ttype = 0 +    id = 0 +    if character == RBRACE: +      ttype = TType.STOP +    else: +      id = self.readJSONInteger() +      self.readJSONObjectStart() +      ttype = JTYPES[self.readJSONString(False)] +    return (None, ttype, id) + +  def readFieldEnd(self): +    self.readJSONObjectEnd() + +  def readMapBegin(self): +    self.readJSONArrayStart() +    keyType = JTYPES[self.readJSONString(False)] +    valueType = JTYPES[self.readJSONString(False)] +    size = self.readJSONInteger() +    self.readJSONObjectStart() +    return (keyType, valueType, size) + +  def readMapEnd(self): +    self.readJSONObjectEnd() +    self.readJSONArrayEnd() + +  def readCollectionBegin(self): +    self.readJSONArrayStart() +    elemType = JTYPES[self.readJSONString(False)] +    size = self.readJSONInteger() +    return (elemType, size) +  readListBegin = readCollectionBegin +  readSetBegin = readCollectionBegin + +  def readCollectionEnd(self): +    self.readJSONArrayEnd() +  readSetEnd = readCollectionEnd +  readListEnd = readCollectionEnd + +  def readBool(self): +    return (False if self.readJSONInteger() == 0 else True) + +  def readNumber(self): +    return self.readJSONInteger() +  readByte = readNumber +  readI16 = readNumber +  readI32 = readNumber +  readI64 = readNumber + +  def readDouble(self): +    return self.readJSONDouble() + +  def readString(self): +    return self.readJSONString(False) + +  def readBinary(self): +    return self.readJSONBase64() + +  def writeMessageBegin(self, name, request_type, seqid): +    self.resetWriteContext() +    self.writeJSONArrayStart() +    self.writeJSONNumber(VERSION) +    self.writeJSONString(name) +    self.writeJSONNumber(request_type) +    self.writeJSONNumber(seqid) + +  def writeMessageEnd(self): +    self.writeJSONArrayEnd() + +  def writeStructBegin(self, name): +    self.writeJSONObjectStart() + +  def writeStructEnd(self): +    self.writeJSONObjectEnd() + +  def writeFieldBegin(self, name, ttype, id): +    self.writeJSONNumber(id) +    self.writeJSONObjectStart() +    self.writeJSONString(CTYPES[ttype]) + +  def writeFieldEnd(self): +    self.writeJSONObjectEnd() + +  def writeFieldStop(self): +    pass + +  def writeMapBegin(self, ktype, vtype, size): +    self.writeJSONArrayStart() +    self.writeJSONString(CTYPES[ktype]) +    self.writeJSONString(CTYPES[vtype]) +    self.writeJSONNumber(size) +    self.writeJSONObjectStart() + +  def writeMapEnd(self): +    self.writeJSONObjectEnd() +    self.writeJSONArrayEnd() +     +  def writeListBegin(self, etype, size): +    self.writeJSONArrayStart() +    self.writeJSONString(CTYPES[etype]) +    self.writeJSONNumber(size) +     +  def writeListEnd(self): +    self.writeJSONArrayEnd() + +  def writeSetBegin(self, etype, size): +    self.writeJSONArrayStart() +    self.writeJSONString(CTYPES[etype]) +    self.writeJSONNumber(size) +     +  def writeSetEnd(self): +    self.writeJSONArrayEnd() + +  def writeBool(self, boolean): +    self.writeJSONNumber(1 if boolean is True else 0) + +  def writeInteger(self, integer): +    self.writeJSONNumber(integer) +  writeByte = writeInteger +  writeI16 = writeInteger +  writeI32 = writeInteger +  writeI64 = writeInteger + +  def writeDouble(self, dbl): +    self.writeJSONNumber(dbl) + +  def writeString(self, string): +    self.writeJSONString(string) +     +  def writeBinary(self, binary): +    self.writeJSONBase64(binary) + + +class TJSONProtocolFactory: + +  def getProtocol(self, trans): +    return TJSONProtocol(trans) + + +class TSimpleJSONProtocol(TJSONProtocolBase): +    """Simple, readable, write-only JSON protocol. +     +    Useful for interacting with scripting languages. +    """ + +    def readMessageBegin(self): +        raise NotImplementedError() +     +    def readMessageEnd(self): +        raise NotImplementedError() +     +    def readStructBegin(self): +        raise NotImplementedError() +     +    def readStructEnd(self): +        raise NotImplementedError() +     +    def writeMessageBegin(self, name, request_type, seqid): +        self.resetWriteContext() +     +    def writeMessageEnd(self): +        pass +     +    def writeStructBegin(self, name): +        self.writeJSONObjectStart() +     +    def writeStructEnd(self): +        self.writeJSONObjectEnd() +       +    def writeFieldBegin(self, name, ttype, fid): +        self.writeJSONString(name) +     +    def writeFieldEnd(self): +        pass +     +    def writeMapBegin(self, ktype, vtype, size): +        self.writeJSONObjectStart() +     +    def writeMapEnd(self): +        self.writeJSONObjectEnd() +     +    def _writeCollectionBegin(self, etype, size): +        self.writeJSONArrayStart() +     +    def _writeCollectionEnd(self): +        self.writeJSONArrayEnd() +    writeListBegin = _writeCollectionBegin +    writeListEnd = _writeCollectionEnd +    writeSetBegin = _writeCollectionBegin +    writeSetEnd = _writeCollectionEnd + +    def writeInteger(self, integer): +        self.writeJSONNumber(integer) +    writeByte = writeInteger +    writeI16 = writeInteger +    writeI32 = writeInteger +    writeI64 = writeInteger +     +    def writeBool(self, boolean): +        self.writeJSONNumber(1 if boolean is True else 0) + +    def writeDouble(self, dbl): +        self.writeJSONNumber(dbl) +     +    def writeString(self, string): +        self.writeJSONString(string) +       +    def writeBinary(self, binary): +        self.writeJSONBase64(binary) + + +class TSimpleJSONProtocolFactory(object): + +    def getProtocol(self, trans): +        return TSimpleJSONProtocol(trans) diff --git a/module/lib/thrift/protocol/TProtocol.py b/module/lib/thrift/protocol/TProtocol.py index 7338ff68a..dc2b095de 100644 --- a/module/lib/thrift/protocol/TProtocol.py +++ b/module/lib/thrift/protocol/TProtocol.py @@ -19,8 +19,8 @@  from thrift.Thrift import * -class TProtocolException(TException): +class TProtocolException(TException):    """Custom Protocol Exception class"""    UNKNOWN = 0 @@ -33,14 +33,14 @@ class TProtocolException(TException):      TException.__init__(self, message)      self.type = type -class TProtocolBase: +class TProtocolBase:    """Base class for Thrift protocol driver."""    def __init__(self, trans):      self.trans = trans -  def writeMessageBegin(self, name, type, seqid): +  def writeMessageBegin(self, name, ttype, seqid):      pass    def writeMessageEnd(self): @@ -52,7 +52,7 @@ class TProtocolBase:    def writeStructEnd(self):      pass -  def writeFieldBegin(self, name, type, id): +  def writeFieldBegin(self, name, ttype, fid):      pass    def writeFieldEnd(self): @@ -79,7 +79,7 @@ class TProtocolBase:    def writeSetEnd(self):      pass -  def writeBool(self, bool): +  def writeBool(self, bool_val):      pass    def writeByte(self, byte): @@ -97,7 +97,7 @@ class TProtocolBase:    def writeDouble(self, dub):      pass -  def writeString(self, str): +  def writeString(self, str_val):      pass    def readMessageBegin(self): @@ -157,69 +157,69 @@ class TProtocolBase:    def readString(self):      pass -  def skip(self, type): -    if type == TType.STOP: +  def skip(self, ttype): +    if ttype == TType.STOP:        return -    elif type == TType.BOOL: +    elif ttype == TType.BOOL:        self.readBool() -    elif type == TType.BYTE: +    elif ttype == TType.BYTE:        self.readByte() -    elif type == TType.I16: +    elif ttype == TType.I16:        self.readI16() -    elif type == TType.I32: +    elif ttype == TType.I32:        self.readI32() -    elif type == TType.I64: +    elif ttype == TType.I64:        self.readI64() -    elif type == TType.DOUBLE: +    elif ttype == TType.DOUBLE:        self.readDouble() -    elif type == TType.STRING: +    elif ttype == TType.STRING:        self.readString() -    elif type == TType.STRUCT: +    elif ttype == TType.STRUCT:        name = self.readStructBegin()        while True: -        (name, type, id) = self.readFieldBegin() -        if type == TType.STOP: +        (name, ttype, id) = self.readFieldBegin() +        if ttype == TType.STOP:            break -        self.skip(type) +        self.skip(ttype)          self.readFieldEnd()        self.readStructEnd() -    elif type == TType.MAP: +    elif ttype == TType.MAP:        (ktype, vtype, size) = self.readMapBegin() -      for i in range(size): +      for i in xrange(size):          self.skip(ktype)          self.skip(vtype)        self.readMapEnd() -    elif type == TType.SET: +    elif ttype == TType.SET:        (etype, size) = self.readSetBegin() -      for i in range(size): +      for i in xrange(size):          self.skip(etype)        self.readSetEnd() -    elif type == TType.LIST: +    elif ttype == TType.LIST:        (etype, size) = self.readListBegin() -      for i in range(size): +      for i in xrange(size):          self.skip(etype)        self.readListEnd() -  # tuple of: ( 'reader method' name, is_container boolean, 'writer_method' name ) +  # tuple of: ( 'reader method' name, is_container bool, 'writer_method' name )    _TTYPE_HANDLERS = ( -       (None, None, False), # 0 == TType,STOP -       (None, None, False), # 1 == TType.VOID # TODO: handle void? -       ('readBool', 'writeBool', False), # 2 == TType.BOOL -       ('readByte',  'writeByte', False), # 3 == TType.BYTE and I08 -       ('readDouble', 'writeDouble', False), # 4 == TType.DOUBLE -       (None, None, False), # 5, undefined -       ('readI16', 'writeI16', False), # 6 == TType.I16 -       (None, None, False), # 7, undefined -       ('readI32', 'writeI32', False), # 8 == TType.I32 -       (None, None, False), # 9, undefined -       ('readI64', 'writeI64', False), # 10 == TType.I64 -       ('readString', 'writeString', False), # 11 == TType.STRING and UTF7 -       ('readContainerStruct', 'writeContainerStruct', True), # 12 == TType.STRUCT -       ('readContainerMap', 'writeContainerMap', True), # 13 == TType.MAP -       ('readContainerSet', 'writeContainerSet', True), # 14 == TType.SET -       ('readContainerList', 'writeContainerList', True), # 15 == TType.LIST -       (None, None, False), # 16 == TType.UTF8 # TODO: handle utf8 types? -       (None, None, False)# 17 == TType.UTF16 # TODO: handle utf16 types? +       (None, None, False),  # 0 TType.STOP +       (None, None, False),  # 1 TType.VOID # TODO: handle void? +       ('readBool', 'writeBool', False),  # 2 TType.BOOL +       ('readByte',  'writeByte', False),  # 3 TType.BYTE and I08 +       ('readDouble', 'writeDouble', False),  # 4 TType.DOUBLE +       (None, None, False),  # 5 undefined +       ('readI16', 'writeI16', False),  # 6 TType.I16 +       (None, None, False),  # 7 undefined +       ('readI32', 'writeI32', False),  # 8 TType.I32 +       (None, None, False),  # 9 undefined +       ('readI64', 'writeI64', False),  # 10 TType.I64 +       ('readString', 'writeString', False),  # 11 TType.STRING and UTF7 +       ('readContainerStruct', 'writeContainerStruct', True),  # 12 *.STRUCT +       ('readContainerMap', 'writeContainerMap', True),  # 13 TType.MAP +       ('readContainerSet', 'writeContainerSet', True),  # 14 TType.SET +       ('readContainerList', 'writeContainerList', True),  # 15 TType.LIST +       (None, None, False),  # 16 TType.UTF8 # TODO: handle utf8 types? +       (None, None, False)  # 17 TType.UTF16 # TODO: handle utf16 types?        )    def readFieldByTType(self, ttype, spec): @@ -270,7 +270,7 @@ class TProtocolBase:        container_reader = self._TTYPE_HANDLERS[set_type][0]        val_reader = getattr(self, container_reader)        for idx in xrange(set_len): -        results.add(val_reader(tspec))  +        results.add(val_reader(tspec))      self.readSetEnd()      return results @@ -279,13 +279,14 @@ class TProtocolBase:      obj = obj_class()      obj.read(self)      return obj -   +    def readContainerMap(self, spec):      results = dict()      key_ttype, key_spec = spec[0], spec[1]      val_ttype, val_spec = spec[2], spec[3]      (map_ktype, map_vtype, map_len) = self.readMapBegin() -    # TODO: compare types we just decoded with thrift_spec and abort/skip if types disagree +    # TODO: compare types we just decoded with thrift_spec and +    # abort/skip if types disagree      key_reader = getattr(self, self._TTYPE_HANDLERS[key_ttype][0])      val_reader = getattr(self, self._TTYPE_HANDLERS[val_ttype][0])      # list values are simple types @@ -298,7 +299,8 @@ class TProtocolBase:          v_val = val_reader()        else:          v_val = self.readFieldByTType(val_ttype, val_spec) -      # this raises a TypeError with unhashable keys types. i.e. d=dict(); d[[0,1]] = 2 fails +      # this raises a TypeError with unhashable keys types +      # i.e. this fails: d=dict(); d[[0,1]] = 2        results[k_val] = v_val      self.readMapEnd()      return results @@ -329,7 +331,7 @@ class TProtocolBase:    def writeContainerList(self, val, spec):      self.writeListBegin(spec[0], len(val)) -    r_handler, w_handler, is_container  = self._TTYPE_HANDLERS[spec[0]] +    r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]]      e_writer = getattr(self, w_handler)      if not is_container:        for elem in val: @@ -398,7 +400,7 @@ class TProtocolBase:      else:        writer(val) +  class TProtocolFactory:    def getProtocol(self, trans):      pass - diff --git a/module/lib/thrift/protocol/__init__.py b/module/lib/thrift/protocol/__init__.py index d53359b28..7eefb458a 100644 --- a/module/lib/thrift/protocol/__init__.py +++ b/module/lib/thrift/protocol/__init__.py @@ -17,4 +17,4 @@  # under the License.  # -__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary', 'TBase'] +__all__ = ['fastbinary', 'TBase', 'TBinaryProtocol', 'TCompactProtocol', 'TJSONProtocol', 'TProtocol'] diff --git a/module/lib/thrift/protocol/fastbinary.c b/module/lib/thrift/protocol/fastbinary.c new file mode 100644 index 000000000..2ce56603c --- /dev/null +++ b/module/lib/thrift/protocol/fastbinary.c @@ -0,0 +1,1219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + *   http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <Python.h> +#include "cStringIO.h" +#include <stdint.h> +#ifndef _WIN32 +# include <stdbool.h> +# include <netinet/in.h> +#else +# include <WinSock2.h> +# pragma comment (lib, "ws2_32.lib") +# define BIG_ENDIAN (4321) +# define LITTLE_ENDIAN (1234) +# define BYTE_ORDER LITTLE_ENDIAN +# if defined(_MSC_VER) && _MSC_VER < 1600 +   typedef int _Bool; +#  define bool _Bool +#  define false 0  +#  define true 1 +# endif +# define inline __inline +#endif + +/* Fix endianness issues on Solaris */ +#if defined (__SVR4) && defined (__sun) + #if defined(__i386) && !defined(__i386__) +  #define __i386__ + #endif + + #ifndef BIG_ENDIAN +  #define BIG_ENDIAN (4321) + #endif + #ifndef LITTLE_ENDIAN +  #define LITTLE_ENDIAN (1234) + #endif + + /* I386 is LE, even on Solaris */ + #if !defined(BYTE_ORDER) && defined(__i386__) +  #define BYTE_ORDER LITTLE_ENDIAN + #endif +#endif + +// TODO(dreiss): defval appears to be unused.  Look into removing it. +// TODO(dreiss): Make parse_spec_args recursive, and cache the output +//               permanently in the object.  (Malloc and orphan.) +// TODO(dreiss): Why do we need cStringIO for reading, why not just char*? +//               Can cStringIO let us work with a BufferedTransport? +// TODO(dreiss): Don't ignore the rv from cwrite (maybe). + +/* ====== BEGIN UTILITIES ====== */ + +#define INIT_OUTBUF_SIZE 128 + +// Stolen out of TProtocol.h. +// It would be a huge pain to have both get this from one place. +typedef enum TType { +  T_STOP       = 0, +  T_VOID       = 1, +  T_BOOL       = 2, +  T_BYTE       = 3, +  T_I08        = 3, +  T_I16        = 6, +  T_I32        = 8, +  T_U64        = 9, +  T_I64        = 10, +  T_DOUBLE     = 4, +  T_STRING     = 11, +  T_UTF7       = 11, +  T_STRUCT     = 12, +  T_MAP        = 13, +  T_SET        = 14, +  T_LIST       = 15, +  T_UTF8       = 16, +  T_UTF16      = 17 +} TType; + +#ifndef __BYTE_ORDER +# if defined(BYTE_ORDER) && defined(LITTLE_ENDIAN) && defined(BIG_ENDIAN) +#  define __BYTE_ORDER BYTE_ORDER +#  define __LITTLE_ENDIAN LITTLE_ENDIAN +#  define __BIG_ENDIAN BIG_ENDIAN +# else +#  error "Cannot determine endianness" +# endif +#endif + +// Same comment as the enum.  Sorry. +#if __BYTE_ORDER == __BIG_ENDIAN +# define ntohll(n) (n) +# define htonll(n) (n) +#elif __BYTE_ORDER == __LITTLE_ENDIAN +# if defined(__GNUC__) && defined(__GLIBC__) +#  include <byteswap.h> +#  define ntohll(n) bswap_64(n) +#  define htonll(n) bswap_64(n) +# else /* GNUC & GLIBC */ +#  define ntohll(n) ( (((unsigned long long)ntohl(n)) << 32) + ntohl(n >> 32) ) +#  define htonll(n) ( (((unsigned long long)htonl(n)) << 32) + htonl(n >> 32) ) +# endif /* GNUC & GLIBC */ +#else /* __BYTE_ORDER */ +# error "Can't define htonll or ntohll!" +#endif + +// Doing a benchmark shows that interning actually makes a difference, amazingly. +#define INTERN_STRING(value) _intern_ ## value + +#define INT_CONV_ERROR_OCCURRED(v) ( ((v) == -1) && PyErr_Occurred() ) +#define CHECK_RANGE(v, min, max) ( ((v) <= (max)) && ((v) >= (min)) ) + +// Py_ssize_t was not defined before Python 2.5 +#if (PY_VERSION_HEX < 0x02050000) +typedef int Py_ssize_t; +#endif + +/** + * A cache of the spec_args for a set or list, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { +  TType element_type; +  PyObject* typeargs; +} SetListTypeArgs; + +/** + * A cache of the spec_args for a map, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { +  TType ktag; +  TType vtag; +  PyObject* ktypeargs; +  PyObject* vtypeargs; +} MapTypeArgs; + +/** + * A cache of the spec_args for a struct, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { +  PyObject* klass; +  PyObject* spec; +} StructTypeArgs; + +/** + * A cache of the item spec from a struct specification, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { +  int tag; +  TType type; +  PyObject* attrname; +  PyObject* typeargs; +  PyObject* defval; +} StructItemSpec; + +/** + * A cache of the two key attributes of a CReadableTransport, + * so we don't have to keep calling PyObject_GetAttr. + */ +typedef struct { +  PyObject* stringiobuf; +  PyObject* refill_callable; +} DecodeBuffer; + +/** Pointer to interned string to speed up attribute lookup. */ +static PyObject* INTERN_STRING(cstringio_buf); +/** Pointer to interned string to speed up attribute lookup. */ +static PyObject* INTERN_STRING(cstringio_refill); + +static inline bool +check_ssize_t_32(Py_ssize_t len) { +  // error from getting the int +  if (INT_CONV_ERROR_OCCURRED(len)) { +    return false; +  } +  if (!CHECK_RANGE(len, 0, INT32_MAX)) { +    PyErr_SetString(PyExc_OverflowError, "string size out of range"); +    return false; +  } +  return true; +} + +static inline bool +parse_pyint(PyObject* o, int32_t* ret, int32_t min, int32_t max) { +  long val = PyInt_AsLong(o); + +  if (INT_CONV_ERROR_OCCURRED(val)) { +    return false; +  } +  if (!CHECK_RANGE(val, min, max)) { +    PyErr_SetString(PyExc_OverflowError, "int out of range"); +    return false; +  } + +  *ret = (int32_t) val; +  return true; +} + + +/* --- FUNCTIONS TO PARSE STRUCT SPECIFICATOINS --- */ + +static bool +parse_set_list_args(SetListTypeArgs* dest, PyObject* typeargs) { +  if (PyTuple_Size(typeargs) != 2) { +    PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for list/set type args"); +    return false; +  } + +  dest->element_type = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0)); +  if (INT_CONV_ERROR_OCCURRED(dest->element_type)) { +    return false; +  } + +  dest->typeargs = PyTuple_GET_ITEM(typeargs, 1); + +  return true; +} + +static bool +parse_map_args(MapTypeArgs* dest, PyObject* typeargs) { +  if (PyTuple_Size(typeargs) != 4) { +    PyErr_SetString(PyExc_TypeError, "expecting 4 arguments for typeargs to map"); +    return false; +  } + +  dest->ktag = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0)); +  if (INT_CONV_ERROR_OCCURRED(dest->ktag)) { +    return false; +  } + +  dest->vtag = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 2)); +  if (INT_CONV_ERROR_OCCURRED(dest->vtag)) { +    return false; +  } + +  dest->ktypeargs = PyTuple_GET_ITEM(typeargs, 1); +  dest->vtypeargs = PyTuple_GET_ITEM(typeargs, 3); + +  return true; +} + +static bool +parse_struct_args(StructTypeArgs* dest, PyObject* typeargs) { +  if (PyTuple_Size(typeargs) != 2) { +    PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for struct args"); +    return false; +  } + +  dest->klass = PyTuple_GET_ITEM(typeargs, 0); +  dest->spec = PyTuple_GET_ITEM(typeargs, 1); + +  return true; +} + +static int +parse_struct_item_spec(StructItemSpec* dest, PyObject* spec_tuple) { + +  // i'd like to use ParseArgs here, but it seems to be a bottleneck. +  if (PyTuple_Size(spec_tuple) != 5) { +    PyErr_SetString(PyExc_TypeError, "expecting 5 arguments for spec tuple"); +    return false; +  } + +  dest->tag = PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 0)); +  if (INT_CONV_ERROR_OCCURRED(dest->tag)) { +    return false; +  } + +  dest->type = PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 1)); +  if (INT_CONV_ERROR_OCCURRED(dest->type)) { +    return false; +  } + +  dest->attrname = PyTuple_GET_ITEM(spec_tuple, 2); +  dest->typeargs = PyTuple_GET_ITEM(spec_tuple, 3); +  dest->defval = PyTuple_GET_ITEM(spec_tuple, 4); +  return true; +} + +/* ====== END UTILITIES ====== */ + + +/* ====== BEGIN WRITING FUNCTIONS ====== */ + +/* --- LOW-LEVEL WRITING FUNCTIONS --- */ + +static void writeByte(PyObject* outbuf, int8_t val) { +  int8_t net = val; +  PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int8_t)); +} + +static void writeI16(PyObject* outbuf, int16_t val) { +  int16_t net = (int16_t)htons(val); +  PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int16_t)); +} + +static void writeI32(PyObject* outbuf, int32_t val) { +  int32_t net = (int32_t)htonl(val); +  PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int32_t)); +} + +static void writeI64(PyObject* outbuf, int64_t val) { +  int64_t net = (int64_t)htonll(val); +  PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int64_t)); +} + +static void writeDouble(PyObject* outbuf, double dub) { +  // Unfortunately, bitwise_cast doesn't work in C.  Bad C! +  union { +    double f; +    int64_t t; +  } transfer; +  transfer.f = dub; +  writeI64(outbuf, transfer.t); +} + + +/* --- MAIN RECURSIVE OUTPUT FUCNTION -- */ + +static int +output_val(PyObject* output, PyObject* value, TType type, PyObject* typeargs) { +  /* +   * Refcounting Strategy: +   * +   * We assume that elements of the thrift_spec tuple are not going to be +   * mutated, so we don't ref count those at all. Other than that, we try to +   * keep a reference to all the user-created objects while we work with them. +   * output_val assumes that a reference is already held. The *caller* is +   * responsible for handling references +   */ + +  switch (type) { + +  case T_BOOL: { +    int v = PyObject_IsTrue(value); +    if (v == -1) { +      return false; +    } + +    writeByte(output, (int8_t) v); +    break; +  } +  case T_I08: { +    int32_t val; + +    if (!parse_pyint(value, &val, INT8_MIN, INT8_MAX)) { +      return false; +    } + +    writeByte(output, (int8_t) val); +    break; +  } +  case T_I16: { +    int32_t val; + +    if (!parse_pyint(value, &val, INT16_MIN, INT16_MAX)) { +      return false; +    } + +    writeI16(output, (int16_t) val); +    break; +  } +  case T_I32: { +    int32_t val; + +    if (!parse_pyint(value, &val, INT32_MIN, INT32_MAX)) { +      return false; +    } + +    writeI32(output, val); +    break; +  } +  case T_I64: { +    int64_t nval = PyLong_AsLongLong(value); + +    if (INT_CONV_ERROR_OCCURRED(nval)) { +      return false; +    } + +    if (!CHECK_RANGE(nval, INT64_MIN, INT64_MAX)) { +      PyErr_SetString(PyExc_OverflowError, "int out of range"); +      return false; +    } + +    writeI64(output, nval); +    break; +  } + +  case T_DOUBLE: { +    double nval = PyFloat_AsDouble(value); +    if (nval == -1.0 && PyErr_Occurred()) { +      return false; +    } + +    writeDouble(output, nval); +    break; +  } + +  case T_STRING: { +    Py_ssize_t len = PyString_Size(value); + +    if (!check_ssize_t_32(len)) { +      return false; +    } + +    writeI32(output, (int32_t) len); +    PycStringIO->cwrite(output, PyString_AsString(value), (int32_t) len); +    break; +  } + +  case T_LIST: +  case T_SET: { +    Py_ssize_t len; +    SetListTypeArgs parsedargs; +    PyObject *item; +    PyObject *iterator; + +    if (!parse_set_list_args(&parsedargs, typeargs)) { +      return false; +    } + +    len = PyObject_Length(value); + +    if (!check_ssize_t_32(len)) { +      return false; +    } + +    writeByte(output, parsedargs.element_type); +    writeI32(output, (int32_t) len); + +    iterator =  PyObject_GetIter(value); +    if (iterator == NULL) { +      return false; +    } + +    while ((item = PyIter_Next(iterator))) { +      if (!output_val(output, item, parsedargs.element_type, parsedargs.typeargs)) { +        Py_DECREF(item); +        Py_DECREF(iterator); +        return false; +      } +      Py_DECREF(item); +    } + +    Py_DECREF(iterator); + +    if (PyErr_Occurred()) { +      return false; +    } + +    break; +  } + +  case T_MAP: { +    PyObject *k, *v; +    Py_ssize_t pos = 0; +    Py_ssize_t len; + +    MapTypeArgs parsedargs; + +    len = PyDict_Size(value); +    if (!check_ssize_t_32(len)) { +      return false; +    } + +    if (!parse_map_args(&parsedargs, typeargs)) { +      return false; +    } + +    writeByte(output, parsedargs.ktag); +    writeByte(output, parsedargs.vtag); +    writeI32(output, len); + +    // TODO(bmaurer): should support any mapping, not just dicts +    while (PyDict_Next(value, &pos, &k, &v)) { +      // TODO(dreiss): Think hard about whether these INCREFs actually +      //               turn any unsafe scenarios into safe scenarios. +      Py_INCREF(k); +      Py_INCREF(v); + +      if (!output_val(output, k, parsedargs.ktag, parsedargs.ktypeargs) +          || !output_val(output, v, parsedargs.vtag, parsedargs.vtypeargs)) { +        Py_DECREF(k); +        Py_DECREF(v); +        return false; +      } +      Py_DECREF(k); +      Py_DECREF(v); +    } +    break; +  } + +  // TODO(dreiss): Consider breaking this out as a function +  //               the way we did for decode_struct. +  case T_STRUCT: { +    StructTypeArgs parsedargs; +    Py_ssize_t nspec; +    Py_ssize_t i; + +    if (!parse_struct_args(&parsedargs, typeargs)) { +      return false; +    } + +    nspec = PyTuple_Size(parsedargs.spec); + +    if (nspec == -1) { +      return false; +    } + +    for (i = 0; i < nspec; i++) { +      StructItemSpec parsedspec; +      PyObject* spec_tuple; +      PyObject* instval = NULL; + +      spec_tuple = PyTuple_GET_ITEM(parsedargs.spec, i); +      if (spec_tuple == Py_None) { +        continue; +      } + +      if (!parse_struct_item_spec (&parsedspec, spec_tuple)) { +        return false; +      } + +      instval = PyObject_GetAttr(value, parsedspec.attrname); + +      if (!instval) { +        return false; +      } + +      if (instval == Py_None) { +        Py_DECREF(instval); +        continue; +      } + +      writeByte(output, (int8_t) parsedspec.type); +      writeI16(output, parsedspec.tag); + +      if (!output_val(output, instval, parsedspec.type, parsedspec.typeargs)) { +        Py_DECREF(instval); +        return false; +      } + +      Py_DECREF(instval); +    } + +    writeByte(output, (int8_t)T_STOP); +    break; +  } + +  case T_STOP: +  case T_VOID: +  case T_UTF16: +  case T_UTF8: +  case T_U64: +  default: +    PyErr_SetString(PyExc_TypeError, "Unexpected TType"); +    return false; + +  } + +  return true; +} + + +/* --- TOP-LEVEL WRAPPER FOR OUTPUT -- */ + +static PyObject * +encode_binary(PyObject *self, PyObject *args) { +  PyObject* enc_obj; +  PyObject* type_args; +  PyObject* buf; +  PyObject* ret = NULL; + +  if (!PyArg_ParseTuple(args, "OO", &enc_obj, &type_args)) { +    return NULL; +  } + +  buf = PycStringIO->NewOutput(INIT_OUTBUF_SIZE); +  if (output_val(buf, enc_obj, T_STRUCT, type_args)) { +    ret = PycStringIO->cgetvalue(buf); +  } + +  Py_DECREF(buf); +  return ret; +} + +/* ====== END WRITING FUNCTIONS ====== */ + + +/* ====== BEGIN READING FUNCTIONS ====== */ + +/* --- LOW-LEVEL READING FUNCTIONS --- */ + +static void +free_decodebuf(DecodeBuffer* d) { +  Py_XDECREF(d->stringiobuf); +  Py_XDECREF(d->refill_callable); +} + +static bool +decode_buffer_from_obj(DecodeBuffer* dest, PyObject* obj) { +  dest->stringiobuf = PyObject_GetAttr(obj, INTERN_STRING(cstringio_buf)); +  if (!dest->stringiobuf) { +    return false; +  } + +  if (!PycStringIO_InputCheck(dest->stringiobuf)) { +    free_decodebuf(dest); +    PyErr_SetString(PyExc_TypeError, "expecting stringio input"); +    return false; +  } + +  dest->refill_callable = PyObject_GetAttr(obj, INTERN_STRING(cstringio_refill)); + +  if(!dest->refill_callable) { +    free_decodebuf(dest); +    return false; +  } + +  if (!PyCallable_Check(dest->refill_callable)) { +    free_decodebuf(dest); +    PyErr_SetString(PyExc_TypeError, "expecting callable"); +    return false; +  } + +  return true; +} + +static bool readBytes(DecodeBuffer* input, char** output, int len) { +  int read; + +  // TODO(dreiss): Don't fear the malloc.  Think about taking a copy of +  //               the partial read instead of forcing the transport +  //               to prepend it to its buffer. + +  read = PycStringIO->cread(input->stringiobuf, output, len); + +  if (read == len) { +    return true; +  } else if (read == -1) { +    return false; +  } else { +    PyObject* newiobuf; + +    // using building functions as this is a rare codepath +    newiobuf = PyObject_CallFunction( +        input->refill_callable, "s#i", *output, read, len, NULL); +    if (newiobuf == NULL) { +      return false; +    } + +    // must do this *AFTER* the call so that we don't deref the io buffer +    Py_CLEAR(input->stringiobuf); +    input->stringiobuf = newiobuf; + +    read = PycStringIO->cread(input->stringiobuf, output, len); + +    if (read == len) { +      return true; +    } else if (read == -1) { +      return false; +    } else { +      // TODO(dreiss): This could be a valid code path for big binary blobs. +      PyErr_SetString(PyExc_TypeError, +          "refill claimed to have refilled the buffer, but didn't!!"); +      return false; +    } +  } +} + +static int8_t readByte(DecodeBuffer* input) { +  char* buf; +  if (!readBytes(input, &buf, sizeof(int8_t))) { +    return -1; +  } + +  return *(int8_t*) buf; +} + +static int16_t readI16(DecodeBuffer* input) { +  char* buf; +  if (!readBytes(input, &buf, sizeof(int16_t))) { +    return -1; +  } + +  return (int16_t) ntohs(*(int16_t*) buf); +} + +static int32_t readI32(DecodeBuffer* input) { +  char* buf; +  if (!readBytes(input, &buf, sizeof(int32_t))) { +    return -1; +  } +  return (int32_t) ntohl(*(int32_t*) buf); +} + + +static int64_t readI64(DecodeBuffer* input) { +  char* buf; +  if (!readBytes(input, &buf, sizeof(int64_t))) { +    return -1; +  } + +  return (int64_t) ntohll(*(int64_t*) buf); +} + +static double readDouble(DecodeBuffer* input) { +  union { +    int64_t f; +    double t; +  } transfer; + +  transfer.f = readI64(input); +  if (transfer.f == -1) { +    return -1; +  } +  return transfer.t; +} + +static bool +checkTypeByte(DecodeBuffer* input, TType expected) { +  TType got = readByte(input); +  if (INT_CONV_ERROR_OCCURRED(got)) { +    return false; +  } + +  if (expected != got) { +    PyErr_SetString(PyExc_TypeError, "got wrong ttype while reading field"); +    return false; +  } +  return true; +} + +static bool +skip(DecodeBuffer* input, TType type) { +#define SKIPBYTES(n) \ +  do { \ +    if (!readBytes(input, &dummy_buf, (n))) { \ +      return false; \ +    } \ +  } while(0) + +  char* dummy_buf; + +  switch (type) { + +  case T_BOOL: +  case T_I08: SKIPBYTES(1); break; +  case T_I16: SKIPBYTES(2); break; +  case T_I32: SKIPBYTES(4); break; +  case T_I64: +  case T_DOUBLE: SKIPBYTES(8); break; + +  case T_STRING: { +    // TODO(dreiss): Find out if these check_ssize_t32s are really necessary. +    int len = readI32(input); +    if (!check_ssize_t_32(len)) { +      return false; +    } +    SKIPBYTES(len); +    break; +  } + +  case T_LIST: +  case T_SET: { +    TType etype; +    int len, i; + +    etype = readByte(input); +    if (etype == -1) { +      return false; +    } + +    len = readI32(input); +    if (!check_ssize_t_32(len)) { +      return false; +    } + +    for (i = 0; i < len; i++) { +      if (!skip(input, etype)) { +        return false; +      } +    } +    break; +  } + +  case T_MAP: { +    TType ktype, vtype; +    int len, i; + +    ktype = readByte(input); +    if (ktype == -1) { +      return false; +    } + +    vtype = readByte(input); +    if (vtype == -1) { +      return false; +    } + +    len = readI32(input); +    if (!check_ssize_t_32(len)) { +      return false; +    } + +    for (i = 0; i < len; i++) { +      if (!(skip(input, ktype) && skip(input, vtype))) { +        return false; +      } +    } +    break; +  } + +  case T_STRUCT: { +    while (true) { +      TType type; + +      type = readByte(input); +      if (type == -1) { +        return false; +      } + +      if (type == T_STOP) +        break; + +      SKIPBYTES(2); // tag +      if (!skip(input, type)) { +        return false; +      } +    } +    break; +  } + +  case T_STOP: +  case T_VOID: +  case T_UTF16: +  case T_UTF8: +  case T_U64: +  default: +    PyErr_SetString(PyExc_TypeError, "Unexpected TType"); +    return false; + +  } + +  return true; + +#undef SKIPBYTES +} + + +/* --- HELPER FUNCTION FOR DECODE_VAL --- */ + +static PyObject* +decode_val(DecodeBuffer* input, TType type, PyObject* typeargs); + +static bool +decode_struct(DecodeBuffer* input, PyObject* output, PyObject* spec_seq) { +  int spec_seq_len = PyTuple_Size(spec_seq); +  if (spec_seq_len == -1) { +    return false; +  } + +  while (true) { +    TType type; +    int16_t tag; +    PyObject* item_spec; +    PyObject* fieldval = NULL; +    StructItemSpec parsedspec; + +    type = readByte(input); +    if (type == -1) { +      return false; +    } +    if (type == T_STOP) { +      break; +    } +    tag = readI16(input); +    if (INT_CONV_ERROR_OCCURRED(tag)) { +      return false; +    } +    if (tag >= 0 && tag < spec_seq_len) { +      item_spec = PyTuple_GET_ITEM(spec_seq, tag); +    } else { +      item_spec = Py_None; +    } + +    if (item_spec == Py_None) { +      if (!skip(input, type)) { +        return false; +      } else { +        continue; +      } +    } + +    if (!parse_struct_item_spec(&parsedspec, item_spec)) { +      return false; +    } +    if (parsedspec.type != type) { +      if (!skip(input, type)) { +        PyErr_SetString(PyExc_TypeError, "struct field had wrong type while reading and can't be skipped"); +        return false; +      } else { +        continue; +      } +    } + +    fieldval = decode_val(input, parsedspec.type, parsedspec.typeargs); +    if (fieldval == NULL) { +      return false; +    } + +    if (PyObject_SetAttr(output, parsedspec.attrname, fieldval) == -1) { +      Py_DECREF(fieldval); +      return false; +    } +    Py_DECREF(fieldval); +  } +  return true; +} + + +/* --- MAIN RECURSIVE INPUT FUCNTION --- */ + +// Returns a new reference. +static PyObject* +decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) { +  switch (type) { + +  case T_BOOL: { +    int8_t v = readByte(input); +    if (INT_CONV_ERROR_OCCURRED(v)) { +      return NULL; +    } + +    switch (v) { +    case 0: Py_RETURN_FALSE; +    case 1: Py_RETURN_TRUE; +    // Don't laugh.  This is a potentially serious issue. +    default: PyErr_SetString(PyExc_TypeError, "boolean out of range"); return NULL; +    } +    break; +  } +  case T_I08: { +    int8_t v = readByte(input); +    if (INT_CONV_ERROR_OCCURRED(v)) { +      return NULL; +    } + +    return PyInt_FromLong(v); +  } +  case T_I16: { +    int16_t v = readI16(input); +    if (INT_CONV_ERROR_OCCURRED(v)) { +      return NULL; +    } +    return PyInt_FromLong(v); +  } +  case T_I32: { +    int32_t v = readI32(input); +    if (INT_CONV_ERROR_OCCURRED(v)) { +      return NULL; +    } +    return PyInt_FromLong(v); +  } + +  case T_I64: { +    int64_t v = readI64(input); +    if (INT_CONV_ERROR_OCCURRED(v)) { +      return NULL; +    } +    // TODO(dreiss): Find out if we can take this fastpath always when +    //               sizeof(long) == sizeof(long long). +    if (CHECK_RANGE(v, LONG_MIN, LONG_MAX)) { +      return PyInt_FromLong((long) v); +    } + +    return PyLong_FromLongLong(v); +  } + +  case T_DOUBLE: { +    double v = readDouble(input); +    if (v == -1.0 && PyErr_Occurred()) { +      return false; +    } +    return PyFloat_FromDouble(v); +  } + +  case T_STRING: { +    Py_ssize_t len = readI32(input); +    char* buf; +    if (!readBytes(input, &buf, len)) { +      return NULL; +    } + +    return PyString_FromStringAndSize(buf, len); +  } + +  case T_LIST: +  case T_SET: { +    SetListTypeArgs parsedargs; +    int32_t len; +    PyObject* ret = NULL; +    int i; + +    if (!parse_set_list_args(&parsedargs, typeargs)) { +      return NULL; +    } + +    if (!checkTypeByte(input, parsedargs.element_type)) { +      return NULL; +    } + +    len = readI32(input); +    if (!check_ssize_t_32(len)) { +      return NULL; +    } + +    ret = PyList_New(len); +    if (!ret) { +      return NULL; +    } + +    for (i = 0; i < len; i++) { +      PyObject* item = decode_val(input, parsedargs.element_type, parsedargs.typeargs); +      if (!item) { +        Py_DECREF(ret); +        return NULL; +      } +      PyList_SET_ITEM(ret, i, item); +    } + +    // TODO(dreiss): Consider biting the bullet and making two separate cases +    //               for list and set, avoiding this post facto conversion. +    if (type == T_SET) { +      PyObject* setret; +#if (PY_VERSION_HEX < 0x02050000) +      // hack needed for older versions +      setret = PyObject_CallFunctionObjArgs((PyObject*)&PySet_Type, ret, NULL); +#else +      // official version +      setret = PySet_New(ret); +#endif +      Py_DECREF(ret); +      return setret; +    } +    return ret; +  } + +  case T_MAP: { +    int32_t len; +    int i; +    MapTypeArgs parsedargs; +    PyObject* ret = NULL; + +    if (!parse_map_args(&parsedargs, typeargs)) { +      return NULL; +    } + +    if (!checkTypeByte(input, parsedargs.ktag)) { +      return NULL; +    } +    if (!checkTypeByte(input, parsedargs.vtag)) { +      return NULL; +    } + +    len = readI32(input); +    if (!check_ssize_t_32(len)) { +      return false; +    } + +    ret = PyDict_New(); +    if (!ret) { +      goto error; +    } + +    for (i = 0; i < len; i++) { +      PyObject* k = NULL; +      PyObject* v = NULL; +      k = decode_val(input, parsedargs.ktag, parsedargs.ktypeargs); +      if (k == NULL) { +        goto loop_error; +      } +      v = decode_val(input, parsedargs.vtag, parsedargs.vtypeargs); +      if (v == NULL) { +        goto loop_error; +      } +      if (PyDict_SetItem(ret, k, v) == -1) { +        goto loop_error; +      } + +      Py_DECREF(k); +      Py_DECREF(v); +      continue; + +      // Yuck!  Destructors, anyone? +      loop_error: +      Py_XDECREF(k); +      Py_XDECREF(v); +      goto error; +    } + +    return ret; + +    error: +    Py_XDECREF(ret); +    return NULL; +  } + +  case T_STRUCT: { +    StructTypeArgs parsedargs; +	PyObject* ret; +    if (!parse_struct_args(&parsedargs, typeargs)) { +      return NULL; +    } + +    ret = PyObject_CallObject(parsedargs.klass, NULL); +    if (!ret) { +      return NULL; +    } + +    if (!decode_struct(input, ret, parsedargs.spec)) { +      Py_DECREF(ret); +      return NULL; +    } + +    return ret; +  } + +  case T_STOP: +  case T_VOID: +  case T_UTF16: +  case T_UTF8: +  case T_U64: +  default: +    PyErr_SetString(PyExc_TypeError, "Unexpected TType"); +    return NULL; +  } +} + + +/* --- TOP-LEVEL WRAPPER FOR INPUT -- */ + +static PyObject* +decode_binary(PyObject *self, PyObject *args) { +  PyObject* output_obj = NULL; +  PyObject* transport = NULL; +  PyObject* typeargs = NULL; +  StructTypeArgs parsedargs; +  DecodeBuffer input = {0, 0}; +   +  if (!PyArg_ParseTuple(args, "OOO", &output_obj, &transport, &typeargs)) { +    return NULL; +  } + +  if (!parse_struct_args(&parsedargs, typeargs)) { +    return NULL; +  } + +  if (!decode_buffer_from_obj(&input, transport)) { +    return NULL; +  } + +  if (!decode_struct(&input, output_obj, parsedargs.spec)) { +    free_decodebuf(&input); +    return NULL; +  } + +  free_decodebuf(&input); + +  Py_RETURN_NONE; +} + +/* ====== END READING FUNCTIONS ====== */ + + +/* -- PYTHON MODULE SETUP STUFF --- */ + +static PyMethodDef ThriftFastBinaryMethods[] = { + +  {"encode_binary",  encode_binary, METH_VARARGS, ""}, +  {"decode_binary",  decode_binary, METH_VARARGS, ""}, + +  {NULL, NULL, 0, NULL}        /* Sentinel */ +}; + +PyMODINIT_FUNC +initfastbinary(void) { +#define INIT_INTERN_STRING(value) \ +  do { \ +    INTERN_STRING(value) = PyString_InternFromString(#value); \ +    if(!INTERN_STRING(value)) return; \ +  } while(0) + +  INIT_INTERN_STRING(cstringio_buf); +  INIT_INTERN_STRING(cstringio_refill); +#undef INIT_INTERN_STRING + +  PycString_IMPORT; +  if (PycStringIO == NULL) return; + +  (void) Py_InitModule("thrift.protocol.fastbinary", ThriftFastBinaryMethods); +} | 
