• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

feihoo87 / waveforms / 6534953321

16 Oct 2023 02:19PM UTC coverage: 35.674% (-22.7%) from 58.421%
6534953321

push

github

feihoo87
fix Coveralls

5913 of 16575 relevant lines covered (35.67%)

3.21 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.09
/waveforms/scan/base.py
1

2
import inspect
9✔
3
import logging
9✔
4
import warnings
9✔
5
from abc import ABC, abstractclassmethod
9✔
6
from concurrent.futures import Executor, Future
9✔
7
from dataclasses import dataclass, field
9✔
8
from graphlib import TopologicalSorter
9✔
9
from itertools import chain, count
9✔
10

11
from queue import Empty, Queue
9✔
12
from typing import Any, Callable, Iterable, Sequence, Type
9✔
13

14
log = logging.getLogger(__name__)
9✔
15
log.setLevel(logging.ERROR)
9✔
16

17

18
class BaseOptimizer(ABC):
9✔
19

20
    @abstractclassmethod
9✔
21
    def ask(self) -> tuple:
9✔
22
        pass
×
23

24
    @abstractclassmethod
9✔
25
    def tell(self, suggested: Sequence, value: Any):
9✔
26
        pass
×
27

28
    @abstractclassmethod
9✔
29
    def get_result(self):
9✔
30
        pass
×
31

32

33
@dataclass
9✔
34
class OptimizerConfig():
9✔
35
    cls: Type[BaseOptimizer]
9✔
36
    dimensions: list = field(default_factory=list)
9✔
37
    args: tuple = ()
9✔
38
    kwds: dict = field(default_factory=dict)
9✔
39
    max_iters: int = 100
9✔
40

41

42
class FeedbackPipe():
9✔
43
    __slots__ = (
9✔
44
        'keys',
45
        '_queue',
46
    )
47

48
    def __init__(self, keys):
9✔
49
        self.keys = keys
9✔
50
        self._queue = Queue()
9✔
51

52
    def __iter__(self):
9✔
53
        while True:
9✔
54
            try:
9✔
55
                yield self._queue.get_nowait()
9✔
56
            except Empty:
9✔
57
                break
9✔
58

59
    def __call__(self):
9✔
60
        return self.__iter__()
9✔
61

62
    def send(self, obj):
9✔
63
        self._queue.put(obj)
9✔
64

65
    def __repr__(self):
9✔
66
        if not isinstance(self.keys, tuple):
×
67
            return f'FeedbackProxy({repr(self.keys)})'
×
68
        else:
69
            return f'FeedbackProxy{self.keys}'
×
70

71

72
class FeedbackProxy():
9✔
73

74
    def feedback(self, keywords, obj, suggested=None):
9✔
75
        if keywords in self._pipes:
9✔
76
            if suggested is None:
9✔
77
                suggested = [self.kwds[k] for k in keywords]
9✔
78
            self._pipes[keywords].send((suggested, obj))
9✔
79
        else:
80
            warnings.warn(f'No feedback pipe for {keywords}', RuntimeWarning,
×
81
                          2)
82

83
    def feed(self, obj, **options):
9✔
84
        for tracker in self._trackers:
9✔
85
            tracker.feed(self, obj, **options)
9✔
86

87
    def store(self, obj, **options):
9✔
88
        self.feed(obj, store=True, **options)
9✔
89

90
    def __getstate__(self):
9✔
91
        state = self.__dict__.copy()
×
92
        del state['_pipes']
×
93
        del state['_trackers']
×
94
        return state
×
95

96
    def __setstate__(self, state):
9✔
97
        self.__dict__ = state
×
98
        self._pipes = {}
×
99
        self._trackers = []
×
100

101

102
@dataclass
9✔
103
class StepStatus(FeedbackProxy):
9✔
104
    iteration: int = 0
9✔
105
    pos: tuple = ()
9✔
106
    index: tuple = ()
9✔
107
    kwds: dict = field(default_factory=dict)
9✔
108
    vars: list[str] = field(default=list)
9✔
109
    unchanged: int = 0
9✔
110

111
    _pipes: dict = field(default_factory=dict, repr=False)
9✔
112
    _trackers: list = field(default_factory=list, repr=False)
9✔
113

114

115
@dataclass
9✔
116
class Begin(FeedbackProxy):
9✔
117
    level: int = 0
9✔
118
    iteration: int = 0
9✔
119
    pos: tuple = ()
9✔
120
    index: tuple = ()
9✔
121
    kwds: dict = field(default_factory=dict)
9✔
122
    vars: list[str] = field(default=list)
9✔
123

124
    _pipes: dict = field(default_factory=dict, repr=False)
9✔
125
    _trackers: list = field(default_factory=list, repr=False)
9✔
126

127
    def __repr__(self):
9✔
128
        return f'Begin(level={self.level}, kwds={self.kwds}, vars={self.vars})'
×
129

130

131
@dataclass
9✔
132
class End(FeedbackProxy):
9✔
133
    level: int = 0
9✔
134
    iteration: int = 0
9✔
135
    pos: tuple = ()
9✔
136
    index: tuple = ()
9✔
137
    kwds: dict = field(default_factory=dict)
9✔
138
    vars: list[str] = field(default=list)
9✔
139

140
    _pipes: dict = field(default_factory=dict, repr=False)
9✔
141
    _trackers: list = field(default_factory=list, repr=False)
9✔
142

143
    def __repr__(self):
9✔
144
        return f'End(level={self.level}, kwds={self.kwds}, vars={self.vars})'
×
145

146

147
class Tracker():
9✔
148

149
    def init(self, loops: dict, functions: dict, constants: dict, graph: dict,
9✔
150
             order: list):
151
        pass
×
152

153
    def update(self, kwds: dict):
9✔
154
        return kwds
9✔
155

156
    def feed(self, step: StepStatus, obj: Any, **options):
9✔
157
        pass
×
158

159

160
def _call_func_with_kwds(func, args, kwds):
9✔
161
    funcname = getattr(func, '__name__', repr(func))
9✔
162
    sig = inspect.signature(func)
9✔
163
    for p in sig.parameters.values():
9✔
164
        if p.kind == p.VAR_KEYWORD:
9✔
165
            return func(*args, **kwds)
9✔
166
    kw = {
9✔
167
        k: v
168
        for k, v in kwds.items()
169
        if k in list(sig.parameters.keys())[len(args):]
170
    }
171
    try:
9✔
172
        args = [
9✔
173
            arg.result() if isinstance(arg, Future) else arg for arg in args
174
        ]
175
        kw = {
9✔
176
            k: v.result() if isinstance(v, Future) else v
177
            for k, v in kw.items()
178
        }
179
        return func(*args, **kw)
9✔
180
    except:
×
181
        log.exception(f'Call {funcname} with {args} and {kw}')
×
182
        raise
×
183
    finally:
184
        log.debug(f'Call {funcname} with {args} and {kw}')
9✔
185

186

187
def _try_to_call(x, args, kwds):
9✔
188
    if callable(x):
9✔
189
        return _call_func_with_kwds(x, args, kwds)
9✔
190
    return x
9✔
191

192

193
def _get_current_iters(loops, level, kwds, pipes):
9✔
194
    keys, current = loops[level]
9✔
195
    limit = -1
9✔
196

197
    if isinstance(keys, str):
9✔
198
        keys = (keys, )
9✔
199
        current = (current, )
9✔
200
    elif isinstance(keys, tuple) and isinstance(
9✔
201
            current, tuple) and len(keys) == len(current):
202
        keys = tuple(k if isinstance(k, tuple) else (k, ) for k in keys)
9✔
203
    elif isinstance(keys, tuple) and not isinstance(current, tuple):
×
204
        current = (current, )
×
205
        if isinstance(keys[0], str):
×
206
            keys = (keys, )
×
207
    else:
208
        log.error(f'Illegal keys {keys} on level {level}.')
×
209
        raise TypeError(f'Illegal keys {keys} on level {level}.')
×
210

211
    if not isinstance(keys, tuple):
9✔
212
        keys = (keys, )
×
213
    if not isinstance(current, tuple):
9✔
214
        current = (current, )
×
215

216
    iters = []
9✔
217
    for k, it in zip(keys, current):
9✔
218
        pipe = FeedbackPipe(k)
9✔
219
        if isinstance(it, OptimizerConfig):
9✔
220
            if limit < 0 or limit > it.max_iters:
9✔
221
                limit = it.max_iters
9✔
222
            it = it.cls(it.dimensions, *it.args, **it.kwds)
9✔
223
        else:
224
            it = iter(_try_to_call(it, (), kwds))
9✔
225

226
        iters.append((it, pipe))
9✔
227
        pipes[k] = pipe
9✔
228

229
    return keys, iters, pipes, limit
9✔
230

231

232
def _generate_kwds(keys, iters, kwds, iteration, limit):
9✔
233
    ret = {}
9✔
234
    for ks, it in zip(keys, iters):
9✔
235
        if isinstance(ks, str):
9✔
236
            ks = (ks, )
9✔
237
        if hasattr(it[0], 'ask') and hasattr(it[0], 'tell') and hasattr(
9✔
238
                it[0], 'get_result'):
239
            if limit > 0 and iteration >= limit - 1:
9✔
240
                value = _call_func_with_kwds(it[0].get_result, (), kwds).x
9✔
241
            else:
242
                value = _call_func_with_kwds(it[0].ask, (), kwds)
9✔
243
        else:
244
            value = next(it[0])
9✔
245
            if len(ks) == 1:
9✔
246
                value = (value, )
9✔
247
        ret.update(zip(ks, value))
9✔
248
    return ret
9✔
249

250

251
def _send_feedback(generator, feedback):
9✔
252
    if hasattr(generator, 'ask') and hasattr(generator, 'tell') and hasattr(
9✔
253
            generator, 'get_result'):
254
        generator.tell(
9✔
255
            *[x.result() if isinstance(x, Future) else x for x in feedback])
256

257

258
def _feedback(iters):
9✔
259
    for generator, pipe in iters:
9✔
260
        for feedback in pipe():
9✔
261
            _send_feedback(generator, feedback)
9✔
262

263

264
def _call_functions(functions, kwds, order, pool: Executor | None = None):
9✔
265
    vars = []
9✔
266
    for i, ready in enumerate(order):
9✔
267
        rest = []
9✔
268
        for k in ready:
9✔
269
            if k in kwds:
9✔
270
                continue
9✔
271
            elif k in functions:
9✔
272
                if pool is None:
9✔
273
                    kwds[k] = _try_to_call(functions[k], (), kwds)
9✔
274
                else:
275
                    kwds[k] = pool.submit(_try_to_call, functions[k], (), kwds)
×
276
                vars.append(k)
9✔
277
            else:
278
                rest.append(k)
9✔
279
        if rest:
9✔
280
            break
9✔
281
    else:
282
        return [], vars
9✔
283
    if rest:
9✔
284
        return [rest] + order[i:], vars
9✔
285
    else:
286
        return order[i:], vars
×
287

288

289
def _args_generator(loops: list,
9✔
290
                    kwds: dict[str, Any],
291
                    level: int,
292
                    pos: tuple[int, ...],
293
                    vars: list[tuple[str]],
294
                    filter: Callable[..., bool] | None,
295
                    functions: dict[str, Callable],
296
                    trackers: list[Tracker],
297
                    pipes: dict[str | tuple[str, ...], FeedbackPipe],
298
                    order: list[str],
299
                    pool: Executor | None = None):
300
    order, local_vars = _call_functions(functions, kwds, order, pool)
9✔
301
    if len(loops) == level and level > 0:
9✔
302
        if order:
9✔
303
            log.error(f'Unresolved functions: {order}')
×
304
            raise TypeError(f'Unresolved functions: {order}')
×
305
        for tracker in trackers:
9✔
306
            kwds = tracker.update(kwds)
9✔
307
        if filter is None or _call_func_with_kwds(filter, (), kwds):
9✔
308
            yield StepStatus(
9✔
309
                pos=pos,
310
                kwds=kwds,
311
                vars=[*vars[:-1], tuple([*vars[-1], *local_vars])],
312
                _pipes=pipes,
313
                _trackers=trackers)
314
        return
9✔
315

316
    keys, current_iters, pipes, limit = _get_current_iters(
9✔
317
        loops, level, kwds, pipes)
318

319
    for i in count():
9✔
320
        if limit > 0 and i >= limit:
9✔
321
            break
9✔
322
        try:
9✔
323
            kw = _generate_kwds(keys, current_iters, kwds, i, limit)
9✔
324
        except StopIteration:
9✔
325
            break
9✔
326
        yield Begin(level=level,
9✔
327
                    pos=pos + (i, ),
328
                    kwds=kwds | kw,
329
                    vars=[*vars, tuple([*local_vars, *kw.keys()])],
330
                    _pipes=pipes,
331
                    _trackers=trackers)
332
        yield from _args_generator(
9✔
333
            loops, kwds | kw, level + 1, pos + (i, ),
334
            [*vars, tuple([*local_vars, *kw.keys()])], filter, functions,
335
            trackers, pipes, order)
336
        yield End(level=level,
9✔
337
                  pos=pos + (i, ),
338
                  kwds=kwds | kw,
339
                  vars=[*vars, tuple([*local_vars, *kw.keys()])],
340
                  _pipes=pipes,
341
                  _trackers=trackers)
342
        _feedback(current_iters)
9✔
343

344

345
def _find_common_prefix(a: tuple, b: tuple):
9✔
346
    for i, (x, y) in enumerate(zip(a, b)):
9✔
347
        if x != y:
9✔
348
            return i
9✔
349
    return i
×
350

351

352
def _add_dependence(graph, keys, function, loop_names, var_names):
9✔
353
    if isinstance(keys, str):
9✔
354
        keys = (keys, )
9✔
355
    for key in keys:
9✔
356
        graph.setdefault(key, set())
9✔
357
        for k, p in inspect.signature(function).parameters.items():
9✔
358
            if p.kind in [
9✔
359
                    p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY
360
            ] and k in var_names:
361
                graph[key].add(k)
9✔
362
            if p.kind == p.VAR_KEYWORD and key not in loop_names:
9✔
363
                graph[key].update(loop_names)
9✔
364

365

366
def _build_dependence(loops, functions, constants, loop_deps=True):
9✔
367
    graph = {}
9✔
368
    loop_names = set()
9✔
369
    var_names = set()
9✔
370
    for keys, iters in loops.items():
9✔
371
        level_vars = set()
9✔
372
        if isinstance(keys, str):
9✔
373
            keys = (keys, )
9✔
374
        if callable(iters):
9✔
375
            iters = tuple([iters for _ in keys])
9✔
376
        for ks, iter_vars in zip(keys, iters):
9✔
377
            if isinstance(ks, str):
9✔
378
                ks = (ks, )
9✔
379
            if callable(iters):
9✔
380
                iter_vars = tuple([iter_vars for _ in ks])
×
381
            level_vars.update(ks)
9✔
382
            for i, k in enumerate(ks):
9✔
383
                d = graph.setdefault(k, set())
9✔
384
                if loop_deps:
9✔
385
                    d.update(loop_names)
9✔
386
                else:
387
                    if isinstance(iter_vars, tuple):
9✔
388
                        iter_var = iter_vars[i]
×
389
                    else:
390
                        iter_var = iter_vars
9✔
391
                    if callable(iter_var):
9✔
392
                        d.update(
9✔
393
                            set(
394
                                inspect.signature(iter_vars).parameters.keys())
395
                            & loop_names)
396

397
        loop_names.update(level_vars)
9✔
398
        var_names.update(level_vars)
9✔
399
    var_names.update(functions.keys())
9✔
400
    var_names.update(constants.keys())
9✔
401

402
    for keys, values in chain(loops.items(), functions.items()):
9✔
403
        if callable(values):
9✔
404
            _add_dependence(graph, keys, values, loop_names, var_names)
9✔
405
        elif isinstance(values, tuple):
9✔
406
            for ks, v in zip(keys, values):
9✔
407
                if callable(v):
9✔
408
                    _add_dependence(graph, ks, v, loop_names, var_names)
9✔
409

410
    return graph
9✔
411

412

413
def _get_all_dependence(key, graph):
9✔
414
    ret = set()
×
415
    if key not in graph:
×
416
        return ret
×
417
    for k in graph[key]:
×
418
        ret.add(k)
×
419
        ret.update(_get_all_dependence(k, graph))
×
420
    return ret
×
421

422

423
def scan_iters(loops: dict[str | tuple[str, ...],
9✔
424
                           Iterable | Callable | OptimizerConfig
425
                           | tuple[Iterable | Callable | OptimizerConfig,
426
                                   ...]] = {},
427
               filter: Callable[..., bool] | None = None,
428
               functions: dict[str, Callable] = {},
429
               constants: dict[str, Any] = {},
430
               trackers: list[Tracker] = [],
431
               level_marker: bool = False,
432
               pool: Executor | None = None,
433
               **kwds) -> Iterable[StepStatus]:
434
    """
435
    Scan the given iterable of iterables.
436

437
    Parameters
438
    ----------
439
    loops : dict
440
        A map of iterables that are scanned.
441
    filter : Callable[..., bool]
442
        A filter function that is called for each step.
443
        If it returns False, the step is skipped.
444
    functions : dict
445
        A map of functions that are called for each step.
446
    constants : dict
447
        Additional keyword arguments that are passed to the iterables.
448

449
    Returns
450
    -------
451
    Iterable[StepStatus]
452
        An iterable of StepStatus objects.
453

454
    Examples
455
    --------
456
    >>> iters = {
457
    ...     'a': range(2),
458
    ...     'b': range(3),
459
    ... }
460
    >>> list(scan_iters(iters))
461
    [StepStatus(iteration=0, pos=(0, 0), index=(0, 0), kwds={'a': 0, 'b': 0}),
462
     StepStatus(iteration=1, pos=(0, 1), index=(0, 1), kwds={'a': 0, 'b': 1}),
463
     StepStatus(iteration=2, pos=(0, 2), index=(0, 2), kwds={'a': 0, 'b': 2}),
464
     StepStatus(iteration=3, pos=(1, 0), index=(1, 0), kwds={'a': 1, 'b': 0}),
465
     StepStatus(iteration=4, pos=(1, 1), index=(1, 1), kwds={'a': 1, 'b': 1}),
466
     StepStatus(iteration=5, pos=(1, 2), index=(1, 2), kwds={'a': 1, 'b': 2})]
467

468
    >>> iters = {
469
    ...     'a': range(2),
470
    ...     'b': range(3),
471
    ... }
472
    ... list(scan_iters(iters, lambda a, b: a < b))
473
    [StepStatus(iteration=0, pos=(0, 1), index=(0, 0), kwds={'a': 0, 'b': 1}),
474
     StepStatus(iteration=1, pos=(0, 2), index=(0, 1), kwds={'a': 0, 'b': 2}),
475
     StepStatus(iteration=2, pos=(1, 2), index=(1, 0), kwds={'a': 1, 'b': 2})]
476
    """
477

478
    # TODO: loops 里的 callable 值如果有 VAR_KEYWORD 参数,并且在运行时实际依
479
    #       赖于 functions 里的某些值,则会导致依赖关系错误
480
    # TODO: functions 里的 callable 值如果有 VAR_KEYWORD 参数,则对这些参数
481
    #       的依赖会被认为是对全体循环参数的依赖,并且这些函数本身不存在相互依赖
482

483
    if 'additional_kwds' in kwds:
9✔
484
        functions = functions | kwds['additional_kwds']
×
485
        warnings.warn(
×
486
            "The argument 'additional_kwds' is deprecated, "
487
            "use 'functions' instead.", DeprecationWarning)
488
    if 'iters' in kwds:
9✔
489
        loops = loops | kwds['iters']
×
490
        warnings.warn(
×
491
            "The argument 'iters' is deprecated, "
492
            "use 'loops' instead.", DeprecationWarning)
493

494
    if len(loops) == 0:
9✔
495
        return
×
496

497
    graph = _build_dependence(loops, functions, constants)
9✔
498
    ts = TopologicalSorter(graph)
9✔
499
    order = []
9✔
500
    ts.prepare()
9✔
501
    while ts.is_active():
9✔
502
        ready = ts.get_ready()
9✔
503
        for k in ready:
9✔
504
            ts.done(k)
9✔
505
        order.append(ready)
9✔
506
    graph = _build_dependence(loops, functions, constants, False)
9✔
507

508
    for tracker in trackers:
9✔
509
        tracker.init(loops, functions, constants, graph, order)
9✔
510

511
    last_step = None
9✔
512
    index = ()
9✔
513
    iteration = count()
9✔
514

515
    for step in _args_generator(list(loops.items()),
9✔
516
                                kwds=constants,
517
                                level=0,
518
                                pos=(),
519
                                vars=[],
520
                                filter=filter,
521
                                functions=functions,
522
                                trackers=trackers,
523
                                pipes={},
524
                                order=order,
525
                                pool=pool):
526
        if isinstance(step, (Begin, End)):
9✔
527
            if level_marker:
9✔
528
                if last_step is not None:
9✔
529
                    step.iteration = last_step.iteration
9✔
530
                yield step
9✔
531
            continue
9✔
532

533
        if last_step is None:
9✔
534
            i = 0
9✔
535
            index = (0, ) * len(step.pos)
9✔
536
        else:
537
            i = _find_common_prefix(last_step.pos, step.pos)
9✔
538
            index = tuple((j <= i) * n + (j == i) for j, n in enumerate(index))
9✔
539
        step.iteration = next(iteration)
9✔
540
        step.index = index
9✔
541
        step.unchanged = i
9✔
542
        yield step
9✔
543
        last_step = step
9✔
544

STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc