diff --git a/config b/config index eb4e5782..7ded0091 100644 --- a/config +++ b/config @@ -24,6 +24,15 @@ # File storing the PID in daemon mode #pid = +# Max parallel connections +#max_connections = 20 + +# Max size of request body (bytes) +#max_content_length = 10000000 + +# Socket timeout (seconds) +#timeout = 10 + # SSL flag, enable HTTPS protocol #ssl = False diff --git a/radicale/__init__.py b/radicale/__init__.py index fa0e0e6e..4f2e26d6 100644 --- a/radicale/__init__.py +++ b/radicale/__init__.py @@ -29,9 +29,11 @@ should have been included in this package. import os import pprint import base64 +import contextlib import socket import socketserver import ssl +import threading import wsgiref.simple_server import re import zlib @@ -54,6 +56,11 @@ WELL_KNOWN_RE = re.compile(r"/\.well-known/(carddav|caldav)/?$") class HTTPServer(wsgiref.simple_server.WSGIServer): """HTTP server.""" + + # These class attributes must be set before creating instance + client_timeout = None + max_connections = None + def __init__(self, address, handler, bind_and_activate=True): """Create server.""" ipv6 = ":" in address[0] @@ -72,6 +79,20 @@ class HTTPServer(wsgiref.simple_server.WSGIServer): self.server_bind() self.server_activate() + if self.max_connections: + self.connections_guard = threading.BoundedSemaphore( + self.max_connections) + else: + # use dummy context manager + self.connections_guard = contextlib.suppress() + + def get_request(self): + # Set timeout for client + _socket, address = super().get_request() + if self.client_timeout: + _socket.settimeout(self.client_timeout) + return _socket, address + class HTTPSServer(HTTPServer): """HTTPS server.""" @@ -95,11 +116,15 @@ class HTTPSServer(HTTPServer): class ThreadedHTTPServer(socketserver.ThreadingMixIn, HTTPServer): - pass + def process_request_thread(self, request, client_address): + with self.connections_guard: + return super().process_request_thread(request, client_address) class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer): - pass + def process_request_thread(self, request, client_address): + with self.connections_guard: + return super().process_request_thread(request, client_address) class RequestHandler(wsgiref.simple_server.WSGIRequestHandler): @@ -218,6 +243,15 @@ class Application: def __call__(self, environ, start_response): """Manage a request.""" + def response(status, headers={}, answer=None): + # Start response + status = "%i %s" % (status, + client.responses.get(status, "Unknown")) + self.logger.debug("Answer status: %s" % status) + start_response(status, list(headers.items())) + # Return response content + return [answer] if answer else [] + self.logger.info("%s request at %s received" % ( environ["REQUEST_METHOD"], environ["PATH_INFO"])) headers = pprint.pformat(self.headers_log(environ)) @@ -234,9 +268,7 @@ class Application: # Request path not starting with base_prefix, not allowed self.logger.debug( "Path not starting with prefix: %s", environ["PATH_INFO"]) - status, headers, _ = NOT_ALLOWED - start_response(status, list(headers.items())) - return [] + return response(*NOT_ALLOWED) # Sanitize request URI environ["PATH_INFO"] = storage.sanitize_path( @@ -275,10 +307,7 @@ class Application: status = client.SEE_OTHER self.logger.info("/.well-known/ redirection to: %s" % redirect) headers = {"Location": redirect} - status = "%i %s" % ( - status, client.responses.get(status, "Unknown")) - start_response(status, list(headers.items())) - return [] + return response(status, headers) is_authenticated = self.is_authenticated(user, password) is_valid_user = is_authenticated or not user @@ -286,8 +315,17 @@ class Application: # Get content content_length = int(environ.get("CONTENT_LENGTH") or 0) if content_length: - content = self.decode( - environ["wsgi.input"].read(content_length), environ) + max_content_length = self.configuration.getint( + "server", "max_content_length") + if max_content_length and content_length > max_content_length: + self.logger.debug( + "Request body too large: %d", content_length) + return response(client.REQUEST_ENTITY_TOO_LARGE) + try: + content = self.decode( + environ["wsgi.input"].read(content_length), environ) + except socket.timeout: + return response(client.REQUEST_TIMEOUT) self.logger.debug("Request content:\n%s" % content) else: content = None @@ -345,13 +383,7 @@ class Application: for key in self.configuration.options("headers"): headers[key] = self.configuration.get("headers", key) - # Start response - status = "%i %s" % (status, client.responses.get(status, "Unknown")) - self.logger.debug("Answer status: %s" % status) - start_response(status, list(headers.items())) - - # Return response content - return [answer] if answer else [] + return response(status, headers, answer) # All these functions must have the same parameters, some are useless # pylint: disable=W0612,W0613,R0201 diff --git a/radicale/__main__.py b/radicale/__main__.py index 849e7187..d1f82be3 100644 --- a/radicale/__main__.py +++ b/radicale/__main__.py @@ -175,6 +175,9 @@ def serve(configuration, logger): name, filename, exception)) else: server_class = ThreadedHTTPServer + server_class.client_timeout = configuration.getint("server", "timeout") + server_class.max_connections = configuration.getint("server", + "max_connections") if not configuration.getboolean("server", "dns_lookup"): RequestHandler.address_string = lambda self: self.client_address[0] diff --git a/radicale/config.py b/radicale/config.py index 5d5e69de..7886c3be 100644 --- a/radicale/config.py +++ b/radicale/config.py @@ -32,6 +32,9 @@ INITIAL_CONFIG = { "hosts": "0.0.0.0:5232", "daemon": "False", "pid": "", + "max_connections": "20", + "max_content_length": "10000000", + "timeout": "10", "ssl": "False", "certificate": "/etc/apache2/ssl/server.crt", "key": "/etc/apache2/ssl/server.key",