diff options
Diffstat (limited to 'module/remote/wsbackend')
| -rw-r--r-- | module/remote/wsbackend/Server.py | 174 | 
1 files changed, 113 insertions, 61 deletions
| diff --git a/module/remote/wsbackend/Server.py b/module/remote/wsbackend/Server.py index 65f4fe5d5..9b52f5464 100644 --- a/module/remote/wsbackend/Server.py +++ b/module/remote/wsbackend/Server.py @@ -57,6 +57,9 @@ from mod_pywebsocket import memorizingfile  from mod_pywebsocket import util +_DEFAULT_LOG_MAX_BYTES = 1024 * 256 +_DEFAULT_LOG_BACKUP_COUNT = 5 +  _DEFAULT_REQUEST_QUEUE_SIZE = 128  # 1024 is practically large enough to contain WebSocket handshake lines. @@ -129,8 +132,7 @@ class _StandaloneRequest(object):              request_handler: A WebSocketRequestHandler instance.          """ -        self._logger = logging.getLogger("log") -        self.api = None +        self._logger = util.get_class_logger(self)          self._request_handler = request_handler          self.connection = _StandaloneConnection(request_handler) @@ -149,6 +151,12 @@ class _StandaloneRequest(object):          return self._request_handler.command      method = property(get_method) +    def get_protocol(self): +        """Getter to mimic request.protocol.""" + +        return self._request_handler.request_version +    protocol = property(get_protocol) +      def is_https(self):          """Mimic request.is_https().""" @@ -189,6 +197,32 @@ class _StandaloneSSLConnection(object):          return socket._fileobject(self._connection, mode, bufsize) +def _alias_handlers(dispatcher, websock_handlers_map_file): +    """Set aliases specified in websock_handler_map_file in dispatcher. + +    Args: +        dispatcher: dispatch.Dispatcher instance +        websock_handler_map_file: alias map file +    """ + +    fp = open(websock_handlers_map_file) +    try: +        for line in fp: +            if line[0] == '#' or line.isspace(): +                continue +            m = re.match('(\S+)\s+(\S+)', line) +            if not m: +                logging.warning('Wrong format in map file:' + line) +                continue +            try: +                dispatcher.add_resource_path_alias( +                    m.group(1), m.group(2)) +            except dispatch.DispatchException, e: +                logging.error(str(e)) +    finally: +        fp.close() + +  class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):      """HTTPServer specialized for WebSocket.""" @@ -202,7 +236,7 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):          socket object to self.socket before server_bind and server_activate,          if necessary.          """ - +        # Removed dispatcher init here          self._logger = logging.getLogger("log")          self.request_queue_size = options.request_queue_size @@ -236,10 +270,10 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):                  (socket.AF_INET, socket.SOCK_STREAM, '', '', '')]          else:              addrinfo_array = socket.getaddrinfo(self.server_name, -                self.server_port, -                socket.AF_UNSPEC, -                socket.SOCK_STREAM, -                socket.IPPROTO_TCP) +                                                self.server_port, +                                                socket.AF_UNSPEC, +                                                socket.SOCK_STREAM, +                                                socket.IPPROTO_TCP)          for addrinfo in addrinfo_array:              family, socktype, proto, canonname, sockaddr = addrinfo              try: @@ -249,16 +283,16 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):                  continue              if self.websocket_server_options.use_tls:                  if _HAS_SSL: -                    if self.websocket_server_options.ca_certificate: +                    if self.websocket_server_options.tls_client_auth:                          client_cert_ = ssl.CERT_REQUIRED                      else:                          client_cert_ = ssl.CERT_NONE                      socket_ = ssl.wrap_socket(socket_, -                        keyfile=self.websocket_server_options.private_key, -                        certfile=self.websocket_server_options.certificate, -                        ssl_version=ssl.PROTOCOL_SSLv23, -                        ca_certs=self.websocket_server_options.ca_certificate, -                        cert_reqs=client_cert_) +                                              keyfile=self.websocket_server_options.private_key, +                                              certfile=self.websocket_server_options.certificate, +                                              ssl_version=ssl.PROTOCOL_SSLv23, +                                              ca_certs=self.websocket_server_options.tls_client_ca, +                                              cert_reqs=client_cert_)                  if _HAS_OPEN_SSL:                      ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)                      ctx.use_privatekey_file( @@ -285,6 +319,15 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):                  self._logger.info('Skip by failure: %r', e)                  socket_.close()                  failed_sockets.append(socketinfo) +            if self.server_address[1] == 0: +                # The operating system assigns the actual port number for port +                # number 0. This case, the second and later sockets should use +                # the same port number. Also self.server_port is rewritten +                # because it is exported, and will be used by external code. +                self.server_address = ( +                    self.server_name, socket_.getsockname()[1]) +                self.server_port = self.server_address[1] +                self._logger.info('Port %r is assigned', self.server_port)          for socketinfo in failed_sockets:              self._sockets.remove(socketinfo) @@ -309,6 +352,10 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):          for socketinfo in failed_sockets:              self._sockets.remove(socketinfo) +        if len(self._sockets) == 0: +            self._logger.critical( +                'No sockets activated. Use info log level to see the reason.') +      def server_close(self):          """Override SocketServer.TCPServer.server_close to enable multiple          sockets close. @@ -436,6 +483,17 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):          # attributes).          if not CGIHTTPServer.CGIHTTPRequestHandler.parse_request(self):              return False + +        if self._options.use_basic_auth: +            auth = self.headers.getheader('Authorization') +            if auth != self._options.basic_auth_credential: +                self.send_response(401) +                self.send_header('WWW-Authenticate', +                                 'Basic realm="Pywebsocket"') +                self.end_headers() +                self._logger.info('Request basic authentication') +                return True +          host, port, resource = http_header_util.parse_uri(self.path)          if resource is None:              self._logger.info('Invalid URI: %r', self.path) @@ -446,16 +504,16 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):              validation_host = server_options.validation_host              if validation_host is not None and host != validation_host:                  self._logger.info('Invalid host: %r (expected: %r)', -                    host, -                    validation_host) +                                  host, +                                  validation_host)                  self._logger.info('Fallback to CGIHTTPRequestHandler')                  return True          if port is not None:              validation_port = server_options.validation_port              if validation_port is not None and port != validation_port:                  self._logger.info('Invalid port: %r (expected: %r)', -                    port, -                    validation_port) +                                  port, +                                  validation_port)                  self._logger.info('Fallback to CGIHTTPRequestHandler')                  return True          self.path = resource @@ -467,7 +525,7 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):              # we don't have request handlers.              if not self._options.dispatcher.get_handler_suite(self.path):                  self._logger.info('No handler for resource: %r', -                    self.path) +                                  self.path)                  self._logger.info('Fallback to CGIHTTPRequestHandler')                  return True          except dispatch.DispatchException, e: @@ -490,7 +548,7 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):                  self._logger.info('%s', e)                  self.send_response(common.HTTP_STATUS_BAD_REQUEST)                  self.send_header(common.SEC_WEBSOCKET_VERSION_HEADER, -                    e.supported_versions) +                                 e.supported_versions)                  self.end_headers()                  return False              except handshake.HandshakeException, e: @@ -509,7 +567,7 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):          """Override BaseHTTPServer.log_request."""          self._logger.info('"%s" %s %s', -            self.requestline, str(code), str(size)) +                          self.requestline, str(code), str(size))      def log_error(self, *args):          """Override BaseHTTPServer.log_error.""" @@ -517,8 +575,8 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):          # Despite the name, this method is for warnings than for errors.          # For example, HTTP status code is logged by this method.          self._logger.warning('%s - %s', -            self.address_string(), -            args[0] % args[1:]) +                             self.address_string(), +                             args[0] % args[1:])      def is_cgi(self):          """Test whether self.path corresponds to a CGI script. @@ -548,30 +606,27 @@ def _get_logger_from_class(c):      return logging.getLogger('%s.%s' % (c.__module__, c.__name__)) -def _alias_handlers(dispatcher, websock_handlers_map_file): -    """Set aliases specified in websock_handler_map_file in dispatcher. - -    Args: -        dispatcher: dispatch.Dispatcher instance -        websock_handler_map_file: alias map file -    """ - -    fp = open(websock_handlers_map_file) -    try: -        for line in fp: -            if line[0] == '#' or line.isspace(): -                continue -            m = re.match('(\S+)\s+(\S+)', line) -            if not m: -                logging.warning('Wrong format in map file:' + line) -                continue -            try: -                dispatcher.add_resource_path_alias( -                    m.group(1), m.group(2)) -            except dispatch.DispatchException, e: -                logging.error(str(e)) -    finally: -        fp.close() +def _configure_logging(options): +    logging.addLevelName(common.LOGLEVEL_FINE, 'FINE') + +    logger = logging.getLogger() +    logger.setLevel(logging.getLevelName(options.log_level.upper())) +    if options.log_file: +        handler = logging.handlers.RotatingFileHandler( +            options.log_file, 'a', options.log_max, options.log_count) +    else: +        handler = logging.StreamHandler() +    formatter = logging.Formatter( +        '[%(asctime)s] [%(levelname)s] %(name)s: %(message)s') +    handler.setFormatter(formatter) +    logger.addHandler(handler) + +    deflate_log_level_name = logging.getLevelName( +        options.deflate_log_level.upper()) +    _get_logger_from_class(util._Deflater).setLevel( +        deflate_log_level_name) +    _get_logger_from_class(util._Inflater).setLevel( +        deflate_log_level_name)  class DefaultOptions:      server_host = '' @@ -582,6 +637,7 @@ class DefaultOptions:      ca_certificate = ''      dispatcher = None      request_queue_size = _DEFAULT_REQUEST_QUEUE_SIZE +    use_basic_auth = False      allow_draft75 = False      strict = False @@ -591,6 +647,12 @@ class DefaultOptions:      is_executable_method = False  def _main(args=None): +    """You can call this function from your own program, but please note that +    this function has some side-effects that might affect your program. For +    example, util.wrap_popen3_for_win use in this method replaces implementation +    of os.popen3. +    """ +      options, args = _parse_args_and_config(args=args)      os.chdir(options.document_root) @@ -627,7 +689,7 @@ def _main(args=None):                  'To use TLS, specify private_key and certificate.')              sys.exit(1) -    if options.ca_certificate: +    if options.tls_client_auth:          if not options.use_tls:              logging.critical('TLS must be enabled for client authentication.')              sys.exit(1) @@ -637,26 +699,16 @@ def _main(args=None):      if not options.scan_dir:          options.scan_dir = options.websock_handlers +    if options.use_basic_auth: +        options.basic_auth_credential = 'Basic ' + base64.b64encode( +            options.basic_auth_credential) +      try:          if options.thread_monitor_interval_in_sec > 0:              # Run a thread monitor to show the status of server threads for              # debugging.              ThreadMonitor(options.thread_monitor_interval_in_sec).start() -        # Share a Dispatcher among request handlers to save time for -        # instantiation.  Dispatcher can be shared because it is thread-safe. -        options.dispatcher = dispatch.Dispatcher( -            options.websock_handlers, -            options.scan_dir, -            options.allow_handlers_outside_root_dir) -        if options.websock_handlers_map_file: -            _alias_handlers(options.dispatcher, -                options.websock_handlers_map_file) -        warnings = options.dispatcher.source_warnings() -        if warnings: -            for warning in warnings: -                logging.warning('mod_pywebsocket: %s' % warning) -          server = WebSocketServer(options)          server.serve_forever()      except Exception, e: | 
