• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

rindPHI / isla / 4532124943

pending completion
4532124943

push

github

GitHub
Merge pull request #61 from rindPHI/documentation

45 of 45 new or added lines in 1 file covered. (100.0%)

6071 of 6483 relevant lines covered (93.64%)

0.94 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.06
/src/isla/solver.py
1
# Copyright © 2022 CISPA Helmholtz Center for Information Security.
2
# Author: Dominic Steinhöfel.
3
#
4
# This file is part of ISLa.
5
#
6
# ISLa is free software: you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation, either version 3 of the License, or
9
# (at your option) any later version.
10
#
11
# ISLa is distributed in the hope that it will be useful,
12
# but WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14
# GNU General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with ISLa.  If not, see <http://www.gnu.org/licenses/>.
18

19
import copy
1✔
20
import functools
1✔
21
import heapq
1✔
22
import itertools
1✔
23
import logging
1✔
24
import math
1✔
25
import operator
1✔
26
import random
1✔
27
import sys
1✔
28
import time
1✔
29
from abc import ABC
1✔
30
from dataclasses import dataclass
1✔
31
from functools import reduce, lru_cache
1✔
32
from typing import (
1✔
33
    Dict,
34
    List,
35
    Set,
36
    Optional,
37
    Tuple,
38
    Union,
39
    cast,
40
    Callable,
41
    Iterable,
42
    Sequence,
43
)
44

45
import pkg_resources
1✔
46
import z3
1✔
47
from grammar_graph import gg
1✔
48
from grammar_graph.gg import GrammarGraph
1✔
49
from grammar_to_regex.cfg2regex import RegexConverter
1✔
50
from grammar_to_regex.regex import regex_to_z3
1✔
51
from orderedset import OrderedSet
1✔
52
from packaging import version
1✔
53

54
import isla.isla_shortcuts as sc
1✔
55
import isla.three_valued_truth
1✔
56
from isla import language
1✔
57
from isla.derivation_tree import DerivationTree
1✔
58
from isla.evaluator import (
1✔
59
    evaluate,
60
    quantified_formula_might_match,
61
    get_toplevel_quantified_formulas,
62
    eliminate_quantifiers,
63
)
64
from isla.evaluator import matches_for_quantified_formula
1✔
65
from isla.existential_helpers import (
1✔
66
    insert_tree,
67
    DIRECT_EMBEDDING,
68
    SELF_EMBEDDING,
69
    CONTEXT_ADDITION,
70
)
71
from isla.fuzzer import GrammarFuzzer, GrammarCoverageFuzzer, expansion_key
1✔
72
from isla.helpers import (
1✔
73
    delete_unreachable,
74
    shuffle,
75
    dict_of_lists_to_list_of_dicts,
76
    weighted_geometric_mean,
77
    assertions_activated,
78
    split_str_with_nonterminals,
79
    cluster_by_common_elements,
80
    is_nonterminal,
81
    canonical,
82
    lazyjoin,
83
    lazystr,
84
    Maybe,
85
    chain_functions,
86
    Exceptional,
87
    eliminate_suffixes,
88
    get_elem_by_equivalence,
89
    get_expansions,
90
    list_del,
91
    compute_nullable_nonterminals,
92
    eassert,
93
    merge_dict_of_sets,
94
)
95
from isla.isla_predicates import (
1✔
96
    STANDARD_STRUCTURAL_PREDICATES,
97
    STANDARD_SEMANTIC_PREDICATES,
98
    COUNT_PREDICATE,
99
)
100
from isla.isla_shortcuts import true
1✔
101
from isla.language import (
1✔
102
    VariablesCollector,
103
    split_conjunction,
104
    split_disjunction,
105
    convert_to_dnf,
106
    convert_to_nnf,
107
    ensure_unique_bound_variables,
108
    parse_isla,
109
    get_conjuncts,
110
    parse_bnf,
111
    ForallIntFormula,
112
    set_smt_auto_eval,
113
    NoopFormulaTransformer,
114
    set_smt_auto_subst,
115
    fresh_constant,
116
)
117
from isla.mutator import Mutator
1✔
118
from isla.parser import EarleyParser
1✔
119
from isla.type_defs import Grammar, Path, ImmutableList, CanonicalGrammar
1✔
120
from isla.z3_helpers import (
1✔
121
    z3_solve,
122
    z3_subst,
123
    z3_eq,
124
    z3_and,
125
    visit_z3_expr,
126
    smt_string_val_to_string,
127
    parent_relationships_in_z3_expr,
128
)
129

130

131
@dataclass(frozen=True)
1✔
132
class SolutionState:
1✔
133
    constraint: language.Formula
1✔
134
    tree: DerivationTree
1✔
135
    level: int = 0
1✔
136
    __hash: Optional[int] = None
1✔
137

138
    def formula_satisfied(
1✔
139
        self, grammar: Grammar
140
    ) -> isla.three_valued_truth.ThreeValuedTruth:
141
        if self.tree.is_open():
1✔
142
            # Have to instantiate variables first
143
            return isla.three_valued_truth.ThreeValuedTruth.unknown()
×
144

145
        return evaluate(self.constraint, self.tree, grammar)
1✔
146

147
    def complete(self) -> bool:
1✔
148
        if not self.tree.is_complete():
1✔
149
            return False
1✔
150

151
        # We assume that any universal quantifier has already been instantiated, if it
152
        # matches, and is thus satisfied, or another unsatisfied constraint resulted
153
        # from the instantiation. Existential, predicate, and SMT formulas have to be
154
        # eliminated first.
155

156
        return self.constraint == sc.true()
1✔
157

158
    # Less-than comparisons are needed for usage in the binary heap queue
159
    def __lt__(self, other: "SolutionState"):
1✔
160
        return hash(self) < hash(other)
1✔
161

162
    def __hash__(self):
1✔
163
        if self.__hash is None:
1✔
164
            result = hash((self.constraint, self.tree))
1✔
165
            object.__setattr__(self, "__hash", result)
1✔
166
            return result
1✔
167

168
        return self.__hash
×
169

170
    def __eq__(self, other):
1✔
171
        return (
1✔
172
            isinstance(other, SolutionState)
173
            and self.constraint == other.constraint
174
            and self.tree.structurally_equal(other.tree)
175
        )
176

177

178
@dataclass(frozen=True)
1✔
179
class CostWeightVector:
1✔
180
    """
1✔
181
    Collection of weights for the
182
    :class:`~isla.solver.GrammarBasedBlackboxCostComputer`.
183
    """
184

185
    tree_closing_cost: float = 0
1✔
186
    constraint_cost: float = 0
1✔
187
    derivation_depth_penalty: float = 0
1✔
188
    low_k_coverage_penalty: float = 0
1✔
189
    low_global_k_path_coverage_penalty: float = 0
1✔
190

191
    def __iter__(self):
1✔
192
        """
193
        Use tuple assignment for objects of this type:
194

195
        >>> v = CostWeightVector(1, 2, 3, 4, 5)
196
        >>> a, b, c, d, e = v
197
        >>> (a, b, c, d, e)
198
        (1, 2, 3, 4, 5)
199

200
        :return: An iterator of the fixed-size list of elements of the weight vector.
201
        """
202
        return iter(
1✔
203
            [
204
                self.tree_closing_cost,
205
                self.constraint_cost,
206
                self.derivation_depth_penalty,
207
                self.low_k_coverage_penalty,
208
                self.low_global_k_path_coverage_penalty,
209
            ]
210
        )
211

212
    def __getitem__(self, item: int) -> float:
1✔
213
        """
214
        Tuple-like access of elements of the vector.
215

216
        >>> v = CostWeightVector(1, 2, 3, 4, 5)
217
        >>> v[3]
218
        4
219

220
        :param item: A numeric index.
221
        :return: The element at index :code:`item`.
222
        """
223
        assert isinstance(item, int)
1✔
224
        return [
1✔
225
            self.tree_closing_cost,
226
            self.constraint_cost,
227
            self.derivation_depth_penalty,
228
            self.low_k_coverage_penalty,
229
            self.low_global_k_path_coverage_penalty,
230
        ][item]
231

232

233
@dataclass(frozen=True)
1✔
234
class CostSettings:
1✔
235
    weight_vector: CostWeightVector
1✔
236
    k: int = 3
1✔
237

238
    def __init__(self, weight_vector: CostWeightVector, k: int = 3):
1✔
239
        assert isinstance(weight_vector, CostWeightVector)
1✔
240
        assert isinstance(k, int)
1✔
241
        object.__setattr__(self, "weight_vector", weight_vector)
1✔
242
        object.__setattr__(self, "k", k)
1✔
243

244

245
STD_COST_SETTINGS = CostSettings(
1✔
246
    CostWeightVector(
247
        tree_closing_cost=6.5,
248
        constraint_cost=1,
249
        derivation_depth_penalty=4,
250
        low_k_coverage_penalty=2,
251
        low_global_k_path_coverage_penalty=19,
252
    ),
253
    k=3,
254
)
255

256

257
@dataclass(frozen=True)
1✔
258
class UnknownResultError(Exception):
1✔
259
    pass
1✔
260

261

262
@dataclass(frozen=True)
1✔
263
class SemanticError(Exception):
1✔
264
    pass
1✔
265

266

267
@dataclass(frozen=True)
1✔
268
class SolverDefaults:
1✔
269
    formula: Optional[language.Formula | str] = None
1✔
270
    structural_predicates: frozenset[
1✔
271
        language.StructuralPredicate
272
    ] = STANDARD_STRUCTURAL_PREDICATES
273
    semantic_predicates: frozenset[
1✔
274
        language.SemanticPredicate
275
    ] = STANDARD_SEMANTIC_PREDICATES
276
    max_number_free_instantiations: int = 10
1✔
277
    max_number_smt_instantiations: int = 10
1✔
278
    max_number_tree_insertion_results: int = 5
1✔
279
    enforce_unique_trees_in_queue: bool = False
1✔
280
    debug: bool = False
1✔
281
    cost_computer: Optional["CostComputer"] = None
1✔
282
    timeout_seconds: Optional[int] = None
1✔
283
    global_fuzzer: bool = False
1✔
284
    predicates_unique_in_int_arg: Tuple[language.SemanticPredicate, ...] = (
1✔
285
        COUNT_PREDICATE,
286
    )
287
    fuzzer_factory: Callable[
1✔
288
        [Grammar], GrammarFuzzer
289
    ] = lambda grammar: GrammarCoverageFuzzer(grammar)
290
    tree_insertion_methods: Optional[int] = None
1✔
291
    activate_unsat_support: bool = False
1✔
292
    grammar_unwinding_threshold: int = 4
1✔
293
    initial_tree: Maybe[DerivationTree] = Maybe.nothing()
1✔
294
    enable_optimized_z3_queries: bool = True
1✔
295
    start_symbol: Optional[str] = None
1✔
296

297

298
_DEFAULTS = SolverDefaults()
1✔
299

300

301
class ISLaSolver:
1✔
302
    """
1✔
303
    The solver class for ISLa formulas/constraints. Its top-level methods are
304

305
    :meth:`~isla.solver.ISLaSolver.solve`
306
      Use to generate solutions for an ISLa constraint.
307

308
    :meth:`~isla.solver.ISLaSolver.check`
309
      Use to check if an ISLa constraint is satisfied for a given input.
310

311
    :meth:`~isla.solver.ISLaSolver.parse`
312
      Use to parse and validate an input.
313

314
    :meth:`~isla.solver.ISLaSolver.repair`
315
      Use to repair an input such that it satisfies a constraint.
316

317
    :meth:`~isla.solver.ISLaSolver.mutate`
318
      Use to mutate an input such that the result satisfies a constraint.
319
    """
320

321
    def __init__(
1✔
322
        self,
323
        grammar: Grammar | str,
324
        formula: Optional[language.Formula | str] = _DEFAULTS.formula,
325
        structural_predicates: Set[
326
            language.StructuralPredicate
327
        ] = _DEFAULTS.structural_predicates,
328
        semantic_predicates: Set[
329
            language.SemanticPredicate
330
        ] = _DEFAULTS.semantic_predicates,
331
        max_number_free_instantiations: int = _DEFAULTS.max_number_free_instantiations,
332
        max_number_smt_instantiations: int = _DEFAULTS.max_number_smt_instantiations,
333
        max_number_tree_insertion_results: int = _DEFAULTS.max_number_tree_insertion_results,
334
        enforce_unique_trees_in_queue: bool = _DEFAULTS.enforce_unique_trees_in_queue,
335
        debug: bool = _DEFAULTS.debug,
336
        cost_computer: Optional["CostComputer"] = _DEFAULTS.cost_computer,
337
        timeout_seconds: Optional[int] = _DEFAULTS.timeout_seconds,
338
        global_fuzzer: bool = _DEFAULTS.global_fuzzer,
339
        predicates_unique_in_int_arg: Tuple[
340
            language.SemanticPredicate, ...
341
        ] = _DEFAULTS.predicates_unique_in_int_arg,
342
        fuzzer_factory: Callable[[Grammar], GrammarFuzzer] = _DEFAULTS.fuzzer_factory,
343
        tree_insertion_methods: Optional[int] = _DEFAULTS.tree_insertion_methods,
344
        activate_unsat_support: bool = _DEFAULTS.activate_unsat_support,
345
        grammar_unwinding_threshold: int = _DEFAULTS.grammar_unwinding_threshold,
346
        initial_tree: Maybe[DerivationTree] = _DEFAULTS.initial_tree,
347
        enable_optimized_z3_queries: bool = _DEFAULTS.enable_optimized_z3_queries,
348
        start_symbol: Optional[str] = _DEFAULTS.start_symbol,
349
    ):
350
        """
351
        The constructor of :class:`~isla.solver.ISLaSolver` accepts a large number of
352
        parameters. However, all but the first one, :code:`grammar`, are *optional.*
353

354
        The simplest way to construct an ISLa solver is by only providing it with a
355
        grammar only; it then works like a grammar fuzzer.
356

357
        >>> import random
358
        >>> random.seed(1)
359

360
        >>> import string
361
        >>> LANG_GRAMMAR = {
362
        ...     "<start>":
363
        ...         ["<stmt>"],
364
        ...     "<stmt>":
365
        ...         ["<assgn> ; <stmt>", "<assgn>"],
366
        ...     "<assgn>":
367
        ...         ["<var> := <rhs>"],
368
        ...     "<rhs>":
369
        ...         ["<var>", "<digit>"],
370
        ...     "<var>": list(string.ascii_lowercase),
371
        ...     "<digit>": list(string.digits)
372
        ... }
373
        >>>
374
        >>> from isla.solver import ISLaSolver
375
        >>> solver = ISLaSolver(LANG_GRAMMAR)
376
        >>>
377
        >>> str(solver.solve())
378
        'd := 9'
379
        >>> str(solver.solve())
380
        'v := n ; s := r'
381

382
        :param grammar: The underlying grammar; either, as a "Fuzzing Book" dictionary
383
          or in BNF syntax.
384
        :param formula: The formula to solve; either a string or a readily parsed
385
          formula. If no formula is given, a default `true` constraint is assumed, and
386
          the solver falls back to a grammar fuzzer. The number of produced solutions
387
          will then be bound by `max_number_free_instantiations`.
388
        :param structural_predicates: Structural predicates to use when parsing a
389
          formula.
390
        :param semantic_predicates: Semantic predicates to use when parsing a formula.
391
        :param max_number_free_instantiations: Number of times that nonterminals that
392
          are not bound by any formula should be expanded by a coverage-based fuzzer.
393
        :param max_number_smt_instantiations: Number of solutions of SMT formulas that
394
          should be produced.
395
        :param max_number_tree_insertion_results: The maximum number of results when
396
          solving existential quantifiers by tree insertion.
397
        :param enforce_unique_trees_in_queue: If true, states with the same tree as an
398
          already existing tree in the queue are discarded, irrespectively of the
399
          constraint.
400
        :param debug: If true, debug information about the evolution of states is
401
          collected, notably in the field state_tree. The root of the tree is in the
402
          field state_tree_root. The field costs stores the computed cost values for
403
          all new nodes.
404
        :param cost_computer: The `CostComputer` class for computing the cost relevant
405
          to placing states in ISLa's queue.
406
        :param timeout_seconds: Number of seconds after which the solver will terminate.
407
        :param global_fuzzer: If set to True, only one coverage-guided grammar fuzzer
408
          object is used to finish off unconstrained open derivation trees throughout
409
          the whole generation time. This may be beneficial for some targets; e.g., we
410
          experienced that CSV works significantly faster. However, the achieved k-path
411
          coverage can be lower with that setting.
412
        :param predicates_unique_in_int_arg: This is needed in certain cases for
413
          instantiating universal integer quantifiers. The supplied predicates should
414
          have exactly one integer argument, and hold for exactly one integer value
415
          once all other parameters are fixed.
416
        :param fuzzer_factory: Constructor of the fuzzer to use for instantiating
417
          "free" nonterminals.
418
        :param tree_insertion_methods: Combination of methods to use for existential
419
          quantifier elimination by tree insertion. Full selection: `DIRECT_EMBEDDING &
420
          SELF_EMBEDDING & CONTEXT_ADDITION`.
421
        :param activate_unsat_support: Set to True if you assume that a formula might
422
          be unsatisfiable. This triggers additional tests for unsatisfiability that
423
          reduce input generation performance, but might ensure termination (with a
424
          negative solver result) for unsatisfiable problems for which the solver could
425
          otherwise diverge.
426
        :param grammar_unwinding_threshold: When querying the SMT solver, ISLa passes a
427
          regular expression for the syntax of the involved nonterminals. If this
428
          syntax is not regular, we unwind the respective part in the reference grammar
429
          up to a depth of `grammar_unwinding_threshold`. If this is too shallow, it can
430
          happen that an equation etc. cannot be solved; if it is too deep, it can
431
          negatively impact performance (and quite tremendously so).
432
        :param initial_tree: An initial input tree for the queue, if the solver shall
433
          not start from the tree `(<start>, None)`.
434
        :param enable_optimized_z3_queries: Enables preprocessing of Z3 queries (mainly
435
          numeric problems concerning things like length). This can improve performance
436
          significantly; however, it might happen that certain problems cannot be solved
437
          anymore. In that case, this option can/should be deactivated.
438
        :param start_symbol: This is an alternative to `initial_tree` for starting with
439
          a start symbol different form `<start>`. If `start_symbol` is provided, a tree
440
          consisting of a single root node with the value of `start_symbol` is chosen as
441
          initial tree.
442
        """
443
        self.logger = logging.getLogger(type(self).__name__)
1✔
444

445
        # We require at least z3 4.8.13.0. ISLa might work for some older versions, but
446
        # at least for 4.8.8.0, we have witnessed that certain rather easy constraints,
447
        # like equations, don't work since z3 cannot handle the restrictions to more
448
        # complicated regular expressions. This happened in the XML case study with
449
        # constraints of the kind <id> == <id_no_prefix>. At least with 4.8.13.0, this
450
        # works flawlessly, but times out for 4.8.8.0.
451
        #
452
        # This should be solved using Python requirements, which however cannot be done
453
        # currently since the fuzzingbook library inflexibly binds z3 to 4.8.8.0. Thus,
454
        # one has to manually install a newer version and ignore the warning.
455

456
        z3_version = pkg_resources.get_distribution("z3-solver").version
1✔
457
        assert version.parse(z3_version) >= version.parse("4.8.13.0"), (
1✔
458
            f"ISLa requires at least z3 4.8.13.0, present: {z3_version}. "
459
            "Please install a newer z3 version, e.g., using 'pip install z3-solver==4.8.14.0'."
460
        )
461

462
        if isinstance(grammar, str):
1✔
463
            self.grammar = parse_bnf(grammar)
1✔
464
        else:
465
            self.grammar = copy.deepcopy(grammar)
1✔
466

467
        assert (
1✔
468
            start_symbol is None or not initial_tree.is_present()
469
        ), "You cannot supply a start symbol *and* an initial tree."
470

471
        if start_symbol is not None:
1✔
472
            self.grammar |= {"<start>": [start_symbol]}
1✔
473
            self.grammar = delete_unreachable(self.grammar)
1✔
474

475
        self.graph = GrammarGraph.from_grammar(self.grammar)
1✔
476
        self.canonical_grammar = canonical(self.grammar)
1✔
477
        self.timeout_seconds = timeout_seconds
1✔
478
        self.start_time: Optional[int] = None
1✔
479
        self.global_fuzzer = global_fuzzer
1✔
480
        self.fuzzer = fuzzer_factory(self.grammar)
1✔
481
        self.fuzzer_factory = fuzzer_factory
1✔
482
        self.predicates_unique_in_int_arg: Set[language.SemanticPredicate] = set(
1✔
483
            predicates_unique_in_int_arg
484
        )
485
        self.grammar_unwinding_threshold = grammar_unwinding_threshold
1✔
486
        self.enable_optimized_z3_queries = enable_optimized_z3_queries
1✔
487

488
        if activate_unsat_support and tree_insertion_methods is None:
1✔
489
            self.tree_insertion_methods = 0
1✔
490
        else:
491
            if activate_unsat_support:
1✔
492
                assert tree_insertion_methods is not None
×
493
                print(
×
494
                    "With activate_unsat_support set, a 0 value for tree_insertion_methods is recommended, "
495
                    f"the current value is: {tree_insertion_methods}",
496
                    file=sys.stderr,
497
                )
498

499
            self.tree_insertion_methods = (
1✔
500
                DIRECT_EMBEDDING + SELF_EMBEDDING + CONTEXT_ADDITION
501
            )
502
            if tree_insertion_methods is not None:
1✔
503
                self.tree_insertion_methods = tree_insertion_methods
1✔
504

505
        self.activate_unsat_support = activate_unsat_support
1✔
506
        self.currently_unsat_checking: bool = False
1✔
507

508
        self.cost_computer = (
1✔
509
            cost_computer
510
            if cost_computer is not None
511
            else GrammarBasedBlackboxCostComputer(STD_COST_SETTINGS, self.graph)
512
        )
513

514
        formula = (
1✔
515
            sc.true()
516
            if formula is None
517
            else (
518
                parse_isla(
519
                    formula, self.grammar, structural_predicates, semantic_predicates
520
                )
521
                if isinstance(formula, str)
522
                else formula
523
            )
524
        )
525

526
        self.formula = ensure_unique_bound_variables(formula)
1✔
527

528
        top_constants: Set[language.Constant] = set(
1✔
529
            [
530
                c
531
                for c in VariablesCollector.collect(self.formula)
532
                if isinstance(c, language.Constant) and not c.is_numeric()
533
            ]
534
        )
535

536
        assert len(top_constants) <= 1, (
1✔
537
            "ISLa only accepts up to one constant (free variable), "
538
            + f'found {len(top_constants)}: {", ".join(map(str, top_constants))}'
539
        )
540

541
        self.top_constant = Maybe.from_iterator(iter(top_constants))
1✔
542

543
        quantifier_chains: List[Tuple[language.ForallFormula, ...]] = [
1✔
544
            tuple([f for f in c if isinstance(f, language.ForallFormula)])
545
            for c in get_quantifier_chains(formula)
546
        ]
547
        # TODO: Remove?
548
        self.quantifier_chains: List[Tuple[language.ForallFormula, ...]] = [
1✔
549
            c for c in quantifier_chains if c
550
        ]
551

552
        self.max_number_free_instantiations: int = max_number_free_instantiations
1✔
553
        self.max_number_smt_instantiations: int = max_number_smt_instantiations
1✔
554
        self.max_number_tree_insertion_results = max_number_tree_insertion_results
1✔
555
        self.enforce_unique_trees_in_queue = enforce_unique_trees_in_queue
1✔
556

557
        # Initialize Queue
558
        self.initial_tree = (
1✔
559
            initial_tree
560
            + Maybe(start_symbol)
561
            .map(lambda s: eassert(s, s in self.grammar))
562
            .map(lambda s: DerivationTree(s, None))
563
            + Maybe(
564
                DerivationTree(
565
                    self.top_constant.map(lambda c: c.n_type)
566
                    .orelse(lambda: "<start>")
567
                    .get(),
568
                    None,
569
                )
570
            )
571
        ).get()
572

573
        initial_formula = (
1✔
574
            self.top_constant.map(
575
                lambda c: self.formula.substitute_expressions({c: self.initial_tree})
576
            )
577
            .orelse(lambda: true())
578
            .get()
579
        )
580
        initial_state = SolutionState(initial_formula, self.initial_tree)
1✔
581
        initial_states = self.establish_invariant(initial_state)
1✔
582

583
        self.queue: List[Tuple[float, SolutionState]] = []
1✔
584
        self.tree_hashes_in_queue: Set[int] = {self.initial_tree.structural_hash()}
1✔
585
        self.state_hashes_in_queue: Set[int] = {hash(state) for state in initial_states}
1✔
586
        for state in initial_states:
1✔
587
            heapq.heappush(self.queue, (self.compute_cost(state), state))
1✔
588

589
        self.seen_coverages: Set[str] = set()
1✔
590
        self.current_level: int = 0
1✔
591
        self.step_cnt: int = 0
1✔
592
        self.last_cost_recomputation: int = 0
1✔
593

594
        self.regex_cache = {}
1✔
595

596
        self.solutions: List[DerivationTree] = []
1✔
597

598
        # Debugging stuff
599
        self.debug = debug
1✔
600
        self.state_tree: Dict[
1✔
601
            SolutionState, List[SolutionState]
602
        ] = {}  # is only filled if self.debug
603
        self.state_tree_root = None
1✔
604
        self.current_state = None
1✔
605
        self.costs: Dict[SolutionState, float] = {}
1✔
606

607
        if self.debug:
1✔
608
            self.state_tree[initial_state] = initial_states
1✔
609
            self.state_tree_root = initial_state
1✔
610
            self.costs[initial_state] = 0
1✔
611
            for state in initial_states:
1✔
612
                self.costs[state] = self.compute_cost(state)
1✔
613

614
    def solve(self) -> DerivationTree:
1✔
615
        """
616
        Attempts to compute a solution to the given ISLa formula. Returns that solution,
617
        if any. This function can be called repeatedly to obtain more solutions until
618
        one of two exception types is raised: A :class:`StopIteration` indicates that
619
        no more solution can be found; a :class:`TimeoutError` is raised if a timeout
620
        occurred. After that, an exception will be raised every time.
621

622
        The timeout can be controlled by the :code:`timeout_seconds`
623
        :meth:`constructor <isla.solver.ISLaSolver.__init__>` parameter.
624

625
        :return: A solution for the ISLa formula passed to the
626
          :class:`isla.solver.ISLaSolver`.
627
        """
628
        if self.timeout_seconds is not None and self.start_time is None:
1✔
629
            self.start_time = int(time.time())
1✔
630

631
        while self.queue:
1✔
632
            self.step_cnt += 1
1✔
633

634
            # import dill as pickle
635
            # state_hash = 9107154106757938105
636
            # out_file = "/tmp/saved_debug_state"
637
            # if hash(self.queue[0][1]) == state_hash:
638
            #     with open(out_file, 'wb') as debug_state_file:
639
            #         pickle.dump(self, debug_state_file)
640
            #     print(f"Dumping state to {out_file}")
641
            #     exit()
642

643
            if self.timeout_seconds is not None:
1✔
644
                if int(time.time()) - self.start_time > self.timeout_seconds:
1✔
645
                    self.logger.debug("TIMEOUT")
1✔
646
                    raise TimeoutError(self.timeout_seconds)
1✔
647

648
            if self.solutions:
1✔
649
                solution = self.solutions.pop(0)
1✔
650
                self.logger.debug('Found solution "%s"', solution)
1✔
651
                return solution
1✔
652

653
            cost: int
654
            state: SolutionState
655
            cost, state = heapq.heappop(self.queue)
1✔
656

657
            self.current_level = state.level
1✔
658
            self.tree_hashes_in_queue.discard(state.tree.structural_hash())
1✔
659
            self.state_hashes_in_queue.discard(hash(state))
1✔
660

661
            if self.debug:
1✔
662
                self.current_state = state
1✔
663
                self.state_tree.setdefault(state, [])
1✔
664
            self.logger.debug(
1✔
665
                "Polling new state (%s, %s) (hash %d, cost %f)",
666
                state.constraint,
667
                state.tree.to_string(show_open_leaves=True, show_ids=True),
668
                hash(state),
669
                cost,
670
            )
671
            self.logger.debug("Queue length: %s", len(self.queue))
1✔
672

673
            assert not isinstance(state.constraint, language.DisjunctiveFormula)
1✔
674

675
            # Instantiate all top-level structural predicate formulas.
676
            state = self.instantiate_structural_predicates(state)
1✔
677

678
            # Apply the first elimination function that is applicable.
679
            # The later ones are ignored.
680
            monad = chain_functions(
1✔
681
                [
682
                    self.noop_on_false_constraint,
683
                    self.eliminate_existential_integer_quantifiers,
684
                    self.instantiate_universal_integer_quantifiers,
685
                    self.match_all_universal_formulas,
686
                    self.expand_to_match_quantifiers,
687
                    self.eliminate_all_semantic_formulas,
688
                    self.eliminate_all_ready_semantic_predicate_formulas,
689
                    self.eliminate_and_match_first_existential_formula_and_expand,
690
                    self.assert_remaining_formulas_are_lazy_binding_semantic,
691
                    self.finish_unconstrained_trees,
692
                    self.expand,
693
                ],
694
                state,
695
            )
696

697
            def process_and_extend_solutions(
1✔
698
                result_states: List[SolutionState],
699
            ) -> None:
700
                assert result_states is not None
1✔
701
                self.solutions.extend(self.process_new_states(result_states))
1✔
702

703
            monad.if_present(process_and_extend_solutions)
1✔
704

705
        if self.solutions:
1✔
706
            solution = self.solutions.pop(0)
1✔
707
            self.logger.debug('Found solution "%s"', solution)
1✔
708
            return solution
1✔
709
        else:
710
            self.logger.debug("UNSAT")
1✔
711
            raise StopIteration()
1✔
712

713
    def check(self, inp: DerivationTree | str) -> bool:
1✔
714
        """
715
        Evaluates whether the given derivation tree satisfies the constraint passed to
716
        the solver. Raises an `UnknownResultError` if this could not be evaluated
717
        (e.g., because of a solver timeout or a semantic predicate that cannot be
718
        evaluated).
719

720
        :param inp: The input to evaluate, either readily parsed or as a string.
721
        :return: A truth value.
722
        """
723
        if isinstance(inp, str):
1✔
724
            try:
1✔
725
                self.parse(inp)
1✔
726
                return True
1✔
727
            except (SyntaxError, SemanticError):
1✔
728
                return False
1✔
729

730
        assert isinstance(inp, DerivationTree)
1✔
731

732
        result = evaluate(self.formula, inp, self.grammar)
1✔
733

734
        if result.is_unknown():
1✔
735
            raise UnknownResultError()
1✔
736
        else:
737
            return bool(result)
1✔
738

739
    def parse(
1✔
740
        self, inp: str, nonterminal: str = "<start>", skip_check: bool = False
741
    ) -> DerivationTree:
742
        """
743
        Parses the given input `inp`. Raises a `SyntaxError` if the input does not
744
        satisfy the grammar, a `SemanticError` if it does not satisfy the constraint
745
        (this is only checked if `nonterminal` is "<start>"), and returns the parsed
746
        `DerivationTree` otherwise.
747

748
        :param inp: The input to parse.
749
        :param nonterminal: The nonterminal to start parsing with, if a string
750
          corresponding to a sub-grammar shall be parsed. We don't check semantic
751
          correctness in that case.
752
        :param skip_check: If True, the semantic check is left out.
753
        :return: A parsed `DerivationTree`.
754
        """
755
        grammar = copy.deepcopy(self.grammar)
1✔
756
        if nonterminal != "<start>":
1✔
757
            grammar |= {"<start>": [nonterminal]}
1✔
758
            grammar = delete_unreachable(grammar)
1✔
759

760
        parser = EarleyParser(grammar)
1✔
761
        try:
1✔
762
            parse_tree = next(parser.parse(inp))
1✔
763
            if nonterminal != "<start>":
1✔
764
                parse_tree = parse_tree[1][0]
1✔
765
            tree = DerivationTree.from_parse_tree(parse_tree)
1✔
766
        except SyntaxError as err:
1✔
767
            self.logger.error(f'Error parsing "{inp}" starting with "{nonterminal}"')
1✔
768
            raise err
1✔
769

770
        if not skip_check and nonterminal == "<start>" and not self.check(tree):
1✔
771
            raise SemanticError()
1✔
772

773
        return tree
1✔
774

775
    def repair(
1✔
776
        self, inp: DerivationTree | str, fix_timeout_seconds: float = 3
777
    ) -> Maybe[DerivationTree]:
778
        """
779
        Attempts to repair the given input. The input must not violate syntactic
780
        (grammar) constraints. If semantic constraints are violated, this method
781
        gradually abstracts the input and tries to turn it into a valid one.
782
        Note that intensive structural manipulations are not performed; we merely
783
        try to satisfy SMT-LIB and semantic predicate constraints.
784

785
        :param inp: The input to fix.
786
        :param fix_timeout_seconds: A timeout used when calling the solver for an
787
          abstracted input. Usually, a low timeout suffices.
788
        :return: A fixed input (or the original, if it was not broken) or nothing.
789
        """
790

791
        inp = self.parse(inp, skip_check=True) if isinstance(inp, str) else inp
1✔
792

793
        try:
1✔
794
            if self.check(inp) or not self.top_constant.is_present():
1✔
795
                return Maybe(inp)
1✔
796
        except UnknownResultError:
1✔
797
            pass
1✔
798

799
        formula = self.top_constant.map(
1✔
800
            lambda c: self.formula.substitute_expressions({c: inp})
801
        ).get()
802

803
        set_smt_auto_eval(formula, False)
1✔
804
        set_smt_auto_subst(formula, False)
1✔
805

806
        qfr_free = eliminate_quantifiers(
1✔
807
            formula,
808
            grammar=self.grammar,
809
            numeric_constants={
810
                c
811
                for c in VariablesCollector.collect(formula)
812
                if isinstance(c, language.Constant) and c.is_numeric()
813
            },
814
        )
815

816
        # We first evaluate all structural predicates; for now, we do not interfere
817
        # with structure.
818
        semantic_only = qfr_free.transform(EvaluatePredicateFormulasTransformer(inp))
1✔
819

820
        if semantic_only == sc.false():
1✔
821
            # This cannot be repaired while preserving structure; for existential
822
            # problems, we could try tree insertion. We leave this for future work.
823
            return Maybe.nothing()
1✔
824

825
        # We try to satisfy any of the remaining disjunctive elements, in random order
826
        for formula_to_satisfy in shuffle(split_disjunction(semantic_only)):
1✔
827
            # Now, we consider all combinations of 1, 2, ... of the derivation trees
828
            # participating in the formula. We successively prune deeper and deeper
829
            # subtrees until the resulting input evaluates to "unknown" for the given
830
            # formula.
831

832
            participating_paths = {
1✔
833
                inp.find_node(arg) for arg in formula_to_satisfy.tree_arguments()
834
            }
835

836
            def do_complete(tree: DerivationTree) -> Maybe[DerivationTree]:
1✔
837
                return (
1✔
838
                    Exceptional.of(
839
                        self.copy_without_queue(
840
                            initial_tree=Maybe(tree),
841
                            timeout_seconds=Maybe(fix_timeout_seconds),
842
                        ).solve
843
                    )
844
                    .map(Maybe)
845
                    .recover(
846
                        lambda _: Maybe.nothing(),
847
                        UnknownResultError,
848
                        TimeoutError,
849
                        StopIteration,
850
                    )
851
                    .reraise()
852
                    .get()
853
                )
854

855
            # If p1, p2 are in participating_paths, then we consider the following
856
            # path combinations (roughly) in the listed order:
857
            # {p1}, {p2}, {p1, p2}, {p1[:-1]}, {p2[:-1]}, {p1[:-1], p2}, {p1, p2[:-1]},
858
            # {p1[:-1], p2[:-1]}, ...
859
            for abstracted_tree in generate_abstracted_trees(inp, participating_paths):
1✔
860
                maybe_completed: Maybe[DerivationTree] = (
1✔
861
                    Exceptional.of(lambda: self.check(abstracted_tree))
862
                    .map(lambda _: Maybe.nothing())
863
                    .recover(lambda _: Maybe(abstracted_tree), UnknownResultError)
864
                    .recover(lambda _: Maybe.nothing())
865
                    .get()
866
                    .bind(do_complete)
867
                )
868

869
                if maybe_completed.is_present():
1✔
870
                    return maybe_completed
1✔
871

872
        return Maybe.nothing()
×
873

874
    def mutate(
1✔
875
        self,
876
        inp: DerivationTree | str,
877
        min_mutations: int = 2,
878
        max_mutations: int = 5,
879
        fix_timeout_seconds: float = 1,
880
    ) -> DerivationTree:
881
        """
882
        Mutates `inp` such that the result satisfies the constraint.
883

884
        :param inp: The input to mutate.
885
        :param min_mutations: The minimum number of mutation steps to perform.
886
        :param max_mutations: The maximum number of mutation steps to perform.
887
        :param fix_timeout_seconds: A timeout used when calling the solver for fixing
888
          an abstracted input. Usually, a low timeout suffices.
889
        :return: A mutated input.
890
        """
891

892
        inp = self.parse(inp, skip_check=True) if isinstance(inp, str) else inp
1✔
893
        mutator = Mutator(
1✔
894
            self.grammar,
895
            min_mutations=min_mutations,
896
            max_mutations=max_mutations,
897
            graph=self.graph,
898
        )
899

900
        while True:
1✔
901
            mutated = mutator.mutate(inp)
1✔
902
            if mutated.structurally_equal(inp):
1✔
903
                continue
1✔
904
            maybe_fixed = self.repair(mutated, fix_timeout_seconds)
1✔
905
            if maybe_fixed.is_present():
1✔
906
                return maybe_fixed.get()
1✔
907

908
    def copy_without_queue(
1✔
909
        self,
910
        grammar: Maybe[Grammar | str] = Maybe.nothing(),
911
        formula: Maybe[language.Formula | str] = Maybe.nothing(),
912
        max_number_free_instantiations: Maybe[int] = Maybe.nothing(),
913
        max_number_smt_instantiations: Maybe[int] = Maybe.nothing(),
914
        max_number_tree_insertion_results: Maybe[int] = Maybe.nothing(),
915
        enforce_unique_trees_in_queue: Maybe[bool] = Maybe.nothing(),
916
        debug: Maybe[bool] = Maybe.nothing(),
917
        cost_computer: Maybe["CostComputer"] = Maybe.nothing(),
918
        timeout_seconds: Maybe[int] = Maybe.nothing(),
919
        global_fuzzer: Maybe[bool] = Maybe.nothing(),
920
        predicates_unique_in_int_arg: Maybe[
921
            Tuple[language.SemanticPredicate, ...]
922
        ] = Maybe.nothing(),
923
        fuzzer_factory: Maybe[Callable[[Grammar], GrammarFuzzer]] = Maybe.nothing(),
924
        tree_insertion_methods: Maybe[int] = Maybe.nothing(),
925
        activate_unsat_support: Maybe[bool] = Maybe.nothing(),
926
        grammar_unwinding_threshold: Maybe[int] = Maybe.nothing(),
927
        initial_tree: Maybe[DerivationTree] = Maybe.nothing(),
928
        enable_optimized_z3_queries: Maybe[bool] = Maybe.nothing(),
929
        start_symbol: Optional[str] = None,
930
    ):
931
        result = ISLaSolver(
1✔
932
            grammar=grammar.orelse(lambda: self.grammar).get(),
933
            formula=formula.orelse(lambda: self.formula).get(),
934
            max_number_free_instantiations=max_number_free_instantiations.orelse(
935
                lambda: self.max_number_free_instantiations
936
            ).get(),
937
            max_number_smt_instantiations=max_number_smt_instantiations.orelse(
938
                lambda: self.max_number_smt_instantiations
939
            ).get(),
940
            max_number_tree_insertion_results=max_number_tree_insertion_results.orelse(
941
                lambda: self.max_number_tree_insertion_results
942
            ).get(),
943
            enforce_unique_trees_in_queue=enforce_unique_trees_in_queue.orelse(
944
                lambda: self.enforce_unique_trees_in_queue
945
            ).get(),
946
            debug=debug.orelse(lambda: self.debug).get(),
947
            cost_computer=cost_computer.orelse(lambda: self.cost_computer).get(),
948
            timeout_seconds=timeout_seconds.orelse(lambda: self.timeout_seconds).a,
949
            global_fuzzer=global_fuzzer.orelse(lambda: self.global_fuzzer).get(),
950
            predicates_unique_in_int_arg=predicates_unique_in_int_arg.orelse(
951
                lambda: self.predicates_unique_in_int_arg
952
            ).get(),
953
            fuzzer_factory=fuzzer_factory.orelse(lambda: self.fuzzer_factory).get(),
954
            tree_insertion_methods=tree_insertion_methods.orelse(
955
                lambda: self.tree_insertion_methods
956
            ).get(),
957
            activate_unsat_support=activate_unsat_support.orelse(
958
                lambda: self.activate_unsat_support
959
            ).get(),
960
            grammar_unwinding_threshold=grammar_unwinding_threshold.orelse(
961
                lambda: self.grammar_unwinding_threshold
962
            ).get(),
963
            initial_tree=initial_tree,
964
            enable_optimized_z3_queries=enable_optimized_z3_queries.orelse(
965
                lambda: self.enable_optimized_z3_queries
966
            ).get(),
967
            start_symbol=start_symbol,
968
        )
969

970
        result.regex_cache = self.regex_cache
1✔
971

972
        return result
1✔
973

974
    @staticmethod
1✔
975
    def noop_on_false_constraint(
1✔
976
        state: SolutionState,
977
    ) -> Maybe[List[SolutionState]]:
978
        if state.constraint == sc.false():
1✔
979
            # This state can be silently discarded.
980
            return Maybe([state])
×
981

982
        return Maybe.nothing()
1✔
983

984
    def expand_to_match_quantifiers(
1✔
985
        self,
986
        state: SolutionState,
987
    ) -> Maybe[List[SolutionState]]:
988
        if all(
1✔
989
            not isinstance(conjunct, language.ForallFormula)
990
            for conjunct in get_conjuncts(state.constraint)
991
        ):
992
            return Maybe.nothing()
1✔
993

994
        expansion_result = self.expand_tree(state)
1✔
995

996
        assert len(expansion_result) > 0, f"State {state} will never leave the queue."
1✔
997
        self.logger.debug(
1✔
998
            "Expanding state %s (%d successors)", state, len(expansion_result)
999
        )
1000

1001
        return Maybe(expansion_result)
1✔
1002

1003
    def eliminate_and_match_first_existential_formula_and_expand(
1✔
1004
        self,
1005
        state: SolutionState,
1006
    ) -> Maybe[List[SolutionState]]:
1007
        elim_result = self.eliminate_and_match_first_existential_formula(state)
1✔
1008
        if elim_result is None:
1✔
1009
            return Maybe.nothing()
1✔
1010

1011
        # Also add some expansions of the original state, to create a larger
1012
        # solution stream (otherwise, it might be possible that only a small
1013
        # finite number of solutions are generated for existential formulas).
1014
        return Maybe(
1✔
1015
            elim_result + self.expand_tree(state, limit=2, only_universal=False)
1016
        )
1017

1018
    def assert_remaining_formulas_are_lazy_binding_semantic(
1✔
1019
        self,
1020
        state: SolutionState,
1021
    ) -> Maybe[List[SolutionState]]:
1022
        # SEMANTIC PREDICATE FORMULAS can remain if they bind lazily. In that case, we can choose a random
1023
        # instantiation and let the predicate "fix" the resulting tree.
1024
        assert state.constraint == sc.true() or all(
1✔
1025
            isinstance(conjunct, language.SemanticPredicateFormula)
1026
            or (
1027
                isinstance(conjunct, language.NegatedFormula)
1028
                and isinstance(conjunct.args[0], language.SemanticPredicateFormula)
1029
            )
1030
            for conjunct in get_conjuncts(state.constraint)
1031
        ), (
1032
            "Constraint is not true and contains formulas "
1033
            f"other than semantic predicate formulas: {state.constraint}"
1034
        )
1035

1036
        assert (
1✔
1037
            state.constraint == sc.true()
1038
            or all(
1039
                not pred_formula.binds_tree(leaf)
1040
                for pred_formula in get_conjuncts(state.constraint)
1041
                if isinstance(pred_formula, language.SemanticPredicateFormula)
1042
                for _, leaf in state.tree.open_leaves()
1043
            )
1044
            or all(
1045
                not cast(
1046
                    language.SemanticPredicateFormula, pred_formula.args[0]
1047
                ).binds_tree(leaf)
1048
                for pred_formula in get_conjuncts(state.constraint)
1049
                if isinstance(pred_formula, language.NegatedFormula)
1050
                and isinstance(pred_formula.args[0], language.SemanticPredicateFormula)
1051
            )
1052
            for _, leaf in state.tree.open_leaves()
1053
        ), (
1054
            "Constraint is not true and contains semantic predicate formulas binding open tree leaves: "
1055
            f"{state.constraint}, leaves: "
1056
            + ", ".join(
1057
                [str(leaf) for _, leaf in state.tree.open_leaves()],
1058
            )
1059
        )
1060

1061
        return Maybe.nothing()
1✔
1062

1063
    def finish_unconstrained_trees(
1✔
1064
        self,
1065
        state: SolutionState,
1066
    ) -> Maybe[List[SolutionState]]:
1067
        fuzzer = (
1✔
1068
            self.fuzzer if self.global_fuzzer else self.fuzzer_factory(self.grammar)
1069
        )
1070

1071
        if isinstance(fuzzer, GrammarCoverageFuzzer):
1✔
1072
            fuzzer.covered_expansions.update(self.seen_coverages)
1✔
1073

1074
        if state.constraint != sc.true():
1✔
1075
            return Maybe.nothing()
1✔
1076

1077
        closed_results: List[SolutionState] = []
1✔
1078
        for _ in range(self.max_number_free_instantiations):
1✔
1079
            result = state.tree
1✔
1080
            for path, leaf in state.tree.open_leaves():
1✔
1081
                leaf_inst = fuzzer.expand_tree(DerivationTree(leaf.value, None))
1✔
1082
                result = result.replace_path(path, leaf_inst)
1✔
1083

1084
            closed_results.append(SolutionState(state.constraint, result))
1✔
1085

1086
        return Maybe(closed_results)
1✔
1087

1088
    def expand(
1✔
1089
        self,
1090
        state: SolutionState,
1091
    ) -> Maybe[List[SolutionState]]:
1092
        fuzzer = (
1✔
1093
            self.fuzzer if self.global_fuzzer else self.fuzzer_factory(self.grammar)
1094
        )
1095

1096
        if isinstance(fuzzer, GrammarCoverageFuzzer):
1✔
1097
            fuzzer.covered_expansions.update(self.seen_coverages)
1✔
1098

1099
        result: List[SolutionState] = []
1✔
1100
        for _ in range(self.max_number_free_instantiations):
1✔
1101
            substitutions: Dict[DerivationTree, DerivationTree] = {
1✔
1102
                subtree: fuzzer.expand_tree(DerivationTree(subtree.value, None))
1103
                for path, subtree in state.tree.open_leaves()
1104
            }
1105

1106
            if substitutions:
1✔
1107
                result.append(
1✔
1108
                    SolutionState(
1109
                        state.constraint.substitute_expressions(substitutions),
1110
                        state.tree.substitute(substitutions),
1111
                    )
1112
                )
1113

1114
        return Maybe(result)
1✔
1115

1116
    def instantiate_structural_predicates(self, state: SolutionState) -> SolutionState:
1✔
1117
        predicate_formulas = [
1✔
1118
            pred_formula
1119
            for pred_formula in language.FilterVisitor(
1120
                lambda f: isinstance(f, language.StructuralPredicateFormula)
1121
            ).collect(state.constraint)
1122
            if (
1123
                isinstance(pred_formula, language.StructuralPredicateFormula)
1124
                and all(
1125
                    not isinstance(arg, language.Variable) for arg in pred_formula.args
1126
                )
1127
            )
1128
        ]
1129

1130
        formula = state.constraint
1✔
1131
        for predicate_formula in predicate_formulas:
1✔
1132
            instantiation = language.SMTFormula(
1✔
1133
                z3.BoolVal(predicate_formula.evaluate(state.tree))
1134
            )
1135
            self.logger.debug(
1✔
1136
                "Eliminating (-> %s) structural predicate formula %s",
1137
                instantiation,
1138
                predicate_formula,
1139
            )
1140
            formula = language.replace_formula(
1✔
1141
                formula, predicate_formula, instantiation
1142
            )
1143

1144
        return SolutionState(formula, state.tree)
1✔
1145

1146
    def eliminate_existential_integer_quantifiers(
1✔
1147
        self, state: SolutionState
1148
    ) -> Maybe[List[SolutionState]]:
1149
        existential_int_formulas = [
1✔
1150
            conjunct
1151
            for conjunct in get_conjuncts(state.constraint)
1152
            if isinstance(conjunct, language.ExistsIntFormula)
1153
        ]
1154

1155
        if not existential_int_formulas:
1✔
1156
            return Maybe.nothing()
1✔
1157

1158
        formula = state.constraint
1✔
1159
        for existential_int_formula in existential_int_formulas:
1✔
1160
            # The following check for validity is not only a performance measure, but required
1161
            # when existential integer formulas are re-inserted. Otherwise, new constants get
1162
            # introduced, and the solver won't terminate.
1163
            if evaluate(
1✔
1164
                existential_int_formula,
1165
                state.tree,
1166
                self.grammar,
1167
                assumptions={
1168
                    f
1169
                    for f in split_conjunction(state.constraint)
1170
                    if f != existential_int_formula
1171
                },
1172
            ).is_true():
1173
                self.logger.debug(
1✔
1174
                    "Removing existential integer quantifier '%.30s', already implied "
1175
                    "by tree and existing constraints",
1176
                    existential_int_formula,
1177
                )
1178
                # This should simplify the process after quantifier re-insertion.
1179
                return Maybe(
1✔
1180
                    [
1181
                        SolutionState(
1182
                            language.replace_formula(
1183
                                state.constraint, existential_int_formula, sc.true()
1184
                            ),
1185
                            state.tree,
1186
                        )
1187
                    ]
1188
                )
1189

1190
            self.logger.debug(
1✔
1191
                "Eliminating existential integer quantifier %s", existential_int_formula
1192
            )
1193
            used_vars = set(VariablesCollector.collect(formula))
1✔
1194
            fresh = language.fresh_constant(
1✔
1195
                used_vars,
1196
                language.Constant(
1197
                    existential_int_formula.bound_variable.name,
1198
                    existential_int_formula.bound_variable.n_type,
1199
                ),
1200
            )
1201
            instantiation = existential_int_formula.inner_formula.substitute_variables(
1✔
1202
                {existential_int_formula.bound_variable: fresh}
1203
            )
1204
            formula = language.replace_formula(
1✔
1205
                formula, existential_int_formula, instantiation
1206
            )
1207

1208
        return Maybe([SolutionState(formula, state.tree)])
1✔
1209

1210
    def instantiate_universal_integer_quantifiers(
1✔
1211
        self, state: SolutionState
1212
    ) -> Maybe[List[SolutionState]]:
1213
        universal_int_formulas = [
1✔
1214
            conjunct
1215
            for conjunct in get_conjuncts(state.constraint)
1216
            if isinstance(conjunct, language.ForallIntFormula)
1217
        ]
1218

1219
        if not universal_int_formulas:
1✔
1220
            return Maybe.nothing()
1✔
1221

1222
        results: List[SolutionState] = [state]
1✔
1223
        for universal_int_formula in universal_int_formulas:
1✔
1224
            results = [
1✔
1225
                result
1226
                for formula_list in [
1227
                    self.instantiate_universal_integer_quantifier(
1228
                        previous_result, universal_int_formula
1229
                    )
1230
                    for previous_result in results
1231
                ]
1232
                for result in formula_list
1233
            ]
1234

1235
        return Maybe(results)
1✔
1236

1237
    def instantiate_universal_integer_quantifier(
1✔
1238
        self, state: SolutionState, universal_int_formula: language.ForallIntFormula
1239
    ) -> List[SolutionState]:
1240
        results = self.instantiate_universal_integer_quantifier_by_enumeration(
1✔
1241
            state, universal_int_formula
1242
        )
1243

1244
        if results:
1✔
1245
            return results
1✔
1246

1247
        return self.instantiate_universal_integer_quantifier_by_transformation(
1✔
1248
            state, universal_int_formula
1249
        )
1250

1251
    def instantiate_universal_integer_quantifier_by_transformation(
1✔
1252
        self, state: SolutionState, universal_int_formula: language.ForallIntFormula
1253
    ) -> List[SolutionState]:
1254
        # If the enumeration approach was not successful, we con transform the universal int
1255
        # quantifier to an existential one in a particular situation:
1256
        #
1257
        # Let phi(elem, i) be such that phi(elem) (for fixed first argument) is a unary
1258
        # relation that holds for exactly one argument:
1259
        #
1260
        # forall <A> elem:
1261
        #   exists int i:
1262
        #     phi(elem, i) and
1263
        #     forall int i':
1264
        #       phi(elem, i) <==> i = i'
1265
        #
1266
        # Then, the following transformations are equivalence-preserving:
1267
        #
1268
        # forall int i:
1269
        #   exists <A> elem:
1270
        #     not phi(elem, i)
1271
        #
1272
        # <==> (*)
1273
        #
1274
        # exists int i:
1275
        #   exists <A> elem':
1276
        #     phi(elem', i) &
1277
        #   exists <A> elem:
1278
        #     not phi(elem, i) &
1279
        #   forall int i':
1280
        #     i != i' ->
1281
        #     exists <A> elem'':
1282
        #       not phi(elem'', i')
1283
        #
1284
        # <==> (+)
1285
        #
1286
        # exists int i:
1287
        #   exists <A> elem':
1288
        #     phi(elem', i) &
1289
        #   exists <A> elem:
1290
        #     not phi(elem, i)
1291
        #
1292
        # (*)
1293
        # Problematic is only the first inner conjunct. However, for every elem, there
1294
        # has to be an i such that phi(elem, i) holds. If there is no no in the first
1295
        # place, also the original formula would be unsatisfiable. Without this conjunct,
1296
        # the transformation is a simple "quantifier unwinding."
1297
        #
1298
        # (+)
1299
        # Let i' != i. Choose elem'' := elem': Since phi(elem', i) holds and i != i',
1300
        # "not phi(elem', i')" has to hold.
1301

1302
        if (
1✔
1303
            isinstance(universal_int_formula.inner_formula, language.ExistsFormula)
1304
            and isinstance(
1305
                universal_int_formula.inner_formula.inner_formula,
1306
                language.NegatedFormula,
1307
            )
1308
            and isinstance(
1309
                universal_int_formula.inner_formula.inner_formula.args[0],
1310
                language.SemanticPredicateFormula,
1311
            )
1312
            and cast(
1313
                language.SemanticPredicateFormula,
1314
                universal_int_formula.inner_formula.inner_formula.args[0],
1315
            ).predicate
1316
            in self.predicates_unique_in_int_arg
1317
        ):
1318
            inner_formula: language.ExistsFormula = universal_int_formula.inner_formula
1✔
1319
            predicate_formula: language.SemanticPredicateFormula = cast(
1✔
1320
                language.SemanticPredicateFormula,
1321
                cast(language.NegatedFormula, inner_formula.inner_formula).args[0],
1322
            )
1323

1324
            fresh_var = language.fresh_bound_variable(
1✔
1325
                language.VariablesCollector().collect(state.constraint),
1326
                inner_formula.bound_variable,
1327
                add=False,
1328
            )
1329

1330
            new_formula = language.ExistsIntFormula(
1✔
1331
                universal_int_formula.bound_variable,
1332
                language.ExistsFormula(
1333
                    fresh_var,
1334
                    inner_formula.in_variable,
1335
                    predicate_formula.substitute_variables(
1336
                        {inner_formula.bound_variable: fresh_var}
1337
                    ),
1338
                )
1339
                & inner_formula,
1340
            )
1341

1342
            self.logger.debug(
1✔
1343
                "Transforming universal integer quantifier "
1344
                "(special case, see code comments for explanation):\n%s ==> %s",
1345
                universal_int_formula,
1346
                new_formula,
1347
            )
1348

1349
            return [
1✔
1350
                SolutionState(
1351
                    language.replace_formula(
1352
                        state.constraint, universal_int_formula, new_formula
1353
                    ),
1354
                    state.tree,
1355
                )
1356
            ]
1357

1358
        self.logger.warning(
×
1359
            "Did not find a way to instantiate formula %s!\n"
1360
            + "Discarding this state. Please report this to your nearest ISLa developer.",
1361
            universal_int_formula,
1362
        )
1363

1364
        return []
×
1365

1366
    def instantiate_universal_integer_quantifier_by_enumeration(
1✔
1367
        self, state: SolutionState, universal_int_formula: ForallIntFormula
1368
    ) -> Optional[List[SolutionState]]:
1369
        constant = language.Constant(
1✔
1370
            universal_int_formula.bound_variable.name,
1371
            universal_int_formula.bound_variable.n_type,
1372
        )
1373

1374
        # noinspection PyTypeChecker
1375
        inner_formula = universal_int_formula.inner_formula.substitute_variables(
1✔
1376
            {universal_int_formula.bound_variable: constant}
1377
        )
1378

1379
        instantiations: List[
1✔
1380
            Dict[
1381
                language.Constant | DerivationTree,
1382
                int | language.Constant | DerivationTree,
1383
            ]
1384
        ] = []
1385

1386
        if isinstance(universal_int_formula.inner_formula, language.DisjunctiveFormula):
1✔
1387
            # In the disjunctive case, we attempt to falsify all SMT formulas in the inner formula
1388
            # (on top level) that contain the bound variable as argument.
1389
            smt_disjuncts = [
1✔
1390
                formula
1391
                for formula in language.split_disjunction(inner_formula)
1392
                if isinstance(formula, language.SMTFormula)
1393
                and constant in formula.free_variables()
1394
            ]
1395

1396
            if smt_disjuncts and len(smt_disjuncts) < len(
1✔
1397
                language.split_disjunction(inner_formula)
1398
            ):
1399
                instantiation_values = (
1✔
1400
                    self.infer_satisfying_assignments_for_smt_formula(
1401
                        -reduce(language.SMTFormula.disjunction, smt_disjuncts),
1402
                        constant,
1403
                    )
1404
                )
1405

1406
                # We also try to falsify (negated) semantic predicate formulas, if present,
1407
                # if there exist any remaining disjuncts.
1408
                semantic_predicate_formulas: List[
1✔
1409
                    Tuple[language.SemanticPredicateFormula, bool]
1410
                ] = [
1411
                    (pred_formula, False)
1412
                    if isinstance(pred_formula, language.SemanticPredicateFormula)
1413
                    else (cast(language.NegatedFormula, pred_formula).args[0], True)
1414
                    for pred_formula in language.FilterVisitor(
1415
                        lambda f: (
1416
                            constant in f.free_variables()
1417
                            and (
1418
                                isinstance(f, language.SemanticPredicateFormula)
1419
                                or isinstance(f, language.NegatedFormula)
1420
                                and isinstance(
1421
                                    f.args[0], language.SemanticPredicateFormula
1422
                                )
1423
                            )
1424
                        ),
1425
                        do_continue=lambda f: isinstance(
1426
                            f, language.DisjunctiveFormula
1427
                        ),
1428
                    ).collect(inner_formula)
1429
                    if all(
1430
                        not isinstance(var, language.BoundVariable)
1431
                        for var in pred_formula.free_variables()
1432
                    )
1433
                ]
1434

1435
                if semantic_predicate_formulas and len(
1✔
1436
                    semantic_predicate_formulas
1437
                ) + len(smt_disjuncts) < len(language.split_disjunction(inner_formula)):
1438
                    for value in instantiation_values:
1✔
1439
                        instantiation: Dict[
1✔
1440
                            language.Constant | DerivationTree,
1441
                            int | language.Constant | DerivationTree,
1442
                        ] = {constant: value}
1443
                        for (
1✔
1444
                            semantic_predicate_formula,
1445
                            negated,
1446
                        ) in semantic_predicate_formulas:
1447
                            eval_result = cast(
1✔
1448
                                language.SemanticPredicateFormula,
1449
                                language.substitute(
1450
                                    semantic_predicate_formula, {constant: value}
1451
                                ),
1452
                            ).evaluate(self.graph, negate=not negated)
1453
                            if eval_result.ready() and not eval_result.is_boolean():
1✔
1454
                                instantiation.update(eval_result.result)
1✔
1455
                        instantiations.append(instantiation)
1✔
1456
                else:
1457
                    instantiations.extend(
×
1458
                        [{constant: value} for value in instantiation_values]
1459
                    )
1460

1461
        results: List[SolutionState] = []
1✔
1462
        for instantiation in instantiations:
1✔
1463
            self.logger.debug(
1✔
1464
                "Instantiating universal integer quantifier (%s -> %s) %s",
1465
                universal_int_formula.bound_variable,
1466
                instantiation[constant],
1467
                universal_int_formula,
1468
            )
1469

1470
            formula = language.replace_formula(
1✔
1471
                state.constraint,
1472
                universal_int_formula,
1473
                language.substitute(inner_formula, instantiation),
1474
            )
1475
            formula = language.substitute(formula, instantiation)
1✔
1476

1477
            tree = state.tree.substitute(
1✔
1478
                {
1479
                    tree: subst
1480
                    for tree, subst in instantiation.items()
1481
                    if isinstance(tree, DerivationTree)
1482
                }
1483
            )
1484

1485
            results.append(SolutionState(formula, tree))
1✔
1486

1487
        return results
1✔
1488

1489
    def infer_satisfying_assignments_for_smt_formula(
1✔
1490
        self, smt_formula: language.SMTFormula, constant: language.Constant
1491
    ) -> Set[int | language.Constant]:
1492
        """
1493
        This method returns `self.max_number_free_instantiations` many solutions for
1494
        the given :class:`~isla.language.SMTFormula` if `constant` is the only free
1495
        variable in `smt_formula`. The given formula must be a numeric formula, i.e.,
1496
        all free variables must be numeric. If more than one free variables are
1497
        present, at most one solution is returned (see example below).
1498

1499
        :param smt_formula: The :class:`~isla.language.SMTFormula` to solve. Must
1500
          only contain numeric free variables.
1501
        :param constant: One free variable in `smt_formula`.
1502
        :return: A set of solutions (see explanation above & comment below).
1503

1504
        We create a solver with a dummy grammar (it's not needed for this example),
1505
        choosing a value of 5 for `max_number_free_instantiations`.
1506

1507
        >>> solver = ISLaSolver(
1508
        ...     '<start> ::= "x"',  # dummy grammar
1509
        ...     max_number_free_instantiations=5,
1510
        ... )
1511

1512
        The formula we're considering is `x > 10`.
1513

1514
        >>> from isla.language import Constant, SMTFormula, Variable, unparse_isla
1515
        >>> x = Constant("x", Variable.NUMERIC_NTYPE)
1516
        >>> formula = SMTFormula(z3.StrToInt(x.to_smt()) > z3.IntVal(10), x)
1517
        >>> unparse_isla(formula)
1518
        '(< 10 (str.to.int x))'
1519

1520
        We obtain five results (due to our choice of `max_number_free_instantiations`).
1521

1522
        >>> results = solver.infer_satisfying_assignments_for_smt_formula(formula, x)
1523
        >>> len(results)
1524
        5
1525

1526
        All results are `int`s...
1527

1528
        >>> all(isinstance(result, int) for result in results)
1529
        True
1530

1531
        ...and all are strictly greater than 10.
1532

1533
        >>> all(result > 10 for result in results)
1534
        True
1535

1536
        Now, lets consider `x == y`. This formula contains *two* free variables, `x`
1537
        and `y`. It is the only type of formula with more than one variable for which
1538
        this method will return a solution in the current state.
1539

1540
        >>> y = Constant("y", Variable.NUMERIC_NTYPE)
1541
        >>> formula = SMTFormula(
1542
        ...     z3_eq(z3.StrToInt(x.to_smt()), z3.StrToInt(y.to_smt())), x, y)
1543
        >>> unparse_isla(formula)
1544
        '(= (str.to.int x) (str.to.int y))'
1545

1546
        The solution is the singleton set with the variable `y`, which is an
1547
        instantiation of the constant `x` solving the equation.
1548

1549
        >>> solver.infer_satisfying_assignments_for_smt_formula(formula, x)
1550
        {Constant("y", "NUM")}
1551

1552
        If we choose a different type of formula (a greater-than relation), we obtain
1553
        an empty solution set.
1554

1555
        >>> formula = SMTFormula(
1556
        ...     z3.StrToInt(x.to_smt()) > z3.StrToInt(y.to_smt()), x, y)
1557
        >>> unparse_isla(formula)
1558
        '(> (str.to.int x) (str.to.int y))'
1559
        >>> solver.infer_satisfying_assignments_for_smt_formula(formula, x)
1560
        set()
1561

1562
        With a non-numeric formula, we obtain an AssertionError (if assertions are
1563
        enabled). This method expects to be called only internally, so this should
1564
        not happen (with or without activated assertions).
1565

1566
        >>> z = Constant("x", "<start>")
1567
        >>> formula = SMTFormula(z3_eq(z.to_smt(), z3.StringVal("x")), z)
1568
        >>> print(unparse_isla(formula))
1569
        const x: <start>;
1570
        <BLANKLINE>
1571
        (= x "x")
1572
        >>> results = solver.infer_satisfying_assignments_for_smt_formula(formula, z)
1573
        Traceback (most recent call last):
1574
        ...
1575
        AssertionError: Expected numeric solution.
1576
        """
1577

1578
        free_variables = smt_formula.free_variables()
1✔
1579
        max_instantiations = (
1✔
1580
            self.max_number_free_instantiations if len(free_variables) == 1 else 1
1581
        )
1582

1583
        try:
1✔
1584
            solver_result = self.solve_quantifier_free_formula(
1✔
1585
                (smt_formula,), max_instantiations=max_instantiations
1586
            )
1587

1588
            solutions: Dict[language.Constant, Set[int]] = {
1✔
1589
                c: {
1590
                    int(solution[cast(language.Constant, c)].value)
1591
                    for solution in solver_result
1592
                }
1593
                for c in free_variables
1594
            }
1595
        except ValueError:
1✔
1596
            assert False, "Expected numeric solution."
1✔
1597

1598
        if solutions:
1✔
1599
            if len(free_variables) == 1:
1✔
1600
                return solutions[constant]
1✔
1601
            else:
1602
                assert all(len(solution) == 1 for solution in solutions.values())
1✔
1603
                # In situations with multiple variables, we might have to abstract from
1604
                # concrete values. Currently, we only support simple equality inference
1605
                # (based on one sample...). Note that for supporting *more complex*
1606
                # terms (e.g., additions), we would have to extend the whole
1607
                # infrastructure: Substitutions with complex terms, and complex terms
1608
                # in semantic predicate arguments, are unsupported as of now.
1609
                candidates = {
1✔
1610
                    c
1611
                    for c in solutions
1612
                    if c != constant
1613
                    and next(iter(solutions[c])) == next(iter(solutions[constant]))
1614
                }
1615

1616
                # Filter working candidates
1617
                return {
1✔
1618
                    c
1619
                    for c in candidates
1620
                    if self.solve_quantifier_free_formula(
1621
                        (
1622
                            cast(
1623
                                language.SMTFormula,
1624
                                smt_formula.substitute_variables({constant: c}),
1625
                            ),
1626
                        ),
1627
                        max_instantiations=1,
1628
                    )
1629
                }
1630

1631
    def eliminate_all_semantic_formulas(
1✔
1632
        self, state: SolutionState, max_instantiations: Optional[int] = None
1633
    ) -> Maybe[List[SolutionState]]:
1634
        """
1635
        Eliminates all SMT-LIB formulas that appear in `state`'s constraint as conjunctive elements.
1636
        If, e.g., an SMT-LIB formula occurs as a disjunction, no solution is computed.
1637

1638
        :param state: The state in which to solve all SMT-LIB formulas.
1639
        :param max_instantiations: The number of solutions the SMT solver should be asked for.
1640
        :return: The discovered solutions.
1641
        """
1642

1643
        conjuncts = split_conjunction(state.constraint)
1✔
1644
        semantic_formulas = [
1✔
1645
            conjunct
1646
            for conjunct in conjuncts
1647
            if isinstance(conjunct, language.SMTFormula)
1648
            and not z3.is_true(conjunct.formula)
1649
        ]
1650

1651
        if not semantic_formulas:
1✔
1652
            return Maybe.nothing()
1✔
1653

1654
        self.logger.debug(
1✔
1655
            "Eliminating semantic formulas [%s]", lazyjoin(", ", semantic_formulas)
1656
        )
1657

1658
        prefix_conjunction = reduce(lambda a, b: a & b, semantic_formulas, sc.true())
1✔
1659
        new_disjunct = prefix_conjunction & reduce(
1✔
1660
            lambda a, b: a & b,
1661
            [conjunct for conjunct in conjuncts if conjunct not in semantic_formulas],
1662
            sc.true(),
1663
        )
1664

1665
        return Maybe(
1✔
1666
            self.eliminate_semantic_formula(
1667
                prefix_conjunction,
1668
                SolutionState(new_disjunct, state.tree),
1669
                max_instantiations,
1670
            )
1671
        )
1672

1673
    def eliminate_all_ready_semantic_predicate_formulas(
1✔
1674
        self, state: SolutionState
1675
    ) -> Maybe[List[SolutionState]]:
1676
        semantic_predicate_formulas: List[
1✔
1677
            language.NegatedFormula | language.SemanticPredicateFormula
1678
        ] = [
1679
            cast(
1680
                language.NegatedFormula | language.SemanticPredicateFormula,
1681
                pred_formula,
1682
            )
1683
            for pred_formula in language.FilterVisitor(
1684
                lambda f: (
1685
                    isinstance(f, language.SemanticPredicateFormula)
1686
                    or isinstance(f, language.NegatedFormula)
1687
                    and isinstance(f.args[0], language.SemanticPredicateFormula)
1688
                ),
1689
                do_continue=lambda f: (
1690
                    not isinstance(f, language.NegatedFormula)
1691
                    or not isinstance(f.args[0], language.SemanticPredicateFormula)
1692
                ),
1693
            ).collect(state.constraint)
1694
            if all(
1695
                not isinstance(var, language.BoundVariable)
1696
                for var in pred_formula.free_variables()
1697
            )
1698
        ]
1699

1700
        semantic_predicate_formulas = sorted(
1✔
1701
            semantic_predicate_formulas,
1702
            key=lambda f: (
1703
                2 * cast(language.SemanticPredicateFormula, f.args[0]).order + 100
1704
                if isinstance(f, language.NegatedFormula)
1705
                else f.order
1706
            ),
1707
        )
1708

1709
        if not semantic_predicate_formulas:
1✔
1710
            return Maybe.nothing()
1✔
1711

1712
        result = state
1✔
1713

1714
        changed = False
1✔
1715
        for idx, possibly_negated_semantic_predicate_formula in enumerate(
1✔
1716
            semantic_predicate_formulas
1717
        ):
1718
            negated = isinstance(
1✔
1719
                possibly_negated_semantic_predicate_formula, language.NegatedFormula
1720
            )
1721
            semantic_predicate_formula: language.SemanticPredicateFormula = (
1✔
1722
                cast(
1723
                    language.NegatedFormula, possibly_negated_semantic_predicate_formula
1724
                ).args[0]
1725
                if negated
1726
                else possibly_negated_semantic_predicate_formula
1727
            )
1728

1729
            evaluation_result = semantic_predicate_formula.evaluate(
1✔
1730
                self.graph, negate=negated
1731
            )
1732
            if not evaluation_result.ready():
1✔
1733
                continue
1✔
1734

1735
            self.logger.debug(
1✔
1736
                "Eliminating semantic predicate formula %s", semantic_predicate_formula
1737
            )
1738
            changed = True
1✔
1739

1740
            if evaluation_result.is_boolean():
1✔
1741
                result = SolutionState(
1✔
1742
                    language.replace_formula(
1743
                        result.constraint,
1744
                        semantic_predicate_formula,
1745
                        language.smt_atom(evaluation_result.true()),
1746
                    ),
1747
                    result.tree,
1748
                )
1749
                continue
1✔
1750

1751
            new_constraint = language.replace_formula(
1✔
1752
                result.constraint,
1753
                semantic_predicate_formula,
1754
                sc.false() if negated else sc.true(),
1755
            ).substitute_expressions(evaluation_result.result)
1756

1757
            for k in range(idx + 1, len(semantic_predicate_formulas)):
1✔
1758
                semantic_predicate_formulas[k] = cast(
1✔
1759
                    language.SemanticPredicateFormula,
1760
                    semantic_predicate_formulas[k].substitute_expressions(
1761
                        evaluation_result.result
1762
                    ),
1763
                )
1764

1765
            result = SolutionState(
1✔
1766
                new_constraint, result.tree.substitute(evaluation_result.result)
1767
            )
1768

1769
        return Maybe([result] if changed else None)
1✔
1770

1771
    def eliminate_and_match_first_existential_formula(
1✔
1772
        self, state: SolutionState
1773
    ) -> Optional[List[SolutionState]]:
1774
        # We produce up to two groups of output states: One where the first existential
1775
        # formula, if it can be matched, is matched, and one where the first existential
1776
        # formula is eliminated by tree insertion.
1777
        maybe_first_existential_formula_with_idx = Maybe.from_iterator(
1✔
1778
            (idx, conjunct)
1779
            for idx, conjunct in enumerate(split_conjunction(state.constraint))
1780
            if isinstance(conjunct, language.ExistsFormula)
1781
        )
1782

1783
        if not maybe_first_existential_formula_with_idx:
1✔
1784
            return None
1✔
1785

1786
        first_matched = OrderedSet(
1✔
1787
            self.match_existential_formula(
1788
                maybe_first_existential_formula_with_idx.get()[0], state
1789
            )
1790
        )
1791

1792
        # Tree insertion can be deactivated by setting `self.tree_insertion_methods`
1793
        # to 0.
1794
        if not self.tree_insertion_methods:
1✔
1795
            return list(first_matched)
1✔
1796

1797
        if first_matched:
1✔
1798
            self.logger.debug(
1✔
1799
                "Matched first existential formulas, result: [%s]",
1800
                lazyjoin(
1801
                    ", ",
1802
                    [lazystr(lambda: f"{s} (hash={hash(s)})") for s in first_matched],
1803
                ),
1804
            )
1805

1806
        # 3. Eliminate first existential formula by tree insertion.
1807
        elimination_result = OrderedSet(
1✔
1808
            self.eliminate_existential_formula(
1809
                maybe_first_existential_formula_with_idx.get()[0], state
1810
            )
1811
        )
1812
        elimination_result = OrderedSet(
1✔
1813
            [
1814
                result
1815
                for result in elimination_result
1816
                if not any(
1817
                    other_result.tree == result.tree
1818
                    and self.propositionally_unsatisfiable(
1819
                        result.constraint & -other_result.constraint
1820
                    )
1821
                    for other_result in first_matched
1822
                )
1823
            ]
1824
        )
1825

1826
        if not elimination_result and not first_matched:
1✔
1827
            self.logger.warning(
×
1828
                "Existential qfr elimination: Could not eliminate existential formula %s "
1829
                "by matching or tree insertion",
1830
                maybe_first_existential_formula_with_idx.get()[1],
1831
            )
1832

1833
        if elimination_result:
1✔
1834
            self.logger.debug(
1✔
1835
                "Eliminated existential formula %s by tree insertion, %d successors",
1836
                maybe_first_existential_formula_with_idx.get()[1],
1837
                len(elimination_result),
1838
            )
1839

1840
        return [
1✔
1841
            result for result in first_matched | elimination_result if result != state
1842
        ]
1843

1844
    def match_all_universal_formulas(
1✔
1845
        self, state: SolutionState
1846
    ) -> Maybe[List[SolutionState]]:
1847
        universal_formulas = [
1✔
1848
            conjunct
1849
            for conjunct in split_conjunction(state.constraint)
1850
            if isinstance(conjunct, language.ForallFormula)
1851
        ]
1852

1853
        if not universal_formulas:
1✔
1854
            return Maybe.nothing()
1✔
1855

1856
        result = self.match_universal_formulas(state)
1✔
1857
        if result:
1✔
1858
            self.logger.debug(
1✔
1859
                "Matched universal formulas [%s]", lazyjoin(", ", universal_formulas)
1860
            )
1861
        else:
1862
            result = None
1✔
1863

1864
        return Maybe(result)
1✔
1865

1866
    def expand_tree(
1✔
1867
        self,
1868
        state: SolutionState,
1869
        only_universal: bool = True,
1870
        limit: Optional[int] = None,
1871
    ) -> List[SolutionState]:
1872
        """
1873
        Expands the given tree, but not at nonterminals that can be freely instantiated of those that directly
1874
        correspond to the assignment constant.
1875

1876
        :param state: The current state.
1877
        :param only_universal: If set to True, only nonterminals that might match universal quantifiers are
1878
        expanded. If set to false, also nonterminals matching to existential quantifiers are expanded.
1879
        :param limit: If set to a value, this will return only up to limit expansions.
1880
        :return: A (possibly empty) list of expanded trees.
1881
        """
1882

1883
        nonterminal_expansions: Dict[Path, List[List[DerivationTree]]] = {
1✔
1884
            leaf_path: [
1885
                [
1886
                    DerivationTree(child, None if is_nonterminal(child) else [])
1887
                    for child in expansion
1888
                ]
1889
                for expansion in self.canonical_grammar[leaf_node.value]
1890
            ]
1891
            for leaf_path, leaf_node in state.tree.open_leaves()
1892
            if any(
1893
                self.quantified_formula_might_match(formula, leaf_path, state.tree)
1894
                for formula in get_conjuncts(state.constraint)
1895
                if (only_universal and isinstance(formula, language.ForallFormula))
1896
                or (
1897
                    not only_universal
1898
                    and isinstance(formula, language.QuantifiedFormula)
1899
                )
1900
            )
1901
        }
1902

1903
        if not nonterminal_expansions:
1✔
1904
            return []
1✔
1905

1906
        possible_expansions: List[Dict[Path, List[DerivationTree]]] = []
1✔
1907
        if not limit:
1✔
1908
            possible_expansions = dict_of_lists_to_list_of_dicts(nonterminal_expansions)
1✔
1909
            assert len(possible_expansions) == math.prod(
1✔
1910
                len(values) for values in nonterminal_expansions.values()
1911
            )
1912
        else:
1913
            for _ in range(limit):
1✔
1914
                curr_expansion = {}
1✔
1915
                for path, expansions in nonterminal_expansions.items():
1✔
1916
                    if not expansions:
1✔
1917
                        continue
×
1918

1919
                    curr_expansion[path] = random.choice(expansions)
1✔
1920
                possible_expansions.append(curr_expansion)
1✔
1921

1922
        # This replaces a previous `if` statement with the negated condition as guard,
1923
        # which seems to be dead code (the guard can never hold true due to the check
1924
        # of emptiness of `nonterminal_expansions` above). We keep this assertion here
1925
        # to be sure.
1926
        assert (
1✔
1927
            len(possible_expansions) > 1
1928
            or len(possible_expansions) == 1
1929
            and possible_expansions[0]
1930
        )
1931

1932
        result: List[SolutionState] = []
1✔
1933
        for possible_expansion in possible_expansions:
1✔
1934
            expanded_tree = state.tree
1✔
1935
            for path, new_children in possible_expansion.items():
1✔
1936
                leaf_node = expanded_tree.get_subtree(path)
1✔
1937
                expanded_tree = expanded_tree.replace_path(
1✔
1938
                    path, DerivationTree(leaf_node.value, new_children, leaf_node.id)
1939
                )
1940

1941
                assert expanded_tree is not state.tree
1✔
1942
                assert expanded_tree != state.tree
1✔
1943
                assert expanded_tree.structural_hash() != state.tree.structural_hash()
1✔
1944

1945
            updated_constraint = state.constraint.substitute_expressions(
1✔
1946
                {
1947
                    state.tree.get_subtree(path[:idx]): expanded_tree.get_subtree(
1948
                        path[:idx]
1949
                    )
1950
                    for path in possible_expansion
1951
                    for idx in range(len(path) + 1)
1952
                }
1953
            )
1954

1955
            result.append(SolutionState(updated_constraint, expanded_tree))
1✔
1956

1957
        assert not limit or len(result) <= limit
1✔
1958
        return result
1✔
1959

1960
    def match_universal_formulas(self, state: SolutionState) -> List[SolutionState]:
1✔
1961
        instantiated_formulas: List[language.Formula] = []
1✔
1962
        conjuncts = split_conjunction(state.constraint)
1✔
1963

1964
        for idx, universal_formula in enumerate(conjuncts):
1✔
1965
            if not isinstance(universal_formula, language.ForallFormula):
1✔
1966
                continue
1✔
1967

1968
            matches: List[Dict[language.Variable, Tuple[Path, DerivationTree]]] = [
1✔
1969
                match
1970
                for match in matches_for_quantified_formula(
1971
                    universal_formula, self.grammar
1972
                )
1973
                if not universal_formula.is_already_matched(
1974
                    match[universal_formula.bound_variable][1]
1975
                )
1976
            ]
1977

1978
            universal_formula_with_matches = universal_formula.add_already_matched(
1✔
1979
                {match[universal_formula.bound_variable][1] for match in matches}
1980
            )
1981

1982
            for match in matches:
1✔
1983
                inst_formula = (
1✔
1984
                    universal_formula_with_matches.inner_formula.substitute_expressions(
1985
                        {
1986
                            variable: match_tree
1987
                            for variable, (_, match_tree) in match.items()
1988
                        }
1989
                    )
1990
                )
1991

1992
                instantiated_formulas.append(inst_formula)
1✔
1993
                conjuncts[idx] = universal_formula_with_matches
1✔
1994

1995
        if instantiated_formulas:
1✔
1996
            return [
1✔
1997
                SolutionState(
1998
                    sc.conjunction(*instantiated_formulas) & sc.conjunction(*conjuncts),
1999
                    state.tree,
2000
                )
2001
            ]
2002
        else:
2003
            return []
1✔
2004

2005
    def match_existential_formula(
1✔
2006
        self, existential_formula_idx: int, state: SolutionState
2007
    ) -> List[SolutionState]:
2008
        result: List[SolutionState] = []
1✔
2009

2010
        conjuncts: ImmutableList[language.Formula] = tuple(
1✔
2011
            split_conjunction(state.constraint)
2012
        )
2013
        existential_formula = cast(
1✔
2014
            language.ExistsFormula, conjuncts[existential_formula_idx]
2015
        )
2016

2017
        matches: List[
1✔
2018
            Dict[language.Variable, Tuple[Path, DerivationTree]]
2019
        ] = matches_for_quantified_formula(existential_formula, self.grammar)
2020

2021
        for match in matches:
1✔
2022
            inst_formula = existential_formula.inner_formula.substitute_expressions(
1✔
2023
                {variable: match_tree for variable, (_, match_tree) in match.items()}
2024
            )
2025
            constraint = inst_formula & sc.conjunction(
1✔
2026
                *list_del(conjuncts, existential_formula_idx)
2027
            )
2028
            result.append(SolutionState(constraint, state.tree))
1✔
2029

2030
        return result
1✔
2031

2032
    def eliminate_existential_formula(
1✔
2033
        self, existential_formula_idx: int, state: SolutionState
2034
    ) -> List[SolutionState]:
2035
        conjuncts: ImmutableList[language.Formula] = tuple(
1✔
2036
            split_conjunction(state.constraint)
2037
        )
2038
        existential_formula = cast(
1✔
2039
            language.ExistsFormula, conjuncts[existential_formula_idx]
2040
        )
2041

2042
        inserted_trees_and_bind_paths = (
1✔
2043
            [(DerivationTree(existential_formula.bound_variable.n_type, None), {})]
2044
            if existential_formula.bind_expression is None
2045
            else (
2046
                existential_formula.bind_expression.to_tree_prefix(
2047
                    existential_formula.bound_variable.n_type, self.grammar
2048
                )
2049
            )
2050
        )
2051

2052
        result: List[SolutionState] = []
1✔
2053

2054
        inserted_tree: DerivationTree
2055
        bind_expr_paths: Dict[language.BoundVariable, Path]
2056
        for inserted_tree, bind_expr_paths in inserted_trees_and_bind_paths:
1✔
2057
            self.logger.debug(
1✔
2058
                "insert_tree(self.canonical_grammar, %s, %s, self.graph, %s)",
2059
                lazystr(lambda: repr(inserted_tree)),
2060
                lazystr(lambda: repr(existential_formula.in_variable)),
2061
                self.max_number_tree_insertion_results,
2062
            )
2063

2064
            insertion_results = insert_tree(
1✔
2065
                self.canonical_grammar,
2066
                inserted_tree,
2067
                existential_formula.in_variable,
2068
                graph=self.graph,
2069
                max_num_solutions=self.max_number_tree_insertion_results * 2,
2070
                methods=self.tree_insertion_methods,
2071
            )
2072

2073
            insertion_results = sorted(
1✔
2074
                insertion_results,
2075
                key=lambda t: compute_tree_closing_cost(t, self.graph),
2076
            )
2077
            insertion_results = insertion_results[
1✔
2078
                : self.max_number_tree_insertion_results
2079
            ]
2080

2081
            for insertion_result in insertion_results:
1✔
2082
                replaced_path = state.tree.find_node(existential_formula.in_variable)
1✔
2083
                resulting_tree = state.tree.replace_path(
1✔
2084
                    replaced_path, insertion_result
2085
                )
2086

2087
                tree_substitution: Dict[DerivationTree, DerivationTree] = {}
1✔
2088
                for idx in range(len(replaced_path) + 1):
1✔
2089
                    original_path = replaced_path[: idx - 1]
1✔
2090
                    original_tree = state.tree.get_subtree(original_path)
1✔
2091
                    if (
1✔
2092
                        resulting_tree.is_valid_path(original_path)
2093
                        and original_tree.value
2094
                        == resulting_tree.get_subtree(original_path).value
2095
                        and resulting_tree.get_subtree(original_path) != original_tree
2096
                    ):
2097
                        tree_substitution[original_tree] = resulting_tree.get_subtree(
1✔
2098
                            original_path
2099
                        )
2100

2101
                assert insertion_result.find_node(inserted_tree) is not None
1✔
2102
                variable_substitutions = {
1✔
2103
                    existential_formula.bound_variable: inserted_tree
2104
                }
2105

2106
                if bind_expr_paths:
1✔
2107
                    if assertions_activated():
1✔
2108
                        dangling_bind_expr_vars = [
1✔
2109
                            (var, path)
2110
                            for var, path in bind_expr_paths.items()
2111
                            if (
2112
                                var
2113
                                in existential_formula.bind_expression.bound_variables()
2114
                                and insertion_result.find_node(
2115
                                    inserted_tree.get_subtree(path)
2116
                                )
2117
                                is None
2118
                            )
2119
                        ]
2120
                        assert not dangling_bind_expr_vars, (
1✔
2121
                            f"Bound variables from match expression not found in tree: "
2122
                            f"[{' ,'.join(map(repr, dangling_bind_expr_vars))}]"
2123
                        )
2124

2125
                    variable_substitutions.update(
1✔
2126
                        {
2127
                            var: inserted_tree.get_subtree(path)
2128
                            for var, path in bind_expr_paths.items()
2129
                            if var
2130
                            in existential_formula.bind_expression.bound_variables()
2131
                        }
2132
                    )
2133

2134
                instantiated_formula = (
1✔
2135
                    existential_formula.inner_formula.substitute_expressions(
2136
                        variable_substitutions
2137
                    ).substitute_expressions(tree_substitution)
2138
                )
2139

2140
                instantiated_original_constraint = sc.conjunction(
1✔
2141
                    *list_del(conjuncts, existential_formula_idx)
2142
                ).substitute_expressions(tree_substitution)
2143

2144
                new_tree = resulting_tree.substitute(tree_substitution)
1✔
2145

2146
                new_formula = (
1✔
2147
                    instantiated_formula
2148
                    & self.formula.substitute_expressions(
2149
                        {self.top_constant.get(): new_tree}
2150
                    )
2151
                    & instantiated_original_constraint
2152
                )
2153

2154
                new_state = SolutionState(new_formula, new_tree)
1✔
2155

2156
                assert all(
1✔
2157
                    new_state.tree.find_node(tree) is not None
2158
                    for quantified_formula in split_conjunction(new_state.constraint)
2159
                    if isinstance(quantified_formula, language.QuantifiedFormula)
2160
                    for _, tree in quantified_formula.in_variable.filter(lambda t: True)
2161
                )
2162

2163
                if assertions_activated() or self.debug:
1✔
2164
                    lost_tree_predicate_arguments: List[DerivationTree] = [
1✔
2165
                        arg
2166
                        for invstate in self.establish_invariant(new_state)
2167
                        for predicate_formula in get_conjuncts(invstate.constraint)
2168
                        if isinstance(
2169
                            predicate_formula, language.StructuralPredicateFormula
2170
                        )
2171
                        for arg in predicate_formula.args
2172
                        if isinstance(arg, DerivationTree)
2173
                        and invstate.tree.find_node(arg) is None
2174
                    ]
2175

2176
                    if lost_tree_predicate_arguments:
1✔
2177
                        previous_posititions = [
×
2178
                            state.tree.find_node(t)
2179
                            for t in lost_tree_predicate_arguments
2180
                        ]
2181
                        assert False, (
×
2182
                            f"Dangling subtrees [{', '.join(map(repr, lost_tree_predicate_arguments))}], "
2183
                            f"previously at positions [{', '.join(map(str, previous_posititions))}] "
2184
                            f"in tree {repr(state.tree)} (hash: {hash(state)})."
2185
                        )
2186

2187
                    lost_semantic_formula_arguments = [
1✔
2188
                        arg
2189
                        for invstate in self.establish_invariant(new_state)
2190
                        for semantic_formula in get_conjuncts(new_state.constraint)
2191
                        if isinstance(semantic_formula, language.SMTFormula)
2192
                        for arg in semantic_formula.substitutions.values()
2193
                        if invstate.tree.find_node(arg) is None
2194
                    ]
2195

2196
                    if lost_semantic_formula_arguments:
1✔
2197
                        previous_posititions = [
×
2198
                            state.tree.find_node(t)
2199
                            for t in lost_semantic_formula_arguments
2200
                        ]
2201
                        previous_posititions = [
×
2202
                            p for p in previous_posititions if p is not None
2203
                        ]
2204
                        assert False, (
×
2205
                            f"Dangling subtrees [{', '.join(map(repr, lost_semantic_formula_arguments))}], "
2206
                            f"previously at positions [{', '.join(map(str, previous_posititions))}] "
2207
                            f"in tree {repr(state.tree)} (hash: {hash(state)})."
2208
                        )
2209

2210
                result.append(new_state)
1✔
2211

2212
        return result
1✔
2213

2214
    def eliminate_semantic_formula(
1✔
2215
        self,
2216
        semantic_formula: language.Formula,
2217
        state: SolutionState,
2218
        max_instantiations: Optional[int] = None,
2219
    ) -> Optional[List[SolutionState]]:
2220
        """
2221
        Solves a semantic formula and, for each solution, substitutes the solution for
2222
        the respective constant in each assignment of the state. Also instantiates all
2223
        "free" constants in the given tree. The SMT solver is passed a regular
2224
        expression approximating the language of the nonterminal of each considered
2225
        constant. Returns an empty list for unsolvable constraints.
2226

2227
        :param semantic_formula: The semantic (i.e., only containing logical connectors and SMT Formulas)
2228
        formula to solve.
2229
        :param state: The original solution state.
2230
        :param max_instantiations: The maximum number of solutions to ask the SMT solver for.
2231
        :return: A list of instantiated SolutionStates.
2232
        """
2233

2234
        assert all(
1✔
2235
            isinstance(conjunct, language.SMTFormula)
2236
            for conjunct in get_conjuncts(semantic_formula)
2237
        )
2238

2239
        # NODE: We need to cluster SMT formulas by tree substitutions. If there are two
2240
        # formulas with a variable $var which is instantiated to different trees, we
2241
        # need two separate solutions. If, however, $var is instantiated with the
2242
        # *same* tree, we need one solution to both formulas together.
2243

2244
        smt_formulas = self.rename_instantiated_variables_in_smt_formulas(
1✔
2245
            [
2246
                smt_formula
2247
                for smt_formula in get_conjuncts(semantic_formula)
2248
                if isinstance(smt_formula, language.SMTFormula)
2249
            ]
2250
        )
2251

2252
        # Now, we also cluster formulas by common variables (and instantiated subtrees:
2253
        # One formula might yield an instantiation of a subtree of the instantiation of
2254
        # another formula. They need to appear in the same cluster). The solver can
2255
        # better handle smaller constraints, and those which do not have variables in
2256
        # common can be handled independently.
2257

2258
        def cluster_keys(smt_formula: language.SMTFormula):
1✔
2259
            return (
1✔
2260
                smt_formula.free_variables()
2261
                | smt_formula.instantiated_variables
2262
                | set(
2263
                    [
2264
                        subtree
2265
                        for tree in smt_formula.substitutions.values()
2266
                        for _, subtree in tree.paths()
2267
                    ]
2268
                )
2269
            )
2270

2271
        formula_clusters: List[List[language.SMTFormula]] = cluster_by_common_elements(
1✔
2272
            smt_formulas, cluster_keys
2273
        )
2274

2275
        assert all(
1✔
2276
            not cluster_keys(smt_formula)
2277
            or any(smt_formula in cluster for cluster in formula_clusters)
2278
            for smt_formula in smt_formulas
2279
        )
2280

2281
        formula_clusters = [cluster for cluster in formula_clusters if cluster]
1✔
2282
        remaining_clusters = [
1✔
2283
            smt_formula for smt_formula in smt_formulas if not cluster_keys(smt_formula)
2284
        ]
2285
        if remaining_clusters:
1✔
2286
            formula_clusters.append(remaining_clusters)
1✔
2287

2288
        all_solutions: List[
1✔
2289
            List[Dict[Union[language.Constant, DerivationTree], DerivationTree]]
2290
        ] = [
2291
            self.solve_quantifier_free_formula(tuple(cluster), max_instantiations)
2292
            for cluster in formula_clusters
2293
        ]
2294

2295
        # These solutions are all independent, such that we can combine each solution with all others.
2296
        solutions: List[
1✔
2297
            Dict[Union[language.Constant, DerivationTree], DerivationTree]
2298
        ] = [
2299
            functools.reduce(operator.or_, dicts)
2300
            for dicts in itertools.product(*all_solutions)
2301
        ]
2302

2303
        solutions_with_subtrees: List[
1✔
2304
            Dict[Union[language.Constant, DerivationTree], DerivationTree]
2305
        ] = []
2306
        for solution in solutions:
1✔
2307
            # We also have to instantiate all subtrees of the substituted element.
2308

2309
            solution_with_subtrees: Dict[
1✔
2310
                Union[language.Constant, DerivationTree], DerivationTree
2311
            ] = {}
2312
            for orig, subst in solution.items():
1✔
2313
                if isinstance(orig, language.Constant):
1✔
2314
                    solution_with_subtrees[orig] = subst
1✔
2315
                    continue
1✔
2316

2317
                assert isinstance(
1✔
2318
                    orig, DerivationTree
2319
                ), f"Expected a DerivationTree, given: {type(orig).__name__}"
2320

2321
                for path, tree in [
1✔
2322
                    (p, t) for p, t in orig.paths() if t not in solution_with_subtrees
2323
                ]:
2324
                    assert subst.is_valid_path(path), (
1✔
2325
                        f"SMT Solution {subst} does not have "
2326
                        f"orig path {path} from tree {orig} (state {hash(state)})"
2327
                    )
2328
                    solution_with_subtrees[tree] = subst.get_subtree(path)
1✔
2329

2330
            solutions_with_subtrees.append(solution_with_subtrees)
1✔
2331

2332
        results = []
1✔
2333
        for solution in solutions_with_subtrees:
1✔
2334
            if solution:
1✔
2335
                new_state = SolutionState(
1✔
2336
                    state.constraint.substitute_expressions(solution),
2337
                    state.tree.substitute(solution),
2338
                )
2339
            else:
2340
                new_state = SolutionState(
×
2341
                    language.replace_formula(
2342
                        state.constraint, semantic_formula, sc.true()
2343
                    ),
2344
                    state.tree,
2345
                )
2346

2347
            results.append(new_state)
1✔
2348

2349
        return results
1✔
2350

2351
    @lru_cache(100)
1✔
2352
    def solve_quantifier_free_formula(
1✔
2353
        self,
2354
        smt_formulas: ImmutableList[language.SMTFormula],
2355
        max_instantiations: Optional[int] = None,
2356
    ) -> List[Dict[language.Constant | DerivationTree, DerivationTree]]:
2357
        """
2358
        Attempts to solve the given SMT-LIB formulas by calling Z3.
2359

2360
        Note that this function does not unify variables pointing to the same derivation
2361
        trees. E.g., a solution may be returned for the formula `var_1 = "a" and
2362
        var_2 = "b"`, though `var_1` and `var_2` point to the same `<var>` tree as
2363
        defined by their substitutions maps. Unification is performed in
2364
        `eliminate_all_semantic_formulas`.
2365

2366
        :param smt_formulas: The SMT-LIB formulas to solve.
2367
        :param max_instantiations: The maximum number of instantiations to produce.
2368
        :return: A (possibly empty) list of solutions.
2369
        """
2370

2371
        # If any SMT formula refers to *sub*trees in the instantiations of other SMT
2372
        # formulas, we have to instantiate those first.
2373
        priority_formulas = smt_formulas_referring_to_subtrees(smt_formulas)
1✔
2374

2375
        if priority_formulas:
1✔
2376
            smt_formulas = priority_formulas
1✔
2377
            assert not smt_formulas_referring_to_subtrees(smt_formulas)
1✔
2378

2379
        tree_substitutions = reduce(
1✔
2380
            lambda d1, d2: d1 | d2,
2381
            [smt_formula.substitutions for smt_formula in smt_formulas],
2382
            {},
2383
        )
2384

2385
        constants = reduce(
1✔
2386
            lambda d1, d2: d1 | d2,
2387
            [
2388
                smt_formula.free_variables() | smt_formula.instantiated_variables
2389
                for smt_formula in smt_formulas
2390
            ],
2391
            set(),
2392
        )
2393

2394
        solutions: List[
1✔
2395
            Dict[Union[language.Constant, DerivationTree], DerivationTree]
2396
        ] = []
2397
        internal_solutions: List[Dict[language.Constant, z3.StringVal]] = []
1✔
2398

2399
        num_instantiations = max_instantiations or self.max_number_smt_instantiations
1✔
2400
        for _ in range(num_instantiations):
1✔
2401
            (
1✔
2402
                solver_result,
2403
                maybe_model,
2404
            ) = self.solve_smt_formulas_with_language_constraints(
2405
                constants,
2406
                tuple([smt_formula.formula for smt_formula in smt_formulas]),
2407
                tree_substitutions,
2408
                internal_solutions,
2409
            )
2410

2411
            if solver_result != z3.sat:
1✔
2412
                if not solutions:
1✔
2413
                    return []
1✔
2414
                else:
2415
                    return solutions
1✔
2416

2417
            assert maybe_model is not None
1✔
2418

2419
            new_solution = {
1✔
2420
                tree_substitutions.get(constant, constant): maybe_model[constant]
2421
                for constant in constants
2422
            }
2423

2424
            new_internal_solution = {
1✔
2425
                constant: z3.StringVal(str(maybe_model[constant]))
2426
                for constant in constants
2427
            }
2428

2429
            if new_solution in solutions:
1✔
2430
                # This can happen for trivial solutions, i.e., if the formula is
2431
                # logically valid. Then, the assignment for that constant will
2432
                # always be {}
2433
                return solutions
×
2434
            else:
2435
                solutions.append(new_solution)
1✔
2436
                if new_internal_solution:
1✔
2437
                    internal_solutions.append(new_internal_solution)
1✔
2438
                else:
2439
                    # Again, for a trivial solution (e.g., True), the assignment
2440
                    # can be empty.
2441
                    break
×
2442

2443
        return solutions
1✔
2444

2445
    def solve_smt_formulas_with_language_constraints(
1✔
2446
        self,
2447
        variables: Set[language.Variable],
2448
        smt_formulas: ImmutableList[z3.BoolRef],
2449
        tree_substitutions: Dict[language.Variable, DerivationTree],
2450
        solutions_to_exclude: List[Dict[language.Variable, z3.StringVal]],
2451
    ) -> Tuple[z3.CheckSatResult, Dict[language.Variable, DerivationTree]]:
2452
        # We disable optimized Z3 queries if the SMT formulas contain "too concrete"
2453
        # substitutions, that is, substitutions with a tree that is not merely an
2454
        # open leaf. Example: we have a constrained `str.len(<chars>) < 10` and a
2455
        # tree `<char><char>`; only the concrete length "10" is possible then. In fact,
2456
        # we could simply finish of the tree and check the constraint, or restrict the
2457
        # custom tree generation to admissible lengths, but we stay general here. The
2458
        # SMT solution is more robust.
2459

2460
        if self.enable_optimized_z3_queries and not any(
1✔
2461
            substitution.children for substitution in tree_substitutions.values()
2462
        ):
2463
            vars_in_context = self.infer_variable_contexts(variables, smt_formulas)
1✔
2464
            length_vars = vars_in_context["length"]
1✔
2465
            int_vars = vars_in_context["int"]
1✔
2466
            flexible_vars = vars_in_context["flexible"]
1✔
2467
        else:
2468
            length_vars = set()
1✔
2469
            int_vars = set()
1✔
2470
            flexible_vars = set(variables)
1✔
2471

2472
        # Add language constraints for "flexible" variables
2473
        formulas: List[z3.BoolRef] = self.generate_language_constraints(
1✔
2474
            flexible_vars, tree_substitutions
2475
        )
2476

2477
        # Create fresh variables for `str.len` and `str.to.int` variables.
2478
        all_variables = set(variables)
1✔
2479
        fresh_var_map: Dict[language.Variable, z3.ExprRef] = {}
1✔
2480
        for var in length_vars | int_vars:
1✔
2481
            fresh = fresh_constant(
1✔
2482
                all_variables,
2483
                language.Constant(var.name, "NOT-NEEDED"),
2484
            )
2485
            fresh_var_map[var] = z3.Int(fresh.name)
1✔
2486

2487
        # In `smt_formulas`, we replace all `length(...)` terms for "length variables"
2488
        # with the corresponding fresh variable.
2489
        replacement_map: Dict[z3.ExprRef, z3.ExprRef] = {
1✔
2490
            expr: fresh_var_map[
2491
                get_elem_by_equivalence(
2492
                    expr.children()[0],
2493
                    length_vars | int_vars,
2494
                    lambda e1, e2: e1 == e2.to_smt(),
2495
                )
2496
            ]
2497
            for formula in smt_formulas
2498
            for expr in visit_z3_expr(formula)
2499
            if expr.decl().kind() in {z3.Z3_OP_SEQ_LENGTH, z3.Z3_OP_STR_TO_INT}
2500
            and expr.children()[0] in {var.to_smt() for var in length_vars | int_vars}
2501
        }
2502

2503
        # Perform substitution, add formulas
2504
        formulas.extend(
1✔
2505
            [
2506
                cast(z3.BoolRef, z3_subst(formula, replacement_map))
2507
                for formula in smt_formulas
2508
            ]
2509
        )
2510

2511
        # Lengths must be positive
2512
        formulas.extend(
1✔
2513
            [
2514
                cast(z3.BoolRef, length_var >= z3.IntVal(0))
2515
                for length_var in replacement_map.values()
2516
            ]
2517
        )
2518

2519
        for prev_solution in solutions_to_exclude:
1✔
2520
            prev_solution_formula = z3_and(
1✔
2521
                [
2522
                    self.previous_solution_formula(
2523
                        var, string_val, fresh_var_map, length_vars, int_vars
2524
                    )
2525
                    for var, string_val in prev_solution.items()
2526
                ]
2527
            )
2528

2529
            formulas.append(z3.Not(prev_solution_formula))
1✔
2530

2531
        sat_result, maybe_model = z3_solve(formulas)
1✔
2532

2533
        if sat_result != z3.sat:
1✔
2534
            return sat_result, {}
1✔
2535

2536
        assert maybe_model is not None
1✔
2537

2538
        return sat_result, {
1✔
2539
            var: self.extract_model_value(
2540
                var, maybe_model, fresh_var_map, length_vars, int_vars
2541
            )
2542
            for var in variables
2543
        }
2544

2545
    @staticmethod
1✔
2546
    def previous_solution_formula(
1✔
2547
        var: language.Variable,
2548
        string_val: z3.StringVal,
2549
        fresh_var_map: Dict[language.Variable, z3.ExprRef],
2550
        length_vars: Set[language.Variable],
2551
        int_vars: Set[language.Variable],
2552
    ) -> z3.BoolRef:
2553
        """
2554
        Computes a formula describing the previously found solution
2555
        :code:`var == string_val` for an :class:`~isla.language.SMTFormula`.
2556
        Considers the special cases that :code:`var` is a "length" or "int"
2557
        variable, i.e., occurred only in these contexts in the formula this
2558
        solution is about.
2559

2560
        >>> x = language.Variable("x", "<X>")
2561
        >>> ISLaSolver.previous_solution_formula(
2562
        ...     x, z3.StringVal("val"), {}, set(), set())
2563
        x == "val"
2564

2565
        >>> ISLaSolver.previous_solution_formula(
2566
        ...     x, z3.StringVal("val"), {x: z3.Int("x_0")}, {x}, set())
2567
        x_0 == 3
2568

2569
        >>> ISLaSolver.previous_solution_formula(
2570
        ...     x, z3.StringVal("10"), {x: z3.Int("x_0")}, set(), {x})
2571
        x_0 == 10
2572

2573
        >>> x = language.Variable("x", language.Variable.NUMERIC_NTYPE)
2574
        >>> ISLaSolver.previous_solution_formula(
2575
        ...     x, z3.StringVal("10"), {x: z3.Int("x_0")}, set(), {x})
2576
        x_0 == 10
2577

2578
        A "numeric" variable (of "NUM" type) is expected to always be an int variable,
2579
        which also needs to be reflected in its inclusion in :code:`fresh_var_map`.
2580

2581
        >>> x = language.Variable("x", language.Variable.NUMERIC_NTYPE)
2582
        >>> ISLaSolver.previous_solution_formula(
2583
        ...     x, z3.StringVal("10"), {}, set(), set())
2584
        Traceback (most recent call last):
2585
        ...
2586
        AssertionError
2587

2588
        :param var: The variable the solution is for.
2589
        :param string_val: The solution for :code:`var`.
2590
        :param fresh_var_map: A map from variables to fresh variables for "length" or
2591
                              "int" variables.
2592
        :param length_vars: The "length" variables.
2593
        :param int_vars: The "int" variables.
2594
        :return: An equation describing the previous solution.
2595
        """
2596
        if var in int_vars:
1✔
2597
            return z3_eq(
1✔
2598
                fresh_var_map[var],
2599
                z3.IntVal(int(smt_string_val_to_string(string_val))),
2600
            )
2601
        elif var in length_vars:
1✔
2602
            return z3_eq(
1✔
2603
                fresh_var_map[var],
2604
                z3.IntVal(len(smt_string_val_to_string(string_val))),
2605
            )
2606
        else:
2607
            assert not var.is_numeric()
1✔
2608
            return z3_eq(var.to_smt(), string_val)
1✔
2609

2610
    def safe_create_fixed_length_tree(
1✔
2611
        self,
2612
        var: language.Variable,
2613
        model: z3.ModelRef,
2614
        fresh_var_map: Dict[language.Variable, z3.ExprRef],
2615
    ) -> DerivationTree:
2616
        """
2617
        Creates a :class:`~isla.derivation_tree.DerivationTree` for :code:`var` such
2618
        that the type of the tree fits to :code:`var` and the length of its string
2619
        representation fits to the length in :code:`model` for the fresh variable in
2620
        :code:`fresh_var_map`. For example:
2621

2622
        >>> grammar = {
2623
        ...     "<start>": ["<X>"],
2624
        ...     "<X>": ["x", "x<X>"],
2625
        ... }
2626
        >>> x = language.Variable("x", "<X>")
2627
        >>> x_0 = z3.Int("x_0")
2628
        >>> f = z3_eq(x_0, z3.IntVal(5))
2629
        >>> z3_solver = z3.Solver()
2630
        >>> z3_solver.add(f)
2631
        >>> z3_solver.check()
2632
        sat
2633
        >>> model = z3_solver.model()
2634
        >>> solver = ISLaSolver(grammar)
2635
        >>> tree = solver.safe_create_fixed_length_tree(x, model, {x: x_0})
2636
        >>> tree.value
2637
        '<X>'
2638
        >>> str(tree)
2639
        'xxxxx'
2640

2641
        :param var: The variable to create a
2642
                    :class:`~isla.derivation_tree.DerivationTree` object for.
2643
        :param model: The Z3 model to extract a solution to the length constraint.
2644
        :param fresh_var_map: A map including a mapping :code:`var` -> :code:`var_0`,
2645
                              where :code:`var_0` is an integer-valued variale included
2646
                              in :code:`model`.
2647
        :return: A tree of the type of :code:`var` and length as specified in
2648
                :code:`model`.
2649
        """
2650

2651
        assert var in fresh_var_map
1✔
2652
        assert fresh_var_map[var].decl() in model.decls()
1✔
2653

2654
        fixed_length_tree = create_fixed_length_tree(
1✔
2655
            start=var.n_type,
2656
            canonical_grammar=self.canonical_grammar,
2657
            target_length=model[fresh_var_map[var]].as_long(),
2658
        )
2659

2660
        if fixed_length_tree is None:
1✔
2661
            raise RuntimeError(
1✔
2662
                f"Could not create a tree with the start symbol '{var.n_type}' "
2663
                + f"of length {model[fresh_var_map[var]].as_long()}; try "
2664
                + "to run the solver without optimized Z3 queries or make "
2665
                + "sure that lengths are restricted to syntactically valid "
2666
                + "ones (according to the grammar).",
2667
            )
2668

2669
        return fixed_length_tree
1✔
2670

2671
    def extract_model_value(
1✔
2672
        self,
2673
        var: language.Variable,
2674
        model: z3.ModelRef,
2675
        fresh_var_map: Dict[language.Variable, z3.ExprRef],
2676
        length_vars: Set[language.Variable],
2677
        int_vars: Set[language.Variable],
2678
    ) -> DerivationTree:
2679
        """
2680
        Extracts a value for :code:`var` from :code:`model`. Considers the following
2681
        special cases:
2682

2683
        Numeric Variables
2684
            Returns a closed derivation tree of one node with a string representation
2685
            of the numeric solution.
2686

2687
        "Length" Variables
2688
            Returns a string of the length corresponding to the model and
2689
            :code:`fresh_var_map`, see also
2690
            :meth:`~isla.solver.ISLaSolver.safe_create_fixed_length_tree()`.
2691

2692
        "Int" Variables
2693
            Tries to parse the numeric solution from the model (obtained via
2694
            :code:`fresh_var_map`) into the type of :code:`var` and returns the
2695
            corresponding derivation tree.
2696

2697
        >>> grammar = {
2698
        ...     "<start>": ["<A>"],
2699
        ...     "<A>": ["<X><Y>"],
2700
        ...     "<X>": ["x", "x<X>"],
2701
        ...     "<Y>": ["<digit>", "<digit><Y>"],
2702
        ...     "<digit>": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
2703
        ... }
2704
        >>> solver = ISLaSolver(grammar)
2705

2706
        **Numeric Variables:**
2707

2708
        >>> n = language.Variable("n", language.Variable.NUMERIC_NTYPE)
2709
        >>> f = z3_eq(z3.StrToInt(n.to_smt()), z3.IntVal(15))
2710
        >>> z3_solver = z3.Solver()
2711
        >>> z3_solver.add(f)
2712
        >>> z3_solver.check()
2713
        sat
2714
        >>> model = z3_solver.model()
2715
        >>> DerivationTree.next_id = 1
2716
        >>> solver.extract_model_value(n, model, {}, set(), set())
2717
        DerivationTree('15', (), id=1)
2718

2719
        **"Length" Variables:**
2720

2721
        >>> x = language.Variable("x", "<X>")
2722
        >>> x_0 = z3.Int("x_0")
2723
        >>> f = z3_eq(x_0, z3.IntVal(3))
2724
        >>> z3_solver = z3.Solver()
2725
        >>> z3_solver.add(f)
2726
        >>> z3_solver.check()
2727
        sat
2728
        >>> model = z3_solver.model()
2729
        >>> result = solver.extract_model_value(x, model, {x: x_0}, {x}, set())
2730
        >>> result.value
2731
        '<X>'
2732
        >>> str(result)
2733
        'xxx'
2734

2735
        **"Int" Variables:**
2736

2737
        >>> y = language.Variable("y", "<Y>")
2738
        >>> y_0 = z3.Int("y_0")
2739
        >>> f = z3_eq(y_0, z3.IntVal(5))
2740
        >>> z3_solver = z3.Solver()
2741
        >>> z3_solver.add(f)
2742
        >>> z3_solver.check()
2743
        sat
2744
        >>> model = z3_solver.model()
2745
        >>> DerivationTree.next_id = 1
2746
        >>> solver.extract_model_value(y, model, {y: y_0}, set(), {y})
2747
        DerivationTree('<Y>', (DerivationTree('<digit>', (DerivationTree('5', (), id=1),), id=2),), id=3)
2748

2749
        **"Flexible" Variables:**
2750

2751
        >>> f = z3_eq(x.to_smt(), z3.StringVal("xxxxx"))
2752
        >>> z3_solver = z3.Solver()
2753
        >>> z3_solver.add(f)
2754
        >>> z3_solver.check()
2755
        sat
2756
        >>> model = z3_solver.model()
2757
        >>> result = solver.extract_model_value(x, model, {}, set(), set())
2758
        >>> result.value
2759
        '<X>'
2760
        >>> str(result)
2761
        'xxxxx'
2762

2763
        :param var: The variable for which to extract a solution from the model.
2764
        :param model: The model containing the solution.
2765
        :param fresh_var_map: A map from variables to fresh symbols for "length" and
2766
                              "int" variables.
2767
        :param length_vars: The set of "length" variables.
2768
        :param int_vars: The set of "int" variables.
2769
        :return: A :class:`~isla.derivation_tree.DerivationTree` object corresponding
2770
                 to the solution in :code:`model`.
2771
        """
2772

2773
        if var.is_numeric():
1✔
2774
            z3_var = z3.String(var.name)
1✔
2775
            if z3_var.decl() in model.decls():
1✔
2776
                model_value = model[z3_var]
1✔
2777
            else:
2778
                assert var in int_vars
1✔
2779
                assert var in fresh_var_map
1✔
2780

2781
                model_value = model[fresh_var_map[var]]
1✔
2782

2783
            string_value = smt_string_val_to_string(model_value)
1✔
2784
            assert string_value
1✔
2785
            assert (
1✔
2786
                string_value.isnumeric()
2787
                or string_value[0] == "-"
2788
                and string_value[1:].isnumeric()
2789
            )
2790

2791
            return DerivationTree(string_value, ())
1✔
2792
        elif var in length_vars:
1✔
2793
            return self.safe_create_fixed_length_tree(var, model, fresh_var_map)
1✔
2794
        elif var in int_vars:
1✔
2795
            return self.parse(
1✔
2796
                model[fresh_var_map[var]].as_string(),
2797
                var.n_type,
2798
            )
2799
        else:
2800
            # A "flexible" variable.
2801
            return self.parse(
1✔
2802
                smt_string_val_to_string(model[z3.String(var.name)]),
2803
                var.n_type,
2804
            )
2805

2806
    @staticmethod
1✔
2807
    def infer_variable_contexts(
1✔
2808
        variables: Set[language.Variable], smt_formulas: ImmutableList[z3.BoolRef]
2809
    ) -> Dict[str, Set[language.Variable]]:
2810
        """
2811
        Divides the given variables into
2812

2813
        1. those that occur only in :code:`length(...)` contexts,
2814
        2. those that occur only in :code:`str.to.int(...)` contexts, and
2815
        3. "flexible" constants occurring in other/various contexts.
2816

2817
        >>> x = language.Variable("x", "<X>")
2818
        >>> y = language.Variable("y", "<Y>")
2819

2820
        Two variables in an arbitrary context.
2821

2822
        >>> f = z3_eq(x.to_smt(), y.to_smt())
2823
        >>> contexts = ISLaSolver.infer_variable_contexts({x, y}, (f,))
2824
        >>> contexts["length"]
2825
        set()
2826
        >>> contexts["int"]
2827
        set()
2828
        >>> contexts["flexible"] == {language.Variable("x", "<X>"), language.Variable("y", "<Y>")}
2829
        True
2830

2831
        Variable x occurs in a length context, variable y in an arbitrary one.
2832

2833
        >>> f = z3.And(
2834
        ...     z3.Length(x.to_smt()) > z3.IntVal(10),
2835
        ...     z3_eq(y.to_smt(), z3.StringVal("y")))
2836
        >>> ISLaSolver.infer_variable_contexts({x, y}, (f,))
2837
        {'length': {Variable("x", "<X>")}, 'int': set(), 'flexible': {Variable("y", "<Y>")}}
2838

2839
        Variable x occurs in a length context, y does not occur.
2840

2841
        >>> f = z3.Length(x.to_smt()) > z3.IntVal(10)
2842
        >>> ISLaSolver.infer_variable_contexts({x, y}, (f,))
2843
        {'length': {Variable("x", "<X>")}, 'int': set(), 'flexible': {Variable("y", "<Y>")}}
2844

2845
        Variables x and y both occur in a length context.
2846

2847
        >>> f = z3.Length(x.to_smt()) > z3.Length(y.to_smt())
2848
        >>> contexts = ISLaSolver.infer_variable_contexts({x, y}, (f,))
2849
        >>> contexts["length"] == {language.Variable("x", "<X>"), language.Variable("y", "<Y>")}
2850
        True
2851
        >>> contexts["int"]
2852
        set()
2853
        >>> contexts["flexible"]
2854
        set()
2855

2856
        Variable x occurs in a :code:`str.to.int` context.
2857

2858
        >>> f = z3.StrToInt(x.to_smt()) > z3.IntVal(17)
2859
        >>> ISLaSolver.infer_variable_contexts({x}, (f,))
2860
        {'length': set(), 'int': {Variable("x", "<X>")}, 'flexible': set()}
2861

2862
        Now, x also occurs in a different context; it's "flexible" now.
2863

2864
        >>> f = z3.And(
2865
        ...     z3.StrToInt(x.to_smt()) > z3.IntVal(17),
2866
        ...     z3_eq(x.to_smt(), z3.StringVal("17")))
2867
        >>> ISLaSolver.infer_variable_contexts({x}, (f,))
2868
        {'length': set(), 'int': set(), 'flexible': {Variable("x", "<X>")}}
2869

2870
        :param variables: The constants to divide/filter from.
2871
        :param smt_formulas: The SMT formulas to consider in the filtering.
2872
        :return: A pair of constants occurring in `str.len` contexts, and the
2873
        remaining ones. The union of both sets equals `variables`, and both sets
2874
        are disjoint.
2875
        """
2876

2877
        parent_relationships = reduce(
1✔
2878
            merge_dict_of_sets,
2879
            [parent_relationships_in_z3_expr(formula) for formula in smt_formulas],
2880
            {},
2881
        )
2882

2883
        contexts: Dict[language.Variable, Set[int]] = {
1✔
2884
            var: {
2885
                expr.decl().kind()
2886
                for expr in parent_relationships.get(var.to_smt(), set())
2887
            }
2888
            or {-1}
2889
            for var in variables
2890
        }
2891

2892
        # The set `length_vars` consists of all variables that only occur in
2893
        # `str.len(...)` context.
2894
        length_vars: Set[language.Variable] = {
1✔
2895
            var
2896
            for var in variables
2897
            if all(context == z3.Z3_OP_SEQ_LENGTH for context in contexts[var])
2898
        }
2899

2900
        # The set `int_vars` consists of all variables that only occur in
2901
        # `str.to.int(...)` context.
2902
        int_vars: Set[language.Variable] = {
1✔
2903
            var
2904
            for var in variables
2905
            if all(context == z3.Z3_OP_STR_TO_INT for context in contexts[var])
2906
        }
2907

2908
        # "Flexible" variables are the remaining ones.
2909
        flexible_vars = variables.difference(length_vars).difference(int_vars)
1✔
2910

2911
        return {"length": length_vars, "int": int_vars, "flexible": flexible_vars}
1✔
2912

2913
    def generate_language_constraints(
1✔
2914
        self,
2915
        constants: Iterable[language.Variable],
2916
        tree_substitutions: Dict[language.Variable, DerivationTree],
2917
    ) -> List[z3.BoolRef]:
2918
        formulas: List[z3.BoolRef] = []
1✔
2919
        for constant in constants:
1✔
2920
            if constant.is_numeric():
1✔
2921
                regex = z3.Union(
×
2922
                    z3.Re("0"),
2923
                    z3.Concat(z3.Range("1", "9"), z3.Star(z3.Range("0", "9"))),
2924
                )
2925
            elif constant in tree_substitutions:
1✔
2926
                # We have a more concrete shape of the desired instantiation available
2927
                regexes = [
1✔
2928
                    self.extract_regular_expression(t)
2929
                    if is_nonterminal(t)
2930
                    else z3.Re(t)
2931
                    for t in split_str_with_nonterminals(
2932
                        str(tree_substitutions[constant])
2933
                    )
2934
                ]
2935
                assert regexes
1✔
2936
                regex = z3.Concat(*regexes) if len(regexes) > 1 else regexes[0]
1✔
2937
            else:
2938
                regex = self.extract_regular_expression(constant.n_type)
1✔
2939

2940
            formulas.append(z3.InRe(z3.String(constant.name), regex))
1✔
2941

2942
        return formulas
1✔
2943

2944
    def rename_instantiated_variables_in_smt_formulas(self, smt_formulas):
1✔
2945
        old_smt_formulas = smt_formulas
1✔
2946
        smt_formulas = []
1✔
2947
        for subformula in old_smt_formulas:
1✔
2948
            subst_var: language.Variable
2949
            subst_tree: DerivationTree
2950

2951
            new_smt_formula: z3.BoolRef = subformula.formula
1✔
2952
            new_substitutions = subformula.substitutions
1✔
2953
            new_instantiated_variables = subformula.instantiated_variables
1✔
2954

2955
            for subst_var, subst_tree in subformula.substitutions.items():
1✔
2956
                new_name = f"{subst_tree.value}_{subst_tree.id}"
1✔
2957
                new_var = language.BoundVariable(new_name, subst_var.n_type)
1✔
2958

2959
                new_smt_formula = cast(
1✔
2960
                    z3.BoolRef,
2961
                    z3_subst(new_smt_formula, {subst_var.to_smt(): new_var.to_smt()}),
2962
                )
2963
                new_substitutions = {
1✔
2964
                    new_var if k == subst_var else k: v
2965
                    for k, v in new_substitutions.items()
2966
                }
2967
                new_instantiated_variables = {
1✔
2968
                    new_var if v == subst_var else v for v in new_instantiated_variables
2969
                }
2970

2971
            smt_formulas.append(
1✔
2972
                language.SMTFormula(
2973
                    new_smt_formula,
2974
                    *subformula.free_variables_,
2975
                    instantiated_variables=new_instantiated_variables,
2976
                    substitutions=new_substitutions,
2977
                )
2978
            )
2979

2980
        return smt_formulas
1✔
2981

2982
    def process_new_states(
1✔
2983
        self, new_states: List[SolutionState]
2984
    ) -> List[DerivationTree]:
2985
        return [
1✔
2986
            tree
2987
            for new_state in new_states
2988
            for tree in self.process_new_state(new_state)
2989
        ]
2990

2991
    def process_new_state(self, new_state: SolutionState) -> List[DerivationTree]:
1✔
2992
        new_state = self.instantiate_structural_predicates(new_state)
1✔
2993
        new_states = self.establish_invariant(new_state)
1✔
2994
        new_states = [
1✔
2995
            self.remove_nonmatching_universal_quantifiers(new_state)
2996
            for new_state in new_states
2997
        ]
2998
        new_states = [
1✔
2999
            self.remove_infeasible_universal_quantifiers(new_state)
3000
            for new_state in new_states
3001
        ]
3002

3003
        if self.activate_unsat_support and not self.currently_unsat_checking:
1✔
3004
            self.currently_unsat_checking = True
1✔
3005

3006
            for new_state in list(new_states):
1✔
3007
                if new_state.constraint == sc.true():
1✔
3008
                    continue
×
3009

3010
                # Remove states with unsatisfiable SMT-LIB formulas.
3011
                if (
1✔
3012
                    any(
3013
                        isinstance(f, language.SMTFormula)
3014
                        for f in split_conjunction(new_state.constraint)
3015
                    )
3016
                    and not self.eliminate_all_semantic_formulas(
3017
                        new_state, max_instantiations=1
3018
                    )
3019
                    .bind(lambda a: Maybe(a if a else None))
3020
                    .is_present()
3021
                ):
3022
                    new_states.remove(new_state)
1✔
3023
                    self.logger.debug(
1✔
3024
                        "Dropping state %s, unsatisfiable SMT formulas", new_state
3025
                    )
3026

3027
                # Remove states with unsatisfiable existential formulas.
3028
                existential_formulas = [
1✔
3029
                    f
3030
                    for f in split_conjunction(new_state.constraint)
3031
                    if isinstance(f, language.ExistsFormula)
3032
                ]
3033
                for existential_formula in existential_formulas:
1✔
3034
                    old_start_time = self.start_time
1✔
3035
                    old_timeout_seconds = self.timeout_seconds
1✔
3036
                    old_queue = list(self.queue)
1✔
3037
                    old_solutions = list(self.solutions)
1✔
3038

3039
                    self.queue = []
1✔
3040
                    self.solutions = []
1✔
3041
                    check_state = SolutionState(existential_formula, new_state.tree)
1✔
3042
                    heapq.heappush(self.queue, (0, check_state))
1✔
3043
                    self.start_time = int(time.time())
1✔
3044
                    self.timeout_seconds = 2
1✔
3045

3046
                    try:
1✔
3047
                        self.solve()
1✔
3048
                    except StopIteration:
1✔
3049
                        new_states.remove(new_state)
1✔
3050
                        self.logger.debug(
1✔
3051
                            "Dropping state %s, unsatisfiable existential formula %s",
3052
                            new_state,
3053
                            existential_formula,
3054
                        )
3055
                        break
1✔
3056
                    finally:
3057
                        self.start_time = old_start_time
1✔
3058
                        self.timeout_seconds = old_timeout_seconds
1✔
3059
                        self.queue = old_queue
1✔
3060
                        self.solutions = old_solutions
1✔
3061

3062
            self.currently_unsat_checking = False
1✔
3063

3064
        assert all(
1✔
3065
            state.tree.find_node(tree) is not None
3066
            for state in new_states
3067
            for quantified_formula in split_conjunction(state.constraint)
3068
            if isinstance(quantified_formula, language.QuantifiedFormula)
3069
            for _, tree in quantified_formula.in_variable.filter(lambda t: True)
3070
        )
3071

3072
        solution_trees = [
1✔
3073
            new_state.tree
3074
            for new_state in new_states
3075
            if self.state_is_valid_or_enqueue(new_state)
3076
        ]
3077

3078
        for tree in solution_trees:
1✔
3079
            self.cost_computer.signal_tree_output(tree)
1✔
3080

3081
        return solution_trees
1✔
3082

3083
    def state_is_valid_or_enqueue(self, state: SolutionState) -> bool:
1✔
3084
        """
3085
        Returns True if the given state is valid, such that it can be yielded. Returns False and enqueues the state
3086
        if the state is not yet complete, otherwise returns False and discards the state.
3087
        """
3088

3089
        if state.complete():
1✔
3090
            for _, subtree in state.tree.paths():
1✔
3091
                if subtree.children:
1✔
3092
                    self.seen_coverages.add(
1✔
3093
                        expansion_key(subtree.value, subtree.children)
3094
                    )
3095

3096
            assert state.formula_satisfied(self.grammar).is_true()
1✔
3097
            return True
1✔
3098

3099
        # Helps in debugging below assertion:
3100
        # [(predicate_formula, [
3101
        #     arg for arg in predicate_formula.args
3102
        #     if isinstance(arg, DerivationTree) and not state.tree.find_node(arg)])
3103
        #  for predicate_formula in get_conjuncts(state.constraint)
3104
        #  if isinstance(predicate_formula, language.StructuralPredicateFormula)]
3105

3106
        self.assert_no_dangling_predicate_argument_trees(state)
1✔
3107
        self.assert_no_dangling_smt_formula_argument_trees(state)
1✔
3108

3109
        if (
1✔
3110
            self.enforce_unique_trees_in_queue
3111
            and state.tree.structural_hash() in self.tree_hashes_in_queue
3112
        ):
3113
            # Some structures can arise as well from tree insertion (existential
3114
            # quantifier elimination) and expansion; also, tree insertion can yield
3115
            # different trees that have intersecting expansions. We drop those to output
3116
            # more diverse solutions (numbers for SMT solutions and free nonterminals
3117
            # are configurable, so you get more outputs by playing with those!).
3118
            self.logger.debug("Discarding state %s, tree already in queue", state)
1✔
3119
            return False
1✔
3120

3121
        if hash(state) in self.state_hashes_in_queue:
1✔
3122
            self.logger.debug("Discarding state %s, already in queue", state)
1✔
3123
            return False
1✔
3124

3125
        if self.propositionally_unsatisfiable(state.constraint):
1✔
3126
            self.logger.debug("Discarding state %s", state)
1✔
3127
            return False
1✔
3128

3129
        state = SolutionState(
1✔
3130
            state.constraint, state.tree, level=self.current_level + 1
3131
        )
3132

3133
        self.recompute_costs()
1✔
3134

3135
        cost = self.compute_cost(state)
1✔
3136
        heapq.heappush(self.queue, (cost, state))
1✔
3137
        self.tree_hashes_in_queue.add(state.tree.structural_hash())
1✔
3138
        self.state_hashes_in_queue.add(hash(state))
1✔
3139

3140
        if self.debug:
1✔
3141
            self.state_tree[self.current_state].append(state)
1✔
3142
            self.costs[state] = cost
1✔
3143

3144
        self.logger.debug(
1✔
3145
            "Pushing new state (%s, %s) (hash %d, cost %f)",
3146
            state.constraint,
3147
            state.tree.to_string(show_open_leaves=True, show_ids=True),
3148
            hash(state),
3149
            cost,
3150
        )
3151
        self.logger.debug("Queue length: %d", len(self.queue))
1✔
3152
        if len(self.queue) % 100 == 0:
1✔
3153
            self.logger.info("Queue length: %d", len(self.queue))
1✔
3154

3155
        return False
1✔
3156

3157
    def recompute_costs(self):
1✔
3158
        if self.step_cnt % 400 != 0 or self.step_cnt <= self.last_cost_recomputation:
1✔
3159
            return
1✔
3160

3161
        self.last_cost_recomputation = self.step_cnt
1✔
3162
        self.logger.info(
1✔
3163
            f"Recomputing costs in queue after {self.step_cnt} solver steps"
3164
        )
3165
        old_queue = list(self.queue)
1✔
3166
        self.queue = []
1✔
3167
        for _, state in old_queue:
1✔
3168
            cost = self.compute_cost(state)
1✔
3169
            heapq.heappush(self.queue, (cost, state))
1✔
3170

3171
    def assert_no_dangling_smt_formula_argument_trees(
1✔
3172
        self, state: SolutionState
3173
    ) -> None:
3174
        if not assertions_activated() and not self.debug:
1✔
3175
            return
×
3176

3177
        dangling_smt_formula_argument_trees = [
1✔
3178
            (smt_formula, arg)
3179
            for smt_formula in language.FilterVisitor(
3180
                lambda f: isinstance(f, language.SMTFormula)
3181
            ).collect(state.constraint)
3182
            for arg in cast(language.SMTFormula, smt_formula).substitutions.values()
3183
            if isinstance(arg, DerivationTree) and state.tree.find_node(arg) is None
3184
        ]
3185

3186
        if dangling_smt_formula_argument_trees:
1✔
3187
            message = "Dangling SMT formula arguments: ["
×
3188
            message += ", ".join(
×
3189
                [
3190
                    str(f) + ", " + repr(a)
3191
                    for f, a in dangling_smt_formula_argument_trees
3192
                ]
3193
            )
3194
            message += "]"
×
3195
            assert False, message
×
3196

3197
    def assert_no_dangling_predicate_argument_trees(self, state: SolutionState) -> None:
1✔
3198
        if not assertions_activated() and not self.debug:
1✔
3199
            return
×
3200

3201
        dangling_predicate_argument_trees = [
1✔
3202
            (predicate_formula, arg)
3203
            for predicate_formula in language.FilterVisitor(
3204
                lambda f: isinstance(f, language.StructuralPredicateFormula)
3205
            ).collect(state.constraint)
3206
            for arg in cast(language.StructuralPredicateFormula, predicate_formula).args
3207
            if isinstance(arg, DerivationTree) and state.tree.find_node(arg) is None
3208
        ]
3209

3210
        if dangling_predicate_argument_trees:
1✔
3211
            message = "Dangling predicate arguments: ["
×
3212
            message += ", ".join(
×
3213
                [str(f) + ", " + repr(a) for f, a in dangling_predicate_argument_trees]
3214
            )
3215
            message += "]"
×
3216
            assert False, message
×
3217

3218
    def propositionally_unsatisfiable(self, formula: language.Formula) -> bool:
1✔
3219
        return formula == sc.false()
1✔
3220

3221
        # NOTE: Deactivated propositional check for performance reasons
3222
        # z3_formula = language.isla_to_smt_formula(formula, replace_untranslatable_with_predicate=True)
3223
        # solver = z3.Solver()
3224
        # solver.add(z3_formula)
3225
        # return solver.check() == z3.unsat
3226

3227
    def establish_invariant(self, state: SolutionState) -> List[SolutionState]:
1✔
3228
        formula = convert_to_dnf(convert_to_nnf(state.constraint), deep=False)
1✔
3229
        return [
1✔
3230
            SolutionState(disjunct, state.tree)
3231
            for disjunct in split_disjunction(formula)
3232
        ]
3233

3234
    def compute_cost(self, state: SolutionState) -> float:
1✔
3235
        if state.constraint == sc.true():
1✔
3236
            return 0
1✔
3237

3238
        return self.cost_computer.compute_cost(state)
1✔
3239

3240
    def remove_nonmatching_universal_quantifiers(
1✔
3241
        self, state: SolutionState
3242
    ) -> SolutionState:
3243
        conjuncts = [conjunct for conjunct in get_conjuncts(state.constraint)]
1✔
3244
        deleted = False
1✔
3245

3246
        for idx, universal_formula in reversed(list(enumerate(conjuncts))):
1✔
3247
            if not isinstance(universal_formula, language.ForallFormula):
1✔
3248
                continue
1✔
3249

3250
            if (
1✔
3251
                universal_formula.in_variable.is_complete()
3252
                and not matches_for_quantified_formula(universal_formula, self.grammar)
3253
            ):
3254
                deleted = True
1✔
3255
                del conjuncts[idx]
1✔
3256

3257
        if not deleted:
1✔
3258
            return state
1✔
3259

3260
        return SolutionState(sc.conjunction(*conjuncts), state.tree)
1✔
3261

3262
    def remove_infeasible_universal_quantifiers(
1✔
3263
        self, state: SolutionState
3264
    ) -> SolutionState:
3265
        conjuncts = get_conjuncts(state.constraint)
1✔
3266
        one_removed = False
1✔
3267

3268
        for idx, universal_formula in reversed(list(enumerate(conjuncts))):
1✔
3269
            if not isinstance(universal_formula, language.ForallFormula):
1✔
3270
                continue
1✔
3271

3272
            matches = matches_for_quantified_formula(universal_formula, self.grammar)
1✔
3273

3274
            all_matches_matched = all(
1✔
3275
                universal_formula.is_already_matched(
3276
                    match[universal_formula.bound_variable][1]
3277
                )
3278
                for match in matches
3279
            )
3280

3281
            def some_leaf_might_match() -> bool:
1✔
3282
                return any(
1✔
3283
                    self.quantified_formula_might_match(
3284
                        universal_formula, leaf_path, universal_formula.in_variable
3285
                    )
3286
                    for leaf_path, _ in universal_formula.in_variable.open_leaves()
3287
                )
3288

3289
            if all_matches_matched and not some_leaf_might_match():
1✔
3290
                one_removed = True
1✔
3291
                del conjuncts[idx]
1✔
3292

3293
        return (
1✔
3294
            state
3295
            if not one_removed
3296
            else SolutionState(
3297
                reduce(lambda a, b: a & b, conjuncts, sc.true()),
3298
                state.tree,
3299
            )
3300
        )
3301

3302
    def quantified_formula_might_match(
1✔
3303
        self,
3304
        qfd_formula: language.QuantifiedFormula,
3305
        path_to_nonterminal: Path,
3306
        tree: DerivationTree,
3307
    ) -> bool:
3308
        return quantified_formula_might_match(
1✔
3309
            qfd_formula,
3310
            path_to_nonterminal,
3311
            tree,
3312
            self.grammar,
3313
            self.graph.reachable,
3314
        )
3315

3316
    def extract_regular_expression(self, nonterminal: str) -> z3.ReRef:
1✔
3317
        if nonterminal in self.regex_cache:
1✔
3318
            return self.regex_cache[nonterminal]
1✔
3319

3320
        regex_conv = RegexConverter(
1✔
3321
            self.grammar,
3322
            compress_unions=True,
3323
            max_num_expansions=self.grammar_unwinding_threshold,
3324
        )
3325
        regex = regex_conv.to_regex(nonterminal, convert_to_z3=False)
1✔
3326
        self.logger.debug(
1✔
3327
            f"Computed regular expression for nonterminal {nonterminal}:\n{regex}"
3328
        )
3329
        z3_regex = regex_to_z3(regex)
1✔
3330

3331
        if assertions_activated():
1✔
3332
            # Check correctness of regular expression
3333
            grammar = self.graph.subgraph(nonterminal).to_grammar()
1✔
3334

3335
            # 1. L(grammar) \subseteq L(regex)
3336
            # NOTE: Removed this check. If unwinding is required, it will fail!
3337
            # self.logger.debug(
3338
            #     "Checking L(grammar) \\subseteq L(regex) for nonterminal '%s' and regex '%s'",
3339
            #     nonterminal,
3340
            #     regex)
3341
            # fuzzer = GrammarCoverageFuzzer(grammar)
3342
            # for _ in range(400):
3343
            #     inp = fuzzer.fuzz()
3344
            #     s = z3.Solver()
3345
            #     s.add(z3.InRe(z3.StringVal(inp), z3_regex))
3346
            #     assert s.check() == z3.sat, f"Input '{inp}' from grammar language is not in regex language"
3347

3348
            # 2. L(regex) \subseteq L(grammar)
3349
            self.logger.debug(
1✔
3350
                "Checking L(regex) \\subseteq L(grammar) for nonterminal '%s' and regex '%s'",
3351
                nonterminal,
3352
                regex,
3353
            )
3354
            parser = EarleyParser(grammar)
1✔
3355
            c = z3.String("c")
1✔
3356
            prev: Set[str] = set()
1✔
3357
            for _ in range(100):
1✔
3358
                s = z3.Solver()
1✔
3359
                s.add(z3.InRe(c, z3_regex))
1✔
3360
                for inp in prev:
1✔
3361
                    s.add(z3.Not(c == z3.StringVal(inp)))
1✔
3362
                if s.check() != z3.sat:
1✔
3363
                    self.logger.debug(
×
3364
                        "Cannot find the %d-th solution for regex %s (timeout).\nThis is *not* a problem "
3365
                        "if there not that many solutions (for regexes with finite language), or if we "
3366
                        "are facing a meaningless timeout of the solver.",
3367
                        len(prev) + 1,
3368
                        regex,
3369
                    )
3370
                    break
×
3371
                new_inp = smt_string_val_to_string(s.model()[c])
1✔
3372
                try:
1✔
3373
                    next(parser.parse(new_inp))
1✔
3374
                except SyntaxError:
×
3375
                    assert (
×
3376
                        False
3377
                    ), f"Input '{new_inp}' from regex language is not in grammar language."
3378
                prev.add(new_inp)
1✔
3379

3380
        self.regex_cache[nonterminal] = z3_regex
1✔
3381

3382
        return z3_regex
1✔
3383

3384

3385
class CostComputer(ABC):
1✔
3386
    def compute_cost(self, state: SolutionState) -> float:
1✔
3387
        """
3388
        Computes a cost value for the given state. States with lower cost
3389
        will be preferred in the analysis.
3390

3391
        :param state: The state for which to compute a cost.
3392
        :return: The cost value.
3393
        """
3394
        raise NotImplementedError()
×
3395

3396
    def signal_tree_output(self, tree: DerivationTree) -> None:
1✔
3397
        """
3398
        Should be called when a tree is output as a solution. Used to
3399
        update internal information for cost computation.
3400

3401
        :param tree The tree that is output as a solution.
3402
        :return: Nothing.
3403
        """
3404
        raise NotImplementedError()
×
3405

3406

3407
class GrammarBasedBlackboxCostComputer(CostComputer):
1✔
3408
    def __init__(
1✔
3409
        self,
3410
        cost_settings: CostSettings,
3411
        graph: gg.GrammarGraph,
3412
        reset_coverage_after_n_round_with_no_coverage: int = 100,
3413
        symbol_costs: Optional[Dict[str, int]] = None,
3414
    ):
3415
        self.cost_settings = cost_settings
1✔
3416
        self.graph = graph
1✔
3417

3418
        self.covered_k_paths: Set[Tuple[gg.Node, ...]] = set()
1✔
3419
        self.rounds_with_no_new_coverage = 0
1✔
3420
        self.reset_coverage_after_n_round_with_no_coverage = (
1✔
3421
            reset_coverage_after_n_round_with_no_coverage
3422
        )
3423
        self.symbol_costs: Optional[Dict[str, int]] = symbol_costs
1✔
3424

3425
        self.logger = logging.getLogger(type(self).__name__)
1✔
3426

3427
    def __repr__(self):
1✔
3428
        return (
×
3429
            "GrammarBasedBlackboxCostComputer("
3430
            + f"{repr(self.cost_settings)}, "
3431
            + "graph, "
3432
            + f"{self.reset_coverage_after_n_round_with_no_coverage}, "
3433
            + f"{self.symbol_costs})"
3434
        )
3435

3436
    def compute_cost(self, state: SolutionState) -> float:
1✔
3437
        # How costly is it to finish the tree?
3438
        tree_closing_cost = self.compute_tree_closing_cost(state.tree)
1✔
3439

3440
        # Quantifiers are expensive (universal formulas have to be matched, tree insertion for existential
3441
        # formulas is even more costly). TODO: Penalize nested quantifiers more.
3442
        constraint_cost = sum(
1✔
3443
            [
3444
                idx * (2 if isinstance(f, language.ExistsFormula) else 1) + 1
3445
                for c in get_quantifier_chains(state.constraint)
3446
                for idx, f in enumerate(c)
3447
            ]
3448
        )
3449

3450
        # k-Path coverage: Fewer covered -> higher penalty
3451
        k_cov_cost = self._compute_k_coverage_cost(state)
1✔
3452

3453
        # Covered k-paths: Fewer contributed -> higher penalty
3454
        global_k_path_cost = self._compute_global_k_coverage_cost(state)
1✔
3455

3456
        costs = [
1✔
3457
            tree_closing_cost,
3458
            constraint_cost,
3459
            state.level,
3460
            k_cov_cost,
3461
            global_k_path_cost,
3462
        ]
3463
        assert tree_closing_cost >= 0, f"tree_closing_cost == {tree_closing_cost}!"
1✔
3464
        assert constraint_cost >= 0, f"constraint_cost == {constraint_cost}!"
1✔
3465
        assert state.level >= 0, f"state.level == {state.level}!"
1✔
3466
        assert k_cov_cost >= 0, f"k_cov_cost == {k_cov_cost}!"
1✔
3467
        assert global_k_path_cost >= 0, f"global_k_path_cost == {global_k_path_cost}!"
1✔
3468

3469
        # Compute geometric mean
3470
        result = weighted_geometric_mean(costs, list(self.cost_settings.weight_vector))
1✔
3471

3472
        self.logger.debug(
1✔
3473
            "Computed cost for state %s:\n%f, individual costs: %s, weights: %s",
3474
            lazystr(lambda: f"({(str(state.constraint)[:50] + '...')}, {state.tree})"),
3475
            result,
3476
            costs,
3477
            self.cost_settings.weight_vector,
3478
        )
3479

3480
        return result
1✔
3481

3482
    def signal_tree_output(self, tree: DerivationTree) -> None:
1✔
3483
        self._update_covered_k_paths(tree)
1✔
3484

3485
    def _symbol_costs(self):
1✔
3486
        if self.symbol_costs is None:
1✔
3487
            self.symbol_costs = compute_symbol_costs(self.graph)
1✔
3488
        return self.symbol_costs
1✔
3489

3490
    def _update_covered_k_paths(self, tree: DerivationTree):
1✔
3491
        if self.cost_settings.weight_vector.low_global_k_path_coverage_penalty > 0:
1✔
3492
            old_covered_k_paths = copy.copy(self.covered_k_paths)
1✔
3493

3494
            self.covered_k_paths.update(
1✔
3495
                tree.k_paths(
3496
                    self.graph, self.cost_settings.k, include_potential_paths=False
3497
                )
3498
            )
3499

3500
            if old_covered_k_paths == self.covered_k_paths:
1✔
3501
                self.rounds_with_no_new_coverage += 1
1✔
3502

3503
            graph_paths = self.graph.k_paths(
1✔
3504
                self.cost_settings.k, include_terminals=False
3505
            )
3506
            if (
1✔
3507
                self.rounds_with_no_new_coverage
3508
                >= self.reset_coverage_after_n_round_with_no_coverage
3509
                or self.covered_k_paths == graph_paths
3510
            ):
3511
                if self.covered_k_paths == graph_paths:
1✔
3512
                    self.logger.debug("ALL PATHS COVERED")
1✔
3513
                else:
3514
                    self.logger.debug(
×
3515
                        "COVERAGE RESET SINCE NO CHANGE IN COVERED PATHS SINCE %d "
3516
                        + "ROUNDS (%d path(s) uncovered)",
3517
                        self.reset_coverage_after_n_round_with_no_coverage,
3518
                        len(graph_paths) - len(self.covered_k_paths),
3519
                    )
3520

3521
                    # uncovered_paths = (
3522
                    #     self.graph.k_paths(
3523
                    #         self.cost_settings.k, include_terminals=False
3524
                    #     )
3525
                    #     - self.covered_k_paths
3526
                    # )
3527
                    # self.logger.debug(
3528
                    #     "\n".join(
3529
                    #         [
3530
                    #             ", ".join(f"'{n.symbol}'" for n in p)
3531
                    #             for p in uncovered_paths
3532
                    #         ]
3533
                    #     )
3534
                    # )
3535

3536
                self.covered_k_paths = set()
1✔
3537
            else:
3538
                pass
1✔
3539
                # uncovered_paths = (
3540
                #     self.graph.k_paths(self.cost_settings.k, include_terminals=False)
3541
                #     - self.covered_k_paths
3542
                # )
3543
                # self.logger.debug("%d uncovered paths", len(uncovered_paths))
3544
                # self.logger.debug(
3545
                #     "\n"
3546
                #     + "\n".join(
3547
                #         [", ".join(f"'{n.symbol}'" for n in p)
3548
                #         for p in uncovered_paths]
3549
                #     )
3550
                #     + "\n"
3551
                # )
3552

3553
            if (
1✔
3554
                self.rounds_with_no_new_coverage
3555
                >= self.reset_coverage_after_n_round_with_no_coverage
3556
            ):
3557
                self.rounds_with_no_new_coverage = 0
×
3558

3559
    def _compute_global_k_coverage_cost(self, state: SolutionState):
1✔
3560
        if self.cost_settings.weight_vector.low_global_k_path_coverage_penalty == 0:
1✔
3561
            return 0
×
3562

3563
        tree_k_paths = state.tree.k_paths(
1✔
3564
            self.graph, self.cost_settings.k, include_potential_paths=False
3565
        )
3566
        all_graph_k_paths = self.graph.k_paths(
1✔
3567
            self.cost_settings.k, include_terminals=False
3568
        )
3569

3570
        contributed_k_paths = {
1✔
3571
            path
3572
            for path in all_graph_k_paths
3573
            if path in tree_k_paths and path not in self.covered_k_paths
3574
        }
3575

3576
        num_contributed_k_paths = len(contributed_k_paths)
1✔
3577
        num_missing_k_paths = len(all_graph_k_paths) - len(self.covered_k_paths)
1✔
3578

3579
        # self.logger.debug(
3580
        #     'k-Paths contributed by input %s:\n%s',
3581
        #     state.tree,
3582
        #     '\n'.join(map(
3583
        #         lambda path: ' '.join(map(
3584
        #             lambda n: n.symbol,
3585
        #             filter(lambda n: not isinstance(n, gg.ChoiceNode), path))),
3586
        #         contributed_k_paths)))
3587
        # self.logger.debug('Missing k paths: %s', num_missing_k_paths)
3588

3589
        assert 0 <= num_contributed_k_paths <= num_missing_k_paths, (
1✔
3590
            f"num_contributed_k_paths == {num_contributed_k_paths}, "
3591
            f"num_missing_k_paths == {num_missing_k_paths}"
3592
        )
3593

3594
        # return 1 - (num_contributed_k_paths / num_missing_k_paths)
3595

3596
        potential_tree_k_paths = state.tree.k_paths(
1✔
3597
            self.graph, self.cost_settings.k, include_potential_paths=True
3598
        )
3599
        contributed_k_paths = {
1✔
3600
            path
3601
            for path in all_graph_k_paths
3602
            if path in potential_tree_k_paths and path not in self.covered_k_paths
3603
        }
3604

3605
        num_contributed_potential_k_paths = len(contributed_k_paths)
1✔
3606

3607
        if not num_missing_k_paths:
1✔
3608
            return 0
1✔
3609

3610
        return 1 - weighted_geometric_mean(
1✔
3611
            [
3612
                num_contributed_k_paths / num_missing_k_paths,
3613
                num_contributed_potential_k_paths / num_missing_k_paths,
3614
            ],
3615
            [0.2, 0.8],
3616
        )
3617

3618
    def _compute_k_coverage_cost(self, state: SolutionState) -> float:
1✔
3619
        if self.cost_settings.weight_vector.low_k_coverage_penalty == 0:
1✔
3620
            return 0
1✔
3621

3622
        coverages = []
1✔
3623
        for k in range(1, self.cost_settings.k + 1):
1✔
3624
            coverage = state.tree.k_coverage(
1✔
3625
                self.graph, k, include_potential_paths=False
3626
            )
3627
            assert 0 <= coverage <= 1, f"coverage == {coverage}"
1✔
3628

3629
            coverages.append(1 - coverage)
1✔
3630

3631
        return math.prod(coverages) ** (1 / float(self.cost_settings.k))
1✔
3632

3633
    def compute_tree_closing_cost(self, tree: DerivationTree) -> float:
1✔
3634
        nonterminals = [leaf.value for _, leaf in tree.open_leaves()]
1✔
3635
        return sum([self._symbol_costs()[nonterminal] for nonterminal in nonterminals])
1✔
3636

3637

3638
def smt_formulas_referring_to_subtrees(
1✔
3639
    smt_formulas: Sequence[language.SMTFormula],
3640
) -> List[language.SMTFormula]:
3641
    """
3642
    Returns a list of SMT formulas whose solutions address subtrees of other SMT
3643
    formulas, but whose own substitution subtrees are in turn *not* referred by
3644
    top-level substitution trees of other formulas. Those must be solved first to avoid
3645
    inconsistencies.
3646

3647
    :param smt_formulas: The formulas to search for references to subtrees.
3648
    :return: The list of conflicting formulas that must be solved first.
3649
    """
3650

3651
    def subtree_ids(formula: language.SMTFormula) -> Set[int]:
1✔
3652
        return {
1✔
3653
            subtree.id
3654
            for tree in formula.substitutions.values()
3655
            for _, subtree in tree.paths()
3656
            if subtree.id != tree.id
3657
        }
3658

3659
    def tree_ids(formula: language.SMTFormula) -> Set[int]:
1✔
3660
        return {tree.id for tree in formula.substitutions.values()}
1✔
3661

3662
    subtree_ids_for_formula: Dict[language.SMTFormula, Set[int]] = {
1✔
3663
        formula: subtree_ids(formula) for formula in smt_formulas
3664
    }
3665

3666
    tree_ids_for_formula: Dict[language.SMTFormula, Set[int]] = {
1✔
3667
        formula: tree_ids(formula) for formula in smt_formulas
3668
    }
3669

3670
    def independent_from_solutions_of_other_formula(
1✔
3671
        idx: int, formula: language.SMTFormula
3672
    ) -> bool:
3673
        return all(
1✔
3674
            not tree_ids_for_formula[other_formula].intersection(
3675
                subtree_ids_for_formula[formula]
3676
            )
3677
            for other_idx, other_formula in enumerate(smt_formulas)
3678
            if other_idx != idx
3679
        )
3680

3681
    def refers_to_subtree_of_other_formula(
1✔
3682
        idx: int, formula: language.SMTFormula
3683
    ) -> bool:
3684
        return any(
1✔
3685
            tree_ids_for_formula[formula].intersection(
3686
                subtree_ids_for_formula[other_formula]
3687
            )
3688
            for other_idx, other_formula in enumerate(smt_formulas)
3689
            if other_idx != idx
3690
        )
3691

3692
    return [
1✔
3693
        formula
3694
        for idx, formula in enumerate(smt_formulas)
3695
        if refers_to_subtree_of_other_formula(idx, formula)
3696
        and independent_from_solutions_of_other_formula(idx, formula)
3697
    ]
3698

3699

3700
def compute_tree_closing_cost(tree: DerivationTree, graph: GrammarGraph) -> float:
1✔
3701
    nonterminals = [leaf.value for _, leaf in tree.open_leaves()]
1✔
3702
    return sum(
1✔
3703
        [compute_symbol_costs(graph)[nonterminal] for nonterminal in nonterminals]
3704
    )
3705

3706

3707
def get_quantifier_chains(
1✔
3708
    formula: language.Formula,
3709
) -> List[Tuple[Union[language.QuantifiedFormula, language.ExistsIntFormula], ...]]:
3710
    univ_toplevel_formulas = get_toplevel_quantified_formulas(formula)
1✔
3711
    return [
1✔
3712
        (f,) + c
3713
        for f in univ_toplevel_formulas
3714
        for c in (get_quantifier_chains(f.inner_formula) or [()])
3715
    ]
3716

3717

3718
def shortest_derivations(graph: gg.GrammarGraph) -> Dict[str, int]:
1✔
3719
    def avg(it) -> int:
1✔
3720
        elems = [elem for elem in it if elem is not None]
1✔
3721
        return math.ceil(math.prod(elems) ** (1 / len(elems)))
1✔
3722

3723
    parent_relation = {node: set() for node in graph.all_nodes}
1✔
3724
    for parent, child in graph.all_edges:
1✔
3725
        parent_relation[child].add(parent)
1✔
3726

3727
    shortest_node_derivations: Dict[gg.Node, int] = {}
1✔
3728
    stack: List[gg.Node] = graph.filter(lambda node: isinstance(node, gg.TerminalNode))
1✔
3729
    while stack:
1✔
3730
        node = stack.pop()
1✔
3731

3732
        old_min = shortest_node_derivations.get(node, None)
1✔
3733

3734
        if isinstance(node, gg.TerminalNode):
1✔
3735
            shortest_node_derivations[node] = 0
1✔
3736
        elif isinstance(node, gg.ChoiceNode):
1✔
3737
            shortest_node_derivations[node] = max(
1✔
3738
                shortest_node_derivations.get(child, 0) for child in node.children
3739
            )
3740
        elif isinstance(node, gg.NonterminalNode):
1✔
3741
            assert not isinstance(node, gg.ChoiceNode)
1✔
3742

3743
            shortest_node_derivations[node] = (
1✔
3744
                avg(
3745
                    shortest_node_derivations.get(child, None)
3746
                    for child in node.children
3747
                )
3748
                + 1
3749
            )
3750

3751
        if (old_min or sys.maxsize) > shortest_node_derivations[node]:
1✔
3752
            stack.extend(parent_relation[node])
1✔
3753

3754
    return {
1✔
3755
        nonterminal: shortest_node_derivations[graph.get_node(nonterminal)]
3756
        for nonterminal in graph.grammar
3757
    }
3758

3759

3760
@lru_cache()
1✔
3761
def compute_symbol_costs(graph: GrammarGraph) -> Dict[str, int]:
1✔
3762
    grammar = graph.to_grammar()
1✔
3763
    canonical_grammar = canonical(grammar)
1✔
3764

3765
    result: Dict[str, int] = shortest_derivations(graph)
1✔
3766

3767
    nonterminal_parents = [
1✔
3768
        nonterminal
3769
        for nonterminal in canonical_grammar
3770
        if any(
3771
            is_nonterminal(symbol)
3772
            for expansion in canonical_grammar[nonterminal]
3773
            for symbol in expansion
3774
        )
3775
    ]
3776

3777
    # Sometimes this computation results in some nonterminals having lower cost values
3778
    # than nonterminals that are reachable from those (but not vice versa), which is
3779
    # undesired. We counteract this by assuring that on paths with at most one cycle
3780
    # from the root to any nonterminal parent, the costs are strictly monotonically
3781
    # decreasing.
3782
    for nonterminal_parent in nonterminal_parents:
1✔
3783
        # noinspection PyTypeChecker
3784
        for path in all_paths(graph, graph.root, graph.get_node(nonterminal_parent)):
1✔
3785
            for idx in reversed(range(1, len(path))):
1✔
3786
                source: gg.Node = path[idx - 1]
1✔
3787
                target: gg.Node = path[idx]
1✔
3788

3789
                if result[source.symbol] <= result[target.symbol]:
1✔
3790
                    result[source.symbol] = result[target.symbol] + 1
1✔
3791

3792
    return result
1✔
3793

3794

3795
def all_paths(
1✔
3796
    graph,
3797
    from_node: gg.NonterminalNode,
3798
    to_node: gg.NonterminalNode,
3799
    cycles_allowed: int = 2,
3800
) -> List[List[gg.NonterminalNode]]:
3801
    """Compute all paths between two nodes. Note: We allow to visit each nonterminal twice.
3802
    This is not really allowing up to `cycles_allowed` cycles (which was the original intention
3803
    of the parameter), since then we would have to check per path; yet, the number of paths would
3804
    explode then and the current implementation provides reasonably good results."""
3805
    result: List[List[gg.NonterminalNode]] = []
1✔
3806
    visited: Dict[gg.NonterminalNode, int] = {n: 0 for n in graph.all_nodes}
1✔
3807

3808
    queue: List[List[gg.NonterminalNode]] = [[from_node]]
1✔
3809
    while queue:
1✔
3810
        p = queue.pop(0)
1✔
3811
        if p[-1] == to_node:
1✔
3812
            result.append(p)
1✔
3813
            continue
1✔
3814

3815
        for child in p[-1].children:
1✔
3816
            if (
1✔
3817
                not isinstance(child, gg.NonterminalNode)
3818
                or visited[child] > cycles_allowed + 1
3819
            ):
3820
                continue
1✔
3821

3822
            visited[child] += 1
1✔
3823
            queue.append(p + [child])
1✔
3824

3825
    return [[n for n in p if not isinstance(n, gg.ChoiceNode)] for p in result]
1✔
3826

3827

3828
def implies(
1✔
3829
    f1: language.Formula, f2: language.Formula, grammar: Grammar, timeout_seconds=5
3830
) -> bool:
3831
    solver = ISLaSolver(
1✔
3832
        grammar, f1 & -f2, activate_unsat_support=True, timeout_seconds=timeout_seconds
3833
    )
3834

3835
    return (
1✔
3836
        Exceptional.of(solver.solve)
3837
        .map(lambda _: False)
3838
        .recover(lambda e: isinstance(e, StopIteration))
3839
    ).a
3840

3841

3842
def equivalent(
1✔
3843
    f1: language.Formula, f2: language.Formula, grammar: Grammar, timeout_seconds=5
3844
) -> bool:
3845
    solver = ISLaSolver(
×
3846
        grammar,
3847
        -(f1 & f2 | -f1 & -f2),
3848
        activate_unsat_support=True,
3849
        timeout_seconds=timeout_seconds,
3850
    )
3851

3852
    return (
×
3853
        Exceptional.of(solver.solve)
3854
        .map(lambda _: False)
3855
        .recover(lambda e: isinstance(e, StopIteration))
3856
    ).a
3857

3858

3859
def generate_abstracted_trees(
1✔
3860
    inp: DerivationTree, participating_paths: Set[Path]
3861
) -> List[DerivationTree]:
3862
    """
3863
    Yields trees that are more and more "abstracted," i.e., pruned, at prefixes of the
3864
    paths specified in `participating_paths`.
3865

3866
    :param inp: The unabstracted input.
3867
    :param participating_paths: The paths to abstract.
3868
    :return: A generator of more and more abstract trees, beginning with the most
3869
    concrete and ending with the most abstract ones.
3870
    """
3871
    parent_paths: Set[ImmutableList[Path]] = {
1✔
3872
        tuple(
3873
            [tuple(path[:i]) for i in reversed(range(1, len(path) + 1))]
3874
            if path
3875
            else [()]
3876
        )
3877
        for path in participating_paths
3878
    }
3879

3880
    abstraction_candidate_combinations: Set[ImmutableList[Path]] = {
1✔
3881
        tuple(eliminate_suffixes(combination))
3882
        for k in range(1, len(participating_paths) + 1)
3883
        for paths in itertools.product(*parent_paths)
3884
        for combination in itertools.combinations(paths, k)
3885
    }
3886

3887
    result: Dict[int, DerivationTree] = {}
1✔
3888
    for paths_to_abstract in abstraction_candidate_combinations:
1✔
3889
        abstracted_tree = inp.substitute(
1✔
3890
            {
3891
                inp.get_subtree(path_to_abstract): DerivationTree(
3892
                    inp.get_subtree(path_to_abstract).value
3893
                )
3894
                for path_to_abstract in paths_to_abstract
3895
            }
3896
        )
3897
        result[abstracted_tree.structural_hash()] = abstracted_tree
1✔
3898

3899
    return sorted(result.values(), key=lambda tree: -len(tree))
1✔
3900

3901

3902
class EvaluatePredicateFormulasTransformer(NoopFormulaTransformer):
1✔
3903
    def __init__(self, inp: DerivationTree):
1✔
3904
        super().__init__()
1✔
3905
        self.inp = inp
1✔
3906

3907
    def transform_predicate_formula(
1✔
3908
        self, sub_formula: language.StructuralPredicateFormula
3909
    ) -> language.Formula:
3910
        return sc.true() if sub_formula.evaluate(self.inp) else sc.false()
1✔
3911

3912
    def transform_conjunctive_formula(
1✔
3913
        self, sub_formula: language.ConjunctiveFormula
3914
    ) -> language.Formula:
3915
        return reduce(language.Formula.__and__, sub_formula.args)
1✔
3916

3917
    def transform_disjunctive_formula(
1✔
3918
        self, sub_formula: language.DisjunctiveFormula
3919
    ) -> language.Formula:
3920
        return reduce(language.Formula.__or__, sub_formula.args)
1✔
3921

3922
    def transform_smt_formula(
1✔
3923
        self, sub_formula: language.SMTFormula
3924
    ) -> language.Formula:
3925
        # We instantiate the formula and check whether it evaluates to
3926
        # True (or False in a negation scope); in that case, we replace
3927
        # it by "true." Otherwise, we keep it for later analysis.
3928

3929
        instantiated_formula = copy.deepcopy(sub_formula)
1✔
3930
        set_smt_auto_subst(instantiated_formula, True)
1✔
3931
        set_smt_auto_eval(instantiated_formula, True)
1✔
3932
        instantiated_formula = instantiated_formula.substitute_expressions(
1✔
3933
            sub_formula.substitutions, force=True
3934
        )
3935

3936
        assert instantiated_formula in {sc.true(), sc.false()}
1✔
3937

3938
        return (
1✔
3939
            sc.true()
3940
            if (instantiated_formula == sc.true()) ^ self.in_negation_scope
3941
            else sub_formula
3942
        )
3943

3944

3945
def create_fixed_length_tree(
1✔
3946
    start: DerivationTree | str,
3947
    canonical_grammar: CanonicalGrammar,
3948
    target_length: int,
3949
) -> Optional[DerivationTree]:
3950
    nullable = compute_nullable_nonterminals(canonical_grammar)
1✔
3951
    start = DerivationTree(start) if isinstance(start, str) else start
1✔
3952
    stack: List[
1✔
3953
        Tuple[DerivationTree, int, ImmutableList[Tuple[Path, DerivationTree]]]
3954
    ] = [
3955
        (start, int(start.value not in nullable), (((), start),)),
3956
    ]
3957

3958
    while stack:
1✔
3959
        tree, curr_len, open_leaves = stack.pop()
1✔
3960

3961
        if not open_leaves:
1✔
3962
            if curr_len == target_length:
1✔
3963
                return tree
1✔
3964
            else:
3965
                continue
1✔
3966

3967
        if curr_len > target_length:
1✔
3968
            continue
1✔
3969

3970
        idx: int
3971
        path: Path
3972
        leaf: DerivationTree
3973
        for idx, (path, leaf) in reversed(list(enumerate(open_leaves))):
1✔
3974
            terminal_expansions, expansions = get_expansions(
1✔
3975
                leaf.value, canonical_grammar
3976
            )
3977

3978
            if terminal_expansions:
1✔
3979
                expansions.append(random.choice(terminal_expansions))
1✔
3980

3981
            # Only choose one random terminal expansion; keep all nonterminal expansions
3982
            expansions = sorted(
1✔
3983
                expansions,
3984
                key=lambda expansion: len(
3985
                    [elem for elem in expansion if is_nonterminal(elem)]
3986
                ),
3987
            )
3988

3989
            for expansion in reversed(expansions):
1✔
3990
                new_children = tuple(
1✔
3991
                    [
3992
                        DerivationTree(elem, None if is_nonterminal(elem) else ())
3993
                        for elem in expansion
3994
                    ]
3995
                )
3996

3997
                expanded_tree = tree.replace_path(
1✔
3998
                    path,
3999
                    DerivationTree(
4000
                        leaf.value,
4001
                        new_children,
4002
                    ),
4003
                )
4004

4005
                stack.append(
1✔
4006
                    (
4007
                        expanded_tree,
4008
                        curr_len
4009
                        + sum(
4010
                            [
4011
                                len(child.value)
4012
                                if child.children == ()
4013
                                else (1 if child.value not in nullable else 0)
4014
                                for child in new_children
4015
                            ]
4016
                        )
4017
                        - int(leaf.value not in nullable),
4018
                        open_leaves[:idx]
4019
                        + tuple(
4020
                            [
4021
                                (path + (child_idx,), new_child)
4022
                                for child_idx, new_child in enumerate(new_children)
4023
                                if is_nonterminal(new_child.value)
4024
                            ]
4025
                        )
4026
                        + open_leaves[idx + 1 :],
4027
                    )
4028
                )
4029

4030
    return None
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc