diff options
| author | 2011-02-07 00:32:21 +0100 | |
|---|---|---|
| committer | 2011-02-07 00:32:21 +0100 | |
| commit | 69f2d3660e44bd138d24b6cf97b234bc2efc1c78 (patch) | |
| tree | af276b8ed0b5cac5a62e6949a00d607afcc94e5d /module/lib/thrift | |
| parent | closed #231 (diff) | |
| download | pyload-69f2d3660e44bd138d24b6cf97b234bc2efc1c78.tar.xz | |
added Thrift backend
Diffstat (limited to 'module/lib/thrift')
| -rw-r--r-- | module/lib/thrift/TSCons.py | 33 | ||||
| -rw-r--r-- | module/lib/thrift/TSerialization.py | 34 | ||||
| -rw-r--r-- | module/lib/thrift/Thrift.py | 133 | ||||
| -rw-r--r-- | module/lib/thrift/__init__.py | 20 | ||||
| -rw-r--r-- | module/lib/thrift/protocol/TBinaryProtocol.py | 259 | ||||
| -rw-r--r-- | module/lib/thrift/protocol/TCompactProtocol.py | 368 | ||||
| -rw-r--r-- | module/lib/thrift/protocol/TProtocol.py | 205 | ||||
| -rw-r--r-- | module/lib/thrift/protocol/__init__.py | 20 | ||||
| -rw-r--r-- | module/lib/thrift/protocol/fastbinary.c | 1203 | ||||
| -rw-r--r-- | module/lib/thrift/server/THttpServer.py | 82 | ||||
| -rw-r--r-- | module/lib/thrift/server/TNonblockingServer.py | 310 | ||||
| -rw-r--r-- | module/lib/thrift/server/TServer.py | 275 | ||||
| -rw-r--r-- | module/lib/thrift/server/__init__.py | 20 | ||||
| -rw-r--r-- | module/lib/thrift/transport/THttpClient.py | 126 | ||||
| -rw-r--r-- | module/lib/thrift/transport/TSocket.py | 163 | ||||
| -rw-r--r-- | module/lib/thrift/transport/TTransport.py | 331 | ||||
| -rw-r--r-- | module/lib/thrift/transport/TTwisted.py | 219 | ||||
| -rw-r--r-- | module/lib/thrift/transport/__init__.py | 20 | 
18 files changed, 3821 insertions, 0 deletions
diff --git a/module/lib/thrift/TSCons.py b/module/lib/thrift/TSCons.py new file mode 100644 index 000000000..24046256c --- /dev/null +++ b/module/lib/thrift/TSCons.py @@ -0,0 +1,33 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from os import path +from SCons.Builder import Builder + +def scons_env(env, add=''): +  opath = path.dirname(path.abspath('$TARGET')) +  lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE' +  cppbuild = Builder(action = lstr) +  env.Append(BUILDERS = {'ThriftCpp' : cppbuild}) + +def gen_cpp(env, dir, file): +  scons_env(env) +  suffixes = ['_types.h', '_types.cpp'] +  targets = map(lambda s: 'gen-cpp/' + file + s, suffixes) +  return env.ThriftCpp(targets, dir+file+'.thrift') diff --git a/module/lib/thrift/TSerialization.py b/module/lib/thrift/TSerialization.py new file mode 100644 index 000000000..b19f98aa8 --- /dev/null +++ b/module/lib/thrift/TSerialization.py @@ -0,0 +1,34 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from protocol import TBinaryProtocol +from transport import TTransport + +def serialize(thrift_object, protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()): +    transport = TTransport.TMemoryBuffer() +    protocol = protocol_factory.getProtocol(transport) +    thrift_object.write(protocol) +    return transport.getvalue() + +def deserialize(base, buf, protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()): +    transport = TTransport.TMemoryBuffer(buf) +    protocol = protocol_factory.getProtocol(transport) +    base.read(protocol) +    return base + diff --git a/module/lib/thrift/Thrift.py b/module/lib/thrift/Thrift.py new file mode 100644 index 000000000..91728a776 --- /dev/null +++ b/module/lib/thrift/Thrift.py @@ -0,0 +1,133 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import sys + +class TType: +  STOP   = 0 +  VOID   = 1 +  BOOL   = 2 +  BYTE   = 3 +  I08    = 3 +  DOUBLE = 4 +  I16    = 6 +  I32    = 8 +  I64    = 10 +  STRING = 11 +  UTF7   = 11 +  STRUCT = 12 +  MAP    = 13 +  SET    = 14 +  LIST   = 15 +  UTF8   = 16 +  UTF16  = 17 + +class TMessageType: +  CALL  = 1 +  REPLY = 2 +  EXCEPTION = 3 +  ONEWAY = 4 + +class TProcessor: + +  """Base class for procsessor, which works on two streams.""" + +  def process(iprot, oprot): +    pass + +class TException(Exception): + +  """Base class for all thrift exceptions.""" + +  # BaseException.message is deprecated in Python v[2.6,3.0) +  if (2,6,0) <= sys.version_info < (3,0): +    def _get_message(self): +	    return self._message +    def _set_message(self, message): +	    self._message = message +    message = property(_get_message, _set_message) + +  def __init__(self, message=None): +    Exception.__init__(self, message) +    self.message = message + +class TApplicationException(TException): + +  """Application level thrift exceptions.""" + +  UNKNOWN = 0 +  UNKNOWN_METHOD = 1 +  INVALID_MESSAGE_TYPE = 2 +  WRONG_METHOD_NAME = 3 +  BAD_SEQUENCE_ID = 4 +  MISSING_RESULT = 5 + +  def __init__(self, type=UNKNOWN, message=None): +    TException.__init__(self, message) +    self.type = type + +  def __str__(self): +    if self.message: +      return self.message +    elif self.type == self.UNKNOWN_METHOD: +      return 'Unknown method' +    elif self.type == self.INVALID_MESSAGE_TYPE: +      return 'Invalid message type' +    elif self.type == self.WRONG_METHOD_NAME: +      return 'Wrong method name' +    elif self.type == self.BAD_SEQUENCE_ID: +      return 'Bad sequence ID' +    elif self.type == self.MISSING_RESULT: +      return 'Missing result' +    else: +      return 'Default (unknown) TApplicationException' + +  def read(self, iprot): +    iprot.readStructBegin() +    while True: +      (fname, ftype, fid) = iprot.readFieldBegin() +      if ftype == TType.STOP: +        break +      if fid == 1: +        if ftype == TType.STRING: +          self.message = iprot.readString(); +        else: +          iprot.skip(ftype) +      elif fid == 2: +        if ftype == TType.I32: +          self.type = iprot.readI32(); +        else: +          iprot.skip(ftype) +      else: +        iprot.skip(ftype) +      iprot.readFieldEnd() +    iprot.readStructEnd() + +  def write(self, oprot): +    oprot.writeStructBegin('TApplicationException') +    if self.message != None: +      oprot.writeFieldBegin('message', TType.STRING, 1) +      oprot.writeString(self.message) +      oprot.writeFieldEnd() +    if self.type != None: +      oprot.writeFieldBegin('type', TType.I32, 2) +      oprot.writeI32(self.type) +      oprot.writeFieldEnd() +    oprot.writeFieldStop() +    oprot.writeStructEnd() diff --git a/module/lib/thrift/__init__.py b/module/lib/thrift/__init__.py new file mode 100644 index 000000000..48d659c40 --- /dev/null +++ b/module/lib/thrift/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +__all__ = ['Thrift', 'TSCons'] diff --git a/module/lib/thrift/protocol/TBinaryProtocol.py b/module/lib/thrift/protocol/TBinaryProtocol.py new file mode 100644 index 000000000..50c6aa896 --- /dev/null +++ b/module/lib/thrift/protocol/TBinaryProtocol.py @@ -0,0 +1,259 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from TProtocol import * +from struct import pack, unpack + +class TBinaryProtocol(TProtocolBase): + +  """Binary implementation of the Thrift protocol driver.""" + +  # NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be +  # positive, converting this into a long. If we hardcode the int value +  # instead it'll stay in 32 bit-land. + +  # VERSION_MASK = 0xffff0000 +  VERSION_MASK = -65536 + +  # VERSION_1 = 0x80010000 +  VERSION_1 = -2147418112 + +  TYPE_MASK = 0x000000ff + +  def __init__(self, trans, strictRead=False, strictWrite=True): +    TProtocolBase.__init__(self, trans) +    self.strictRead = strictRead +    self.strictWrite = strictWrite + +  def writeMessageBegin(self, name, type, seqid): +    if self.strictWrite: +      self.writeI32(TBinaryProtocol.VERSION_1 | type) +      self.writeString(name) +      self.writeI32(seqid) +    else: +      self.writeString(name) +      self.writeByte(type) +      self.writeI32(seqid) + +  def writeMessageEnd(self): +    pass + +  def writeStructBegin(self, name): +    pass + +  def writeStructEnd(self): +    pass + +  def writeFieldBegin(self, name, type, id): +    self.writeByte(type) +    self.writeI16(id) + +  def writeFieldEnd(self): +    pass + +  def writeFieldStop(self): +    self.writeByte(TType.STOP); + +  def writeMapBegin(self, ktype, vtype, size): +    self.writeByte(ktype) +    self.writeByte(vtype) +    self.writeI32(size) + +  def writeMapEnd(self): +    pass + +  def writeListBegin(self, etype, size): +    self.writeByte(etype) +    self.writeI32(size) + +  def writeListEnd(self): +    pass + +  def writeSetBegin(self, etype, size): +    self.writeByte(etype) +    self.writeI32(size) + +  def writeSetEnd(self): +    pass + +  def writeBool(self, bool): +    if bool: +      self.writeByte(1) +    else: +      self.writeByte(0) + +  def writeByte(self, byte): +    buff = pack("!b", byte) +    self.trans.write(buff) + +  def writeI16(self, i16): +    buff = pack("!h", i16) +    self.trans.write(buff) + +  def writeI32(self, i32): +    buff = pack("!i", i32) +    self.trans.write(buff) + +  def writeI64(self, i64): +    buff = pack("!q", i64) +    self.trans.write(buff) + +  def writeDouble(self, dub): +    buff = pack("!d", dub) +    self.trans.write(buff) + +  def writeString(self, str): +    self.writeI32(len(str)) +    self.trans.write(str) + +  def readMessageBegin(self): +    sz = self.readI32() +    if sz < 0: +      version = sz & TBinaryProtocol.VERSION_MASK +      if version != TBinaryProtocol.VERSION_1: +        raise TProtocolException(type=TProtocolException.BAD_VERSION, message='Bad version in readMessageBegin: %d' % (sz)) +      type = sz & TBinaryProtocol.TYPE_MASK +      name = self.readString() +      seqid = self.readI32() +    else: +      if self.strictRead: +        raise TProtocolException(type=TProtocolException.BAD_VERSION, message='No protocol version header') +      name = self.trans.readAll(sz) +      type = self.readByte() +      seqid = self.readI32() +    return (name, type, seqid) + +  def readMessageEnd(self): +    pass + +  def readStructBegin(self): +    pass + +  def readStructEnd(self): +    pass + +  def readFieldBegin(self): +    type = self.readByte() +    if type == TType.STOP: +      return (None, type, 0) +    id = self.readI16() +    return (None, type, id) + +  def readFieldEnd(self): +    pass + +  def readMapBegin(self): +    ktype = self.readByte() +    vtype = self.readByte() +    size = self.readI32() +    return (ktype, vtype, size) + +  def readMapEnd(self): +    pass + +  def readListBegin(self): +    etype = self.readByte() +    size = self.readI32() +    return (etype, size) + +  def readListEnd(self): +    pass + +  def readSetBegin(self): +    etype = self.readByte() +    size = self.readI32() +    return (etype, size) + +  def readSetEnd(self): +    pass + +  def readBool(self): +    byte = self.readByte() +    if byte == 0: +      return False +    return True + +  def readByte(self): +    buff = self.trans.readAll(1) +    val, = unpack('!b', buff) +    return val + +  def readI16(self): +    buff = self.trans.readAll(2) +    val, = unpack('!h', buff) +    return val + +  def readI32(self): +    buff = self.trans.readAll(4) +    val, = unpack('!i', buff) +    return val + +  def readI64(self): +    buff = self.trans.readAll(8) +    val, = unpack('!q', buff) +    return val + +  def readDouble(self): +    buff = self.trans.readAll(8) +    val, = unpack('!d', buff) +    return val + +  def readString(self): +    len = self.readI32() +    str = self.trans.readAll(len) +    return str + + +class TBinaryProtocolFactory: +  def __init__(self, strictRead=False, strictWrite=True): +    self.strictRead = strictRead +    self.strictWrite = strictWrite + +  def getProtocol(self, trans): +    prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite) +    return prot + + +class TBinaryProtocolAccelerated(TBinaryProtocol): + +  """C-Accelerated version of TBinaryProtocol. + +  This class does not override any of TBinaryProtocol's methods, +  but the generated code recognizes it directly and will call into +  our C module to do the encoding, bypassing this object entirely. +  We inherit from TBinaryProtocol so that the normal TBinaryProtocol +  encoding can happen if the fastbinary module doesn't work for some +  reason.  (TODO(dreiss): Make this happen sanely in more cases.) + +  In order to take advantage of the C module, just use +  TBinaryProtocolAccelerated instead of TBinaryProtocol. + +  NOTE:  This code was contributed by an external developer. +         The internal Thrift team has reviewed and tested it, +         but we cannot guarantee that it is production-ready. +         Please feel free to report bugs and/or success stories +         to the public mailing list. +  """ + +  pass + + +class TBinaryProtocolAcceleratedFactory: +  def getProtocol(self, trans): +    return TBinaryProtocolAccelerated(trans) diff --git a/module/lib/thrift/protocol/TCompactProtocol.py b/module/lib/thrift/protocol/TCompactProtocol.py new file mode 100644 index 000000000..fbc156a8f --- /dev/null +++ b/module/lib/thrift/protocol/TCompactProtocol.py @@ -0,0 +1,368 @@ +from TProtocol import * +from struct import pack, unpack + +__all__ = ['TCompactProtocol', 'TCompactProtocolFactory'] + +CLEAR = 0 +FIELD_WRITE = 1 +VALUE_WRITE = 2 +CONTAINER_WRITE = 3 +BOOL_WRITE = 4 +FIELD_READ = 5 +CONTAINER_READ = 6 +VALUE_READ = 7 +BOOL_READ = 8 + +def make_helper(v_from, container): +  def helper(func): +    def nested(self, *args, **kwargs): +      assert self.state in (v_from, container), (self.state, v_from, container) +      return func(self, *args, **kwargs) +    return nested +  return helper +writer = make_helper(VALUE_WRITE, CONTAINER_WRITE) +reader = make_helper(VALUE_READ, CONTAINER_READ) + +def makeZigZag(n, bits): +  return (n << 1) ^ (n >> (bits - 1)) + +def fromZigZag(n): +  return (n >> 1) ^ -(n & 1) + +def writeVarint(trans, n): +  out = [] +  while True: +    if n & ~0x7f == 0: +      out.append(n) +      break +    else: +      out.append((n & 0xff) | 0x80) +      n = n >> 7 +  trans.write(''.join(map(chr, out))) + +def readVarint(trans): +  result = 0 +  shift = 0 +  while True: +    x = trans.readAll(1) +    byte = ord(x) +    result |= (byte & 0x7f) << shift +    if byte >> 7 == 0: +      return result +    shift += 7 + +class CompactType: +  TRUE = 1 +  FALSE = 2 +  BYTE = 0x03 +  I16 = 0x04 +  I32 = 0x05 +  I64 = 0x06 +  DOUBLE = 0x07 +  BINARY = 0x08 +  LIST = 0x09 +  SET = 0x0A +  MAP = 0x0B +  STRUCT = 0x0C + +CTYPES = {TType.BOOL: CompactType.TRUE, # used for collection +          TType.BYTE: CompactType.BYTE, +          TType.I16: CompactType.I16, +          TType.I32: CompactType.I32, +          TType.I64: CompactType.I64, +          TType.DOUBLE: CompactType.DOUBLE, +          TType.STRING: CompactType.BINARY, +          TType.STRUCT: CompactType.STRUCT, +          TType.LIST: CompactType.LIST, +          TType.SET: CompactType.SET, +          TType.MAP: CompactType.MAP, +          } + +TTYPES = {} +for k, v in CTYPES.items(): +  TTYPES[v] = k +TTYPES[CompactType.FALSE] = TType.BOOL +del k +del v + +class TCompactProtocol(TProtocolBase): +  "Compact implementation of the Thrift protocol driver." + +  PROTOCOL_ID = 0x82 +  VERSION = 1 +  VERSION_MASK = 0x1f +  TYPE_MASK = 0xe0 +  TYPE_SHIFT_AMOUNT = 5 + +  def __init__(self, trans): +    TProtocolBase.__init__(self, trans) +    self.state = CLEAR +    self.__last_fid = 0 +    self.__bool_fid = None +    self.__bool_value = None +    self.__structs = [] +    self.__containers = [] + +  def __writeVarint(self, n): +    writeVarint(self.trans, n) + +  def writeMessageBegin(self, name, type, seqid): +    assert self.state == CLEAR +    self.__writeUByte(self.PROTOCOL_ID) +    self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT)) +    self.__writeVarint(seqid) +    self.__writeString(name) +    self.state = VALUE_WRITE + +  def writeMessageEnd(self): +    assert self.state == VALUE_WRITE +    self.state = CLEAR + +  def writeStructBegin(self, name): +    assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state +    self.__structs.append((self.state, self.__last_fid)) +    self.state = FIELD_WRITE +    self.__last_fid = 0 + +  def writeStructEnd(self): +    assert self.state == FIELD_WRITE +    self.state, self.__last_fid = self.__structs.pop() + +  def writeFieldStop(self): +    self.__writeByte(0) + +  def __writeFieldHeader(self, type, fid): +    delta = fid - self.__last_fid +    if 0 < delta <= 15: +      self.__writeUByte(delta << 4 | type) +    else: +      self.__writeByte(type) +      self.__writeI16(fid) +    self.__last_fid = fid + +  def writeFieldBegin(self, name, type, fid): +    assert self.state == FIELD_WRITE, self.state +    if type == TType.BOOL: +      self.state = BOOL_WRITE +      self.__bool_fid = fid +    else: +      self.state = VALUE_WRITE +      self.__writeFieldHeader(CTYPES[type], fid) + +  def writeFieldEnd(self): +    assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state +    self.state = FIELD_WRITE + +  def __writeUByte(self, byte): +    self.trans.write(pack('!B', byte)) + +  def __writeByte(self, byte): +    self.trans.write(pack('!b', byte)) + +  def __writeI16(self, i16): +    self.__writeVarint(makeZigZag(i16, 16)) + +  def __writeSize(self, i32): +    self.__writeVarint(i32) + +  def writeCollectionBegin(self, etype, size): +    assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state +    if size <= 14: +      self.__writeUByte(size << 4 | CTYPES[etype]) +    else: +      self.__writeUByte(0xf0 | CTYPES[etype]) +      self.__writeSize(size) +    self.__containers.append(self.state) +    self.state = CONTAINER_WRITE +  writeSetBegin = writeCollectionBegin +  writeListBegin = writeCollectionBegin + +  def writeMapBegin(self, ktype, vtype, size): +    assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state +    if size == 0: +      self.__writeByte(0) +    else: +      self.__writeSize(size) +      self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype]) +    self.__containers.append(self.state) +    self.state = CONTAINER_WRITE + +  def writeCollectionEnd(self): +    assert self.state == CONTAINER_WRITE, self.state +    self.state = self.__containers.pop() +  writeMapEnd = writeCollectionEnd +  writeSetEnd = writeCollectionEnd +  writeListEnd = writeCollectionEnd + +  def writeBool(self, bool): +    if self.state == BOOL_WRITE: +      self.__writeFieldHeader(types[bool], self.__bool_fid) +    elif self.state == CONTAINER_WRITE: +      self.__writeByte(int(bool)) +    else: +      raise AssertetionError, "Invalid state in compact protocol" + +  writeByte = writer(__writeByte) +  writeI16 = writer(__writeI16) + +  @writer +  def writeI32(self, i32): +    self.__writeVarint(makeZigZag(i32, 32)) + +  @writer +  def writeI64(self, i64): +    self.__writeVarint(makeZigZag(i64, 64)) + +  @writer +  def writeDouble(self, dub): +    self.trans.write(pack('!d', dub)) + +  def __writeString(self, s): +    self.__writeSize(len(s)) +    self.trans.write(s) +  writeString = writer(__writeString) + +  def readFieldBegin(self): +    assert self.state == FIELD_READ, self.state +    type = self.__readUByte() +    if type & 0x0f == TType.STOP: +      return (None, 0, 0) +    delta = type >> 4 +    if delta == 0: +      fid = self.__readI16() +    else: +      fid = self.__last_fid + delta +    self.__last_fid = fid +    type = type & 0x0f +    if type == CompactType.TRUE: +      self.state = BOOL_READ +      self.__bool_value = True +    elif type == CompactType.FALSE: +      self.state = BOOL_READ +      self.__bool_value = False +    else: +      self.state = VALUE_READ +    return (None, self.__getTType(type), fid) + +  def readFieldEnd(self): +    assert self.state in (VALUE_READ, BOOL_READ), self.state +    self.state = FIELD_READ + +  def __readUByte(self): +    result, = unpack('!B', self.trans.readAll(1)) +    return result + +  def __readByte(self): +    result, = unpack('!b', self.trans.readAll(1)) +    return result + +  def __readVarint(self): +    return readVarint(self.trans) + +  def __readZigZag(self): +    return fromZigZag(self.__readVarint()) + +  def __readSize(self): +    result = self.__readVarint() +    if result < 0: +      raise TException("Length < 0") +    return result + +  def readMessageBegin(self): +    assert self.state == CLEAR +    proto_id = self.__readUByte() +    if proto_id != self.PROTOCOL_ID: +      raise TProtocolException(TProtocolException.BAD_VERSION, +          'Bad protocol id in the message: %d' % proto_id) +    ver_type = self.__readUByte() +    type = (ver_type & self.TYPE_MASK) >> self.TYPE_SHIFT_AMOUNT +    version = ver_type & self.VERSION_MASK +    if version != self.VERSION: +      raise TProtocolException(TProtocolException.BAD_VERSION, +          'Bad version: %d (expect %d)' % (version, self.VERSION)) +    seqid = self.__readVarint() +    name = self.__readString() +    return (name, type, seqid) + +  def readMessageEnd(self): +    assert self.state == VALUE_READ +    assert len(self.__structs) == 0 +    self.state = CLEAR + +  def readStructBegin(self): +    assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state +    self.__structs.append((self.state, self.__last_fid)) +    self.state = FIELD_READ +    self.__last_fid = 0 + +  def readStructEnd(self): +    assert self.state == FIELD_READ +    self.state, self.__last_fid = self.__structs.pop() + +  def readCollectionBegin(self): +    assert self.state in (VALUE_READ, CONTAINER_READ), self.state +    size_type = self.__readUByte() +    size = size_type >> 4 +    type = self.__getTType(size_type) +    if size == 15: +      size = self.__readSize() +    self.__containers.append(self.state) +    self.state = CONTAINER_READ +    return type, size +  readSetBegin = readCollectionBegin +  readListBegin = readCollectionBegin + +  def readMapBegin(self): +    assert self.state in (VALUE_READ, CONTAINER_READ), self.state +    size = self.__readSize() +    types = 0 +    if size > 0: +      types = self.__readUByte() +    vtype = self.__getTType(types) +    ktype = self.__getTType(types >> 4) +    self.__containers.append(self.state) +    self.state = CONTAINER_READ +    return (ktype, vtype, size) + +  def readCollectionEnd(self): +    assert self.state == CONTAINER_READ, self.state +    self.state = self.__containers.pop() +  readSetEnd = readCollectionEnd +  readListEnd = readCollectionEnd +  readMapEnd = readCollectionEnd + +  def readBool(self): +    if self.state == BOOL_READ: +      return self.__bool_value +    elif self.state == CONTAINER_READ: +      return bool(self.__readByte()) +    else: +      raise AssertionError, "Invalid state in compact protocol: %d" % self.state + +  readByte = reader(__readByte) +  __readI16 = __readZigZag +  readI16 = reader(__readZigZag) +  readI32 = reader(__readZigZag) +  readI64 = reader(__readZigZag) + +  @reader +  def readDouble(self): +    buff = self.trans.readAll(8) +    val, = unpack('!d', buff) +    return val + +  def __readString(self): +    len = self.__readSize() +    return self.trans.readAll(len) +  readString = reader(__readString) + +  def __getTType(self, byte): +    return TTYPES[byte & 0x0f] + + +class TCompactProtocolFactory: +  def __init__(self): +    pass + +  def getProtocol(self, trans): +    return TCompactProtocol(trans) diff --git a/module/lib/thrift/protocol/TProtocol.py b/module/lib/thrift/protocol/TProtocol.py new file mode 100644 index 000000000..be3cb1403 --- /dev/null +++ b/module/lib/thrift/protocol/TProtocol.py @@ -0,0 +1,205 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from thrift.Thrift import * + +class TProtocolException(TException): + +  """Custom Protocol Exception class""" + +  UNKNOWN = 0 +  INVALID_DATA = 1 +  NEGATIVE_SIZE = 2 +  SIZE_LIMIT = 3 +  BAD_VERSION = 4 + +  def __init__(self, type=UNKNOWN, message=None): +    TException.__init__(self, message) +    self.type = type + +class TProtocolBase: + +  """Base class for Thrift protocol driver.""" + +  def __init__(self, trans): +    self.trans = trans + +  def writeMessageBegin(self, name, type, seqid): +    pass + +  def writeMessageEnd(self): +    pass + +  def writeStructBegin(self, name): +    pass + +  def writeStructEnd(self): +    pass + +  def writeFieldBegin(self, name, type, id): +    pass + +  def writeFieldEnd(self): +    pass + +  def writeFieldStop(self): +    pass + +  def writeMapBegin(self, ktype, vtype, size): +    pass + +  def writeMapEnd(self): +    pass + +  def writeListBegin(self, etype, size): +    pass + +  def writeListEnd(self): +    pass + +  def writeSetBegin(self, etype, size): +    pass + +  def writeSetEnd(self): +    pass + +  def writeBool(self, bool): +    pass + +  def writeByte(self, byte): +    pass + +  def writeI16(self, i16): +    pass + +  def writeI32(self, i32): +    pass + +  def writeI64(self, i64): +    pass + +  def writeDouble(self, dub): +    pass + +  def writeString(self, str): +    pass + +  def readMessageBegin(self): +    pass + +  def readMessageEnd(self): +    pass + +  def readStructBegin(self): +    pass + +  def readStructEnd(self): +    pass + +  def readFieldBegin(self): +    pass + +  def readFieldEnd(self): +    pass + +  def readMapBegin(self): +    pass + +  def readMapEnd(self): +    pass + +  def readListBegin(self): +    pass + +  def readListEnd(self): +    pass + +  def readSetBegin(self): +    pass + +  def readSetEnd(self): +    pass + +  def readBool(self): +    pass + +  def readByte(self): +    pass + +  def readI16(self): +    pass + +  def readI32(self): +    pass + +  def readI64(self): +    pass + +  def readDouble(self): +    pass + +  def readString(self): +    pass + +  def skip(self, type): +    if type == TType.STOP: +      return +    elif type == TType.BOOL: +      self.readBool() +    elif type == TType.BYTE: +      self.readByte() +    elif type == TType.I16: +      self.readI16() +    elif type == TType.I32: +      self.readI32() +    elif type == TType.I64: +      self.readI64() +    elif type == TType.DOUBLE: +      self.readDouble() +    elif type == TType.STRING: +      self.readString() +    elif type == TType.STRUCT: +      name = self.readStructBegin() +      while True: +        (name, type, id) = self.readFieldBegin() +        if type == TType.STOP: +          break +        self.skip(type) +        self.readFieldEnd() +      self.readStructEnd() +    elif type == TType.MAP: +      (ktype, vtype, size) = self.readMapBegin() +      for i in range(size): +        self.skip(ktype) +        self.skip(vtype) +      self.readMapEnd() +    elif type == TType.SET: +      (etype, size) = self.readSetBegin() +      for i in range(size): +        self.skip(etype) +      self.readSetEnd() +    elif type == TType.LIST: +      (etype, size) = self.readListBegin() +      for i in range(size): +        self.skip(etype) +      self.readListEnd() + +class TProtocolFactory: +  def getProtocol(self, trans): +    pass diff --git a/module/lib/thrift/protocol/__init__.py b/module/lib/thrift/protocol/__init__.py new file mode 100644 index 000000000..01bfe18e5 --- /dev/null +++ b/module/lib/thrift/protocol/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary'] diff --git a/module/lib/thrift/protocol/fastbinary.c b/module/lib/thrift/protocol/fastbinary.c new file mode 100644 index 000000000..67b215a83 --- /dev/null +++ b/module/lib/thrift/protocol/fastbinary.c @@ -0,0 +1,1203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + *   http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <Python.h> +#include "cStringIO.h" +#include <stdbool.h> +#include <stdint.h> +#include <netinet/in.h> + +/* Fix endianness issues on Solaris */ +#if defined (__SVR4) && defined (__sun) + #if defined(__i386) && !defined(__i386__) +  #define __i386__ + #endif + + #ifndef BIG_ENDIAN +  #define BIG_ENDIAN (4321) + #endif + #ifndef LITTLE_ENDIAN +  #define LITTLE_ENDIAN (1234) + #endif + + /* I386 is LE, even on Solaris */ + #if !defined(BYTE_ORDER) && defined(__i386__) +  #define BYTE_ORDER LITTLE_ENDIAN + #endif +#endif + +// TODO(dreiss): defval appears to be unused.  Look into removing it. +// TODO(dreiss): Make parse_spec_args recursive, and cache the output +//               permanently in the object.  (Malloc and orphan.) +// TODO(dreiss): Why do we need cStringIO for reading, why not just char*? +//               Can cStringIO let us work with a BufferedTransport? +// TODO(dreiss): Don't ignore the rv from cwrite (maybe). + +/* ====== BEGIN UTILITIES ====== */ + +#define INIT_OUTBUF_SIZE 128 + +// Stolen out of TProtocol.h. +// It would be a huge pain to have both get this from one place. +typedef enum TType { +  T_STOP       = 0, +  T_VOID       = 1, +  T_BOOL       = 2, +  T_BYTE       = 3, +  T_I08        = 3, +  T_I16        = 6, +  T_I32        = 8, +  T_U64        = 9, +  T_I64        = 10, +  T_DOUBLE     = 4, +  T_STRING     = 11, +  T_UTF7       = 11, +  T_STRUCT     = 12, +  T_MAP        = 13, +  T_SET        = 14, +  T_LIST       = 15, +  T_UTF8       = 16, +  T_UTF16      = 17 +} TType; + +#ifndef __BYTE_ORDER +# if defined(BYTE_ORDER) && defined(LITTLE_ENDIAN) && defined(BIG_ENDIAN) +#  define __BYTE_ORDER BYTE_ORDER +#  define __LITTLE_ENDIAN LITTLE_ENDIAN +#  define __BIG_ENDIAN BIG_ENDIAN +# else +#  error "Cannot determine endianness" +# endif +#endif + +// Same comment as the enum.  Sorry. +#if __BYTE_ORDER == __BIG_ENDIAN +# define ntohll(n) (n) +# define htonll(n) (n) +#elif __BYTE_ORDER == __LITTLE_ENDIAN +# if defined(__GNUC__) && defined(__GLIBC__) +#  include <byteswap.h> +#  define ntohll(n) bswap_64(n) +#  define htonll(n) bswap_64(n) +# else /* GNUC & GLIBC */ +#  define ntohll(n) ( (((unsigned long long)ntohl(n)) << 32) + ntohl(n >> 32) ) +#  define htonll(n) ( (((unsigned long long)htonl(n)) << 32) + htonl(n >> 32) ) +# endif /* GNUC & GLIBC */ +#else /* __BYTE_ORDER */ +# error "Can't define htonll or ntohll!" +#endif + +// Doing a benchmark shows that interning actually makes a difference, amazingly. +#define INTERN_STRING(value) _intern_ ## value + +#define INT_CONV_ERROR_OCCURRED(v) ( ((v) == -1) && PyErr_Occurred() ) +#define CHECK_RANGE(v, min, max) ( ((v) <= (max)) && ((v) >= (min)) ) + +// Py_ssize_t was not defined before Python 2.5 +#if (PY_VERSION_HEX < 0x02050000) +typedef int Py_ssize_t; +#endif + +/** + * A cache of the spec_args for a set or list, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { +  TType element_type; +  PyObject* typeargs; +} SetListTypeArgs; + +/** + * A cache of the spec_args for a map, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { +  TType ktag; +  TType vtag; +  PyObject* ktypeargs; +  PyObject* vtypeargs; +} MapTypeArgs; + +/** + * A cache of the spec_args for a struct, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { +  PyObject* klass; +  PyObject* spec; +} StructTypeArgs; + +/** + * A cache of the item spec from a struct specification, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { +  int tag; +  TType type; +  PyObject* attrname; +  PyObject* typeargs; +  PyObject* defval; +} StructItemSpec; + +/** + * A cache of the two key attributes of a CReadableTransport, + * so we don't have to keep calling PyObject_GetAttr. + */ +typedef struct { +  PyObject* stringiobuf; +  PyObject* refill_callable; +} DecodeBuffer; + +/** Pointer to interned string to speed up attribute lookup. */ +static PyObject* INTERN_STRING(cstringio_buf); +/** Pointer to interned string to speed up attribute lookup. */ +static PyObject* INTERN_STRING(cstringio_refill); + +static inline bool +check_ssize_t_32(Py_ssize_t len) { +  // error from getting the int +  if (INT_CONV_ERROR_OCCURRED(len)) { +    return false; +  } +  if (!CHECK_RANGE(len, 0, INT32_MAX)) { +    PyErr_SetString(PyExc_OverflowError, "string size out of range"); +    return false; +  } +  return true; +} + +static inline bool +parse_pyint(PyObject* o, int32_t* ret, int32_t min, int32_t max) { +  long val = PyInt_AsLong(o); + +  if (INT_CONV_ERROR_OCCURRED(val)) { +    return false; +  } +  if (!CHECK_RANGE(val, min, max)) { +    PyErr_SetString(PyExc_OverflowError, "int out of range"); +    return false; +  } + +  *ret = (int32_t) val; +  return true; +} + + +/* --- FUNCTIONS TO PARSE STRUCT SPECIFICATOINS --- */ + +static bool +parse_set_list_args(SetListTypeArgs* dest, PyObject* typeargs) { +  if (PyTuple_Size(typeargs) != 2) { +    PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for list/set type args"); +    return false; +  } + +  dest->element_type = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0)); +  if (INT_CONV_ERROR_OCCURRED(dest->element_type)) { +    return false; +  } + +  dest->typeargs = PyTuple_GET_ITEM(typeargs, 1); + +  return true; +} + +static bool +parse_map_args(MapTypeArgs* dest, PyObject* typeargs) { +  if (PyTuple_Size(typeargs) != 4) { +    PyErr_SetString(PyExc_TypeError, "expecting 4 arguments for typeargs to map"); +    return false; +  } + +  dest->ktag = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0)); +  if (INT_CONV_ERROR_OCCURRED(dest->ktag)) { +    return false; +  } + +  dest->vtag = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 2)); +  if (INT_CONV_ERROR_OCCURRED(dest->vtag)) { +    return false; +  } + +  dest->ktypeargs = PyTuple_GET_ITEM(typeargs, 1); +  dest->vtypeargs = PyTuple_GET_ITEM(typeargs, 3); + +  return true; +} + +static bool +parse_struct_args(StructTypeArgs* dest, PyObject* typeargs) { +  if (PyTuple_Size(typeargs) != 2) { +    PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for struct args"); +    return false; +  } + +  dest->klass = PyTuple_GET_ITEM(typeargs, 0); +  dest->spec = PyTuple_GET_ITEM(typeargs, 1); + +  return true; +} + +static int +parse_struct_item_spec(StructItemSpec* dest, PyObject* spec_tuple) { + +  // i'd like to use ParseArgs here, but it seems to be a bottleneck. +  if (PyTuple_Size(spec_tuple) != 5) { +    PyErr_SetString(PyExc_TypeError, "expecting 5 arguments for spec tuple"); +    return false; +  } + +  dest->tag = PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 0)); +  if (INT_CONV_ERROR_OCCURRED(dest->tag)) { +    return false; +  } + +  dest->type = PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 1)); +  if (INT_CONV_ERROR_OCCURRED(dest->type)) { +    return false; +  } + +  dest->attrname = PyTuple_GET_ITEM(spec_tuple, 2); +  dest->typeargs = PyTuple_GET_ITEM(spec_tuple, 3); +  dest->defval = PyTuple_GET_ITEM(spec_tuple, 4); +  return true; +} + +/* ====== END UTILITIES ====== */ + + +/* ====== BEGIN WRITING FUNCTIONS ====== */ + +/* --- LOW-LEVEL WRITING FUNCTIONS --- */ + +static void writeByte(PyObject* outbuf, int8_t val) { +  int8_t net = val; +  PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int8_t)); +} + +static void writeI16(PyObject* outbuf, int16_t val) { +  int16_t net = (int16_t)htons(val); +  PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int16_t)); +} + +static void writeI32(PyObject* outbuf, int32_t val) { +  int32_t net = (int32_t)htonl(val); +  PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int32_t)); +} + +static void writeI64(PyObject* outbuf, int64_t val) { +  int64_t net = (int64_t)htonll(val); +  PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int64_t)); +} + +static void writeDouble(PyObject* outbuf, double dub) { +  // Unfortunately, bitwise_cast doesn't work in C.  Bad C! +  union { +    double f; +    int64_t t; +  } transfer; +  transfer.f = dub; +  writeI64(outbuf, transfer.t); +} + + +/* --- MAIN RECURSIVE OUTPUT FUCNTION -- */ + +static int +output_val(PyObject* output, PyObject* value, TType type, PyObject* typeargs) { +  /* +   * Refcounting Strategy: +   * +   * We assume that elements of the thrift_spec tuple are not going to be +   * mutated, so we don't ref count those at all. Other than that, we try to +   * keep a reference to all the user-created objects while we work with them. +   * output_val assumes that a reference is already held. The *caller* is +   * responsible for handling references +   */ + +  switch (type) { + +  case T_BOOL: { +    int v = PyObject_IsTrue(value); +    if (v == -1) { +      return false; +    } + +    writeByte(output, (int8_t) v); +    break; +  } +  case T_I08: { +    int32_t val; + +    if (!parse_pyint(value, &val, INT8_MIN, INT8_MAX)) { +      return false; +    } + +    writeByte(output, (int8_t) val); +    break; +  } +  case T_I16: { +    int32_t val; + +    if (!parse_pyint(value, &val, INT16_MIN, INT16_MAX)) { +      return false; +    } + +    writeI16(output, (int16_t) val); +    break; +  } +  case T_I32: { +    int32_t val; + +    if (!parse_pyint(value, &val, INT32_MIN, INT32_MAX)) { +      return false; +    } + +    writeI32(output, val); +    break; +  } +  case T_I64: { +    int64_t nval = PyLong_AsLongLong(value); + +    if (INT_CONV_ERROR_OCCURRED(nval)) { +      return false; +    } + +    if (!CHECK_RANGE(nval, INT64_MIN, INT64_MAX)) { +      PyErr_SetString(PyExc_OverflowError, "int out of range"); +      return false; +    } + +    writeI64(output, nval); +    break; +  } + +  case T_DOUBLE: { +    double nval = PyFloat_AsDouble(value); +    if (nval == -1.0 && PyErr_Occurred()) { +      return false; +    } + +    writeDouble(output, nval); +    break; +  } + +  case T_STRING: { +    Py_ssize_t len = PyString_Size(value); + +    if (!check_ssize_t_32(len)) { +      return false; +    } + +    writeI32(output, (int32_t) len); +    PycStringIO->cwrite(output, PyString_AsString(value), (int32_t) len); +    break; +  } + +  case T_LIST: +  case T_SET: { +    Py_ssize_t len; +    SetListTypeArgs parsedargs; +    PyObject *item; +    PyObject *iterator; + +    if (!parse_set_list_args(&parsedargs, typeargs)) { +      return false; +    } + +    len = PyObject_Length(value); + +    if (!check_ssize_t_32(len)) { +      return false; +    } + +    writeByte(output, parsedargs.element_type); +    writeI32(output, (int32_t) len); + +    iterator =  PyObject_GetIter(value); +    if (iterator == NULL) { +      return false; +    } + +    while ((item = PyIter_Next(iterator))) { +      if (!output_val(output, item, parsedargs.element_type, parsedargs.typeargs)) { +        Py_DECREF(item); +        Py_DECREF(iterator); +        return false; +      } +      Py_DECREF(item); +    } + +    Py_DECREF(iterator); + +    if (PyErr_Occurred()) { +      return false; +    } + +    break; +  } + +  case T_MAP: { +    PyObject *k, *v; +    Py_ssize_t pos = 0; +    Py_ssize_t len; + +    MapTypeArgs parsedargs; + +    len = PyDict_Size(value); +    if (!check_ssize_t_32(len)) { +      return false; +    } + +    if (!parse_map_args(&parsedargs, typeargs)) { +      return false; +    } + +    writeByte(output, parsedargs.ktag); +    writeByte(output, parsedargs.vtag); +    writeI32(output, len); + +    // TODO(bmaurer): should support any mapping, not just dicts +    while (PyDict_Next(value, &pos, &k, &v)) { +      // TODO(dreiss): Think hard about whether these INCREFs actually +      //               turn any unsafe scenarios into safe scenarios. +      Py_INCREF(k); +      Py_INCREF(v); + +      if (!output_val(output, k, parsedargs.ktag, parsedargs.ktypeargs) +          || !output_val(output, v, parsedargs.vtag, parsedargs.vtypeargs)) { +        Py_DECREF(k); +        Py_DECREF(v); +        return false; +      } +      Py_DECREF(k); +      Py_DECREF(v); +    } +    break; +  } + +  // TODO(dreiss): Consider breaking this out as a function +  //               the way we did for decode_struct. +  case T_STRUCT: { +    StructTypeArgs parsedargs; +    Py_ssize_t nspec; +    Py_ssize_t i; + +    if (!parse_struct_args(&parsedargs, typeargs)) { +      return false; +    } + +    nspec = PyTuple_Size(parsedargs.spec); + +    if (nspec == -1) { +      return false; +    } + +    for (i = 0; i < nspec; i++) { +      StructItemSpec parsedspec; +      PyObject* spec_tuple; +      PyObject* instval = NULL; + +      spec_tuple = PyTuple_GET_ITEM(parsedargs.spec, i); +      if (spec_tuple == Py_None) { +        continue; +      } + +      if (!parse_struct_item_spec (&parsedspec, spec_tuple)) { +        return false; +      } + +      instval = PyObject_GetAttr(value, parsedspec.attrname); + +      if (!instval) { +        return false; +      } + +      if (instval == Py_None) { +        Py_DECREF(instval); +        continue; +      } + +      writeByte(output, (int8_t) parsedspec.type); +      writeI16(output, parsedspec.tag); + +      if (!output_val(output, instval, parsedspec.type, parsedspec.typeargs)) { +        Py_DECREF(instval); +        return false; +      } + +      Py_DECREF(instval); +    } + +    writeByte(output, (int8_t)T_STOP); +    break; +  } + +  case T_STOP: +  case T_VOID: +  case T_UTF16: +  case T_UTF8: +  case T_U64: +  default: +    PyErr_SetString(PyExc_TypeError, "Unexpected TType"); +    return false; + +  } + +  return true; +} + + +/* --- TOP-LEVEL WRAPPER FOR OUTPUT -- */ + +static PyObject * +encode_binary(PyObject *self, PyObject *args) { +  PyObject* enc_obj; +  PyObject* type_args; +  PyObject* buf; +  PyObject* ret = NULL; + +  if (!PyArg_ParseTuple(args, "OO", &enc_obj, &type_args)) { +    return NULL; +  } + +  buf = PycStringIO->NewOutput(INIT_OUTBUF_SIZE); +  if (output_val(buf, enc_obj, T_STRUCT, type_args)) { +    ret = PycStringIO->cgetvalue(buf); +  } + +  Py_DECREF(buf); +  return ret; +} + +/* ====== END WRITING FUNCTIONS ====== */ + + +/* ====== BEGIN READING FUNCTIONS ====== */ + +/* --- LOW-LEVEL READING FUNCTIONS --- */ + +static void +free_decodebuf(DecodeBuffer* d) { +  Py_XDECREF(d->stringiobuf); +  Py_XDECREF(d->refill_callable); +} + +static bool +decode_buffer_from_obj(DecodeBuffer* dest, PyObject* obj) { +  dest->stringiobuf = PyObject_GetAttr(obj, INTERN_STRING(cstringio_buf)); +  if (!dest->stringiobuf) { +    return false; +  } + +  if (!PycStringIO_InputCheck(dest->stringiobuf)) { +    free_decodebuf(dest); +    PyErr_SetString(PyExc_TypeError, "expecting stringio input"); +    return false; +  } + +  dest->refill_callable = PyObject_GetAttr(obj, INTERN_STRING(cstringio_refill)); + +  if(!dest->refill_callable) { +    free_decodebuf(dest); +    return false; +  } + +  if (!PyCallable_Check(dest->refill_callable)) { +    free_decodebuf(dest); +    PyErr_SetString(PyExc_TypeError, "expecting callable"); +    return false; +  } + +  return true; +} + +static bool readBytes(DecodeBuffer* input, char** output, int len) { +  int read; + +  // TODO(dreiss): Don't fear the malloc.  Think about taking a copy of +  //               the partial read instead of forcing the transport +  //               to prepend it to its buffer. + +  read = PycStringIO->cread(input->stringiobuf, output, len); + +  if (read == len) { +    return true; +  } else if (read == -1) { +    return false; +  } else { +    PyObject* newiobuf; + +    // using building functions as this is a rare codepath +    newiobuf = PyObject_CallFunction( +        input->refill_callable, "s#i", *output, read, len, NULL); +    if (newiobuf == NULL) { +      return false; +    } + +    // must do this *AFTER* the call so that we don't deref the io buffer +    Py_CLEAR(input->stringiobuf); +    input->stringiobuf = newiobuf; + +    read = PycStringIO->cread(input->stringiobuf, output, len); + +    if (read == len) { +      return true; +    } else if (read == -1) { +      return false; +    } else { +      // TODO(dreiss): This could be a valid code path for big binary blobs. +      PyErr_SetString(PyExc_TypeError, +          "refill claimed to have refilled the buffer, but didn't!!"); +      return false; +    } +  } +} + +static int8_t readByte(DecodeBuffer* input) { +  char* buf; +  if (!readBytes(input, &buf, sizeof(int8_t))) { +    return -1; +  } + +  return *(int8_t*) buf; +} + +static int16_t readI16(DecodeBuffer* input) { +  char* buf; +  if (!readBytes(input, &buf, sizeof(int16_t))) { +    return -1; +  } + +  return (int16_t) ntohs(*(int16_t*) buf); +} + +static int32_t readI32(DecodeBuffer* input) { +  char* buf; +  if (!readBytes(input, &buf, sizeof(int32_t))) { +    return -1; +  } +  return (int32_t) ntohl(*(int32_t*) buf); +} + + +static int64_t readI64(DecodeBuffer* input) { +  char* buf; +  if (!readBytes(input, &buf, sizeof(int64_t))) { +    return -1; +  } + +  return (int64_t) ntohll(*(int64_t*) buf); +} + +static double readDouble(DecodeBuffer* input) { +  union { +    int64_t f; +    double t; +  } transfer; + +  transfer.f = readI64(input); +  if (transfer.f == -1) { +    return -1; +  } +  return transfer.t; +} + +static bool +checkTypeByte(DecodeBuffer* input, TType expected) { +  TType got = readByte(input); +  if (INT_CONV_ERROR_OCCURRED(got)) { +    return false; +  } + +  if (expected != got) { +    PyErr_SetString(PyExc_TypeError, "got wrong ttype while reading field"); +    return false; +  } +  return true; +} + +static bool +skip(DecodeBuffer* input, TType type) { +#define SKIPBYTES(n) \ +  do { \ +    if (!readBytes(input, &dummy_buf, (n))) { \ +      return false; \ +    } \ +  } while(0) + +  char* dummy_buf; + +  switch (type) { + +  case T_BOOL: +  case T_I08: SKIPBYTES(1); break; +  case T_I16: SKIPBYTES(2); break; +  case T_I32: SKIPBYTES(4); break; +  case T_I64: +  case T_DOUBLE: SKIPBYTES(8); break; + +  case T_STRING: { +    // TODO(dreiss): Find out if these check_ssize_t32s are really necessary. +    int len = readI32(input); +    if (!check_ssize_t_32(len)) { +      return false; +    } +    SKIPBYTES(len); +    break; +  } + +  case T_LIST: +  case T_SET: { +    TType etype; +    int len, i; + +    etype = readByte(input); +    if (etype == -1) { +      return false; +    } + +    len = readI32(input); +    if (!check_ssize_t_32(len)) { +      return false; +    } + +    for (i = 0; i < len; i++) { +      if (!skip(input, etype)) { +        return false; +      } +    } +    break; +  } + +  case T_MAP: { +    TType ktype, vtype; +    int len, i; + +    ktype = readByte(input); +    if (ktype == -1) { +      return false; +    } + +    vtype = readByte(input); +    if (vtype == -1) { +      return false; +    } + +    len = readI32(input); +    if (!check_ssize_t_32(len)) { +      return false; +    } + +    for (i = 0; i < len; i++) { +      if (!(skip(input, ktype) && skip(input, vtype))) { +        return false; +      } +    } +    break; +  } + +  case T_STRUCT: { +    while (true) { +      TType type; + +      type = readByte(input); +      if (type == -1) { +        return false; +      } + +      if (type == T_STOP) +        break; + +      SKIPBYTES(2); // tag +      if (!skip(input, type)) { +        return false; +      } +    } +    break; +  } + +  case T_STOP: +  case T_VOID: +  case T_UTF16: +  case T_UTF8: +  case T_U64: +  default: +    PyErr_SetString(PyExc_TypeError, "Unexpected TType"); +    return false; + +  } + +  return true; + +#undef SKIPBYTES +} + + +/* --- HELPER FUNCTION FOR DECODE_VAL --- */ + +static PyObject* +decode_val(DecodeBuffer* input, TType type, PyObject* typeargs); + +static bool +decode_struct(DecodeBuffer* input, PyObject* output, PyObject* spec_seq) { +  int spec_seq_len = PyTuple_Size(spec_seq); +  if (spec_seq_len == -1) { +    return false; +  } + +  while (true) { +    TType type; +    int16_t tag; +    PyObject* item_spec; +    PyObject* fieldval = NULL; +    StructItemSpec parsedspec; + +    type = readByte(input); +    if (type == -1) { +      return false; +    } +    if (type == T_STOP) { +      break; +    } +    tag = readI16(input); +    if (INT_CONV_ERROR_OCCURRED(tag)) { +      return false; +    } +    if (tag >= 0 && tag < spec_seq_len) { +      item_spec = PyTuple_GET_ITEM(spec_seq, tag); +    } else { +      item_spec = Py_None; +    } + +    if (item_spec == Py_None) { +      if (!skip(input, type)) { +        return false; +      } else { +        continue; +      } +    } + +    if (!parse_struct_item_spec(&parsedspec, item_spec)) { +      return false; +    } +    if (parsedspec.type != type) { +      if (!skip(input, type)) { +        PyErr_SetString(PyExc_TypeError, "struct field had wrong type while reading and can't be skipped"); +        return false; +      } else { +        continue; +      } +    } + +    fieldval = decode_val(input, parsedspec.type, parsedspec.typeargs); +    if (fieldval == NULL) { +      return false; +    } + +    if (PyObject_SetAttr(output, parsedspec.attrname, fieldval) == -1) { +      Py_DECREF(fieldval); +      return false; +    } +    Py_DECREF(fieldval); +  } +  return true; +} + + +/* --- MAIN RECURSIVE INPUT FUCNTION --- */ + +// Returns a new reference. +static PyObject* +decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) { +  switch (type) { + +  case T_BOOL: { +    int8_t v = readByte(input); +    if (INT_CONV_ERROR_OCCURRED(v)) { +      return NULL; +    } + +    switch (v) { +    case 0: Py_RETURN_FALSE; +    case 1: Py_RETURN_TRUE; +    // Don't laugh.  This is a potentially serious issue. +    default: PyErr_SetString(PyExc_TypeError, "boolean out of range"); return NULL; +    } +    break; +  } +  case T_I08: { +    int8_t v = readByte(input); +    if (INT_CONV_ERROR_OCCURRED(v)) { +      return NULL; +    } + +    return PyInt_FromLong(v); +  } +  case T_I16: { +    int16_t v = readI16(input); +    if (INT_CONV_ERROR_OCCURRED(v)) { +      return NULL; +    } +    return PyInt_FromLong(v); +  } +  case T_I32: { +    int32_t v = readI32(input); +    if (INT_CONV_ERROR_OCCURRED(v)) { +      return NULL; +    } +    return PyInt_FromLong(v); +  } + +  case T_I64: { +    int64_t v = readI64(input); +    if (INT_CONV_ERROR_OCCURRED(v)) { +      return NULL; +    } +    // TODO(dreiss): Find out if we can take this fastpath always when +    //               sizeof(long) == sizeof(long long). +    if (CHECK_RANGE(v, LONG_MIN, LONG_MAX)) { +      return PyInt_FromLong((long) v); +    } + +    return PyLong_FromLongLong(v); +  } + +  case T_DOUBLE: { +    double v = readDouble(input); +    if (v == -1.0 && PyErr_Occurred()) { +      return false; +    } +    return PyFloat_FromDouble(v); +  } + +  case T_STRING: { +    Py_ssize_t len = readI32(input); +    char* buf; +    if (!readBytes(input, &buf, len)) { +      return NULL; +    } + +    return PyString_FromStringAndSize(buf, len); +  } + +  case T_LIST: +  case T_SET: { +    SetListTypeArgs parsedargs; +    int32_t len; +    PyObject* ret = NULL; +    int i; + +    if (!parse_set_list_args(&parsedargs, typeargs)) { +      return NULL; +    } + +    if (!checkTypeByte(input, parsedargs.element_type)) { +      return NULL; +    } + +    len = readI32(input); +    if (!check_ssize_t_32(len)) { +      return NULL; +    } + +    ret = PyList_New(len); +    if (!ret) { +      return NULL; +    } + +    for (i = 0; i < len; i++) { +      PyObject* item = decode_val(input, parsedargs.element_type, parsedargs.typeargs); +      if (!item) { +        Py_DECREF(ret); +        return NULL; +      } +      PyList_SET_ITEM(ret, i, item); +    } + +    // TODO(dreiss): Consider biting the bullet and making two separate cases +    //               for list and set, avoiding this post facto conversion. +    if (type == T_SET) { +      PyObject* setret; +#if (PY_VERSION_HEX < 0x02050000) +      // hack needed for older versions +      setret = PyObject_CallFunctionObjArgs((PyObject*)&PySet_Type, ret, NULL); +#else +      // official version +      setret = PySet_New(ret); +#endif +      Py_DECREF(ret); +      return setret; +    } +    return ret; +  } + +  case T_MAP: { +    int32_t len; +    int i; +    MapTypeArgs parsedargs; +    PyObject* ret = NULL; + +    if (!parse_map_args(&parsedargs, typeargs)) { +      return NULL; +    } + +    if (!checkTypeByte(input, parsedargs.ktag)) { +      return NULL; +    } +    if (!checkTypeByte(input, parsedargs.vtag)) { +      return NULL; +    } + +    len = readI32(input); +    if (!check_ssize_t_32(len)) { +      return false; +    } + +    ret = PyDict_New(); +    if (!ret) { +      goto error; +    } + +    for (i = 0; i < len; i++) { +      PyObject* k = NULL; +      PyObject* v = NULL; +      k = decode_val(input, parsedargs.ktag, parsedargs.ktypeargs); +      if (k == NULL) { +        goto loop_error; +      } +      v = decode_val(input, parsedargs.vtag, parsedargs.vtypeargs); +      if (v == NULL) { +        goto loop_error; +      } +      if (PyDict_SetItem(ret, k, v) == -1) { +        goto loop_error; +      } + +      Py_DECREF(k); +      Py_DECREF(v); +      continue; + +      // Yuck!  Destructors, anyone? +      loop_error: +      Py_XDECREF(k); +      Py_XDECREF(v); +      goto error; +    } + +    return ret; + +    error: +    Py_XDECREF(ret); +    return NULL; +  } + +  case T_STRUCT: { +    StructTypeArgs parsedargs; +    if (!parse_struct_args(&parsedargs, typeargs)) { +      return NULL; +    } + +    PyObject* ret = PyObject_CallObject(parsedargs.klass, NULL); +    if (!ret) { +      return NULL; +    } + +    if (!decode_struct(input, ret, parsedargs.spec)) { +      Py_DECREF(ret); +      return NULL; +    } + +    return ret; +  } + +  case T_STOP: +  case T_VOID: +  case T_UTF16: +  case T_UTF8: +  case T_U64: +  default: +    PyErr_SetString(PyExc_TypeError, "Unexpected TType"); +    return NULL; +  } +} + + +/* --- TOP-LEVEL WRAPPER FOR INPUT -- */ + +static PyObject* +decode_binary(PyObject *self, PyObject *args) { +  PyObject* output_obj = NULL; +  PyObject* transport = NULL; +  PyObject* typeargs = NULL; +  StructTypeArgs parsedargs; +  DecodeBuffer input = {}; + +  if (!PyArg_ParseTuple(args, "OOO", &output_obj, &transport, &typeargs)) { +    return NULL; +  } + +  if (!parse_struct_args(&parsedargs, typeargs)) { +    return NULL; +  } + +  if (!decode_buffer_from_obj(&input, transport)) { +    return NULL; +  } + +  if (!decode_struct(&input, output_obj, parsedargs.spec)) { +    free_decodebuf(&input); +    return NULL; +  } + +  free_decodebuf(&input); + +  Py_RETURN_NONE; +} + +/* ====== END READING FUNCTIONS ====== */ + + +/* -- PYTHON MODULE SETUP STUFF --- */ + +static PyMethodDef ThriftFastBinaryMethods[] = { + +  {"encode_binary",  encode_binary, METH_VARARGS, ""}, +  {"decode_binary",  decode_binary, METH_VARARGS, ""}, + +  {NULL, NULL, 0, NULL}        /* Sentinel */ +}; + +PyMODINIT_FUNC +initfastbinary(void) { +#define INIT_INTERN_STRING(value) \ +  do { \ +    INTERN_STRING(value) = PyString_InternFromString(#value); \ +    if(!INTERN_STRING(value)) return; \ +  } while(0) + +  INIT_INTERN_STRING(cstringio_buf); +  INIT_INTERN_STRING(cstringio_refill); +#undef INIT_INTERN_STRING + +  PycString_IMPORT; +  if (PycStringIO == NULL) return; + +  (void) Py_InitModule("thrift.protocol.fastbinary", ThriftFastBinaryMethods); +} diff --git a/module/lib/thrift/server/THttpServer.py b/module/lib/thrift/server/THttpServer.py new file mode 100644 index 000000000..3047d9c00 --- /dev/null +++ b/module/lib/thrift/server/THttpServer.py @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import BaseHTTPServer + +from thrift.server import TServer +from thrift.transport import TTransport + +class ResponseException(Exception): +  """Allows handlers to override the HTTP response + +  Normally, THttpServer always sends a 200 response.  If a handler wants +  to override this behavior (e.g., to simulate a misconfigured or +  overloaded web server during testing), it can raise a ResponseException. +  The function passed to the constructor will be called with the +  RequestHandler as its only argument. +  """ +  def __init__(self, handler): +    self.handler = handler + + +class THttpServer(TServer.TServer): +  """A simple HTTP-based Thrift server + +  This class is not very performant, but it is useful (for example) for +  acting as a mock version of an Apache-based PHP Thrift endpoint.""" + +  def __init__(self, processor, server_address, +      inputProtocolFactory, outputProtocolFactory = None, +      server_class = BaseHTTPServer.HTTPServer): +    """Set up protocol factories and HTTP server. + +    See BaseHTTPServer for server_address. +    See TServer for protocol factories.""" + +    if outputProtocolFactory is None: +      outputProtocolFactory = inputProtocolFactory + +    TServer.TServer.__init__(self, processor, None, None, None, +        inputProtocolFactory, outputProtocolFactory) + +    thttpserver = self + +    class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler): +      def do_POST(self): +        # Don't care about the request path. +        itrans = TTransport.TFileObjectTransport(self.rfile) +        otrans = TTransport.TFileObjectTransport(self.wfile) +        itrans = TTransport.TBufferedTransport(itrans, int(self.headers['Content-Length'])) +        otrans = TTransport.TMemoryBuffer() +        iprot = thttpserver.inputProtocolFactory.getProtocol(itrans) +        oprot = thttpserver.outputProtocolFactory.getProtocol(otrans) +        try: +          thttpserver.processor.process(iprot, oprot) +        except ResponseException, exn: +          exn.handler(self) +        else: +          self.send_response(200) +          self.send_header("content-type", "application/x-thrift") +          self.end_headers() +          self.wfile.write(otrans.getvalue()) + +    self.httpd = server_class(server_address, RequestHander) + +  def serve(self): +    self.httpd.serve_forever() diff --git a/module/lib/thrift/server/TNonblockingServer.py b/module/lib/thrift/server/TNonblockingServer.py new file mode 100644 index 000000000..ea348a0b6 --- /dev/null +++ b/module/lib/thrift/server/TNonblockingServer.py @@ -0,0 +1,310 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""Implementation of non-blocking server. + +The main idea of the server is reciving and sending requests +only from main thread. + +It also makes thread pool server in tasks terms, not connections. +""" +import threading +import socket +import Queue +import select +import struct +import logging + +from thrift.transport import TTransport +from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory + +__all__ = ['TNonblockingServer'] + +class Worker(threading.Thread): +    """Worker is a small helper to process incoming connection.""" +    def __init__(self, queue): +        threading.Thread.__init__(self) +        self.queue = queue + +    def run(self): +        """Process queries from task queue, stop if processor is None.""" +        while True: +            try: +                processor, iprot, oprot, otrans, callback = self.queue.get() +                if processor is None: +                    break +                processor.process(iprot, oprot) +                callback(True, otrans.getvalue()) +            except Exception: +                logging.exception("Exception while processing request") +                callback(False, '') + +WAIT_LEN = 0 +WAIT_MESSAGE = 1 +WAIT_PROCESS = 2 +SEND_ANSWER = 3 +CLOSED = 4 + +def locked(func): +    "Decorator which locks self.lock." +    def nested(self, *args, **kwargs): +        self.lock.acquire() +        try: +            return func(self, *args, **kwargs) +        finally: +            self.lock.release() +    return nested + +def socket_exception(func): +    "Decorator close object on socket.error." +    def read(self, *args, **kwargs): +        try: +            return func(self, *args, **kwargs) +        except socket.error: +            self.close() +    return read + +class Connection: +    """Basic class is represented connection. +     +    It can be in state: +        WAIT_LEN --- connection is reading request len. +        WAIT_MESSAGE --- connection is reading request. +        WAIT_PROCESS --- connection has just read whole request and  +            waits for call ready routine. +        SEND_ANSWER --- connection is sending answer string (including length +            of answer). +        CLOSED --- socket was closed and connection should be deleted. +    """ +    def __init__(self, new_socket, wake_up): +        self.socket = new_socket +        self.socket.setblocking(False) +        self.status = WAIT_LEN +        self.len = 0 +        self.message = '' +        self.lock = threading.Lock() +        self.wake_up = wake_up + +    def _read_len(self): +        """Reads length of request. +         +        It's really paranoic routine and it may be replaced by  +        self.socket.recv(4).""" +        read = self.socket.recv(4 - len(self.message)) +        if len(read) == 0: +            # if we read 0 bytes and self.message is empty, it means client close  +            # connection +            if len(self.message) != 0: +                logging.error("can't read frame size from socket") +            self.close() +            return +        self.message += read +        if len(self.message) == 4: +            self.len, = struct.unpack('!i', self.message) +            if self.len < 0: +                logging.error("negative frame size, it seems client"\ +                    " doesn't use FramedTransport") +                self.close() +            elif self.len == 0: +                logging.error("empty frame, it's really strange") +                self.close() +            else: +                self.message = '' +                self.status = WAIT_MESSAGE + +    @socket_exception +    def read(self): +        """Reads data from stream and switch state.""" +        assert self.status in (WAIT_LEN, WAIT_MESSAGE) +        if self.status == WAIT_LEN: +            self._read_len() +            # go back to the main loop here for simplicity instead of +            # falling through, even though there is a good chance that +            # the message is already available +        elif self.status == WAIT_MESSAGE: +            read = self.socket.recv(self.len - len(self.message)) +            if len(read) == 0: +                logging.error("can't read frame from socket (get %d of %d bytes)" % +                    (len(self.message), self.len)) +                self.close() +                return +            self.message += read +            if len(self.message) == self.len: +                self.status = WAIT_PROCESS + +    @socket_exception +    def write(self): +        """Writes data from socket and switch state.""" +        assert self.status == SEND_ANSWER +        sent = self.socket.send(self.message) +        if sent == len(self.message): +            self.status = WAIT_LEN +            self.message = '' +            self.len = 0 +        else: +            self.message = self.message[sent:] + +    @locked +    def ready(self, all_ok, message): +        """Callback function for switching state and waking up main thread. +         +        This function is the only function witch can be called asynchronous. +         +        The ready can switch Connection to three states: +            WAIT_LEN if request was oneway. +            SEND_ANSWER if request was processed in normal way. +            CLOSED if request throws unexpected exception. +         +        The one wakes up main thread. +        """ +        assert self.status == WAIT_PROCESS +        if not all_ok: +            self.close() +            self.wake_up() +            return +        self.len = '' +        if len(message) == 0: +            # it was a oneway request, do not write answer +            self.message = '' +            self.status = WAIT_LEN +        else: +            self.message = struct.pack('!i', len(message)) + message +            self.status = SEND_ANSWER +        self.wake_up() + +    @locked +    def is_writeable(self): +        "Returns True if connection should be added to write list of select." +        return self.status == SEND_ANSWER + +    # it's not necessary, but... +    @locked +    def is_readable(self): +        "Returns True if connection should be added to read list of select." +        return self.status in (WAIT_LEN, WAIT_MESSAGE) + +    @locked +    def is_closed(self): +        "Returns True if connection is closed." +        return self.status == CLOSED + +    def fileno(self): +        "Returns the file descriptor of the associated socket." +        return self.socket.fileno() + +    def close(self): +        "Closes connection" +        self.status = CLOSED +        self.socket.close() + +class TNonblockingServer: +    """Non-blocking server.""" +    def __init__(self, processor, lsocket, inputProtocolFactory=None,  +            outputProtocolFactory=None, threads=10): +        self.processor = processor +        self.socket = lsocket +        self.in_protocol = inputProtocolFactory or TBinaryProtocolFactory() +        self.out_protocol = outputProtocolFactory or self.in_protocol +        self.threads = int(threads) +        self.clients = {} +        self.tasks = Queue.Queue() +        self._read, self._write = socket.socketpair() +        self.prepared = False + +    def setNumThreads(self, num): +        """Set the number of worker threads that should be created.""" +        # implement ThreadPool interface +        assert not self.prepared, "You can't change number of threads for working server" +        self.threads = num + +    def prepare(self): +        """Prepares server for serve requests.""" +        self.socket.listen() +        for _ in xrange(self.threads): +            thread = Worker(self.tasks) +            thread.setDaemon(True) +            thread.start() +        self.prepared = True + +    def wake_up(self): +        """Wake up main thread. +         +        The server usualy waits in select call in we should terminate one. +        The simplest way is using socketpair. +         +        Select always wait to read from the first socket of socketpair. +         +        In this case, we can just write anything to the second socket from +        socketpair.""" +        self._write.send('1') + +    def _select(self): +        """Does select on open connections.""" +        readable = [self.socket.handle.fileno(), self._read.fileno()] +        writable = [] +        for i, connection in self.clients.items(): +            if connection.is_readable(): +                readable.append(connection.fileno()) +            if connection.is_writeable(): +                writable.append(connection.fileno()) +            if connection.is_closed(): +                del self.clients[i] +        return select.select(readable, writable, readable) +         +    def handle(self): +        """Handle requests. +        +        WARNING! You must call prepare BEFORE calling handle. +        """ +        assert self.prepared, "You have to call prepare before handle" +        rset, wset, xset = self._select() +        for readable in rset: +            if readable == self._read.fileno(): +                # don't care i just need to clean readable flag +                self._read.recv(1024)  +            elif readable == self.socket.handle.fileno(): +                client = self.socket.accept().handle +                self.clients[client.fileno()] = Connection(client, self.wake_up) +            else: +                connection = self.clients[readable] +                connection.read() +                if connection.status == WAIT_PROCESS: +                    itransport = TTransport.TMemoryBuffer(connection.message) +                    otransport = TTransport.TMemoryBuffer() +                    iprot = self.in_protocol.getProtocol(itransport) +                    oprot = self.out_protocol.getProtocol(otransport) +                    self.tasks.put([self.processor, iprot, oprot,  +                                    otransport, connection.ready]) +        for writeable in wset: +            self.clients[writeable].write() +        for oob in xset: +            self.clients[oob].close() +            del self.clients[oob] + +    def close(self): +        """Closes the server.""" +        for _ in xrange(self.threads): +            self.tasks.put([None, None, None, None, None]) +        self.socket.close() +        self.prepared = False +         +    def serve(self): +        """Serve forever.""" +        self.prepare() +        while True: +            self.handle() diff --git a/module/lib/thrift/server/TServer.py b/module/lib/thrift/server/TServer.py new file mode 100644 index 000000000..6b0707b54 --- /dev/null +++ b/module/lib/thrift/server/TServer.py @@ -0,0 +1,275 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import logging +import sys +import os +import traceback +import threading +import Queue + +from thrift.Thrift import TProcessor +from thrift.transport import TTransport +from thrift.protocol import TBinaryProtocol + +class TServer: + +  """Base interface for a server, which must have a serve method.""" + +  """ 3 constructors for all servers: +  1) (processor, serverTransport) +  2) (processor, serverTransport, transportFactory, protocolFactory) +  3) (processor, serverTransport, +      inputTransportFactory, outputTransportFactory, +      inputProtocolFactory, outputProtocolFactory)""" +  def __init__(self, *args): +    if (len(args) == 2): +      self.__initArgs__(args[0], args[1], +                        TTransport.TTransportFactoryBase(), +                        TTransport.TTransportFactoryBase(), +                        TBinaryProtocol.TBinaryProtocolFactory(), +                        TBinaryProtocol.TBinaryProtocolFactory()) +    elif (len(args) == 4): +      self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3]) +    elif (len(args) == 6): +      self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5]) + +  def __initArgs__(self, processor, serverTransport, +                   inputTransportFactory, outputTransportFactory, +                   inputProtocolFactory, outputProtocolFactory): +    self.processor = processor +    self.serverTransport = serverTransport +    self.inputTransportFactory = inputTransportFactory +    self.outputTransportFactory = outputTransportFactory +    self.inputProtocolFactory = inputProtocolFactory +    self.outputProtocolFactory = outputProtocolFactory + +  def serve(self): +    pass + +class TSimpleServer(TServer): + +  """Simple single-threaded server that just pumps around one transport.""" + +  def __init__(self, *args): +    TServer.__init__(self, *args) + +  def serve(self): +    self.serverTransport.listen() +    while True: +      client = self.serverTransport.accept() +      itrans = self.inputTransportFactory.getTransport(client) +      otrans = self.outputTransportFactory.getTransport(client) +      iprot = self.inputProtocolFactory.getProtocol(itrans) +      oprot = self.outputProtocolFactory.getProtocol(otrans) +      try: +        while True: +          self.processor.process(iprot, oprot) +      except TTransport.TTransportException, tx: +        pass +      except Exception, x: +        logging.exception(x) + +      itrans.close() +      otrans.close() + +class TThreadedServer(TServer): + +  """Threaded server that spawns a new thread per each connection.""" + +  def __init__(self, *args, **kwargs): +    TServer.__init__(self, *args) +    self.daemon = kwargs.get("daemon", False) + +  def serve(self): +    self.serverTransport.listen() +    while True: +      try: +        client = self.serverTransport.accept() +        print client +        t = threading.Thread(target = self.handle, args=(client,)) +        t.setDaemon(self.daemon) +        t.start() +      except KeyboardInterrupt: +        raise +      except Exception, x: +        logging.exception(x) + +  def handle(self, client): +    itrans = self.inputTransportFactory.getTransport(client) +    otrans = self.outputTransportFactory.getTransport(client) +    iprot = self.inputProtocolFactory.getProtocol(itrans) +    oprot = self.outputProtocolFactory.getProtocol(otrans) +    try: +      while True: +        self.processor.process(iprot, oprot) +    except TTransport.TTransportException, tx: +      pass +    except Exception, x: +      logging.exception(x) + +    itrans.close() +    otrans.close() + +class TThreadPoolServer(TServer): + +  """Server with a fixed size pool of threads which service requests.""" + +  def __init__(self, *args, **kwargs): +    TServer.__init__(self, *args) +    self.clients = Queue.Queue() +    self.threads = 10 +    self.daemon = kwargs.get("daemon", False) + +  def setNumThreads(self, num): +    """Set the number of worker threads that should be created""" +    self.threads = num + +  def serveThread(self): +    """Loop around getting clients from the shared queue and process them.""" +    while True: +      try: +        client = self.clients.get() +        self.serveClient(client) +      except Exception, x: +        logging.exception(x) + +  def serveClient(self, client): +    """Process input/output from a client for as long as possible""" +    itrans = self.inputTransportFactory.getTransport(client) +    otrans = self.outputTransportFactory.getTransport(client) +    iprot = self.inputProtocolFactory.getProtocol(itrans) +    oprot = self.outputProtocolFactory.getProtocol(otrans) +    try: +      while True: +        self.processor.process(iprot, oprot) +    except TTransport.TTransportException, tx: +      pass +    except Exception, x: +      logging.exception(x) + +    itrans.close() +    otrans.close() + +  def serve(self): +    """Start a fixed number of worker threads and put client into a queue""" +    for i in range(self.threads): +      try: +        t = threading.Thread(target = self.serveThread) +        t.setDaemon(self.daemon) +        t.start() +      except Exception, x: +        logging.exception(x) + +    # Pump the socket for clients +    self.serverTransport.listen() +    while True: +      try: +        client = self.serverTransport.accept() +        self.clients.put(client) +      except Exception, x: +        logging.exception(x) + + +class TForkingServer(TServer): + +  """A Thrift server that forks a new process for each request""" +  """ +  This is more scalable than the threaded server as it does not cause +  GIL contention. + +  Note that this has different semantics from the threading server. +  Specifically, updates to shared variables will no longer be shared. +  It will also not work on windows. + +  This code is heavily inspired by SocketServer.ForkingMixIn in the +  Python stdlib. +  """ + +  def __init__(self, *args): +    TServer.__init__(self, *args) +    self.children = [] + +  def serve(self): +    def try_close(file): +      try: +        file.close() +      except IOError, e: +        logging.warning(e, exc_info=True) + + +    self.serverTransport.listen() +    while True: +      client = self.serverTransport.accept() +      try: +        pid = os.fork() + +        if pid: # parent +          # add before collect, otherwise you race w/ waitpid +          self.children.append(pid) +          self.collect_children() + +          # Parent must close socket or the connection may not get +          # closed promptly +          itrans = self.inputTransportFactory.getTransport(client) +          otrans = self.outputTransportFactory.getTransport(client) +          try_close(itrans) +          try_close(otrans) +        else: +          itrans = self.inputTransportFactory.getTransport(client) +          otrans = self.outputTransportFactory.getTransport(client) + +          iprot = self.inputProtocolFactory.getProtocol(itrans) +          oprot = self.outputProtocolFactory.getProtocol(otrans) + +          ecode = 0 +          try: +            try: +              while True: +                self.processor.process(iprot, oprot) +            except TTransport.TTransportException, tx: +              pass +            except Exception, e: +              logging.exception(e) +              ecode = 1 +          finally: +            try_close(itrans) +            try_close(otrans) + +          os._exit(ecode) + +      except TTransport.TTransportException, tx: +        pass +      except Exception, x: +        logging.exception(x) + + +  def collect_children(self): +    while self.children: +      try: +        pid, status = os.waitpid(0, os.WNOHANG) +      except os.error: +        pid = None + +      if pid: +        self.children.remove(pid) +      else: +        break + + diff --git a/module/lib/thrift/server/__init__.py b/module/lib/thrift/server/__init__.py new file mode 100644 index 000000000..1bf6e254e --- /dev/null +++ b/module/lib/thrift/server/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +__all__ = ['TServer', 'TNonblockingServer'] diff --git a/module/lib/thrift/transport/THttpClient.py b/module/lib/thrift/transport/THttpClient.py new file mode 100644 index 000000000..50269785c --- /dev/null +++ b/module/lib/thrift/transport/THttpClient.py @@ -0,0 +1,126 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from TTransport import * +from cStringIO import StringIO + +import urlparse +import httplib +import warnings +import socket + +class THttpClient(TTransportBase): + +  """Http implementation of TTransport base.""" + +  def __init__(self, uri_or_host, port=None, path=None): +    """THttpClient supports two different types constructor parameters. + +    THttpClient(host, port, path) - deprecated +    THttpClient(uri) + +    Only the second supports https.""" + +    if port is not None: +      warnings.warn("Please use the THttpClient('http://host:port/path') syntax", DeprecationWarning, stacklevel=2) +      self.host = uri_or_host +      self.port = port +      assert path +      self.path = path +      self.scheme = 'http' +    else: +      parsed = urlparse.urlparse(uri_or_host) +      self.scheme = parsed.scheme +      assert self.scheme in ('http', 'https') +      if self.scheme == 'http': +        self.port = parsed.port or httplib.HTTP_PORT +      elif self.scheme == 'https': +        self.port = parsed.port or httplib.HTTPS_PORT +      self.host = parsed.hostname +      self.path = parsed.path +      if parsed.query: +        self.path += '?%s' % parsed.query +    self.__wbuf = StringIO() +    self.__http = None +    self.__timeout = None + +  def open(self): +    if self.scheme == 'http': +      self.__http = httplib.HTTP(self.host, self.port) +    else: +      self.__http = httplib.HTTPS(self.host, self.port) + +  def close(self): +    self.__http.close() +    self.__http = None + +  def isOpen(self): +    return self.__http != None + +  def setTimeout(self, ms): +    if not hasattr(socket, 'getdefaulttimeout'): +      raise NotImplementedError + +    if ms is None: +      self.__timeout = None +    else: +      self.__timeout = ms/1000.0 + +  def read(self, sz): +    return self.__http.file.read(sz) + +  def write(self, buf): +    self.__wbuf.write(buf) + +  def __withTimeout(f): +    def _f(*args, **kwargs): +      orig_timeout = socket.getdefaulttimeout() +      socket.setdefaulttimeout(args[0].__timeout) +      result = f(*args, **kwargs) +      socket.setdefaulttimeout(orig_timeout) +      return result +    return _f + +  def flush(self): +    if self.isOpen(): +      self.close() +    self.open(); + +    # Pull data out of buffer +    data = self.__wbuf.getvalue() +    self.__wbuf = StringIO() + +    # HTTP request +    self.__http.putrequest('POST', self.path) + +    # Write headers +    self.__http.putheader('Host', self.host) +    self.__http.putheader('Content-Type', 'application/x-thrift') +    self.__http.putheader('Content-Length', str(len(data))) +    self.__http.endheaders() + +    # Write payload +    self.__http.send(data) + +    # Get reply to flush the request +    self.code, self.message, self.headers = self.__http.getreply() + +  # Decorate if we know how to timeout +  if hasattr(socket, 'getdefaulttimeout'): +    flush = __withTimeout(flush) diff --git a/module/lib/thrift/transport/TSocket.py b/module/lib/thrift/transport/TSocket.py new file mode 100644 index 000000000..d77e358a2 --- /dev/null +++ b/module/lib/thrift/transport/TSocket.py @@ -0,0 +1,163 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from TTransport import * +import os +import errno +import socket +import sys + +class TSocketBase(TTransportBase): +  def _resolveAddr(self): +    if self._unix_socket is not None: +      return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, self._unix_socket)] +    else: +      return socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE | socket.AI_ADDRCONFIG) + +  def close(self): +    if self.handle: +      self.handle.close() +      self.handle = None + +class TSocket(TSocketBase): +  """Socket implementation of TTransport base.""" + +  def __init__(self, host='localhost', port=9090, unix_socket=None): +    """Initialize a TSocket + +    @param host(str)  The host to connect to. +    @param port(int)  The (TCP) port to connect to. +    @param unix_socket(str)  The filename of a unix socket to connect to. +                             (host and port will be ignored.) +    """ + +    self.host = host +    self.port = port +    self.handle = None +    self._unix_socket = unix_socket +    self._timeout = None + +  def setHandle(self, h): +    self.handle = h + +  def isOpen(self): +    return self.handle != None + +  def setTimeout(self, ms): +    if ms is None: +      self._timeout = None +    else: +      self._timeout = ms/1000.0 + +    if (self.handle != None): +      self.handle.settimeout(self._timeout) + +  def open(self): +    try: +      res0 = self._resolveAddr() +      for res in res0: +        self.handle = socket.socket(res[0], res[1]) +        self.handle.settimeout(self._timeout) +        try: +          self.handle.connect(res[4]) +        except socket.error, e: +          if res is not res0[-1]: +            continue +          else: +            raise e +        break +    except socket.error, e: +      if self._unix_socket: +        message = 'Could not connect to socket %s' % self._unix_socket +      else: +        message = 'Could not connect to %s:%d' % (self.host, self.port) +      raise TTransportException(type=TTransportException.NOT_OPEN, message=message) + +  def read(self, sz): +    try: +      buff = self.handle.recv(sz) +    except socket.error, e: +      if (e.args[0] == errno.ECONNRESET and +          (sys.platform == 'darwin' or sys.platform.startswith('freebsd'))): +        # freebsd and Mach don't follow POSIX semantic of recv +        # and fail with ECONNRESET if peer performed shutdown. +        # See corresponding comment and code in TSocket::read() +        # in lib/cpp/src/transport/TSocket.cpp. +        self.close() +        # Trigger the check to raise the END_OF_FILE exception below. +        buff = '' +      else: +        raise +    if len(buff) == 0: +      raise TTransportException(type=TTransportException.END_OF_FILE, message='TSocket read 0 bytes') +    return buff + +  def write(self, buff): +    if not self.handle: +      raise TTransportException(type=TTransportException.NOT_OPEN, message='Transport not open') +    sent = 0 +    have = len(buff) +    while sent < have: +      plus = self.handle.send(buff) +      if plus == 0: +        raise TTransportException(type=TTransportException.END_OF_FILE, message='TSocket sent 0 bytes') +      sent += plus +      buff = buff[plus:] + +  def flush(self): +    pass + +class TServerSocket(TSocketBase, TServerTransportBase): +  """Socket implementation of TServerTransport base.""" + +  def __init__(self, port=9090, unix_socket=None): +    self.host = None +    self.port = port +    self._unix_socket = unix_socket +    self.handle = None + +  def listen(self): +    res0 = self._resolveAddr() +    for res in res0: +      if res[0] is socket.AF_INET6 or res is res0[-1]: +        break + +    # We need remove the old unix socket if the file exists and +    # nobody is listening on it. +    if self._unix_socket: +      tmp = socket.socket(res[0], res[1]) +      try: +        tmp.connect(res[4]) +      except socket.error, err: +        eno, message = err.args +        if eno == errno.ECONNREFUSED: +          os.unlink(res[4]) + +    self.handle = socket.socket(res[0], res[1]) +    self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) +    if hasattr(self.handle, 'set_timeout'): +      self.handle.set_timeout(None) +    self.handle.bind(res[4]) +    self.handle.listen(128) + +  def accept(self): +    client, addr = self.handle.accept() +    result = TSocket() +    result.setHandle(client) +    return result diff --git a/module/lib/thrift/transport/TTransport.py b/module/lib/thrift/transport/TTransport.py new file mode 100644 index 000000000..12e51a9bf --- /dev/null +++ b/module/lib/thrift/transport/TTransport.py @@ -0,0 +1,331 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from cStringIO import StringIO +from struct import pack,unpack +from thrift.Thrift import TException + +class TTransportException(TException): + +  """Custom Transport Exception class""" + +  UNKNOWN = 0 +  NOT_OPEN = 1 +  ALREADY_OPEN = 2 +  TIMED_OUT = 3 +  END_OF_FILE = 4 + +  def __init__(self, type=UNKNOWN, message=None): +    TException.__init__(self, message) +    self.type = type + +class TTransportBase: + +  """Base class for Thrift transport layer.""" + +  def isOpen(self): +    pass + +  def open(self): +    pass + +  def close(self): +    pass + +  def read(self, sz): +    pass + +  def readAll(self, sz): +    buff = '' +    have = 0 +    while (have < sz): +      chunk = self.read(sz-have) +      have += len(chunk) +      buff += chunk + +      if len(chunk) == 0: +        raise EOFError() + +    return buff + +  def write(self, buf): +    pass + +  def flush(self): +    pass + +# This class should be thought of as an interface. +class CReadableTransport: +  """base class for transports that are readable from C""" + +  # TODO(dreiss): Think about changing this interface to allow us to use +  #               a (Python, not c) StringIO instead, because it allows +  #               you to write after reading. + +  # NOTE: This is a classic class, so properties will NOT work +  #       correctly for setting. +  @property +  def cstringio_buf(self): +    """A cStringIO buffer that contains the current chunk we are reading.""" +    pass + +  def cstringio_refill(self, partialread, reqlen): +    """Refills cstringio_buf. + +    Returns the currently used buffer (which can but need not be the same as +    the old cstringio_buf). partialread is what the C code has read from the +    buffer, and should be inserted into the buffer before any more reads.  The +    return value must be a new, not borrowed reference.  Something along the +    lines of self._buf should be fine. + +    If reqlen bytes can't be read, throw EOFError. +    """ +    pass + +class TServerTransportBase: + +  """Base class for Thrift server transports.""" + +  def listen(self): +    pass + +  def accept(self): +    pass + +  def close(self): +    pass + +class TTransportFactoryBase: + +  """Base class for a Transport Factory""" + +  def getTransport(self, trans): +    return trans + +class TBufferedTransportFactory: + +  """Factory transport that builds buffered transports""" + +  def getTransport(self, trans): +    buffered = TBufferedTransport(trans) +    return buffered + + +class TBufferedTransport(TTransportBase,CReadableTransport): + +  """Class that wraps another transport and buffers its I/O. + +  The implementation uses a (configurable) fixed-size read buffer +  but buffers all writes until a flush is performed. +  """ + +  DEFAULT_BUFFER = 4096 + +  def __init__(self, trans, rbuf_size = DEFAULT_BUFFER): +    self.__trans = trans +    self.__wbuf = StringIO() +    self.__rbuf = StringIO("") +    self.__rbuf_size = rbuf_size + +  def isOpen(self): +    return self.__trans.isOpen() + +  def open(self): +    return self.__trans.open() + +  def close(self): +    return self.__trans.close() + +  def read(self, sz): +    ret = self.__rbuf.read(sz) +    if len(ret) != 0: +      return ret + +    self.__rbuf = StringIO(self.__trans.read(max(sz, self.__rbuf_size))) +    return self.__rbuf.read(sz) + +  def write(self, buf): +    self.__wbuf.write(buf) + +  def flush(self): +    out = self.__wbuf.getvalue() +    # reset wbuf before write/flush to preserve state on underlying failure +    self.__wbuf = StringIO() +    self.__trans.write(out) +    self.__trans.flush() + +  # Implement the CReadableTransport interface. +  @property +  def cstringio_buf(self): +    return self.__rbuf + +  def cstringio_refill(self, partialread, reqlen): +    retstring = partialread +    if reqlen < self.__rbuf_size: +      # try to make a read of as much as we can. +      retstring += self.__trans.read(self.__rbuf_size) + +    # but make sure we do read reqlen bytes. +    if len(retstring) < reqlen: +      retstring += self.__trans.readAll(reqlen - len(retstring)) + +    self.__rbuf = StringIO(retstring) +    return self.__rbuf + +class TMemoryBuffer(TTransportBase, CReadableTransport): +  """Wraps a cStringIO object as a TTransport. + +  NOTE: Unlike the C++ version of this class, you cannot write to it +        then immediately read from it.  If you want to read from a +        TMemoryBuffer, you must either pass a string to the constructor. +  TODO(dreiss): Make this work like the C++ version. +  """ + +  def __init__(self, value=None): +    """value -- a value to read from for stringio + +    If value is set, this will be a transport for reading, +    otherwise, it is for writing""" +    if value is not None: +      self._buffer = StringIO(value) +    else: +      self._buffer = StringIO() + +  def isOpen(self): +    return not self._buffer.closed + +  def open(self): +    pass + +  def close(self): +    self._buffer.close() + +  def read(self, sz): +    return self._buffer.read(sz) + +  def write(self, buf): +    self._buffer.write(buf) + +  def flush(self): +    pass + +  def getvalue(self): +    return self._buffer.getvalue() + +  # Implement the CReadableTransport interface. +  @property +  def cstringio_buf(self): +    return self._buffer + +  def cstringio_refill(self, partialread, reqlen): +    # only one shot at reading... +    raise EOFError() + +class TFramedTransportFactory: + +  """Factory transport that builds framed transports""" + +  def getTransport(self, trans): +    framed = TFramedTransport(trans) +    return framed + + +class TFramedTransport(TTransportBase, CReadableTransport): + +  """Class that wraps another transport and frames its I/O when writing.""" + +  def __init__(self, trans,): +    self.__trans = trans +    self.__rbuf = StringIO() +    self.__wbuf = StringIO() + +  def isOpen(self): +    return self.__trans.isOpen() + +  def open(self): +    return self.__trans.open() + +  def close(self): +    return self.__trans.close() + +  def read(self, sz): +    ret = self.__rbuf.read(sz) +    if len(ret) != 0: +      return ret + +    self.readFrame() +    return self.__rbuf.read(sz) + +  def readFrame(self): +    buff = self.__trans.readAll(4) +    sz, = unpack('!i', buff) +    self.__rbuf = StringIO(self.__trans.readAll(sz)) + +  def write(self, buf): +    self.__wbuf.write(buf) + +  def flush(self): +    wout = self.__wbuf.getvalue() +    wsz = len(wout) +    # reset wbuf before write/flush to preserve state on underlying failure +    self.__wbuf = StringIO() +    # N.B.: Doing this string concatenation is WAY cheaper than making +    # two separate calls to the underlying socket object. Socket writes in +    # Python turn out to be REALLY expensive, but it seems to do a pretty +    # good job of managing string buffer operations without excessive copies +    buf = pack("!i", wsz) + wout +    self.__trans.write(buf) +    self.__trans.flush() + +  # Implement the CReadableTransport interface. +  @property +  def cstringio_buf(self): +    return self.__rbuf + +  def cstringio_refill(self, prefix, reqlen): +    # self.__rbuf will already be empty here because fastbinary doesn't +    # ask for a refill until the previous buffer is empty.  Therefore, +    # we can start reading new frames immediately. +    while len(prefix) < reqlen: +      self.readFrame() +      prefix += self.__rbuf.getvalue() +    self.__rbuf = StringIO(prefix) +    return self.__rbuf + + +class TFileObjectTransport(TTransportBase): +  """Wraps a file-like object to make it work as a Thrift transport.""" + +  def __init__(self, fileobj): +    self.fileobj = fileobj + +  def isOpen(self): +    return True + +  def close(self): +    self.fileobj.close() + +  def read(self, sz): +    return self.fileobj.read(sz) + +  def write(self, buf): +    self.fileobj.write(buf) + +  def flush(self): +    self.fileobj.flush() diff --git a/module/lib/thrift/transport/TTwisted.py b/module/lib/thrift/transport/TTwisted.py new file mode 100644 index 000000000..b6dcb4e0b --- /dev/null +++ b/module/lib/thrift/transport/TTwisted.py @@ -0,0 +1,219 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from zope.interface import implements, Interface, Attribute +from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \ +    connectionDone +from twisted.internet import defer +from twisted.protocols import basic +from twisted.python import log +from twisted.web import server, resource, http + +from thrift.transport import TTransport +from cStringIO import StringIO + + +class TMessageSenderTransport(TTransport.TTransportBase): + +    def __init__(self): +        self.__wbuf = StringIO() + +    def write(self, buf): +        self.__wbuf.write(buf) + +    def flush(self): +        msg = self.__wbuf.getvalue() +        self.__wbuf = StringIO() +        self.sendMessage(msg) + +    def sendMessage(self, message): +        raise NotImplementedError + + +class TCallbackTransport(TMessageSenderTransport): + +    def __init__(self, func): +        TMessageSenderTransport.__init__(self) +        self.func = func + +    def sendMessage(self, message): +        self.func(message) + + +class ThriftClientProtocol(basic.Int32StringReceiver): + +    MAX_LENGTH = 2 ** 31 - 1 + +    def __init__(self, client_class, iprot_factory, oprot_factory=None): +        self._client_class = client_class +        self._iprot_factory = iprot_factory +        if oprot_factory is None: +            self._oprot_factory = iprot_factory +        else: +            self._oprot_factory = oprot_factory + +        self.recv_map = {} +        self.started = defer.Deferred() + +    def dispatch(self, msg): +        self.sendString(msg) + +    def connectionMade(self): +        tmo = TCallbackTransport(self.dispatch) +        self.client = self._client_class(tmo, self._oprot_factory) +        self.started.callback(self.client) + +    def connectionLost(self, reason=connectionDone): +        for k,v in self.client._reqs.iteritems(): +            tex = TTransport.TTransportException( +                type=TTransport.TTransportException.END_OF_FILE, +                message='Connection closed') +            v.errback(tex) + +    def stringReceived(self, frame): +        tr = TTransport.TMemoryBuffer(frame) +        iprot = self._iprot_factory.getProtocol(tr) +        (fname, mtype, rseqid) = iprot.readMessageBegin() + +        try: +            method = self.recv_map[fname] +        except KeyError: +            method = getattr(self.client, 'recv_' + fname) +            self.recv_map[fname] = method + +        method(iprot, mtype, rseqid) + + +class ThriftServerProtocol(basic.Int32StringReceiver): + +    MAX_LENGTH = 2 ** 31 - 1 + +    def dispatch(self, msg): +        self.sendString(msg) + +    def processError(self, error): +        self.transport.loseConnection() + +    def processOk(self, _, tmo): +        msg = tmo.getvalue() + +        if len(msg) > 0: +            self.dispatch(msg) + +    def stringReceived(self, frame): +        tmi = TTransport.TMemoryBuffer(frame) +        tmo = TTransport.TMemoryBuffer() + +        iprot = self.factory.iprot_factory.getProtocol(tmi) +        oprot = self.factory.oprot_factory.getProtocol(tmo) + +        d = self.factory.processor.process(iprot, oprot) +        d.addCallbacks(self.processOk, self.processError, +            callbackArgs=(tmo,)) + + +class IThriftServerFactory(Interface): + +    processor = Attribute("Thrift processor") + +    iprot_factory = Attribute("Input protocol factory") + +    oprot_factory = Attribute("Output protocol factory") + + +class IThriftClientFactory(Interface): + +    client_class = Attribute("Thrift client class") + +    iprot_factory = Attribute("Input protocol factory") + +    oprot_factory = Attribute("Output protocol factory") + + +class ThriftServerFactory(ServerFactory): + +    implements(IThriftServerFactory) + +    protocol = ThriftServerProtocol + +    def __init__(self, processor, iprot_factory, oprot_factory=None): +        self.processor = processor +        self.iprot_factory = iprot_factory +        if oprot_factory is None: +            self.oprot_factory = iprot_factory +        else: +            self.oprot_factory = oprot_factory + + +class ThriftClientFactory(ClientFactory): + +    implements(IThriftClientFactory) + +    protocol = ThriftClientProtocol + +    def __init__(self, client_class, iprot_factory, oprot_factory=None): +        self.client_class = client_class +        self.iprot_factory = iprot_factory +        if oprot_factory is None: +            self.oprot_factory = iprot_factory +        else: +            self.oprot_factory = oprot_factory + +    def buildProtocol(self, addr): +        p = self.protocol(self.client_class, self.iprot_factory, +            self.oprot_factory) +        p.factory = self +        return p + + +class ThriftResource(resource.Resource): + +    allowedMethods = ('POST',) + +    def __init__(self, processor, inputProtocolFactory, +        outputProtocolFactory=None): +        resource.Resource.__init__(self) +        self.inputProtocolFactory = inputProtocolFactory +        if outputProtocolFactory is None: +            self.outputProtocolFactory = inputProtocolFactory +        else: +            self.outputProtocolFactory = outputProtocolFactory +        self.processor = processor + +    def getChild(self, path, request): +        return self + +    def _cbProcess(self, _, request, tmo): +        msg = tmo.getvalue() +        request.setResponseCode(http.OK) +        request.setHeader("content-type", "application/x-thrift") +        request.write(msg) +        request.finish() + +    def render_POST(self, request): +        request.content.seek(0, 0) +        data = request.content.read() +        tmi = TTransport.TMemoryBuffer(data) +        tmo = TTransport.TMemoryBuffer() + +        iprot = self.inputProtocolFactory.getProtocol(tmi) +        oprot = self.outputProtocolFactory.getProtocol(tmo) + +        d = self.processor.process(iprot, oprot) +        d.addCallback(self._cbProcess, request, tmo) +        return server.NOT_DONE_YET diff --git a/module/lib/thrift/transport/__init__.py b/module/lib/thrift/transport/__init__.py new file mode 100644 index 000000000..02c6048a9 --- /dev/null +++ b/module/lib/thrift/transport/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +__all__ = ['TTransport', 'TSocket', 'THttpClient']  | 
