diff options
Diffstat (limited to 'lib/Python/Lib/thrift')
20 files changed, 3322 insertions, 0 deletions
| diff --git a/lib/Python/Lib/thrift/TSCons.py b/lib/Python/Lib/thrift/TSCons.py new file mode 100644 index 000000000..24046256c --- /dev/null +++ b/lib/Python/Lib/thrift/TSCons.py @@ -0,0 +1,33 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from os import path +from SCons.Builder import Builder + +def scons_env(env, add=''): +  opath = path.dirname(path.abspath('$TARGET')) +  lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE' +  cppbuild = Builder(action = lstr) +  env.Append(BUILDERS = {'ThriftCpp' : cppbuild}) + +def gen_cpp(env, dir, file): +  scons_env(env) +  suffixes = ['_types.h', '_types.cpp'] +  targets = map(lambda s: 'gen-cpp/' + file + s, suffixes) +  return env.ThriftCpp(targets, dir+file+'.thrift') diff --git a/lib/Python/Lib/thrift/TSerialization.py b/lib/Python/Lib/thrift/TSerialization.py new file mode 100644 index 000000000..b19f98aa8 --- /dev/null +++ b/lib/Python/Lib/thrift/TSerialization.py @@ -0,0 +1,34 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from protocol import TBinaryProtocol +from transport import TTransport + +def serialize(thrift_object, protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()): +    transport = TTransport.TMemoryBuffer() +    protocol = protocol_factory.getProtocol(transport) +    thrift_object.write(protocol) +    return transport.getvalue() + +def deserialize(base, buf, protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()): +    transport = TTransport.TMemoryBuffer(buf) +    protocol = protocol_factory.getProtocol(transport) +    base.read(protocol) +    return base + diff --git a/lib/Python/Lib/thrift/Thrift.py b/lib/Python/Lib/thrift/Thrift.py new file mode 100644 index 000000000..1d271fcff --- /dev/null +++ b/lib/Python/Lib/thrift/Thrift.py @@ -0,0 +1,154 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import sys + +class TType: +  STOP   = 0 +  VOID   = 1 +  BOOL   = 2 +  BYTE   = 3 +  I08    = 3 +  DOUBLE = 4 +  I16    = 6 +  I32    = 8 +  I64    = 10 +  STRING = 11 +  UTF7   = 11 +  STRUCT = 12 +  MAP    = 13 +  SET    = 14 +  LIST   = 15 +  UTF8   = 16 +  UTF16  = 17 + +  _VALUES_TO_NAMES = ( 'STOP', +                      'VOID', +                      'BOOL', +                      'BYTE', +                      'DOUBLE', +                      None, +                      'I16', +                      None, +                      'I32', +                      None, +                       'I64', +                       'STRING', +                       'STRUCT', +                       'MAP', +                       'SET', +                       'LIST', +                       'UTF8', +                       'UTF16' ) + +class TMessageType: +  CALL  = 1 +  REPLY = 2 +  EXCEPTION = 3 +  ONEWAY = 4 + +class TProcessor: + +  """Base class for procsessor, which works on two streams.""" + +  def process(iprot, oprot): +    pass + +class TException(Exception): + +  """Base class for all thrift exceptions.""" + +  # BaseException.message is deprecated in Python v[2.6,3.0) +  if (2,6,0) <= sys.version_info < (3,0): +    def _get_message(self): +	    return self._message +    def _set_message(self, message): +	    self._message = message +    message = property(_get_message, _set_message) + +  def __init__(self, message=None): +    Exception.__init__(self, message) +    self.message = message + +class TApplicationException(TException): + +  """Application level thrift exceptions.""" + +  UNKNOWN = 0 +  UNKNOWN_METHOD = 1 +  INVALID_MESSAGE_TYPE = 2 +  WRONG_METHOD_NAME = 3 +  BAD_SEQUENCE_ID = 4 +  MISSING_RESULT = 5 +  INTERNAL_ERROR = 6 +  PROTOCOL_ERROR = 7 + +  def __init__(self, type=UNKNOWN, message=None): +    TException.__init__(self, message) +    self.type = type + +  def __str__(self): +    if self.message: +      return self.message +    elif self.type == self.UNKNOWN_METHOD: +      return 'Unknown method' +    elif self.type == self.INVALID_MESSAGE_TYPE: +      return 'Invalid message type' +    elif self.type == self.WRONG_METHOD_NAME: +      return 'Wrong method name' +    elif self.type == self.BAD_SEQUENCE_ID: +      return 'Bad sequence ID' +    elif self.type == self.MISSING_RESULT: +      return 'Missing result' +    else: +      return 'Default (unknown) TApplicationException' + +  def read(self, iprot): +    iprot.readStructBegin() +    while True: +      (fname, ftype, fid) = iprot.readFieldBegin() +      if ftype == TType.STOP: +        break +      if fid == 1: +        if ftype == TType.STRING: +          self.message = iprot.readString(); +        else: +          iprot.skip(ftype) +      elif fid == 2: +        if ftype == TType.I32: +          self.type = iprot.readI32(); +        else: +          iprot.skip(ftype) +      else: +        iprot.skip(ftype) +      iprot.readFieldEnd() +    iprot.readStructEnd() + +  def write(self, oprot): +    oprot.writeStructBegin('TApplicationException') +    if self.message != None: +      oprot.writeFieldBegin('message', TType.STRING, 1) +      oprot.writeString(self.message) +      oprot.writeFieldEnd() +    if self.type != None: +      oprot.writeFieldBegin('type', TType.I32, 2) +      oprot.writeI32(self.type) +      oprot.writeFieldEnd() +    oprot.writeFieldStop() +    oprot.writeStructEnd() diff --git a/lib/Python/Lib/thrift/__init__.py b/lib/Python/Lib/thrift/__init__.py new file mode 100644 index 000000000..48d659c40 --- /dev/null +++ b/lib/Python/Lib/thrift/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +__all__ = ['Thrift', 'TSCons'] diff --git a/lib/Python/Lib/thrift/protocol/TBase.py b/lib/Python/Lib/thrift/protocol/TBase.py new file mode 100644 index 000000000..e675c7dc0 --- /dev/null +++ b/lib/Python/Lib/thrift/protocol/TBase.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from thrift.Thrift import * +from thrift.protocol import TBinaryProtocol +from thrift.transport import TTransport + +try: +  from thrift.protocol import fastbinary +except: +  fastbinary = None + +class TBase(object): +  __slots__ = [] + +  def __repr__(self): +    L = ['%s=%r' % (key, getattr(self, key)) +              for key in self.__slots__ ] +    return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + +  def __eq__(self, other): +    if not isinstance(other, self.__class__): +      return False +    for attr in self.__slots__: +      my_val = getattr(self, attr) +      other_val = getattr(other, attr) +      if my_val != other_val: +        return False +    return True +     +  def __ne__(self, other): +    return not (self == other) +   +  def read(self, iprot): +    if iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None and fastbinary is not None: +      fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec)) +      return +    iprot.readStruct(self, self.thrift_spec) + +  def write(self, oprot): +    if oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and self.thrift_spec is not None and fastbinary is not None: +      oprot.trans.write(fastbinary.encode_binary(self, (self.__class__, self.thrift_spec))) +      return +    oprot.writeStruct(self, self.thrift_spec) + +class TExceptionBase(Exception): +  # old style class so python2.4 can raise exceptions derived from this +  #  This can't inherit from TBase because of that limitation. +  __slots__ = [] +   +  __repr__ = TBase.__repr__.im_func +  __eq__ = TBase.__eq__.im_func +  __ne__ = TBase.__ne__.im_func +  read = TBase.read.im_func +  write = TBase.write.im_func +   diff --git a/lib/Python/Lib/thrift/protocol/TBinaryProtocol.py b/lib/Python/Lib/thrift/protocol/TBinaryProtocol.py new file mode 100644 index 000000000..50c6aa896 --- /dev/null +++ b/lib/Python/Lib/thrift/protocol/TBinaryProtocol.py @@ -0,0 +1,259 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from TProtocol import * +from struct import pack, unpack + +class TBinaryProtocol(TProtocolBase): + +  """Binary implementation of the Thrift protocol driver.""" + +  # NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be +  # positive, converting this into a long. If we hardcode the int value +  # instead it'll stay in 32 bit-land. + +  # VERSION_MASK = 0xffff0000 +  VERSION_MASK = -65536 + +  # VERSION_1 = 0x80010000 +  VERSION_1 = -2147418112 + +  TYPE_MASK = 0x000000ff + +  def __init__(self, trans, strictRead=False, strictWrite=True): +    TProtocolBase.__init__(self, trans) +    self.strictRead = strictRead +    self.strictWrite = strictWrite + +  def writeMessageBegin(self, name, type, seqid): +    if self.strictWrite: +      self.writeI32(TBinaryProtocol.VERSION_1 | type) +      self.writeString(name) +      self.writeI32(seqid) +    else: +      self.writeString(name) +      self.writeByte(type) +      self.writeI32(seqid) + +  def writeMessageEnd(self): +    pass + +  def writeStructBegin(self, name): +    pass + +  def writeStructEnd(self): +    pass + +  def writeFieldBegin(self, name, type, id): +    self.writeByte(type) +    self.writeI16(id) + +  def writeFieldEnd(self): +    pass + +  def writeFieldStop(self): +    self.writeByte(TType.STOP); + +  def writeMapBegin(self, ktype, vtype, size): +    self.writeByte(ktype) +    self.writeByte(vtype) +    self.writeI32(size) + +  def writeMapEnd(self): +    pass + +  def writeListBegin(self, etype, size): +    self.writeByte(etype) +    self.writeI32(size) + +  def writeListEnd(self): +    pass + +  def writeSetBegin(self, etype, size): +    self.writeByte(etype) +    self.writeI32(size) + +  def writeSetEnd(self): +    pass + +  def writeBool(self, bool): +    if bool: +      self.writeByte(1) +    else: +      self.writeByte(0) + +  def writeByte(self, byte): +    buff = pack("!b", byte) +    self.trans.write(buff) + +  def writeI16(self, i16): +    buff = pack("!h", i16) +    self.trans.write(buff) + +  def writeI32(self, i32): +    buff = pack("!i", i32) +    self.trans.write(buff) + +  def writeI64(self, i64): +    buff = pack("!q", i64) +    self.trans.write(buff) + +  def writeDouble(self, dub): +    buff = pack("!d", dub) +    self.trans.write(buff) + +  def writeString(self, str): +    self.writeI32(len(str)) +    self.trans.write(str) + +  def readMessageBegin(self): +    sz = self.readI32() +    if sz < 0: +      version = sz & TBinaryProtocol.VERSION_MASK +      if version != TBinaryProtocol.VERSION_1: +        raise TProtocolException(type=TProtocolException.BAD_VERSION, message='Bad version in readMessageBegin: %d' % (sz)) +      type = sz & TBinaryProtocol.TYPE_MASK +      name = self.readString() +      seqid = self.readI32() +    else: +      if self.strictRead: +        raise TProtocolException(type=TProtocolException.BAD_VERSION, message='No protocol version header') +      name = self.trans.readAll(sz) +      type = self.readByte() +      seqid = self.readI32() +    return (name, type, seqid) + +  def readMessageEnd(self): +    pass + +  def readStructBegin(self): +    pass + +  def readStructEnd(self): +    pass + +  def readFieldBegin(self): +    type = self.readByte() +    if type == TType.STOP: +      return (None, type, 0) +    id = self.readI16() +    return (None, type, id) + +  def readFieldEnd(self): +    pass + +  def readMapBegin(self): +    ktype = self.readByte() +    vtype = self.readByte() +    size = self.readI32() +    return (ktype, vtype, size) + +  def readMapEnd(self): +    pass + +  def readListBegin(self): +    etype = self.readByte() +    size = self.readI32() +    return (etype, size) + +  def readListEnd(self): +    pass + +  def readSetBegin(self): +    etype = self.readByte() +    size = self.readI32() +    return (etype, size) + +  def readSetEnd(self): +    pass + +  def readBool(self): +    byte = self.readByte() +    if byte == 0: +      return False +    return True + +  def readByte(self): +    buff = self.trans.readAll(1) +    val, = unpack('!b', buff) +    return val + +  def readI16(self): +    buff = self.trans.readAll(2) +    val, = unpack('!h', buff) +    return val + +  def readI32(self): +    buff = self.trans.readAll(4) +    val, = unpack('!i', buff) +    return val + +  def readI64(self): +    buff = self.trans.readAll(8) +    val, = unpack('!q', buff) +    return val + +  def readDouble(self): +    buff = self.trans.readAll(8) +    val, = unpack('!d', buff) +    return val + +  def readString(self): +    len = self.readI32() +    str = self.trans.readAll(len) +    return str + + +class TBinaryProtocolFactory: +  def __init__(self, strictRead=False, strictWrite=True): +    self.strictRead = strictRead +    self.strictWrite = strictWrite + +  def getProtocol(self, trans): +    prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite) +    return prot + + +class TBinaryProtocolAccelerated(TBinaryProtocol): + +  """C-Accelerated version of TBinaryProtocol. + +  This class does not override any of TBinaryProtocol's methods, +  but the generated code recognizes it directly and will call into +  our C module to do the encoding, bypassing this object entirely. +  We inherit from TBinaryProtocol so that the normal TBinaryProtocol +  encoding can happen if the fastbinary module doesn't work for some +  reason.  (TODO(dreiss): Make this happen sanely in more cases.) + +  In order to take advantage of the C module, just use +  TBinaryProtocolAccelerated instead of TBinaryProtocol. + +  NOTE:  This code was contributed by an external developer. +         The internal Thrift team has reviewed and tested it, +         but we cannot guarantee that it is production-ready. +         Please feel free to report bugs and/or success stories +         to the public mailing list. +  """ + +  pass + + +class TBinaryProtocolAcceleratedFactory: +  def getProtocol(self, trans): +    return TBinaryProtocolAccelerated(trans) diff --git a/lib/Python/Lib/thrift/protocol/TCompactProtocol.py b/lib/Python/Lib/thrift/protocol/TCompactProtocol.py new file mode 100644 index 000000000..016a33171 --- /dev/null +++ b/lib/Python/Lib/thrift/protocol/TCompactProtocol.py @@ -0,0 +1,395 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from TProtocol import * +from struct import pack, unpack + +__all__ = ['TCompactProtocol', 'TCompactProtocolFactory'] + +CLEAR = 0 +FIELD_WRITE = 1 +VALUE_WRITE = 2 +CONTAINER_WRITE = 3 +BOOL_WRITE = 4 +FIELD_READ = 5 +CONTAINER_READ = 6 +VALUE_READ = 7 +BOOL_READ = 8 + +def make_helper(v_from, container): +  def helper(func): +    def nested(self, *args, **kwargs): +      assert self.state in (v_from, container), (self.state, v_from, container) +      return func(self, *args, **kwargs) +    return nested +  return helper +writer = make_helper(VALUE_WRITE, CONTAINER_WRITE) +reader = make_helper(VALUE_READ, CONTAINER_READ) + +def makeZigZag(n, bits): +  return (n << 1) ^ (n >> (bits - 1)) + +def fromZigZag(n): +  return (n >> 1) ^ -(n & 1) + +def writeVarint(trans, n): +  out = [] +  while True: +    if n & ~0x7f == 0: +      out.append(n) +      break +    else: +      out.append((n & 0xff) | 0x80) +      n = n >> 7 +  trans.write(''.join(map(chr, out))) + +def readVarint(trans): +  result = 0 +  shift = 0 +  while True: +    x = trans.readAll(1) +    byte = ord(x) +    result |= (byte & 0x7f) << shift +    if byte >> 7 == 0: +      return result +    shift += 7 + +class CompactType: +  STOP = 0x00 +  TRUE = 0x01 +  FALSE = 0x02 +  BYTE = 0x03 +  I16 = 0x04 +  I32 = 0x05 +  I64 = 0x06 +  DOUBLE = 0x07 +  BINARY = 0x08 +  LIST = 0x09 +  SET = 0x0A +  MAP = 0x0B +  STRUCT = 0x0C + +CTYPES = {TType.STOP: CompactType.STOP, +          TType.BOOL: CompactType.TRUE, # used for collection +          TType.BYTE: CompactType.BYTE, +          TType.I16: CompactType.I16, +          TType.I32: CompactType.I32, +          TType.I64: CompactType.I64, +          TType.DOUBLE: CompactType.DOUBLE, +          TType.STRING: CompactType.BINARY, +          TType.STRUCT: CompactType.STRUCT, +          TType.LIST: CompactType.LIST, +          TType.SET: CompactType.SET, +          TType.MAP: CompactType.MAP +          } + +TTYPES = {} +for k, v in CTYPES.items(): +  TTYPES[v] = k +TTYPES[CompactType.FALSE] = TType.BOOL +del k +del v + +class TCompactProtocol(TProtocolBase): +  "Compact implementation of the Thrift protocol driver." + +  PROTOCOL_ID = 0x82 +  VERSION = 1 +  VERSION_MASK = 0x1f +  TYPE_MASK = 0xe0 +  TYPE_SHIFT_AMOUNT = 5 + +  def __init__(self, trans): +    TProtocolBase.__init__(self, trans) +    self.state = CLEAR +    self.__last_fid = 0 +    self.__bool_fid = None +    self.__bool_value = None +    self.__structs = [] +    self.__containers = [] + +  def __writeVarint(self, n): +    writeVarint(self.trans, n) + +  def writeMessageBegin(self, name, type, seqid): +    assert self.state == CLEAR +    self.__writeUByte(self.PROTOCOL_ID) +    self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT)) +    self.__writeVarint(seqid) +    self.__writeString(name) +    self.state = VALUE_WRITE + +  def writeMessageEnd(self): +    assert self.state == VALUE_WRITE +    self.state = CLEAR + +  def writeStructBegin(self, name): +    assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state +    self.__structs.append((self.state, self.__last_fid)) +    self.state = FIELD_WRITE +    self.__last_fid = 0 + +  def writeStructEnd(self): +    assert self.state == FIELD_WRITE +    self.state, self.__last_fid = self.__structs.pop() + +  def writeFieldStop(self): +    self.__writeByte(0) + +  def __writeFieldHeader(self, type, fid): +    delta = fid - self.__last_fid +    if 0 < delta <= 15: +      self.__writeUByte(delta << 4 | type) +    else: +      self.__writeByte(type) +      self.__writeI16(fid) +    self.__last_fid = fid + +  def writeFieldBegin(self, name, type, fid): +    assert self.state == FIELD_WRITE, self.state +    if type == TType.BOOL: +      self.state = BOOL_WRITE +      self.__bool_fid = fid +    else: +      self.state = VALUE_WRITE +      self.__writeFieldHeader(CTYPES[type], fid) + +  def writeFieldEnd(self): +    assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state +    self.state = FIELD_WRITE + +  def __writeUByte(self, byte): +    self.trans.write(pack('!B', byte)) + +  def __writeByte(self, byte): +    self.trans.write(pack('!b', byte)) + +  def __writeI16(self, i16): +    self.__writeVarint(makeZigZag(i16, 16)) + +  def __writeSize(self, i32): +    self.__writeVarint(i32) + +  def writeCollectionBegin(self, etype, size): +    assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state +    if size <= 14: +      self.__writeUByte(size << 4 | CTYPES[etype]) +    else: +      self.__writeUByte(0xf0 | CTYPES[etype]) +      self.__writeSize(size) +    self.__containers.append(self.state) +    self.state = CONTAINER_WRITE +  writeSetBegin = writeCollectionBegin +  writeListBegin = writeCollectionBegin + +  def writeMapBegin(self, ktype, vtype, size): +    assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state +    if size == 0: +      self.__writeByte(0) +    else: +      self.__writeSize(size) +      self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype]) +    self.__containers.append(self.state) +    self.state = CONTAINER_WRITE + +  def writeCollectionEnd(self): +    assert self.state == CONTAINER_WRITE, self.state +    self.state = self.__containers.pop() +  writeMapEnd = writeCollectionEnd +  writeSetEnd = writeCollectionEnd +  writeListEnd = writeCollectionEnd + +  def writeBool(self, bool): +    if self.state == BOOL_WRITE: +        if bool: +            ctype = CompactType.TRUE +        else: +            ctype = CompactType.FALSE +        self.__writeFieldHeader(ctype, self.__bool_fid) +    elif self.state == CONTAINER_WRITE: +       if bool: +           self.__writeByte(CompactType.TRUE) +       else: +           self.__writeByte(CompactType.FALSE) +    else: +      raise AssertionError, "Invalid state in compact protocol" + +  writeByte = writer(__writeByte) +  writeI16 = writer(__writeI16) + +  @writer +  def writeI32(self, i32): +    self.__writeVarint(makeZigZag(i32, 32)) + +  @writer +  def writeI64(self, i64): +    self.__writeVarint(makeZigZag(i64, 64)) + +  @writer +  def writeDouble(self, dub): +    self.trans.write(pack('!d', dub)) + +  def __writeString(self, s): +    self.__writeSize(len(s)) +    self.trans.write(s) +  writeString = writer(__writeString) + +  def readFieldBegin(self): +    assert self.state == FIELD_READ, self.state +    type = self.__readUByte() +    if type & 0x0f == TType.STOP: +      return (None, 0, 0) +    delta = type >> 4 +    if delta == 0: +      fid = self.__readI16() +    else: +      fid = self.__last_fid + delta +    self.__last_fid = fid +    type = type & 0x0f +    if type == CompactType.TRUE: +      self.state = BOOL_READ +      self.__bool_value = True +    elif type == CompactType.FALSE: +      self.state = BOOL_READ +      self.__bool_value = False +    else: +      self.state = VALUE_READ +    return (None, self.__getTType(type), fid) + +  def readFieldEnd(self): +    assert self.state in (VALUE_READ, BOOL_READ), self.state +    self.state = FIELD_READ + +  def __readUByte(self): +    result, = unpack('!B', self.trans.readAll(1)) +    return result + +  def __readByte(self): +    result, = unpack('!b', self.trans.readAll(1)) +    return result + +  def __readVarint(self): +    return readVarint(self.trans) + +  def __readZigZag(self): +    return fromZigZag(self.__readVarint()) + +  def __readSize(self): +    result = self.__readVarint() +    if result < 0: +      raise TException("Length < 0") +    return result + +  def readMessageBegin(self): +    assert self.state == CLEAR +    proto_id = self.__readUByte() +    if proto_id != self.PROTOCOL_ID: +      raise TProtocolException(TProtocolException.BAD_VERSION, +          'Bad protocol id in the message: %d' % proto_id) +    ver_type = self.__readUByte() +    type = (ver_type & self.TYPE_MASK) >> self.TYPE_SHIFT_AMOUNT +    version = ver_type & self.VERSION_MASK +    if version != self.VERSION: +      raise TProtocolException(TProtocolException.BAD_VERSION, +          'Bad version: %d (expect %d)' % (version, self.VERSION)) +    seqid = self.__readVarint() +    name = self.__readString() +    return (name, type, seqid) + +  def readMessageEnd(self): +    assert self.state == CLEAR +    assert len(self.__structs) == 0 + +  def readStructBegin(self): +    assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state +    self.__structs.append((self.state, self.__last_fid)) +    self.state = FIELD_READ +    self.__last_fid = 0 + +  def readStructEnd(self): +    assert self.state == FIELD_READ +    self.state, self.__last_fid = self.__structs.pop() + +  def readCollectionBegin(self): +    assert self.state in (VALUE_READ, CONTAINER_READ), self.state +    size_type = self.__readUByte() +    size = size_type >> 4 +    type = self.__getTType(size_type) +    if size == 15: +      size = self.__readSize() +    self.__containers.append(self.state) +    self.state = CONTAINER_READ +    return type, size +  readSetBegin = readCollectionBegin +  readListBegin = readCollectionBegin + +  def readMapBegin(self): +    assert self.state in (VALUE_READ, CONTAINER_READ), self.state +    size = self.__readSize() +    types = 0 +    if size > 0: +      types = self.__readUByte() +    vtype = self.__getTType(types) +    ktype = self.__getTType(types >> 4) +    self.__containers.append(self.state) +    self.state = CONTAINER_READ +    return (ktype, vtype, size) + +  def readCollectionEnd(self): +    assert self.state == CONTAINER_READ, self.state +    self.state = self.__containers.pop() +  readSetEnd = readCollectionEnd +  readListEnd = readCollectionEnd +  readMapEnd = readCollectionEnd + +  def readBool(self): +    if self.state == BOOL_READ: +      return self.__bool_value == CompactType.TRUE +    elif self.state == CONTAINER_READ: +      return self.__readByte() == CompactType.TRUE +    else: +      raise AssertionError, "Invalid state in compact protocol: %d" % self.state + +  readByte = reader(__readByte) +  __readI16 = __readZigZag +  readI16 = reader(__readZigZag) +  readI32 = reader(__readZigZag) +  readI64 = reader(__readZigZag) + +  @reader +  def readDouble(self): +    buff = self.trans.readAll(8) +    val, = unpack('!d', buff) +    return val + +  def __readString(self): +    len = self.__readSize() +    return self.trans.readAll(len) +  readString = reader(__readString) + +  def __getTType(self, byte): +    return TTYPES[byte & 0x0f] + + +class TCompactProtocolFactory: +  def __init__(self): +    pass + +  def getProtocol(self, trans): +    return TCompactProtocol(trans) diff --git a/lib/Python/Lib/thrift/protocol/TProtocol.py b/lib/Python/Lib/thrift/protocol/TProtocol.py new file mode 100644 index 000000000..7338ff68a --- /dev/null +++ b/lib/Python/Lib/thrift/protocol/TProtocol.py @@ -0,0 +1,404 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from thrift.Thrift import * + +class TProtocolException(TException): + +  """Custom Protocol Exception class""" + +  UNKNOWN = 0 +  INVALID_DATA = 1 +  NEGATIVE_SIZE = 2 +  SIZE_LIMIT = 3 +  BAD_VERSION = 4 + +  def __init__(self, type=UNKNOWN, message=None): +    TException.__init__(self, message) +    self.type = type + +class TProtocolBase: + +  """Base class for Thrift protocol driver.""" + +  def __init__(self, trans): +    self.trans = trans + +  def writeMessageBegin(self, name, type, seqid): +    pass + +  def writeMessageEnd(self): +    pass + +  def writeStructBegin(self, name): +    pass + +  def writeStructEnd(self): +    pass + +  def writeFieldBegin(self, name, type, id): +    pass + +  def writeFieldEnd(self): +    pass + +  def writeFieldStop(self): +    pass + +  def writeMapBegin(self, ktype, vtype, size): +    pass + +  def writeMapEnd(self): +    pass + +  def writeListBegin(self, etype, size): +    pass + +  def writeListEnd(self): +    pass + +  def writeSetBegin(self, etype, size): +    pass + +  def writeSetEnd(self): +    pass + +  def writeBool(self, bool): +    pass + +  def writeByte(self, byte): +    pass + +  def writeI16(self, i16): +    pass + +  def writeI32(self, i32): +    pass + +  def writeI64(self, i64): +    pass + +  def writeDouble(self, dub): +    pass + +  def writeString(self, str): +    pass + +  def readMessageBegin(self): +    pass + +  def readMessageEnd(self): +    pass + +  def readStructBegin(self): +    pass + +  def readStructEnd(self): +    pass + +  def readFieldBegin(self): +    pass + +  def readFieldEnd(self): +    pass + +  def readMapBegin(self): +    pass + +  def readMapEnd(self): +    pass + +  def readListBegin(self): +    pass + +  def readListEnd(self): +    pass + +  def readSetBegin(self): +    pass + +  def readSetEnd(self): +    pass + +  def readBool(self): +    pass + +  def readByte(self): +    pass + +  def readI16(self): +    pass + +  def readI32(self): +    pass + +  def readI64(self): +    pass + +  def readDouble(self): +    pass + +  def readString(self): +    pass + +  def skip(self, type): +    if type == TType.STOP: +      return +    elif type == TType.BOOL: +      self.readBool() +    elif type == TType.BYTE: +      self.readByte() +    elif type == TType.I16: +      self.readI16() +    elif type == TType.I32: +      self.readI32() +    elif type == TType.I64: +      self.readI64() +    elif type == TType.DOUBLE: +      self.readDouble() +    elif type == TType.STRING: +      self.readString() +    elif type == TType.STRUCT: +      name = self.readStructBegin() +      while True: +        (name, type, id) = self.readFieldBegin() +        if type == TType.STOP: +          break +        self.skip(type) +        self.readFieldEnd() +      self.readStructEnd() +    elif type == TType.MAP: +      (ktype, vtype, size) = self.readMapBegin() +      for i in range(size): +        self.skip(ktype) +        self.skip(vtype) +      self.readMapEnd() +    elif type == TType.SET: +      (etype, size) = self.readSetBegin() +      for i in range(size): +        self.skip(etype) +      self.readSetEnd() +    elif type == TType.LIST: +      (etype, size) = self.readListBegin() +      for i in range(size): +        self.skip(etype) +      self.readListEnd() + +  # tuple of: ( 'reader method' name, is_container boolean, 'writer_method' name ) +  _TTYPE_HANDLERS = ( +       (None, None, False), # 0 == TType,STOP +       (None, None, False), # 1 == TType.VOID # TODO: handle void? +       ('readBool', 'writeBool', False), # 2 == TType.BOOL +       ('readByte',  'writeByte', False), # 3 == TType.BYTE and I08 +       ('readDouble', 'writeDouble', False), # 4 == TType.DOUBLE +       (None, None, False), # 5, undefined +       ('readI16', 'writeI16', False), # 6 == TType.I16 +       (None, None, False), # 7, undefined +       ('readI32', 'writeI32', False), # 8 == TType.I32 +       (None, None, False), # 9, undefined +       ('readI64', 'writeI64', False), # 10 == TType.I64 +       ('readString', 'writeString', False), # 11 == TType.STRING and UTF7 +       ('readContainerStruct', 'writeContainerStruct', True), # 12 == TType.STRUCT +       ('readContainerMap', 'writeContainerMap', True), # 13 == TType.MAP +       ('readContainerSet', 'writeContainerSet', True), # 14 == TType.SET +       ('readContainerList', 'writeContainerList', True), # 15 == TType.LIST +       (None, None, False), # 16 == TType.UTF8 # TODO: handle utf8 types? +       (None, None, False)# 17 == TType.UTF16 # TODO: handle utf16 types? +      ) + +  def readFieldByTType(self, ttype, spec): +    try: +      (r_handler, w_handler, is_container) = self._TTYPE_HANDLERS[ttype] +    except IndexError: +      raise TProtocolException(type=TProtocolException.INVALID_DATA, +                               message='Invalid field type %d' % (ttype)) +    if r_handler is None: +      raise TProtocolException(type=TProtocolException.INVALID_DATA, +                               message='Invalid field type %d' % (ttype)) +    reader = getattr(self, r_handler) +    if not is_container: +      return reader() +    return reader(spec) + +  def readContainerList(self, spec): +    results = [] +    ttype, tspec = spec[0], spec[1] +    r_handler = self._TTYPE_HANDLERS[ttype][0] +    reader = getattr(self, r_handler) +    (list_type, list_len) = self.readListBegin() +    if tspec is None: +      # list values are simple types +      for idx in xrange(list_len): +        results.append(reader()) +    else: +      # this is like an inlined readFieldByTType +      container_reader = self._TTYPE_HANDLERS[list_type][0] +      val_reader = getattr(self, container_reader) +      for idx in xrange(list_len): +        val = val_reader(tspec) +        results.append(val) +    self.readListEnd() +    return results + +  def readContainerSet(self, spec): +    results = set() +    ttype, tspec = spec[0], spec[1] +    r_handler = self._TTYPE_HANDLERS[ttype][0] +    reader = getattr(self, r_handler) +    (set_type, set_len) = self.readSetBegin() +    if tspec is None: +      # set members are simple types +      for idx in xrange(set_len): +        results.add(reader()) +    else: +      container_reader = self._TTYPE_HANDLERS[set_type][0] +      val_reader = getattr(self, container_reader) +      for idx in xrange(set_len): +        results.add(val_reader(tspec))  +    self.readSetEnd() +    return results + +  def readContainerStruct(self, spec): +    (obj_class, obj_spec) = spec +    obj = obj_class() +    obj.read(self) +    return obj +   +  def readContainerMap(self, spec): +    results = dict() +    key_ttype, key_spec = spec[0], spec[1] +    val_ttype, val_spec = spec[2], spec[3] +    (map_ktype, map_vtype, map_len) = self.readMapBegin() +    # TODO: compare types we just decoded with thrift_spec and abort/skip if types disagree +    key_reader = getattr(self, self._TTYPE_HANDLERS[key_ttype][0]) +    val_reader = getattr(self, self._TTYPE_HANDLERS[val_ttype][0]) +    # list values are simple types +    for idx in xrange(map_len): +      if key_spec is None: +        k_val = key_reader() +      else: +        k_val = self.readFieldByTType(key_ttype, key_spec) +      if val_spec is None: +        v_val = val_reader() +      else: +        v_val = self.readFieldByTType(val_ttype, val_spec) +      # this raises a TypeError with unhashable keys types. i.e. d=dict(); d[[0,1]] = 2 fails +      results[k_val] = v_val +    self.readMapEnd() +    return results + +  def readStruct(self, obj, thrift_spec): +    self.readStructBegin() +    while True: +      (fname, ftype, fid) = self.readFieldBegin() +      if ftype == TType.STOP: +        break +      try: +        field = thrift_spec[fid] +      except IndexError: +        self.skip(ftype) +      else: +        if field is not None and ftype == field[1]: +          fname = field[2] +          fspec = field[3] +          val = self.readFieldByTType(ftype, fspec) +          setattr(obj, fname, val) +        else: +          self.skip(ftype) +      self.readFieldEnd() +    self.readStructEnd() + +  def writeContainerStruct(self, val, spec): +    val.write(self) + +  def writeContainerList(self, val, spec): +    self.writeListBegin(spec[0], len(val)) +    r_handler, w_handler, is_container  = self._TTYPE_HANDLERS[spec[0]] +    e_writer = getattr(self, w_handler) +    if not is_container: +      for elem in val: +        e_writer(elem) +    else: +      for elem in val: +        e_writer(elem, spec[1]) +    self.writeListEnd() + +  def writeContainerSet(self, val, spec): +    self.writeSetBegin(spec[0], len(val)) +    r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]] +    e_writer = getattr(self, w_handler) +    if not is_container: +      for elem in val: +        e_writer(elem) +    else: +      for elem in val: +        e_writer(elem, spec[1]) +    self.writeSetEnd() + +  def writeContainerMap(self, val, spec): +    k_type = spec[0] +    v_type = spec[2] +    ignore, ktype_name, k_is_container = self._TTYPE_HANDLERS[k_type] +    ignore, vtype_name, v_is_container = self._TTYPE_HANDLERS[v_type] +    k_writer = getattr(self, ktype_name) +    v_writer = getattr(self, vtype_name) +    self.writeMapBegin(k_type, v_type, len(val)) +    for m_key, m_val in val.iteritems(): +      if not k_is_container: +        k_writer(m_key) +      else: +        k_writer(m_key, spec[1]) +      if not v_is_container: +        v_writer(m_val) +      else: +        v_writer(m_val, spec[3]) +    self.writeMapEnd() + +  def writeStruct(self, obj, thrift_spec): +    self.writeStructBegin(obj.__class__.__name__) +    for field in thrift_spec: +      if field is None: +        continue +      fname = field[2] +      val = getattr(obj, fname) +      if val is None: +        # skip writing out unset fields +        continue +      fid = field[0] +      ftype = field[1] +      fspec = field[3] +      # get the writer method for this value +      self.writeFieldBegin(fname, ftype, fid) +      self.writeFieldByTType(ftype, val, fspec) +      self.writeFieldEnd() +    self.writeFieldStop() +    self.writeStructEnd() + +  def writeFieldByTType(self, ttype, val, spec): +    r_handler, w_handler, is_container = self._TTYPE_HANDLERS[ttype] +    writer = getattr(self, w_handler) +    if is_container: +      writer(val, spec) +    else: +      writer(val) + +class TProtocolFactory: +  def getProtocol(self, trans): +    pass + diff --git a/lib/Python/Lib/thrift/protocol/__init__.py b/lib/Python/Lib/thrift/protocol/__init__.py new file mode 100644 index 000000000..d53359b28 --- /dev/null +++ b/lib/Python/Lib/thrift/protocol/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary', 'TBase'] diff --git a/lib/Python/Lib/thrift/server/THttpServer.py b/lib/Python/Lib/thrift/server/THttpServer.py new file mode 100644 index 000000000..3047d9c00 --- /dev/null +++ b/lib/Python/Lib/thrift/server/THttpServer.py @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import BaseHTTPServer + +from thrift.server import TServer +from thrift.transport import TTransport + +class ResponseException(Exception): +  """Allows handlers to override the HTTP response + +  Normally, THttpServer always sends a 200 response.  If a handler wants +  to override this behavior (e.g., to simulate a misconfigured or +  overloaded web server during testing), it can raise a ResponseException. +  The function passed to the constructor will be called with the +  RequestHandler as its only argument. +  """ +  def __init__(self, handler): +    self.handler = handler + + +class THttpServer(TServer.TServer): +  """A simple HTTP-based Thrift server + +  This class is not very performant, but it is useful (for example) for +  acting as a mock version of an Apache-based PHP Thrift endpoint.""" + +  def __init__(self, processor, server_address, +      inputProtocolFactory, outputProtocolFactory = None, +      server_class = BaseHTTPServer.HTTPServer): +    """Set up protocol factories and HTTP server. + +    See BaseHTTPServer for server_address. +    See TServer for protocol factories.""" + +    if outputProtocolFactory is None: +      outputProtocolFactory = inputProtocolFactory + +    TServer.TServer.__init__(self, processor, None, None, None, +        inputProtocolFactory, outputProtocolFactory) + +    thttpserver = self + +    class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler): +      def do_POST(self): +        # Don't care about the request path. +        itrans = TTransport.TFileObjectTransport(self.rfile) +        otrans = TTransport.TFileObjectTransport(self.wfile) +        itrans = TTransport.TBufferedTransport(itrans, int(self.headers['Content-Length'])) +        otrans = TTransport.TMemoryBuffer() +        iprot = thttpserver.inputProtocolFactory.getProtocol(itrans) +        oprot = thttpserver.outputProtocolFactory.getProtocol(otrans) +        try: +          thttpserver.processor.process(iprot, oprot) +        except ResponseException, exn: +          exn.handler(self) +        else: +          self.send_response(200) +          self.send_header("content-type", "application/x-thrift") +          self.end_headers() +          self.wfile.write(otrans.getvalue()) + +    self.httpd = server_class(server_address, RequestHander) + +  def serve(self): +    self.httpd.serve_forever() diff --git a/lib/Python/Lib/thrift/server/TNonblockingServer.py b/lib/Python/Lib/thrift/server/TNonblockingServer.py new file mode 100644 index 000000000..ea348a0b6 --- /dev/null +++ b/lib/Python/Lib/thrift/server/TNonblockingServer.py @@ -0,0 +1,310 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""Implementation of non-blocking server. + +The main idea of the server is reciving and sending requests +only from main thread. + +It also makes thread pool server in tasks terms, not connections. +""" +import threading +import socket +import Queue +import select +import struct +import logging + +from thrift.transport import TTransport +from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory + +__all__ = ['TNonblockingServer'] + +class Worker(threading.Thread): +    """Worker is a small helper to process incoming connection.""" +    def __init__(self, queue): +        threading.Thread.__init__(self) +        self.queue = queue + +    def run(self): +        """Process queries from task queue, stop if processor is None.""" +        while True: +            try: +                processor, iprot, oprot, otrans, callback = self.queue.get() +                if processor is None: +                    break +                processor.process(iprot, oprot) +                callback(True, otrans.getvalue()) +            except Exception: +                logging.exception("Exception while processing request") +                callback(False, '') + +WAIT_LEN = 0 +WAIT_MESSAGE = 1 +WAIT_PROCESS = 2 +SEND_ANSWER = 3 +CLOSED = 4 + +def locked(func): +    "Decorator which locks self.lock." +    def nested(self, *args, **kwargs): +        self.lock.acquire() +        try: +            return func(self, *args, **kwargs) +        finally: +            self.lock.release() +    return nested + +def socket_exception(func): +    "Decorator close object on socket.error." +    def read(self, *args, **kwargs): +        try: +            return func(self, *args, **kwargs) +        except socket.error: +            self.close() +    return read + +class Connection: +    """Basic class is represented connection. +     +    It can be in state: +        WAIT_LEN --- connection is reading request len. +        WAIT_MESSAGE --- connection is reading request. +        WAIT_PROCESS --- connection has just read whole request and  +            waits for call ready routine. +        SEND_ANSWER --- connection is sending answer string (including length +            of answer). +        CLOSED --- socket was closed and connection should be deleted. +    """ +    def __init__(self, new_socket, wake_up): +        self.socket = new_socket +        self.socket.setblocking(False) +        self.status = WAIT_LEN +        self.len = 0 +        self.message = '' +        self.lock = threading.Lock() +        self.wake_up = wake_up + +    def _read_len(self): +        """Reads length of request. +         +        It's really paranoic routine and it may be replaced by  +        self.socket.recv(4).""" +        read = self.socket.recv(4 - len(self.message)) +        if len(read) == 0: +            # if we read 0 bytes and self.message is empty, it means client close  +            # connection +            if len(self.message) != 0: +                logging.error("can't read frame size from socket") +            self.close() +            return +        self.message += read +        if len(self.message) == 4: +            self.len, = struct.unpack('!i', self.message) +            if self.len < 0: +                logging.error("negative frame size, it seems client"\ +                    " doesn't use FramedTransport") +                self.close() +            elif self.len == 0: +                logging.error("empty frame, it's really strange") +                self.close() +            else: +                self.message = '' +                self.status = WAIT_MESSAGE + +    @socket_exception +    def read(self): +        """Reads data from stream and switch state.""" +        assert self.status in (WAIT_LEN, WAIT_MESSAGE) +        if self.status == WAIT_LEN: +            self._read_len() +            # go back to the main loop here for simplicity instead of +            # falling through, even though there is a good chance that +            # the message is already available +        elif self.status == WAIT_MESSAGE: +            read = self.socket.recv(self.len - len(self.message)) +            if len(read) == 0: +                logging.error("can't read frame from socket (get %d of %d bytes)" % +                    (len(self.message), self.len)) +                self.close() +                return +            self.message += read +            if len(self.message) == self.len: +                self.status = WAIT_PROCESS + +    @socket_exception +    def write(self): +        """Writes data from socket and switch state.""" +        assert self.status == SEND_ANSWER +        sent = self.socket.send(self.message) +        if sent == len(self.message): +            self.status = WAIT_LEN +            self.message = '' +            self.len = 0 +        else: +            self.message = self.message[sent:] + +    @locked +    def ready(self, all_ok, message): +        """Callback function for switching state and waking up main thread. +         +        This function is the only function witch can be called asynchronous. +         +        The ready can switch Connection to three states: +            WAIT_LEN if request was oneway. +            SEND_ANSWER if request was processed in normal way. +            CLOSED if request throws unexpected exception. +         +        The one wakes up main thread. +        """ +        assert self.status == WAIT_PROCESS +        if not all_ok: +            self.close() +            self.wake_up() +            return +        self.len = '' +        if len(message) == 0: +            # it was a oneway request, do not write answer +            self.message = '' +            self.status = WAIT_LEN +        else: +            self.message = struct.pack('!i', len(message)) + message +            self.status = SEND_ANSWER +        self.wake_up() + +    @locked +    def is_writeable(self): +        "Returns True if connection should be added to write list of select." +        return self.status == SEND_ANSWER + +    # it's not necessary, but... +    @locked +    def is_readable(self): +        "Returns True if connection should be added to read list of select." +        return self.status in (WAIT_LEN, WAIT_MESSAGE) + +    @locked +    def is_closed(self): +        "Returns True if connection is closed." +        return self.status == CLOSED + +    def fileno(self): +        "Returns the file descriptor of the associated socket." +        return self.socket.fileno() + +    def close(self): +        "Closes connection" +        self.status = CLOSED +        self.socket.close() + +class TNonblockingServer: +    """Non-blocking server.""" +    def __init__(self, processor, lsocket, inputProtocolFactory=None,  +            outputProtocolFactory=None, threads=10): +        self.processor = processor +        self.socket = lsocket +        self.in_protocol = inputProtocolFactory or TBinaryProtocolFactory() +        self.out_protocol = outputProtocolFactory or self.in_protocol +        self.threads = int(threads) +        self.clients = {} +        self.tasks = Queue.Queue() +        self._read, self._write = socket.socketpair() +        self.prepared = False + +    def setNumThreads(self, num): +        """Set the number of worker threads that should be created.""" +        # implement ThreadPool interface +        assert not self.prepared, "You can't change number of threads for working server" +        self.threads = num + +    def prepare(self): +        """Prepares server for serve requests.""" +        self.socket.listen() +        for _ in xrange(self.threads): +            thread = Worker(self.tasks) +            thread.setDaemon(True) +            thread.start() +        self.prepared = True + +    def wake_up(self): +        """Wake up main thread. +         +        The server usualy waits in select call in we should terminate one. +        The simplest way is using socketpair. +         +        Select always wait to read from the first socket of socketpair. +         +        In this case, we can just write anything to the second socket from +        socketpair.""" +        self._write.send('1') + +    def _select(self): +        """Does select on open connections.""" +        readable = [self.socket.handle.fileno(), self._read.fileno()] +        writable = [] +        for i, connection in self.clients.items(): +            if connection.is_readable(): +                readable.append(connection.fileno()) +            if connection.is_writeable(): +                writable.append(connection.fileno()) +            if connection.is_closed(): +                del self.clients[i] +        return select.select(readable, writable, readable) +         +    def handle(self): +        """Handle requests. +        +        WARNING! You must call prepare BEFORE calling handle. +        """ +        assert self.prepared, "You have to call prepare before handle" +        rset, wset, xset = self._select() +        for readable in rset: +            if readable == self._read.fileno(): +                # don't care i just need to clean readable flag +                self._read.recv(1024)  +            elif readable == self.socket.handle.fileno(): +                client = self.socket.accept().handle +                self.clients[client.fileno()] = Connection(client, self.wake_up) +            else: +                connection = self.clients[readable] +                connection.read() +                if connection.status == WAIT_PROCESS: +                    itransport = TTransport.TMemoryBuffer(connection.message) +                    otransport = TTransport.TMemoryBuffer() +                    iprot = self.in_protocol.getProtocol(itransport) +                    oprot = self.out_protocol.getProtocol(otransport) +                    self.tasks.put([self.processor, iprot, oprot,  +                                    otransport, connection.ready]) +        for writeable in wset: +            self.clients[writeable].write() +        for oob in xset: +            self.clients[oob].close() +            del self.clients[oob] + +    def close(self): +        """Closes the server.""" +        for _ in xrange(self.threads): +            self.tasks.put([None, None, None, None, None]) +        self.socket.close() +        self.prepared = False +         +    def serve(self): +        """Serve forever.""" +        self.prepare() +        while True: +            self.handle() diff --git a/lib/Python/Lib/thrift/server/TProcessPoolServer.py b/lib/Python/Lib/thrift/server/TProcessPoolServer.py new file mode 100644 index 000000000..7ed814a88 --- /dev/null +++ b/lib/Python/Lib/thrift/server/TProcessPoolServer.py @@ -0,0 +1,125 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + + +import logging +from multiprocessing import  Process, Value, Condition, reduction + +from TServer import TServer +from thrift.transport.TTransport import TTransportException + +class TProcessPoolServer(TServer): + +    """ +    Server with a fixed size pool of worker subprocesses which service requests. +    Note that if you need shared state between the handlers - it's up to you! +    Written by Dvir Volk, doat.com +    """ + +    def __init__(self, * args): +        TServer.__init__(self, *args) +        self.numWorkers = 10 +        self.workers = [] +        self.isRunning = Value('b', False) +        self.stopCondition = Condition() +        self.postForkCallback = None + +    def setPostForkCallback(self, callback): +        if not callable(callback): +            raise TypeError("This is not a callback!") +        self.postForkCallback = callback + +    def setNumWorkers(self, num): +        """Set the number of worker threads that should be created""" +        self.numWorkers = num + +    def workerProcess(self): +        """Loop around getting clients from the shared queue and process them.""" + +        if self.postForkCallback: +            self.postForkCallback() + +        while self.isRunning.value == True: +            try: +                client = self.serverTransport.accept() +                self.serveClient(client) +            except (KeyboardInterrupt, SystemExit): +                return 0 +            except Exception, x: +                logging.exception(x) + +    def serveClient(self, client): +        """Process input/output from a client for as long as possible""" +        itrans = self.inputTransportFactory.getTransport(client) +        otrans = self.outputTransportFactory.getTransport(client) +        iprot = self.inputProtocolFactory.getProtocol(itrans) +        oprot = self.outputProtocolFactory.getProtocol(otrans) + +        try: +            while True: +                self.processor.process(iprot, oprot) +        except TTransportException, tx: +            pass +        except Exception, x: +            logging.exception(x) + +        itrans.close() +        otrans.close() + + +    def serve(self): +        """Start a fixed number of worker threads and put client into a queue""" + +        #this is a shared state that can tell the workers to exit when set as false +        self.isRunning.value = True + +        #first bind and listen to the port +        self.serverTransport.listen() + +        #fork the children +        for i in range(self.numWorkers): +            try: +                w = Process(target=self.workerProcess) +                w.daemon = True +                w.start() +                self.workers.append(w) +            except Exception, x: +                logging.exception(x) + +        #wait until the condition is set by stop() + +        while True: + +            self.stopCondition.acquire() +            try: +                self.stopCondition.wait() +                break +            except (SystemExit, KeyboardInterrupt): +		break +            except Exception, x: +                logging.exception(x) + +        self.isRunning.value = False + +    def stop(self): +        self.isRunning.value = False +        self.stopCondition.acquire() +        self.stopCondition.notify() +        self.stopCondition.release() + diff --git a/lib/Python/Lib/thrift/server/TServer.py b/lib/Python/Lib/thrift/server/TServer.py new file mode 100644 index 000000000..8456e2d40 --- /dev/null +++ b/lib/Python/Lib/thrift/server/TServer.py @@ -0,0 +1,274 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import logging +import sys +import os +import traceback +import threading +import Queue + +from thrift.Thrift import TProcessor +from thrift.transport import TTransport +from thrift.protocol import TBinaryProtocol + +class TServer: + +  """Base interface for a server, which must have a serve method.""" + +  """ 3 constructors for all servers: +  1) (processor, serverTransport) +  2) (processor, serverTransport, transportFactory, protocolFactory) +  3) (processor, serverTransport, +      inputTransportFactory, outputTransportFactory, +      inputProtocolFactory, outputProtocolFactory)""" +  def __init__(self, *args): +    if (len(args) == 2): +      self.__initArgs__(args[0], args[1], +                        TTransport.TTransportFactoryBase(), +                        TTransport.TTransportFactoryBase(), +                        TBinaryProtocol.TBinaryProtocolFactory(), +                        TBinaryProtocol.TBinaryProtocolFactory()) +    elif (len(args) == 4): +      self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3]) +    elif (len(args) == 6): +      self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5]) + +  def __initArgs__(self, processor, serverTransport, +                   inputTransportFactory, outputTransportFactory, +                   inputProtocolFactory, outputProtocolFactory): +    self.processor = processor +    self.serverTransport = serverTransport +    self.inputTransportFactory = inputTransportFactory +    self.outputTransportFactory = outputTransportFactory +    self.inputProtocolFactory = inputProtocolFactory +    self.outputProtocolFactory = outputProtocolFactory + +  def serve(self): +    pass + +class TSimpleServer(TServer): + +  """Simple single-threaded server that just pumps around one transport.""" + +  def __init__(self, *args): +    TServer.__init__(self, *args) + +  def serve(self): +    self.serverTransport.listen() +    while True: +      client = self.serverTransport.accept() +      itrans = self.inputTransportFactory.getTransport(client) +      otrans = self.outputTransportFactory.getTransport(client) +      iprot = self.inputProtocolFactory.getProtocol(itrans) +      oprot = self.outputProtocolFactory.getProtocol(otrans) +      try: +        while True: +          self.processor.process(iprot, oprot) +      except TTransport.TTransportException, tx: +        pass +      except Exception, x: +        logging.exception(x) + +      itrans.close() +      otrans.close() + +class TThreadedServer(TServer): + +  """Threaded server that spawns a new thread per each connection.""" + +  def __init__(self, *args, **kwargs): +    TServer.__init__(self, *args) +    self.daemon = kwargs.get("daemon", False) + +  def serve(self): +    self.serverTransport.listen() +    while True: +      try: +        client = self.serverTransport.accept() +        t = threading.Thread(target = self.handle, args=(client,)) +        t.setDaemon(self.daemon) +        t.start() +      except KeyboardInterrupt: +        raise +      except Exception, x: +        logging.exception(x) + +  def handle(self, client): +    itrans = self.inputTransportFactory.getTransport(client) +    otrans = self.outputTransportFactory.getTransport(client) +    iprot = self.inputProtocolFactory.getProtocol(itrans) +    oprot = self.outputProtocolFactory.getProtocol(otrans) +    try: +      while True: +        self.processor.process(iprot, oprot) +    except TTransport.TTransportException, tx: +      pass +    except Exception, x: +      logging.exception(x) + +    itrans.close() +    otrans.close() + +class TThreadPoolServer(TServer): + +  """Server with a fixed size pool of threads which service requests.""" + +  def __init__(self, *args, **kwargs): +    TServer.__init__(self, *args) +    self.clients = Queue.Queue() +    self.threads = 10 +    self.daemon = kwargs.get("daemon", False) + +  def setNumThreads(self, num): +    """Set the number of worker threads that should be created""" +    self.threads = num + +  def serveThread(self): +    """Loop around getting clients from the shared queue and process them.""" +    while True: +      try: +        client = self.clients.get() +        self.serveClient(client) +      except Exception, x: +        logging.exception(x) + +  def serveClient(self, client): +    """Process input/output from a client for as long as possible""" +    itrans = self.inputTransportFactory.getTransport(client) +    otrans = self.outputTransportFactory.getTransport(client) +    iprot = self.inputProtocolFactory.getProtocol(itrans) +    oprot = self.outputProtocolFactory.getProtocol(otrans) +    try: +      while True: +        self.processor.process(iprot, oprot) +    except TTransport.TTransportException, tx: +      pass +    except Exception, x: +      logging.exception(x) + +    itrans.close() +    otrans.close() + +  def serve(self): +    """Start a fixed number of worker threads and put client into a queue""" +    for i in range(self.threads): +      try: +        t = threading.Thread(target = self.serveThread) +        t.setDaemon(self.daemon) +        t.start() +      except Exception, x: +        logging.exception(x) + +    # Pump the socket for clients +    self.serverTransport.listen() +    while True: +      try: +        client = self.serverTransport.accept() +        self.clients.put(client) +      except Exception, x: +        logging.exception(x) + + +class TForkingServer(TServer): + +  """A Thrift server that forks a new process for each request""" +  """ +  This is more scalable than the threaded server as it does not cause +  GIL contention. + +  Note that this has different semantics from the threading server. +  Specifically, updates to shared variables will no longer be shared. +  It will also not work on windows. + +  This code is heavily inspired by SocketServer.ForkingMixIn in the +  Python stdlib. +  """ + +  def __init__(self, *args): +    TServer.__init__(self, *args) +    self.children = [] + +  def serve(self): +    def try_close(file): +      try: +        file.close() +      except IOError, e: +        logging.warning(e, exc_info=True) + + +    self.serverTransport.listen() +    while True: +      client = self.serverTransport.accept() +      try: +        pid = os.fork() + +        if pid: # parent +          # add before collect, otherwise you race w/ waitpid +          self.children.append(pid) +          self.collect_children() + +          # Parent must close socket or the connection may not get +          # closed promptly +          itrans = self.inputTransportFactory.getTransport(client) +          otrans = self.outputTransportFactory.getTransport(client) +          try_close(itrans) +          try_close(otrans) +        else: +          itrans = self.inputTransportFactory.getTransport(client) +          otrans = self.outputTransportFactory.getTransport(client) + +          iprot = self.inputProtocolFactory.getProtocol(itrans) +          oprot = self.outputProtocolFactory.getProtocol(otrans) + +          ecode = 0 +          try: +            try: +              while True: +                self.processor.process(iprot, oprot) +            except TTransport.TTransportException, tx: +              pass +            except Exception, e: +              logging.exception(e) +              ecode = 1 +          finally: +            try_close(itrans) +            try_close(otrans) + +          os._exit(ecode) + +      except TTransport.TTransportException, tx: +        pass +      except Exception, x: +        logging.exception(x) + + +  def collect_children(self): +    while self.children: +      try: +        pid, status = os.waitpid(0, os.WNOHANG) +      except os.error: +        pid = None + +      if pid: +        self.children.remove(pid) +      else: +        break + + diff --git a/lib/Python/Lib/thrift/server/__init__.py b/lib/Python/Lib/thrift/server/__init__.py new file mode 100644 index 000000000..1bf6e254e --- /dev/null +++ b/lib/Python/Lib/thrift/server/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +__all__ = ['TServer', 'TNonblockingServer'] diff --git a/lib/Python/Lib/thrift/transport/THttpClient.py b/lib/Python/Lib/thrift/transport/THttpClient.py new file mode 100644 index 000000000..50269785c --- /dev/null +++ b/lib/Python/Lib/thrift/transport/THttpClient.py @@ -0,0 +1,126 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from TTransport import * +from cStringIO import StringIO + +import urlparse +import httplib +import warnings +import socket + +class THttpClient(TTransportBase): + +  """Http implementation of TTransport base.""" + +  def __init__(self, uri_or_host, port=None, path=None): +    """THttpClient supports two different types constructor parameters. + +    THttpClient(host, port, path) - deprecated +    THttpClient(uri) + +    Only the second supports https.""" + +    if port is not None: +      warnings.warn("Please use the THttpClient('http://host:port/path') syntax", DeprecationWarning, stacklevel=2) +      self.host = uri_or_host +      self.port = port +      assert path +      self.path = path +      self.scheme = 'http' +    else: +      parsed = urlparse.urlparse(uri_or_host) +      self.scheme = parsed.scheme +      assert self.scheme in ('http', 'https') +      if self.scheme == 'http': +        self.port = parsed.port or httplib.HTTP_PORT +      elif self.scheme == 'https': +        self.port = parsed.port or httplib.HTTPS_PORT +      self.host = parsed.hostname +      self.path = parsed.path +      if parsed.query: +        self.path += '?%s' % parsed.query +    self.__wbuf = StringIO() +    self.__http = None +    self.__timeout = None + +  def open(self): +    if self.scheme == 'http': +      self.__http = httplib.HTTP(self.host, self.port) +    else: +      self.__http = httplib.HTTPS(self.host, self.port) + +  def close(self): +    self.__http.close() +    self.__http = None + +  def isOpen(self): +    return self.__http != None + +  def setTimeout(self, ms): +    if not hasattr(socket, 'getdefaulttimeout'): +      raise NotImplementedError + +    if ms is None: +      self.__timeout = None +    else: +      self.__timeout = ms/1000.0 + +  def read(self, sz): +    return self.__http.file.read(sz) + +  def write(self, buf): +    self.__wbuf.write(buf) + +  def __withTimeout(f): +    def _f(*args, **kwargs): +      orig_timeout = socket.getdefaulttimeout() +      socket.setdefaulttimeout(args[0].__timeout) +      result = f(*args, **kwargs) +      socket.setdefaulttimeout(orig_timeout) +      return result +    return _f + +  def flush(self): +    if self.isOpen(): +      self.close() +    self.open(); + +    # Pull data out of buffer +    data = self.__wbuf.getvalue() +    self.__wbuf = StringIO() + +    # HTTP request +    self.__http.putrequest('POST', self.path) + +    # Write headers +    self.__http.putheader('Host', self.host) +    self.__http.putheader('Content-Type', 'application/x-thrift') +    self.__http.putheader('Content-Length', str(len(data))) +    self.__http.endheaders() + +    # Write payload +    self.__http.send(data) + +    # Get reply to flush the request +    self.code, self.message, self.headers = self.__http.getreply() + +  # Decorate if we know how to timeout +  if hasattr(socket, 'getdefaulttimeout'): +    flush = __withTimeout(flush) diff --git a/lib/Python/Lib/thrift/transport/TSocket.py b/lib/Python/Lib/thrift/transport/TSocket.py new file mode 100644 index 000000000..4e0e1874f --- /dev/null +++ b/lib/Python/Lib/thrift/transport/TSocket.py @@ -0,0 +1,163 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from TTransport import * +import os +import errno +import socket +import sys + +class TSocketBase(TTransportBase): +  def _resolveAddr(self): +    if self._unix_socket is not None: +      return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, self._unix_socket)] +    else: +      return socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE | socket.AI_ADDRCONFIG) + +  def close(self): +    if self.handle: +      self.handle.close() +      self.handle = None + +class TSocket(TSocketBase): +  """Socket implementation of TTransport base.""" + +  def __init__(self, host='localhost', port=9090, unix_socket=None): +    """Initialize a TSocket + +    @param host(str)  The host to connect to. +    @param port(int)  The (TCP) port to connect to. +    @param unix_socket(str)  The filename of a unix socket to connect to. +                             (host and port will be ignored.) +    """ + +    self.host = host +    self.port = port +    self.handle = None +    self._unix_socket = unix_socket +    self._timeout = None + +  def setHandle(self, h): +    self.handle = h + +  def isOpen(self): +    return self.handle is not None + +  def setTimeout(self, ms): +    if ms is None: +      self._timeout = None +    else: +      self._timeout = ms/1000.0 + +    if self.handle is not None: +      self.handle.settimeout(self._timeout) + +  def open(self): +    try: +      res0 = self._resolveAddr() +      for res in res0: +        self.handle = socket.socket(res[0], res[1]) +        self.handle.settimeout(self._timeout) +        try: +          self.handle.connect(res[4]) +        except socket.error, e: +          if res is not res0[-1]: +            continue +          else: +            raise e +        break +    except socket.error, e: +      if self._unix_socket: +        message = 'Could not connect to socket %s' % self._unix_socket +      else: +        message = 'Could not connect to %s:%d' % (self.host, self.port) +      raise TTransportException(type=TTransportException.NOT_OPEN, message=message) + +  def read(self, sz): +    try: +      buff = self.handle.recv(sz) +    except socket.error, e: +      if (e.args[0] == errno.ECONNRESET and +          (sys.platform == 'darwin' or sys.platform.startswith('freebsd'))): +        # freebsd and Mach don't follow POSIX semantic of recv +        # and fail with ECONNRESET if peer performed shutdown. +        # See corresponding comment and code in TSocket::read() +        # in lib/cpp/src/transport/TSocket.cpp. +        self.close() +        # Trigger the check to raise the END_OF_FILE exception below. +        buff = '' +      else: +        raise +    if len(buff) == 0: +      raise TTransportException(type=TTransportException.END_OF_FILE, message='TSocket read 0 bytes') +    return buff + +  def write(self, buff): +    if not self.handle: +      raise TTransportException(type=TTransportException.NOT_OPEN, message='Transport not open') +    sent = 0 +    have = len(buff) +    while sent < have: +      plus = self.handle.send(buff) +      if plus == 0: +        raise TTransportException(type=TTransportException.END_OF_FILE, message='TSocket sent 0 bytes') +      sent += plus +      buff = buff[plus:] + +  def flush(self): +    pass + +class TServerSocket(TSocketBase, TServerTransportBase): +  """Socket implementation of TServerTransport base.""" + +  def __init__(self, host=None, port=9090, unix_socket=None): +    self.host = host +    self.port = port +    self._unix_socket = unix_socket +    self.handle = None + +  def listen(self): +    res0 = self._resolveAddr() +    for res in res0: +      if res[0] is socket.AF_INET6 or res is res0[-1]: +        break + +    # We need remove the old unix socket if the file exists and +    # nobody is listening on it. +    if self._unix_socket: +      tmp = socket.socket(res[0], res[1]) +      try: +        tmp.connect(res[4]) +      except socket.error, err: +        eno, message = err.args +        if eno == errno.ECONNREFUSED: +          os.unlink(res[4]) + +    self.handle = socket.socket(res[0], res[1]) +    self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) +    if hasattr(self.handle, 'settimeout'): +      self.handle.settimeout(None) +    self.handle.bind(res[4]) +    self.handle.listen(128) + +  def accept(self): +    client, addr = self.handle.accept() +    result = TSocket() +    result.setHandle(client) +    return result diff --git a/lib/Python/Lib/thrift/transport/TTransport.py b/lib/Python/Lib/thrift/transport/TTransport.py new file mode 100644 index 000000000..12e51a9bf --- /dev/null +++ b/lib/Python/Lib/thrift/transport/TTransport.py @@ -0,0 +1,331 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from cStringIO import StringIO +from struct import pack,unpack +from thrift.Thrift import TException + +class TTransportException(TException): + +  """Custom Transport Exception class""" + +  UNKNOWN = 0 +  NOT_OPEN = 1 +  ALREADY_OPEN = 2 +  TIMED_OUT = 3 +  END_OF_FILE = 4 + +  def __init__(self, type=UNKNOWN, message=None): +    TException.__init__(self, message) +    self.type = type + +class TTransportBase: + +  """Base class for Thrift transport layer.""" + +  def isOpen(self): +    pass + +  def open(self): +    pass + +  def close(self): +    pass + +  def read(self, sz): +    pass + +  def readAll(self, sz): +    buff = '' +    have = 0 +    while (have < sz): +      chunk = self.read(sz-have) +      have += len(chunk) +      buff += chunk + +      if len(chunk) == 0: +        raise EOFError() + +    return buff + +  def write(self, buf): +    pass + +  def flush(self): +    pass + +# This class should be thought of as an interface. +class CReadableTransport: +  """base class for transports that are readable from C""" + +  # TODO(dreiss): Think about changing this interface to allow us to use +  #               a (Python, not c) StringIO instead, because it allows +  #               you to write after reading. + +  # NOTE: This is a classic class, so properties will NOT work +  #       correctly for setting. +  @property +  def cstringio_buf(self): +    """A cStringIO buffer that contains the current chunk we are reading.""" +    pass + +  def cstringio_refill(self, partialread, reqlen): +    """Refills cstringio_buf. + +    Returns the currently used buffer (which can but need not be the same as +    the old cstringio_buf). partialread is what the C code has read from the +    buffer, and should be inserted into the buffer before any more reads.  The +    return value must be a new, not borrowed reference.  Something along the +    lines of self._buf should be fine. + +    If reqlen bytes can't be read, throw EOFError. +    """ +    pass + +class TServerTransportBase: + +  """Base class for Thrift server transports.""" + +  def listen(self): +    pass + +  def accept(self): +    pass + +  def close(self): +    pass + +class TTransportFactoryBase: + +  """Base class for a Transport Factory""" + +  def getTransport(self, trans): +    return trans + +class TBufferedTransportFactory: + +  """Factory transport that builds buffered transports""" + +  def getTransport(self, trans): +    buffered = TBufferedTransport(trans) +    return buffered + + +class TBufferedTransport(TTransportBase,CReadableTransport): + +  """Class that wraps another transport and buffers its I/O. + +  The implementation uses a (configurable) fixed-size read buffer +  but buffers all writes until a flush is performed. +  """ + +  DEFAULT_BUFFER = 4096 + +  def __init__(self, trans, rbuf_size = DEFAULT_BUFFER): +    self.__trans = trans +    self.__wbuf = StringIO() +    self.__rbuf = StringIO("") +    self.__rbuf_size = rbuf_size + +  def isOpen(self): +    return self.__trans.isOpen() + +  def open(self): +    return self.__trans.open() + +  def close(self): +    return self.__trans.close() + +  def read(self, sz): +    ret = self.__rbuf.read(sz) +    if len(ret) != 0: +      return ret + +    self.__rbuf = StringIO(self.__trans.read(max(sz, self.__rbuf_size))) +    return self.__rbuf.read(sz) + +  def write(self, buf): +    self.__wbuf.write(buf) + +  def flush(self): +    out = self.__wbuf.getvalue() +    # reset wbuf before write/flush to preserve state on underlying failure +    self.__wbuf = StringIO() +    self.__trans.write(out) +    self.__trans.flush() + +  # Implement the CReadableTransport interface. +  @property +  def cstringio_buf(self): +    return self.__rbuf + +  def cstringio_refill(self, partialread, reqlen): +    retstring = partialread +    if reqlen < self.__rbuf_size: +      # try to make a read of as much as we can. +      retstring += self.__trans.read(self.__rbuf_size) + +    # but make sure we do read reqlen bytes. +    if len(retstring) < reqlen: +      retstring += self.__trans.readAll(reqlen - len(retstring)) + +    self.__rbuf = StringIO(retstring) +    return self.__rbuf + +class TMemoryBuffer(TTransportBase, CReadableTransport): +  """Wraps a cStringIO object as a TTransport. + +  NOTE: Unlike the C++ version of this class, you cannot write to it +        then immediately read from it.  If you want to read from a +        TMemoryBuffer, you must either pass a string to the constructor. +  TODO(dreiss): Make this work like the C++ version. +  """ + +  def __init__(self, value=None): +    """value -- a value to read from for stringio + +    If value is set, this will be a transport for reading, +    otherwise, it is for writing""" +    if value is not None: +      self._buffer = StringIO(value) +    else: +      self._buffer = StringIO() + +  def isOpen(self): +    return not self._buffer.closed + +  def open(self): +    pass + +  def close(self): +    self._buffer.close() + +  def read(self, sz): +    return self._buffer.read(sz) + +  def write(self, buf): +    self._buffer.write(buf) + +  def flush(self): +    pass + +  def getvalue(self): +    return self._buffer.getvalue() + +  # Implement the CReadableTransport interface. +  @property +  def cstringio_buf(self): +    return self._buffer + +  def cstringio_refill(self, partialread, reqlen): +    # only one shot at reading... +    raise EOFError() + +class TFramedTransportFactory: + +  """Factory transport that builds framed transports""" + +  def getTransport(self, trans): +    framed = TFramedTransport(trans) +    return framed + + +class TFramedTransport(TTransportBase, CReadableTransport): + +  """Class that wraps another transport and frames its I/O when writing.""" + +  def __init__(self, trans,): +    self.__trans = trans +    self.__rbuf = StringIO() +    self.__wbuf = StringIO() + +  def isOpen(self): +    return self.__trans.isOpen() + +  def open(self): +    return self.__trans.open() + +  def close(self): +    return self.__trans.close() + +  def read(self, sz): +    ret = self.__rbuf.read(sz) +    if len(ret) != 0: +      return ret + +    self.readFrame() +    return self.__rbuf.read(sz) + +  def readFrame(self): +    buff = self.__trans.readAll(4) +    sz, = unpack('!i', buff) +    self.__rbuf = StringIO(self.__trans.readAll(sz)) + +  def write(self, buf): +    self.__wbuf.write(buf) + +  def flush(self): +    wout = self.__wbuf.getvalue() +    wsz = len(wout) +    # reset wbuf before write/flush to preserve state on underlying failure +    self.__wbuf = StringIO() +    # N.B.: Doing this string concatenation is WAY cheaper than making +    # two separate calls to the underlying socket object. Socket writes in +    # Python turn out to be REALLY expensive, but it seems to do a pretty +    # good job of managing string buffer operations without excessive copies +    buf = pack("!i", wsz) + wout +    self.__trans.write(buf) +    self.__trans.flush() + +  # Implement the CReadableTransport interface. +  @property +  def cstringio_buf(self): +    return self.__rbuf + +  def cstringio_refill(self, prefix, reqlen): +    # self.__rbuf will already be empty here because fastbinary doesn't +    # ask for a refill until the previous buffer is empty.  Therefore, +    # we can start reading new frames immediately. +    while len(prefix) < reqlen: +      self.readFrame() +      prefix += self.__rbuf.getvalue() +    self.__rbuf = StringIO(prefix) +    return self.__rbuf + + +class TFileObjectTransport(TTransportBase): +  """Wraps a file-like object to make it work as a Thrift transport.""" + +  def __init__(self, fileobj): +    self.fileobj = fileobj + +  def isOpen(self): +    return True + +  def close(self): +    self.fileobj.close() + +  def read(self, sz): +    return self.fileobj.read(sz) + +  def write(self, buf): +    self.fileobj.write(buf) + +  def flush(self): +    self.fileobj.flush() diff --git a/lib/Python/Lib/thrift/transport/TTwisted.py b/lib/Python/Lib/thrift/transport/TTwisted.py new file mode 100644 index 000000000..b6dcb4e0b --- /dev/null +++ b/lib/Python/Lib/thrift/transport/TTwisted.py @@ -0,0 +1,219 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from zope.interface import implements, Interface, Attribute +from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \ +    connectionDone +from twisted.internet import defer +from twisted.protocols import basic +from twisted.python import log +from twisted.web import server, resource, http + +from thrift.transport import TTransport +from cStringIO import StringIO + + +class TMessageSenderTransport(TTransport.TTransportBase): + +    def __init__(self): +        self.__wbuf = StringIO() + +    def write(self, buf): +        self.__wbuf.write(buf) + +    def flush(self): +        msg = self.__wbuf.getvalue() +        self.__wbuf = StringIO() +        self.sendMessage(msg) + +    def sendMessage(self, message): +        raise NotImplementedError + + +class TCallbackTransport(TMessageSenderTransport): + +    def __init__(self, func): +        TMessageSenderTransport.__init__(self) +        self.func = func + +    def sendMessage(self, message): +        self.func(message) + + +class ThriftClientProtocol(basic.Int32StringReceiver): + +    MAX_LENGTH = 2 ** 31 - 1 + +    def __init__(self, client_class, iprot_factory, oprot_factory=None): +        self._client_class = client_class +        self._iprot_factory = iprot_factory +        if oprot_factory is None: +            self._oprot_factory = iprot_factory +        else: +            self._oprot_factory = oprot_factory + +        self.recv_map = {} +        self.started = defer.Deferred() + +    def dispatch(self, msg): +        self.sendString(msg) + +    def connectionMade(self): +        tmo = TCallbackTransport(self.dispatch) +        self.client = self._client_class(tmo, self._oprot_factory) +        self.started.callback(self.client) + +    def connectionLost(self, reason=connectionDone): +        for k,v in self.client._reqs.iteritems(): +            tex = TTransport.TTransportException( +                type=TTransport.TTransportException.END_OF_FILE, +                message='Connection closed') +            v.errback(tex) + +    def stringReceived(self, frame): +        tr = TTransport.TMemoryBuffer(frame) +        iprot = self._iprot_factory.getProtocol(tr) +        (fname, mtype, rseqid) = iprot.readMessageBegin() + +        try: +            method = self.recv_map[fname] +        except KeyError: +            method = getattr(self.client, 'recv_' + fname) +            self.recv_map[fname] = method + +        method(iprot, mtype, rseqid) + + +class ThriftServerProtocol(basic.Int32StringReceiver): + +    MAX_LENGTH = 2 ** 31 - 1 + +    def dispatch(self, msg): +        self.sendString(msg) + +    def processError(self, error): +        self.transport.loseConnection() + +    def processOk(self, _, tmo): +        msg = tmo.getvalue() + +        if len(msg) > 0: +            self.dispatch(msg) + +    def stringReceived(self, frame): +        tmi = TTransport.TMemoryBuffer(frame) +        tmo = TTransport.TMemoryBuffer() + +        iprot = self.factory.iprot_factory.getProtocol(tmi) +        oprot = self.factory.oprot_factory.getProtocol(tmo) + +        d = self.factory.processor.process(iprot, oprot) +        d.addCallbacks(self.processOk, self.processError, +            callbackArgs=(tmo,)) + + +class IThriftServerFactory(Interface): + +    processor = Attribute("Thrift processor") + +    iprot_factory = Attribute("Input protocol factory") + +    oprot_factory = Attribute("Output protocol factory") + + +class IThriftClientFactory(Interface): + +    client_class = Attribute("Thrift client class") + +    iprot_factory = Attribute("Input protocol factory") + +    oprot_factory = Attribute("Output protocol factory") + + +class ThriftServerFactory(ServerFactory): + +    implements(IThriftServerFactory) + +    protocol = ThriftServerProtocol + +    def __init__(self, processor, iprot_factory, oprot_factory=None): +        self.processor = processor +        self.iprot_factory = iprot_factory +        if oprot_factory is None: +            self.oprot_factory = iprot_factory +        else: +            self.oprot_factory = oprot_factory + + +class ThriftClientFactory(ClientFactory): + +    implements(IThriftClientFactory) + +    protocol = ThriftClientProtocol + +    def __init__(self, client_class, iprot_factory, oprot_factory=None): +        self.client_class = client_class +        self.iprot_factory = iprot_factory +        if oprot_factory is None: +            self.oprot_factory = iprot_factory +        else: +            self.oprot_factory = oprot_factory + +    def buildProtocol(self, addr): +        p = self.protocol(self.client_class, self.iprot_factory, +            self.oprot_factory) +        p.factory = self +        return p + + +class ThriftResource(resource.Resource): + +    allowedMethods = ('POST',) + +    def __init__(self, processor, inputProtocolFactory, +        outputProtocolFactory=None): +        resource.Resource.__init__(self) +        self.inputProtocolFactory = inputProtocolFactory +        if outputProtocolFactory is None: +            self.outputProtocolFactory = inputProtocolFactory +        else: +            self.outputProtocolFactory = outputProtocolFactory +        self.processor = processor + +    def getChild(self, path, request): +        return self + +    def _cbProcess(self, _, request, tmo): +        msg = tmo.getvalue() +        request.setResponseCode(http.OK) +        request.setHeader("content-type", "application/x-thrift") +        request.write(msg) +        request.finish() + +    def render_POST(self, request): +        request.content.seek(0, 0) +        data = request.content.read() +        tmi = TTransport.TMemoryBuffer(data) +        tmo = TTransport.TMemoryBuffer() + +        iprot = self.inputProtocolFactory.getProtocol(tmi) +        oprot = self.outputProtocolFactory.getProtocol(tmo) + +        d = self.processor.process(iprot, oprot) +        d.addCallback(self._cbProcess, request, tmo) +        return server.NOT_DONE_YET diff --git a/lib/Python/Lib/thrift/transport/TZlibTransport.py b/lib/Python/Lib/thrift/transport/TZlibTransport.py new file mode 100644 index 000000000..784d4e1e0 --- /dev/null +++ b/lib/Python/Lib/thrift/transport/TZlibTransport.py @@ -0,0 +1,261 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +''' +TZlibTransport provides a compressed transport and transport factory +class, using the python standard library zlib module to implement +data compression. +''' + +from __future__ import division +import zlib +from cStringIO import StringIO +from TTransport import TTransportBase, CReadableTransport + +class TZlibTransportFactory(object): +  ''' +  Factory transport that builds zlib compressed transports. +   +  This factory caches the last single client/transport that it was passed +  and returns the same TZlibTransport object that was created. +   +  This caching means the TServer class will get the _same_ transport +  object for both input and output transports from this factory. +  (For non-threaded scenarios only, since the cache only holds one object) +   +  The purpose of this caching is to allocate only one TZlibTransport where +  only one is really needed (since it must have separate read/write buffers), +  and makes the statistics from getCompSavings() and getCompRatio() +  easier to understand. +  ''' + +  # class scoped cache of last transport given and zlibtransport returned +  _last_trans = None +  _last_z = None + +  def getTransport(self, trans, compresslevel=9): +    '''Wrap a transport , trans, with the TZlibTransport +    compressed transport class, returning a new +    transport to the caller. +     +    @param compresslevel: The zlib compression level, ranging +    from 0 (no compression) to 9 (best compression).  Defaults to 9. +    @type compresslevel: int +     +    This method returns a TZlibTransport which wraps the +    passed C{trans} TTransport derived instance. +    ''' +    if trans == self._last_trans: +      return self._last_z +    ztrans = TZlibTransport(trans, compresslevel) +    self._last_trans = trans +    self._last_z = ztrans +    return ztrans + + +class TZlibTransport(TTransportBase, CReadableTransport): +  ''' +  Class that wraps a transport with zlib, compressing writes +  and decompresses reads, using the python standard +  library zlib module. +  ''' + +  # Read buffer size for the python fastbinary C extension, +  # the TBinaryProtocolAccelerated class. +  DEFAULT_BUFFSIZE = 4096 + +  def __init__(self, trans, compresslevel=9): +    ''' +    Create a new TZlibTransport, wrapping C{trans}, another +    TTransport derived object. +     +    @param trans: A thrift transport object, i.e. a TSocket() object. +    @type trans: TTransport +    @param compresslevel: The zlib compression level, ranging +    from 0 (no compression) to 9 (best compression).  Default is 9. +    @type compresslevel: int +    ''' +    self.__trans = trans +    self.compresslevel = compresslevel +    self.__rbuf = StringIO() +    self.__wbuf = StringIO() +    self._init_zlib() +    self._init_stats() + +  def _reinit_buffers(self): +    ''' +    Internal method to initialize/reset the internal StringIO objects +    for read and write buffers. +    ''' +    self.__rbuf = StringIO() +    self.__wbuf = StringIO() + +  def _init_stats(self): +    ''' +    Internal method to reset the internal statistics counters +    for compression ratios and bandwidth savings. +    ''' +    self.bytes_in = 0 +    self.bytes_out = 0 +    self.bytes_in_comp = 0 +    self.bytes_out_comp = 0 + +  def _init_zlib(self): +    ''' +    Internal method for setting up the zlib compression and +    decompression objects. +    ''' +    self._zcomp_read = zlib.decompressobj() +    self._zcomp_write = zlib.compressobj(self.compresslevel) + +  def getCompRatio(self): +    ''' +    Get the current measured compression ratios (in,out) from +    this transport. +     +    Returns a tuple of:  +    (inbound_compression_ratio, outbound_compression_ratio) +     +    The compression ratios are computed as: +        compressed / uncompressed + +    E.g., data that compresses by 10x will have a ratio of: 0.10 +    and data that compresses to half of ts original size will +    have a ratio of 0.5 +     +    None is returned if no bytes have yet been processed in +    a particular direction. +    ''' +    r_percent, w_percent = (None, None) +    if self.bytes_in > 0: +      r_percent = self.bytes_in_comp / self.bytes_in +    if self.bytes_out > 0: +      w_percent = self.bytes_out_comp / self.bytes_out +    return (r_percent, w_percent) + +  def getCompSavings(self): +    ''' +    Get the current count of saved bytes due to data +    compression. +     +    Returns a tuple of: +    (inbound_saved_bytes, outbound_saved_bytes) +     +    Note: if compression is actually expanding your +    data (only likely with very tiny thrift objects), then +    the values returned will be negative. +    ''' +    r_saved = self.bytes_in - self.bytes_in_comp +    w_saved = self.bytes_out - self.bytes_out_comp +    return (r_saved, w_saved) + +  def isOpen(self): +    '''Return the underlying transport's open status''' +    return self.__trans.isOpen() + +  def open(self): +    """Open the underlying transport""" +    self._init_stats() +    return self.__trans.open() + +  def listen(self): +    '''Invoke the underlying transport's listen() method''' +    self.__trans.listen() + +  def accept(self): +    '''Accept connections on the underlying transport''' +    return self.__trans.accept() + +  def close(self): +    '''Close the underlying transport,''' +    self._reinit_buffers() +    self._init_zlib() +    return self.__trans.close() + +  def read(self, sz): +    ''' +    Read up to sz bytes from the decompressed bytes buffer, and +    read from the underlying transport if the decompression +    buffer is empty. +    ''' +    ret = self.__rbuf.read(sz) +    if len(ret) > 0: +      return ret +    # keep reading from transport until something comes back +    while True: +      if self.readComp(sz): +        break +    ret = self.__rbuf.read(sz) +    return ret + +  def readComp(self, sz): +    ''' +    Read compressed data from the underlying transport, then +    decompress it and append it to the internal StringIO read buffer +    ''' +    zbuf = self.__trans.read(sz) +    zbuf = self._zcomp_read.unconsumed_tail + zbuf +    buf = self._zcomp_read.decompress(zbuf) +    self.bytes_in += len(zbuf) +    self.bytes_in_comp += len(buf) +    old = self.__rbuf.read() +    self.__rbuf = StringIO(old + buf) +    if len(old) + len(buf) == 0: +      return False +    return True + +  def write(self, buf): +    ''' +    Write some bytes, putting them into the internal write +    buffer for eventual compression. +    ''' +    self.__wbuf.write(buf) + +  def flush(self): +    ''' +    Flush any queued up data in the write buffer and ensure the +    compression buffer is flushed out to the underlying transport +    ''' +    wout = self.__wbuf.getvalue() +    if len(wout) > 0: +      zbuf = self._zcomp_write.compress(wout) +      self.bytes_out += len(wout) +      self.bytes_out_comp += len(zbuf) +    else: +      zbuf = '' +    ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH) +    self.bytes_out_comp += len(ztail) +    if (len(zbuf) + len(ztail)) > 0: +      self.__wbuf = StringIO() +      self.__trans.write(zbuf + ztail) +    self.__trans.flush() + +  @property +  def cstringio_buf(self): +    '''Implement the CReadableTransport interface''' +    return self.__rbuf + +  def cstringio_refill(self, partialread, reqlen): +    '''Implement the CReadableTransport interface for refill''' +    retstring = partialread +    if reqlen < self.DEFAULT_BUFFSIZE: +      retstring += self.read(self.DEFAULT_BUFFSIZE) +    while len(retstring) < reqlen: +      retstring += self.read(reqlen - len(retstring)) +    self.__rbuf = StringIO(retstring) +    return self.__rbuf diff --git a/lib/Python/Lib/thrift/transport/__init__.py b/lib/Python/Lib/thrift/transport/__init__.py new file mode 100644 index 000000000..46e54fe6b --- /dev/null +++ b/lib/Python/Lib/thrift/transport/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +__all__ = ['TTransport', 'TSocket', 'THttpClient','TZlibTransport'] | 
