1
0
Fork 0
mirror of https://github.com/Kozea/Radicale.git synced 2025-08-10 18:40:53 +00:00

More type hints

This commit is contained in:
Unrud 2021-07-26 20:56:46 +02:00 committed by Unrud
parent 12fe5ce637
commit cecb17df03
51 changed files with 1374 additions and 957 deletions

View file

@ -19,35 +19,40 @@
import math
import xml.etree.ElementTree as ET
from datetime import date, datetime, timedelta, timezone
from itertools import chain
from typing import (Callable, Iterable, Iterator, List, Optional, Sequence,
Tuple)
from radicale import xmlutils
import vobject
from radicale import item, xmlutils
from radicale.log import logger
DAY = timedelta(days=1)
SECOND = timedelta(seconds=1)
DATETIME_MIN = datetime.min.replace(tzinfo=timezone.utc)
DATETIME_MAX = datetime.max.replace(tzinfo=timezone.utc)
TIMESTAMP_MIN = math.floor(DATETIME_MIN.timestamp())
TIMESTAMP_MAX = math.ceil(DATETIME_MAX.timestamp())
DAY: timedelta = timedelta(days=1)
SECOND: timedelta = timedelta(seconds=1)
DATETIME_MIN: datetime = datetime.min.replace(tzinfo=timezone.utc)
DATETIME_MAX: datetime = datetime.max.replace(tzinfo=timezone.utc)
TIMESTAMP_MIN: int = math.floor(DATETIME_MIN.timestamp())
TIMESTAMP_MAX: int = math.ceil(DATETIME_MAX.timestamp())
def date_to_datetime(date_):
"""Transform a date to a UTC datetime.
def date_to_datetime(d: date) -> datetime:
"""Transform any date to a UTC datetime.
If date_ is a datetime without timezone, return as UTC datetime. If date_
If ``d`` is a datetime without timezone, return as UTC datetime. If ``d``
is already a datetime with timezone, return as is.
"""
if not isinstance(date_, datetime):
date_ = datetime.combine(date_, datetime.min.time())
if not date_.tzinfo:
date_ = date_.replace(tzinfo=timezone.utc)
return date_
if not isinstance(d, datetime):
d = datetime.combine(d, datetime.min.time())
if not d.tzinfo:
d = d.replace(tzinfo=timezone.utc)
return d
def comp_match(item, filter_, level=0):
def comp_match(item: "item.Item", filter_: ET.Element, level: int = 0) -> bool:
"""Check whether the ``item`` matches the comp ``filter_``.
If ``level`` is ``0``, the filter is applied on the
@ -70,7 +75,7 @@ def comp_match(item, filter_, level=0):
return True
if not tag:
return False
name = filter_.get("name").upper()
name = filter_.get("name", "").upper()
if len(filter_) == 0:
# Point #1 of rfc4791-9.7.1
return name == tag
@ -104,13 +109,14 @@ def comp_match(item, filter_, level=0):
return True
def prop_match(vobject_item, filter_, ns):
def prop_match(vobject_item: vobject.base.Component,
filter_: ET.Element, ns: str) -> bool:
"""Check whether the ``item`` matches the prop ``filter_``.
See rfc4791-9.7.2 and rfc6352-10.5.1.
"""
name = filter_.get("name").lower()
name = filter_.get("name", "").lower()
if len(filter_) == 0:
# Point #1 of rfc4791-9.7.2
return name in vobject_item.contents
@ -136,20 +142,21 @@ def prop_match(vobject_item, filter_, ns):
return True
def time_range_match(vobject_item, filter_, child_name):
def time_range_match(vobject_item: vobject.base.Component,
filter_: ET.Element, child_name: str) -> bool:
"""Check whether the component/property ``child_name`` of
``vobject_item`` matches the time-range ``filter_``."""
start = filter_.get("start")
end = filter_.get("end")
if not start and not end:
start_text = filter_.get("start")
end_text = filter_.get("end")
if not start_text and not end_text:
return False
if start:
start = datetime.strptime(start, "%Y%m%dT%H%M%SZ")
if start_text:
start = datetime.strptime(start_text, "%Y%m%dT%H%M%SZ")
else:
start = datetime.min
if end:
end = datetime.strptime(end, "%Y%m%dT%H%M%SZ")
if end_text:
end = datetime.strptime(end_text, "%Y%m%dT%H%M%SZ")
else:
end = datetime.max
start = start.replace(tzinfo=timezone.utc)
@ -157,7 +164,8 @@ def time_range_match(vobject_item, filter_, child_name):
matched = False
def range_fn(range_start, range_end, is_recurrence):
def range_fn(range_start: datetime, range_end: datetime,
is_recurrence: bool) -> bool:
nonlocal matched
if start < range_end and range_start < end:
matched = True
@ -166,14 +174,16 @@ def time_range_match(vobject_item, filter_, child_name):
return True
return False
def infinity_fn(start):
def infinity_fn(start: datetime) -> bool:
return False
visit_time_ranges(vobject_item, child_name, range_fn, infinity_fn)
return matched
def visit_time_ranges(vobject_item, child_name, range_fn, infinity_fn):
def visit_time_ranges(vobject_item: vobject.base.Component, child_name: str,
range_fn: Callable[[datetime, datetime, bool], bool],
infinity_fn: Callable[[datetime], bool]) -> None:
"""Visit all time ranges in the component/property ``child_name`` of
`vobject_item`` with visitors ``range_fn`` and ``infinity_fn``.
@ -194,7 +204,8 @@ def visit_time_ranges(vobject_item, child_name, range_fn, infinity_fn):
# recurrences too. This is not respected and client don't seem to bother
# either.
def getrruleset(child, ignore=()):
def getrruleset(child: vobject.base.Component, ignore: Sequence[date]
) -> Tuple[Iterable[date], bool]:
if (hasattr(child, "rrule") and
";UNTIL=" not in child.rrule.value.upper() and
";COUNT=" not in child.rrule.value.upper()):
@ -207,7 +218,8 @@ def visit_time_ranges(vobject_item, child_name, range_fn, infinity_fn):
return filter(lambda dtstart: dtstart not in ignore,
child.getrruleset(addRDate=True)), False
def get_children(components):
def get_children(components: Iterable[vobject.base.Component]) -> Iterator[
Tuple[vobject.base.Component, bool, List[date]]]:
main = None
recurrences = []
for comp in components:
@ -216,7 +228,7 @@ def visit_time_ranges(vobject_item, child_name, range_fn, infinity_fn):
if comp.rruleset:
# Prevent possible infinite loop
raise ValueError("Overwritten recurrence with RRULESET")
yield comp, True, ()
yield comp, True, []
else:
if main is not None:
raise ValueError("Multiple main components")
@ -418,7 +430,9 @@ def visit_time_ranges(vobject_item, child_name, range_fn, infinity_fn):
range_fn(child, child + DAY, False)
def text_match(vobject_item, filter_, child_name, ns, attrib_name=None):
def text_match(vobject_item: vobject.base.Component,
filter_: ET.Element, child_name: str, ns: str,
attrib_name: Optional[str] = None) -> bool:
"""Check whether the ``item`` matches the text-match ``filter_``.
See rfc4791-9.7.5.
@ -432,7 +446,7 @@ def text_match(vobject_item, filter_, child_name, ns, attrib_name=None):
if ns == "CR":
match_type = filter_.get("match-type", match_type)
def match(value):
def match(value: str) -> bool:
value = value.lower()
if match_type == "equals":
return value == text
@ -445,7 +459,7 @@ def text_match(vobject_item, filter_, child_name, ns, attrib_name=None):
raise ValueError("Unexpected text-match match-type: %r" % match_type)
children = getattr(vobject_item, "%s_list" % child_name, [])
if attrib_name:
if attrib_name is not None:
condition = any(
match(attrib) for child in children
for attrib in child.params.get(attrib_name, []))
@ -456,13 +470,14 @@ def text_match(vobject_item, filter_, child_name, ns, attrib_name=None):
return condition
def param_filter_match(vobject_item, filter_, parent_name, ns):
def param_filter_match(vobject_item: vobject.base.Component,
filter_: ET.Element, parent_name: str, ns: str) -> bool:
"""Check whether the ``item`` matches the param-filter ``filter_``.
See rfc4791-9.7.3.
"""
name = filter_.get("name").upper()
name = filter_.get("name", "").upper()
children = getattr(vobject_item, "%s_list" % parent_name, [])
condition = any(name in child.params for child in children)
if len(filter_) > 0:
@ -474,7 +489,8 @@ def param_filter_match(vobject_item, filter_, parent_name, ns):
return condition
def simplify_prefilters(filters, collection_tag="VCALENDAR"):
def simplify_prefilters(filters: Iterable[ET.Element], collection_tag: str
) -> Tuple[Optional[str], int, int, bool]:
"""Creates a simplified condition from ``filters``.
Returns a tuple (``tag``, ``start``, ``end``, ``simple``) where ``tag`` is
@ -483,14 +499,14 @@ def simplify_prefilters(filters, collection_tag="VCALENDAR"):
and the simplified condition are identical.
"""
flat_filters = tuple(chain.from_iterable(filters))
flat_filters = list(chain.from_iterable(filters))
simple = len(flat_filters) <= 1
for col_filter in flat_filters:
if collection_tag != "VCALENDAR":
simple = False
break
if (col_filter.tag != xmlutils.make_clark("C:comp-filter") or
col_filter.get("name").upper() != "VCALENDAR"):
col_filter.get("name", "").upper() != "VCALENDAR"):
simple = False
continue
simple &= len(col_filter) <= 1
@ -498,7 +514,7 @@ def simplify_prefilters(filters, collection_tag="VCALENDAR"):
if comp_filter.tag != xmlutils.make_clark("C:comp-filter"):
simple = False
continue
tag = comp_filter.get("name").upper()
tag = comp_filter.get("name", "").upper()
if comp_filter.find(
xmlutils.make_clark("C:is-not-defined")) is not None:
simple = False
@ -511,17 +527,17 @@ def simplify_prefilters(filters, collection_tag="VCALENDAR"):
if time_filter.tag != xmlutils.make_clark("C:time-range"):
simple = False
continue
start = time_filter.get("start")
end = time_filter.get("end")
if start:
start_text = time_filter.get("start")
end_text = time_filter.get("end")
if start_text:
start = math.floor(datetime.strptime(
start, "%Y%m%dT%H%M%SZ").replace(
start_text, "%Y%m%dT%H%M%SZ").replace(
tzinfo=timezone.utc).timestamp())
else:
start = TIMESTAMP_MIN
if end:
if end_text:
end = math.ceil(datetime.strptime(
end, "%Y%m%dT%H%M%SZ").replace(
end_text, "%Y%m%dT%H%M%SZ").replace(
tzinfo=timezone.utc).timestamp())
else:
end = TIMESTAMP_MAX