mirror of
https://github.com/Kozea/Radicale.git
synced 2025-06-26 16:45:52 +00:00
More type hints
This commit is contained in:
parent
12fe5ce637
commit
cecb17df03
51 changed files with 1374 additions and 957 deletions
|
@ -23,14 +23,15 @@ Built-in WSGI server.
|
|||
"""
|
||||
|
||||
import errno
|
||||
import os
|
||||
import http
|
||||
import select
|
||||
import socket
|
||||
import socketserver
|
||||
import ssl
|
||||
import sys
|
||||
import wsgiref.simple_server
|
||||
from typing import MutableMapping
|
||||
from typing import (Any, Callable, Dict, List, MutableMapping, Optional, Set,
|
||||
Tuple, Union)
|
||||
from urllib.parse import unquote
|
||||
|
||||
from radicale import Application, config
|
||||
|
@ -38,7 +39,7 @@ from radicale.log import logger
|
|||
|
||||
COMPAT_EAI_ADDRFAMILY: int
|
||||
if hasattr(socket, "EAI_ADDRFAMILY"):
|
||||
COMPAT_EAI_ADDRFAMILY = socket.EAI_ADDRFAMILY # type: ignore[attr-defined]
|
||||
COMPAT_EAI_ADDRFAMILY = socket.EAI_ADDRFAMILY # type:ignore[attr-defined]
|
||||
elif hasattr(socket, "EAI_NONAME"):
|
||||
# Windows and BSD don't have a special error code for this
|
||||
COMPAT_EAI_ADDRFAMILY = socket.EAI_NONAME
|
||||
|
@ -51,57 +52,99 @@ elif hasattr(socket, "EAI_NONAME"):
|
|||
COMPAT_IPPROTO_IPV6: int
|
||||
if hasattr(socket, "IPPROTO_IPV6"):
|
||||
COMPAT_IPPROTO_IPV6 = socket.IPPROTO_IPV6
|
||||
elif os.name == "nt":
|
||||
# Workaround: https://bugs.python.org/issue29515
|
||||
elif sys.platform == "win32":
|
||||
# HACK: https://bugs.python.org/issue29515
|
||||
COMPAT_IPPROTO_IPV6 = 41
|
||||
|
||||
|
||||
def format_address(address):
|
||||
# IPv4 (host, port) and IPv6 (host, port, flowinfo, scopeid)
|
||||
ADDRESS_TYPE = Union[Tuple[str, int], Tuple[str, int, int, int]]
|
||||
|
||||
|
||||
def format_address(address: ADDRESS_TYPE) -> str:
|
||||
return "[%s]:%d" % address[:2]
|
||||
|
||||
|
||||
class ParallelHTTPServer(socketserver.ThreadingMixIn,
|
||||
wsgiref.simple_server.WSGIServer):
|
||||
|
||||
# We wait for child threads ourself
|
||||
block_on_close = False
|
||||
daemon_threads = True
|
||||
configuration: config.Configuration
|
||||
worker_sockets: Set[socket.socket]
|
||||
_timeout: float
|
||||
|
||||
def __init__(self, configuration, family, address, RequestHandlerClass):
|
||||
# We wait for child threads ourself (ThreadingMixIn)
|
||||
block_on_close: bool = False
|
||||
daemon_threads: bool = True
|
||||
|
||||
def __init__(self, configuration: config.Configuration, family: int,
|
||||
address: Tuple[str, int], RequestHandlerClass:
|
||||
Callable[..., http.server.BaseHTTPRequestHandler]) -> None:
|
||||
self.configuration = configuration
|
||||
self.address_family = family
|
||||
super().__init__(address, RequestHandlerClass)
|
||||
self.client_sockets = set()
|
||||
self.worker_sockets = set()
|
||||
self._timeout = configuration.get("server", "timeout")
|
||||
|
||||
def server_bind(self):
|
||||
def server_bind(self) -> None:
|
||||
if self.address_family == socket.AF_INET6:
|
||||
# Only allow IPv6 connections to the IPv6 socket
|
||||
self.socket.setsockopt(COMPAT_IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
|
||||
super().server_bind()
|
||||
|
||||
def get_request(self):
|
||||
def get_request( # type:ignore[override]
|
||||
self) -> Tuple[socket.socket, Tuple[ADDRESS_TYPE, socket.socket]]:
|
||||
# Set timeout for client
|
||||
request, client_address = super().get_request()
|
||||
timeout = self.configuration.get("server", "timeout")
|
||||
if timeout:
|
||||
request.settimeout(timeout)
|
||||
client_socket, client_socket_out = socket.socketpair()
|
||||
self.client_sockets.add(client_socket_out)
|
||||
return request, (*client_address, client_socket)
|
||||
request: socket.socket
|
||||
client_address: ADDRESS_TYPE
|
||||
request, client_address = super().get_request() # type:ignore[misc]
|
||||
if self._timeout > 0:
|
||||
request.settimeout(self._timeout)
|
||||
worker_socket, worker_socket_out = socket.socketpair()
|
||||
self.worker_sockets.add(worker_socket_out)
|
||||
# HACK: Forward `worker_socket` via `client_address` return value
|
||||
# to worker thread.
|
||||
# The super class calls `verify_request`, `process_request` and
|
||||
# `handle_error` with modified `client_address` value.
|
||||
return request, (client_address, worker_socket)
|
||||
|
||||
def finish_request_locked(self, request, client_address):
|
||||
return super().finish_request(request, client_address)
|
||||
def verify_request( # type:ignore[override]
|
||||
self, request: socket.socket, client_address_and_socket:
|
||||
Tuple[ADDRESS_TYPE, socket.socket]) -> bool:
|
||||
return True
|
||||
|
||||
def finish_request(self, request, client_address):
|
||||
*client_address, client_socket = client_address
|
||||
client_address = tuple(client_address)
|
||||
def process_request( # type:ignore[override]
|
||||
self, request: socket.socket, client_address_and_socket:
|
||||
Tuple[ADDRESS_TYPE, socket.socket]) -> None:
|
||||
# HACK: Super class calls `finish_request` in new thread with
|
||||
# `client_address_and_socket`
|
||||
return super().process_request(
|
||||
request, client_address_and_socket) # type:ignore[arg-type]
|
||||
|
||||
def finish_request( # type:ignore[override]
|
||||
self, request: socket.socket, client_address_and_socket:
|
||||
Tuple[ADDRESS_TYPE, socket.socket]) -> None:
|
||||
# HACK: Unpack `client_address_and_socket` and call super class
|
||||
# `finish_request` with original `client_address`
|
||||
client_address, worker_socket = client_address_and_socket
|
||||
try:
|
||||
return self.finish_request_locked(request, client_address)
|
||||
finally:
|
||||
client_socket.close()
|
||||
worker_socket.close()
|
||||
|
||||
def handle_error(self, request, client_address):
|
||||
if issubclass(sys.exc_info()[0], socket.timeout):
|
||||
def finish_request_locked(self, request: socket.socket,
|
||||
client_address: ADDRESS_TYPE) -> None:
|
||||
return super().finish_request(
|
||||
request, client_address) # type:ignore[arg-type]
|
||||
|
||||
def handle_error( # type:ignore[override]
|
||||
self, request: socket.socket,
|
||||
client_address_or_client_address_and_socket:
|
||||
Union[ADDRESS_TYPE, Tuple[ADDRESS_TYPE, socket.socket]]) -> None:
|
||||
# HACK: This method can be called with the modified
|
||||
# `client_address_and_socket` or the original `client_address` value
|
||||
e = sys.exc_info()[1]
|
||||
assert e is not None
|
||||
if isinstance(e, socket.timeout):
|
||||
logger.info("Client timed out", exc_info=True)
|
||||
else:
|
||||
logger.error("An exception occurred during request: %s",
|
||||
|
@ -110,12 +153,12 @@ class ParallelHTTPServer(socketserver.ThreadingMixIn,
|
|||
|
||||
class ParallelHTTPSServer(ParallelHTTPServer):
|
||||
|
||||
def server_bind(self):
|
||||
def server_bind(self) -> None:
|
||||
super().server_bind()
|
||||
# Wrap the TCP socket in an SSL socket
|
||||
certfile = self.configuration.get("server", "certificate")
|
||||
keyfile = self.configuration.get("server", "key")
|
||||
cafile = self.configuration.get("server", "certificate_authority")
|
||||
certfile: str = self.configuration.get("server", "certificate")
|
||||
keyfile: str = self.configuration.get("server", "key")
|
||||
cafile: str = self.configuration.get("server", "certificate_authority")
|
||||
# Test if the files can be read
|
||||
for name, filename in [("certificate", certfile), ("key", keyfile),
|
||||
("certificate_authority", cafile)]:
|
||||
|
@ -139,7 +182,9 @@ class ParallelHTTPSServer(ParallelHTTPServer):
|
|||
self.socket = context.wrap_socket(
|
||||
self.socket, server_side=True, do_handshake_on_connect=False)
|
||||
|
||||
def finish_request_locked(self, request, client_address):
|
||||
def finish_request_locked( # type:ignore[override]
|
||||
self, request: ssl.SSLSocket, client_address: ADDRESS_TYPE
|
||||
) -> None:
|
||||
try:
|
||||
try:
|
||||
request.do_handshake()
|
||||
|
@ -151,7 +196,7 @@ class ParallelHTTPSServer(ParallelHTTPServer):
|
|||
try:
|
||||
self.handle_error(request, client_address)
|
||||
finally:
|
||||
self.shutdown_request(request)
|
||||
self.shutdown_request(request) # type:ignore[attr-defined]
|
||||
return
|
||||
return super().finish_request_locked(request, client_address)
|
||||
|
||||
|
@ -161,30 +206,34 @@ class ServerHandler(wsgiref.simple_server.ServerHandler):
|
|||
# Don't pollute WSGI environ with OS environment
|
||||
os_environ: MutableMapping[str, str] = {}
|
||||
|
||||
def log_exception(self, exc_info):
|
||||
def log_exception(self, exc_info: "wsgiref.handlers._exc_info") -> None:
|
||||
logger.error("An exception occurred during request: %s",
|
||||
exc_info[1], exc_info=exc_info)
|
||||
exc_info[1], exc_info=exc_info) # type:ignore[arg-type]
|
||||
|
||||
|
||||
class RequestHandler(wsgiref.simple_server.WSGIRequestHandler):
|
||||
"""HTTP requests handler."""
|
||||
|
||||
def log_request(self, code="-", size="-"):
|
||||
# HACK: Assigned in `socketserver.StreamRequestHandler`
|
||||
connection: socket.socket
|
||||
|
||||
def log_request(self, code: Union[int, str] = "-",
|
||||
size: Union[int, str] = "-") -> None:
|
||||
pass # Disable request logging.
|
||||
|
||||
def log_error(self, format_, *args):
|
||||
def log_error(self, format_: str, *args: Any) -> None:
|
||||
logger.error("An error occurred during request: %s", format_ % args)
|
||||
|
||||
def get_environ(self):
|
||||
def get_environ(self) -> Dict[str, Any]:
|
||||
env = super().get_environ()
|
||||
if hasattr(self.connection, "getpeercert"):
|
||||
if isinstance(self.connection, ssl.SSLSocket):
|
||||
# The certificate can be evaluated by the auth module
|
||||
env["REMOTE_CERTIFICATE"] = self.connection.getpeercert()
|
||||
# Parent class only tries latin1 encoding
|
||||
env["PATH_INFO"] = unquote(self.path.split("?", 1)[0])
|
||||
return env
|
||||
|
||||
def handle(self):
|
||||
def handle(self) -> None:
|
||||
"""Copy of WSGIRequestHandler.handle with different ServerHandler"""
|
||||
|
||||
self.raw_requestline = self.rfile.readline(65537)
|
||||
|
@ -201,11 +250,13 @@ class RequestHandler(wsgiref.simple_server.WSGIRequestHandler):
|
|||
handler = ServerHandler(
|
||||
self.rfile, self.wfile, self.get_stderr(), self.get_environ()
|
||||
)
|
||||
handler.request_handler = self
|
||||
handler.run(self.server.get_app())
|
||||
handler.request_handler = self # type:ignore[attr-defined]
|
||||
app = self.server.get_app() # type:ignore[attr-defined]
|
||||
handler.run(app)
|
||||
|
||||
|
||||
def serve(configuration, shutdown_socket=None):
|
||||
def serve(configuration: config.Configuration,
|
||||
shutdown_socket: Optional[socket.socket] = None) -> None:
|
||||
"""Serve radicale from configuration.
|
||||
|
||||
`shutdown_socket` can be used to gracefully shutdown the server.
|
||||
|
@ -221,12 +272,13 @@ def serve(configuration, shutdown_socket=None):
|
|||
configuration.update({"server": {"_internal_server": "True"}}, "server",
|
||||
privileged=True)
|
||||
|
||||
use_ssl = configuration.get("server", "ssl")
|
||||
use_ssl: bool = configuration.get("server", "ssl")
|
||||
server_class = ParallelHTTPSServer if use_ssl else ParallelHTTPServer
|
||||
application = Application(configuration)
|
||||
servers = {}
|
||||
try:
|
||||
for address in configuration.get("server", "hosts"):
|
||||
hosts: List[Tuple[str, int]] = configuration.get("server", "hosts")
|
||||
for address in hosts:
|
||||
# Try to bind sockets for IPv4 and IPv6
|
||||
possible_families = (socket.AF_INET, socket.AF_INET6)
|
||||
bind_ok = False
|
||||
|
@ -270,16 +322,16 @@ def serve(configuration, shutdown_socket=None):
|
|||
|
||||
# Mainloop
|
||||
select_timeout = None
|
||||
if os.name == "nt":
|
||||
if sys.platform == "win32":
|
||||
# Fallback to busy waiting. (select(...) blocks SIGINT on Windows.)
|
||||
select_timeout = 1.0
|
||||
max_connections = configuration.get("server", "max_connections")
|
||||
max_connections: int = configuration.get("server", "max_connections")
|
||||
logger.info("Radicale server ready")
|
||||
while True:
|
||||
rlist = []
|
||||
rlist: List[socket.socket] = []
|
||||
# Wait for finished clients
|
||||
for server in servers.values():
|
||||
rlist.extend(server.client_sockets)
|
||||
rlist.extend(server.worker_sockets)
|
||||
# Accept new connections if max_connections is not reached
|
||||
if max_connections <= 0 or len(rlist) < max_connections:
|
||||
rlist.extend(servers)
|
||||
|
@ -287,26 +339,26 @@ def serve(configuration, shutdown_socket=None):
|
|||
if shutdown_socket is not None:
|
||||
rlist.append(shutdown_socket)
|
||||
rlist, _, _ = select.select(rlist, [], [], select_timeout)
|
||||
rlist = set(rlist)
|
||||
if shutdown_socket in rlist:
|
||||
rset = set(rlist)
|
||||
if shutdown_socket in rset:
|
||||
logger.info("Stopping Radicale")
|
||||
break
|
||||
for server in servers.values():
|
||||
finished_sockets = server.client_sockets.intersection(rlist)
|
||||
finished_sockets = server.worker_sockets.intersection(rset)
|
||||
for s in finished_sockets:
|
||||
s.close()
|
||||
server.client_sockets.remove(s)
|
||||
rlist.remove(s)
|
||||
server.worker_sockets.remove(s)
|
||||
rset.remove(s)
|
||||
if finished_sockets:
|
||||
server.service_actions()
|
||||
if rlist:
|
||||
server = servers.get(rlist.pop())
|
||||
if server:
|
||||
server.handle_request()
|
||||
if rset:
|
||||
active_server = servers.get(rset.pop())
|
||||
if active_server:
|
||||
active_server.handle_request()
|
||||
finally:
|
||||
# Wait for clients to finish and close servers
|
||||
for server in servers.values():
|
||||
for s in server.client_sockets:
|
||||
for s in server.worker_sockets:
|
||||
s.recv(1)
|
||||
s.close()
|
||||
server.server_close()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue