diff options
Diffstat (limited to 'pyload/remote/wsbackend')
| -rw-r--r-- | pyload/remote/wsbackend/Server.py | 263 | 
1 files changed, 203 insertions, 60 deletions
| diff --git a/pyload/remote/wsbackend/Server.py b/pyload/remote/wsbackend/Server.py index 9a6649ca9..3ffe198eb 100644 --- a/pyload/remote/wsbackend/Server.py +++ b/pyload/remote/wsbackend/Server.py @@ -37,6 +37,7 @@  import BaseHTTPServer  import CGIHTTPServer  import SocketServer +import base64  import httplib  import logging  import os @@ -46,9 +47,6 @@ import socket  import sys  import threading -_HAS_SSL = False -_HAS_OPEN_SSL = False -  from mod_pywebsocket import common  from mod_pywebsocket import dispatch  from mod_pywebsocket import handshake @@ -65,20 +63,9 @@ _DEFAULT_REQUEST_QUEUE_SIZE = 128  # 1024 is practically large enough to contain WebSocket handshake lines.  _MAX_MEMORIZED_LINES = 1024 -def import_ssl(): -    global _HAS_SSL, _HAS_OPEN_SSL -    global ssl, OpenSSL -    try: -        import ssl -        _HAS_SSL = True -    except ImportError: -        try: -            import OpenSSL.SSL -            _HAS_OPEN_SSL = True -        except ImportError: -            pass - -    return _HAS_OPEN_SSL or _HAS_SSL +# Constants for the --tls_module flag. +_TLS_BY_STANDARD_MODULE = 'ssl' +_TLS_BY_PYOPENSSL = 'pyopenssl'  class _StandaloneConnection(object): @@ -143,11 +130,23 @@ class _StandaloneRequest(object):          self.headers_in = request_handler.headers      def get_uri(self): -        """Getter to mimic request.uri.""" +        """Getter to mimic request.uri. + +        This method returns the raw data at the Request-URI part of the +        Request-Line, while the uri method on the request object of mod_python +        returns the path portion after parsing the raw data. This behavior is +        kept for compatibility. +        """          return self._request_handler.path      uri = property(get_uri) +    def get_unparsed_uri(self): +        """Getter to mimic request.unparsed_uri.""" + +        return self._request_handler.path +    unparsed_uri = property(get_unparsed_uri) +      def get_method(self):          """Getter to mimic request.method.""" @@ -178,26 +177,67 @@ class _StandaloneRequest(object):                  'Drained data following close frame: %r', drained_data) +def _import_ssl(): +    global ssl +    try: +        import ssl +        return True +    except ImportError: +        return False + + +def _import_pyopenssl(): +    global OpenSSL +    try: +        import OpenSSL.SSL +        return True +    except ImportError: +        return False + +  class _StandaloneSSLConnection(object): -    """A wrapper class for OpenSSL.SSL.Connection to provide makefile method -    which is not supported by the class. +    """A wrapper class for OpenSSL.SSL.Connection to +    - provide makefile method which is not supported by the class +    - tweak shutdown method since OpenSSL.SSL.Connection.shutdown doesn't +      accept the "how" argument. +    - convert SysCallError exceptions that its recv method may raise into a +      return value of '', meaning EOF. We cannot overwrite the recv method on +      self._connection since it's immutable.      """ +    _OVERRIDDEN_ATTRIBUTES = ['_connection', 'makefile', 'shutdown', 'recv'] +      def __init__(self, connection):          self._connection = connection      def __getattribute__(self, name): -        if name in ('_connection', 'makefile'): +        if name in _StandaloneSSLConnection._OVERRIDDEN_ATTRIBUTES:              return object.__getattribute__(self, name)          return self._connection.__getattribute__(name)      def __setattr__(self, name, value): -        if name in ('_connection', 'makefile'): +        if name in _StandaloneSSLConnection._OVERRIDDEN_ATTRIBUTES:              return object.__setattr__(self, name, value)          return self._connection.__setattr__(name, value)      def makefile(self, mode='r', bufsize=-1): -        return socket._fileobject(self._connection, mode, bufsize) +        return socket._fileobject(self, mode, bufsize) + +    def shutdown(self, unused_how): +        self._connection.shutdown() + +    def recv(self, bufsize, flags=0): +        if flags != 0: +            raise ValueError('Non-zero flags not allowed') + +        try: +            return self._connection.recv(bufsize) +        except OpenSSL.SSL.SysCallError, (err, message): +            if err == -1: +                # Suppress "unexpected EOF" exception. See the OpenSSL document +                # for SSL_get_error. +                return '' +            raise  def _alias_handlers(dispatcher, websock_handlers_map_file): @@ -284,25 +324,25 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):              except Exception, e:                  self._logger.info('Skip by failure: %r', e)                  continue -            if self.websocket_server_options.use_tls: -                if _HAS_SSL: -                    if self.websocket_server_options.tls_client_auth: -                        client_cert_ = ssl.CERT_REQUIRED +            server_options = self.websocket_server_options +            if server_options.use_tls: +                # For the case of _HAS_OPEN_SSL, we do wrapper setup after +                # accept. +                if server_options.tls_module == _TLS_BY_STANDARD_MODULE: +                    if server_options.tls_client_auth: +                        if server_options.tls_client_cert_optional: +                            client_cert_ = ssl.CERT_OPTIONAL +                        else: +                            client_cert_ = ssl.CERT_REQUIRED                      else:                          client_cert_ = ssl.CERT_NONE                      socket_ = ssl.wrap_socket(socket_, -                                              keyfile=self.websocket_server_options.private_key, -                                              certfile=self.websocket_server_options.certificate, -                                              ssl_version=ssl.PROTOCOL_SSLv23, -                                              ca_certs=self.websocket_server_options.tls_client_ca, -                                              cert_reqs=client_cert_) -                if _HAS_OPEN_SSL: -                    ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) -                    ctx.use_privatekey_file( -                        self.websocket_server_options.private_key) -                    ctx.use_certificate_file( -                        self.websocket_server_options.certificate) -                    socket_ = OpenSSL.SSL.Connection(ctx, socket_) +                        keyfile=server_options.private_key, +                        certfile=server_options.certificate, +                        ssl_version=ssl.PROTOCOL_SSLv23, +                        ca_certs=server_options.tls_client_ca, +                        cert_reqs=client_cert_, +                        do_handshake_on_connect=False)              self._sockets.append((socket_, addrinfo))      def server_bind(self): @@ -375,7 +415,7 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):          self._logger.critical('Not supported: fileno')          return self._sockets[0][0].fileno() -    def handle_error(self, rquest, client_address): +    def handle_error(self, request, client_address):          """Override SocketServer.handle_error."""          self._logger.error( @@ -392,8 +432,63 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):          """          accepted_socket, client_address = self.socket.accept() -        if self.websocket_server_options.use_tls and _HAS_OPEN_SSL: -            accepted_socket = _StandaloneSSLConnection(accepted_socket) + +        server_options = self.websocket_server_options +        if server_options.use_tls: +            if server_options.tls_module == _TLS_BY_STANDARD_MODULE: +                try: +                    accepted_socket.do_handshake() +                except ssl.SSLError, e: +                    self._logger.debug('%r', e) +                    raise + +                # Print cipher in use. Handshake is done on accept. +                self._logger.debug('Cipher: %s', accepted_socket.cipher()) +                self._logger.debug('Client cert: %r', +                                   accepted_socket.getpeercert()) +            elif server_options.tls_module == _TLS_BY_PYOPENSSL: +                # We cannot print the cipher in use. pyOpenSSL doesn't provide +                # any method to fetch that. + +                ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) +                ctx.use_privatekey_file(server_options.private_key) +                ctx.use_certificate_file(server_options.certificate) + +                def default_callback(conn, cert, errnum, errdepth, ok): +                    return ok == 1 + +                # See the OpenSSL document for SSL_CTX_set_verify. +                if server_options.tls_client_auth: +                    verify_mode = OpenSSL.SSL.VERIFY_PEER +                    if not server_options.tls_client_cert_optional: +                        verify_mode |= OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT +                    ctx.set_verify(verify_mode, default_callback) +                    ctx.load_verify_locations(server_options.tls_client_ca, +                                              None) +                else: +                    ctx.set_verify(OpenSSL.SSL.VERIFY_NONE, default_callback) + +                accepted_socket = OpenSSL.SSL.Connection(ctx, accepted_socket) +                accepted_socket.set_accept_state() + +                # Convert SSL related error into socket.error so that +                # SocketServer ignores them and keeps running. +                # +                # TODO(tyoshino): Convert all kinds of errors. +                try: +                    accepted_socket.do_handshake() +                except OpenSSL.SSL.Error, e: +                    # Set errno part to 1 (SSL_ERROR_SSL) like the ssl module +                    # does. +                    self._logger.debug('%r', e) +                    raise socket.error(1, '%r' % e) +                cert = accepted_socket.get_peer_certificate() +                self._logger.debug('Client cert subject: %r', +                                   cert.get_subject().get_components()) +                accepted_socket = _StandaloneSSLConnection(accepted_socket) +            else: +                raise ValueError('No TLS support module is available') +          return accepted_socket, client_address      def serve_forever(self, poll_interval=0.5): @@ -474,8 +569,6 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):                  raise              self._logger.debug("WS: Broken pipe") - -      def parse_request(self):          """Override BaseHTTPServer.BaseHTTPRequestHandler.parse_request. @@ -545,7 +638,7 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):                  self._logger.info('Fallback to CGIHTTPRequestHandler')                  return False          except dispatch.DispatchException, e: -            self._logger.info('%s', e) +            self._logger.info('Dispatch failed for error: %s', e)              self.send_error(e.status)              return False @@ -561,7 +654,7 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):                      allowDraft75=self._options.allow_draft75,                      strict=self._options.strict)              except handshake.VersionException, e: -                self._logger.info('%s', e) +                self._logger.info('Handshake failed for version error: %s', e)                  self.send_response(common.HTTP_STATUS_BAD_REQUEST)                  self.send_header(common.SEC_WEBSOCKET_VERSION_HEADER,                                   e.supported_versions) @@ -569,14 +662,14 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):                  return False              except handshake.HandshakeException, e:                  # Handshake for ws(s) failed. -                self._logger.info('%s', e) +                self._logger.info('Handshake failed for error: %s', e)                  self.send_error(e.status)                  return False              request._dispatcher = self._options.dispatcher              self._options.dispatcher.transfer_data(request)          except handshake.AbortedByUserException, e: -            self._logger.info('%s', e) +            self._logger.info('Aborted: %s', e)          return False      def log_request(self, code='-', size='-'): @@ -606,7 +699,7 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):          if CGIHTTPServer.CGIHTTPRequestHandler.is_cgi(self):              if '..' in self.path:                  return False -                # strip query parameter from request path +            # strip query parameter from request path              resource_name = self.path.split('?', 2)[0]              # convert resource_name into real path name in filesystem.              scriptfile = self.translate_path(resource_name) @@ -629,11 +722,11 @@ def _configure_logging(options):      logger.setLevel(logging.getLevelName(options.log_level.upper()))      if options.log_file:          handler = logging.handlers.RotatingFileHandler( -            options.log_file, 'a', options.log_max, options.log_count) +                options.log_file, 'a', options.log_max, options.log_count)      else:          handler = logging.StreamHandler()      formatter = logging.Formatter( -        '[%(asctime)s] [%(levelname)s] %(name)s: %(message)s') +            '[%(asctime)s] [%(levelname)s] %(name)s: %(message)s')      handler.setFormatter(formatter)      logger.addHandler(handler) @@ -650,9 +743,10 @@ class DefaultOptions:      use_tls = False      private_key = ''      certificate = '' -    ca_certificate = '' -    tls_client_ca = '' +    tls_client_ca = None      tls_client_auth = False +    tls_client_cert_optional = False +    tls_module = _TLS_BY_STANDARD_MODULE      dispatcher = None      request_queue_size = _DEFAULT_REQUEST_QUEUE_SIZE      use_basic_auth = False @@ -664,6 +758,16 @@ class DefaultOptions:      cgi_directories = ''      is_executable_method = False + +def import_ssl(): +    if _import_ssl(): +        return _TLS_BY_STANDARD_MODULE + +    elif _import_pyopenssl(): +        return _TLS_BY_PYOPENSSL + +    return False +  def _main(args=None):      """You can call this function from your own program, but please note that      this function has some side-effects that might affect your program. For @@ -677,6 +781,12 @@ def _main(args=None):      _configure_logging(options) +    if options.allow_draft75: +        logging.warning('--allow_draft75 option is obsolete.') + +    if options.strict: +        logging.warning('--strict option is obsolete.') +      # TODO(tyoshino): Clean up initialization of CGI related values. Move some      # of code here to WebSocketRequestHandler class if it's better.      options.cgi_directories = [] @@ -699,20 +809,53 @@ def _main(args=None):              options.is_executable_method = __check_script      if options.use_tls: -        if not (_HAS_SSL or _HAS_OPEN_SSL): -            logging.critical('TLS support requires ssl or pyOpenSSL module.') +        if options.tls_module is None: +            if _import_ssl(): +                options.tls_module = _TLS_BY_STANDARD_MODULE +                logging.debug('Using ssl module') +            elif _import_pyopenssl(): +                options.tls_module = _TLS_BY_PYOPENSSL +                logging.debug('Using pyOpenSSL module') +            else: +                logging.critical( +                        'TLS support requires ssl or pyOpenSSL module.') +                sys.exit(1) +        elif options.tls_module == _TLS_BY_STANDARD_MODULE: +            if not _import_ssl(): +                logging.critical('ssl module is not available') +                sys.exit(1) +        elif options.tls_module == _TLS_BY_PYOPENSSL: +            if not _import_pyopenssl(): +                logging.critical('pyOpenSSL module is not available') +                sys.exit(1) +        else: +            logging.critical('Invalid --tls-module option: %r', +                             options.tls_module)              sys.exit(1) +          if not options.private_key or not options.certificate:              logging.critical( -                'To use TLS, specify private_key and certificate.') +                    'To use TLS, specify private_key and certificate.') +            sys.exit(1) + +        if (options.tls_client_cert_optional and +            not options.tls_client_auth): +            logging.critical('Client authentication must be enabled to ' +                             'specify tls_client_cert_optional') +            sys.exit(1) +    else: +        if options.tls_module is not None: +            logging.critical('Use --tls-module option only together with ' +                             '--use-tls option.') +            sys.exit(1) + +        if options.tls_client_auth: +            logging.critical('TLS must be enabled for client authentication.')              sys.exit(1) -    if options.tls_client_auth: -        if not options.use_tls: +        if options.tls_client_cert_optional:              logging.critical('TLS must be enabled for client authentication.')              sys.exit(1) -        if not _HAS_SSL: -            logging.critical('Client authentication requires ssl module.')      if not options.scan_dir:          options.scan_dir = options.websock_handlers | 
