• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ultimate-notion / ultimate-notion / 15740234214

18 Jun 2025 06:04PM CUT coverage: 89.532% (+0.1%) from 89.392%
15740234214

Pull #79

github

web-flow
Merge cfb4f1aab into 949e4df22
Pull Request #79: Change the attributes of database property types.

66 of 71 new or added lines in 4 files covered. (92.96%)

39 existing lines in 3 files now uncovered.

5568 of 6219 relevant lines covered (89.53%)

5.34 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

72.65
/src/ultimate_notion/utils.py
1
"""Additional utilities that fit nowhere else."""
2

3
from __future__ import annotations
6✔
4

5
import datetime as dt
6✔
6
import re
6✔
7
import textwrap
6✔
8
from collections.abc import Callable, Generator, Mapping, Sequence
6✔
9
from contextlib import contextmanager
6✔
10
from copy import deepcopy
6✔
11
from functools import wraps
6✔
12
from hashlib import sha256
6✔
13
from itertools import chain
6✔
14
from pathlib import Path
6✔
15
from typing import Any, TypeVar
6✔
16

17
import numpy as np
6✔
18
import pendulum as pnd
6✔
19
from packaging.version import Version
6✔
20
from pydantic import BaseModel
6✔
21

22
from ultimate_notion import __version__
6✔
23
from ultimate_notion.errors import EmptyListError, MultipleItemsError
6✔
24

25
T = TypeVar('T')  # ToDo: Use new syntax when requires-python >= 3.12
6✔
26

27

28
class SList(list[T]):
6✔
29
    """A list that holds often only a single element."""
30

31
    def item(self) -> T:
6✔
32
        if len(self) == 1:
6✔
33
            return self[0]
6✔
34
        elif len(self) == 0:
6✔
35
            msg = 'list is empty'
6✔
36
            raise EmptyListError(msg)
6✔
37
        else:
38
            msg = 'list has multiple items'
6✔
39
            raise MultipleItemsError(msg)
6✔
40

41

42
def flatten(nested_list: Sequence[Sequence[T]], /) -> list[T]:
6✔
43
    """Flatten a nested list."""
44
    return list(chain.from_iterable(nested_list))
6✔
45

46

47
def safe_list_get(lst: Sequence[T], idx: int, *, default: T | None = None) -> T | None:
6✔
48
    """Get the element at the index of the list or return the default value."""
49
    try:
6✔
50
        return lst[idx]
6✔
51
    except IndexError:
6✔
52
        return default
6✔
53

54

55
def is_notebook() -> bool:
6✔
56
    """Determine if we are running within a Jupyter notebook."""
57
    try:
5✔
58
        from IPython import get_ipython  # noqa: PLC0415
5✔
59
    except ModuleNotFoundError:
×
60
        return False  # Probably standard Python interpreter
×
61
    else:
62
        shell = get_ipython().__class__.__name__
5✔
63
        if shell == 'ZMQInteractiveShell':
5✔
64
            return True  # Jupyter notebook or qtconsole
×
65
        elif shell == 'TerminalInteractiveShell':
5✔
66
            return False  # Terminal running IPython
×
67
        else:
68
            return False  # Other type (?)
5✔
69

70

71
def store_retvals(func):
6✔
72
    """Decorator storing the return values as function attribute for later cleanups.
73

74
    This can be used for instance in a generator like this:
75
    ```
76
    @pytest.fixture
77
    def create_blank_db(notion, test_area):
78
        @store_retvals
79
        def nested_func(db_name):
80
            db = notion.databases.create(
81
                parent=test_area,
82
                title=db_name,
83
                schema={
84
                    'Name': schema.Title(),
85
                },
86
            )
87
            return db
88

89
        yield nested_func
90

91
        # clean up by deleting the db of each prior call
92
        for db in nested_func.retvals:
93
            notion.databases.delete(db)
94
    ```
95
    """
96

97
    @wraps(func)
×
98
    def wrapped(*args, **kwargs):
×
99
        retval = func(*args, **kwargs)
×
100
        wrapped.retvals.append(retval)
×
101
        return retval
×
102

103
    wrapped.retvals = []
×
104
    return wrapped
×
105

106

107
def find_indices(elements: np.ndarray | Sequence[Any], total_set: np.ndarray | Sequence[Any]) -> np.ndarray:
6✔
108
    """Finds the indices of the elements in the total set."""
109
    if not isinstance(total_set, np.ndarray):
6✔
110
        total_set = np.array(total_set)
×
111
    mask = np.isin(total_set, elements)
6✔
112
    indices = np.where(mask)[0]
6✔
113
    lookup = dict(zip(total_set[mask], indices, strict=True))
6✔
114
    result = np.array([lookup.get(x) for x in elements])
6✔
115
    return result
6✔
116

117

118
def find_index(elem: Any, lst: list[Any]) -> int | None:
6✔
119
    """Find the index of the element in the list or return `None`."""
120
    try:
6✔
121
        return lst.index(elem)
6✔
122
    except ValueError:
6✔
123
        return None
6✔
124

125

126
def deepcopy_with_sharing(obj: Any, shared_attributes: Sequence[str], memo: dict[int, Any] | None = None):
6✔
127
    """Like `deepcopy` but specified attributes are shared.
128

129
    Deepcopy an object, except for a given list of attributes, which should
130
    be shared between the original object and its copy.
131

132
    Args:
133
        obj: some object to copy
134
        shared_attributes: A list of strings identifying the attributes that should be shared instead of copied.
135
        memo: dictionary passed into __deepcopy__.  Ignore this argument if not calling from within __deepcopy__.
136

137
    Example:
138
        ```python
139
        class A(object):
140
            def __init__(self):
141
                self.copy_me = []
142
                self.share_me = []
143

144
            def __deepcopy__(self, memo):
145
                return deepcopy_with_sharing(
146
                    self, shared_attribute_names=['share_me'], memo=memo
147
                )
148

149

150
        a = A()
151
        b = deepcopy(a)
152
        assert a.copy_me is not b.copy_me
153
        assert a.share_me is b.share_me
154

155
        c = deepcopy(b)
156
        assert c.copy_me is not b.copy_me
157
        assert c.share_me is b.share_me
158
        ```
159

160
    Original from https://stackoverflow.com/a/24621200
161
    """
162
    shared_attrs = {k: getattr(obj, k) for k in shared_attributes}
6✔
163

164
    deepcopy_defined = hasattr(obj, '__deepcopy__')
6✔
165
    if deepcopy_defined:
6✔
166
        # Do hack to prevent infinite recursion in call to deepcopy
167
        deepcopy_method = obj.__deepcopy__
×
168
        obj.__deepcopy__ = None
×
169

170
    for attr in shared_attributes:
6✔
171
        del obj.__dict__[attr]
6✔
172

173
    clone = deepcopy(obj, memo)
6✔
174

175
    for attr, val in shared_attrs.items():
6✔
176
        setattr(obj, attr, val)
6✔
177
        setattr(clone, attr, val)
6✔
178

179
    if deepcopy_defined:
6✔
180
        # Undo hack
181
        obj.__deepcopy__ = deepcopy_method
×
182
        del clone.__deepcopy__
×
183

184
    return clone
6✔
185

186

187
KT = TypeVar('KT')  # ToDo: Use new syntax when requires-python >= 3.12
6✔
188
VT = TypeVar('VT')  # ToDo: Use new syntax when requires-python >= 3.12
6✔
189

190

191
def dict_diff(dct1: Mapping[KT, VT], dct2: Mapping[KT, VT]) -> tuple[list[KT], list[KT], dict[KT, tuple[VT, VT]]]:
6✔
192
    """Returns the added keys, removed keys and keys of changed values of both dictionaries."""
193
    set1, set2 = set(dct1.keys()), set(dct2.keys())
6✔
194
    keys_added = list(set2 - set1)
6✔
195
    keys_removed = list(set1 - set2)
6✔
196
    values_changed = {key: (dct1[key], dct2[key]) for key in set1 & set2 if dct1[key] != dct2[key]}
6✔
197
    return keys_added, keys_removed, values_changed
6✔
198

199

200
def dict_diff_str(dct1: Mapping[KT, VT], dct2: Mapping[KT, VT]) -> tuple[str, str, str]:
6✔
201
    """Returns the added keys, removed keys and keys of changed values of both dictionaries as strings for printing."""
202
    keys_added, keys_removed, values_changed = dict_diff(dct1, dct2)
6✔
203
    keys_added_str = ', '.join([str(k) for k in keys_added]) or 'None'
6✔
204
    keys_removed_str = ', '.join([str(k) for k in keys_removed]) or 'None'
6✔
205
    keys_changed_str = ', '.join(f'{k}: {v[0]} -> {v[1]}' for k, v in values_changed.items()) or 'None'
6✔
206
    return keys_added_str, keys_removed_str, keys_changed_str
6✔
207

208

209
def convert_md_to_py(path: Path | str, *, target_path: Path | str | None = None) -> None:
6✔
210
    """Converts a Markdown file to a py file by extracting all python codeblocks
211

212
    Args:
213
        path: Path to the Markdown file to convert
214
        target_path: Path to save the new Python file. If not provided, the new file will be the same file with .py
215

216
    !!! warning
217

218
        If a file with the same name already exists, it will be overwritten.
219
    """
220
    if isinstance(path, str):
×
221
        path = Path(path)
×
222
    if not path.is_file():
×
223
        msg = f'{path} is no file!'
×
224
        raise RuntimeError(msg)
×
225

226
    if target_path is None:
×
227
        target_path = path.with_suffix('.py')
×
228
    elif isinstance(target_path, str):
×
229
        target_path = Path(target_path)
×
230

231
    md_str = path.read_text()
×
232

233
    def check_codeblock(block):
×
234
        first_line = block.split('\n')[0]
×
235
        if first_line[3:] != 'python':
×
236
            return ''
×
237
        return '\n'.join(block.split('\n')[1:])
×
238

239
    docstring = textwrap.dedent(md_str)
×
240
    in_block = False
×
241
    block = ''
×
242
    codeblocks = []
×
243
    for line in docstring.split('\n'):
×
244
        if line.startswith('```'):
×
245
            if in_block:
×
246
                codeblocks.append(check_codeblock(block))
×
247
                block = ''
×
248
            in_block = not in_block
×
249
        if in_block:
×
250
            block += line + '\n'
×
251
    py_str = '\n'.join([c for c in codeblocks if c != ''])
×
252

253
    target_path.with_suffix('.py').write_text(py_str)
×
254

255

256
def str_hash(*args: str, n_chars: int = 16) -> str:
6✔
257
    """Hashes string arguments to a n-character string."""
258
    return sha256(''.join(args).encode('utf-8')).hexdigest()[:n_chars]
5✔
259

260

261
def rank(arr: np.ndarray) -> np.ndarray:
6✔
262
    """Returns the rank of the elements in the array and gives the same rank to equal elements."""
263
    mask = np.argsort(arr)
6✔
264
    rank = np.zeros_like(arr)
6✔
265
    rank[1:] = np.cumsum(np.where(np.diff(arr[mask]) != 0, 1, 0))
6✔
266
    return rank[np.argsort(mask)]
6✔
267

268

269
def is_stable_version(version_str: str) -> bool:
6✔
270
    """Return whether the given version is a stable release."""
271
    version = Version(version_str)
6✔
272
    return not (version.is_prerelease or version.is_devrelease or version.is_postrelease)
6✔
273

274

275
def is_stable_release() -> bool:
6✔
276
    """Return whether the current version is a stable release."""
277
    return is_stable_version(__version__)
6✔
278

279

280
def parse_dt_str(dt_str: str) -> pnd.DateTime | pnd.Date | pnd.Interval:
6✔
281
    """Parse typical Notion date/datetime/interval strings to pendulum objects.
282

283
    If no timezone is provided assume local timezone and convert everything else to UTC for consistency."""
284

285
    def set_tz(dt_spec: pnd.DateTime | pnd.Date | dt.datetime | dt.date) -> pnd.DateTime | pnd.Date:
6✔
286
        """Set the timezone of the datetime specifier object if necessary."""
287
        match dt_spec:
6✔
288
            case pnd.DateTime() if dt_spec.tz is None:
6✔
289
                return dt_spec.in_tz('local')
6✔
290
            case pnd.DateTime():
6✔
291
                return dt_spec.in_tz('UTC')  # to avoid unnamed timezones we convert to UTC
6✔
292
            case pnd.Date():
6✔
293
                return dt_spec  # as it is a date and has no tz information
6✔
294
            case _:
×
295
                msg = f'Unexpected type `{type(dt_spec)}` for `{dt_spec}`'
×
296
                raise TypeError(msg)
×
297

298
    # Handle strings with "Europe/Berlin" and "UTC" style timezone
299
    if match := re.match(r'(.+)\s+([A-Za-z/]+)$', dt_str.strip()):
6✔
300
        dt_part, tz_part = match.groups()
6✔
301
        dt_spec = pnd.parse(dt_part, exact=True, tz=None)
6✔
302
        match dt_spec:
6✔
303
            case pnd.DateTime():
6✔
304
                dt_spec = dt_spec.in_tz(tz_part)
6✔
305
            case _:
×
306
                msg = f'Expected a datetime string but got {dt_str}'
×
307
                raise ValueError(msg)
×
308
    else:
309
        dt_spec = pnd.parse(dt_str, exact=True, tz=None)
6✔
310

311
    match dt_spec:
6✔
312
        case pnd.DateTime():
6✔
313
            return set_tz(dt_spec)
6✔
314
        case pnd.Date():
6✔
315
            return dt_spec
6✔
316
        case pnd.Interval():
6✔
317
            # We extend the interval to the full day if only a date is given
318
            start, end = set_tz(dt_spec.start), set_tz(dt_spec.end)
6✔
319
            if not isinstance(dt_spec.start, pnd.DateTime):
6✔
320
                start = pnd.datetime(start.year, start.month, start.day, 0, 0, 0)
6✔
321
            if not isinstance(dt_spec.end, pnd.DateTime):
6✔
322
                end = pnd.datetime(end.year, end.month, end.day, 23, 59, 59)
6✔
323
            return pnd.Interval(start=start, end=end)
6✔
324
        case _:
×
325
            msg = f'Unexpected parsing result of type {type(dt_spec)} for {dt_str}'
×
326
            raise TypeError(msg)
×
327

328

329
def is_dt_str(dt_str: str) -> bool:
6✔
330
    """Check if the given string is a valid datetime string."""
331
    try:
×
332
        parse_dt_str(dt_str)
×
333
        return True
×
334
    except (ValueError, TypeError):
×
335
        return False
×
336

337

338
def to_pendulum(dt_spec: str | dt.datetime | dt.date | pnd.Interval) -> pnd.DateTime | pnd.Date | pnd.Interval:
6✔
339
    """Convert a datetime or date object to a pendulum object."""
340
    match dt_spec:
6✔
341
        case pnd.DateTime() | pnd.Date() | pnd.Interval():
6✔
342
            return dt_spec
6✔
343
        case str():
6✔
344
            return parse_dt_str(dt_spec)
6✔
345
        case dt.datetime() if dt_spec.tzinfo is None:
6✔
346
            return pnd.instance(dt_spec, tz='local')
6✔
347
        case dt.datetime():
6✔
348
            return pnd.instance(dt_spec).in_tz('UTC')  # to avoid unnamed timezones we convert to UTC
6✔
349
        case dt.date():
6✔
350
            return pnd.instance(dt_spec)
×
351
        case _:
6✔
352
            msg = f'Unexpected type {type(dt_spec)} for {dt_spec}'
6✔
353
            raise TypeError(msg)
6✔
354

355

356
@contextmanager
6✔
357
def temp_timezone(tz: str | pnd.Timezone):
6✔
358
    """Temporarily set the local timezone to the given timezone. Mostly used by unit tests."""
359
    if not isinstance(tz, pnd.Timezone):
6✔
360
        tz = pnd.timezone(tz)
6✔
361

362
    current_tz = pnd.local_timezone()
6✔
363
    if not isinstance(current_tz, pnd.Timezone):
6✔
364
        msg = f'Expected a Timezone object but got type {type(current_tz)}.'
×
365
        raise RuntimeError(msg)
×
366
    pnd.set_local_timezone(tz)
6✔
367
    try:
6✔
368
        yield
6✔
369
    finally:
370
        pnd.set_local_timezone(current_tz)
6✔
371

372

373
PT = TypeVar('PT', bound=BaseModel)  # ToDo: Use new syntax when requires-python >= 3.12
6✔
374

375

376
def del_nested_attr(
6✔
377
    obj: PT, attr_paths: str | Sequence[str] | None, *, inplace: bool = False, missing_ok: bool = False
378
) -> PT:
379
    """Remove nested attributes from an object."""
380
    if attr_paths is None:
6✔
381
        return obj
6✔
382
    if isinstance(attr_paths, str):
6✔
383
        attr_paths = [attr_paths]
×
384

385
    if not inplace:
6✔
386
        obj = obj.model_copy(deep=True)
6✔
387
    for attr_path in attr_paths:
6✔
388
        attrs = attr_path.split('.')
6✔
389

390
        curr_obj: Any = obj
6✔
391
        for lvl, attr in enumerate(attrs[:-1]):
6✔
392
            curr_obj = getattr(curr_obj, attr, None)
6✔
393
            if curr_obj is None and not missing_ok:
6✔
394
                msg = f'{attr} does not exist in {".".join(attrs[: lvl - 1]) if lvl > 1 else "the object"}.'
×
395
                raise AttributeError(msg)
×
396

397
        last_attr = attrs[-1]
6✔
398
        if hasattr(curr_obj, last_attr):
6✔
399
            delattr(curr_obj, last_attr)
6✔
400
        elif not missing_ok:
5✔
401
            msg = f'{last_attr} does not exist in {".".join(attrs[:-2]) if len(attrs) > 1 else "the object"}.'
×
402
            raise AttributeError(msg)
×
403

404
    return obj
6✔
405

406

407
@contextmanager
6✔
408
def temp_attr(obj: object, **kwargs: Any) -> Generator[None, None, None]:
6✔
409
    """
410
    Temporarily sets multiple attributes of an object to specified values,
411
    and restores their original values after the context exits.
412

413
    Args:
414
        obj (object): The object whose attributes will be modified.
415
        **kwargs (Any): The attributes and their temporary values to modify.
416
    """
417
    orig_values = {attr: getattr(obj, attr, None) for attr in kwargs}
6✔
418
    for attr, new_value in kwargs.items():
6✔
419
        setattr(obj, attr, new_value)
6✔
420
    try:
6✔
421
        yield
6✔
422
    finally:
423
        for attr, original_value in orig_values.items():
6✔
424
            setattr(obj, attr, original_value)
6✔
425

426

427
def rec_apply(func: Callable[[Any], Any], obj: Any) -> Any:
6✔
428
    """
429
    Recursively applies a function `func` to all elements in a nested structure.
430

431
    - Applies `func` to every non-container element.
432
    - Recurses into lists and tuples.
433
    - Strings are treated as atomic elements and are **not** considered containers.
434

435
    Example:
436
        rows = [[1, 2], [3, [4, 5]]]
437
        result = recursive_apply(rows, lambda x: x * 2)
438
        print(result)  # [[2, 4], [6, [8, 10]]]
439
    """
440
    if isinstance(obj, list | tuple):
6✔
441
        return type(obj)(rec_apply(func, item) for item in obj)
6✔
442
    else:
443
        return func(obj)
6✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc