diff options
Diffstat (limited to 'module')
22 files changed, 2524 insertions, 293 deletions
diff --git a/module/lib/thrift/TSCons.py b/module/lib/thrift/TSCons.py index 24046256c..da8d2833b 100644 --- a/module/lib/thrift/TSCons.py +++ b/module/lib/thrift/TSCons.py @@ -20,14 +20,16 @@  from os import path  from SCons.Builder import Builder +  def scons_env(env, add=''):    opath = path.dirname(path.abspath('$TARGET'))    lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE' -  cppbuild = Builder(action = lstr) -  env.Append(BUILDERS = {'ThriftCpp' : cppbuild}) +  cppbuild = Builder(action=lstr) +  env.Append(BUILDERS={'ThriftCpp': cppbuild}) +  def gen_cpp(env, dir, file):    scons_env(env)    suffixes = ['_types.h', '_types.cpp']    targets = map(lambda s: 'gen-cpp/' + file + s, suffixes) -  return env.ThriftCpp(targets, dir+file+'.thrift') +  return env.ThriftCpp(targets, dir + file + '.thrift') diff --git a/module/lib/thrift/TSerialization.py b/module/lib/thrift/TSerialization.py index b19f98aa8..8a58d89df 100644 --- a/module/lib/thrift/TSerialization.py +++ b/module/lib/thrift/TSerialization.py @@ -20,15 +20,19 @@  from protocol import TBinaryProtocol  from transport import TTransport -def serialize(thrift_object, protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()): + +def serialize(thrift_object, +              protocol_factory=TBinaryProtocol.TBinaryProtocolFactory()):      transport = TTransport.TMemoryBuffer()      protocol = protocol_factory.getProtocol(transport)      thrift_object.write(protocol)      return transport.getvalue() -def deserialize(base, buf, protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()): + +def deserialize(base, +                buf, +                protocol_factory=TBinaryProtocol.TBinaryProtocolFactory()):      transport = TTransport.TMemoryBuffer(buf)      protocol = protocol_factory.getProtocol(transport)      base.read(protocol)      return base - diff --git a/module/lib/thrift/TTornado.py b/module/lib/thrift/TTornado.py new file mode 100644 index 000000000..af309c3d9 --- /dev/null +++ b/module/lib/thrift/TTornado.py @@ -0,0 +1,153 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from cStringIO import StringIO +import logging +import socket +import struct + +from thrift.transport import TTransport +from thrift.transport.TTransport import TTransportException + +from tornado import gen +from tornado import iostream +from tornado import netutil + + +class TTornadoStreamTransport(TTransport.TTransportBase): +    """a framed, buffered transport over a Tornado stream""" +    def __init__(self, host, port, stream=None): +        self.host = host +        self.port = port +        self.is_queuing_reads = False +        self.read_queue = [] +        self.__wbuf = StringIO() + +        # servers provide a ready-to-go stream +        self.stream = stream +        if self.stream is not None: +            self._set_close_callback() + +    # not the same number of parameters as TTransportBase.open +    def open(self, callback): +        logging.debug('socket connecting') +        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) +        self.stream = iostream.IOStream(sock) + +        def on_close_in_connect(*_): +            message = 'could not connect to {}:{}'.format(self.host, self.port) +            raise TTransportException( +                type=TTransportException.NOT_OPEN, +                message=message) +        self.stream.set_close_callback(on_close_in_connect) + +        def finish(*_): +            self._set_close_callback() +            callback() + +        self.stream.connect((self.host, self.port), callback=finish) + +    def _set_close_callback(self): +        def on_close(): +            raise TTransportException( +                type=TTransportException.END_OF_FILE, +                message='socket closed') +        self.stream.set_close_callback(self.close) + +    def close(self): +        # don't raise if we intend to close +        self.stream.set_close_callback(None) +        self.stream.close() + +    def read(self, _): +        # The generated code for Tornado shouldn't do individual reads -- only +        # frames at a time +        assert "you're doing it wrong" is True + +    @gen.engine +    def readFrame(self, callback): +        self.read_queue.append(callback) +        logging.debug('read queue: %s', self.read_queue) + +        if self.is_queuing_reads: +            # If a read is already in flight, then the while loop below should +            # pull it from self.read_queue +            return + +        self.is_queuing_reads = True +        while self.read_queue: +            next_callback = self.read_queue.pop() +            result = yield gen.Task(self._readFrameFromStream) +            next_callback(result) +        self.is_queuing_reads = False + +    @gen.engine +    def _readFrameFromStream(self, callback): +        logging.debug('_readFrameFromStream') +        frame_header = yield gen.Task(self.stream.read_bytes, 4) +        frame_length, = struct.unpack('!i', frame_header) +        logging.debug('received frame header, frame length = %i', frame_length) +        frame = yield gen.Task(self.stream.read_bytes, frame_length) +        logging.debug('received frame payload') +        callback(frame) + +    def write(self, buf): +        self.__wbuf.write(buf) + +    def flush(self, callback=None): +        wout = self.__wbuf.getvalue() +        wsz = len(wout) +        # reset wbuf before write/flush to preserve state on underlying failure +        self.__wbuf = StringIO() +        # N.B.: Doing this string concatenation is WAY cheaper than making +        # two separate calls to the underlying socket object. Socket writes in +        # Python turn out to be REALLY expensive, but it seems to do a pretty +        # good job of managing string buffer operations without excessive copies +        buf = struct.pack("!i", wsz) + wout + +        logging.debug('writing frame length = %i', wsz) +        self.stream.write(buf, callback) + + +class TTornadoServer(netutil.TCPServer): +    def __init__(self, processor, iprot_factory, oprot_factory=None, +                 *args, **kwargs): +        super(TTornadoServer, self).__init__(*args, **kwargs) + +        self._processor = processor +        self._iprot_factory = iprot_factory +        self._oprot_factory = (oprot_factory if oprot_factory is not None +                               else iprot_factory) + +    def handle_stream(self, stream, address): +        try: +            host, port = address +            trans = TTornadoStreamTransport(host=host, port=port, stream=stream) +            oprot = self._oprot_factory.getProtocol(trans) + +            def next_pass(): +                if not trans.stream.closed(): +                    self._processor.process(trans, self._iprot_factory, oprot, +                                            callback=next_pass) + +            next_pass() + +        except Exception: +            logging.exception('thrift exception in handle_stream') +            trans.close() diff --git a/module/lib/thrift/Thrift.py b/module/lib/thrift/Thrift.py index 1d271fcff..9890af7e1 100644 --- a/module/lib/thrift/Thrift.py +++ b/module/lib/thrift/Thrift.py @@ -19,6 +19,7 @@  import sys +  class TType:    STOP   = 0    VOID   = 1 @@ -38,7 +39,7 @@ class TType:    UTF8   = 16    UTF16  = 17 -  _VALUES_TO_NAMES = ( 'STOP', +  _VALUES_TO_NAMES = ('STOP',                        'VOID',                        'BOOL',                        'BYTE', @@ -48,46 +49,48 @@ class TType:                        None,                        'I32',                        None, -                       'I64', -                       'STRING', -                       'STRUCT', -                       'MAP', -                       'SET', -                       'LIST', -                       'UTF8', -                       'UTF16' ) +                     'I64', +                     'STRING', +                     'STRUCT', +                     'MAP', +                     'SET', +                     'LIST', +                     'UTF8', +                     'UTF16') +  class TMessageType: -  CALL  = 1 +  CALL = 1    REPLY = 2    EXCEPTION = 3    ONEWAY = 4 -class TProcessor: +class TProcessor:    """Base class for procsessor, which works on two streams."""    def process(iprot, oprot):      pass -class TException(Exception): +class TException(Exception):    """Base class for all thrift exceptions."""    # BaseException.message is deprecated in Python v[2.6,3.0) -  if (2,6,0) <= sys.version_info < (3,0): +  if (2, 6, 0) <= sys.version_info < (3, 0):      def _get_message(self): -	    return self._message +      return self._message +      def _set_message(self, message): -	    self._message = message +      self._message = message      message = property(_get_message, _set_message)    def __init__(self, message=None):      Exception.__init__(self, message)      self.message = message -class TApplicationException(TException): +class TApplicationException(TException):    """Application level thrift exceptions."""    UNKNOWN = 0 @@ -98,6 +101,9 @@ class TApplicationException(TException):    MISSING_RESULT = 5    INTERNAL_ERROR = 6    PROTOCOL_ERROR = 7 +  INVALID_TRANSFORM = 8 +  INVALID_PROTOCOL = 9 +  UNSUPPORTED_CLIENT_TYPE = 10    def __init__(self, type=UNKNOWN, message=None):      TException.__init__(self, message) @@ -116,6 +122,16 @@ class TApplicationException(TException):        return 'Bad sequence ID'      elif self.type == self.MISSING_RESULT:        return 'Missing result' +    elif self.type == self.INTERNAL_ERROR: +      return 'Internal error' +    elif self.type == self.PROTOCOL_ERROR: +      return 'Protocol error' +    elif self.type == self.INVALID_TRANSFORM: +      return 'Invalid transform' +    elif self.type == self.INVALID_PROTOCOL: +      return 'Invalid protocol' +    elif self.type == self.UNSUPPORTED_CLIENT_TYPE: +      return 'Unsupported client type'      else:        return 'Default (unknown) TApplicationException' @@ -127,12 +143,12 @@ class TApplicationException(TException):          break        if fid == 1:          if ftype == TType.STRING: -          self.message = iprot.readString(); +          self.message = iprot.readString()          else:            iprot.skip(ftype)        elif fid == 2:          if ftype == TType.I32: -          self.type = iprot.readI32(); +          self.type = iprot.readI32()          else:            iprot.skip(ftype)        else: @@ -142,11 +158,11 @@ class TApplicationException(TException):    def write(self, oprot):      oprot.writeStructBegin('TApplicationException') -    if self.message != None: +    if self.message is not None:        oprot.writeFieldBegin('message', TType.STRING, 1)        oprot.writeString(self.message)        oprot.writeFieldEnd() -    if self.type != None: +    if self.type is not None:        oprot.writeFieldBegin('type', TType.I32, 2)        oprot.writeI32(self.type)        oprot.writeFieldEnd() diff --git a/module/lib/thrift/protocol/TBase.py b/module/lib/thrift/protocol/TBase.py index e675c7dc0..6cbd5f39a 100644 --- a/module/lib/thrift/protocol/TBase.py +++ b/module/lib/thrift/protocol/TBase.py @@ -26,12 +26,13 @@ try:  except:    fastbinary = None +  class TBase(object):    __slots__ = []    def __repr__(self):      L = ['%s=%r' % (key, getattr(self, key)) -              for key in self.__slots__ ] +              for key in self.__slots__]      return '%s(%s)' % (self.__class__.__name__, ', '.join(L))    def __eq__(self, other): @@ -43,30 +44,38 @@ class TBase(object):        if my_val != other_val:          return False      return True -     +    def __ne__(self, other):      return not (self == other) -   +    def read(self, iprot): -    if iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None and fastbinary is not None: -      fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec)) +    if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and +        isinstance(iprot.trans, TTransport.CReadableTransport) and +        self.thrift_spec is not None and +        fastbinary is not None): +      fastbinary.decode_binary(self, +                               iprot.trans, +                               (self.__class__, self.thrift_spec))        return      iprot.readStruct(self, self.thrift_spec)    def write(self, oprot): -    if oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and self.thrift_spec is not None and fastbinary is not None: -      oprot.trans.write(fastbinary.encode_binary(self, (self.__class__, self.thrift_spec))) +    if (oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and +        self.thrift_spec is not None and +        fastbinary is not None): +      oprot.trans.write( +        fastbinary.encode_binary(self, (self.__class__, self.thrift_spec)))        return      oprot.writeStruct(self, self.thrift_spec) +  class TExceptionBase(Exception):    # old style class so python2.4 can raise exceptions derived from this    #  This can't inherit from TBase because of that limitation.    __slots__ = [] -   +    __repr__ = TBase.__repr__.im_func    __eq__ = TBase.__eq__.im_func    __ne__ = TBase.__ne__.im_func    read = TBase.read.im_func    write = TBase.write.im_func -   diff --git a/module/lib/thrift/protocol/TBinaryProtocol.py b/module/lib/thrift/protocol/TBinaryProtocol.py index 50c6aa896..6fdd08c26 100644 --- a/module/lib/thrift/protocol/TBinaryProtocol.py +++ b/module/lib/thrift/protocol/TBinaryProtocol.py @@ -20,8 +20,8 @@  from TProtocol import *  from struct import pack, unpack -class TBinaryProtocol(TProtocolBase): +class TBinaryProtocol(TProtocolBase):    """Binary implementation of the Thrift protocol driver."""    # NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be @@ -68,7 +68,7 @@ class TBinaryProtocol(TProtocolBase):      pass    def writeFieldStop(self): -    self.writeByte(TType.STOP); +    self.writeByte(TType.STOP)    def writeMapBegin(self, ktype, vtype, size):      self.writeByte(ktype) @@ -127,13 +127,16 @@ class TBinaryProtocol(TProtocolBase):      if sz < 0:        version = sz & TBinaryProtocol.VERSION_MASK        if version != TBinaryProtocol.VERSION_1: -        raise TProtocolException(type=TProtocolException.BAD_VERSION, message='Bad version in readMessageBegin: %d' % (sz)) +        raise TProtocolException( +          type=TProtocolException.BAD_VERSION, +          message='Bad version in readMessageBegin: %d' % (sz))        type = sz & TBinaryProtocol.TYPE_MASK        name = self.readString()        seqid = self.readI32()      else:        if self.strictRead: -        raise TProtocolException(type=TProtocolException.BAD_VERSION, message='No protocol version header') +        raise TProtocolException(type=TProtocolException.BAD_VERSION, +                                 message='No protocol version header')        name = self.trans.readAll(sz)        type = self.readByte()        seqid = self.readI32() @@ -231,7 +234,6 @@ class TBinaryProtocolFactory:  class TBinaryProtocolAccelerated(TBinaryProtocol): -    """C-Accelerated version of TBinaryProtocol.    This class does not override any of TBinaryProtocol's methods, @@ -250,7 +252,6 @@ class TBinaryProtocolAccelerated(TBinaryProtocol):           Please feel free to report bugs and/or success stories           to the public mailing list.    """ -    pass diff --git a/module/lib/thrift/protocol/TCompactProtocol.py b/module/lib/thrift/protocol/TCompactProtocol.py index 016a33171..cdec60773 100644 --- a/module/lib/thrift/protocol/TCompactProtocol.py +++ b/module/lib/thrift/protocol/TCompactProtocol.py @@ -32,6 +32,7 @@ CONTAINER_READ = 6  VALUE_READ = 7  BOOL_READ = 8 +  def make_helper(v_from, container):    def helper(func):      def nested(self, *args, **kwargs): @@ -42,12 +43,15 @@ def make_helper(v_from, container):  writer = make_helper(VALUE_WRITE, CONTAINER_WRITE)  reader = make_helper(VALUE_READ, CONTAINER_READ) +  def makeZigZag(n, bits):    return (n << 1) ^ (n >> (bits - 1)) +  def fromZigZag(n):    return (n >> 1) ^ -(n & 1) +  def writeVarint(trans, n):    out = []    while True: @@ -59,6 +63,7 @@ def writeVarint(trans, n):        n = n >> 7    trans.write(''.join(map(chr, out))) +  def readVarint(trans):    result = 0    shift = 0 @@ -70,6 +75,7 @@ def readVarint(trans):        return result      shift += 7 +  class CompactType:    STOP = 0x00    TRUE = 0x01 @@ -86,7 +92,7 @@ class CompactType:    STRUCT = 0x0C  CTYPES = {TType.STOP: CompactType.STOP, -          TType.BOOL: CompactType.TRUE, # used for collection +          TType.BOOL: CompactType.TRUE,  # used for collection            TType.BYTE: CompactType.BYTE,            TType.I16: CompactType.I16,            TType.I32: CompactType.I32, @@ -106,8 +112,9 @@ TTYPES[CompactType.FALSE] = TType.BOOL  del k  del v +  class TCompactProtocol(TProtocolBase): -  "Compact implementation of the Thrift protocol driver." +  """Compact implementation of the Thrift protocol driver."""    PROTOCOL_ID = 0x82    VERSION = 1 @@ -217,18 +224,18 @@ class TCompactProtocol(TProtocolBase):    def writeBool(self, bool):      if self.state == BOOL_WRITE: -        if bool: -            ctype = CompactType.TRUE -        else: -            ctype = CompactType.FALSE -        self.__writeFieldHeader(ctype, self.__bool_fid) +      if bool: +        ctype = CompactType.TRUE +      else: +        ctype = CompactType.FALSE +      self.__writeFieldHeader(ctype, self.__bool_fid)      elif self.state == CONTAINER_WRITE: -       if bool: -           self.__writeByte(CompactType.TRUE) -       else: -           self.__writeByte(CompactType.FALSE) +      if bool: +        self.__writeByte(CompactType.TRUE) +      else: +        self.__writeByte(CompactType.FALSE)      else: -      raise AssertionError, "Invalid state in compact protocol" +      raise AssertionError("Invalid state in compact protocol")    writeByte = writer(__writeByte)    writeI16 = writer(__writeI16) @@ -364,7 +371,8 @@ class TCompactProtocol(TProtocolBase):      elif self.state == CONTAINER_READ:        return self.__readByte() == CompactType.TRUE      else: -      raise AssertionError, "Invalid state in compact protocol: %d" % self.state +      raise AssertionError("Invalid state in compact protocol: %d" % +                           self.state)    readByte = reader(__readByte)    __readI16 = __readZigZag diff --git a/module/lib/thrift/protocol/TJSONProtocol.py b/module/lib/thrift/protocol/TJSONProtocol.py new file mode 100644 index 000000000..3048197d4 --- /dev/null +++ b/module/lib/thrift/protocol/TJSONProtocol.py @@ -0,0 +1,550 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from TProtocol import TType, TProtocolBase, TProtocolException +import base64 +import json +import math + +__all__ = ['TJSONProtocol', +           'TJSONProtocolFactory', +           'TSimpleJSONProtocol', +           'TSimpleJSONProtocolFactory'] + +VERSION = 1 + +COMMA = ',' +COLON = ':' +LBRACE = '{' +RBRACE = '}' +LBRACKET = '[' +RBRACKET = ']' +QUOTE = '"' +BACKSLASH = '\\' +ZERO = '0' + +ESCSEQ = '\\u00' +ESCAPE_CHAR = '"\\bfnrt' +ESCAPE_CHAR_VALS = ['"', '\\', '\b', '\f', '\n', '\r', '\t'] +NUMERIC_CHAR = '+-.0123456789Ee' + +CTYPES = {TType.BOOL:       'tf', +          TType.BYTE:       'i8', +          TType.I16:        'i16', +          TType.I32:        'i32', +          TType.I64:        'i64', +          TType.DOUBLE:     'dbl', +          TType.STRING:     'str', +          TType.STRUCT:     'rec', +          TType.LIST:       'lst', +          TType.SET:        'set', +          TType.MAP:        'map'} + +JTYPES = {} +for key in CTYPES.keys(): +  JTYPES[CTYPES[key]] = key + + +class JSONBaseContext(object): + +  def __init__(self, protocol): +    self.protocol = protocol +    self.first = True + +  def doIO(self, function): +    pass +   +  def write(self): +    pass + +  def read(self): +    pass + +  def escapeNum(self): +    return False + +  def __str__(self): +    return self.__class__.__name__ + + +class JSONListContext(JSONBaseContext): +     +  def doIO(self, function): +    if self.first is True: +      self.first = False +    else: +      function(COMMA) + +  def write(self): +    self.doIO(self.protocol.trans.write) + +  def read(self): +    self.doIO(self.protocol.readJSONSyntaxChar) + + +class JSONPairContext(JSONBaseContext): +   +  def __init__(self, protocol): +    super(JSONPairContext, self).__init__(protocol) +    self.colon = True + +  def doIO(self, function): +    if self.first: +      self.first = False +      self.colon = True +    else: +      function(COLON if self.colon else COMMA) +      self.colon = not self.colon + +  def write(self): +    self.doIO(self.protocol.trans.write) + +  def read(self): +    self.doIO(self.protocol.readJSONSyntaxChar) + +  def escapeNum(self): +    return self.colon + +  def __str__(self): +    return '%s, colon=%s' % (self.__class__.__name__, self.colon) + + +class LookaheadReader(): +  hasData = False +  data = '' + +  def __init__(self, protocol): +    self.protocol = protocol + +  def read(self): +    if self.hasData is True: +      self.hasData = False +    else: +      self.data = self.protocol.trans.read(1) +    return self.data + +  def peek(self): +    if self.hasData is False: +      self.data = self.protocol.trans.read(1) +    self.hasData = True +    return self.data + +class TJSONProtocolBase(TProtocolBase): + +  def __init__(self, trans): +    TProtocolBase.__init__(self, trans) +    self.resetWriteContext() +    self.resetReadContext() + +  def resetWriteContext(self): +    self.context = JSONBaseContext(self) +    self.contextStack = [self.context] + +  def resetReadContext(self): +    self.resetWriteContext() +    self.reader = LookaheadReader(self) + +  def pushContext(self, ctx): +    self.contextStack.append(ctx) +    self.context = ctx + +  def popContext(self): +    self.contextStack.pop() +    if self.contextStack: +      self.context = self.contextStack[-1] +    else: +      self.context = JSONBaseContext(self) + +  def writeJSONString(self, string): +    self.context.write() +    self.trans.write(json.dumps(string)) + +  def writeJSONNumber(self, number): +    self.context.write() +    jsNumber = str(number) +    if self.context.escapeNum(): +      jsNumber = "%s%s%s" % (QUOTE, jsNumber,  QUOTE) +    self.trans.write(jsNumber) + +  def writeJSONBase64(self, binary): +    self.context.write() +    self.trans.write(QUOTE) +    self.trans.write(base64.b64encode(binary)) +    self.trans.write(QUOTE) + +  def writeJSONObjectStart(self): +    self.context.write() +    self.trans.write(LBRACE) +    self.pushContext(JSONPairContext(self)) + +  def writeJSONObjectEnd(self): +    self.popContext() +    self.trans.write(RBRACE) + +  def writeJSONArrayStart(self): +    self.context.write() +    self.trans.write(LBRACKET) +    self.pushContext(JSONListContext(self)) + +  def writeJSONArrayEnd(self): +    self.popContext() +    self.trans.write(RBRACKET) + +  def readJSONSyntaxChar(self, character): +    current = self.reader.read() +    if character != current: +      raise TProtocolException(TProtocolException.INVALID_DATA, +                               "Unexpected character: %s" % current) + +  def readJSONString(self, skipContext): +    string = [] +    if skipContext is False: +      self.context.read() +    self.readJSONSyntaxChar(QUOTE) +    while True: +      character = self.reader.read() +      if character == QUOTE: +        break +      if character == ESCSEQ[0]: +        character = self.reader.read() +        if character == ESCSEQ[1]: +          self.readJSONSyntaxChar(ZERO) +          self.readJSONSyntaxChar(ZERO) +          character = json.JSONDecoder().decode('"\u00%s"' % self.trans.read(2)) +        else: +          off = ESCAPE_CHAR.find(character) +          if off == -1: +            raise TProtocolException(TProtocolException.INVALID_DATA, +                                     "Expected control char") +          character = ESCAPE_CHAR_VALS[off] +      string.append(character) +    return ''.join(string) + +  def isJSONNumeric(self, character): +    return (True if NUMERIC_CHAR.find(character) != - 1 else False) + +  def readJSONQuotes(self): +    if (self.context.escapeNum()): +      self.readJSONSyntaxChar(QUOTE) + +  def readJSONNumericChars(self): +    numeric = [] +    while True: +      character = self.reader.peek() +      if self.isJSONNumeric(character) is False: +        break +      numeric.append(self.reader.read()) +    return ''.join(numeric) + +  def readJSONInteger(self): +    self.context.read() +    self.readJSONQuotes() +    numeric = self.readJSONNumericChars() +    self.readJSONQuotes() +    try: +      return int(numeric) +    except ValueError: +      raise TProtocolException(TProtocolException.INVALID_DATA, +                               "Bad data encounted in numeric data") + +  def readJSONDouble(self): +    self.context.read() +    if self.reader.peek() == QUOTE: +      string  = self.readJSONString(True) +      try: +        double = float(string) +        if (self.context.escapeNum is False and +            not math.isinf(double) and +            not math.isnan(double)): +          raise TProtocolException(TProtocolException.INVALID_DATA, +                                   "Numeric data unexpectedly quoted") +        return double +      except ValueError: +        raise TProtocolException(TProtocolException.INVALID_DATA, +                                 "Bad data encounted in numeric data") +    else: +      if self.context.escapeNum() is True: +        self.readJSONSyntaxChar(QUOTE) +      try: +        return float(self.readJSONNumericChars()) +      except ValueError: +        raise TProtocolException(TProtocolException.INVALID_DATA, +                                 "Bad data encounted in numeric data") + +  def readJSONBase64(self): +    string = self.readJSONString(False) +    return base64.b64decode(string) + +  def readJSONObjectStart(self): +    self.context.read() +    self.readJSONSyntaxChar(LBRACE) +    self.pushContext(JSONPairContext(self)) + +  def readJSONObjectEnd(self): +    self.readJSONSyntaxChar(RBRACE) +    self.popContext() + +  def readJSONArrayStart(self): +    self.context.read() +    self.readJSONSyntaxChar(LBRACKET) +    self.pushContext(JSONListContext(self)) + +  def readJSONArrayEnd(self): +    self.readJSONSyntaxChar(RBRACKET) +    self.popContext() + + +class TJSONProtocol(TJSONProtocolBase): + +  def readMessageBegin(self): +    self.resetReadContext() +    self.readJSONArrayStart() +    if self.readJSONInteger() != VERSION: +      raise TProtocolException(TProtocolException.BAD_VERSION, +                               "Message contained bad version.") +    name = self.readJSONString(False) +    typen = self.readJSONInteger() +    seqid = self.readJSONInteger() +    return (name, typen, seqid) + +  def readMessageEnd(self): +    self.readJSONArrayEnd() + +  def readStructBegin(self): +    self.readJSONObjectStart() + +  def readStructEnd(self): +    self.readJSONObjectEnd() + +  def readFieldBegin(self): +    character = self.reader.peek() +    ttype = 0 +    id = 0 +    if character == RBRACE: +      ttype = TType.STOP +    else: +      id = self.readJSONInteger() +      self.readJSONObjectStart() +      ttype = JTYPES[self.readJSONString(False)] +    return (None, ttype, id) + +  def readFieldEnd(self): +    self.readJSONObjectEnd() + +  def readMapBegin(self): +    self.readJSONArrayStart() +    keyType = JTYPES[self.readJSONString(False)] +    valueType = JTYPES[self.readJSONString(False)] +    size = self.readJSONInteger() +    self.readJSONObjectStart() +    return (keyType, valueType, size) + +  def readMapEnd(self): +    self.readJSONObjectEnd() +    self.readJSONArrayEnd() + +  def readCollectionBegin(self): +    self.readJSONArrayStart() +    elemType = JTYPES[self.readJSONString(False)] +    size = self.readJSONInteger() +    return (elemType, size) +  readListBegin = readCollectionBegin +  readSetBegin = readCollectionBegin + +  def readCollectionEnd(self): +    self.readJSONArrayEnd() +  readSetEnd = readCollectionEnd +  readListEnd = readCollectionEnd + +  def readBool(self): +    return (False if self.readJSONInteger() == 0 else True) + +  def readNumber(self): +    return self.readJSONInteger() +  readByte = readNumber +  readI16 = readNumber +  readI32 = readNumber +  readI64 = readNumber + +  def readDouble(self): +    return self.readJSONDouble() + +  def readString(self): +    return self.readJSONString(False) + +  def readBinary(self): +    return self.readJSONBase64() + +  def writeMessageBegin(self, name, request_type, seqid): +    self.resetWriteContext() +    self.writeJSONArrayStart() +    self.writeJSONNumber(VERSION) +    self.writeJSONString(name) +    self.writeJSONNumber(request_type) +    self.writeJSONNumber(seqid) + +  def writeMessageEnd(self): +    self.writeJSONArrayEnd() + +  def writeStructBegin(self, name): +    self.writeJSONObjectStart() + +  def writeStructEnd(self): +    self.writeJSONObjectEnd() + +  def writeFieldBegin(self, name, ttype, id): +    self.writeJSONNumber(id) +    self.writeJSONObjectStart() +    self.writeJSONString(CTYPES[ttype]) + +  def writeFieldEnd(self): +    self.writeJSONObjectEnd() + +  def writeFieldStop(self): +    pass + +  def writeMapBegin(self, ktype, vtype, size): +    self.writeJSONArrayStart() +    self.writeJSONString(CTYPES[ktype]) +    self.writeJSONString(CTYPES[vtype]) +    self.writeJSONNumber(size) +    self.writeJSONObjectStart() + +  def writeMapEnd(self): +    self.writeJSONObjectEnd() +    self.writeJSONArrayEnd() +     +  def writeListBegin(self, etype, size): +    self.writeJSONArrayStart() +    self.writeJSONString(CTYPES[etype]) +    self.writeJSONNumber(size) +     +  def writeListEnd(self): +    self.writeJSONArrayEnd() + +  def writeSetBegin(self, etype, size): +    self.writeJSONArrayStart() +    self.writeJSONString(CTYPES[etype]) +    self.writeJSONNumber(size) +     +  def writeSetEnd(self): +    self.writeJSONArrayEnd() + +  def writeBool(self, boolean): +    self.writeJSONNumber(1 if boolean is True else 0) + +  def writeInteger(self, integer): +    self.writeJSONNumber(integer) +  writeByte = writeInteger +  writeI16 = writeInteger +  writeI32 = writeInteger +  writeI64 = writeInteger + +  def writeDouble(self, dbl): +    self.writeJSONNumber(dbl) + +  def writeString(self, string): +    self.writeJSONString(string) +     +  def writeBinary(self, binary): +    self.writeJSONBase64(binary) + + +class TJSONProtocolFactory: + +  def getProtocol(self, trans): +    return TJSONProtocol(trans) + + +class TSimpleJSONProtocol(TJSONProtocolBase): +    """Simple, readable, write-only JSON protocol. +     +    Useful for interacting with scripting languages. +    """ + +    def readMessageBegin(self): +        raise NotImplementedError() +     +    def readMessageEnd(self): +        raise NotImplementedError() +     +    def readStructBegin(self): +        raise NotImplementedError() +     +    def readStructEnd(self): +        raise NotImplementedError() +     +    def writeMessageBegin(self, name, request_type, seqid): +        self.resetWriteContext() +     +    def writeMessageEnd(self): +        pass +     +    def writeStructBegin(self, name): +        self.writeJSONObjectStart() +     +    def writeStructEnd(self): +        self.writeJSONObjectEnd() +       +    def writeFieldBegin(self, name, ttype, fid): +        self.writeJSONString(name) +     +    def writeFieldEnd(self): +        pass +     +    def writeMapBegin(self, ktype, vtype, size): +        self.writeJSONObjectStart() +     +    def writeMapEnd(self): +        self.writeJSONObjectEnd() +     +    def _writeCollectionBegin(self, etype, size): +        self.writeJSONArrayStart() +     +    def _writeCollectionEnd(self): +        self.writeJSONArrayEnd() +    writeListBegin = _writeCollectionBegin +    writeListEnd = _writeCollectionEnd +    writeSetBegin = _writeCollectionBegin +    writeSetEnd = _writeCollectionEnd + +    def writeInteger(self, integer): +        self.writeJSONNumber(integer) +    writeByte = writeInteger +    writeI16 = writeInteger +    writeI32 = writeInteger +    writeI64 = writeInteger +     +    def writeBool(self, boolean): +        self.writeJSONNumber(1 if boolean is True else 0) + +    def writeDouble(self, dbl): +        self.writeJSONNumber(dbl) +     +    def writeString(self, string): +        self.writeJSONString(string) +       +    def writeBinary(self, binary): +        self.writeJSONBase64(binary) + + +class TSimpleJSONProtocolFactory(object): + +    def getProtocol(self, trans): +        return TSimpleJSONProtocol(trans) diff --git a/module/lib/thrift/protocol/TProtocol.py b/module/lib/thrift/protocol/TProtocol.py index 7338ff68a..dc2b095de 100644 --- a/module/lib/thrift/protocol/TProtocol.py +++ b/module/lib/thrift/protocol/TProtocol.py @@ -19,8 +19,8 @@  from thrift.Thrift import * -class TProtocolException(TException): +class TProtocolException(TException):    """Custom Protocol Exception class"""    UNKNOWN = 0 @@ -33,14 +33,14 @@ class TProtocolException(TException):      TException.__init__(self, message)      self.type = type -class TProtocolBase: +class TProtocolBase:    """Base class for Thrift protocol driver."""    def __init__(self, trans):      self.trans = trans -  def writeMessageBegin(self, name, type, seqid): +  def writeMessageBegin(self, name, ttype, seqid):      pass    def writeMessageEnd(self): @@ -52,7 +52,7 @@ class TProtocolBase:    def writeStructEnd(self):      pass -  def writeFieldBegin(self, name, type, id): +  def writeFieldBegin(self, name, ttype, fid):      pass    def writeFieldEnd(self): @@ -79,7 +79,7 @@ class TProtocolBase:    def writeSetEnd(self):      pass -  def writeBool(self, bool): +  def writeBool(self, bool_val):      pass    def writeByte(self, byte): @@ -97,7 +97,7 @@ class TProtocolBase:    def writeDouble(self, dub):      pass -  def writeString(self, str): +  def writeString(self, str_val):      pass    def readMessageBegin(self): @@ -157,69 +157,69 @@ class TProtocolBase:    def readString(self):      pass -  def skip(self, type): -    if type == TType.STOP: +  def skip(self, ttype): +    if ttype == TType.STOP:        return -    elif type == TType.BOOL: +    elif ttype == TType.BOOL:        self.readBool() -    elif type == TType.BYTE: +    elif ttype == TType.BYTE:        self.readByte() -    elif type == TType.I16: +    elif ttype == TType.I16:        self.readI16() -    elif type == TType.I32: +    elif ttype == TType.I32:        self.readI32() -    elif type == TType.I64: +    elif ttype == TType.I64:        self.readI64() -    elif type == TType.DOUBLE: +    elif ttype == TType.DOUBLE:        self.readDouble() -    elif type == TType.STRING: +    elif ttype == TType.STRING:        self.readString() -    elif type == TType.STRUCT: +    elif ttype == TType.STRUCT:        name = self.readStructBegin()        while True: -        (name, type, id) = self.readFieldBegin() -        if type == TType.STOP: +        (name, ttype, id) = self.readFieldBegin() +        if ttype == TType.STOP:            break -        self.skip(type) +        self.skip(ttype)          self.readFieldEnd()        self.readStructEnd() -    elif type == TType.MAP: +    elif ttype == TType.MAP:        (ktype, vtype, size) = self.readMapBegin() -      for i in range(size): +      for i in xrange(size):          self.skip(ktype)          self.skip(vtype)        self.readMapEnd() -    elif type == TType.SET: +    elif ttype == TType.SET:        (etype, size) = self.readSetBegin() -      for i in range(size): +      for i in xrange(size):          self.skip(etype)        self.readSetEnd() -    elif type == TType.LIST: +    elif ttype == TType.LIST:        (etype, size) = self.readListBegin() -      for i in range(size): +      for i in xrange(size):          self.skip(etype)        self.readListEnd() -  # tuple of: ( 'reader method' name, is_container boolean, 'writer_method' name ) +  # tuple of: ( 'reader method' name, is_container bool, 'writer_method' name )    _TTYPE_HANDLERS = ( -       (None, None, False), # 0 == TType,STOP -       (None, None, False), # 1 == TType.VOID # TODO: handle void? -       ('readBool', 'writeBool', False), # 2 == TType.BOOL -       ('readByte',  'writeByte', False), # 3 == TType.BYTE and I08 -       ('readDouble', 'writeDouble', False), # 4 == TType.DOUBLE -       (None, None, False), # 5, undefined -       ('readI16', 'writeI16', False), # 6 == TType.I16 -       (None, None, False), # 7, undefined -       ('readI32', 'writeI32', False), # 8 == TType.I32 -       (None, None, False), # 9, undefined -       ('readI64', 'writeI64', False), # 10 == TType.I64 -       ('readString', 'writeString', False), # 11 == TType.STRING and UTF7 -       ('readContainerStruct', 'writeContainerStruct', True), # 12 == TType.STRUCT -       ('readContainerMap', 'writeContainerMap', True), # 13 == TType.MAP -       ('readContainerSet', 'writeContainerSet', True), # 14 == TType.SET -       ('readContainerList', 'writeContainerList', True), # 15 == TType.LIST -       (None, None, False), # 16 == TType.UTF8 # TODO: handle utf8 types? -       (None, None, False)# 17 == TType.UTF16 # TODO: handle utf16 types? +       (None, None, False),  # 0 TType.STOP +       (None, None, False),  # 1 TType.VOID # TODO: handle void? +       ('readBool', 'writeBool', False),  # 2 TType.BOOL +       ('readByte',  'writeByte', False),  # 3 TType.BYTE and I08 +       ('readDouble', 'writeDouble', False),  # 4 TType.DOUBLE +       (None, None, False),  # 5 undefined +       ('readI16', 'writeI16', False),  # 6 TType.I16 +       (None, None, False),  # 7 undefined +       ('readI32', 'writeI32', False),  # 8 TType.I32 +       (None, None, False),  # 9 undefined +       ('readI64', 'writeI64', False),  # 10 TType.I64 +       ('readString', 'writeString', False),  # 11 TType.STRING and UTF7 +       ('readContainerStruct', 'writeContainerStruct', True),  # 12 *.STRUCT +       ('readContainerMap', 'writeContainerMap', True),  # 13 TType.MAP +       ('readContainerSet', 'writeContainerSet', True),  # 14 TType.SET +       ('readContainerList', 'writeContainerList', True),  # 15 TType.LIST +       (None, None, False),  # 16 TType.UTF8 # TODO: handle utf8 types? +       (None, None, False)  # 17 TType.UTF16 # TODO: handle utf16 types?        )    def readFieldByTType(self, ttype, spec): @@ -270,7 +270,7 @@ class TProtocolBase:        container_reader = self._TTYPE_HANDLERS[set_type][0]        val_reader = getattr(self, container_reader)        for idx in xrange(set_len): -        results.add(val_reader(tspec))  +        results.add(val_reader(tspec))      self.readSetEnd()      return results @@ -279,13 +279,14 @@ class TProtocolBase:      obj = obj_class()      obj.read(self)      return obj -   +    def readContainerMap(self, spec):      results = dict()      key_ttype, key_spec = spec[0], spec[1]      val_ttype, val_spec = spec[2], spec[3]      (map_ktype, map_vtype, map_len) = self.readMapBegin() -    # TODO: compare types we just decoded with thrift_spec and abort/skip if types disagree +    # TODO: compare types we just decoded with thrift_spec and +    # abort/skip if types disagree      key_reader = getattr(self, self._TTYPE_HANDLERS[key_ttype][0])      val_reader = getattr(self, self._TTYPE_HANDLERS[val_ttype][0])      # list values are simple types @@ -298,7 +299,8 @@ class TProtocolBase:          v_val = val_reader()        else:          v_val = self.readFieldByTType(val_ttype, val_spec) -      # this raises a TypeError with unhashable keys types. i.e. d=dict(); d[[0,1]] = 2 fails +      # this raises a TypeError with unhashable keys types +      # i.e. this fails: d=dict(); d[[0,1]] = 2        results[k_val] = v_val      self.readMapEnd()      return results @@ -329,7 +331,7 @@ class TProtocolBase:    def writeContainerList(self, val, spec):      self.writeListBegin(spec[0], len(val)) -    r_handler, w_handler, is_container  = self._TTYPE_HANDLERS[spec[0]] +    r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]]      e_writer = getattr(self, w_handler)      if not is_container:        for elem in val: @@ -398,7 +400,7 @@ class TProtocolBase:      else:        writer(val) +  class TProtocolFactory:    def getProtocol(self, trans):      pass - diff --git a/module/lib/thrift/protocol/__init__.py b/module/lib/thrift/protocol/__init__.py index d53359b28..7eefb458a 100644 --- a/module/lib/thrift/protocol/__init__.py +++ b/module/lib/thrift/protocol/__init__.py @@ -17,4 +17,4 @@  # under the License.  # -__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary', 'TBase'] +__all__ = ['fastbinary', 'TBase', 'TBinaryProtocol', 'TCompactProtocol', 'TJSONProtocol', 'TProtocol'] diff --git a/module/lib/thrift/protocol/fastbinary.c b/module/lib/thrift/protocol/fastbinary.c new file mode 100644 index 000000000..2ce56603c --- /dev/null +++ b/module/lib/thrift/protocol/fastbinary.c @@ -0,0 +1,1219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + *   http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <Python.h> +#include "cStringIO.h" +#include <stdint.h> +#ifndef _WIN32 +# include <stdbool.h> +# include <netinet/in.h> +#else +# include <WinSock2.h> +# pragma comment (lib, "ws2_32.lib") +# define BIG_ENDIAN (4321) +# define LITTLE_ENDIAN (1234) +# define BYTE_ORDER LITTLE_ENDIAN +# if defined(_MSC_VER) && _MSC_VER < 1600 +   typedef int _Bool; +#  define bool _Bool +#  define false 0  +#  define true 1 +# endif +# define inline __inline +#endif + +/* Fix endianness issues on Solaris */ +#if defined (__SVR4) && defined (__sun) + #if defined(__i386) && !defined(__i386__) +  #define __i386__ + #endif + + #ifndef BIG_ENDIAN +  #define BIG_ENDIAN (4321) + #endif + #ifndef LITTLE_ENDIAN +  #define LITTLE_ENDIAN (1234) + #endif + + /* I386 is LE, even on Solaris */ + #if !defined(BYTE_ORDER) && defined(__i386__) +  #define BYTE_ORDER LITTLE_ENDIAN + #endif +#endif + +// TODO(dreiss): defval appears to be unused.  Look into removing it. +// TODO(dreiss): Make parse_spec_args recursive, and cache the output +//               permanently in the object.  (Malloc and orphan.) +// TODO(dreiss): Why do we need cStringIO for reading, why not just char*? +//               Can cStringIO let us work with a BufferedTransport? +// TODO(dreiss): Don't ignore the rv from cwrite (maybe). + +/* ====== BEGIN UTILITIES ====== */ + +#define INIT_OUTBUF_SIZE 128 + +// Stolen out of TProtocol.h. +// It would be a huge pain to have both get this from one place. +typedef enum TType { +  T_STOP       = 0, +  T_VOID       = 1, +  T_BOOL       = 2, +  T_BYTE       = 3, +  T_I08        = 3, +  T_I16        = 6, +  T_I32        = 8, +  T_U64        = 9, +  T_I64        = 10, +  T_DOUBLE     = 4, +  T_STRING     = 11, +  T_UTF7       = 11, +  T_STRUCT     = 12, +  T_MAP        = 13, +  T_SET        = 14, +  T_LIST       = 15, +  T_UTF8       = 16, +  T_UTF16      = 17 +} TType; + +#ifndef __BYTE_ORDER +# if defined(BYTE_ORDER) && defined(LITTLE_ENDIAN) && defined(BIG_ENDIAN) +#  define __BYTE_ORDER BYTE_ORDER +#  define __LITTLE_ENDIAN LITTLE_ENDIAN +#  define __BIG_ENDIAN BIG_ENDIAN +# else +#  error "Cannot determine endianness" +# endif +#endif + +// Same comment as the enum.  Sorry. +#if __BYTE_ORDER == __BIG_ENDIAN +# define ntohll(n) (n) +# define htonll(n) (n) +#elif __BYTE_ORDER == __LITTLE_ENDIAN +# if defined(__GNUC__) && defined(__GLIBC__) +#  include <byteswap.h> +#  define ntohll(n) bswap_64(n) +#  define htonll(n) bswap_64(n) +# else /* GNUC & GLIBC */ +#  define ntohll(n) ( (((unsigned long long)ntohl(n)) << 32) + ntohl(n >> 32) ) +#  define htonll(n) ( (((unsigned long long)htonl(n)) << 32) + htonl(n >> 32) ) +# endif /* GNUC & GLIBC */ +#else /* __BYTE_ORDER */ +# error "Can't define htonll or ntohll!" +#endif + +// Doing a benchmark shows that interning actually makes a difference, amazingly. +#define INTERN_STRING(value) _intern_ ## value + +#define INT_CONV_ERROR_OCCURRED(v) ( ((v) == -1) && PyErr_Occurred() ) +#define CHECK_RANGE(v, min, max) ( ((v) <= (max)) && ((v) >= (min)) ) + +// Py_ssize_t was not defined before Python 2.5 +#if (PY_VERSION_HEX < 0x02050000) +typedef int Py_ssize_t; +#endif + +/** + * A cache of the spec_args for a set or list, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { +  TType element_type; +  PyObject* typeargs; +} SetListTypeArgs; + +/** + * A cache of the spec_args for a map, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { +  TType ktag; +  TType vtag; +  PyObject* ktypeargs; +  PyObject* vtypeargs; +} MapTypeArgs; + +/** + * A cache of the spec_args for a struct, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { +  PyObject* klass; +  PyObject* spec; +} StructTypeArgs; + +/** + * A cache of the item spec from a struct specification, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { +  int tag; +  TType type; +  PyObject* attrname; +  PyObject* typeargs; +  PyObject* defval; +} StructItemSpec; + +/** + * A cache of the two key attributes of a CReadableTransport, + * so we don't have to keep calling PyObject_GetAttr. + */ +typedef struct { +  PyObject* stringiobuf; +  PyObject* refill_callable; +} DecodeBuffer; + +/** Pointer to interned string to speed up attribute lookup. */ +static PyObject* INTERN_STRING(cstringio_buf); +/** Pointer to interned string to speed up attribute lookup. */ +static PyObject* INTERN_STRING(cstringio_refill); + +static inline bool +check_ssize_t_32(Py_ssize_t len) { +  // error from getting the int +  if (INT_CONV_ERROR_OCCURRED(len)) { +    return false; +  } +  if (!CHECK_RANGE(len, 0, INT32_MAX)) { +    PyErr_SetString(PyExc_OverflowError, "string size out of range"); +    return false; +  } +  return true; +} + +static inline bool +parse_pyint(PyObject* o, int32_t* ret, int32_t min, int32_t max) { +  long val = PyInt_AsLong(o); + +  if (INT_CONV_ERROR_OCCURRED(val)) { +    return false; +  } +  if (!CHECK_RANGE(val, min, max)) { +    PyErr_SetString(PyExc_OverflowError, "int out of range"); +    return false; +  } + +  *ret = (int32_t) val; +  return true; +} + + +/* --- FUNCTIONS TO PARSE STRUCT SPECIFICATOINS --- */ + +static bool +parse_set_list_args(SetListTypeArgs* dest, PyObject* typeargs) { +  if (PyTuple_Size(typeargs) != 2) { +    PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for list/set type args"); +    return false; +  } + +  dest->element_type = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0)); +  if (INT_CONV_ERROR_OCCURRED(dest->element_type)) { +    return false; +  } + +  dest->typeargs = PyTuple_GET_ITEM(typeargs, 1); + +  return true; +} + +static bool +parse_map_args(MapTypeArgs* dest, PyObject* typeargs) { +  if (PyTuple_Size(typeargs) != 4) { +    PyErr_SetString(PyExc_TypeError, "expecting 4 arguments for typeargs to map"); +    return false; +  } + +  dest->ktag = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0)); +  if (INT_CONV_ERROR_OCCURRED(dest->ktag)) { +    return false; +  } + +  dest->vtag = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 2)); +  if (INT_CONV_ERROR_OCCURRED(dest->vtag)) { +    return false; +  } + +  dest->ktypeargs = PyTuple_GET_ITEM(typeargs, 1); +  dest->vtypeargs = PyTuple_GET_ITEM(typeargs, 3); + +  return true; +} + +static bool +parse_struct_args(StructTypeArgs* dest, PyObject* typeargs) { +  if (PyTuple_Size(typeargs) != 2) { +    PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for struct args"); +    return false; +  } + +  dest->klass = PyTuple_GET_ITEM(typeargs, 0); +  dest->spec = PyTuple_GET_ITEM(typeargs, 1); + +  return true; +} + +static int +parse_struct_item_spec(StructItemSpec* dest, PyObject* spec_tuple) { + +  // i'd like to use ParseArgs here, but it seems to be a bottleneck. +  if (PyTuple_Size(spec_tuple) != 5) { +    PyErr_SetString(PyExc_TypeError, "expecting 5 arguments for spec tuple"); +    return false; +  } + +  dest->tag = PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 0)); +  if (INT_CONV_ERROR_OCCURRED(dest->tag)) { +    return false; +  } + +  dest->type = PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 1)); +  if (INT_CONV_ERROR_OCCURRED(dest->type)) { +    return false; +  } + +  dest->attrname = PyTuple_GET_ITEM(spec_tuple, 2); +  dest->typeargs = PyTuple_GET_ITEM(spec_tuple, 3); +  dest->defval = PyTuple_GET_ITEM(spec_tuple, 4); +  return true; +} + +/* ====== END UTILITIES ====== */ + + +/* ====== BEGIN WRITING FUNCTIONS ====== */ + +/* --- LOW-LEVEL WRITING FUNCTIONS --- */ + +static void writeByte(PyObject* outbuf, int8_t val) { +  int8_t net = val; +  PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int8_t)); +} + +static void writeI16(PyObject* outbuf, int16_t val) { +  int16_t net = (int16_t)htons(val); +  PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int16_t)); +} + +static void writeI32(PyObject* outbuf, int32_t val) { +  int32_t net = (int32_t)htonl(val); +  PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int32_t)); +} + +static void writeI64(PyObject* outbuf, int64_t val) { +  int64_t net = (int64_t)htonll(val); +  PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int64_t)); +} + +static void writeDouble(PyObject* outbuf, double dub) { +  // Unfortunately, bitwise_cast doesn't work in C.  Bad C! +  union { +    double f; +    int64_t t; +  } transfer; +  transfer.f = dub; +  writeI64(outbuf, transfer.t); +} + + +/* --- MAIN RECURSIVE OUTPUT FUCNTION -- */ + +static int +output_val(PyObject* output, PyObject* value, TType type, PyObject* typeargs) { +  /* +   * Refcounting Strategy: +   * +   * We assume that elements of the thrift_spec tuple are not going to be +   * mutated, so we don't ref count those at all. Other than that, we try to +   * keep a reference to all the user-created objects while we work with them. +   * output_val assumes that a reference is already held. The *caller* is +   * responsible for handling references +   */ + +  switch (type) { + +  case T_BOOL: { +    int v = PyObject_IsTrue(value); +    if (v == -1) { +      return false; +    } + +    writeByte(output, (int8_t) v); +    break; +  } +  case T_I08: { +    int32_t val; + +    if (!parse_pyint(value, &val, INT8_MIN, INT8_MAX)) { +      return false; +    } + +    writeByte(output, (int8_t) val); +    break; +  } +  case T_I16: { +    int32_t val; + +    if (!parse_pyint(value, &val, INT16_MIN, INT16_MAX)) { +      return false; +    } + +    writeI16(output, (int16_t) val); +    break; +  } +  case T_I32: { +    int32_t val; + +    if (!parse_pyint(value, &val, INT32_MIN, INT32_MAX)) { +      return false; +    } + +    writeI32(output, val); +    break; +  } +  case T_I64: { +    int64_t nval = PyLong_AsLongLong(value); + +    if (INT_CONV_ERROR_OCCURRED(nval)) { +      return false; +    } + +    if (!CHECK_RANGE(nval, INT64_MIN, INT64_MAX)) { +      PyErr_SetString(PyExc_OverflowError, "int out of range"); +      return false; +    } + +    writeI64(output, nval); +    break; +  } + +  case T_DOUBLE: { +    double nval = PyFloat_AsDouble(value); +    if (nval == -1.0 && PyErr_Occurred()) { +      return false; +    } + +    writeDouble(output, nval); +    break; +  } + +  case T_STRING: { +    Py_ssize_t len = PyString_Size(value); + +    if (!check_ssize_t_32(len)) { +      return false; +    } + +    writeI32(output, (int32_t) len); +    PycStringIO->cwrite(output, PyString_AsString(value), (int32_t) len); +    break; +  } + +  case T_LIST: +  case T_SET: { +    Py_ssize_t len; +    SetListTypeArgs parsedargs; +    PyObject *item; +    PyObject *iterator; + +    if (!parse_set_list_args(&parsedargs, typeargs)) { +      return false; +    } + +    len = PyObject_Length(value); + +    if (!check_ssize_t_32(len)) { +      return false; +    } + +    writeByte(output, parsedargs.element_type); +    writeI32(output, (int32_t) len); + +    iterator =  PyObject_GetIter(value); +    if (iterator == NULL) { +      return false; +    } + +    while ((item = PyIter_Next(iterator))) { +      if (!output_val(output, item, parsedargs.element_type, parsedargs.typeargs)) { +        Py_DECREF(item); +        Py_DECREF(iterator); +        return false; +      } +      Py_DECREF(item); +    } + +    Py_DECREF(iterator); + +    if (PyErr_Occurred()) { +      return false; +    } + +    break; +  } + +  case T_MAP: { +    PyObject *k, *v; +    Py_ssize_t pos = 0; +    Py_ssize_t len; + +    MapTypeArgs parsedargs; + +    len = PyDict_Size(value); +    if (!check_ssize_t_32(len)) { +      return false; +    } + +    if (!parse_map_args(&parsedargs, typeargs)) { +      return false; +    } + +    writeByte(output, parsedargs.ktag); +    writeByte(output, parsedargs.vtag); +    writeI32(output, len); + +    // TODO(bmaurer): should support any mapping, not just dicts +    while (PyDict_Next(value, &pos, &k, &v)) { +      // TODO(dreiss): Think hard about whether these INCREFs actually +      //               turn any unsafe scenarios into safe scenarios. +      Py_INCREF(k); +      Py_INCREF(v); + +      if (!output_val(output, k, parsedargs.ktag, parsedargs.ktypeargs) +          || !output_val(output, v, parsedargs.vtag, parsedargs.vtypeargs)) { +        Py_DECREF(k); +        Py_DECREF(v); +        return false; +      } +      Py_DECREF(k); +      Py_DECREF(v); +    } +    break; +  } + +  // TODO(dreiss): Consider breaking this out as a function +  //               the way we did for decode_struct. +  case T_STRUCT: { +    StructTypeArgs parsedargs; +    Py_ssize_t nspec; +    Py_ssize_t i; + +    if (!parse_struct_args(&parsedargs, typeargs)) { +      return false; +    } + +    nspec = PyTuple_Size(parsedargs.spec); + +    if (nspec == -1) { +      return false; +    } + +    for (i = 0; i < nspec; i++) { +      StructItemSpec parsedspec; +      PyObject* spec_tuple; +      PyObject* instval = NULL; + +      spec_tuple = PyTuple_GET_ITEM(parsedargs.spec, i); +      if (spec_tuple == Py_None) { +        continue; +      } + +      if (!parse_struct_item_spec (&parsedspec, spec_tuple)) { +        return false; +      } + +      instval = PyObject_GetAttr(value, parsedspec.attrname); + +      if (!instval) { +        return false; +      } + +      if (instval == Py_None) { +        Py_DECREF(instval); +        continue; +      } + +      writeByte(output, (int8_t) parsedspec.type); +      writeI16(output, parsedspec.tag); + +      if (!output_val(output, instval, parsedspec.type, parsedspec.typeargs)) { +        Py_DECREF(instval); +        return false; +      } + +      Py_DECREF(instval); +    } + +    writeByte(output, (int8_t)T_STOP); +    break; +  } + +  case T_STOP: +  case T_VOID: +  case T_UTF16: +  case T_UTF8: +  case T_U64: +  default: +    PyErr_SetString(PyExc_TypeError, "Unexpected TType"); +    return false; + +  } + +  return true; +} + + +/* --- TOP-LEVEL WRAPPER FOR OUTPUT -- */ + +static PyObject * +encode_binary(PyObject *self, PyObject *args) { +  PyObject* enc_obj; +  PyObject* type_args; +  PyObject* buf; +  PyObject* ret = NULL; + +  if (!PyArg_ParseTuple(args, "OO", &enc_obj, &type_args)) { +    return NULL; +  } + +  buf = PycStringIO->NewOutput(INIT_OUTBUF_SIZE); +  if (output_val(buf, enc_obj, T_STRUCT, type_args)) { +    ret = PycStringIO->cgetvalue(buf); +  } + +  Py_DECREF(buf); +  return ret; +} + +/* ====== END WRITING FUNCTIONS ====== */ + + +/* ====== BEGIN READING FUNCTIONS ====== */ + +/* --- LOW-LEVEL READING FUNCTIONS --- */ + +static void +free_decodebuf(DecodeBuffer* d) { +  Py_XDECREF(d->stringiobuf); +  Py_XDECREF(d->refill_callable); +} + +static bool +decode_buffer_from_obj(DecodeBuffer* dest, PyObject* obj) { +  dest->stringiobuf = PyObject_GetAttr(obj, INTERN_STRING(cstringio_buf)); +  if (!dest->stringiobuf) { +    return false; +  } + +  if (!PycStringIO_InputCheck(dest->stringiobuf)) { +    free_decodebuf(dest); +    PyErr_SetString(PyExc_TypeError, "expecting stringio input"); +    return false; +  } + +  dest->refill_callable = PyObject_GetAttr(obj, INTERN_STRING(cstringio_refill)); + +  if(!dest->refill_callable) { +    free_decodebuf(dest); +    return false; +  } + +  if (!PyCallable_Check(dest->refill_callable)) { +    free_decodebuf(dest); +    PyErr_SetString(PyExc_TypeError, "expecting callable"); +    return false; +  } + +  return true; +} + +static bool readBytes(DecodeBuffer* input, char** output, int len) { +  int read; + +  // TODO(dreiss): Don't fear the malloc.  Think about taking a copy of +  //               the partial read instead of forcing the transport +  //               to prepend it to its buffer. + +  read = PycStringIO->cread(input->stringiobuf, output, len); + +  if (read == len) { +    return true; +  } else if (read == -1) { +    return false; +  } else { +    PyObject* newiobuf; + +    // using building functions as this is a rare codepath +    newiobuf = PyObject_CallFunction( +        input->refill_callable, "s#i", *output, read, len, NULL); +    if (newiobuf == NULL) { +      return false; +    } + +    // must do this *AFTER* the call so that we don't deref the io buffer +    Py_CLEAR(input->stringiobuf); +    input->stringiobuf = newiobuf; + +    read = PycStringIO->cread(input->stringiobuf, output, len); + +    if (read == len) { +      return true; +    } else if (read == -1) { +      return false; +    } else { +      // TODO(dreiss): This could be a valid code path for big binary blobs. +      PyErr_SetString(PyExc_TypeError, +          "refill claimed to have refilled the buffer, but didn't!!"); +      return false; +    } +  } +} + +static int8_t readByte(DecodeBuffer* input) { +  char* buf; +  if (!readBytes(input, &buf, sizeof(int8_t))) { +    return -1; +  } + +  return *(int8_t*) buf; +} + +static int16_t readI16(DecodeBuffer* input) { +  char* buf; +  if (!readBytes(input, &buf, sizeof(int16_t))) { +    return -1; +  } + +  return (int16_t) ntohs(*(int16_t*) buf); +} + +static int32_t readI32(DecodeBuffer* input) { +  char* buf; +  if (!readBytes(input, &buf, sizeof(int32_t))) { +    return -1; +  } +  return (int32_t) ntohl(*(int32_t*) buf); +} + + +static int64_t readI64(DecodeBuffer* input) { +  char* buf; +  if (!readBytes(input, &buf, sizeof(int64_t))) { +    return -1; +  } + +  return (int64_t) ntohll(*(int64_t*) buf); +} + +static double readDouble(DecodeBuffer* input) { +  union { +    int64_t f; +    double t; +  } transfer; + +  transfer.f = readI64(input); +  if (transfer.f == -1) { +    return -1; +  } +  return transfer.t; +} + +static bool +checkTypeByte(DecodeBuffer* input, TType expected) { +  TType got = readByte(input); +  if (INT_CONV_ERROR_OCCURRED(got)) { +    return false; +  } + +  if (expected != got) { +    PyErr_SetString(PyExc_TypeError, "got wrong ttype while reading field"); +    return false; +  } +  return true; +} + +static bool +skip(DecodeBuffer* input, TType type) { +#define SKIPBYTES(n) \ +  do { \ +    if (!readBytes(input, &dummy_buf, (n))) { \ +      return false; \ +    } \ +  } while(0) + +  char* dummy_buf; + +  switch (type) { + +  case T_BOOL: +  case T_I08: SKIPBYTES(1); break; +  case T_I16: SKIPBYTES(2); break; +  case T_I32: SKIPBYTES(4); break; +  case T_I64: +  case T_DOUBLE: SKIPBYTES(8); break; + +  case T_STRING: { +    // TODO(dreiss): Find out if these check_ssize_t32s are really necessary. +    int len = readI32(input); +    if (!check_ssize_t_32(len)) { +      return false; +    } +    SKIPBYTES(len); +    break; +  } + +  case T_LIST: +  case T_SET: { +    TType etype; +    int len, i; + +    etype = readByte(input); +    if (etype == -1) { +      return false; +    } + +    len = readI32(input); +    if (!check_ssize_t_32(len)) { +      return false; +    } + +    for (i = 0; i < len; i++) { +      if (!skip(input, etype)) { +        return false; +      } +    } +    break; +  } + +  case T_MAP: { +    TType ktype, vtype; +    int len, i; + +    ktype = readByte(input); +    if (ktype == -1) { +      return false; +    } + +    vtype = readByte(input); +    if (vtype == -1) { +      return false; +    } + +    len = readI32(input); +    if (!check_ssize_t_32(len)) { +      return false; +    } + +    for (i = 0; i < len; i++) { +      if (!(skip(input, ktype) && skip(input, vtype))) { +        return false; +      } +    } +    break; +  } + +  case T_STRUCT: { +    while (true) { +      TType type; + +      type = readByte(input); +      if (type == -1) { +        return false; +      } + +      if (type == T_STOP) +        break; + +      SKIPBYTES(2); // tag +      if (!skip(input, type)) { +        return false; +      } +    } +    break; +  } + +  case T_STOP: +  case T_VOID: +  case T_UTF16: +  case T_UTF8: +  case T_U64: +  default: +    PyErr_SetString(PyExc_TypeError, "Unexpected TType"); +    return false; + +  } + +  return true; + +#undef SKIPBYTES +} + + +/* --- HELPER FUNCTION FOR DECODE_VAL --- */ + +static PyObject* +decode_val(DecodeBuffer* input, TType type, PyObject* typeargs); + +static bool +decode_struct(DecodeBuffer* input, PyObject* output, PyObject* spec_seq) { +  int spec_seq_len = PyTuple_Size(spec_seq); +  if (spec_seq_len == -1) { +    return false; +  } + +  while (true) { +    TType type; +    int16_t tag; +    PyObject* item_spec; +    PyObject* fieldval = NULL; +    StructItemSpec parsedspec; + +    type = readByte(input); +    if (type == -1) { +      return false; +    } +    if (type == T_STOP) { +      break; +    } +    tag = readI16(input); +    if (INT_CONV_ERROR_OCCURRED(tag)) { +      return false; +    } +    if (tag >= 0 && tag < spec_seq_len) { +      item_spec = PyTuple_GET_ITEM(spec_seq, tag); +    } else { +      item_spec = Py_None; +    } + +    if (item_spec == Py_None) { +      if (!skip(input, type)) { +        return false; +      } else { +        continue; +      } +    } + +    if (!parse_struct_item_spec(&parsedspec, item_spec)) { +      return false; +    } +    if (parsedspec.type != type) { +      if (!skip(input, type)) { +        PyErr_SetString(PyExc_TypeError, "struct field had wrong type while reading and can't be skipped"); +        return false; +      } else { +        continue; +      } +    } + +    fieldval = decode_val(input, parsedspec.type, parsedspec.typeargs); +    if (fieldval == NULL) { +      return false; +    } + +    if (PyObject_SetAttr(output, parsedspec.attrname, fieldval) == -1) { +      Py_DECREF(fieldval); +      return false; +    } +    Py_DECREF(fieldval); +  } +  return true; +} + + +/* --- MAIN RECURSIVE INPUT FUCNTION --- */ + +// Returns a new reference. +static PyObject* +decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) { +  switch (type) { + +  case T_BOOL: { +    int8_t v = readByte(input); +    if (INT_CONV_ERROR_OCCURRED(v)) { +      return NULL; +    } + +    switch (v) { +    case 0: Py_RETURN_FALSE; +    case 1: Py_RETURN_TRUE; +    // Don't laugh.  This is a potentially serious issue. +    default: PyErr_SetString(PyExc_TypeError, "boolean out of range"); return NULL; +    } +    break; +  } +  case T_I08: { +    int8_t v = readByte(input); +    if (INT_CONV_ERROR_OCCURRED(v)) { +      return NULL; +    } + +    return PyInt_FromLong(v); +  } +  case T_I16: { +    int16_t v = readI16(input); +    if (INT_CONV_ERROR_OCCURRED(v)) { +      return NULL; +    } +    return PyInt_FromLong(v); +  } +  case T_I32: { +    int32_t v = readI32(input); +    if (INT_CONV_ERROR_OCCURRED(v)) { +      return NULL; +    } +    return PyInt_FromLong(v); +  } + +  case T_I64: { +    int64_t v = readI64(input); +    if (INT_CONV_ERROR_OCCURRED(v)) { +      return NULL; +    } +    // TODO(dreiss): Find out if we can take this fastpath always when +    //               sizeof(long) == sizeof(long long). +    if (CHECK_RANGE(v, LONG_MIN, LONG_MAX)) { +      return PyInt_FromLong((long) v); +    } + +    return PyLong_FromLongLong(v); +  } + +  case T_DOUBLE: { +    double v = readDouble(input); +    if (v == -1.0 && PyErr_Occurred()) { +      return false; +    } +    return PyFloat_FromDouble(v); +  } + +  case T_STRING: { +    Py_ssize_t len = readI32(input); +    char* buf; +    if (!readBytes(input, &buf, len)) { +      return NULL; +    } + +    return PyString_FromStringAndSize(buf, len); +  } + +  case T_LIST: +  case T_SET: { +    SetListTypeArgs parsedargs; +    int32_t len; +    PyObject* ret = NULL; +    int i; + +    if (!parse_set_list_args(&parsedargs, typeargs)) { +      return NULL; +    } + +    if (!checkTypeByte(input, parsedargs.element_type)) { +      return NULL; +    } + +    len = readI32(input); +    if (!check_ssize_t_32(len)) { +      return NULL; +    } + +    ret = PyList_New(len); +    if (!ret) { +      return NULL; +    } + +    for (i = 0; i < len; i++) { +      PyObject* item = decode_val(input, parsedargs.element_type, parsedargs.typeargs); +      if (!item) { +        Py_DECREF(ret); +        return NULL; +      } +      PyList_SET_ITEM(ret, i, item); +    } + +    // TODO(dreiss): Consider biting the bullet and making two separate cases +    //               for list and set, avoiding this post facto conversion. +    if (type == T_SET) { +      PyObject* setret; +#if (PY_VERSION_HEX < 0x02050000) +      // hack needed for older versions +      setret = PyObject_CallFunctionObjArgs((PyObject*)&PySet_Type, ret, NULL); +#else +      // official version +      setret = PySet_New(ret); +#endif +      Py_DECREF(ret); +      return setret; +    } +    return ret; +  } + +  case T_MAP: { +    int32_t len; +    int i; +    MapTypeArgs parsedargs; +    PyObject* ret = NULL; + +    if (!parse_map_args(&parsedargs, typeargs)) { +      return NULL; +    } + +    if (!checkTypeByte(input, parsedargs.ktag)) { +      return NULL; +    } +    if (!checkTypeByte(input, parsedargs.vtag)) { +      return NULL; +    } + +    len = readI32(input); +    if (!check_ssize_t_32(len)) { +      return false; +    } + +    ret = PyDict_New(); +    if (!ret) { +      goto error; +    } + +    for (i = 0; i < len; i++) { +      PyObject* k = NULL; +      PyObject* v = NULL; +      k = decode_val(input, parsedargs.ktag, parsedargs.ktypeargs); +      if (k == NULL) { +        goto loop_error; +      } +      v = decode_val(input, parsedargs.vtag, parsedargs.vtypeargs); +      if (v == NULL) { +        goto loop_error; +      } +      if (PyDict_SetItem(ret, k, v) == -1) { +        goto loop_error; +      } + +      Py_DECREF(k); +      Py_DECREF(v); +      continue; + +      // Yuck!  Destructors, anyone? +      loop_error: +      Py_XDECREF(k); +      Py_XDECREF(v); +      goto error; +    } + +    return ret; + +    error: +    Py_XDECREF(ret); +    return NULL; +  } + +  case T_STRUCT: { +    StructTypeArgs parsedargs; +	PyObject* ret; +    if (!parse_struct_args(&parsedargs, typeargs)) { +      return NULL; +    } + +    ret = PyObject_CallObject(parsedargs.klass, NULL); +    if (!ret) { +      return NULL; +    } + +    if (!decode_struct(input, ret, parsedargs.spec)) { +      Py_DECREF(ret); +      return NULL; +    } + +    return ret; +  } + +  case T_STOP: +  case T_VOID: +  case T_UTF16: +  case T_UTF8: +  case T_U64: +  default: +    PyErr_SetString(PyExc_TypeError, "Unexpected TType"); +    return NULL; +  } +} + + +/* --- TOP-LEVEL WRAPPER FOR INPUT -- */ + +static PyObject* +decode_binary(PyObject *self, PyObject *args) { +  PyObject* output_obj = NULL; +  PyObject* transport = NULL; +  PyObject* typeargs = NULL; +  StructTypeArgs parsedargs; +  DecodeBuffer input = {0, 0}; +   +  if (!PyArg_ParseTuple(args, "OOO", &output_obj, &transport, &typeargs)) { +    return NULL; +  } + +  if (!parse_struct_args(&parsedargs, typeargs)) { +    return NULL; +  } + +  if (!decode_buffer_from_obj(&input, transport)) { +    return NULL; +  } + +  if (!decode_struct(&input, output_obj, parsedargs.spec)) { +    free_decodebuf(&input); +    return NULL; +  } + +  free_decodebuf(&input); + +  Py_RETURN_NONE; +} + +/* ====== END READING FUNCTIONS ====== */ + + +/* -- PYTHON MODULE SETUP STUFF --- */ + +static PyMethodDef ThriftFastBinaryMethods[] = { + +  {"encode_binary",  encode_binary, METH_VARARGS, ""}, +  {"decode_binary",  decode_binary, METH_VARARGS, ""}, + +  {NULL, NULL, 0, NULL}        /* Sentinel */ +}; + +PyMODINIT_FUNC +initfastbinary(void) { +#define INIT_INTERN_STRING(value) \ +  do { \ +    INTERN_STRING(value) = PyString_InternFromString(#value); \ +    if(!INTERN_STRING(value)) return; \ +  } while(0) + +  INIT_INTERN_STRING(cstringio_buf); +  INIT_INTERN_STRING(cstringio_refill); +#undef INIT_INTERN_STRING + +  PycString_IMPORT; +  if (PycStringIO == NULL) return; + +  (void) Py_InitModule("thrift.protocol.fastbinary", ThriftFastBinaryMethods); +} diff --git a/module/lib/thrift/server/THttpServer.py b/module/lib/thrift/server/THttpServer.py index 3047d9c00..be54bab94 100644 --- a/module/lib/thrift/server/THttpServer.py +++ b/module/lib/thrift/server/THttpServer.py @@ -22,6 +22,7 @@ import BaseHTTPServer  from thrift.server import TServer  from thrift.transport import TTransport +  class ResponseException(Exception):    """Allows handlers to override the HTTP response @@ -39,16 +40,19 @@ class THttpServer(TServer.TServer):    """A simple HTTP-based Thrift server    This class is not very performant, but it is useful (for example) for -  acting as a mock version of an Apache-based PHP Thrift endpoint.""" - -  def __init__(self, processor, server_address, -      inputProtocolFactory, outputProtocolFactory = None, -      server_class = BaseHTTPServer.HTTPServer): +  acting as a mock version of an Apache-based PHP Thrift endpoint. +  """ +  def __init__(self, +               processor, +               server_address, +               inputProtocolFactory, +               outputProtocolFactory=None, +               server_class=BaseHTTPServer.HTTPServer):      """Set up protocol factories and HTTP server.      See BaseHTTPServer for server_address. -    See TServer for protocol factories.""" - +    See TServer for protocol factories. +    """      if outputProtocolFactory is None:        outputProtocolFactory = inputProtocolFactory @@ -62,7 +66,8 @@ class THttpServer(TServer.TServer):          # Don't care about the request path.          itrans = TTransport.TFileObjectTransport(self.rfile)          otrans = TTransport.TFileObjectTransport(self.wfile) -        itrans = TTransport.TBufferedTransport(itrans, int(self.headers['Content-Length'])) +        itrans = TTransport.TBufferedTransport( +          itrans, int(self.headers['Content-Length']))          otrans = TTransport.TMemoryBuffer()          iprot = thttpserver.inputProtocolFactory.getProtocol(itrans)          oprot = thttpserver.outputProtocolFactory.getProtocol(otrans) diff --git a/module/lib/thrift/server/TNonblockingServer.py b/module/lib/thrift/server/TNonblockingServer.py index ea348a0b6..fa478d01f 100644 --- a/module/lib/thrift/server/TNonblockingServer.py +++ b/module/lib/thrift/server/TNonblockingServer.py @@ -18,10 +18,11 @@  #  """Implementation of non-blocking server. -The main idea of the server is reciving and sending requests -only from main thread. +The main idea of the server is to receive and send requests +only from the main thread. -It also makes thread pool server in tasks terms, not connections. +The thread poool should be sized for concurrent tasks, not +maximum connections  """  import threading  import socket @@ -35,8 +36,10 @@ from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory  __all__ = ['TNonblockingServer'] +  class Worker(threading.Thread):      """Worker is a small helper to process incoming connection.""" +      def __init__(self, queue):          threading.Thread.__init__(self)          self.queue = queue @@ -60,8 +63,9 @@ WAIT_PROCESS = 2  SEND_ANSWER = 3  CLOSED = 4 +  def locked(func): -    "Decorator which locks self.lock." +    """Decorator which locks self.lock."""      def nested(self, *args, **kwargs):          self.lock.acquire()          try: @@ -70,8 +74,9 @@ def locked(func):              self.lock.release()      return nested +  def socket_exception(func): -    "Decorator close object on socket.error." +    """Decorator close object on socket.error."""      def read(self, *args, **kwargs):          try:              return func(self, *args, **kwargs) @@ -79,16 +84,17 @@ def socket_exception(func):              self.close()      return read +  class Connection:      """Basic class is represented connection. -     +      It can be in state:          WAIT_LEN --- connection is reading request len.          WAIT_MESSAGE --- connection is reading request. -        WAIT_PROCESS --- connection has just read whole request and  -            waits for call ready routine. +        WAIT_PROCESS --- connection has just read whole request and +                         waits for call ready routine.          SEND_ANSWER --- connection is sending answer string (including length -            of answer). +                        of answer).          CLOSED --- socket was closed and connection should be deleted.      """      def __init__(self, new_socket, wake_up): @@ -102,13 +108,13 @@ class Connection:      def _read_len(self):          """Reads length of request. -         -        It's really paranoic routine and it may be replaced by  -        self.socket.recv(4).""" + +        It's a safer alternative to self.socket.recv(4) +        """          read = self.socket.recv(4 - len(self.message))          if len(read) == 0: -            # if we read 0 bytes and self.message is empty, it means client close  -            # connection +            # if we read 0 bytes and self.message is empty, then +            # the client closed the connection              if len(self.message) != 0:                  logging.error("can't read frame size from socket")              self.close() @@ -117,8 +123,8 @@ class Connection:          if len(self.message) == 4:              self.len, = struct.unpack('!i', self.message)              if self.len < 0: -                logging.error("negative frame size, it seems client"\ -                    " doesn't use FramedTransport") +                logging.error("negative frame size, it seems client " +                              "doesn't use FramedTransport")                  self.close()              elif self.len == 0:                  logging.error("empty frame, it's really strange") @@ -139,8 +145,8 @@ class Connection:          elif self.status == WAIT_MESSAGE:              read = self.socket.recv(self.len - len(self.message))              if len(read) == 0: -                logging.error("can't read frame from socket (get %d of %d bytes)" % -                    (len(self.message), self.len)) +                logging.error("can't read frame from socket (get %d of " +                              "%d bytes)" % (len(self.message), self.len))                  self.close()                  return              self.message += read @@ -162,14 +168,14 @@ class Connection:      @locked      def ready(self, all_ok, message):          """Callback function for switching state and waking up main thread. -         +          This function is the only function witch can be called asynchronous. -         +          The ready can switch Connection to three states:              WAIT_LEN if request was oneway.              SEND_ANSWER if request was processed in normal way.              CLOSED if request throws unexpected exception. -         +          The one wakes up main thread.          """          assert self.status == WAIT_PROCESS @@ -189,33 +195,39 @@ class Connection:      @locked      def is_writeable(self): -        "Returns True if connection should be added to write list of select." +        """Return True if connection should be added to write list of select"""          return self.status == SEND_ANSWER      # it's not necessary, but...      @locked      def is_readable(self): -        "Returns True if connection should be added to read list of select." +        """Return True if connection should be added to read list of select"""          return self.status in (WAIT_LEN, WAIT_MESSAGE)      @locked      def is_closed(self): -        "Returns True if connection is closed." +        """Returns True if connection is closed."""          return self.status == CLOSED      def fileno(self): -        "Returns the file descriptor of the associated socket." +        """Returns the file descriptor of the associated socket."""          return self.socket.fileno()      def close(self): -        "Closes connection" +        """Closes connection"""          self.status = CLOSED          self.socket.close() +  class TNonblockingServer:      """Non-blocking server.""" -    def __init__(self, processor, lsocket, inputProtocolFactory=None,  -            outputProtocolFactory=None, threads=10): + +    def __init__(self, +                 processor, +                 lsocket, +                 inputProtocolFactory=None, +                 outputProtocolFactory=None, +                 threads=10):          self.processor = processor          self.socket = lsocket          self.in_protocol = inputProtocolFactory or TBinaryProtocolFactory() @@ -225,15 +237,18 @@ class TNonblockingServer:          self.tasks = Queue.Queue()          self._read, self._write = socket.socketpair()          self.prepared = False +        self._stop = False      def setNumThreads(self, num):          """Set the number of worker threads that should be created."""          # implement ThreadPool interface -        assert not self.prepared, "You can't change number of threads for working server" +        assert not self.prepared, "Can't change number of threads after start"          self.threads = num      def prepare(self):          """Prepares server for serve requests.""" +        if self.prepared: +            return          self.socket.listen()          for _ in xrange(self.threads):              thread = Worker(self.tasks) @@ -243,16 +258,32 @@ class TNonblockingServer:      def wake_up(self):          """Wake up main thread. -         +          The server usualy waits in select call in we should terminate one.          The simplest way is using socketpair. -         +          Select always wait to read from the first socket of socketpair. -         +          In this case, we can just write anything to the second socket from -        socketpair.""" +        socketpair. +        """          self._write.send('1') +    def stop(self): +        """Stop the server. + +        This method causes the serve() method to return.  stop() may be invoked +        from within your handler, or from another thread. + +        After stop() is called, serve() will return but the server will still +        be listening on the socket.  serve() may then be called again to resume +        processing requests.  Alternatively, close() may be called after +        serve() returns to close the server socket and shutdown all worker +        threads. +        """ +        self._stop = True +        self.wake_up() +      def _select(self):          """Does select on open connections."""          readable = [self.socket.handle.fileno(), self._read.fileno()] @@ -265,21 +296,22 @@ class TNonblockingServer:              if connection.is_closed():                  del self.clients[i]          return select.select(readable, writable, readable) -         +      def handle(self):          """Handle requests. -        -        WARNING! You must call prepare BEFORE calling handle. + +        WARNING! You must call prepare() BEFORE calling handle()          """          assert self.prepared, "You have to call prepare before handle"          rset, wset, xset = self._select()          for readable in rset:              if readable == self._read.fileno():                  # don't care i just need to clean readable flag -                self._read.recv(1024)  +                self._read.recv(1024)              elif readable == self.socket.handle.fileno():                  client = self.socket.accept().handle -                self.clients[client.fileno()] = Connection(client, self.wake_up) +                self.clients[client.fileno()] = Connection(client, +                                                           self.wake_up)              else:                  connection = self.clients[readable]                  connection.read() @@ -288,7 +320,7 @@ class TNonblockingServer:                      otransport = TTransport.TMemoryBuffer()                      iprot = self.in_protocol.getProtocol(itransport)                      oprot = self.out_protocol.getProtocol(otransport) -                    self.tasks.put([self.processor, iprot, oprot,  +                    self.tasks.put([self.processor, iprot, oprot,                                      otransport, connection.ready])          for writeable in wset:              self.clients[writeable].write() @@ -302,9 +334,13 @@ class TNonblockingServer:              self.tasks.put([None, None, None, None, None])          self.socket.close()          self.prepared = False -         +      def serve(self): -        """Serve forever.""" +        """Serve requests. + +        Serve requests forever, or until stop() is called. +        """ +        self._stop = False          self.prepare() -        while True: +        while not self._stop:              self.handle() diff --git a/module/lib/thrift/server/TProcessPoolServer.py b/module/lib/thrift/server/TProcessPoolServer.py index 7ed814a88..7a695a883 100644 --- a/module/lib/thrift/server/TProcessPoolServer.py +++ b/module/lib/thrift/server/TProcessPoolServer.py @@ -24,15 +24,14 @@ from multiprocessing import  Process, Value, Condition, reduction  from TServer import TServer  from thrift.transport.TTransport import TTransportException +  class TProcessPoolServer(TServer): +    """Server with a fixed size pool of worker subprocesses to service requests -    """ -    Server with a fixed size pool of worker subprocesses which service requests.      Note that if you need shared state between the handlers - it's up to you!      Written by Dvir Volk, doat.com      """ - -    def __init__(self, * args): +    def __init__(self, *args):          TServer.__init__(self, *args)          self.numWorkers = 10          self.workers = [] @@ -50,12 +49,11 @@ class TProcessPoolServer(TServer):          self.numWorkers = num      def workerProcess(self): -        """Loop around getting clients from the shared queue and process them.""" - +        """Loop getting clients from the shared queue and process them"""          if self.postForkCallback:              self.postForkCallback() -        while self.isRunning.value == True: +        while self.isRunning.value:              try:                  client = self.serverTransport.accept()                  self.serveClient(client) @@ -82,17 +80,15 @@ class TProcessPoolServer(TServer):          itrans.close()          otrans.close() -      def serve(self): -        """Start a fixed number of worker threads and put client into a queue""" - -        #this is a shared state that can tell the workers to exit when set as false +        """Start workers and put into queue""" +        # this is a shared state that can tell the workers to exit when False          self.isRunning.value = True -        #first bind and listen to the port +        # first bind and listen to the port          self.serverTransport.listen() -        #fork the children +        # fork the children          for i in range(self.numWorkers):              try:                  w = Process(target=self.workerProcess) @@ -102,16 +98,14 @@ class TProcessPoolServer(TServer):              except Exception, x:                  logging.exception(x) -        #wait until the condition is set by stop() - +        # wait until the condition is set by stop()          while True: -              self.stopCondition.acquire()              try:                  self.stopCondition.wait()                  break              except (SystemExit, KeyboardInterrupt): -		break +                break              except Exception, x:                  logging.exception(x) @@ -122,4 +116,3 @@ class TProcessPoolServer(TServer):          self.stopCondition.acquire()          self.stopCondition.notify()          self.stopCondition.release() - diff --git a/module/lib/thrift/server/TServer.py b/module/lib/thrift/server/TServer.py index 8456e2d40..2f24842c4 100644 --- a/module/lib/thrift/server/TServer.py +++ b/module/lib/thrift/server/TServer.py @@ -17,27 +17,28 @@  # under the License.  # +import Queue  import logging -import sys  import os -import traceback +import sys  import threading -import Queue +import traceback  from thrift.Thrift import TProcessor -from thrift.transport import TTransport  from thrift.protocol import TBinaryProtocol +from thrift.transport import TTransport -class TServer: -  """Base interface for a server, which must have a serve method.""" +class TServer: +  """Base interface for a server, which must have a serve() method. -  """ 3 constructors for all servers: +  Three constructors for all servers:    1) (processor, serverTransport)    2) (processor, serverTransport, transportFactory, protocolFactory)    3) (processor, serverTransport,        inputTransportFactory, outputTransportFactory, -      inputProtocolFactory, outputProtocolFactory)""" +      inputProtocolFactory, outputProtocolFactory) +  """    def __init__(self, *args):      if (len(args) == 2):        self.__initArgs__(args[0], args[1], @@ -63,8 +64,8 @@ class TServer:    def serve(self):      pass -class TSimpleServer(TServer): +class TSimpleServer(TServer):    """Simple single-threaded server that just pumps around one transport."""    def __init__(self, *args): @@ -89,8 +90,8 @@ class TSimpleServer(TServer):        itrans.close()        otrans.close() -class TThreadedServer(TServer): +class TThreadedServer(TServer):    """Threaded server that spawns a new thread per each connection."""    def __init__(self, *args, **kwargs): @@ -102,7 +103,7 @@ class TThreadedServer(TServer):      while True:        try:          client = self.serverTransport.accept() -        t = threading.Thread(target = self.handle, args=(client,)) +        t = threading.Thread(target=self.handle, args=(client,))          t.setDaemon(self.daemon)          t.start()        except KeyboardInterrupt: @@ -126,8 +127,8 @@ class TThreadedServer(TServer):      itrans.close()      otrans.close() -class TThreadPoolServer(TServer): +class TThreadPoolServer(TServer):    """Server with a fixed size pool of threads which service requests."""    def __init__(self, *args, **kwargs): @@ -170,7 +171,7 @@ class TThreadPoolServer(TServer):      """Start a fixed number of worker threads and put client into a queue"""      for i in range(self.threads):        try: -        t = threading.Thread(target = self.serveThread) +        t = threading.Thread(target=self.serveThread)          t.setDaemon(self.daemon)          t.start()        except Exception, x: @@ -187,9 +188,8 @@ class TThreadPoolServer(TServer):  class TForkingServer(TServer): +  """A Thrift server that forks a new process for each request -  """A Thrift server that forks a new process for each request""" -  """    This is more scalable than the threaded server as it does not cause    GIL contention. @@ -200,7 +200,6 @@ class TForkingServer(TServer):    This code is heavily inspired by SocketServer.ForkingMixIn in the    Python stdlib.    """ -    def __init__(self, *args):      TServer.__init__(self, *args)      self.children = [] @@ -212,14 +211,13 @@ class TForkingServer(TServer):        except IOError, e:          logging.warning(e, exc_info=True) -      self.serverTransport.listen()      while True:        client = self.serverTransport.accept()        try:          pid = os.fork() -        if pid: # parent +        if pid:  # parent            # add before collect, otherwise you race w/ waitpid            self.children.append(pid)            self.collect_children() @@ -258,7 +256,6 @@ class TForkingServer(TServer):        except Exception, x:          logging.exception(x) -    def collect_children(self):      while self.children:        try: @@ -270,5 +267,3 @@ class TForkingServer(TServer):          self.children.remove(pid)        else:          break - - diff --git a/module/lib/thrift/transport/THttpClient.py b/module/lib/thrift/transport/THttpClient.py index 50269785c..ea80a1ae8 100644 --- a/module/lib/thrift/transport/THttpClient.py +++ b/module/lib/thrift/transport/THttpClient.py @@ -17,16 +17,20 @@  # under the License.  # -from TTransport import * -from cStringIO import StringIO - -import urlparse  import httplib -import warnings +import os  import socket +import sys +import urllib +import urlparse +import warnings -class THttpClient(TTransportBase): +from cStringIO import StringIO +from TTransport import * + + +class THttpClient(TTransportBase):    """Http implementation of TTransport base."""    def __init__(self, uri_or_host, port=None, path=None): @@ -35,10 +39,13 @@ class THttpClient(TTransportBase):      THttpClient(host, port, path) - deprecated      THttpClient(uri) -    Only the second supports https.""" - +    Only the second supports https. +    """      if port is not None: -      warnings.warn("Please use the THttpClient('http://host:port/path') syntax", DeprecationWarning, stacklevel=2) +      warnings.warn( +        "Please use the THttpClient('http://host:port/path') syntax", +        DeprecationWarning, +        stacklevel=2)        self.host = uri_or_host        self.port = port        assert path @@ -59,6 +66,7 @@ class THttpClient(TTransportBase):      self.__wbuf = StringIO()      self.__http = None      self.__timeout = None +    self.__custom_headers = None    def open(self):      if self.scheme == 'http': @@ -71,7 +79,7 @@ class THttpClient(TTransportBase):      self.__http = None    def isOpen(self): -    return self.__http != None +    return self.__http is not None    def setTimeout(self, ms):      if not hasattr(socket, 'getdefaulttimeout'): @@ -80,7 +88,10 @@ class THttpClient(TTransportBase):      if ms is None:        self.__timeout = None      else: -      self.__timeout = ms/1000.0 +      self.__timeout = ms / 1000.0 + +  def setCustomHeaders(self, headers): +    self.__custom_headers = headers    def read(self, sz):      return self.__http.file.read(sz) @@ -100,7 +111,7 @@ class THttpClient(TTransportBase):    def flush(self):      if self.isOpen():        self.close() -    self.open(); +    self.open()      # Pull data out of buffer      data = self.__wbuf.getvalue() @@ -113,6 +124,18 @@ class THttpClient(TTransportBase):      self.__http.putheader('Host', self.host)      self.__http.putheader('Content-Type', 'application/x-thrift')      self.__http.putheader('Content-Length', str(len(data))) + +    if not self.__custom_headers or 'User-Agent' not in self.__custom_headers: +      user_agent = 'Python/THttpClient' +      script = os.path.basename(sys.argv[0]) +      if script: +        user_agent = '%s (%s)' % (user_agent, urllib.quote(script)) +      self.__http.putheader('User-Agent', user_agent) + +    if self.__custom_headers: +        for key, val in self.__custom_headers.iteritems(): +            self.__http.putheader(key, val) +      self.__http.endheaders()      # Write payload diff --git a/module/lib/thrift/transport/TSSLSocket.py b/module/lib/thrift/transport/TSSLSocket.py new file mode 100644 index 000000000..81e098426 --- /dev/null +++ b/module/lib/thrift/transport/TSSLSocket.py @@ -0,0 +1,214 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import os +import socket +import ssl + +from thrift.transport import TSocket +from thrift.transport.TTransport import TTransportException + + +class TSSLSocket(TSocket.TSocket): +  """ +  SSL implementation of client-side TSocket + +  This class creates outbound sockets wrapped using the +  python standard ssl module for encrypted connections. + +  The protocol used is set using the class variable +  SSL_VERSION, which must be one of ssl.PROTOCOL_* and +  defaults to  ssl.PROTOCOL_TLSv1 for greatest security. +  """ +  SSL_VERSION = ssl.PROTOCOL_TLSv1 + +  def __init__(self, +               host='localhost', +               port=9090, +               validate=True, +               ca_certs=None, +               keyfile=None, +               certfile=None, +               unix_socket=None): +    """Create SSL TSocket + +    @param validate: Set to False to disable SSL certificate validation +    @type validate: bool +    @param ca_certs: Filename to the Certificate Authority pem file, possibly a +    file downloaded from: http://curl.haxx.se/ca/cacert.pem  This is passed to +    the ssl_wrap function as the 'ca_certs' parameter. +    @type ca_certs: str +    @param keyfile: The private key +    @type keyfile: str +    @param certfile: The cert file +    @type certfile: str +     +    Raises an IOError exception if validate is True and the ca_certs file is +    None, not present or unreadable. +    """ +    self.validate = validate +    self.is_valid = False +    self.peercert = None +    if not validate: +      self.cert_reqs = ssl.CERT_NONE +    else: +      self.cert_reqs = ssl.CERT_REQUIRED +    self.ca_certs = ca_certs +    self.keyfile = keyfile +    self.certfile = certfile +    if validate: +      if ca_certs is None or not os.access(ca_certs, os.R_OK): +        raise IOError('Certificate Authority ca_certs file "%s" ' +                      'is not readable, cannot validate SSL ' +                      'certificates.' % (ca_certs)) +    TSocket.TSocket.__init__(self, host, port, unix_socket) + +  def open(self): +    try: +      res0 = self._resolveAddr() +      for res in res0: +        sock_family, sock_type = res[0:2] +        ip_port = res[4] +        plain_sock = socket.socket(sock_family, sock_type) +        self.handle = ssl.wrap_socket(plain_sock, +                                      ssl_version=self.SSL_VERSION, +                                      do_handshake_on_connect=True, +                                      ca_certs=self.ca_certs, +                                      keyfile=self.keyfile, +                                      certfile=self.certfile, +                                      cert_reqs=self.cert_reqs) +        self.handle.settimeout(self._timeout) +        try: +          self.handle.connect(ip_port) +        except socket.error, e: +          if res is not res0[-1]: +            continue +          else: +            raise e +        break +    except socket.error, e: +      if self._unix_socket: +        message = 'Could not connect to secure socket %s: %s' \ +                % (self._unix_socket, e) +      else: +        message = 'Could not connect to %s:%d: %s' % (self.host, self.port, e) +      raise TTransportException(type=TTransportException.NOT_OPEN, +                                message=message) +    if self.validate: +      self._validate_cert() + +  def _validate_cert(self): +    """internal method to validate the peer's SSL certificate, and to check the +    commonName of the certificate to ensure it matches the hostname we +    used to make this connection.  Does not support subjectAltName records +    in certificates. + +    raises TTransportException if the certificate fails validation. +    """ +    cert = self.handle.getpeercert() +    self.peercert = cert +    if 'subject' not in cert: +      raise TTransportException( +        type=TTransportException.NOT_OPEN, +        message='No SSL certificate found from %s:%s' % (self.host, self.port)) +    fields = cert['subject'] +    for field in fields: +      # ensure structure we get back is what we expect +      if not isinstance(field, tuple): +        continue +      cert_pair = field[0] +      if len(cert_pair) < 2: +        continue +      cert_key, cert_value = cert_pair[0:2] +      if cert_key != 'commonName': +        continue +      certhost = cert_value +      # this check should be performed by some sort of Access Manager +      if certhost == self.host: +        # success, cert commonName matches desired hostname +        self.is_valid = True +        return +      else: +        raise TTransportException( +          type=TTransportException.UNKNOWN, +          message='Hostname we connected to "%s" doesn\'t match certificate ' +                  'provided commonName "%s"' % (self.host, certhost)) +    raise TTransportException( +      type=TTransportException.UNKNOWN, +      message='Could not validate SSL certificate from ' +              'host "%s".  Cert=%s' % (self.host, cert)) + + +class TSSLServerSocket(TSocket.TServerSocket): +  """SSL implementation of TServerSocket + +  This uses the ssl module's wrap_socket() method to provide SSL +  negotiated encryption. +  """ +  SSL_VERSION = ssl.PROTOCOL_TLSv1 + +  def __init__(self, +               host=None, +               port=9090, +               certfile='cert.pem', +               unix_socket=None): +    """Initialize a TSSLServerSocket + +    @param certfile: filename of the server certificate, defaults to cert.pem +    @type certfile: str +    @param host: The hostname or IP to bind the listen socket to, +                 i.e. 'localhost' for only allowing local network connections. +                 Pass None to bind to all interfaces. +    @type host: str +    @param port: The port to listen on for inbound connections. +    @type port: int +    """ +    self.setCertfile(certfile) +    TSocket.TServerSocket.__init__(self, host, port) + +  def setCertfile(self, certfile): +    """Set or change the server certificate file used to wrap new connections. + +    @param certfile: The filename of the server certificate, +                     i.e. '/etc/certs/server.pem' +    @type certfile: str + +    Raises an IOError exception if the certfile is not present or unreadable. +    """ +    if not os.access(certfile, os.R_OK): +      raise IOError('No such certfile found: %s' % (certfile)) +    self.certfile = certfile + +  def accept(self): +    plain_client, addr = self.handle.accept() +    try: +      client = ssl.wrap_socket(plain_client, certfile=self.certfile, +                      server_side=True, ssl_version=self.SSL_VERSION) +    except ssl.SSLError, ssl_exc: +      # failed handshake/ssl wrap, close socket to client +      plain_client.close() +      # raise ssl_exc +      # We can't raise the exception, because it kills most TServer derived +      # serve() methods. +      # Instead, return None, and let the TServer instance deal with it in +      # other exception handling.  (but TSimpleServer dies anyway) +      return None +    result = TSocket.TSocket() +    result.setHandle(client) +    return result diff --git a/module/lib/thrift/transport/TSocket.py b/module/lib/thrift/transport/TSocket.py index 4e0e1874f..9e2b3849b 100644 --- a/module/lib/thrift/transport/TSocket.py +++ b/module/lib/thrift/transport/TSocket.py @@ -17,24 +17,33 @@  # under the License.  # -from TTransport import * -import os  import errno +import os  import socket  import sys +from TTransport import * + +  class TSocketBase(TTransportBase):    def _resolveAddr(self):      if self._unix_socket is not None: -      return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, self._unix_socket)] +      return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, +               self._unix_socket)]      else: -      return socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE | socket.AI_ADDRCONFIG) +      return socket.getaddrinfo(self.host, +                                self.port, +                                socket.AF_UNSPEC, +                                socket.SOCK_STREAM, +                                0, +                                socket.AI_PASSIVE | socket.AI_ADDRCONFIG)    def close(self):      if self.handle:        self.handle.close()        self.handle = None +  class TSocket(TSocketBase):    """Socket implementation of TTransport base.""" @@ -46,7 +55,6 @@ class TSocket(TSocketBase):      @param unix_socket(str)  The filename of a unix socket to connect to.                               (host and port will be ignored.)      """ -      self.host = host      self.port = port      self.handle = None @@ -63,7 +71,7 @@ class TSocket(TSocketBase):      if ms is None:        self._timeout = None      else: -      self._timeout = ms/1000.0 +      self._timeout = ms / 1000.0      if self.handle is not None:        self.handle.settimeout(self._timeout) @@ -87,7 +95,8 @@ class TSocket(TSocketBase):          message = 'Could not connect to socket %s' % self._unix_socket        else:          message = 'Could not connect to %s:%d' % (self.host, self.port) -      raise TTransportException(type=TTransportException.NOT_OPEN, message=message) +      raise TTransportException(type=TTransportException.NOT_OPEN, +                                message=message)    def read(self, sz):      try: @@ -105,24 +114,28 @@ class TSocket(TSocketBase):        else:          raise      if len(buff) == 0: -      raise TTransportException(type=TTransportException.END_OF_FILE, message='TSocket read 0 bytes') +      raise TTransportException(type=TTransportException.END_OF_FILE, +                                message='TSocket read 0 bytes')      return buff    def write(self, buff):      if not self.handle: -      raise TTransportException(type=TTransportException.NOT_OPEN, message='Transport not open') +      raise TTransportException(type=TTransportException.NOT_OPEN, +                                message='Transport not open')      sent = 0      have = len(buff)      while sent < have:        plus = self.handle.send(buff)        if plus == 0: -        raise TTransportException(type=TTransportException.END_OF_FILE, message='TSocket sent 0 bytes') +        raise TTransportException(type=TTransportException.END_OF_FILE, +                                  message='TSocket sent 0 bytes')        sent += plus        buff = buff[plus:]    def flush(self):      pass +  class TServerSocket(TSocketBase, TServerTransportBase):    """Socket implementation of TServerTransport base.""" diff --git a/module/lib/thrift/transport/TTransport.py b/module/lib/thrift/transport/TTransport.py index 12e51a9bf..4481371a6 100644 --- a/module/lib/thrift/transport/TTransport.py +++ b/module/lib/thrift/transport/TTransport.py @@ -18,11 +18,11 @@  #  from cStringIO import StringIO -from struct import pack,unpack +from struct import pack, unpack  from thrift.Thrift import TException -class TTransportException(TException): +class TTransportException(TException):    """Custom Transport Exception class"""    UNKNOWN = 0 @@ -35,8 +35,8 @@ class TTransportException(TException):      TException.__init__(self, message)      self.type = type -class TTransportBase: +class TTransportBase:    """Base class for Thrift transport layer."""    def isOpen(self): @@ -55,7 +55,7 @@ class TTransportBase:      buff = ''      have = 0      while (have < sz): -      chunk = self.read(sz-have) +      chunk = self.read(sz - have)        have += len(chunk)        buff += chunk @@ -70,6 +70,7 @@ class TTransportBase:    def flush(self):      pass +  # This class should be thought of as an interface.  class CReadableTransport:    """base class for transports that are readable from C""" @@ -98,8 +99,8 @@ class CReadableTransport:      """      pass -class TServerTransportBase: +class TServerTransportBase:    """Base class for Thrift server transports."""    def listen(self): @@ -111,15 +112,15 @@ class TServerTransportBase:    def close(self):      pass -class TTransportFactoryBase: +class TTransportFactoryBase:    """Base class for a Transport Factory"""    def getTransport(self, trans):      return trans -class TBufferedTransportFactory: +class TBufferedTransportFactory:    """Factory transport that builds buffered transports"""    def getTransport(self, trans): @@ -127,17 +128,15 @@ class TBufferedTransportFactory:      return buffered -class TBufferedTransport(TTransportBase,CReadableTransport): - +class TBufferedTransport(TTransportBase, CReadableTransport):    """Class that wraps another transport and buffers its I/O.    The implementation uses a (configurable) fixed-size read buffer    but buffers all writes until a flush is performed.    """ -    DEFAULT_BUFFER = 4096 -  def __init__(self, trans, rbuf_size = DEFAULT_BUFFER): +  def __init__(self, trans, rbuf_size=DEFAULT_BUFFER):      self.__trans = trans      self.__wbuf = StringIO()      self.__rbuf = StringIO("") @@ -188,6 +187,7 @@ class TBufferedTransport(TTransportBase,CReadableTransport):      self.__rbuf = StringIO(retstring)      return self.__rbuf +  class TMemoryBuffer(TTransportBase, CReadableTransport):    """Wraps a cStringIO object as a TTransport. @@ -237,8 +237,8 @@ class TMemoryBuffer(TTransportBase, CReadableTransport):      # only one shot at reading...      raise EOFError() -class TFramedTransportFactory: +class TFramedTransportFactory:    """Factory transport that builds framed transports"""    def getTransport(self, trans): @@ -247,7 +247,6 @@ class TFramedTransportFactory:  class TFramedTransport(TTransportBase, CReadableTransport): -    """Class that wraps another transport and frames its I/O when writing."""    def __init__(self, trans,): diff --git a/module/lib/thrift/transport/TTwisted.py b/module/lib/thrift/transport/TTwisted.py index b6dcb4e0b..3ce3eb220 100644 --- a/module/lib/thrift/transport/TTwisted.py +++ b/module/lib/thrift/transport/TTwisted.py @@ -16,6 +16,9 @@  # specific language governing permissions and limitations  # under the License.  # + +from cStringIO import StringIO +  from zope.interface import implements, Interface, Attribute  from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \      connectionDone @@ -25,7 +28,6 @@ from twisted.python import log  from twisted.web import server, resource, http  from thrift.transport import TTransport -from cStringIO import StringIO  class TMessageSenderTransport(TTransport.TTransportBase): @@ -79,7 +81,7 @@ class ThriftClientProtocol(basic.Int32StringReceiver):          self.started.callback(self.client)      def connectionLost(self, reason=connectionDone): -        for k,v in self.client._reqs.iteritems(): +        for k, v in self.client._reqs.iteritems():              tex = TTransport.TTransportException(                  type=TTransport.TTransportException.END_OF_FILE,                  message='Connection closed') diff --git a/module/lib/thrift/transport/TZlibTransport.py b/module/lib/thrift/transport/TZlibTransport.py index 784d4e1e0..a2f42a5d2 100644 --- a/module/lib/thrift/transport/TZlibTransport.py +++ b/module/lib/thrift/transport/TZlibTransport.py @@ -16,50 +16,49 @@  # specific language governing permissions and limitations  # under the License.  # -''' -TZlibTransport provides a compressed transport and transport factory + +"""TZlibTransport provides a compressed transport and transport factory  class, using the python standard library zlib module to implement  data compression. -''' +"""  from __future__ import division  import zlib  from cStringIO import StringIO  from TTransport import TTransportBase, CReadableTransport +  class TZlibTransportFactory(object): -  ''' -  Factory transport that builds zlib compressed transports. -   +  """Factory transport that builds zlib compressed transports. +    This factory caches the last single client/transport that it was passed    and returns the same TZlibTransport object that was created. -   +    This caching means the TServer class will get the _same_ transport    object for both input and output transports from this factory.    (For non-threaded scenarios only, since the cache only holds one object) -   +    The purpose of this caching is to allocate only one TZlibTransport where    only one is really needed (since it must have separate read/write buffers),    and makes the statistics from getCompSavings() and getCompRatio()    easier to understand. -  ''' - +  """    # class scoped cache of last transport given and zlibtransport returned    _last_trans = None    _last_z = None    def getTransport(self, trans, compresslevel=9): -    '''Wrap a transport , trans, with the TZlibTransport +    """Wrap a transport, trans, with the TZlibTransport      compressed transport class, returning a new      transport to the caller. -     +      @param compresslevel: The zlib compression level, ranging      from 0 (no compression) to 9 (best compression).  Defaults to 9.      @type compresslevel: int -     +      This method returns a TZlibTransport which wraps the      passed C{trans} TTransport derived instance. -    ''' +    """      if trans == self._last_trans:        return self._last_z      ztrans = TZlibTransport(trans, compresslevel) @@ -69,27 +68,24 @@ class TZlibTransportFactory(object):  class TZlibTransport(TTransportBase, CReadableTransport): -  ''' -  Class that wraps a transport with zlib, compressing writes +  """Class that wraps a transport with zlib, compressing writes    and decompresses reads, using the python standard    library zlib module. -  ''' - +  """    # Read buffer size for the python fastbinary C extension,    # the TBinaryProtocolAccelerated class.    DEFAULT_BUFFSIZE = 4096    def __init__(self, trans, compresslevel=9): -    ''' -    Create a new TZlibTransport, wrapping C{trans}, another +    """Create a new TZlibTransport, wrapping C{trans}, another      TTransport derived object. -     +      @param trans: A thrift transport object, i.e. a TSocket() object.      @type trans: TTransport      @param compresslevel: The zlib compression level, ranging      from 0 (no compression) to 9 (best compression).  Default is 9.      @type compresslevel: int -    ''' +    """      self.__trans = trans      self.compresslevel = compresslevel      self.__rbuf = StringIO() @@ -98,49 +94,45 @@ class TZlibTransport(TTransportBase, CReadableTransport):      self._init_stats()    def _reinit_buffers(self): -    ''' -    Internal method to initialize/reset the internal StringIO objects +    """Internal method to initialize/reset the internal StringIO objects      for read and write buffers. -    ''' +    """      self.__rbuf = StringIO()      self.__wbuf = StringIO()    def _init_stats(self): -    ''' -    Internal method to reset the internal statistics counters +    """Internal method to reset the internal statistics counters      for compression ratios and bandwidth savings. -    ''' +    """      self.bytes_in = 0      self.bytes_out = 0      self.bytes_in_comp = 0      self.bytes_out_comp = 0    def _init_zlib(self): -    ''' -    Internal method for setting up the zlib compression and +    """Internal method for setting up the zlib compression and      decompression objects. -    ''' +    """      self._zcomp_read = zlib.decompressobj()      self._zcomp_write = zlib.compressobj(self.compresslevel)    def getCompRatio(self): -    ''' -    Get the current measured compression ratios (in,out) from +    """Get the current measured compression ratios (in,out) from      this transport. -     -    Returns a tuple of:  + +    Returns a tuple of:      (inbound_compression_ratio, outbound_compression_ratio) -     +      The compression ratios are computed as:          compressed / uncompressed      E.g., data that compresses by 10x will have a ratio of: 0.10      and data that compresses to half of ts original size will      have a ratio of 0.5 -     +      None is returned if no bytes have yet been processed in      a particular direction. -    ''' +    """      r_percent, w_percent = (None, None)      if self.bytes_in > 0:        r_percent = self.bytes_in_comp / self.bytes_in @@ -149,23 +141,22 @@ class TZlibTransport(TTransportBase, CReadableTransport):      return (r_percent, w_percent)    def getCompSavings(self): -    ''' -    Get the current count of saved bytes due to data +    """Get the current count of saved bytes due to data      compression. -     +      Returns a tuple of:      (inbound_saved_bytes, outbound_saved_bytes) -     +      Note: if compression is actually expanding your      data (only likely with very tiny thrift objects), then      the values returned will be negative. -    ''' +    """      r_saved = self.bytes_in - self.bytes_in_comp      w_saved = self.bytes_out - self.bytes_out_comp      return (r_saved, w_saved)    def isOpen(self): -    '''Return the underlying transport's open status''' +    """Return the underlying transport's open status"""      return self.__trans.isOpen()    def open(self): @@ -174,25 +165,24 @@ class TZlibTransport(TTransportBase, CReadableTransport):      return self.__trans.open()    def listen(self): -    '''Invoke the underlying transport's listen() method''' +    """Invoke the underlying transport's listen() method"""      self.__trans.listen()    def accept(self): -    '''Accept connections on the underlying transport''' +    """Accept connections on the underlying transport"""      return self.__trans.accept()    def close(self): -    '''Close the underlying transport,''' +    """Close the underlying transport,"""      self._reinit_buffers()      self._init_zlib()      return self.__trans.close()    def read(self, sz): -    ''' -    Read up to sz bytes from the decompressed bytes buffer, and +    """Read up to sz bytes from the decompressed bytes buffer, and      read from the underlying transport if the decompression      buffer is empty. -    ''' +    """      ret = self.__rbuf.read(sz)      if len(ret) > 0:        return ret @@ -204,10 +194,9 @@ class TZlibTransport(TTransportBase, CReadableTransport):      return ret    def readComp(self, sz): -    ''' -    Read compressed data from the underlying transport, then +    """Read compressed data from the underlying transport, then      decompress it and append it to the internal StringIO read buffer -    ''' +    """      zbuf = self.__trans.read(sz)      zbuf = self._zcomp_read.unconsumed_tail + zbuf      buf = self._zcomp_read.decompress(zbuf) @@ -220,17 +209,15 @@ class TZlibTransport(TTransportBase, CReadableTransport):      return True    def write(self, buf): -    ''' -    Write some bytes, putting them into the internal write +    """Write some bytes, putting them into the internal write      buffer for eventual compression. -    ''' +    """      self.__wbuf.write(buf)    def flush(self): -    ''' -    Flush any queued up data in the write buffer and ensure the +    """Flush any queued up data in the write buffer and ensure the      compression buffer is flushed out to the underlying transport -    ''' +    """      wout = self.__wbuf.getvalue()      if len(wout) > 0:        zbuf = self._zcomp_write.compress(wout) @@ -247,11 +234,11 @@ class TZlibTransport(TTransportBase, CReadableTransport):    @property    def cstringio_buf(self): -    '''Implement the CReadableTransport interface''' +    """Implement the CReadableTransport interface"""      return self.__rbuf    def cstringio_refill(self, partialread, reqlen): -    '''Implement the CReadableTransport interface for refill''' +    """Implement the CReadableTransport interface for refill"""      retstring = partialread      if reqlen < self.DEFAULT_BUFFSIZE:        retstring += self.read(self.DEFAULT_BUFFSIZE) diff --git a/module/lib/thrift/transport/__init__.py b/module/lib/thrift/transport/__init__.py index 46e54fe6b..c9596d9a6 100644 --- a/module/lib/thrift/transport/__init__.py +++ b/module/lib/thrift/transport/__init__.py @@ -17,4 +17,4 @@  # under the License.  # -__all__ = ['TTransport', 'TSocket', 'THttpClient','TZlibTransport'] +__all__ = ['TTransport', 'TSocket', 'THttpClient', 'TZlibTransport']  | 
