diff options
| author | 2013-12-15 14:36:38 +0100 | |
|---|---|---|
| committer | 2013-12-15 14:36:38 +0100 | |
| commit | 367172406d1382f2fdd39936bde31ebcbcd1c56f (patch) | |
| tree | 6abd7b2adffdd7fbb4eb35a9171c25b0e8eee4a0 /pyload/lib | |
| parent | more options to get webUI through proxy working (diff) | |
| download | pyload-367172406d1382f2fdd39936bde31ebcbcd1c56f.tar.xz | |
updated pywebsocket
Diffstat (limited to 'pyload/lib')
| -rw-r--r-- | pyload/lib/mod_pywebsocket/_stream_base.py | 46 | ||||
| -rw-r--r-- | pyload/lib/mod_pywebsocket/_stream_hybi.py | 24 | ||||
| -rw-r--r-- | pyload/lib/mod_pywebsocket/deflate_stream_extension.py | 69 | ||||
| -rw-r--r-- | pyload/lib/mod_pywebsocket/dispatch.py | 10 | ||||
| -rw-r--r-- | pyload/lib/mod_pywebsocket/extensions.py | 171 | ||||
| -rw-r--r-- | pyload/lib/mod_pywebsocket/handshake/_base.py | 72 | ||||
| -rw-r--r-- | pyload/lib/mod_pywebsocket/handshake/hybi.py | 64 | ||||
| -rw-r--r-- | pyload/lib/mod_pywebsocket/handshake/hybi00.py | 63 | ||||
| -rw-r--r-- | pyload/lib/mod_pywebsocket/headerparserhandler.py | 28 | ||||
| -rw-r--r-- | pyload/lib/mod_pywebsocket/mux.py | 470 | ||||
| -rwxr-xr-x | pyload/lib/mod_pywebsocket/standalone.py | 249 | ||||
| -rw-r--r-- | pyload/lib/mod_pywebsocket/util.py | 43 | 
12 files changed, 947 insertions, 362 deletions
| diff --git a/pyload/lib/mod_pywebsocket/_stream_base.py b/pyload/lib/mod_pywebsocket/_stream_base.py index 60fb33d2c..8235666bb 100644 --- a/pyload/lib/mod_pywebsocket/_stream_base.py +++ b/pyload/lib/mod_pywebsocket/_stream_base.py @@ -39,6 +39,8 @@  # writing/reading. +import socket +  from mod_pywebsocket import util @@ -109,20 +111,34 @@ class StreamBase(object):              ConnectionTerminatedException: when read returns empty string.          """ -        bytes = self._request.connection.read(length) -        if not bytes: +        try: +            read_bytes = self._request.connection.read(length) +            if not read_bytes: +                raise ConnectionTerminatedException( +                    'Receiving %d byte failed. Peer (%r) closed connection' % +                    (length, (self._request.connection.remote_addr,))) +            return read_bytes +        except socket.error, e: +            # Catch a socket.error. Because it's not a child class of the +            # IOError prior to Python 2.6, we cannot omit this except clause. +            # Use %s rather than %r for the exception to use human friendly +            # format. +            raise ConnectionTerminatedException( +                'Receiving %d byte failed. socket.error (%s) occurred' % +                (length, e)) +        except IOError, e: +            # Also catch an IOError because mod_python throws it.              raise ConnectionTerminatedException( -                'Receiving %d byte failed. Peer (%r) closed connection' % -                (length, (self._request.connection.remote_addr,))) -        return bytes +                'Receiving %d byte failed. IOError (%s) occurred' % +                (length, e)) -    def _write(self, bytes): +    def _write(self, bytes_to_write):          """Writes given bytes to connection. In case we catch any exception,          prepends remote address to the exception message and raise again.          """          try: -            self._request.connection.write(bytes) +            self._request.connection.write(bytes_to_write)          except Exception, e:              util.prepend_message_to_exception(                      'Failed to send message to %r: ' % @@ -138,12 +154,12 @@ class StreamBase(object):              ConnectionTerminatedException: when read returns empty string.          """ -        bytes = [] +        read_bytes = []          while length > 0: -            new_bytes = self._read(length) -            bytes.append(new_bytes) -            length -= len(new_bytes) -        return ''.join(bytes) +            new_read_bytes = self._read(length) +            read_bytes.append(new_read_bytes) +            length -= len(new_read_bytes) +        return ''.join(read_bytes)      def _read_until(self, delim_char):          """Reads bytes until we encounter delim_char. The result will not @@ -153,13 +169,13 @@ class StreamBase(object):              ConnectionTerminatedException: when read returns empty string.          """ -        bytes = [] +        read_bytes = []          while True:              ch = self._read(1)              if ch == delim_char:                  break -            bytes.append(ch) -        return ''.join(bytes) +            read_bytes.append(ch) +        return ''.join(read_bytes)  # vi:sts=4 sw=4 et diff --git a/pyload/lib/mod_pywebsocket/_stream_hybi.py b/pyload/lib/mod_pywebsocket/_stream_hybi.py index bd158fa6b..1c43249a4 100644 --- a/pyload/lib/mod_pywebsocket/_stream_hybi.py +++ b/pyload/lib/mod_pywebsocket/_stream_hybi.py @@ -280,7 +280,7 @@ def parse_frame(receive_bytes, logger=None,      if logger.isEnabledFor(common.LOGLEVEL_FINE):          unmask_start = time.time() -    bytes = masker.mask(raw_payload_bytes) +    unmasked_bytes = masker.mask(raw_payload_bytes)      if logger.isEnabledFor(common.LOGLEVEL_FINE):          logger.log( @@ -288,7 +288,7 @@ def parse_frame(receive_bytes, logger=None,              'Done unmasking payload data at %s MB/s',              payload_length / (time.time() - unmask_start) / 1000 / 1000) -    return opcode, bytes, fin, rsv1, rsv2, rsv3 +    return opcode, unmasked_bytes, fin, rsv1, rsv2, rsv3  class FragmentedFrameBuilder(object): @@ -403,9 +403,6 @@ class StreamOptions(object):          self.encode_text_message_to_utf8 = True          self.mask_send = False          self.unmask_receive = True -        # RFC6455 disallows fragmented control frames, but mux extension -        # relaxes the restriction. -        self.allow_fragmented_control_frame = False  class Stream(StreamBase): @@ -463,10 +460,10 @@ class Stream(StreamBase):                             unmask_receive=self._options.unmask_receive)      def _receive_frame_as_frame_object(self): -        opcode, bytes, fin, rsv1, rsv2, rsv3 = self._receive_frame() +        opcode, unmasked_bytes, fin, rsv1, rsv2, rsv3 = self._receive_frame()          return Frame(fin=fin, rsv1=rsv1, rsv2=rsv2, rsv3=rsv3, -                     opcode=opcode, payload=bytes) +                     opcode=opcode, payload=unmasked_bytes)      def receive_filtered_frame(self):          """Receives a frame and applies frame filters and message filters. @@ -602,8 +599,7 @@ class Stream(StreamBase):              else:                  # Start of fragmentation frame -                if (not self._options.allow_fragmented_control_frame and -                    common.is_control_opcode(frame.opcode)): +                if common.is_control_opcode(frame.opcode):                      raise InvalidFrameException(                          'Control frames must not be fragmented') @@ -672,7 +668,7 @@ class Stream(StreamBase):                  reason = ''          self._send_closing_handshake(code, reason)          self._logger.debug( -            'Sent ack for client-initiated closing handshake ' +            'Acknowledged closing handshake initiated by the peer '              '(code=%r, reason=%r)', code, reason)      def _process_ping_message(self, message): @@ -815,13 +811,15 @@ class Stream(StreamBase):          self._write(frame) -    def close_connection(self, code=common.STATUS_NORMAL_CLOSURE, reason=''): +    def close_connection(self, code=common.STATUS_NORMAL_CLOSURE, reason='', +                         wait_response=True):          """Closes a WebSocket connection.          Args:              code: Status code for close frame. If code is None, a close                  frame with empty body will be sent.              reason: string representing close reason. +            wait_response: True when caller want to wait the response.          Raises:              BadOperationException: when reason is specified with code None              or reason is not an instance of both str and unicode. @@ -844,11 +842,11 @@ class Stream(StreamBase):          self._send_closing_handshake(code, reason)          self._logger.debug( -            'Sent server-initiated closing handshake (code=%r, reason=%r)', +            'Initiated closing handshake (code=%r, reason=%r)',              code, reason)          if (code == common.STATUS_GOING_AWAY or -            code == common.STATUS_PROTOCOL_ERROR): +            code == common.STATUS_PROTOCOL_ERROR) or not wait_response:              # It doesn't make sense to wait for a close frame if the reason is              # protocol error or that the server is going away. For some of              # other reasons, it might not make sense to wait for a close frame, diff --git a/pyload/lib/mod_pywebsocket/deflate_stream_extension.py b/pyload/lib/mod_pywebsocket/deflate_stream_extension.py new file mode 100644 index 000000000..d2ba477c4 --- /dev/null +++ b/pyload/lib/mod_pywebsocket/deflate_stream_extension.py @@ -0,0 +1,69 @@ +# Copyright 2013, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +#     * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +#     * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +#     * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +from mod_pywebsocket import common +from mod_pywebsocket.extensions import _available_processors +from mod_pywebsocket.extensions import ExtensionProcessorInterface +from mod_pywebsocket import util + + +class DeflateStreamExtensionProcessor(ExtensionProcessorInterface): +    """WebSocket DEFLATE stream extension processor. + +    Specification: +    Section 9.2.1 in +    http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-10 +    """ + +    def __init__(self, request): +        ExtensionProcessorInterface.__init__(self, request) +        self._logger = util.get_class_logger(self) + +    def name(self): +        return common.DEFLATE_STREAM_EXTENSION + +    def _get_extension_response_internal(self): +        if len(self._request.get_parameter_names()) != 0: +            return None + +        self._logger.debug( +            'Enable %s extension', common.DEFLATE_STREAM_EXTENSION) + +        return common.ExtensionParameter(common.DEFLATE_STREAM_EXTENSION) + +    def _setup_stream_options_internal(self, stream_options): +        stream_options.deflate_stream = True + + +_available_processors[common.DEFLATE_STREAM_EXTENSION] = ( +    DeflateStreamExtensionProcessor) + + +# vi:sts=4 sw=4 et diff --git a/pyload/lib/mod_pywebsocket/dispatch.py b/pyload/lib/mod_pywebsocket/dispatch.py index 25905f180..96c91e0c9 100644 --- a/pyload/lib/mod_pywebsocket/dispatch.py +++ b/pyload/lib/mod_pywebsocket/dispatch.py @@ -255,6 +255,9 @@ class Dispatcher(object):          try:              do_extra_handshake_(request)          except handshake.AbortedByUserException, e: +            # Re-raise to tell the caller of this function to finish this +            # connection without sending any error. +            self._logger.debug('%s', util.get_stack_trace())              raise          except Exception, e:              util.prepend_message_to_exception( @@ -294,11 +297,12 @@ class Dispatcher(object):                  request.ws_stream.close_connection()          # Catch non-critical exceptions the handler didn't handle.          except handshake.AbortedByUserException, e: -            self._logger.debug('%s', e) +            self._logger.debug('%s', util.get_stack_trace())              raise          except msgutil.BadOperationException, e:              self._logger.debug('%s', e) -            request.ws_stream.close_connection(common.STATUS_ABNORMAL_CLOSURE) +            request.ws_stream.close_connection( +                common.STATUS_INTERNAL_ENDPOINT_ERROR)          except msgutil.InvalidFrameException, e:              # InvalidFrameException must be caught before              # ConnectionTerminatedException that catches InvalidFrameException. @@ -314,6 +318,8 @@ class Dispatcher(object):          except msgutil.ConnectionTerminatedException, e:              self._logger.debug('%s', e)          except Exception, e: +            # Any other exceptions are forwarded to the caller of this +            # function.              util.prepend_message_to_exception(                  '%s raised exception for %s: ' % (                      _TRANSFER_DATA_HANDLER_NAME, request.ws_resource), diff --git a/pyload/lib/mod_pywebsocket/extensions.py b/pyload/lib/mod_pywebsocket/extensions.py index 03dbf9ee1..18841ed92 100644 --- a/pyload/lib/mod_pywebsocket/extensions.py +++ b/pyload/lib/mod_pywebsocket/extensions.py @@ -38,47 +38,42 @@ _available_processors = {}  class ExtensionProcessorInterface(object): -    def name(self): -        return None +    def __init__(self, request): +        self._request = request +        self._active = True -    def get_extension_response(self): +    def request(self): +        return self._request + +    def name(self):          return None -    def setup_stream_options(self, stream_options): +    def check_consistency_with_other_processors(self, processors):          pass +    def set_active(self, active): +        self._active = active -class DeflateStreamExtensionProcessor(ExtensionProcessorInterface): -    """WebSocket DEFLATE stream extension processor. - -    Specification: -    Section 9.2.1 in -    http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-10 -    """ - -    def __init__(self, request): -        self._logger = util.get_class_logger(self) - -        self._request = request +    def is_active(self): +        return self._active -    def name(self): -        return common.DEFLATE_STREAM_EXTENSION +    def _get_extension_response_internal(self): +        return None      def get_extension_response(self): -        if len(self._request.get_parameter_names()) != 0: -            return None - -        self._logger.debug( -            'Enable %s extension', common.DEFLATE_STREAM_EXTENSION) +        if self._active: +            response = self._get_extension_response_internal() +            if response is None: +                self._active = False +            return response +        return None -        return common.ExtensionParameter(common.DEFLATE_STREAM_EXTENSION) +    def _setup_stream_options_internal(self, stream_options): +        pass      def setup_stream_options(self, stream_options): -        stream_options.deflate_stream = True - - -_available_processors[common.DEFLATE_STREAM_EXTENSION] = ( -    DeflateStreamExtensionProcessor) +        if self._active: +            self._setup_stream_options_internal(stream_options)  def _log_compression_ratio(logger, original_bytes, total_original_bytes, @@ -109,6 +104,17 @@ def _log_decompression_ratio(logger, received_bytes, total_received_bytes,          (ratio, average_ratio)) +def _validate_window_bits(bits): +    if bits is not None: +        try: +            bits = int(bits) +        except ValueError, e: +            return False +        if bits < 8 or bits > 15: +            return False +    return True + +  class DeflateFrameExtensionProcessor(ExtensionProcessorInterface):      """WebSocket Per-frame DEFLATE extension processor. @@ -120,10 +126,9 @@ class DeflateFrameExtensionProcessor(ExtensionProcessorInterface):      _NO_CONTEXT_TAKEOVER_PARAM = 'no_context_takeover'      def __init__(self, request): +        ExtensionProcessorInterface.__init__(self, request)          self._logger = util.get_class_logger(self) -        self._request = request -          self._response_window_bits = None          self._response_no_context_takeover = False          self._bfinal = False @@ -143,7 +148,7 @@ class DeflateFrameExtensionProcessor(ExtensionProcessorInterface):      def name(self):          return common.DEFLATE_FRAME_EXTENSION -    def get_extension_response(self): +    def _get_extension_response_internal(self):          # Any unknown parameter will be just ignored.          window_bits = self._request.get_parameter_value( @@ -155,13 +160,8 @@ class DeflateFrameExtensionProcessor(ExtensionProcessorInterface):                  self._NO_CONTEXT_TAKEOVER_PARAM) is not None):              return None -        if window_bits is not None: -            try: -                window_bits = int(window_bits) -            except ValueError, e: -                return None -            if window_bits < 8 or window_bits > 15: -                return None +        if not _validate_window_bits(window_bits): +            return None          self._deflater = util._RFC1979Deflater(              window_bits, no_context_takeover) @@ -191,7 +191,7 @@ class DeflateFrameExtensionProcessor(ExtensionProcessorInterface):          return response -    def setup_stream_options(self, stream_options): +    def _setup_stream_options_internal(self, stream_options):          class _OutgoingFilter(object): @@ -311,8 +311,8 @@ class CompressionExtensionProcessorBase(ExtensionProcessorInterface):      _METHOD_PARAM = 'method'      def __init__(self, request): +        ExtensionProcessorInterface.__init__(self, request)          self._logger = util.get_class_logger(self) -        self._request = request          self._compression_method_name = None          self._compression_processor = None          self._compression_processor_hook = None @@ -357,7 +357,7 @@ class CompressionExtensionProcessorBase(ExtensionProcessorInterface):          self._compression_processor = compression_processor          return processor_response -    def get_extension_response(self): +    def _get_extension_response_internal(self):          processor_response = self._get_compression_processor_response()          if processor_response is None:              return None @@ -372,7 +372,7 @@ class CompressionExtensionProcessorBase(ExtensionProcessorInterface):              (self._request.name(), self._compression_method_name))          return response -    def setup_stream_options(self, stream_options): +    def _setup_stream_options_internal(self, stream_options):          if self._compression_processor is None:              return          self._compression_processor.setup_stream_options(stream_options) @@ -418,7 +418,7 @@ class DeflateMessageProcessor(ExtensionProcessorInterface):      _C2S_NO_CONTEXT_TAKEOVER_PARAM = 'c2s_no_context_takeover'      def __init__(self, request): -        self._request = request +        ExtensionProcessorInterface.__init__(self, request)          self._logger = util.get_class_logger(self)          self._c2s_max_window_bits = None @@ -445,18 +445,13 @@ class DeflateMessageProcessor(ExtensionProcessorInterface):      def name(self):          return 'deflate' -    def get_extension_response(self): +    def _get_extension_response_internal(self):          # Any unknown parameter will be just ignored.          s2c_max_window_bits = self._request.get_parameter_value(              self._S2C_MAX_WINDOW_BITS_PARAM) -        if s2c_max_window_bits is not None: -            try: -                s2c_max_window_bits = int(s2c_max_window_bits) -            except ValueError, e: -                return None -            if s2c_max_window_bits < 8 or s2c_max_window_bits > 15: -                return None +        if not _validate_window_bits(s2c_max_window_bits): +            return None          s2c_no_context_takeover = self._request.has_parameter(              self._S2C_NO_CONTEXT_TAKEOVER_PARAM) @@ -502,7 +497,7 @@ class DeflateMessageProcessor(ExtensionProcessorInterface):          return response -    def setup_stream_options(self, stream_options): +    def _setup_stream_options_internal(self, stream_options):          class _OutgoingMessageFilter(object):              def __init__(self, parent): @@ -676,42 +671,72 @@ class MuxExtensionProcessor(ExtensionProcessorInterface):      _QUOTA_PARAM = 'quota'      def __init__(self, request): -        self._request = request +        ExtensionProcessorInterface.__init__(self, request) +        self._quota = 0 +        self._extensions = []      def name(self):          return common.MUX_EXTENSION -    def get_extension_response(self, ws_request, -                               logical_channel_extensions): -        # Mux extension cannot be used after extensions that depend on -        # frame boundary, extension data field, or any reserved bits -        # which are attributed to each frame. -        for extension in logical_channel_extensions: -            name = extension.name() -            if (name == common.PERFRAME_COMPRESSION_EXTENSION or -                name == common.DEFLATE_FRAME_EXTENSION or -                name == common.X_WEBKIT_DEFLATE_FRAME_EXTENSION): -                return None - +    def check_consistency_with_other_processors(self, processors): +        before_mux = True +        for processor in processors: +            name = processor.name() +            if name == self.name(): +                before_mux = False +                continue +            if not processor.is_active(): +                continue +            if before_mux: +                # Mux extension cannot be used after extensions +                # that depend on frame boundary, extension data field, or any +                # reserved bits which are attributed to each frame. +                if (name == common.PERFRAME_COMPRESSION_EXTENSION or +                    name == common.DEFLATE_FRAME_EXTENSION or +                    name == common.X_WEBKIT_DEFLATE_FRAME_EXTENSION): +                    self.set_active(False) +                    return +            else: +                # Mux extension should not be applied before any history-based +                # compression extension. +                if (name == common.PERFRAME_COMPRESSION_EXTENSION or +                    name == common.DEFLATE_FRAME_EXTENSION or +                    name == common.X_WEBKIT_DEFLATE_FRAME_EXTENSION or +                    name == common.PERMESSAGE_COMPRESSION_EXTENSION or +                    name == common.X_WEBKIT_PERMESSAGE_COMPRESSION_EXTENSION): +                    self.set_active(False) +                    return + +    def _get_extension_response_internal(self): +        self._active = False          quota = self._request.get_parameter_value(self._QUOTA_PARAM) -        if quota is None: -            ws_request.mux_quota = 0 -        else: +        if quota is not None:              try:                  quota = int(quota)              except ValueError, e:                  return None              if quota < 0 or quota >= 2 ** 32:                  return None -            ws_request.mux_quota = quota +            self._quota = quota -        ws_request.mux = True -        ws_request.mux_extensions = logical_channel_extensions +        self._active = True          return common.ExtensionParameter(common.MUX_EXTENSION) -    def setup_stream_options(self, stream_options): +    def _setup_stream_options_internal(self, stream_options):          pass +    def set_quota(self, quota): +        self._quota = quota + +    def quota(self): +        return self._quota + +    def set_extensions(self, extensions): +        self._extensions = extensions + +    def extensions(self): +        return self._extensions +  _available_processors[common.MUX_EXTENSION] = MuxExtensionProcessor diff --git a/pyload/lib/mod_pywebsocket/handshake/_base.py b/pyload/lib/mod_pywebsocket/handshake/_base.py index e5c94ca90..c993a584b 100644 --- a/pyload/lib/mod_pywebsocket/handshake/_base.py +++ b/pyload/lib/mod_pywebsocket/handshake/_base.py @@ -84,42 +84,29 @@ def get_default_port(is_secure):          return common.DEFAULT_WEB_SOCKET_PORT -def validate_subprotocol(subprotocol, hixie): +def validate_subprotocol(subprotocol):      """Validate a value in the Sec-WebSocket-Protocol field. -    See -    - RFC 6455: Section 4.1., 4.2.2., and 4.3. -    - HyBi 00: Section 4.1. Opening handshake - -    Args: -         hixie: if True, checks if characters in subprotocol are in range -                between U+0020 and U+007E. It's required by HyBi 00 but not by -                RFC 6455. +    See the Section 4.1., 4.2.2., and 4.3. of RFC 6455.      """      if not subprotocol:          raise HandshakeException('Invalid subprotocol name: empty') -    if hixie: -        # Parameter should be in the range U+0020 to U+007E. -        for c in subprotocol: -            if not 0x20 <= ord(c) <= 0x7e: -                raise HandshakeException( -                    'Illegal character in subprotocol name: %r' % c) -    else: -        # Parameter should be encoded HTTP token. -        state = http_header_util.ParsingState(subprotocol) -        token = http_header_util.consume_token(state) -        rest = http_header_util.peek(state) -        # If |rest| is not None, |subprotocol| is not one token or invalid. If -        # |rest| is None, |token| must not be None because |subprotocol| is -        # concatenation of |token| and |rest| and is not None. -        if rest is not None: -            raise HandshakeException('Invalid non-token string in subprotocol ' -                                     'name: %r' % rest) + +    # Parameter should be encoded HTTP token. +    state = http_header_util.ParsingState(subprotocol) +    token = http_header_util.consume_token(state) +    rest = http_header_util.peek(state) +    # If |rest| is not None, |subprotocol| is not one token or invalid. If +    # |rest| is None, |token| must not be None because |subprotocol| is +    # concatenation of |token| and |rest| and is not None. +    if rest is not None: +        raise HandshakeException('Invalid non-token string in subprotocol ' +                                 'name: %r' % rest)  def parse_host_header(request): -    fields = request.headers_in['Host'].split(':', 1) +    fields = request.headers_in[common.HOST_HEADER].split(':', 1)      if len(fields) == 1:          return fields[0], get_default_port(request.is_https())      try: @@ -132,27 +119,6 @@ def format_header(name, value):      return '%s: %s\r\n' % (name, value) -def build_location(request): -    """Build WebSocket location for request.""" -    location_parts = [] -    if request.is_https(): -        location_parts.append(common.WEB_SOCKET_SECURE_SCHEME) -    else: -        location_parts.append(common.WEB_SOCKET_SCHEME) -    location_parts.append('://') -    host, port = parse_host_header(request) -    connection_port = request.connection.local_addr[1] -    if port != connection_port: -        raise HandshakeException('Header/connection port mismatch: %d/%d' % -                                 (port, connection_port)) -    location_parts.append(host) -    if (port != get_default_port(request.is_https())): -        location_parts.append(':') -        location_parts.append(str(port)) -    location_parts.append(request.uri) -    return ''.join(location_parts) - -  def get_mandatory_header(request, key):      value = request.headers_in.get(key)      if value is None: @@ -180,16 +146,6 @@ def check_request_line(request):                                   request.protocol) -def check_header_lines(request, mandatory_headers): -    check_request_line(request) - -    # The expected field names, and the meaning of their corresponding -    # values, are as follows. -    #  |Upgrade| and |Connection| -    for key, expected_value in mandatory_headers: -        validate_mandatory_header(request, key, expected_value) - -  def parse_token_list(data):      """Parses a header value which follows 1#token and returns parsed elements      as a list of strings. diff --git a/pyload/lib/mod_pywebsocket/handshake/hybi.py b/pyload/lib/mod_pywebsocket/handshake/hybi.py index fc0e2a096..669097d77 100644 --- a/pyload/lib/mod_pywebsocket/handshake/hybi.py +++ b/pyload/lib/mod_pywebsocket/handshake/hybi.py @@ -48,6 +48,7 @@ import os  import re  from mod_pywebsocket import common +from mod_pywebsocket import deflate_stream_extension  from mod_pywebsocket.extensions import get_extension_processor  from mod_pywebsocket.handshake._base import check_request_line  from mod_pywebsocket.handshake._base import format_header @@ -180,44 +181,57 @@ class Handshaker(object):                          processors.append(processor)              self._request.ws_extension_processors = processors +            # List of extra headers. The extra handshake handler may add header +            # data as name/value pairs to this list and pywebsocket appends +            # them to the WebSocket handshake. +            self._request.extra_headers = [] +              # Extra handshake handler may modify/remove processors.              self._dispatcher.do_extra_handshake(self._request)              processors = filter(lambda processor: processor is not None,                                  self._request.ws_extension_processors) +            # Ask each processor if there are extensions on the request which +            # cannot co-exist. When processor decided other processors cannot +            # co-exist with it, the processor marks them (or itself) as +            # "inactive". The first extension processor has the right to +            # make the final call. +            for processor in reversed(processors): +                if processor.is_active(): +                    processor.check_consistency_with_other_processors( +                        processors) +            processors = filter(lambda processor: processor.is_active(), +                                processors) +              accepted_extensions = [] -            # We need to take care of mux extension here. Extensions that -            # are placed before mux should be applied to logical channels. +            # We need to take into account of mux extension here. +            # If mux extension exists: +            # - Remove processors of extensions for logical channel, +            #   which are processors located before the mux processor +            # - Pass extension requests for logical channel to mux processor +            # - Attach the mux processor to the request. It will be referred +            #   by dispatcher to see whether the dispatcher should use mux +            #   handler or not.              mux_index = -1              for i, processor in enumerate(processors):                  if processor.name() == common.MUX_EXTENSION:                      mux_index = i                      break              if mux_index >= 0: -                mux_processor = processors[mux_index] -                logical_channel_processors = processors[:mux_index] -                processors = processors[mux_index+1:] - -                for processor in logical_channel_processors: -                    extension_response = processor.get_extension_response() -                    if extension_response is None: -                        # Rejected. -                        continue -                    accepted_extensions.append(extension_response) -                # Pass a shallow copy of accepted_extensions as extensions for -                # logical channels. -                mux_response = mux_processor.get_extension_response( -                    self._request, accepted_extensions[:]) -                if mux_response is not None: -                    accepted_extensions.append(mux_response) +                logical_channel_extensions = [] +                for processor in processors[:mux_index]: +                    logical_channel_extensions.append(processor.request()) +                    processor.set_active(False) +                self._request.mux_processor = processors[mux_index] +                self._request.mux_processor.set_extensions( +                    logical_channel_extensions) +                processors = filter(lambda processor: processor.is_active(), +                                    processors)              stream_options = StreamOptions() -            # When there is mux extension, here, |processors| contain only -            # prosessors for extensions placed after mux.              for processor in processors: -                  extension_response = processor.get_extension_response()                  if extension_response is None:                      # Rejected. @@ -242,7 +256,7 @@ class Handshaker(object):                      raise HandshakeException(                          'do_extra_handshake must choose one subprotocol from '                          'ws_requested_protocols and set it to ws_protocol') -                validate_subprotocol(self._request.ws_protocol, hixie=False) +                validate_subprotocol(self._request.ws_protocol)                  self._logger.debug(                      'Subprotocol accepted: %r', @@ -375,6 +389,7 @@ class Handshaker(object):          response.append('HTTP/1.1 101 Switching Protocols\r\n') +        # WebSocket headers          response.append(format_header(              common.UPGRADE_HEADER, common.WEBSOCKET_UPGRADE_TYPE))          response.append(format_header( @@ -390,6 +405,11 @@ class Handshaker(object):              response.append(format_header(                  common.SEC_WEBSOCKET_EXTENSIONS_HEADER,                  common.format_extensions(self._request.ws_extensions))) + +        # Headers not specific for WebSocket +        for name, value in self._request.extra_headers: +            response.append(format_header(name, value)) +          response.append('\r\n')          return ''.join(response) diff --git a/pyload/lib/mod_pywebsocket/handshake/hybi00.py b/pyload/lib/mod_pywebsocket/handshake/hybi00.py index cc6f8dc43..8757717a6 100644 --- a/pyload/lib/mod_pywebsocket/handshake/hybi00.py +++ b/pyload/lib/mod_pywebsocket/handshake/hybi00.py @@ -51,11 +51,12 @@ from mod_pywebsocket import common  from mod_pywebsocket.stream import StreamHixie75  from mod_pywebsocket import util  from mod_pywebsocket.handshake._base import HandshakeException -from mod_pywebsocket.handshake._base import build_location -from mod_pywebsocket.handshake._base import check_header_lines +from mod_pywebsocket.handshake._base import check_request_line  from mod_pywebsocket.handshake._base import format_header +from mod_pywebsocket.handshake._base import get_default_port  from mod_pywebsocket.handshake._base import get_mandatory_header -from mod_pywebsocket.handshake._base import validate_subprotocol +from mod_pywebsocket.handshake._base import parse_host_header +from mod_pywebsocket.handshake._base import validate_mandatory_header  _MANDATORY_HEADERS = [ @@ -65,6 +66,56 @@ _MANDATORY_HEADERS = [  ] +def _validate_subprotocol(subprotocol): +    """Checks if characters in subprotocol are in range between U+0020 and +    U+007E. A value in the Sec-WebSocket-Protocol field need to satisfy this +    requirement. + +    See the Section 4.1. Opening handshake of the spec. +    """ + +    if not subprotocol: +        raise HandshakeException('Invalid subprotocol name: empty') + +    # Parameter should be in the range U+0020 to U+007E. +    for c in subprotocol: +        if not 0x20 <= ord(c) <= 0x7e: +            raise HandshakeException( +                'Illegal character in subprotocol name: %r' % c) + + +def _check_header_lines(request, mandatory_headers): +    check_request_line(request) + +    # The expected field names, and the meaning of their corresponding +    # values, are as follows. +    #  |Upgrade| and |Connection| +    for key, expected_value in mandatory_headers: +        validate_mandatory_header(request, key, expected_value) + + +def _build_location(request): +    """Build WebSocket location for request.""" + +    location_parts = [] +    if request.is_https(): +        location_parts.append(common.WEB_SOCKET_SECURE_SCHEME) +    else: +        location_parts.append(common.WEB_SOCKET_SCHEME) +    location_parts.append('://') +    host, port = parse_host_header(request) +    connection_port = request.connection.local_addr[1] +    if port != connection_port: +        raise HandshakeException('Header/connection port mismatch: %d/%d' % +                                 (port, connection_port)) +    location_parts.append(host) +    if (port != get_default_port(request.is_https())): +        location_parts.append(':') +        location_parts.append(str(port)) +    location_parts.append(request.unparsed_uri) +    return ''.join(location_parts) + +  class Handshaker(object):      """Opening handshake processor for the WebSocket protocol version HyBi 00.      """ @@ -101,7 +152,7 @@ class Handshaker(object):          # 5.1 Reading the client's opening handshake.          # dispatcher sets it in self._request. -        check_header_lines(self._request, _MANDATORY_HEADERS) +        _check_header_lines(self._request, _MANDATORY_HEADERS)          self._set_resource()          self._set_subprotocol()          self._set_location() @@ -121,14 +172,14 @@ class Handshaker(object):          subprotocol = self._request.headers_in.get(              common.SEC_WEBSOCKET_PROTOCOL_HEADER)          if subprotocol is not None: -            validate_subprotocol(subprotocol, hixie=True) +            _validate_subprotocol(subprotocol)          self._request.ws_protocol = subprotocol      def _set_location(self):          # |Host|          host = self._request.headers_in.get(common.HOST_HEADER)          if host is not None: -            self._request.ws_location = build_location(self._request) +            self._request.ws_location = _build_location(self._request)          # TODO(ukai): check host is this host.      def _set_origin(self): diff --git a/pyload/lib/mod_pywebsocket/headerparserhandler.py b/pyload/lib/mod_pywebsocket/headerparserhandler.py index 2cc62de04..c244421cf 100644 --- a/pyload/lib/mod_pywebsocket/headerparserhandler.py +++ b/pyload/lib/mod_pywebsocket/headerparserhandler.py @@ -167,7 +167,9 @@ def _create_dispatcher():          handler_root, handler_scan, allow_handlers_outside_root)      for warning in dispatcher.source_warnings(): -        apache.log_error('mod_pywebsocket: %s' % warning, apache.APLOG_WARNING) +        apache.log_error( +            'mod_pywebsocket: Warning in source loading: %s' % warning, +            apache.APLOG_WARNING)      return dispatcher @@ -191,12 +193,16 @@ def headerparserhandler(request):          # Fallback to default http handler for request paths for which          # we don't have request handlers.          if not _dispatcher.get_handler_suite(request.uri): -            request.log_error('No handler for resource: %r' % request.uri, -                              apache.APLOG_INFO) -            request.log_error('Fallback to Apache', apache.APLOG_INFO) +            request.log_error( +                'mod_pywebsocket: No handler for resource: %r' % request.uri, +                apache.APLOG_INFO) +            request.log_error( +                'mod_pywebsocket: Fallback to Apache', apache.APLOG_INFO)              return apache.DECLINED      except dispatch.DispatchException, e: -        request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO) +        request.log_error( +            'mod_pywebsocket: Dispatch failed for error: %s' % e, +            apache.APLOG_INFO)          if not handshake_is_done:              return e.status @@ -210,26 +216,30 @@ def headerparserhandler(request):              handshake.do_handshake(                  request, _dispatcher, allowDraft75=allow_draft75)          except handshake.VersionException, e: -            request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO) +            request.log_error( +                'mod_pywebsocket: Handshake failed for version error: %s' % e, +                apache.APLOG_INFO)              request.err_headers_out.add(common.SEC_WEBSOCKET_VERSION_HEADER,                                          e.supported_versions)              return apache.HTTP_BAD_REQUEST          except handshake.HandshakeException, e:              # Handshake for ws/wss failed.              # Send http response with error status. -            request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO) +            request.log_error( +                'mod_pywebsocket: Handshake failed for error: %s' % e, +                apache.APLOG_INFO)              return e.status          handshake_is_done = True          request._dispatcher = _dispatcher          _dispatcher.transfer_data(request)      except handshake.AbortedByUserException, e: -        request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO) +        request.log_error('mod_pywebsocket: Aborted: %s' % e, apache.APLOG_INFO)      except Exception, e:          # DispatchException can also be thrown if something is wrong in          # pywebsocket code. It's caught here, then. -        request.log_error('mod_pywebsocket: %s\n%s' % +        request.log_error('mod_pywebsocket: Exception occurred: %s\n%s' %                            (e, util.get_stack_trace()),                            apache.APLOG_ERR)          # Unknown exceptions before handshake mean Apache must handle its diff --git a/pyload/lib/mod_pywebsocket/mux.py b/pyload/lib/mod_pywebsocket/mux.py index f0bdd2461..7923fb211 100644 --- a/pyload/lib/mod_pywebsocket/mux.py +++ b/pyload/lib/mod_pywebsocket/mux.py @@ -50,6 +50,7 @@ from mod_pywebsocket import handshake  from mod_pywebsocket import util  from mod_pywebsocket._stream_base import BadOperationException  from mod_pywebsocket._stream_base import ConnectionTerminatedException +from mod_pywebsocket._stream_base import InvalidFrameException  from mod_pywebsocket._stream_hybi import Frame  from mod_pywebsocket._stream_hybi import Stream  from mod_pywebsocket._stream_hybi import StreamOptions @@ -94,10 +95,12 @@ _DROP_CODE_UNKNOWN_MUX_OPCODE = 2004  _DROP_CODE_INVALID_MUX_CONTROL_BLOCK = 2005  _DROP_CODE_CHANNEL_ALREADY_EXISTS = 2006  _DROP_CODE_NEW_CHANNEL_SLOT_VIOLATION = 2007 +_DROP_CODE_UNKNOWN_REQUEST_ENCODING = 2010 -_DROP_CODE_UNKNOWN_REQUEST_ENCODING = 3002  _DROP_CODE_SEND_QUOTA_VIOLATION = 3005 +_DROP_CODE_SEND_QUOTA_OVERFLOW = 3006  _DROP_CODE_ACKNOWLEDGED = 3008 +_DROP_CODE_BAD_FRAGMENTATION = 3009  class MuxUnexpectedException(Exception): @@ -158,8 +161,7 @@ def _encode_number(number):  def _create_add_channel_response(channel_id, encoded_handshake, -                                 encoding=0, rejected=False, -                                 outer_frame_mask=False): +                                 encoding=0, rejected=False):      if encoding != 0 and encoding != 1:          raise ValueError('Invalid encoding %d' % encoding) @@ -169,12 +171,10 @@ def _create_add_channel_response(channel_id, encoded_handshake,               _encode_channel_id(channel_id) +               _encode_number(len(encoded_handshake)) +               encoded_handshake) -    payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block -    return create_binary_frame(payload, mask=outer_frame_mask) +    return block -def _create_drop_channel(channel_id, code=None, message='', -                         outer_frame_mask=False): +def _create_drop_channel(channel_id, code=None, message=''):      if len(message) > 0 and code is None:          raise ValueError('Code must be specified if message is specified') @@ -187,36 +187,31 @@ def _create_drop_channel(channel_id, code=None, message='',          reason_size = _encode_number(len(reason))          block += reason_size + reason -    payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block -    return create_binary_frame(payload, mask=outer_frame_mask) +    return block -def _create_flow_control(channel_id, replenished_quota, -                         outer_frame_mask=False): +def _create_flow_control(channel_id, replenished_quota):      first_byte = _MUX_OPCODE_FLOW_CONTROL << 5      block = (chr(first_byte) +               _encode_channel_id(channel_id) +               _encode_number(replenished_quota)) -    payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block -    return create_binary_frame(payload, mask=outer_frame_mask) +    return block -def _create_new_channel_slot(slots, send_quota, outer_frame_mask=False): +def _create_new_channel_slot(slots, send_quota):      if slots < 0 or send_quota < 0:          raise ValueError('slots and send_quota must be non-negative.')      first_byte = _MUX_OPCODE_NEW_CHANNEL_SLOT << 5      block = (chr(first_byte) +               _encode_number(slots) +               _encode_number(send_quota)) -    payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block -    return create_binary_frame(payload, mask=outer_frame_mask) +    return block -def _create_fallback_new_channel_slot(outer_frame_mask=False): +def _create_fallback_new_channel_slot():      first_byte = (_MUX_OPCODE_NEW_CHANNEL_SLOT << 5) | 1 # Set the F flag      block = (chr(first_byte) + _encode_number(0) + _encode_number(0)) -    payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block -    return create_binary_frame(payload, mask=outer_frame_mask) +    return block  def _parse_request_text(request_text): @@ -318,44 +313,34 @@ class _MuxFramePayloadParser(object):      def _read_number(self):          if self._read_position + 1 > len(self._data): -            raise PhysicalConnectionError( -                _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, +            raise ValueError(                  'Cannot read the first byte of number field')          number = ord(self._data[self._read_position])          if number & 0x80 == 0x80: -            raise PhysicalConnectionError( -                _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, +            raise ValueError(                  'The most significant bit of the first byte of number should '                  'be unset')          self._read_position += 1          pos = self._read_position          if number == 127:              if pos + 8 > len(self._data): -                raise PhysicalConnectionError( -                    _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, -                    'Invalid number field') +                raise ValueError('Invalid number field')              self._read_position += 8              number = struct.unpack('!Q', self._data[pos:pos+8])[0]              if number > 0x7FFFFFFFFFFFFFFF: -                raise PhysicalConnectionError( -                    _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, -                    'Encoded number >= 2^63') +                raise ValueError('Encoded number(%d) >= 2^63' % number)              if number <= 0xFFFF: -                raise PhysicalConnectionError( -                    _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, +                raise ValueError(                      '%d should not be encoded by 9 bytes encoding' % number)              return number          if number == 126:              if pos + 2 > len(self._data): -                raise PhysicalConnectionError( -                    _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, -                    'Invalid number field') +                raise ValueError('Invalid number field')              self._read_position += 2              number = struct.unpack('!H', self._data[pos:pos+2])[0]              if number <= 125: -                raise PhysicalConnectionError( -                    _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, +                raise ValueError(                      '%d should not be encoded by 3 bytes encoding' % number)          return number @@ -366,7 +351,11 @@ class _MuxFramePayloadParser(object):              - the contents.          """ -        size = self._read_number() +        try: +            size = self._read_number() +        except ValueError, e: +            raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK, +                                          str(e))          pos = self._read_position          if pos + size > len(self._data):              raise PhysicalConnectionError( @@ -419,9 +408,11 @@ class _MuxFramePayloadParser(object):          try:              control_block.channel_id = self.read_channel_id() +            control_block.send_quota = self._read_number()          except ValueError, e: -            raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK) -        control_block.send_quota = self._read_number() +            raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK, +                                          str(e)) +          return control_block      def _read_drop_channel(self, first_byte, control_block): @@ -455,8 +446,12 @@ class _MuxFramePayloadParser(object):                  _DROP_CODE_INVALID_MUX_CONTROL_BLOCK,                  'Reserved bits must be unset')          control_block.fallback = first_byte & 1 -        control_block.slots = self._read_number() -        control_block.send_quota = self._read_number() +        try: +            control_block.slots = self._read_number() +            control_block.send_quota = self._read_number() +        except ValueError, e: +            raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK, +                                          str(e))          return control_block      def read_control_blocks(self): @@ -549,8 +544,12 @@ class _LogicalConnection(object):          self._mux_handler = mux_handler          self._channel_id = channel_id          self._incoming_data = '' + +        # - Protects _waiting_write_completion +        # - Signals the thread waiting for completion of write by mux handler          self._write_condition = threading.Condition()          self._waiting_write_completion = False +          self._read_condition = threading.Condition()          self._read_state = self.STATE_ACTIVE @@ -594,6 +593,7 @@ class _LogicalConnection(object):              self._waiting_write_completion = True              self._mux_handler.send_data(self._channel_id, data)              self._write_condition.wait() +            # TODO(tyoshino): Raise an exception if woke up by on_writer_done.          finally:              self._write_condition.release() @@ -607,20 +607,31 @@ class _LogicalConnection(object):          self._mux_handler.send_control_data(data) -    def notify_write_done(self): +    def on_write_data_done(self):          """Called when sending data is completed."""          try:              self._write_condition.acquire()              if not self._waiting_write_completion:                  raise MuxUnexpectedException( -                    'Invalid call of notify_write_done for logical connection' -                    ' %d' % self._channel_id) +                    'Invalid call of on_write_data_done for logical ' +                    'connection %d' % self._channel_id) +            self._waiting_write_completion = False +            self._write_condition.notify() +        finally: +            self._write_condition.release() + +    def on_writer_done(self): +        """Called by the mux handler when the writer thread has finished.""" + +        try: +            self._write_condition.acquire()              self._waiting_write_completion = False              self._write_condition.notify()          finally:              self._write_condition.release() +      def append_frame_data(self, frame_data):          """Appends incoming frame data. Called when mux_handler dispatches          frame data to the corresponding application. @@ -686,37 +697,162 @@ class _LogicalConnection(object):          self._read_condition.release() +class _InnerMessage(object): +    """Holds the result of _InnerMessageBuilder.build(). +    """ + +    def __init__(self, opcode, payload): +        self.opcode = opcode +        self.payload = payload + + +class _InnerMessageBuilder(object): +    """A class that holds the context of inner message fragmentation and +    builds a message from fragmented inner frame(s). +    """ + +    def __init__(self): +        self._control_opcode = None +        self._pending_control_fragments = [] +        self._message_opcode = None +        self._pending_message_fragments = [] +        self._frame_handler = self._handle_first + +    def _handle_first(self, frame): +        if frame.opcode == common.OPCODE_CONTINUATION: +            raise InvalidFrameException('Sending invalid continuation opcode') + +        if common.is_control_opcode(frame.opcode): +            return self._process_first_fragmented_control(frame) +        else: +            return self._process_first_fragmented_message(frame) + +    def _process_first_fragmented_control(self, frame): +        self._control_opcode = frame.opcode +        self._pending_control_fragments.append(frame.payload) +        if not frame.fin: +            self._frame_handler = self._handle_fragmented_control +            return None +        return self._reassemble_fragmented_control() + +    def _process_first_fragmented_message(self, frame): +        self._message_opcode = frame.opcode +        self._pending_message_fragments.append(frame.payload) +        if not frame.fin: +            self._frame_handler = self._handle_fragmented_message +            return None +        return self._reassemble_fragmented_message() + +    def _handle_fragmented_control(self, frame): +        if frame.opcode != common.OPCODE_CONTINUATION: +            raise InvalidFrameException( +                'Sending invalid opcode %d while sending fragmented control ' +                'message' % frame.opcode) +        self._pending_control_fragments.append(frame.payload) +        if not frame.fin: +            return None +        return self._reassemble_fragmented_control() + +    def _reassemble_fragmented_control(self): +        opcode = self._control_opcode +        payload = ''.join(self._pending_control_fragments) +        self._control_opcode = None +        self._pending_control_fragments = [] +        if self._message_opcode is not None: +            self._frame_handler = self._handle_fragmented_message +        else: +            self._frame_handler = self._handle_first +        return _InnerMessage(opcode, payload) + +    def _handle_fragmented_message(self, frame): +        # Sender can interleave a control message while sending fragmented +        # messages. +        if common.is_control_opcode(frame.opcode): +            if self._control_opcode is not None: +                raise MuxUnexpectedException( +                    'Should not reach here(Bug in builder)') +            return self._process_first_fragmented_control(frame) + +        if frame.opcode != common.OPCODE_CONTINUATION: +            raise InvalidFrameException( +                'Sending invalid opcode %d while sending fragmented message' % +                frame.opcode) +        self._pending_message_fragments.append(frame.payload) +        if not frame.fin: +            return None +        return self._reassemble_fragmented_message() + +    def _reassemble_fragmented_message(self): +        opcode = self._message_opcode +        payload = ''.join(self._pending_message_fragments) +        self._message_opcode = None +        self._pending_message_fragments = [] +        self._frame_handler = self._handle_first +        return _InnerMessage(opcode, payload) + +    def build(self, frame): +        """Build an inner message. Returns an _InnerMessage instance when +        the given frame is the last fragmented frame. Returns None otherwise. + +        Args: +            frame: an inner frame. +        Raises: +            InvalidFrameException: when received invalid opcode. (e.g. +                receiving non continuation data opcode but the fin flag of +                the previous inner frame was not set.) +        """ + +        return self._frame_handler(frame) + +  class _LogicalStream(Stream):      """Mimics the Stream class. This class interprets multiplexed WebSocket      frames.      """ -    def __init__(self, request, send_quota, receive_quota): +    def __init__(self, request, stream_options, send_quota, receive_quota):          """Constructs an instance.          Args:              request: _LogicalRequest instance. +            stream_options: StreamOptions instance.              send_quota: Initial send quota.              receive_quota: Initial receive quota.          """ -        # TODO(bashi): Support frame filters. -        stream_options = StreamOptions()          # Physical stream is responsible for masking.          stream_options.unmask_receive = False -        # Control frames can be fragmented on logical channel. -        stream_options.allow_fragmented_control_frame = True          Stream.__init__(self, request, stream_options) + +        self._send_closed = False          self._send_quota = send_quota -        self._send_quota_condition = threading.Condition() +        # - Protects _send_closed and _send_quota +        # - Signals the thread waiting for send quota replenished +        self._send_condition = threading.Condition() + +        # The opcode of the first frame in messages. +        self._message_opcode = common.OPCODE_TEXT +        # True when the last message was fragmented. +        self._last_message_was_fragmented = False +          self._receive_quota = receive_quota          self._write_inner_frame_semaphore = threading.Semaphore() +        self._inner_message_builder = _InnerMessageBuilder() +      def _create_inner_frame(self, opcode, payload, end=True): -        # TODO(bashi): Support extensions that use reserved bits. -        first_byte = (end << 7) | opcode -        return (_encode_channel_id(self._request.channel_id) + -                chr(first_byte) + payload) +        frame = Frame(fin=end, opcode=opcode, payload=payload) +        for frame_filter in self._options.outgoing_frame_filters: +            frame_filter.filter(frame) + +        if len(payload) != len(frame.payload): +            raise MuxUnexpectedException( +                'Mux extension must not be used after extensions which change ' +                ' frame boundary') + +        first_byte = ((frame.fin << 7) | (frame.rsv1 << 6) | +                      (frame.rsv2 << 5) | (frame.rsv3 << 4) | frame.opcode) +        return chr(first_byte) + frame.payload      def _write_inner_frame(self, opcode, payload, end=True):          payload_length = len(payload) @@ -730,14 +866,36 @@ class _LogicalStream(Stream):              # multiplexing control blocks can be inserted between fragmented              # inner frames on the physical channel.              self._write_inner_frame_semaphore.acquire() + +            # Consume an octet quota when this is the first fragmented frame. +            if opcode != common.OPCODE_CONTINUATION: +                try: +                    self._send_condition.acquire() +                    while (not self._send_closed) and self._send_quota == 0: +                        self._send_condition.wait() + +                    if self._send_closed: +                        raise BadOperationException( +                            'Logical connection %d is closed' % +                            self._request.channel_id) + +                    self._send_quota -= 1 +                finally: +                    self._send_condition.release() +              while write_position < payload_length:                  try: -                    self._send_quota_condition.acquire() -                    while self._send_quota == 0: +                    self._send_condition.acquire() +                    while (not self._send_closed) and self._send_quota == 0:                          self._logger.debug(                              'No quota. Waiting FlowControl message for %d.' %                              self._request.channel_id) -                        self._send_quota_condition.wait() +                        self._send_condition.wait() + +                    if self._send_closed: +                        raise BadOperationException( +                            'Logical connection %d is closed' % +                            self.request._channel_id)                      remaining = payload_length - write_position                      write_length = min(self._send_quota, remaining) @@ -749,18 +907,16 @@ class _LogicalStream(Stream):                          opcode,                          payload[write_position:write_position+write_length],                          inner_frame_end) -                    frame_data = self._writer.build( -                        inner_frame, end=True, binary=True)                      self._send_quota -= write_length                      self._logger.debug('Consumed quota=%d, remaining=%d' %                                         (write_length, self._send_quota))                  finally: -                    self._send_quota_condition.release() +                    self._send_condition.release()                  # Writing data will block the worker so we need to release -                # _send_quota_condition before writing. -                self._logger.debug('Sending inner frame: %r' % frame_data) -                self._request.connection.write(frame_data) +                # _send_condition before writing. +                self._logger.debug('Sending inner frame: %r' % inner_frame) +                self._request.connection.write(inner_frame)                  write_position += write_length                  opcode = common.OPCODE_CONTINUATION @@ -773,12 +929,18 @@ class _LogicalStream(Stream):      def replenish_send_quota(self, send_quota):          """Replenish send quota.""" -        self._send_quota_condition.acquire() -        self._send_quota += send_quota -        self._logger.debug('Replenished send quota for channel id %d: %d' % -                           (self._request.channel_id, self._send_quota)) -        self._send_quota_condition.notify() -        self._send_quota_condition.release() +        try: +            self._send_condition.acquire() +            if self._send_quota + send_quota > 0x7FFFFFFFFFFFFFFF: +                self._send_quota = 0 +                raise LogicalChannelError( +                    self._request.channel_id, _DROP_CODE_SEND_QUOTA_OVERFLOW) +            self._send_quota += send_quota +            self._logger.debug('Replenished send quota for channel id %d: %d' % +                               (self._request.channel_id, self._send_quota)) +        finally: +            self._send_condition.notify() +            self._send_condition.release()      def consume_receive_quota(self, amount):          """Consumes receive quota. Returns False on failure.""" @@ -808,7 +970,19 @@ class _LogicalStream(Stream):              opcode = common.OPCODE_TEXT              message = message.encode('utf-8') +        for message_filter in self._options.outgoing_message_filters: +            message = message_filter.filter(message, end, binary) + +        if self._last_message_was_fragmented: +            if opcode != self._message_opcode: +                raise BadOperationException('Message types are different in ' +                                            'frames for the same message') +            opcode = common.OPCODE_CONTINUATION +        else: +            self._message_opcode = opcode +          self._write_inner_frame(opcode, message, end) +        self._last_message_was_fragmented = not end      def _receive_frame(self):          """Overrides Stream._receive_frame. @@ -821,6 +995,9 @@ class _LogicalStream(Stream):          opcode, payload, fin, rsv1, rsv2, rsv3 = Stream._receive_frame(self)          amount = len(payload) +        # Replenish extra one octet when receiving the first fragmented frame. +        if opcode != common.OPCODE_CONTINUATION: +            amount += 1          self._receive_quota += amount          frame_data = _create_flow_control(self._request.channel_id,                                            amount) @@ -829,6 +1006,21 @@ class _LogicalStream(Stream):          self._request.connection.write_control_data(frame_data)          return opcode, payload, fin, rsv1, rsv2, rsv3 +    def _get_message_from_frame(self, frame): +        """Overrides Stream._get_message_from_frame. +        """ + +        try: +            inner_message = self._inner_message_builder.build(frame) +        except InvalidFrameException: +            raise LogicalChannelError( +                self._request.channel_id, _DROP_CODE_BAD_FRAGMENTATION) + +        if inner_message is None: +            return None +        self._original_opcode = inner_message.opcode +        return inner_message.payload +      def receive_message(self):          """Overrides Stream.receive_message.""" @@ -882,6 +1074,14 @@ class _LogicalStream(Stream):          pass +    def stop_sending(self): +        """Stops accepting new send operation (_write_inner_frame).""" + +        self._send_condition.acquire() +        self._send_closed = True +        self._send_condition.notify() +        self._send_condition.release() +  class _OutgoingData(object):      """A structure that holds data to be sent via physical connection and @@ -911,8 +1111,17 @@ class _PhysicalConnectionWriter(threading.Thread):          self._logger = util.get_class_logger(self)          self._mux_handler = mux_handler          self.setDaemon(True) + +        # When set, make this thread stop accepting new data, flush pending +        # data and exit.          self._stop_requested = False +        # The close code of the physical connection. +        self._close_code = common.STATUS_NORMAL_CLOSURE +        # Deque for passing write data. It's protected by _deque_condition +        # until _stop_requested is set.          self._deque = collections.deque() +        # - Protects _deque, _stop_requested and _close_code +        # - Signals threads waiting for them to be available          self._deque_condition = threading.Condition()      def put_outgoing_data(self, data): @@ -937,8 +1146,11 @@ class _PhysicalConnectionWriter(threading.Thread):              self._deque_condition.release()      def _write_data(self, outgoing_data): +        message = (_encode_channel_id(outgoing_data.channel_id) + +                   outgoing_data.data)          try: -            self._mux_handler.physical_connection.write(outgoing_data.data) +            self._mux_handler.physical_stream.send_message( +                message=message, end=True, binary=True)          except Exception, e:              util.prepend_message_to_exception(                  'Failed to send message to %r: ' % @@ -948,33 +1160,51 @@ class _PhysicalConnectionWriter(threading.Thread):          # TODO(bashi): It would be better to block the thread that sends          # control data as well.          if outgoing_data.channel_id != _CONTROL_CHANNEL_ID: -            self._mux_handler.notify_write_done(outgoing_data.channel_id) +            self._mux_handler.notify_write_data_done(outgoing_data.channel_id)      def run(self): -        self._deque_condition.acquire() -        while not self._stop_requested: -            if len(self._deque) == 0: -                self._deque_condition.wait() -                continue - -            outgoing_data = self._deque.popleft() -            self._deque_condition.release() -            self._write_data(outgoing_data) +        try:              self._deque_condition.acquire() +            while not self._stop_requested: +                if len(self._deque) == 0: +                    self._deque_condition.wait() +                    continue -        # Flush deque -        try: -            while len(self._deque) > 0:                  outgoing_data = self._deque.popleft() + +                self._deque_condition.release()                  self._write_data(outgoing_data) +                self._deque_condition.acquire() + +            # Flush deque. +            # +            # At this point, self._deque_condition is always acquired. +            try: +                while len(self._deque) > 0: +                    outgoing_data = self._deque.popleft() +                    self._write_data(outgoing_data) +            finally: +                self._deque_condition.release() + +            # Close physical connection. +            try: +                # Don't wait the response here. The response will be read +                # by the reader thread. +                self._mux_handler.physical_stream.close_connection( +                    self._close_code, wait_response=False) +            except Exception, e: +                util.prepend_message_to_exception( +                    'Failed to close the physical connection: %r' % e) +                raise          finally: -            self._deque_condition.release() +            self._mux_handler.notify_writer_done() -    def stop(self): +    def stop(self, close_code=common.STATUS_NORMAL_CLOSURE):          """Stops the writer thread."""          self._deque_condition.acquire()          self._stop_requested = True +        self._close_code = close_code          self._deque_condition.notify()          self._deque_condition.release() @@ -1055,6 +1285,9 @@ class _Worker(threading.Thread):          try:              # Non-critical exceptions will be handled by dispatcher.              self._mux_handler.dispatcher.transfer_data(self._request) +        except LogicalChannelError, e: +            self._mux_handler.fail_logical_channel( +                e.channel_id, e.drop_code, e.message)          finally:              self._mux_handler.notify_worker_done(self._request.channel_id) @@ -1083,8 +1316,6 @@ class _MuxHandshaker(hybi.Handshaker):          #     these headers are included already.          request.headers_in[common.UPGRADE_HEADER] = (              common.WEBSOCKET_UPGRADE_TYPE) -        request.headers_in[common.CONNECTION_HEADER] = ( -            common.UPGRADE_CONNECTION_TYPE)          request.headers_in[common.SEC_WEBSOCKET_VERSION_HEADER] = (              str(common.VERSION_HYBI_LATEST))          request.headers_in[common.SEC_WEBSOCKET_KEY_HEADER] = ( @@ -1095,8 +1326,9 @@ class _MuxHandshaker(hybi.Handshaker):          self._logger.debug('Creating logical stream for %d' %                             self._request.channel_id) -        return _LogicalStream(self._request, self._send_quota, -                              self._receive_quota) +        return _LogicalStream( +            self._request, stream_options, self._send_quota, +            self._receive_quota)      def _create_handshake_response(self, accept):          """Override hybi._create_handshake_response.""" @@ -1105,7 +1337,9 @@ class _MuxHandshaker(hybi.Handshaker):          response.append('HTTP/1.1 101 Switching Protocols\r\n') -        # Upgrade, Connection and Sec-WebSocket-Accept should be excluded. +        # Upgrade and Sec-WebSocket-Accept should be excluded. +        response.append('%s: %s\r\n' % ( +            common.CONNECTION_HEADER, common.UPGRADE_CONNECTION_TYPE))          if self._request.ws_protocol is not None:              response.append('%s: %s\r\n' % (                  common.SEC_WEBSOCKET_PROTOCOL_HEADER, @@ -1169,8 +1403,6 @@ class _HandshakeDeltaBase(object):                      del headers[key]                  else:                      headers[key] = value -        # TODO(bashi): Support extensions -        headers['Sec-WebSocket-Extensions'] = ''          return headers @@ -1232,8 +1464,12 @@ class _MuxHandler(object):          # Create "Implicitly Opened Connection".          logical_connection = _LogicalConnection(self, _DEFAULT_CHANNEL_ID) -        self._handshake_base = _HandshakeDeltaBase( -            self.original_request.headers_in) +        headers = copy.copy(self.original_request.headers_in) +        # Add extensions for logical channel. +        headers[common.SEC_WEBSOCKET_EXTENSIONS_HEADER] = ( +            common.format_extensions( +                self.original_request.mux_processor.extensions())) +        self._handshake_base = _HandshakeDeltaBase(headers)          logical_request = _LogicalRequest(              _DEFAULT_CHANNEL_ID,              self.original_request.method, @@ -1245,8 +1481,9 @@ class _MuxHandler(object):          # but we will send FlowControl later so set the initial quota to          # _INITIAL_QUOTA_FOR_CLIENT.          self._channel_slots.append(_INITIAL_QUOTA_FOR_CLIENT) +        send_quota = self.original_request.mux_processor.quota()          if not self._do_handshake_for_logical_request( -            logical_request, send_quota=self.original_request.mux_quota): +            logical_request, send_quota=send_quota):              raise MuxUnexpectedException(                  'Failed handshake on the default channel id')          self._add_logical_channel(logical_request) @@ -1287,7 +1524,6 @@ class _MuxHandler(object):                  if not self._worker_done_notify_received:                      self._logger.debug('Waiting worker(s) timed out')                      return False -          finally:              self._logical_channels_condition.release() @@ -1297,7 +1533,7 @@ class _MuxHandler(object):          return True -    def notify_write_done(self, channel_id): +    def notify_write_data_done(self, channel_id):          """Called by the writer thread when a write operation has done.          Args: @@ -1308,7 +1544,7 @@ class _MuxHandler(object):              self._logical_channels_condition.acquire()              if channel_id in self._logical_channels:                  channel_data = self._logical_channels[channel_id] -                channel_data.request.connection.notify_write_done() +                channel_data.request.connection.on_write_data_done()              else:                  self._logger.debug('Seems that logical channel for %d has gone'                                     % channel_id) @@ -1469,9 +1705,11 @@ class _MuxHandler(object):                  return              channel_data = self._logical_channels[block.channel_id]              channel_data.drop_code = _DROP_CODE_ACKNOWLEDGED +              # Close the logical channel              channel_data.request.connection.set_read_state(                  _LogicalConnection.STATE_TERMINATED) +            channel_data.request.ws_stream.stop_sending()          finally:              self._logical_channels_condition.release() @@ -1506,8 +1744,11 @@ class _MuxHandler(object):                  return              channel_data = self._logical_channels[channel_id]              fin, rsv1, rsv2, rsv3, opcode, payload = parser.read_inner_frame() +            consuming_byte = len(payload) +            if opcode != common.OPCODE_CONTINUATION: +                consuming_byte += 1              if not channel_data.request.ws_stream.consume_receive_quota( -                len(payload)): +                consuming_byte):                  # The client violates quota. Close logical channel.                  raise LogicalChannelError(                      channel_id, _DROP_CODE_SEND_QUOTA_VIOLATION) @@ -1569,15 +1810,32 @@ class _MuxHandler(object):          finished.          """ -        # Terminate all logical connections -        self._logger.debug('termiating all logical connections...') +        self._logger.debug( +            'Termiating all logical connections waiting for incoming data ' +            '...')          self._logical_channels_condition.acquire()          for channel_data in self._logical_channels.values():              try:                  channel_data.request.connection.set_read_state(                      _LogicalConnection.STATE_TERMINATED)              except Exception: -                pass +                self._logger.debug(traceback.format_exc()) +        self._logical_channels_condition.release() + +    def notify_writer_done(self): +        """This method is called by the writer thread when the writer has +        finished. +        """ + +        self._logger.debug( +            'Termiating all logical connections waiting for write ' +            'completion ...') +        self._logical_channels_condition.acquire() +        for channel_data in self._logical_channels.values(): +            try: +                channel_data.request.connection.on_writer_done() +            except Exception: +                self._logger.debug(traceback.format_exc())          self._logical_channels_condition.release()      def fail_physical_connection(self, code, message): @@ -1590,8 +1848,7 @@ class _MuxHandler(object):          self._logger.debug('Failing the physical connection...')          self._send_drop_channel(_CONTROL_CHANNEL_ID, code, message) -        self.physical_stream.close_connection( -            common.STATUS_INTERNAL_ENDPOINT_ERROR) +        self._writer.stop(common.STATUS_INTERNAL_ENDPOINT_ERROR)      def fail_logical_channel(self, channel_id, code, message):          """Fail a logical channel. @@ -1611,8 +1868,10 @@ class _MuxHandler(object):                  # called later and it will send DropChannel.                  channel_data.drop_code = code                  channel_data.drop_message = message +                  channel_data.request.connection.set_read_state(                      _LogicalConnection.STATE_TERMINATED) +                channel_data.request.ws_stream.stop_sending()              else:                  self._send_drop_channel(channel_id, code, message)          finally: @@ -1620,7 +1879,8 @@ class _MuxHandler(object):  def use_mux(request): -    return hasattr(request, 'mux') and request.mux +    return hasattr(request, 'mux_processor') and ( +        request.mux_processor.is_active())  def start(request, dispatcher): diff --git a/pyload/lib/mod_pywebsocket/standalone.py b/pyload/lib/mod_pywebsocket/standalone.py index 07a33d9c9..e9f083753 100755 --- a/pyload/lib/mod_pywebsocket/standalone.py +++ b/pyload/lib/mod_pywebsocket/standalone.py @@ -76,6 +76,9 @@ SUPPORTING TLS  To support TLS, run standalone.py with -t, -k, and -c options. +Note that when ssl module is used and the key/cert location is incorrect, +TLS connection silently fails while pyOpenSSL fails on startup. +  SUPPORTING CLIENT AUTHENTICATION @@ -140,18 +143,6 @@ import sys  import threading  import time -_HAS_SSL = False -_HAS_OPEN_SSL = False -try: -    import ssl -    _HAS_SSL = True -except ImportError: -    try: -        import OpenSSL.SSL -        _HAS_OPEN_SSL = True -    except ImportError: -        pass -  from mod_pywebsocket import common  from mod_pywebsocket import dispatch  from mod_pywebsocket import handshake @@ -168,6 +159,10 @@ _DEFAULT_REQUEST_QUEUE_SIZE = 128  # 1024 is practically large enough to contain WebSocket handshake lines.  _MAX_MEMORIZED_LINES = 1024 +# Constants for the --tls_module flag. +_TLS_BY_STANDARD_MODULE = 'ssl' +_TLS_BY_PYOPENSSL = 'pyopenssl' +  class _StandaloneConnection(object):      """Mimic mod_python mp_conn.""" @@ -231,11 +226,23 @@ class _StandaloneRequest(object):          self.headers_in = request_handler.headers      def get_uri(self): -        """Getter to mimic request.uri.""" +        """Getter to mimic request.uri. + +        This method returns the raw data at the Request-URI part of the +        Request-Line, while the uri method on the request object of mod_python +        returns the path portion after parsing the raw data. This behavior is +        kept for compatibility. +        """          return self._request_handler.path      uri = property(get_uri) +    def get_unparsed_uri(self): +        """Getter to mimic request.unparsed_uri.""" + +        return self._request_handler.path +    unparsed_uri = property(get_unparsed_uri) +      def get_method(self):          """Getter to mimic request.method.""" @@ -266,26 +273,67 @@ class _StandaloneRequest(object):                  'Drained data following close frame: %r', drained_data) +def _import_ssl(): +    global ssl +    try: +        import ssl +        return True +    except ImportError: +        return False + + +def _import_pyopenssl(): +    global OpenSSL +    try: +        import OpenSSL.SSL +        return True +    except ImportError: +        return False + +  class _StandaloneSSLConnection(object): -    """A wrapper class for OpenSSL.SSL.Connection to provide makefile method -    which is not supported by the class. +    """A wrapper class for OpenSSL.SSL.Connection to +    - provide makefile method which is not supported by the class +    - tweak shutdown method since OpenSSL.SSL.Connection.shutdown doesn't +      accept the "how" argument. +    - convert SysCallError exceptions that its recv method may raise into a +      return value of '', meaning EOF. We cannot overwrite the recv method on +      self._connection since it's immutable.      """ +    _OVERRIDDEN_ATTRIBUTES = ['_connection', 'makefile', 'shutdown', 'recv'] +      def __init__(self, connection):          self._connection = connection      def __getattribute__(self, name): -        if name in ('_connection', 'makefile'): +        if name in _StandaloneSSLConnection._OVERRIDDEN_ATTRIBUTES:              return object.__getattribute__(self, name)          return self._connection.__getattribute__(name)      def __setattr__(self, name, value): -        if name in ('_connection', 'makefile'): +        if name in _StandaloneSSLConnection._OVERRIDDEN_ATTRIBUTES:              return object.__setattr__(self, name, value)          return self._connection.__setattr__(name, value)      def makefile(self, mode='r', bufsize=-1): -        return socket._fileobject(self._connection, mode, bufsize) +        return socket._fileobject(self, mode, bufsize) + +    def shutdown(self, unused_how): +        self._connection.shutdown() + +    def recv(self, bufsize, flags=0): +        if flags != 0: +            raise ValueError('Non-zero flags not allowed') + +        try: +            return self._connection.recv(bufsize) +        except OpenSSL.SSL.SysCallError, (err, message): +            if err == -1: +                # Suppress "unexpected EOF" exception. See the OpenSSL document +                # for SSL_get_error. +                return '' +            raise  def _alias_handlers(dispatcher, websock_handlers_map_file): @@ -340,7 +388,7 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):          warnings = options.dispatcher.source_warnings()          if warnings:              for warning in warnings: -                logging.warning('mod_pywebsocket: %s' % warning) +                logging.warning('Warning in source loading: %s' % warning)          self._logger = util.get_class_logger(self) @@ -387,25 +435,25 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):              except Exception, e:                  self._logger.info('Skip by failure: %r', e)                  continue -            if self.websocket_server_options.use_tls: -                if _HAS_SSL: -                    if self.websocket_server_options.tls_client_auth: -                        client_cert_ = ssl.CERT_REQUIRED +            server_options = self.websocket_server_options +            if server_options.use_tls: +                # For the case of _HAS_OPEN_SSL, we do wrapper setup after +                # accept. +                if server_options.tls_module == _TLS_BY_STANDARD_MODULE: +                    if server_options.tls_client_auth: +                        if server_options.tls_client_cert_optional: +                            client_cert_ = ssl.CERT_OPTIONAL +                        else: +                            client_cert_ = ssl.CERT_REQUIRED                      else:                          client_cert_ = ssl.CERT_NONE                      socket_ = ssl.wrap_socket(socket_, -                        keyfile=self.websocket_server_options.private_key, -                        certfile=self.websocket_server_options.certificate, +                        keyfile=server_options.private_key, +                        certfile=server_options.certificate,                          ssl_version=ssl.PROTOCOL_SSLv23, -                        ca_certs=self.websocket_server_options.tls_client_ca, -                        cert_reqs=client_cert_) -                if _HAS_OPEN_SSL: -                    ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) -                    ctx.use_privatekey_file( -                        self.websocket_server_options.private_key) -                    ctx.use_certificate_file( -                        self.websocket_server_options.certificate) -                    socket_ = OpenSSL.SSL.Connection(ctx, socket_) +                        ca_certs=server_options.tls_client_ca, +                        cert_reqs=client_cert_, +                        do_handshake_on_connect=False)              self._sockets.append((socket_, addrinfo))      def server_bind(self): @@ -479,7 +527,7 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):          self._logger.critical('Not supported: fileno')          return self._sockets[0][0].fileno() -    def handle_error(self, rquest, client_address): +    def handle_error(self, request, client_address):          """Override SocketServer.handle_error."""          self._logger.error( @@ -496,8 +544,63 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):          """          accepted_socket, client_address = self.socket.accept() -        if self.websocket_server_options.use_tls and _HAS_OPEN_SSL: -            accepted_socket = _StandaloneSSLConnection(accepted_socket) + +        server_options = self.websocket_server_options +        if server_options.use_tls: +            if server_options.tls_module == _TLS_BY_STANDARD_MODULE: +                try: +                    accepted_socket.do_handshake() +                except ssl.SSLError, e: +                    self._logger.debug('%r', e) +                    raise + +                # Print cipher in use. Handshake is done on accept. +                self._logger.debug('Cipher: %s', accepted_socket.cipher()) +                self._logger.debug('Client cert: %r', +                                   accepted_socket.getpeercert()) +            elif server_options.tls_module == _TLS_BY_PYOPENSSL: +                # We cannot print the cipher in use. pyOpenSSL doesn't provide +                # any method to fetch that. + +                ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) +                ctx.use_privatekey_file(server_options.private_key) +                ctx.use_certificate_file(server_options.certificate) + +                def default_callback(conn, cert, errnum, errdepth, ok): +                    return ok == 1 + +                # See the OpenSSL document for SSL_CTX_set_verify. +                if server_options.tls_client_auth: +                    verify_mode = OpenSSL.SSL.VERIFY_PEER +                    if not server_options.tls_client_cert_optional: +                        verify_mode |= OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT +                    ctx.set_verify(verify_mode, default_callback) +                    ctx.load_verify_locations(server_options.tls_client_ca, +                                              None) +                else: +                    ctx.set_verify(OpenSSL.SSL.VERIFY_NONE, default_callback) + +                accepted_socket = OpenSSL.SSL.Connection(ctx, accepted_socket) +                accepted_socket.set_accept_state() + +                # Convert SSL related error into socket.error so that +                # SocketServer ignores them and keeps running. +                # +                # TODO(tyoshino): Convert all kinds of errors. +                try: +                    accepted_socket.do_handshake() +                except OpenSSL.SSL.Error, e: +                    # Set errno part to 1 (SSL_ERROR_SSL) like the ssl module +                    # does. +                    self._logger.debug('%r', e) +                    raise socket.error(1, '%r' % e) +                cert = accepted_socket.get_peer_certificate() +                self._logger.debug('Client cert subject: %r', +                                   cert.get_subject().get_components()) +                accepted_socket = _StandaloneSSLConnection(accepted_socket) +            else: +                raise ValueError('No TLS support module is available') +          return accepted_socket, client_address      def serve_forever(self, poll_interval=0.5): @@ -636,7 +739,7 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):                  self._logger.info('Fallback to CGIHTTPRequestHandler')                  return True          except dispatch.DispatchException, e: -            self._logger.info('%s', e) +            self._logger.info('Dispatch failed for error: %s', e)              self.send_error(e.status)              return False @@ -652,7 +755,7 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):                      allowDraft75=self._options.allow_draft75,                      strict=self._options.strict)              except handshake.VersionException, e: -                self._logger.info('%s', e) +                self._logger.info('Handshake failed for version error: %s', e)                  self.send_response(common.HTTP_STATUS_BAD_REQUEST)                  self.send_header(common.SEC_WEBSOCKET_VERSION_HEADER,                                   e.supported_versions) @@ -660,14 +763,14 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):                  return False              except handshake.HandshakeException, e:                  # Handshake for ws(s) failed. -                self._logger.info('%s', e) +                self._logger.info('Handshake failed for error: %s', e)                  self.send_error(e.status)                  return False              request._dispatcher = self._options.dispatcher              self._options.dispatcher.transfer_data(request)          except handshake.AbortedByUserException, e: -            self._logger.info('%s', e) +            self._logger.info('Aborted: %s', e)          return False      def log_request(self, code='-', size='-'): @@ -799,6 +902,12 @@ def _build_option_parser():                              'as CGI programs. Must be executable.'))      parser.add_option('-t', '--tls', dest='use_tls', action='store_true',                        default=False, help='use TLS (wss://)') +    parser.add_option('--tls-module', '--tls_module', dest='tls_module', +                      type='choice', +                      choices = [_TLS_BY_STANDARD_MODULE, _TLS_BY_PYOPENSSL], +                      help='Use ssl module if "%s" is specified. ' +                      'Use pyOpenSSL module if "%s" is specified' % +                      (_TLS_BY_STANDARD_MODULE, _TLS_BY_PYOPENSSL))      parser.add_option('-k', '--private-key', '--private_key',                        dest='private_key',                        default='', help='TLS private key file.') @@ -806,7 +915,12 @@ def _build_option_parser():                        default='', help='TLS certificate file.')      parser.add_option('--tls-client-auth', dest='tls_client_auth',                        action='store_true', default=False, -                      help='Requires TLS client auth on every connection.') +                      help='Requests TLS client auth on every connection.') +    parser.add_option('--tls-client-cert-optional', +                      dest='tls_client_cert_optional', +                      action='store_true', default=False, +                      help=('Makes client certificate optional even though ' +                            'TLS client auth is enabled.'))      parser.add_option('--tls-client-ca', dest='tls_client_ca', default='',                        help=('Specifies a pem file which contains a set of '                              'concatenated CA certificates which are used to ' @@ -933,6 +1047,12 @@ def _main(args=None):      _configure_logging(options) +    if options.allow_draft75: +        logging.warning('--allow_draft75 option is obsolete.') + +    if options.strict: +        logging.warning('--strict option is obsolete.') +      # TODO(tyoshino): Clean up initialization of CGI related values. Move some      # of code here to WebSocketRequestHandler class if it's better.      options.cgi_directories = [] @@ -955,20 +1075,53 @@ def _main(args=None):              options.is_executable_method = __check_script      if options.use_tls: -        if not (_HAS_SSL or _HAS_OPEN_SSL): -            logging.critical('TLS support requires ssl or pyOpenSSL module.') +        if options.tls_module is None: +            if _import_ssl(): +                options.tls_module = _TLS_BY_STANDARD_MODULE +                logging.debug('Using ssl module') +            elif _import_pyopenssl(): +                options.tls_module = _TLS_BY_PYOPENSSL +                logging.debug('Using pyOpenSSL module') +            else: +                logging.critical( +                        'TLS support requires ssl or pyOpenSSL module.') +                sys.exit(1) +        elif options.tls_module == _TLS_BY_STANDARD_MODULE: +            if not _import_ssl(): +                logging.critical('ssl module is not available') +                sys.exit(1) +        elif options.tls_module == _TLS_BY_PYOPENSSL: +            if not _import_pyopenssl(): +                logging.critical('pyOpenSSL module is not available') +                sys.exit(1) +        else: +            logging.critical('Invalid --tls-module option: %r', +                             options.tls_module)              sys.exit(1) +          if not options.private_key or not options.certificate:              logging.critical(                      'To use TLS, specify private_key and certificate.')              sys.exit(1) -    if options.tls_client_auth: -        if not options.use_tls: +        if (options.tls_client_cert_optional and +            not options.tls_client_auth): +            logging.critical('Client authentication must be enabled to ' +                             'specify tls_client_cert_optional') +            sys.exit(1) +    else: +        if options.tls_module is not None: +            logging.critical('Use --tls-module option only together with ' +                             '--use-tls option.') +            sys.exit(1) + +        if options.tls_client_auth: +            logging.critical('TLS must be enabled for client authentication.') +            sys.exit(1) + +        if options.tls_client_cert_optional:              logging.critical('TLS must be enabled for client authentication.')              sys.exit(1) -        if not _HAS_SSL: -            logging.critical('Client authentication requires ssl module.')      if not options.scan_dir:          options.scan_dir = options.websock_handlers diff --git a/pyload/lib/mod_pywebsocket/util.py b/pyload/lib/mod_pywebsocket/util.py index 7bb0b5d9e..fc8451be7 100644 --- a/pyload/lib/mod_pywebsocket/util.py +++ b/pyload/lib/mod_pywebsocket/util.py @@ -56,6 +56,11 @@ import socket  import traceback  import zlib +try: +    from mod_pywebsocket import fast_masking +except ImportError: +    pass +  def get_stack_trace():      """Get the current stack trace as string. @@ -169,26 +174,40 @@ class RepeatedXorMasker(object):      ended and resumes from that point on the next mask method call.      """ -    def __init__(self, mask): -        self._mask = map(ord, mask) -        self._mask_size = len(self._mask) -        self._count = 0 +    def __init__(self, masking_key): +        self._masking_key = masking_key +        self._masking_key_index = 0 -    def mask(self, s): +    def _mask_using_swig(self, s): +        masked_data = fast_masking.mask( +                s, self._masking_key, self._masking_key_index) +        self._masking_key_index = ( +                (self._masking_key_index + len(s)) % len(self._masking_key)) +        return masked_data + +    def _mask_using_array(self, s):          result = array.array('B')          result.fromstring(s) +          # Use temporary local variables to eliminate the cost to access          # attributes -        count = self._count -        mask = self._mask -        mask_size = self._mask_size +        masking_key = map(ord, self._masking_key) +        masking_key_size = len(masking_key) +        masking_key_index = self._masking_key_index +          for i in xrange(len(result)): -            result[i] ^= mask[count] -            count = (count + 1) % mask_size -        self._count = count +            result[i] ^= masking_key[masking_key_index] +            masking_key_index = (masking_key_index + 1) % masking_key_size + +        self._masking_key_index = masking_key_index          return result.tostring() +    if 'fast_masking' in globals(): +        mask = _mask_using_swig +    else: +        mask = _mask_using_array +  class DeflateRequest(object):      """A wrapper class for request object to intercept send and recv to perform @@ -252,6 +271,7 @@ class _Deflater(object):          self._logger.debug('Compress result %r', compressed_bytes)          return compressed_bytes +  class _Inflater(object):      def __init__(self): @@ -346,6 +366,7 @@ class _RFC1979Deflater(object):              return self._deflater.compress_and_flush(bytes)[:-4]          return self._deflater.compress(bytes) +  class _RFC1979Inflater(object):      """A decompressor class for byte sequence compressed and flushed following      the algorithm described in the RFC1979 section 2.1. | 
