1
0
Fork 0
mirror of https://github.com/Kozea/Radicale.git synced 2025-08-01 18:18:31 +00:00

More type hints

This commit is contained in:
Unrud 2021-07-26 20:56:46 +02:00 committed by Unrud
parent 12fe5ce637
commit cecb17df03
51 changed files with 1374 additions and 957 deletions

View file

@ -27,27 +27,35 @@ import binascii
import math
import os
import sys
from datetime import timedelta
from datetime import datetime, timedelta
from hashlib import sha256
from typing import (Any, Callable, List, MutableMapping, Optional, Sequence,
Tuple)
import vobject
from radicale import storage # noqa:F401
from radicale import pathutils
from radicale.item import filter as radicale_filter
from radicale.log import logger
def predict_tag_of_parent_collection(vobject_items):
def predict_tag_of_parent_collection(
vobject_items: Sequence[vobject.base.Component]) -> Optional[str]:
"""Returns the predicted tag or `None`"""
if len(vobject_items) != 1:
return ""
return None
if vobject_items[0].name == "VCALENDAR":
return "VCALENDAR"
if vobject_items[0].name in ("VCARD", "VLIST"):
return "VADDRESSBOOK"
return ""
return None
def predict_tag_of_whole_collection(vobject_items, fallback_tag=None):
def predict_tag_of_whole_collection(
vobject_items: Sequence[vobject.base.Component],
fallback_tag: Optional[str] = None) -> Optional[str]:
"""Returns the predicted tag or `fallback_tag`"""
if vobject_items and vobject_items[0].name == "VCALENDAR":
return "VCALENDAR"
if vobject_items and vobject_items[0].name in ("VCARD", "VLIST"):
@ -58,9 +66,13 @@ def predict_tag_of_whole_collection(vobject_items, fallback_tag=None):
return fallback_tag
def check_and_sanitize_items(vobject_items, is_collection=False, tag=None):
def check_and_sanitize_items(
vobject_items: List[vobject.base.Component],
is_collection: bool = False, tag: str = "") -> None:
"""Check vobject items for common errors and add missing UIDs.
Modifies the list `vobject_items`.
``is_collection`` indicates that vobject_item contains unrelated
components.
@ -169,9 +181,14 @@ def check_and_sanitize_items(vobject_items, is_collection=False, tag=None):
(i.name, repr(tag) if tag else "generic"))
def check_and_sanitize_props(props):
"""Check collection properties for common errors."""
for k, v in props.copy().items(): # Make copy to be able to delete items
def check_and_sanitize_props(props: MutableMapping[Any, Any]
) -> MutableMapping[str, str]:
"""Check collection properties for common errors.
Modifies the dict `props`.
"""
for k, v in list(props.items()): # Make copy to be able to delete items
if not isinstance(k, str):
raise ValueError("Key must be %r not %r: %r" % (
str.__name__, type(k).__name__, k))
@ -182,14 +199,13 @@ def check_and_sanitize_props(props):
raise ValueError("Value of %r must be %r not %r: %r" % (
k, str.__name__, type(v).__name__, v))
if k == "tag":
if not v:
del props[k]
continue
if v not in ("VCALENDAR", "VADDRESSBOOK"):
if v not in ("", "VCALENDAR", "VADDRESSBOOK"):
raise ValueError("Unsupported collection tag: %r" % v)
return props
def find_available_uid(exists_fn, suffix=""):
def find_available_uid(exists_fn: Callable[[str], bool], suffix: str = ""
) -> str:
"""Generate a pseudo-random UID"""
# Prevent infinite loop
for _ in range(1000):
@ -202,7 +218,7 @@ def find_available_uid(exists_fn, suffix=""):
raise RuntimeError("No unique random sequence found")
def get_etag(text):
def get_etag(text: str) -> str:
"""Etag from collection or item.
Encoded as quoted-string (see RFC 2616).
@ -213,13 +229,13 @@ def get_etag(text):
return '"%s"' % etag.hexdigest()
def get_uid(vobject_component):
def get_uid(vobject_component: vobject.base.Component) -> str:
"""UID value of an item if defined."""
return (vobject_component.uid.value
if hasattr(vobject_component, "uid") else None)
return (vobject_component.uid.value or ""
if hasattr(vobject_component, "uid") else "")
def get_uid_from_object(vobject_item):
def get_uid_from_object(vobject_item: vobject.base.Component) -> str:
"""UID value of an calendar/addressbook object."""
if vobject_item.name == "VCALENDAR":
if hasattr(vobject_item, "vevent"):
@ -230,10 +246,10 @@ def get_uid_from_object(vobject_item):
return get_uid(vobject_item.vtodo)
elif vobject_item.name == "VCARD":
return get_uid(vobject_item)
return None
return ""
def find_tag(vobject_item):
def find_tag(vobject_item: vobject.base.Component) -> str:
"""Find component name from ``vobject_item``."""
if vobject_item.name == "VCALENDAR":
for component in vobject_item.components():
@ -242,22 +258,24 @@ def find_tag(vobject_item):
return ""
def find_tag_and_time_range(vobject_item):
"""Find component name and enclosing time range from ``vobject item``.
def find_time_range(vobject_item: vobject.base.Component, tag: str
) -> Tuple[int, int]:
"""Find enclosing time range from ``vobject item``.
Returns a tuple (``tag``, ``start``, ``end``) where ``tag`` is a string
and ``start`` and ``end`` are POSIX timestamps (as int).
``tag`` must be set to the return value of ``find_tag``.
Returns a tuple (``start``, ``end``) where ``start`` and ``end`` are
POSIX timestamps.
This is intened to be used for matching against simplified prefilters.
"""
tag = find_tag(vobject_item)
if not tag:
return (
tag, radicale_filter.TIMESTAMP_MIN, radicale_filter.TIMESTAMP_MAX)
return radicale_filter.TIMESTAMP_MIN, radicale_filter.TIMESTAMP_MAX
start = end = None
def range_fn(range_start, range_end, is_recurrence):
def range_fn(range_start: datetime, range_end: datetime,
is_recurrence: bool) -> bool:
nonlocal start, end
if start is None or range_start < start:
start = range_start
@ -265,7 +283,7 @@ def find_tag_and_time_range(vobject_item):
end = range_end
return False
def infinity_fn(range_start):
def infinity_fn(range_start: datetime) -> bool:
nonlocal start, end
if start is None or range_start < start:
start = range_start
@ -278,7 +296,7 @@ def find_tag_and_time_range(vobject_item):
if end is None:
end = radicale_filter.DATETIME_MAX
try:
return tag, math.floor(start.timestamp()), math.ceil(end.timestamp())
return math.floor(start.timestamp()), math.ceil(end.timestamp())
except ValueError as e:
if str(e) == ("offset must be a timedelta representing a whole "
"number of minutes") and sys.version_info < (3, 6):
@ -289,10 +307,31 @@ def find_tag_and_time_range(vobject_item):
class Item:
"""Class for address book and calendar entries."""
def __init__(self, collection_path=None, collection=None,
vobject_item=None, href=None, last_modified=None, text=None,
etag=None, uid=None, name=None, component_name=None,
time_range=None):
collection: Optional["storage.BaseCollection"]
href: Optional[str]
last_modified: Optional[str]
_collection_path: str
_text: Optional[str]
_vobject_item: Optional[vobject.base.Component]
_etag: Optional[str]
_uid: Optional[str]
_name: Optional[str]
_component_name: Optional[str]
_time_range: Optional[Tuple[int, int]]
def __init__(self,
collection_path: Optional[str] = None,
collection: Optional["storage.BaseCollection"] = None,
vobject_item: Optional[vobject.base.Component] = None,
href: Optional[str] = None,
last_modified: Optional[str] = None,
text: Optional[str] = None,
etag: Optional[str] = None,
uid: Optional[str] = None,
name: Optional[str] = None,
component_name: Optional[str] = None,
time_range: Optional[Tuple[int, int]] = None):
"""Initialize an item.
``collection_path`` the path of the parent collection (optional if
@ -318,8 +357,7 @@ class Item:
``component_name`` the name of the primary component (optional).
See ``find_tag``.
``time_range`` the enclosing time range.
See ``find_tag_and_time_range``.
``time_range`` the enclosing time range. See ``find_time_range``.
"""
if text is None and vobject_item is None:
@ -344,7 +382,7 @@ class Item:
self._component_name = component_name
self._time_range = time_range
def serialize(self):
def serialize(self) -> str:
if self._text is None:
try:
self._text = self.vobject_item.serialize()
@ -366,38 +404,38 @@ class Item:
return self._vobject_item
@property
def etag(self):
def etag(self) -> str:
"""Encoded as quoted-string (see RFC 2616)."""
if self._etag is None:
self._etag = get_etag(self.serialize())
return self._etag
@property
def uid(self):
def uid(self) -> str:
if self._uid is None:
self._uid = get_uid_from_object(self.vobject_item)
return self._uid
@property
def name(self):
def name(self) -> str:
if self._name is None:
self._name = self.vobject_item.name or ""
return self._name
@property
def component_name(self):
if self._component_name is not None:
return self._component_name
return find_tag(self.vobject_item)
def component_name(self) -> str:
if self._component_name is None:
self._component_name = find_tag(self.vobject_item)
return self._component_name
@property
def time_range(self):
def time_range(self) -> Tuple[int, int]:
if self._time_range is None:
self._component_name, *self._time_range = (
find_tag_and_time_range(self.vobject_item))
self._time_range = find_time_range(
self.vobject_item, self.component_name)
return self._time_range
def prepare(self):
def prepare(self) -> None:
"""Fill cache with values."""
orig_vobject_item = self._vobject_item
self.serialize()

View file

@ -19,35 +19,40 @@
import math
import xml.etree.ElementTree as ET
from datetime import date, datetime, timedelta, timezone
from itertools import chain
from typing import (Callable, Iterable, Iterator, List, Optional, Sequence,
Tuple)
from radicale import xmlutils
import vobject
from radicale import item, xmlutils
from radicale.log import logger
DAY = timedelta(days=1)
SECOND = timedelta(seconds=1)
DATETIME_MIN = datetime.min.replace(tzinfo=timezone.utc)
DATETIME_MAX = datetime.max.replace(tzinfo=timezone.utc)
TIMESTAMP_MIN = math.floor(DATETIME_MIN.timestamp())
TIMESTAMP_MAX = math.ceil(DATETIME_MAX.timestamp())
DAY: timedelta = timedelta(days=1)
SECOND: timedelta = timedelta(seconds=1)
DATETIME_MIN: datetime = datetime.min.replace(tzinfo=timezone.utc)
DATETIME_MAX: datetime = datetime.max.replace(tzinfo=timezone.utc)
TIMESTAMP_MIN: int = math.floor(DATETIME_MIN.timestamp())
TIMESTAMP_MAX: int = math.ceil(DATETIME_MAX.timestamp())
def date_to_datetime(date_):
"""Transform a date to a UTC datetime.
def date_to_datetime(d: date) -> datetime:
"""Transform any date to a UTC datetime.
If date_ is a datetime without timezone, return as UTC datetime. If date_
If ``d`` is a datetime without timezone, return as UTC datetime. If ``d``
is already a datetime with timezone, return as is.
"""
if not isinstance(date_, datetime):
date_ = datetime.combine(date_, datetime.min.time())
if not date_.tzinfo:
date_ = date_.replace(tzinfo=timezone.utc)
return date_
if not isinstance(d, datetime):
d = datetime.combine(d, datetime.min.time())
if not d.tzinfo:
d = d.replace(tzinfo=timezone.utc)
return d
def comp_match(item, filter_, level=0):
def comp_match(item: "item.Item", filter_: ET.Element, level: int = 0) -> bool:
"""Check whether the ``item`` matches the comp ``filter_``.
If ``level`` is ``0``, the filter is applied on the
@ -70,7 +75,7 @@ def comp_match(item, filter_, level=0):
return True
if not tag:
return False
name = filter_.get("name").upper()
name = filter_.get("name", "").upper()
if len(filter_) == 0:
# Point #1 of rfc4791-9.7.1
return name == tag
@ -104,13 +109,14 @@ def comp_match(item, filter_, level=0):
return True
def prop_match(vobject_item, filter_, ns):
def prop_match(vobject_item: vobject.base.Component,
filter_: ET.Element, ns: str) -> bool:
"""Check whether the ``item`` matches the prop ``filter_``.
See rfc4791-9.7.2 and rfc6352-10.5.1.
"""
name = filter_.get("name").lower()
name = filter_.get("name", "").lower()
if len(filter_) == 0:
# Point #1 of rfc4791-9.7.2
return name in vobject_item.contents
@ -136,20 +142,21 @@ def prop_match(vobject_item, filter_, ns):
return True
def time_range_match(vobject_item, filter_, child_name):
def time_range_match(vobject_item: vobject.base.Component,
filter_: ET.Element, child_name: str) -> bool:
"""Check whether the component/property ``child_name`` of
``vobject_item`` matches the time-range ``filter_``."""
start = filter_.get("start")
end = filter_.get("end")
if not start and not end:
start_text = filter_.get("start")
end_text = filter_.get("end")
if not start_text and not end_text:
return False
if start:
start = datetime.strptime(start, "%Y%m%dT%H%M%SZ")
if start_text:
start = datetime.strptime(start_text, "%Y%m%dT%H%M%SZ")
else:
start = datetime.min
if end:
end = datetime.strptime(end, "%Y%m%dT%H%M%SZ")
if end_text:
end = datetime.strptime(end_text, "%Y%m%dT%H%M%SZ")
else:
end = datetime.max
start = start.replace(tzinfo=timezone.utc)
@ -157,7 +164,8 @@ def time_range_match(vobject_item, filter_, child_name):
matched = False
def range_fn(range_start, range_end, is_recurrence):
def range_fn(range_start: datetime, range_end: datetime,
is_recurrence: bool) -> bool:
nonlocal matched
if start < range_end and range_start < end:
matched = True
@ -166,14 +174,16 @@ def time_range_match(vobject_item, filter_, child_name):
return True
return False
def infinity_fn(start):
def infinity_fn(start: datetime) -> bool:
return False
visit_time_ranges(vobject_item, child_name, range_fn, infinity_fn)
return matched
def visit_time_ranges(vobject_item, child_name, range_fn, infinity_fn):
def visit_time_ranges(vobject_item: vobject.base.Component, child_name: str,
range_fn: Callable[[datetime, datetime, bool], bool],
infinity_fn: Callable[[datetime], bool]) -> None:
"""Visit all time ranges in the component/property ``child_name`` of
`vobject_item`` with visitors ``range_fn`` and ``infinity_fn``.
@ -194,7 +204,8 @@ def visit_time_ranges(vobject_item, child_name, range_fn, infinity_fn):
# recurrences too. This is not respected and client don't seem to bother
# either.
def getrruleset(child, ignore=()):
def getrruleset(child: vobject.base.Component, ignore: Sequence[date]
) -> Tuple[Iterable[date], bool]:
if (hasattr(child, "rrule") and
";UNTIL=" not in child.rrule.value.upper() and
";COUNT=" not in child.rrule.value.upper()):
@ -207,7 +218,8 @@ def visit_time_ranges(vobject_item, child_name, range_fn, infinity_fn):
return filter(lambda dtstart: dtstart not in ignore,
child.getrruleset(addRDate=True)), False
def get_children(components):
def get_children(components: Iterable[vobject.base.Component]) -> Iterator[
Tuple[vobject.base.Component, bool, List[date]]]:
main = None
recurrences = []
for comp in components:
@ -216,7 +228,7 @@ def visit_time_ranges(vobject_item, child_name, range_fn, infinity_fn):
if comp.rruleset:
# Prevent possible infinite loop
raise ValueError("Overwritten recurrence with RRULESET")
yield comp, True, ()
yield comp, True, []
else:
if main is not None:
raise ValueError("Multiple main components")
@ -418,7 +430,9 @@ def visit_time_ranges(vobject_item, child_name, range_fn, infinity_fn):
range_fn(child, child + DAY, False)
def text_match(vobject_item, filter_, child_name, ns, attrib_name=None):
def text_match(vobject_item: vobject.base.Component,
filter_: ET.Element, child_name: str, ns: str,
attrib_name: Optional[str] = None) -> bool:
"""Check whether the ``item`` matches the text-match ``filter_``.
See rfc4791-9.7.5.
@ -432,7 +446,7 @@ def text_match(vobject_item, filter_, child_name, ns, attrib_name=None):
if ns == "CR":
match_type = filter_.get("match-type", match_type)
def match(value):
def match(value: str) -> bool:
value = value.lower()
if match_type == "equals":
return value == text
@ -445,7 +459,7 @@ def text_match(vobject_item, filter_, child_name, ns, attrib_name=None):
raise ValueError("Unexpected text-match match-type: %r" % match_type)
children = getattr(vobject_item, "%s_list" % child_name, [])
if attrib_name:
if attrib_name is not None:
condition = any(
match(attrib) for child in children
for attrib in child.params.get(attrib_name, []))
@ -456,13 +470,14 @@ def text_match(vobject_item, filter_, child_name, ns, attrib_name=None):
return condition
def param_filter_match(vobject_item, filter_, parent_name, ns):
def param_filter_match(vobject_item: vobject.base.Component,
filter_: ET.Element, parent_name: str, ns: str) -> bool:
"""Check whether the ``item`` matches the param-filter ``filter_``.
See rfc4791-9.7.3.
"""
name = filter_.get("name").upper()
name = filter_.get("name", "").upper()
children = getattr(vobject_item, "%s_list" % parent_name, [])
condition = any(name in child.params for child in children)
if len(filter_) > 0:
@ -474,7 +489,8 @@ def param_filter_match(vobject_item, filter_, parent_name, ns):
return condition
def simplify_prefilters(filters, collection_tag="VCALENDAR"):
def simplify_prefilters(filters: Iterable[ET.Element], collection_tag: str
) -> Tuple[Optional[str], int, int, bool]:
"""Creates a simplified condition from ``filters``.
Returns a tuple (``tag``, ``start``, ``end``, ``simple``) where ``tag`` is
@ -483,14 +499,14 @@ def simplify_prefilters(filters, collection_tag="VCALENDAR"):
and the simplified condition are identical.
"""
flat_filters = tuple(chain.from_iterable(filters))
flat_filters = list(chain.from_iterable(filters))
simple = len(flat_filters) <= 1
for col_filter in flat_filters:
if collection_tag != "VCALENDAR":
simple = False
break
if (col_filter.tag != xmlutils.make_clark("C:comp-filter") or
col_filter.get("name").upper() != "VCALENDAR"):
col_filter.get("name", "").upper() != "VCALENDAR"):
simple = False
continue
simple &= len(col_filter) <= 1
@ -498,7 +514,7 @@ def simplify_prefilters(filters, collection_tag="VCALENDAR"):
if comp_filter.tag != xmlutils.make_clark("C:comp-filter"):
simple = False
continue
tag = comp_filter.get("name").upper()
tag = comp_filter.get("name", "").upper()
if comp_filter.find(
xmlutils.make_clark("C:is-not-defined")) is not None:
simple = False
@ -511,17 +527,17 @@ def simplify_prefilters(filters, collection_tag="VCALENDAR"):
if time_filter.tag != xmlutils.make_clark("C:time-range"):
simple = False
continue
start = time_filter.get("start")
end = time_filter.get("end")
if start:
start_text = time_filter.get("start")
end_text = time_filter.get("end")
if start_text:
start = math.floor(datetime.strptime(
start, "%Y%m%dT%H%M%SZ").replace(
start_text, "%Y%m%dT%H%M%SZ").replace(
tzinfo=timezone.utc).timestamp())
else:
start = TIMESTAMP_MIN
if end:
if end_text:
end = math.ceil(datetime.strptime(
end, "%Y%m%dT%H%M%SZ").replace(
end_text, "%Y%m%dT%H%M%SZ").replace(
tzinfo=timezone.utc).timestamp())
else:
end = TIMESTAMP_MAX