• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

feihoo87 / waveforms / 10366062234

13 Aug 2024 08:00AM UTC coverage: 18.563% (-11.3%) from 29.887%
10366062234

push

github

feihoo87
rm codes

1294 of 6971 relevant lines covered (18.56%)

1.67 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/waveforms/scan/base.py
1

2
import inspect
×
3
import logging
×
4
import warnings
×
5
from abc import ABC, abstractclassmethod
×
6
from concurrent.futures import Executor, Future
×
7
from dataclasses import dataclass, field
×
8
from graphlib import TopologicalSorter
×
9
from itertools import chain, count
×
10

11
from queue import Empty, Queue
×
12
from typing import Any, Callable, Iterable, Sequence, Type
×
13

14
log = logging.getLogger(__name__)
×
15
log.setLevel(logging.ERROR)
×
16

17

18
class BaseOptimizer(ABC):
×
19

20
    @abstractclassmethod
×
21
    def ask(self) -> tuple:
×
22
        pass
×
23

24
    @abstractclassmethod
×
25
    def tell(self, suggested: Sequence, value: Any):
×
26
        pass
×
27

28
    @abstractclassmethod
×
29
    def get_result(self):
×
30
        pass
×
31

32

33
@dataclass
×
34
class OptimizerConfig():
×
35
    cls: Type[BaseOptimizer]
×
36
    dimensions: list = field(default_factory=list)
×
37
    args: tuple = ()
×
38
    kwds: dict = field(default_factory=dict)
×
39
    max_iters: int = 100
×
40

41

42
class FeedbackPipe():
×
43
    __slots__ = (
×
44
        'keys',
45
        '_queue',
46
    )
47

48
    def __init__(self, keys):
×
49
        self.keys = keys
×
50
        self._queue = Queue()
×
51

52
    def __iter__(self):
×
53
        while True:
×
54
            try:
×
55
                yield self._queue.get_nowait()
×
56
            except Empty:
×
57
                break
×
58

59
    def __call__(self):
×
60
        return self.__iter__()
×
61

62
    def send(self, obj):
×
63
        self._queue.put(obj)
×
64

65
    def __repr__(self):
×
66
        if not isinstance(self.keys, tuple):
×
67
            return f'FeedbackProxy({repr(self.keys)})'
×
68
        else:
69
            return f'FeedbackProxy{self.keys}'
×
70

71

72
class FeedbackProxy():
×
73

74
    def feedback(self, keywords, obj, suggested=None):
×
75
        if keywords in self._pipes:
×
76
            if suggested is None:
×
77
                suggested = [self.kwds[k] for k in keywords]
×
78
            self._pipes[keywords].send((suggested, obj))
×
79
        else:
80
            warnings.warn(f'No feedback pipe for {keywords}', RuntimeWarning,
×
81
                          2)
82

83
    def feed(self, obj, **options):
×
84
        for tracker in self._trackers:
×
85
            tracker.feed(self, obj, **options)
×
86

87
    def store(self, obj, **options):
×
88
        self.feed(obj, store=True, **options)
×
89

90
    def __getstate__(self):
×
91
        state = self.__dict__.copy()
×
92
        del state['_pipes']
×
93
        del state['_trackers']
×
94
        return state
×
95

96
    def __setstate__(self, state):
×
97
        self.__dict__ = state
×
98
        self._pipes = {}
×
99
        self._trackers = []
×
100

101

102
@dataclass
×
103
class StepStatus(FeedbackProxy):
×
104
    iteration: int = 0
×
105
    pos: tuple = ()
×
106
    index: tuple = ()
×
107
    kwds: dict = field(default_factory=dict)
×
108
    vars: list[str] = field(default=list)
×
109
    unchanged: int = 0
×
110

111
    _pipes: dict = field(default_factory=dict, repr=False)
×
112
    _trackers: list = field(default_factory=list, repr=False)
×
113

114

115
@dataclass
×
116
class Begin(FeedbackProxy):
×
117
    level: int = 0
×
118
    iteration: int = 0
×
119
    pos: tuple = ()
×
120
    index: tuple = ()
×
121
    kwds: dict = field(default_factory=dict)
×
122
    vars: list[str] = field(default=list)
×
123

124
    _pipes: dict = field(default_factory=dict, repr=False)
×
125
    _trackers: list = field(default_factory=list, repr=False)
×
126

127
    def __repr__(self):
×
128
        return f'Begin(level={self.level}, kwds={self.kwds}, vars={self.vars})'
×
129

130

131
@dataclass
×
132
class End(FeedbackProxy):
×
133
    level: int = 0
×
134
    iteration: int = 0
×
135
    pos: tuple = ()
×
136
    index: tuple = ()
×
137
    kwds: dict = field(default_factory=dict)
×
138
    vars: list[str] = field(default=list)
×
139

140
    _pipes: dict = field(default_factory=dict, repr=False)
×
141
    _trackers: list = field(default_factory=list, repr=False)
×
142

143
    def __repr__(self):
×
144
        return f'End(level={self.level}, kwds={self.kwds}, vars={self.vars})'
×
145

146

147
class Tracker():
×
148

149
    def init(self, loops: dict, functions: dict, constants: dict, graph: dict,
×
150
             order: list):
151
        pass
×
152

153
    def update(self, kwds: dict):
×
154
        return kwds
×
155

156
    def feed(self, step: StepStatus, obj: Any, **options):
×
157
        pass
×
158

159

160
def _call_func_with_kwds(func, args, kwds):
×
161
    funcname = getattr(func, '__name__', repr(func))
×
162
    sig = inspect.signature(func)
×
163
    for p in sig.parameters.values():
×
164
        if p.kind == p.VAR_KEYWORD:
×
165
            return func(*args, **kwds)
×
166
    kw = {
×
167
        k: v
168
        for k, v in kwds.items()
169
        if k in list(sig.parameters.keys())[len(args):]
170
    }
171
    try:
×
172
        args = [
×
173
            arg.result() if isinstance(arg, Future) else arg for arg in args
174
        ]
175
        kw = {
×
176
            k: v.result() if isinstance(v, Future) else v
177
            for k, v in kw.items()
178
        }
179
        return func(*args, **kw)
×
180
    except:
×
181
        log.exception(f'Call {funcname} with {args} and {kw}')
×
182
        raise
×
183
    finally:
184
        log.debug(f'Call {funcname} with {args} and {kw}')
×
185

186

187
def _try_to_call(x, args, kwds):
×
188
    if callable(x):
×
189
        return _call_func_with_kwds(x, args, kwds)
×
190
    return x
×
191

192

193
def _get_current_iters(loops, level, kwds, pipes):
×
194
    keys, current = loops[level]
×
195
    limit = -1
×
196

197
    if isinstance(keys, str):
×
198
        keys = (keys, )
×
199
        current = (current, )
×
200
    elif isinstance(keys, tuple) and isinstance(
×
201
            current, tuple) and len(keys) == len(current):
202
        keys = tuple(k if isinstance(k, tuple) else (k, ) for k in keys)
×
203
    elif isinstance(keys, tuple) and not isinstance(current, tuple):
×
204
        current = (current, )
×
205
        if isinstance(keys[0], str):
×
206
            keys = (keys, )
×
207
    else:
208
        log.error(f'Illegal keys {keys} on level {level}.')
×
209
        raise TypeError(f'Illegal keys {keys} on level {level}.')
×
210

211
    if not isinstance(keys, tuple):
×
212
        keys = (keys, )
×
213
    if not isinstance(current, tuple):
×
214
        current = (current, )
×
215

216
    iters = []
×
217
    for k, it in zip(keys, current):
×
218
        pipe = FeedbackPipe(k)
×
219
        if isinstance(it, OptimizerConfig):
×
220
            if limit < 0 or limit > it.max_iters:
×
221
                limit = it.max_iters
×
222
            it = it.cls(it.dimensions, *it.args, **it.kwds)
×
223
        else:
224
            it = iter(_try_to_call(it, (), kwds))
×
225

226
        iters.append((it, pipe))
×
227
        pipes[k] = pipe
×
228

229
    return keys, iters, pipes, limit
×
230

231

232
def _generate_kwds(keys, iters, kwds, iteration, limit):
×
233
    ret = {}
×
234
    for ks, it in zip(keys, iters):
×
235
        if isinstance(ks, str):
×
236
            ks = (ks, )
×
237
        if hasattr(it[0], 'ask') and hasattr(it[0], 'tell') and hasattr(
×
238
                it[0], 'get_result'):
239
            if limit > 0 and iteration >= limit - 1:
×
240
                value = _call_func_with_kwds(it[0].get_result, (), kwds).x
×
241
            else:
242
                value = _call_func_with_kwds(it[0].ask, (), kwds)
×
243
        else:
244
            value = next(it[0])
×
245
            if len(ks) == 1:
×
246
                value = (value, )
×
247
        ret.update(zip(ks, value))
×
248
    return ret
×
249

250

251
def _send_feedback(generator, feedback):
×
252
    if hasattr(generator, 'ask') and hasattr(generator, 'tell') and hasattr(
×
253
            generator, 'get_result'):
254
        generator.tell(
×
255
            *[x.result() if isinstance(x, Future) else x for x in feedback])
256

257

258
def _feedback(iters):
×
259
    for generator, pipe in iters:
×
260
        for feedback in pipe():
×
261
            _send_feedback(generator, feedback)
×
262

263

264
def _call_functions(functions, kwds, order, pool: Executor | None = None):
×
265
    vars = []
×
266
    for i, ready in enumerate(order):
×
267
        rest = []
×
268
        for k in ready:
×
269
            if k in kwds:
×
270
                continue
×
271
            elif k in functions:
×
272
                if pool is None:
×
273
                    kwds[k] = _try_to_call(functions[k], (), kwds)
×
274
                else:
275
                    kwds[k] = pool.submit(_try_to_call, functions[k], (), kwds)
×
276
                vars.append(k)
×
277
            else:
278
                rest.append(k)
×
279
        if rest:
×
280
            break
×
281
    else:
282
        return [], vars
×
283
    if rest:
×
284
        return [rest] + order[i:], vars
×
285
    else:
286
        return order[i:], vars
×
287

288

289
def _args_generator(loops: list,
×
290
                    kwds: dict[str, Any],
291
                    level: int,
292
                    pos: tuple[int, ...],
293
                    vars: list[tuple[str]],
294
                    filter: Callable[..., bool] | None,
295
                    functions: dict[str, Callable],
296
                    trackers: list[Tracker],
297
                    pipes: dict[str | tuple[str, ...], FeedbackPipe],
298
                    order: list[str],
299
                    pool: Executor | None = None):
300
    order, local_vars = _call_functions(functions, kwds, order, pool)
×
301
    if len(loops) == level and level > 0:
×
302
        if order:
×
303
            log.error(f'Unresolved functions: {order}')
×
304
            raise TypeError(f'Unresolved functions: {order}')
×
305
        for tracker in trackers:
×
306
            kwds = tracker.update(kwds)
×
307
        if filter is None or _call_func_with_kwds(filter, (), kwds):
×
308
            yield StepStatus(
×
309
                pos=pos,
310
                kwds=kwds,
311
                vars=[*vars[:-1], tuple([*vars[-1], *local_vars])],
312
                _pipes=pipes,
313
                _trackers=trackers)
314
        return
×
315

316
    keys, current_iters, pipes, limit = _get_current_iters(
×
317
        loops, level, kwds, pipes)
318

319
    for i in count():
×
320
        if limit > 0 and i >= limit:
×
321
            break
×
322
        try:
×
323
            kw = _generate_kwds(keys, current_iters, kwds, i, limit)
×
324
        except StopIteration:
×
325
            break
×
326
        yield Begin(level=level,
×
327
                    pos=pos + (i, ),
328
                    kwds=kwds | kw,
329
                    vars=[*vars, tuple([*local_vars, *kw.keys()])],
330
                    _pipes=pipes,
331
                    _trackers=trackers)
332
        yield from _args_generator(
×
333
            loops, kwds | kw, level + 1, pos + (i, ),
334
            [*vars, tuple([*local_vars, *kw.keys()])], filter, functions,
335
            trackers, pipes, order)
336
        yield End(level=level,
×
337
                  pos=pos + (i, ),
338
                  kwds=kwds | kw,
339
                  vars=[*vars, tuple([*local_vars, *kw.keys()])],
340
                  _pipes=pipes,
341
                  _trackers=trackers)
342
        _feedback(current_iters)
×
343

344

345
def _find_common_prefix(a: tuple, b: tuple):
×
346
    for i, (x, y) in enumerate(zip(a, b)):
×
347
        if x != y:
×
348
            return i
×
349
    return i
×
350

351

352
def _add_dependence(graph, keys, function, loop_names, var_names):
×
353
    if isinstance(keys, str):
×
354
        keys = (keys, )
×
355
    for key in keys:
×
356
        graph.setdefault(key, set())
×
357
        for k, p in inspect.signature(function).parameters.items():
×
358
            if p.kind in [
×
359
                    p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY
360
            ] and k in var_names:
361
                graph[key].add(k)
×
362
            if p.kind == p.VAR_KEYWORD and key not in loop_names:
×
363
                graph[key].update(loop_names)
×
364

365

366
def _build_dependence(loops, functions, constants, loop_deps=True):
×
367
    graph = {}
×
368
    loop_names = set()
×
369
    var_names = set()
×
370
    for keys, iters in loops.items():
×
371
        level_vars = set()
×
372
        if isinstance(keys, str):
×
373
            keys = (keys, )
×
374
        if callable(iters):
×
375
            iters = tuple([iters for _ in keys])
×
376
        for ks, iter_vars in zip(keys, iters):
×
377
            if isinstance(ks, str):
×
378
                ks = (ks, )
×
379
            if callable(iters):
×
380
                iter_vars = tuple([iter_vars for _ in ks])
×
381
            level_vars.update(ks)
×
382
            for i, k in enumerate(ks):
×
383
                d = graph.setdefault(k, set())
×
384
                if loop_deps:
×
385
                    d.update(loop_names)
×
386
                else:
387
                    if isinstance(iter_vars, tuple):
×
388
                        iter_var = iter_vars[i]
×
389
                    else:
390
                        iter_var = iter_vars
×
391
                    if callable(iter_var):
×
392
                        d.update(
×
393
                            set(
394
                                inspect.signature(iter_vars).parameters.keys())
395
                            & loop_names)
396

397
        loop_names.update(level_vars)
×
398
        var_names.update(level_vars)
×
399
    var_names.update(functions.keys())
×
400
    var_names.update(constants.keys())
×
401

402
    for keys, values in chain(loops.items(), functions.items()):
×
403
        if callable(values):
×
404
            _add_dependence(graph, keys, values, loop_names, var_names)
×
405
        elif isinstance(values, tuple):
×
406
            for ks, v in zip(keys, values):
×
407
                if callable(v):
×
408
                    _add_dependence(graph, ks, v, loop_names, var_names)
×
409

410
    return graph
×
411

412

413
def _get_all_dependence(key, graph):
×
414
    ret = set()
×
415
    if key not in graph:
×
416
        return ret
×
417
    for k in graph[key]:
×
418
        ret.add(k)
×
419
        ret.update(_get_all_dependence(k, graph))
×
420
    return ret
×
421

422

423
def scan_iters(loops: dict[str | tuple[str, ...],
×
424
                           Iterable | Callable | OptimizerConfig
425
                           | tuple[Iterable | Callable | OptimizerConfig,
426
                                   ...]] = {},
427
               filter: Callable[..., bool] | None = None,
428
               functions: dict[str, Callable] = {},
429
               constants: dict[str, Any] = {},
430
               trackers: list[Tracker] = [],
431
               level_marker: bool = False,
432
               pool: Executor | None = None,
433
               **kwds) -> Iterable[StepStatus]:
434
    """
435
    Scan the given iterable of iterables.
436

437
    Parameters
438
    ----------
439
    loops : dict
440
        A map of iterables that are scanned.
441
    filter : Callable[..., bool]
442
        A filter function that is called for each step.
443
        If it returns False, the step is skipped.
444
    functions : dict
445
        A map of functions that are called for each step.
446
    constants : dict
447
        Additional keyword arguments that are passed to the iterables.
448

449
    Returns
450
    -------
451
    Iterable[StepStatus]
452
        An iterable of StepStatus objects.
453

454
    Examples
455
    --------
456
    >>> iters = {
457
    ...     'a': range(2),
458
    ...     'b': range(3),
459
    ... }
460
    >>> list(scan_iters(iters))
461
    [StepStatus(iteration=0, pos=(0, 0), index=(0, 0), kwds={'a': 0, 'b': 0}),
462
     StepStatus(iteration=1, pos=(0, 1), index=(0, 1), kwds={'a': 0, 'b': 1}),
463
     StepStatus(iteration=2, pos=(0, 2), index=(0, 2), kwds={'a': 0, 'b': 2}),
464
     StepStatus(iteration=3, pos=(1, 0), index=(1, 0), kwds={'a': 1, 'b': 0}),
465
     StepStatus(iteration=4, pos=(1, 1), index=(1, 1), kwds={'a': 1, 'b': 1}),
466
     StepStatus(iteration=5, pos=(1, 2), index=(1, 2), kwds={'a': 1, 'b': 2})]
467

468
    >>> iters = {
469
    ...     'a': range(2),
470
    ...     'b': range(3),
471
    ... }
472
    ... list(scan_iters(iters, lambda a, b: a < b))
473
    [StepStatus(iteration=0, pos=(0, 1), index=(0, 0), kwds={'a': 0, 'b': 1}),
474
     StepStatus(iteration=1, pos=(0, 2), index=(0, 1), kwds={'a': 0, 'b': 2}),
475
     StepStatus(iteration=2, pos=(1, 2), index=(1, 0), kwds={'a': 1, 'b': 2})]
476
    """
477

478
    # TODO: loops 里的 callable 值如果有 VAR_KEYWORD 参数,并且在运行时实际依
479
    #       赖于 functions 里的某些值,则会导致依赖关系错误
480
    # TODO: functions 里的 callable 值如果有 VAR_KEYWORD 参数,则对这些参数
481
    #       的依赖会被认为是对全体循环参数的依赖,并且这些函数本身不存在相互依赖
482

483
    if 'additional_kwds' in kwds:
×
484
        functions = functions | kwds['additional_kwds']
×
485
        warnings.warn(
×
486
            "The argument 'additional_kwds' is deprecated, "
487
            "use 'functions' instead.", DeprecationWarning)
488
    if 'iters' in kwds:
×
489
        loops = loops | kwds['iters']
×
490
        warnings.warn(
×
491
            "The argument 'iters' is deprecated, "
492
            "use 'loops' instead.", DeprecationWarning)
493

494
    if len(loops) == 0:
×
495
        return
×
496

497
    graph = _build_dependence(loops, functions, constants)
×
498
    ts = TopologicalSorter(graph)
×
499
    order = []
×
500
    ts.prepare()
×
501
    while ts.is_active():
×
502
        ready = ts.get_ready()
×
503
        for k in ready:
×
504
            ts.done(k)
×
505
        order.append(ready)
×
506
    graph = _build_dependence(loops, functions, constants, False)
×
507

508
    for tracker in trackers:
×
509
        tracker.init(loops, functions, constants, graph, order)
×
510

511
    last_step = None
×
512
    index = ()
×
513
    iteration = count()
×
514

515
    for step in _args_generator(list(loops.items()),
×
516
                                kwds=constants,
517
                                level=0,
518
                                pos=(),
519
                                vars=[],
520
                                filter=filter,
521
                                functions=functions,
522
                                trackers=trackers,
523
                                pipes={},
524
                                order=order,
525
                                pool=pool):
526
        if isinstance(step, (Begin, End)):
×
527
            if level_marker:
×
528
                if last_step is not None:
×
529
                    step.iteration = last_step.iteration
×
530
                yield step
×
531
            continue
×
532

533
        if last_step is None:
×
534
            i = 0
×
535
            index = (0, ) * len(step.pos)
×
536
        else:
537
            i = _find_common_prefix(last_step.pos, step.pos)
×
538
            index = tuple((j <= i) * n + (j == i) for j, n in enumerate(index))
×
539
        step.iteration = next(iteration)
×
540
        step.index = index
×
541
        step.unchanged = i
×
542
        yield step
×
543
        last_step = step
×
544

STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc