• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

rindPHI / isla / 7846461830

09 Feb 2024 03:58PM UTC coverage: 93.504% (-0.2%) from 93.737%
7846461830

Pull #90

github

web-flow
Merge 4e3e41de3 into 3de029959
Pull Request #90: RepairSolver

796 of 884 new or added lines in 7 files covered. (90.05%)

2 existing lines in 2 files now uncovered.

6852 of 7328 relevant lines covered (93.5%)

0.94 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.54
/src/isla/solver.py
1
# Copyright © 2022 CISPA Helmholtz Center for Information Security.
2
# Author: Dominic Steinhöfel.
3
#
4
# This file is part of ISLa.
5
#
6
# ISLa is free software: you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation, either version 3 of the License, or
9
# (at your option) any later version.
10
#
11
# ISLa is distributed in the hope that it will be useful,
12
# but WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14
# GNU General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with ISLa.  If not, see <http://www.gnu.org/licenses/>.
18

19
import copy
1✔
20
import functools
1✔
21
import heapq
1✔
22
import itertools
1✔
23
import logging
1✔
24
import math
1✔
25
import operator
1✔
26
import random
1✔
27
import sys
1✔
28
import time
1✔
29
from abc import ABC
1✔
30
from dataclasses import dataclass
1✔
31
from functools import reduce, lru_cache, partial
1✔
32
from typing import (
1✔
33
    Dict,
34
    List,
35
    Set,
36
    Optional,
37
    Tuple,
38
    Union,
39
    cast,
40
    Callable,
41
    Iterable,
42
    Sequence,
43
)
44

45
import pkg_resources
1✔
46
import z3
1✔
47
from grammar_graph import gg
1✔
48
from grammar_graph.gg import GrammarGraph
1✔
49
from grammar_to_regex.cfg2regex import RegexConverter
1✔
50
from grammar_to_regex.regex import regex_to_z3
1✔
51
from orderedset import OrderedSet
1✔
52
from packaging import version
1✔
53
from returns.converters import result_to_maybe
1✔
54
from returns.functions import compose, tap
1✔
55
from returns.maybe import Nothing, Some
1✔
56
from returns.pipeline import flow, is_successful
1✔
57
from returns.pointfree import lash
1✔
58
from returns.result import safe, Success
1✔
59

60
import isla.isla_shortcuts as sc
1✔
61
import isla.three_valued_truth
1✔
62
from isla import language
1✔
63
from isla.derivation_tree import DerivationTree
1✔
64
from isla.evaluator import (
1✔
65
    evaluate,
66
    quantified_formula_might_match,
67
    get_toplevel_quantified_formulas,
68
    eliminate_quantifiers,
69
)
70
from isla.evaluator import matches_for_quantified_formula
1✔
71
from isla.existential_helpers import (
1✔
72
    insert_tree,
73
    DIRECT_EMBEDDING,
74
    SELF_EMBEDDING,
75
    CONTEXT_ADDITION,
76
)
77
from isla.fuzzer import GrammarFuzzer, GrammarCoverageFuzzer, expansion_key
1✔
78
from isla.helpers import (
1✔
79
    delete_unreachable,
80
    shuffle,
81
    dict_of_lists_to_list_of_dicts,
82
    weighted_geometric_mean,
83
    assertions_activated,
84
    split_str_with_nonterminals,
85
    cluster_by_common_elements,
86
    is_nonterminal,
87
    canonical,
88
    lazyjoin,
89
    lazystr,
90
    Maybe,
91
    eliminate_suffixes,
92
    get_elem_by_equivalence,
93
    get_expansions,
94
    list_del,
95
    compute_nullable_nonterminals,
96
    eassert,
97
    merge_dict_of_sets,
98
    list_set,
99
)
100
from isla.isla_predicates import (
1✔
101
    STANDARD_STRUCTURAL_PREDICATES,
102
    STANDARD_SEMANTIC_PREDICATES,
103
    COUNT_PREDICATE,
104
)
105
from isla.language import (
1✔
106
    VariablesCollector,
107
    split_conjunction,
108
    split_disjunction,
109
    convert_to_nnf,
110
    ensure_unique_bound_variables,
111
    parse_isla,
112
    get_conjuncts,
113
    parse_bnf,
114
    ForallIntFormula,
115
    set_smt_auto_eval,
116
    NoopFormulaTransformer,
117
    set_smt_auto_subst,
118
    fresh_constant,
119
    to_dnf_clauses,
120
)
121
from isla.mutator import Mutator
1✔
122
from isla.parser import EarleyParser
1✔
123
from isla.type_defs import Grammar, Path, ImmutableList, CanonicalGrammar
1✔
124
from isla.z3_helpers import (
1✔
125
    z3_solve,
126
    z3_subst,
127
    z3_eq,
128
    z3_and,
129
    visit_z3_expr,
130
    smt_string_val_to_string,
131
    parent_relationships_in_z3_expr,
132
    numeric_intervals_from_regex,
133
    z3_or,
134
)
135

136

137
@dataclass(frozen=True)
1✔
138
class SolutionState:
1✔
139
    constraint: language.Formula
1✔
140
    tree: DerivationTree
1✔
141
    level: int = 0
1✔
142
    __hash: Optional[int] = None
1✔
143

144
    def formula_satisfied(
1✔
145
        self, grammar: Grammar
146
    ) -> isla.three_valued_truth.ThreeValuedTruth:
147
        if self.tree.is_open():
1✔
148
            # Have to instantiate variables first
149
            return isla.three_valued_truth.ThreeValuedTruth.unknown()
×
150

151
        return evaluate(self.constraint, self.tree, grammar)
1✔
152

153
    def complete(self) -> bool:
1✔
154
        if not self.tree.is_complete():
1✔
155
            return False
1✔
156

157
        # We assume that any universal quantifier has already been instantiated, if it
158
        # matches, and is thus satisfied, or another unsatisfied constraint resulted
159
        # from the instantiation. Existential, predicate, and SMT formulas have to be
160
        # eliminated first.
161

162
        return self.constraint == sc.true()
1✔
163

164
    # Less-than comparisons are needed for usage in the binary heap queue
165
    def __lt__(self, other: "SolutionState"):
1✔
166
        return hash(self) < hash(other)
1✔
167

168
    def __hash__(self):
1✔
169
        if self.__hash is None:
1✔
170
            result = hash((self.constraint, self.tree))
1✔
171
            object.__setattr__(self, "__hash", result)
1✔
172
            return result
1✔
173

174
        return self.__hash
×
175

176
    def __eq__(self, other):
1✔
177
        return (
1✔
178
            isinstance(other, SolutionState)
179
            and self.constraint == other.constraint
180
            and self.tree.structurally_equal(other.tree)
181
        )
182

183

184
@dataclass(frozen=True)
1✔
185
class CostWeightVector:
1✔
186
    """
1✔
187
    Collection of weights for the
188
    :class:`~isla.solver.GrammarBasedBlackboxCostComputer`.
189
    """
190

191
    tree_closing_cost: float = 0
1✔
192
    constraint_cost: float = 0
1✔
193
    derivation_depth_penalty: float = 0
1✔
194
    low_k_coverage_penalty: float = 0
1✔
195
    low_global_k_path_coverage_penalty: float = 0
1✔
196

197
    def __iter__(self):
1✔
198
        """
199
        Use tuple assignment for objects of this type:
200

201
        >>> v = CostWeightVector(1, 2, 3, 4, 5)
202
        >>> a, b, c, d, e = v
203
        >>> (a, b, c, d, e)
204
        (1, 2, 3, 4, 5)
205

206
        :return: An iterator of the fixed-size list of elements of the weight vector.
207
        """
208
        return iter(
1✔
209
            [
210
                self.tree_closing_cost,
211
                self.constraint_cost,
212
                self.derivation_depth_penalty,
213
                self.low_k_coverage_penalty,
214
                self.low_global_k_path_coverage_penalty,
215
            ]
216
        )
217

218
    def __getitem__(self, item: int) -> float:
1✔
219
        """
220
        Tuple-like access of elements of the vector.
221

222
        >>> v = CostWeightVector(1, 2, 3, 4, 5)
223
        >>> v[3]
224
        4
225

226
        :param item: A numeric index.
227
        :return: The element at index :code:`item`.
228
        """
229
        assert isinstance(item, int)
1✔
230
        return [
1✔
231
            self.tree_closing_cost,
232
            self.constraint_cost,
233
            self.derivation_depth_penalty,
234
            self.low_k_coverage_penalty,
235
            self.low_global_k_path_coverage_penalty,
236
        ][item]
237

238

239
@dataclass(frozen=True)
1✔
240
class CostSettings:
1✔
241
    weight_vector: CostWeightVector
1✔
242
    k: int = 3
1✔
243

244
    def __init__(self, weight_vector: CostWeightVector, k: int = 3):
1✔
245
        assert isinstance(weight_vector, CostWeightVector)
1✔
246
        assert isinstance(k, int)
1✔
247
        object.__setattr__(self, "weight_vector", weight_vector)
1✔
248
        object.__setattr__(self, "k", k)
1✔
249

250

251
STD_COST_SETTINGS = CostSettings(
1✔
252
    CostWeightVector(
253
        tree_closing_cost=6.5,
254
        constraint_cost=1,
255
        derivation_depth_penalty=4,
256
        low_k_coverage_penalty=2,
257
        low_global_k_path_coverage_penalty=19,
258
    ),
259
    k=3,
260
)
261

262

263
@dataclass(frozen=True)
1✔
264
class UnknownResultError(Exception):
1✔
265
    pass
1✔
266

267

268
@dataclass(frozen=True)
1✔
269
class SemanticError(Exception):
1✔
270
    pass
1✔
271

272

273
@dataclass(frozen=True)
1✔
274
class SolverDefaults:
1✔
275
    formula: Optional[language.Formula | str] = None
1✔
276
    structural_predicates: frozenset[language.StructuralPredicate] = (
1✔
277
        STANDARD_STRUCTURAL_PREDICATES
278
    )
279
    semantic_predicates: frozenset[language.SemanticPredicate] = (
1✔
280
        STANDARD_SEMANTIC_PREDICATES
281
    )
282
    max_number_free_instantiations: int = 10
1✔
283
    max_number_smt_instantiations: int = 10
1✔
284
    max_number_tree_insertion_results: int = 5
1✔
285
    enforce_unique_trees_in_queue: bool = False
1✔
286
    debug: bool = False
1✔
287
    cost_computer: Optional["CostComputer"] = None
1✔
288
    timeout_seconds: Optional[int] = None
1✔
289
    global_fuzzer: bool = False
1✔
290
    predicates_unique_in_int_arg: Tuple[language.SemanticPredicate, ...] = (
1✔
291
        COUNT_PREDICATE,
292
    )
293
    fuzzer_factory: Callable[[Grammar], GrammarFuzzer] = (
1✔
294
        lambda grammar: GrammarCoverageFuzzer(grammar)
295
    )
296
    tree_insertion_methods: Optional[int] = None
1✔
297
    activate_unsat_support: bool = False
1✔
298
    grammar_unwinding_threshold: int = 4
1✔
299
    initial_tree: Maybe[DerivationTree] = Nothing
1✔
300
    enable_optimized_z3_queries: bool = True
1✔
301
    start_symbol: Optional[str] = None
1✔
302

303

304
_DEFAULTS = SolverDefaults()
1✔
305

306

307
class ISLaSolver:
1✔
308
    """
1✔
309
    The solver class for ISLa formulas/constraints. Its top-level methods are
310

311
    :meth:`~isla.solver.ISLaSolver.solve`
312
      Use to generate solutions for an ISLa constraint.
313

314
    :meth:`~isla.solver.ISLaSolver.check`
315
      Use to check if an ISLa constraint is satisfied for a given input.
316

317
    :meth:`~isla.solver.ISLaSolver.parse`
318
      Use to parse and validate an input.
319

320
    :meth:`~isla.solver.ISLaSolver.repair`
321
      Use to repair an input such that it satisfies a constraint.
322

323
    :meth:`~isla.solver.ISLaSolver.mutate`
324
      Use to mutate an input such that the result satisfies a constraint.
325
    """
326

327
    def __init__(
1✔
328
        self,
329
        grammar: Grammar | str,
330
        formula: Optional[language.Formula | str] = _DEFAULTS.formula,
331
        structural_predicates: Set[
332
            language.StructuralPredicate
333
        ] = _DEFAULTS.structural_predicates,
334
        semantic_predicates: Set[
335
            language.SemanticPredicate
336
        ] = _DEFAULTS.semantic_predicates,
337
        max_number_free_instantiations: int = _DEFAULTS.max_number_free_instantiations,
338
        max_number_smt_instantiations: int = _DEFAULTS.max_number_smt_instantiations,
339
        max_number_tree_insertion_results: int = _DEFAULTS.max_number_tree_insertion_results,
340
        enforce_unique_trees_in_queue: bool = _DEFAULTS.enforce_unique_trees_in_queue,
341
        debug: bool = _DEFAULTS.debug,
342
        cost_computer: Optional["CostComputer"] = _DEFAULTS.cost_computer,
343
        timeout_seconds: Optional[int] = _DEFAULTS.timeout_seconds,
344
        global_fuzzer: bool = _DEFAULTS.global_fuzzer,
345
        predicates_unique_in_int_arg: Tuple[
346
            language.SemanticPredicate, ...
347
        ] = _DEFAULTS.predicates_unique_in_int_arg,
348
        fuzzer_factory: Callable[[Grammar], GrammarFuzzer] = _DEFAULTS.fuzzer_factory,
349
        tree_insertion_methods: Optional[int] = _DEFAULTS.tree_insertion_methods,
350
        activate_unsat_support: bool = _DEFAULTS.activate_unsat_support,
351
        grammar_unwinding_threshold: int = _DEFAULTS.grammar_unwinding_threshold,
352
        initial_tree: Maybe[DerivationTree] = _DEFAULTS.initial_tree,
353
        enable_optimized_z3_queries: bool = _DEFAULTS.enable_optimized_z3_queries,
354
        start_symbol: Optional[str] = _DEFAULTS.start_symbol,
355
    ):
356
        """
357
        The constructor of :class:`~isla.solver.ISLaSolver` accepts a large number of
358
        parameters. However, all but the first one, :code:`grammar`, are *optional.*
359

360
        The simplest way to construct an ISLa solver is by only providing it with a
361
        grammar only; it then works like a grammar fuzzer.
362

363
        >>> import random
364
        >>> random.seed(1)
365

366
        >>> import string
367
        >>> LANG_GRAMMAR = {
368
        ...     "<start>":
369
        ...         ["<stmt>"],
370
        ...     "<stmt>":
371
        ...         ["<assgn> ; <stmt>", "<assgn>"],
372
        ...     "<assgn>":
373
        ...         ["<var> := <rhs>"],
374
        ...     "<rhs>":
375
        ...         ["<var>", "<digit>"],
376
        ...     "<var>": list(string.ascii_lowercase),
377
        ...     "<digit>": list(string.digits)
378
        ... }
379
        >>>
380
        >>> from isla.solver import ISLaSolver
381
        >>> solver = ISLaSolver(LANG_GRAMMAR)
382
        >>>
383
        >>> str(solver.solve())
384
        'd := 9'
385
        >>> str(solver.solve())
386
        'v := n ; s := r'
387

388
        :param grammar: The underlying grammar; either, as a "Fuzzing Book" dictionary
389
          or in BNF syntax.
390
        :param formula: The formula to solve; either a string or a readily parsed
391
          formula. If no formula is given, a default `true` constraint is assumed, and
392
          the solver falls back to a grammar fuzzer. The number of produced solutions
393
          will then be bound by `max_number_free_instantiations`.
394
        :param structural_predicates: Structural predicates to use when parsing a
395
          formula.
396
        :param semantic_predicates: Semantic predicates to use when parsing a formula.
397
        :param max_number_free_instantiations: Number of times that nonterminals that
398
          are not bound by any formula should be expanded by a coverage-based fuzzer.
399
        :param max_number_smt_instantiations: Number of solutions of SMT formulas that
400
          should be produced.
401
        :param max_number_tree_insertion_results: The maximum number of results when
402
          solving existential quantifiers by tree insertion.
403
        :param enforce_unique_trees_in_queue: If true, states with the same tree as an
404
          already existing tree in the queue are discarded, irrespectively of the
405
          constraint.
406
        :param debug: If true, debug information about the evolution of states is
407
          collected, notably in the field state_tree. The root of the tree is in the
408
          field state_tree_root. The field costs stores the computed cost values for
409
          all new nodes.
410
        :param cost_computer: The `CostComputer` class for computing the cost relevant
411
          to placing states in ISLa's queue.
412
        :param timeout_seconds: Number of seconds after which the solver will terminate.
413
        :param global_fuzzer: If set to True, only one coverage-guided grammar fuzzer
414
          object is used to finish off unconstrained open derivation trees throughout
415
          the whole generation time. This may be beneficial for some targets; e.g., we
416
          experienced that CSV works significantly faster. However, the achieved k-path
417
          coverage can be lower with that setting.
418
        :param predicates_unique_in_int_arg: This is needed in certain cases for
419
          instantiating universal integer quantifiers. The supplied predicates should
420
          have exactly one integer argument, and hold for exactly one integer value
421
          once all other parameters are fixed.
422
        :param fuzzer_factory: Constructor of the fuzzer to use for instantiating
423
          "free" nonterminals.
424
        :param tree_insertion_methods: Combination of methods to use for existential
425
          quantifier elimination by tree insertion. Full selection: `DIRECT_EMBEDDING &
426
          SELF_EMBEDDING & CONTEXT_ADDITION`.
427
        :param activate_unsat_support: Set to True if you assume that a formula might
428
          be unsatisfiable. This triggers additional tests for unsatisfiability that
429
          reduce input generation performance, but might ensure termination (with a
430
          negative solver result) for unsatisfiable problems for which the solver could
431
          otherwise diverge.
432
        :param grammar_unwinding_threshold: When querying the SMT solver, ISLa passes a
433
          regular expression for the syntax of the involved nonterminals. If this
434
          syntax is not regular, we unwind the respective part in the reference grammar
435
          up to a depth of `grammar_unwinding_threshold`. If this is too shallow, it can
436
          happen that an equation etc. cannot be solved; if it is too deep, it can
437
          negatively impact performance (and quite tremendously so).
438
        :param initial_tree: An initial input tree for the queue, if the solver shall
439
          not start from the tree `(<start>, None)`.
440
        :param enable_optimized_z3_queries: Enables preprocessing of Z3 queries (mainly
441
          numeric problems concerning things like length). This can improve performance
442
          significantly; however, it might happen that certain problems cannot be solved
443
          anymore. In that case, this option can/should be deactivated.
444
        :param start_symbol: This is an alternative to `initial_tree` for starting with
445
          a start symbol different form `<start>`. If `start_symbol` is provided, a tree
446
          consisting of a single root node with the value of `start_symbol` is chosen as
447
          initial tree.
448
        """
449
        self.logger = logging.getLogger(type(self).__name__)
1✔
450

451
        # We require at least z3 4.8.13.0. ISLa might work for some older versions, but
452
        # at least for 4.8.8.0, we have witnessed that certain rather easy constraints,
453
        # like equations, don't work since z3 cannot handle the restrictions to more
454
        # complicated regular expressions. This happened in the XML case study with
455
        # constraints of the kind <id> == <id_no_prefix>. At least with 4.8.13.0, this
456
        # works flawlessly, but times out for 4.8.8.0.
457
        #
458
        # This should be solved using Python requirements, which however cannot be done
459
        # currently since the fuzzingbook library inflexibly binds z3 to 4.8.8.0. Thus,
460
        # one has to manually install a newer version and ignore the warning.
461

462
        z3_version = pkg_resources.get_distribution("z3-solver").version
1✔
463
        assert version.parse(z3_version) >= version.parse("4.8.13.0"), (
1✔
464
            f"ISLa requires at least z3 4.8.13.0, present: {z3_version}. "
465
            "Please install a newer z3 version, e.g., using 'pip install z3-solver==4.8.14.0'."
466
        )
467

468
        if isinstance(grammar, str):
1✔
469
            self.grammar = parse_bnf(grammar)
1✔
470
        else:
471
            self.grammar = copy.deepcopy(grammar)
1✔
472

473
        assert start_symbol is None or not is_successful(
1✔
474
            initial_tree
475
        ), "You cannot supply a start symbol *and* an initial tree."
476

477
        if start_symbol is not None:
1✔
478
            self.grammar |= {"<start>": [start_symbol]}
1✔
479
            self.grammar = delete_unreachable(self.grammar)
1✔
480

481
        self.graph = GrammarGraph.from_grammar(self.grammar)
1✔
482
        self.canonical_grammar = canonical(self.grammar)
1✔
483
        self.timeout_seconds = timeout_seconds
1✔
484
        self.start_time: Optional[int] = None
1✔
485
        self.global_fuzzer = global_fuzzer
1✔
486
        self.fuzzer = fuzzer_factory(self.grammar)
1✔
487
        self.fuzzer_factory = fuzzer_factory
1✔
488
        self.predicates_unique_in_int_arg: Set[language.SemanticPredicate] = set(
1✔
489
            predicates_unique_in_int_arg
490
        )
491
        self.grammar_unwinding_threshold = grammar_unwinding_threshold
1✔
492
        self.enable_optimized_z3_queries = enable_optimized_z3_queries
1✔
493

494
        if activate_unsat_support and tree_insertion_methods is None:
1✔
495
            self.tree_insertion_methods = 0
1✔
496
        else:
497
            if activate_unsat_support:
1✔
498
                assert tree_insertion_methods is not None
×
499
                print(
×
500
                    "With activate_unsat_support set, a 0 value for tree_insertion_methods is recommended, "
501
                    f"the current value is: {tree_insertion_methods}",
502
                    file=sys.stderr,
503
                )
504

505
            self.tree_insertion_methods = (
1✔
506
                DIRECT_EMBEDDING + SELF_EMBEDDING + CONTEXT_ADDITION
507
            )
508
            if tree_insertion_methods is not None:
1✔
509
                self.tree_insertion_methods = tree_insertion_methods
1✔
510

511
        self.activate_unsat_support = activate_unsat_support
1✔
512
        self.currently_unsat_checking: bool = False
1✔
513

514
        self.cost_computer = (
1✔
515
            cost_computer
516
            if cost_computer is not None
517
            else GrammarBasedBlackboxCostComputer(STD_COST_SETTINGS, self.graph)
518
        )
519

520
        formula = (
1✔
521
            sc.true()
522
            if formula is None
523
            else (
524
                parse_isla(
525
                    formula, self.grammar, structural_predicates, semantic_predicates
526
                )
527
                if isinstance(formula, str)
528
                else formula
529
            )
530
        )
531

532
        self.formula = ensure_unique_bound_variables(formula)
1✔
533

534
        top_constants: Set[language.Constant] = set(
1✔
535
            [
536
                c
537
                for c in VariablesCollector.collect(self.formula)
538
                if isinstance(c, language.Constant) and not c.is_numeric()
539
            ]
540
        )
541

542
        assert len(top_constants) <= 1, (
1✔
543
            "ISLa only accepts up to one constant (free variable), "
544
            + f'found {len(top_constants)}: {", ".join(map(str, top_constants))}'
545
        )
546

547
        only_top_constant = Maybe.from_optional(next(iter(top_constants), None))
1✔
548
        self.top_constant = only_top_constant.map(
1✔
549
            lambda c: language.Constant(
550
                c.name, Maybe.from_optional(start_symbol).value_or(c.n_type)
551
            )
552
        )
553

554
        if only_top_constant != self.top_constant:
1✔
555
            assert is_successful(only_top_constant)
1✔
556
            assert is_successful(self.top_constant)
1✔
557
            self.formula = self.formula.substitute_variables(
1✔
558
                {only_top_constant.unwrap(): self.top_constant.unwrap()}
559
            )
560

561
        quantifier_chains: List[Tuple[language.ForallFormula, ...]] = [
1✔
562
            tuple([f for f in c if isinstance(f, language.ForallFormula)])
563
            for c in get_quantifier_chains(self.formula)
564
        ]
565
        # TODO: Remove?
566
        self.quantifier_chains: List[Tuple[language.ForallFormula, ...]] = [
1✔
567
            c for c in quantifier_chains if c
568
        ]
569

570
        self.max_number_free_instantiations: int = max_number_free_instantiations
1✔
571
        self.max_number_smt_instantiations: int = max_number_smt_instantiations
1✔
572
        self.max_number_tree_insertion_results = max_number_tree_insertion_results
1✔
573
        self.enforce_unique_trees_in_queue = enforce_unique_trees_in_queue
1✔
574

575
        # Initialize Queue
576
        self.initial_tree = (
1✔
577
            initial_tree.lash(
578
                lambda _: Maybe.from_optional(start_symbol)
579
                .map(lambda s: eassert(s, s in self.grammar))
580
                .map(lambda s: DerivationTree(s, None))
581
            ).lash(
582
                lambda _: Some(
583
                    DerivationTree(
584
                        self.top_constant.map(lambda c: c.n_type).value_or("<start>"),
585
                        None,
586
                    )
587
                )
588
            )
589
        ).unwrap()
590

591
        initial_formula = self.top_constant.map(
1✔
592
            lambda c: self.formula.substitute_expressions({c: self.initial_tree})
593
        ).value_or(self.formula)
594
        initial_state = SolutionState(initial_formula, self.initial_tree)
1✔
595
        initial_states = self.establish_invariant(initial_state)
1✔
596

597
        self.queue: List[Tuple[float, SolutionState]] = []
1✔
598
        self.tree_hashes_in_queue: Set[int] = {self.initial_tree.structural_hash()}
1✔
599
        self.state_hashes_in_queue: Set[int] = {hash(state) for state in initial_states}
1✔
600
        for state in initial_states:
1✔
601
            heapq.heappush(self.queue, (self.compute_cost(state), state))
1✔
602

603
        self.seen_coverages: Set[str] = set()
1✔
604
        self.current_level: int = 0
1✔
605
        self.step_cnt: int = 0
1✔
606
        self.last_cost_recomputation: int = 0
1✔
607

608
        self.regex_cache = {}
1✔
609

610
        self.solutions: List[DerivationTree] = []
1✔
611

612
        # Debugging stuff
613
        self.debug = debug
1✔
614
        self.state_tree: Dict[SolutionState, List[SolutionState]] = (
1✔
615
            {}
616
        )  # is only filled if self.debug
617
        self.state_tree_root = None
1✔
618
        self.current_state = None
1✔
619
        self.costs: Dict[SolutionState, float] = {}
1✔
620

621
        if self.debug:
1✔
622
            self.state_tree[initial_state] = initial_states
1✔
623
            self.state_tree_root = initial_state
1✔
624
            self.costs[initial_state] = 0
1✔
625
            for state in initial_states:
1✔
626
                self.costs[state] = self.compute_cost(state)
1✔
627

628
    def solve(self) -> DerivationTree:
1✔
629
        """
630
        Attempts to compute a solution to the given ISLa formula. Returns that solution,
631
        if any. This function can be called repeatedly to obtain more solutions until
632
        one of two exception types is raised: A :class:`StopIteration` indicates that
633
        no more solution can be found; a :class:`TimeoutError` is raised if a timeout
634
        occurred. After that, an exception will be raised every time.
635

636
        The timeout can be controlled by the :code:`timeout_seconds`
637
        :meth:`constructor <isla.solver.ISLaSolver.__init__>` parameter.
638

639
        :return: A solution for the ISLa formula passed to the
640
          :class:`isla.solver.ISLaSolver`.
641
        """
642
        if self.timeout_seconds is not None and self.start_time is None:
1✔
643
            self.start_time = int(time.time())
1✔
644

645
        while self.queue:
1✔
646
            self.step_cnt += 1
1✔
647

648
            # import dill as pickle
649
            # state_hash = 9107154106757938105
650
            # out_file = "/tmp/saved_debug_state"
651
            # if hash(self.queue[0][1]) == state_hash:
652
            #     with open(out_file, 'wb') as debug_state_file:
653
            #         pickle.dump(self, debug_state_file)
654
            #     print(f"Dumping state to {out_file}")
655
            #     exit()
656

657
            if self.timeout_seconds is not None:
1✔
658
                if int(time.time()) - self.start_time > self.timeout_seconds:
1✔
659
                    self.logger.debug("TIMEOUT")
1✔
660
                    raise TimeoutError(self.timeout_seconds)
1✔
661

662
            if self.solutions:
1✔
663
                solution = self.solutions.pop(0)
1✔
664
                self.logger.debug('Found solution "%s"', solution)
1✔
665
                return solution
1✔
666

667
            cost: int
668
            state: SolutionState
669
            cost, state = heapq.heappop(self.queue)
1✔
670

671
            self.current_level = state.level
1✔
672
            self.tree_hashes_in_queue.discard(state.tree.structural_hash())
1✔
673
            self.state_hashes_in_queue.discard(hash(state))
1✔
674

675
            if self.debug:
1✔
676
                self.current_state = state
1✔
677
                self.state_tree.setdefault(state, [])
1✔
678
            self.logger.debug(
1✔
679
                "Polling new state (%s, %s) (hash %d, cost %f)",
680
                state.constraint,
681
                state.tree.to_string(show_open_leaves=True, show_ids=True),
682
                hash(state),
683
                cost,
684
            )
685
            self.logger.debug("Queue length: %s", len(self.queue))
1✔
686

687
            assert not isinstance(state.constraint, language.DisjunctiveFormula)
1✔
688

689
            # Instantiate all top-level structural predicate formulas.
690
            state = self.instantiate_structural_predicates(state)
1✔
691

692
            # Apply the first elimination function that is applicable.
693
            # The later ones are ignored.
694
            def process_and_extend_solutions(
1✔
695
                result_states: List[SolutionState],
696
            ) -> Nothing:
697
                assert result_states is not None
1✔
698
                self.solutions.extend(self.process_new_states(result_states))
1✔
699
                return Nothing
1✔
700

701
            flow(
1✔
702
                Nothing,
703
                *map(
704
                    compose(lambda f: (lambda _: f(state)), lash),
705
                    [
706
                        self.noop_on_false_constraint,
707
                        self.eliminate_existential_integer_quantifiers,
708
                        self.instantiate_universal_integer_quantifiers,
709
                        self.match_all_universal_formulas,
710
                        self.expand_to_match_quantifiers,
711
                        self.eliminate_all_semantic_formulas,
712
                        self.eliminate_all_ready_semantic_predicate_formulas,
713
                        self.eliminate_and_match_first_existential_formula_and_expand,
714
                        self.assert_remaining_formulas_are_lazy_binding_semantic,
715
                        self.finish_unconstrained_trees,
716
                        self.expand,
717
                    ],
718
                ),
719
            ).bind(process_and_extend_solutions)
720

721
        if self.solutions:
1✔
722
            solution = self.solutions.pop(0)
1✔
723
            self.logger.debug('Found solution "%s"', solution)
1✔
724
            return solution
1✔
725
        else:
726
            self.logger.debug("UNSAT")
1✔
727
            raise StopIteration()
1✔
728

729
    def check(self, inp: DerivationTree | str) -> bool:
1✔
730
        """
731
        Evaluates whether the given derivation tree satisfies the constraint passed to
732
        the solver. Raises an `UnknownResultError` if this could not be evaluated
733
        (e.g., because of a solver timeout or a semantic predicate that cannot be
734
        evaluated).
735

736
        :param inp: The input to evaluate, either readily parsed or as a string.
737
        :return: A truth value.
738
        """
739
        if isinstance(inp, str):
1✔
740
            try:
1✔
741
                self.parse(inp)
1✔
742
                return True
1✔
743
            except (SyntaxError, SemanticError):
1✔
744
                return False
1✔
745

746
        assert isinstance(inp, DerivationTree)
1✔
747

748
        result = evaluate(self.formula, inp, self.grammar)
1✔
749

750
        if result.is_unknown():
1✔
751
            raise UnknownResultError()
1✔
752
        else:
753
            return bool(result)
1✔
754

755
    def parse(
1✔
756
        self,
757
        inp: str,
758
        nonterminal: str = "<start>",
759
        skip_check: bool = False,
760
        silent: bool = False,
761
    ) -> DerivationTree:
762
        """
763
        Parses the given input `inp`. Raises a `SyntaxError` if the input does not
764
        satisfy the grammar, a `SemanticError` if it does not satisfy the constraint
765
        (this is only checked if `nonterminal` is "<start>"), and returns the parsed
766
        `DerivationTree` otherwise.
767

768
        :param inp: The input to parse.
769
        :param nonterminal: The nonterminal to start parsing with, if a string
770
          corresponding to a sub-grammar shall be parsed. We don't check semantic
771
          correctness in that case.
772
        :param skip_check: If True, the semantic check is left out.
773
        :param silent: If True, no error is sent to the log stream in case of a
774
            failed parse.
775
        :return: A parsed `DerivationTree`.
776
        """
777
        grammar = copy.deepcopy(self.grammar)
1✔
778
        if nonterminal != "<start>":
1✔
779
            grammar |= {"<start>": [nonterminal]}
1✔
780
            grammar = delete_unreachable(grammar)
1✔
781

782
        parser = EarleyParser(grammar)
1✔
783
        try:
1✔
784
            parse_tree = next(parser.parse(inp))
1✔
785
            if nonterminal != "<start>":
1✔
786
                parse_tree = parse_tree[1][0]
1✔
787
            tree = DerivationTree.from_parse_tree(parse_tree)
1✔
788
        except SyntaxError as err:
1✔
789
            if not silent:
1✔
790
                self.logger.error(
1✔
791
                    f'Error parsing "{inp}" starting with "{nonterminal}"'
792
                )
793
            raise err
1✔
794

795
        if not skip_check and nonterminal == "<start>" and not self.check(tree):
1✔
796
            raise SemanticError()
1✔
797

798
        return tree
1✔
799

800
    def repair(
1✔
801
        self, inp: DerivationTree | str, fix_timeout_seconds: float = 3
802
    ) -> Maybe[DerivationTree]:
803
        """
804
        Attempts to repair the given input. The input must not violate syntactic
805
        (grammar) constraints. If semantic constraints are violated, this method
806
        gradually abstracts the input and tries to turn it into a valid one.
807
        Note that intensive structural manipulations are not performed; we merely
808
        try to satisfy SMT-LIB and semantic predicate constraints.
809

810
        :param inp: The input to fix.
811
        :param fix_timeout_seconds: A timeout used when calling the solver for an
812
          abstracted input. Usually, a low timeout suffices.
813
        :return: A fixed input (or the original, if it was not broken) or nothing.
814
        """
815

816
        inp = self.parse(inp, skip_check=True) if isinstance(inp, str) else inp
1✔
817

818
        try:
1✔
819
            if self.check(inp) or not is_successful(self.top_constant):
1✔
820
                return Some(inp)
1✔
821
        except UnknownResultError:
1✔
822
            pass
1✔
823

824
        formula = self.top_constant.map(
1✔
825
            lambda c: self.formula.substitute_expressions({c: inp})
826
        ).unwrap()
827

828
        set_smt_auto_eval(formula, False)
1✔
829
        set_smt_auto_subst(formula, False)
1✔
830

831
        qfr_free = eliminate_quantifiers(
1✔
832
            formula,
833
            grammar=self.grammar,
834
            numeric_constants={
835
                c
836
                for c in VariablesCollector.collect(formula)
837
                if isinstance(c, language.Constant) and c.is_numeric()
838
            },
839
        )
840

841
        # We first evaluate all structural predicates; for now, we do not interfere
842
        # with structure.
843
        semantic_only = qfr_free.transform(EvaluatePredicateFormulasTransformer(inp))
1✔
844

845
        if semantic_only == sc.false():
1✔
846
            # This cannot be repaired while preserving structure; for existential
847
            # problems, we could try tree insertion. We leave this for future work.
848
            return Nothing
1✔
849

850
        # We try to satisfy any of the remaining disjunctive elements, in random order
851
        for formula_to_satisfy in shuffle(split_disjunction(semantic_only)):
1✔
852
            # Now, we consider all combinations of 1, 2, ... of the derivation trees
853
            # participating in the formula. We successively prune deeper and deeper
854
            # subtrees until the resulting input evaluates to "unknown" for the given
855
            # formula.
856

857
            participating_paths = {
1✔
858
                inp.find_node(arg) for arg in formula_to_satisfy.tree_arguments()
859
            }
860

861
            def do_complete(tree: DerivationTree) -> Maybe[DerivationTree]:
1✔
862
                return result_to_maybe(
1✔
863
                    safe(
864
                        self.copy_without_queue(
865
                            initial_tree=Some(tree),
866
                            timeout_seconds=Some(fix_timeout_seconds),
867
                        ).solve,
868
                        (UnknownResultError, TimeoutError, StopIteration),
869
                    )()
870
                )
871

872
            # If p1, p2 are in participating_paths, then we consider the following
873
            # path combinations (roughly) in the listed order:
874
            # {p1}, {p2}, {p1, p2}, {p1[:-1]}, {p2[:-1]}, {p1[:-1], p2}, {p1, p2[:-1]},
875
            # {p1[:-1], p2[:-1]}, ...
876
            for abstracted_tree in generate_abstracted_trees(inp, participating_paths):
1✔
877
                match (
1✔
878
                    safe(lambda: self.check(abstracted_tree))()
879
                    .bind(lambda _: Nothing)
880
                    .lash(
881
                        lambda exc: (
882
                            Some(abstracted_tree)
883
                            if isinstance(exc, UnknownResultError)
884
                            else Nothing
885
                        )
886
                    )
887
                    .bind(do_complete)
888
                ):
889
                    case Some(completed):
1✔
890
                        return Some(completed)
1✔
891
                    case _:
1✔
892
                        pass
1✔
893

894
        return Nothing
×
895

896
    def mutate(
1✔
897
        self,
898
        inp: DerivationTree | str,
899
        min_mutations: int = 2,
900
        max_mutations: int = 5,
901
        fix_timeout_seconds: float = 1,
902
    ) -> DerivationTree:
903
        """
904
        Mutates `inp` such that the result satisfies the constraint.
905

906
        :param inp: The input to mutate.
907
        :param min_mutations: The minimum number of mutation steps to perform.
908
        :param max_mutations: The maximum number of mutation steps to perform.
909
        :param fix_timeout_seconds: A timeout used when calling the solver for fixing
910
          an abstracted input. Usually, a low timeout suffices.
911
        :return: A mutated input.
912
        """
913

914
        inp = self.parse(inp, skip_check=True) if isinstance(inp, str) else inp
1✔
915
        mutator = Mutator(
1✔
916
            self.grammar,
917
            min_mutations=min_mutations,
918
            max_mutations=max_mutations,
919
            graph=self.graph,
920
        )
921

922
        while True:
1✔
923
            mutated = mutator.mutate(inp)
1✔
924
            if mutated.structurally_equal(inp):
1✔
925
                continue
1✔
926
            maybe_fixed = self.repair(mutated, fix_timeout_seconds)
1✔
927
            if is_successful(maybe_fixed):
1✔
928
                return maybe_fixed.unwrap()
1✔
929

930
    def copy_without_queue(
1✔
931
        self,
932
        grammar: Maybe[Grammar | str] = Nothing,
933
        formula: Maybe[language.Formula | str] = Nothing,
934
        max_number_free_instantiations: Maybe[int] = Nothing,
935
        max_number_smt_instantiations: Maybe[int] = Nothing,
936
        max_number_tree_insertion_results: Maybe[int] = Nothing,
937
        enforce_unique_trees_in_queue: Maybe[bool] = Nothing,
938
        debug: Maybe[bool] = Nothing,
939
        cost_computer: Maybe["CostComputer"] = Nothing,
940
        timeout_seconds: Maybe[int] = Nothing,
941
        global_fuzzer: Maybe[bool] = Nothing,
942
        predicates_unique_in_int_arg: Maybe[
943
            Tuple[language.SemanticPredicate, ...]
944
        ] = Nothing,
945
        fuzzer_factory: Maybe[Callable[[Grammar], GrammarFuzzer]] = Nothing,
946
        tree_insertion_methods: Maybe[int] = Nothing,
947
        activate_unsat_support: Maybe[bool] = Nothing,
948
        grammar_unwinding_threshold: Maybe[int] = Nothing,
949
        initial_tree: Maybe[DerivationTree] = Nothing,
950
        enable_optimized_z3_queries: Maybe[bool] = Nothing,
951
        start_symbol: Optional[str] = None,
952
    ):
953
        result = ISLaSolver(
1✔
954
            grammar=grammar.value_or(self.grammar),
955
            formula=formula.value_or(self.formula),
956
            max_number_free_instantiations=max_number_free_instantiations.value_or(
957
                self.max_number_free_instantiations
958
            ),
959
            max_number_smt_instantiations=max_number_smt_instantiations.value_or(
960
                self.max_number_smt_instantiations
961
            ),
962
            max_number_tree_insertion_results=max_number_tree_insertion_results.value_or(
963
                self.max_number_tree_insertion_results
964
            ),
965
            enforce_unique_trees_in_queue=enforce_unique_trees_in_queue.value_or(
966
                self.enforce_unique_trees_in_queue
967
            ),
968
            debug=debug.value_or(self.debug),
969
            cost_computer=cost_computer.value_or(self.cost_computer),
970
            timeout_seconds=timeout_seconds.value_or(self.timeout_seconds),
971
            global_fuzzer=global_fuzzer.value_or(self.global_fuzzer),
972
            predicates_unique_in_int_arg=predicates_unique_in_int_arg.value_or(
973
                self.predicates_unique_in_int_arg
974
            ),
975
            fuzzer_factory=fuzzer_factory.value_or(self.fuzzer_factory),
976
            tree_insertion_methods=tree_insertion_methods.value_or(
977
                self.tree_insertion_methods
978
            ),
979
            activate_unsat_support=activate_unsat_support.value_or(
980
                self.activate_unsat_support
981
            ),
982
            grammar_unwinding_threshold=grammar_unwinding_threshold.value_or(
983
                self.grammar_unwinding_threshold
984
            ),
985
            initial_tree=initial_tree,
986
            enable_optimized_z3_queries=enable_optimized_z3_queries.value_or(
987
                self.enable_optimized_z3_queries
988
            ),
989
            start_symbol=start_symbol,
990
        )
991

992
        result.regex_cache = self.regex_cache
1✔
993

994
        return result
1✔
995

996
    @staticmethod
1✔
997
    def noop_on_false_constraint(
1✔
998
        state: SolutionState,
999
    ) -> Maybe[List[SolutionState]]:
1000
        if state.constraint == sc.false():
1✔
1001
            # This state can be silently discarded.
1002
            return Some([state])
1✔
1003

1004
        return Nothing
1✔
1005

1006
    def expand_to_match_quantifiers(
1✔
1007
        self,
1008
        state: SolutionState,
1009
    ) -> Maybe[List[SolutionState]]:
1010
        if all(
1✔
1011
            not isinstance(conjunct, language.ForallFormula)
1012
            for conjunct in get_conjuncts(state.constraint)
1013
        ):
1014
            return Nothing
1✔
1015

1016
        expansion_result = self.expand_tree(state)
1✔
1017

1018
        assert len(expansion_result) > 0, f"State {state} will never leave the queue."
1✔
1019
        self.logger.debug(
1✔
1020
            "Expanding state %s (%d successors)", state, len(expansion_result)
1021
        )
1022

1023
        return Some(expansion_result)
1✔
1024

1025
    def eliminate_and_match_first_existential_formula_and_expand(
1✔
1026
        self,
1027
        state: SolutionState,
1028
    ) -> Maybe[List[SolutionState]]:
1029
        elim_result = self.eliminate_and_match_first_existential_formula(state)
1✔
1030
        if elim_result is None:
1✔
1031
            return Nothing
1✔
1032

1033
        # Also add some expansions of the original state, to create a larger
1034
        # solution stream (otherwise, it might be possible that only a small
1035
        # finite number of solutions are generated for existential formulas).
1036
        return Some(
1✔
1037
            elim_result + self.expand_tree(state, limit=2, only_universal=False)
1038
        )
1039

1040
    def assert_remaining_formulas_are_lazy_binding_semantic(
1✔
1041
        self,
1042
        state: SolutionState,
1043
    ) -> Maybe[List[SolutionState]]:
1044
        # SEMANTIC PREDICATE FORMULAS can remain if they bind lazily. In that case, we can choose a random
1045
        # instantiation and let the predicate "fix" the resulting tree.
1046
        assert state.constraint == sc.true() or all(
1✔
1047
            isinstance(conjunct, language.SemanticPredicateFormula)
1048
            or (
1049
                isinstance(conjunct, language.NegatedFormula)
1050
                and isinstance(conjunct.args[0], language.SemanticPredicateFormula)
1051
            )
1052
            for conjunct in get_conjuncts(state.constraint)
1053
        ), (
1054
            "Constraint is not true and contains formulas "
1055
            f"other than semantic predicate formulas: {state.constraint}"
1056
        )
1057

1058
        assert (
1✔
1059
            state.constraint == sc.true()
1060
            or all(
1061
                not pred_formula.binds_tree(leaf)
1062
                for pred_formula in get_conjuncts(state.constraint)
1063
                if isinstance(pred_formula, language.SemanticPredicateFormula)
1064
                for _, leaf in state.tree.open_leaves()
1065
            )
1066
            or all(
1067
                not cast(
1068
                    language.SemanticPredicateFormula, pred_formula.args[0]
1069
                ).binds_tree(leaf)
1070
                for pred_formula in get_conjuncts(state.constraint)
1071
                if isinstance(pred_formula, language.NegatedFormula)
1072
                and isinstance(pred_formula.args[0], language.SemanticPredicateFormula)
1073
            )
1074
            for _, leaf in state.tree.open_leaves()
1075
        ), (
1076
            "Constraint is not true and contains semantic predicate formulas binding open tree leaves: "
1077
            f"{state.constraint}, leaves: "
1078
            + ", ".join(
1079
                [str(leaf) for _, leaf in state.tree.open_leaves()],
1080
            )
1081
        )
1082

1083
        return Nothing
1✔
1084

1085
    def finish_unconstrained_trees(
1✔
1086
        self,
1087
        state: SolutionState,
1088
    ) -> Maybe[List[SolutionState]]:
1089
        fuzzer = (
1✔
1090
            self.fuzzer if self.global_fuzzer else self.fuzzer_factory(self.grammar)
1091
        )
1092

1093
        if isinstance(fuzzer, GrammarCoverageFuzzer):
1✔
1094
            fuzzer.covered_expansions.update(self.seen_coverages)
1✔
1095

1096
        if state.constraint != sc.true():
1✔
1097
            return Nothing
1✔
1098

1099
        closed_results: List[SolutionState] = []
1✔
1100
        for _ in range(self.max_number_free_instantiations):
1✔
1101
            result = state.tree
1✔
1102
            for path, leaf in state.tree.open_leaves():
1✔
1103
                leaf_inst = fuzzer.expand_tree(DerivationTree(leaf.value, None))
1✔
1104
                result = result.replace_path(path, leaf_inst)
1✔
1105

1106
            closed_results.append(SolutionState(state.constraint, result))
1✔
1107

1108
        return Some(closed_results)
1✔
1109

1110
    def expand(
1✔
1111
        self,
1112
        state: SolutionState,
1113
    ) -> Maybe[List[SolutionState]]:
1114
        fuzzer = (
1✔
1115
            self.fuzzer if self.global_fuzzer else self.fuzzer_factory(self.grammar)
1116
        )
1117

1118
        if isinstance(fuzzer, GrammarCoverageFuzzer):
1✔
1119
            fuzzer.covered_expansions.update(self.seen_coverages)
1✔
1120

1121
        result: List[SolutionState] = []
1✔
1122
        for _ in range(self.max_number_free_instantiations):
1✔
1123
            substitutions: Dict[DerivationTree, DerivationTree] = {
1✔
1124
                subtree: fuzzer.expand_tree(DerivationTree(subtree.value, None))
1125
                for path, subtree in state.tree.open_leaves()
1126
            }
1127

1128
            if substitutions:
1✔
1129
                result.append(
1✔
1130
                    SolutionState(
1131
                        state.constraint.substitute_expressions(substitutions),
1132
                        state.tree.substitute(substitutions),
1133
                    )
1134
                )
1135

1136
        return Some(result)
1✔
1137

1138
    def instantiate_structural_predicates(self, state: SolutionState) -> SolutionState:
1✔
1139
        predicate_formulas = [
1✔
1140
            pred_formula
1141
            for pred_formula in language.FilterVisitor(
1142
                lambda f: isinstance(f, language.StructuralPredicateFormula)
1143
            ).collect(state.constraint)
1144
            if (
1145
                isinstance(pred_formula, language.StructuralPredicateFormula)
1146
                and all(
1147
                    not isinstance(arg, language.Variable) for arg in pred_formula.args
1148
                )
1149
            )
1150
        ]
1151

1152
        formula = state.constraint
1✔
1153
        for predicate_formula in predicate_formulas:
1✔
1154
            instantiation = language.SMTFormula(
1✔
1155
                z3.BoolVal(predicate_formula.evaluate(state.tree))
1156
            )
1157
            self.logger.debug(
1✔
1158
                "Eliminating (-> %s) structural predicate formula %s",
1159
                instantiation,
1160
                predicate_formula,
1161
            )
1162
            formula = language.replace_formula(
1✔
1163
                formula, predicate_formula, instantiation
1164
            )
1165

1166
        return SolutionState(formula, state.tree)
1✔
1167

1168
    def eliminate_existential_integer_quantifiers(
1✔
1169
        self, state: SolutionState
1170
    ) -> Maybe[List[SolutionState]]:
1171
        existential_int_formulas = [
1✔
1172
            conjunct
1173
            for conjunct in get_conjuncts(state.constraint)
1174
            if isinstance(conjunct, language.ExistsIntFormula)
1175
        ]
1176

1177
        if not existential_int_formulas:
1✔
1178
            return Nothing
1✔
1179

1180
        formula = state.constraint
1✔
1181
        for existential_int_formula in existential_int_formulas:
1✔
1182
            # The following check for validity is not only a performance measure, but required
1183
            # when existential integer formulas are re-inserted. Otherwise, new constants get
1184
            # introduced, and the solver won't terminate.
1185
            if evaluate(
1✔
1186
                existential_int_formula,
1187
                state.tree,
1188
                self.grammar,
1189
                assumptions={
1190
                    f
1191
                    for f in split_conjunction(state.constraint)
1192
                    if f != existential_int_formula
1193
                },
1194
            ).is_true():
1195
                self.logger.debug(
1✔
1196
                    "Removing existential integer quantifier '%.30s', already implied "
1197
                    "by tree and existing constraints",
1198
                    existential_int_formula,
1199
                )
1200
                # This should simplify the process after quantifier re-insertion.
1201
                return Some(
1✔
1202
                    [
1203
                        SolutionState(
1204
                            language.replace_formula(
1205
                                state.constraint, existential_int_formula, sc.true()
1206
                            ),
1207
                            state.tree,
1208
                        )
1209
                    ]
1210
                )
1211

1212
            self.logger.debug(
1✔
1213
                "Eliminating existential integer quantifier %s", existential_int_formula
1214
            )
1215
            used_vars = set(VariablesCollector.collect(formula))
1✔
1216
            fresh = language.fresh_constant(
1✔
1217
                used_vars,
1218
                language.Constant(
1219
                    existential_int_formula.bound_variable.name,
1220
                    existential_int_formula.bound_variable.n_type,
1221
                ),
1222
            )
1223
            instantiation = existential_int_formula.inner_formula.substitute_variables(
1✔
1224
                {existential_int_formula.bound_variable: fresh}
1225
            )
1226
            formula = language.replace_formula(
1✔
1227
                formula, existential_int_formula, instantiation
1228
            )
1229

1230
        return Some([SolutionState(formula, state.tree)])
1✔
1231

1232
    def instantiate_universal_integer_quantifiers(
1✔
1233
        self, state: SolutionState
1234
    ) -> Maybe[List[SolutionState]]:
1235
        universal_int_formulas = [
1✔
1236
            conjunct
1237
            for conjunct in get_conjuncts(state.constraint)
1238
            if isinstance(conjunct, language.ForallIntFormula)
1239
        ]
1240

1241
        if not universal_int_formulas:
1✔
1242
            return Nothing
1✔
1243

1244
        results: List[SolutionState] = [state]
1✔
1245
        for universal_int_formula in universal_int_formulas:
1✔
1246
            results = [
1✔
1247
                result
1248
                for formula_list in [
1249
                    self.instantiate_universal_integer_quantifier(
1250
                        previous_result, universal_int_formula
1251
                    )
1252
                    for previous_result in results
1253
                ]
1254
                for result in formula_list
1255
            ]
1256

1257
        return Some(results)
1✔
1258

1259
    def instantiate_universal_integer_quantifier(
1✔
1260
        self, state: SolutionState, universal_int_formula: language.ForallIntFormula
1261
    ) -> List[SolutionState]:
1262
        results = self.instantiate_universal_integer_quantifier_by_enumeration(
1✔
1263
            state, universal_int_formula
1264
        )
1265

1266
        if results:
1✔
1267
            return results
1✔
1268

1269
        return self.instantiate_universal_integer_quantifier_by_transformation(
1✔
1270
            state, universal_int_formula
1271
        )
1272

1273
    def instantiate_universal_integer_quantifier_by_transformation(
1✔
1274
        self, state: SolutionState, universal_int_formula: language.ForallIntFormula
1275
    ) -> List[SolutionState]:
1276
        # If the enumeration approach was not successful, we con transform the universal int
1277
        # quantifier to an existential one in a particular situation:
1278
        #
1279
        # Let phi(elem, i) be such that phi(elem) (for fixed first argument) is a unary
1280
        # relation that holds for exactly one argument:
1281
        #
1282
        # forall <A> elem:
1283
        #   exists int i:
1284
        #     phi(elem, i) and
1285
        #     forall int i':
1286
        #       phi(elem, i) <==> i = i'
1287
        #
1288
        # Then, the following transformations are equivalence-preserving:
1289
        #
1290
        # forall int i:
1291
        #   exists <A> elem:
1292
        #     not phi(elem, i)
1293
        #
1294
        # <==> (*)
1295
        #
1296
        # exists int i:
1297
        #   exists <A> elem':
1298
        #     phi(elem', i) &
1299
        #   exists <A> elem:
1300
        #     not phi(elem, i) &
1301
        #   forall int i':
1302
        #     i != i' ->
1303
        #     exists <A> elem'':
1304
        #       not phi(elem'', i')
1305
        #
1306
        # <==> (+)
1307
        #
1308
        # exists int i:
1309
        #   exists <A> elem':
1310
        #     phi(elem', i) &
1311
        #   exists <A> elem:
1312
        #     not phi(elem, i)
1313
        #
1314
        # (*)
1315
        # Problematic is only the first inner conjunct. However, for every elem, there
1316
        # has to be an i such that phi(elem, i) holds. If there is no no in the first
1317
        # place, also the original formula would be unsatisfiable. Without this conjunct,
1318
        # the transformation is a simple "quantifier unwinding."
1319
        #
1320
        # (+)
1321
        # Let i' != i. Choose elem'' := elem': Since phi(elem', i) holds and i != i',
1322
        # "not phi(elem', i')" has to hold.
1323

1324
        if (
1✔
1325
            isinstance(universal_int_formula.inner_formula, language.ExistsFormula)
1326
            and isinstance(
1327
                universal_int_formula.inner_formula.inner_formula,
1328
                language.NegatedFormula,
1329
            )
1330
            and isinstance(
1331
                universal_int_formula.inner_formula.inner_formula.args[0],
1332
                language.SemanticPredicateFormula,
1333
            )
1334
            and cast(
1335
                language.SemanticPredicateFormula,
1336
                universal_int_formula.inner_formula.inner_formula.args[0],
1337
            ).predicate
1338
            in self.predicates_unique_in_int_arg
1339
        ):
1340
            inner_formula: language.ExistsFormula = universal_int_formula.inner_formula
1✔
1341
            predicate_formula: language.SemanticPredicateFormula = cast(
1✔
1342
                language.SemanticPredicateFormula,
1343
                cast(language.NegatedFormula, inner_formula.inner_formula).args[0],
1344
            )
1345

1346
            fresh_var = language.fresh_bound_variable(
1✔
1347
                language.VariablesCollector().collect(state.constraint),
1348
                inner_formula.bound_variable,
1349
                add=False,
1350
            )
1351

1352
            new_formula = language.ExistsIntFormula(
1✔
1353
                universal_int_formula.bound_variable,
1354
                language.ExistsFormula(
1355
                    fresh_var,
1356
                    inner_formula.in_variable,
1357
                    predicate_formula.substitute_variables(
1358
                        {inner_formula.bound_variable: fresh_var}
1359
                    ),
1360
                )
1361
                & inner_formula,
1362
            )
1363

1364
            self.logger.debug(
1✔
1365
                "Transforming universal integer quantifier "
1366
                "(special case, see code comments for explanation):\n%s ==> %s",
1367
                universal_int_formula,
1368
                new_formula,
1369
            )
1370

1371
            return [
1✔
1372
                SolutionState(
1373
                    language.replace_formula(
1374
                        state.constraint, universal_int_formula, new_formula
1375
                    ),
1376
                    state.tree,
1377
                )
1378
            ]
1379

1380
        self.logger.warning(
×
1381
            "Did not find a way to instantiate formula %s!\n"
1382
            + "Discarding this state. Please report this to your nearest ISLa developer.",
1383
            universal_int_formula,
1384
        )
1385

1386
        return []
×
1387

1388
    def instantiate_universal_integer_quantifier_by_enumeration(
1✔
1389
        self, state: SolutionState, universal_int_formula: ForallIntFormula
1390
    ) -> Optional[List[SolutionState]]:
1391
        constant = language.Constant(
1✔
1392
            universal_int_formula.bound_variable.name,
1393
            universal_int_formula.bound_variable.n_type,
1394
        )
1395

1396
        # noinspection PyTypeChecker
1397
        inner_formula = universal_int_formula.inner_formula.substitute_variables(
1✔
1398
            {universal_int_formula.bound_variable: constant}
1399
        )
1400

1401
        instantiations: List[
1✔
1402
            Dict[
1403
                language.Constant | DerivationTree,
1404
                int | language.Constant | DerivationTree,
1405
            ]
1406
        ] = []
1407

1408
        if isinstance(universal_int_formula.inner_formula, language.DisjunctiveFormula):
1✔
1409
            # In the disjunctive case, we attempt to falsify all SMT formulas in the inner formula
1410
            # (on top level) that contain the bound variable as argument.
1411
            smt_disjuncts = [
1✔
1412
                formula
1413
                for formula in language.split_disjunction(inner_formula)
1414
                if isinstance(formula, language.SMTFormula)
1415
                and constant in formula.free_variables()
1416
            ]
1417

1418
            if smt_disjuncts and len(smt_disjuncts) < len(
1✔
1419
                language.split_disjunction(inner_formula)
1420
            ):
1421
                instantiation_values = (
1✔
1422
                    self.infer_satisfying_assignments_for_smt_formula(
1423
                        -reduce(language.SMTFormula.disjunction, smt_disjuncts),
1424
                        constant,
1425
                    )
1426
                )
1427

1428
                # We also try to falsify (negated) semantic predicate formulas, if present,
1429
                # if there exist any remaining disjuncts.
1430
                semantic_predicate_formulas: List[
1✔
1431
                    Tuple[language.SemanticPredicateFormula, bool]
1432
                ] = [
1433
                    (
1434
                        (pred_formula, False)
1435
                        if isinstance(pred_formula, language.SemanticPredicateFormula)
1436
                        else (cast(language.NegatedFormula, pred_formula).args[0], True)
1437
                    )
1438
                    for pred_formula in language.FilterVisitor(
1439
                        lambda f: (
1440
                            constant in f.free_variables()
1441
                            and (
1442
                                isinstance(f, language.SemanticPredicateFormula)
1443
                                or isinstance(f, language.NegatedFormula)
1444
                                and isinstance(
1445
                                    f.args[0], language.SemanticPredicateFormula
1446
                                )
1447
                            )
1448
                        ),
1449
                        do_continue=lambda f: isinstance(
1450
                            f, language.DisjunctiveFormula
1451
                        ),
1452
                    ).collect(inner_formula)
1453
                    if all(
1454
                        not isinstance(var, language.BoundVariable)
1455
                        for var in pred_formula.free_variables()
1456
                    )
1457
                ]
1458

1459
                if semantic_predicate_formulas and len(
1✔
1460
                    semantic_predicate_formulas
1461
                ) + len(smt_disjuncts) < len(language.split_disjunction(inner_formula)):
1462
                    for value in instantiation_values:
1✔
1463
                        instantiation: Dict[
1✔
1464
                            language.Constant | DerivationTree,
1465
                            int | language.Constant | DerivationTree,
1466
                        ] = {constant: value}
1467
                        for (
1✔
1468
                            semantic_predicate_formula,
1469
                            negated,
1470
                        ) in semantic_predicate_formulas:
1471
                            eval_result = cast(
1✔
1472
                                language.SemanticPredicateFormula,
1473
                                language.substitute(
1474
                                    semantic_predicate_formula, {constant: value}
1475
                                ),
1476
                            ).evaluate(self.graph, negate=not negated)
1477
                            if eval_result.ready() and not eval_result.is_boolean():
1✔
1478
                                instantiation.update(eval_result.result)
1✔
1479
                        instantiations.append(instantiation)
1✔
1480
                else:
1481
                    instantiations.extend(
×
1482
                        [{constant: value} for value in instantiation_values]
1483
                    )
1484

1485
        results: List[SolutionState] = []
1✔
1486
        for instantiation in instantiations:
1✔
1487
            self.logger.debug(
1✔
1488
                "Instantiating universal integer quantifier (%s -> %s) %s",
1489
                universal_int_formula.bound_variable,
1490
                instantiation[constant],
1491
                universal_int_formula,
1492
            )
1493

1494
            formula = language.replace_formula(
1✔
1495
                state.constraint,
1496
                universal_int_formula,
1497
                language.substitute(inner_formula, instantiation),
1498
            )
1499
            formula = language.substitute(formula, instantiation)
1✔
1500

1501
            tree = state.tree.substitute(
1✔
1502
                {
1503
                    tree: subst
1504
                    for tree, subst in instantiation.items()
1505
                    if isinstance(tree, DerivationTree)
1506
                }
1507
            )
1508

1509
            results.append(SolutionState(formula, tree))
1✔
1510

1511
        return results
1✔
1512

1513
    def infer_satisfying_assignments_for_smt_formula(
1✔
1514
        self, smt_formula: language.SMTFormula, constant: language.Constant
1515
    ) -> Set[int | language.Constant]:
1516
        """
1517
        This method returns `self.max_number_free_instantiations` many solutions for
1518
        the given :class:`~isla.language.SMTFormula` if `constant` is the only free
1519
        variable in `smt_formula`. The given formula must be a numeric formula, i.e.,
1520
        all free variables must be numeric. If more than one free variables are
1521
        present, at most one solution is returned (see example below).
1522

1523
        :param smt_formula: The :class:`~isla.language.SMTFormula` to solve. Must
1524
          only contain numeric free variables.
1525
        :param constant: One free variable in `smt_formula`.
1526
        :return: A set of solutions (see explanation above & comment below).
1527

1528
        We create a solver with a dummy grammar (it's not needed for this example),
1529
        choosing a value of 5 for `max_number_free_instantiations`.
1530

1531
        >>> solver = ISLaSolver(
1532
        ...     '<start> ::= "x"',  # dummy grammar
1533
        ...     max_number_free_instantiations=5,
1534
        ... )
1535

1536
        The formula we're considering is `x > 10`.
1537

1538
        >>> from isla.language import Constant, SMTFormula, Variable, unparse_isla
1539
        >>> x = Constant("x", Variable.NUMERIC_NTYPE)
1540

1541
        >>> formula = SMTFormula(z3.StrToInt(x.to_smt()) > z3.IntVal(10), x)
1542
        >>> unparse_isla(formula)
1543
        '(< 10 (str.to.int x))'
1544

1545
        We obtain five results (due to our choice of `max_number_free_instantiations`).
1546

1547
        >>> results = solver.infer_satisfying_assignments_for_smt_formula(formula, x)
1548
        >>> len(results)
1549
        5
1550

1551
        All results are `int`s...
1552

1553
        >>> all(isinstance(result, int) for result in results)
1554
        True
1555

1556
        ...and all are strictly greater than 10.
1557

1558
        >>> all(result > 10 for result in results)
1559
        True
1560

1561
        Now, lets consider `x == y`. This formula contains *two* free variables, `x`
1562
        and `y`. It is the only type of formula with more than one variable for which
1563
        this method will return a solution in the current state.
1564

1565
        >>> y = Constant("y", Variable.NUMERIC_NTYPE)
1566
        >>> formula = SMTFormula(
1567
        ...     z3_eq(z3.StrToInt(x.to_smt()), z3.StrToInt(y.to_smt())), x, y)
1568
        >>> unparse_isla(formula)
1569
        '(= (str.to.int x) (str.to.int y))'
1570

1571
        The solution is the singleton set with the variable `y`, which is an
1572
        instantiation of the constant `x` solving the equation.
1573

1574
        >>> solver.infer_satisfying_assignments_for_smt_formula(formula, x)
1575
        {Constant("y", "NUM")}
1576

1577
        If we choose a different type of formula (a greater-than relation), we obtain
1578
        an empty solution set.
1579

1580
        >>> formula = SMTFormula(
1581
        ...     z3.StrToInt(x.to_smt()) > z3.StrToInt(y.to_smt()), x, y)
1582
        >>> unparse_isla(formula)
1583
        '(> (str.to.int x) (str.to.int y))'
1584
        >>> solver.infer_satisfying_assignments_for_smt_formula(formula, x)
1585
        set()
1586

1587
        With a non-numeric formula, we obtain an AssertionError (if assertions are
1588
        enabled). This method expects to be called only internally, so this should
1589
        not happen (with or without activated assertions).
1590

1591
        >>> z = Constant("x", "<start>")
1592
        >>> formula = SMTFormula(z3_eq(z.to_smt(), z3.StringVal("x")), z)
1593
        >>> print(unparse_isla(formula))
1594
        const x: <start>;
1595
        <BLANKLINE>
1596
        (= x "x")
1597
        >>> results = solver.infer_satisfying_assignments_for_smt_formula(formula, z)
1598
        Traceback (most recent call last):
1599
        ...
1600
        AssertionError: Expected numeric solution.
1601
        """
1602

1603
        free_variables = smt_formula.free_variables()
1✔
1604
        max_instantiations = (
1✔
1605
            self.max_number_free_instantiations if len(free_variables) == 1 else 1
1606
        )
1607

1608
        try:
1✔
1609
            solver_result = self.solve_quantifier_free_formula(
1✔
1610
                (smt_formula,), max_instantiations=max_instantiations
1611
            )
1612

1613
            solutions: Dict[language.Constant, Set[int]] = {
1✔
1614
                c: {
1615
                    int(solution[cast(language.Constant, c)].value)
1616
                    for solution in solver_result
1617
                }
1618
                for c in free_variables
1619
            }
1620
        except ValueError:
1✔
1621
            assert False, "Expected numeric solution."
1✔
1622

1623
        if solutions:
1✔
1624
            if len(free_variables) == 1:
1✔
1625
                return solutions[constant]
1✔
1626
            else:
1627
                assert all(len(solution) == 1 for solution in solutions.values())
1✔
1628
                # In situations with multiple variables, we might have to abstract from
1629
                # concrete values. Currently, we only support simple equality inference
1630
                # (based on one sample...). Note that for supporting *more complex*
1631
                # terms (e.g., additions), we would have to extend the whole
1632
                # infrastructure: Substitutions with complex terms, and complex terms
1633
                # in semantic predicate arguments, are unsupported as of now.
1634
                candidates = {
1✔
1635
                    c
1636
                    for c in solutions
1637
                    if c != constant
1638
                    and next(iter(solutions[c])) == next(iter(solutions[constant]))
1639
                }
1640

1641
                # Filter working candidates
1642
                return {
1✔
1643
                    c
1644
                    for c in candidates
1645
                    if self.solve_quantifier_free_formula(
1646
                        (
1647
                            cast(
1648
                                language.SMTFormula,
1649
                                smt_formula.substitute_variables({constant: c}),
1650
                            ),
1651
                        ),
1652
                        max_instantiations=1,
1653
                    )
1654
                }
1655

1656
    def eliminate_all_semantic_formulas(
1✔
1657
        self, state: SolutionState, max_instantiations: Optional[int] = None
1658
    ) -> Maybe[List[SolutionState]]:
1659
        """
1660
        Eliminates all SMT-LIB formulas that appear in `state`'s constraint as conjunctive elements.
1661
        If, e.g., an SMT-LIB formula occurs as a disjunction, no solution is computed.
1662

1663
        :param state: The state in which to solve all SMT-LIB formulas.
1664
        :param max_instantiations: The number of solutions the SMT solver should be asked for.
1665
        :return: The discovered solutions.
1666
        """
1667

1668
        conjuncts = split_conjunction(state.constraint)
1✔
1669
        semantic_formulas = [
1✔
1670
            conjunct
1671
            for conjunct in conjuncts
1672
            if isinstance(conjunct, language.SMTFormula)
1673
            and not z3.is_true(conjunct.formula)
1674
        ]
1675

1676
        if not semantic_formulas:
1✔
1677
            return Nothing
1✔
1678

1679
        self.logger.debug(
1✔
1680
            "Eliminating semantic formulas [%s]", lazyjoin(", ", semantic_formulas)
1681
        )
1682

1683
        prefix_conjunction = reduce(lambda a, b: a & b, semantic_formulas, sc.true())
1✔
1684
        new_disjunct = prefix_conjunction & reduce(
1✔
1685
            lambda a, b: a & b,
1686
            [conjunct for conjunct in conjuncts if conjunct not in semantic_formulas],
1687
            sc.true(),
1688
        )
1689

1690
        return Some(
1✔
1691
            self.eliminate_semantic_formula(
1692
                prefix_conjunction,
1693
                SolutionState(new_disjunct, state.tree),
1694
                max_instantiations,
1695
            )
1696
        )
1697

1698
    def eliminate_all_ready_semantic_predicate_formulas(
1✔
1699
        self, state: SolutionState
1700
    ) -> Maybe[List[SolutionState]]:
1701
        semantic_predicate_formulas: List[
1✔
1702
            language.NegatedFormula | language.SemanticPredicateFormula
1703
        ] = [
1704
            cast(
1705
                language.NegatedFormula | language.SemanticPredicateFormula,
1706
                pred_formula,
1707
            )
1708
            for pred_formula in language.FilterVisitor(
1709
                lambda f: (
1710
                    isinstance(f, language.SemanticPredicateFormula)
1711
                    or isinstance(f, language.NegatedFormula)
1712
                    and isinstance(f.args[0], language.SemanticPredicateFormula)
1713
                ),
1714
                do_continue=lambda f: (
1715
                    not isinstance(f, language.NegatedFormula)
1716
                    or not isinstance(f.args[0], language.SemanticPredicateFormula)
1717
                ),
1718
            ).collect(state.constraint)
1719
            if all(
1720
                not isinstance(var, language.BoundVariable)
1721
                for var in pred_formula.free_variables()
1722
            )
1723
        ]
1724

1725
        semantic_predicate_formulas = sorted(
1✔
1726
            semantic_predicate_formulas,
1727
            key=lambda f: (
1728
                2 * cast(language.SemanticPredicateFormula, f.args[0]).predicate.order
1729
                + 100
1730
                if isinstance(f, language.NegatedFormula)
1731
                else f.predicate.order
1732
            ),
1733
        )
1734

1735
        if not semantic_predicate_formulas:
1✔
1736
            return Nothing
1✔
1737

1738
        result = state
1✔
1739

1740
        changed = False
1✔
1741
        for idx, possibly_negated_semantic_predicate_formula in enumerate(
1✔
1742
            semantic_predicate_formulas
1743
        ):
1744
            negated = isinstance(
1✔
1745
                possibly_negated_semantic_predicate_formula, language.NegatedFormula
1746
            )
1747
            semantic_predicate_formula: language.SemanticPredicateFormula = (
1✔
1748
                cast(
1749
                    language.NegatedFormula, possibly_negated_semantic_predicate_formula
1750
                ).args[0]
1751
                if negated
1752
                else possibly_negated_semantic_predicate_formula
1753
            )
1754

1755
            evaluation_result = semantic_predicate_formula.evaluate(
1✔
1756
                self.graph, negate=negated
1757
            )
1758
            if not evaluation_result.ready():
1✔
1759
                continue
1✔
1760

1761
            self.logger.debug(
1✔
1762
                "Eliminating semantic predicate formula %s", semantic_predicate_formula
1763
            )
1764
            changed = True
1✔
1765

1766
            if evaluation_result.is_boolean():
1✔
1767
                result = SolutionState(
1✔
1768
                    language.replace_formula(
1769
                        result.constraint,
1770
                        semantic_predicate_formula,
1771
                        language.smt_atom(evaluation_result.true()),
1772
                    ),
1773
                    result.tree,
1774
                )
1775
                continue
1✔
1776

1777
            substitution = subtree_solutions(evaluation_result.result)
1✔
1778

1779
            new_constraint = language.replace_formula(
1✔
1780
                result.constraint,
1781
                semantic_predicate_formula,
1782
                sc.false() if negated else sc.true(),
1783
            ).substitute_expressions(substitution)
1784

1785
            for k in range(idx + 1, len(semantic_predicate_formulas)):
1✔
1786
                semantic_predicate_formulas[k] = cast(
1✔
1787
                    language.SemanticPredicateFormula,
1788
                    semantic_predicate_formulas[k].substitute_expressions(substitution),
1789
                )
1790

1791
            result = SolutionState(new_constraint, result.tree.substitute(substitution))
1✔
1792
            assert self.graph.tree_is_valid(result.tree)
1✔
1793

1794
        return Maybe.from_optional([result] if changed else None)
1✔
1795

1796
    def eliminate_and_match_first_existential_formula(
1✔
1797
        self, state: SolutionState
1798
    ) -> Optional[List[SolutionState]]:
1799
        # We produce up to two groups of output states: One where the first existential
1800
        # formula, if it can be matched, is matched, and one where the first existential
1801
        # formula is eliminated by tree insertion.
1802

1803
        def do_eliminate(
1✔
1804
            first_existential_formula_with_idx: Tuple[int, language.ExistsFormula]
1805
        ) -> List[SolutionState]:
1806
            first_matched = OrderedSet(
1✔
1807
                self.match_existential_formula(
1808
                    first_existential_formula_with_idx[0], state
1809
                )
1810
            )
1811

1812
            # Tree insertion can be deactivated by setting `self.tree_insertion_methods`
1813
            # to 0.
1814
            if not self.tree_insertion_methods:
1✔
1815
                return list(first_matched)
1✔
1816

1817
            if first_matched:
1✔
1818
                self.logger.debug(
1✔
1819
                    "Matched first existential formulas, result: [%s]",
1820
                    lazyjoin(
1821
                        ", ",
1822
                        [
1823
                            lazystr(lambda: f"{s} (hash={hash(s)})")
1824
                            for s in first_matched
1825
                        ],
1826
                    ),
1827
                )
1828

1829
            # 3. Eliminate first existential formula by tree insertion.
1830
            elimination_result = OrderedSet(
1✔
1831
                self.eliminate_existential_formula(
1832
                    first_existential_formula_with_idx[0], state
1833
                )
1834
            )
1835
            elimination_result = OrderedSet(
1✔
1836
                [
1837
                    result
1838
                    for result in elimination_result
1839
                    if not any(
1840
                        other_result.tree == result.tree
1841
                        and self.propositionally_unsatisfiable(
1842
                            result.constraint & -other_result.constraint
1843
                        )
1844
                        for other_result in first_matched
1845
                    )
1846
                ]
1847
            )
1848

1849
            if not elimination_result and not first_matched:
1✔
1850
                self.logger.warning(
×
1851
                    "Existential qfr elimination: Could not eliminate existential formula %s "
1852
                    "by matching or tree insertion",
1853
                    first_existential_formula_with_idx[1],
1854
                )
1855

1856
            if elimination_result:
1✔
1857
                self.logger.debug(
1✔
1858
                    "Eliminated existential formula %s by tree insertion, %d successors",
1859
                    first_existential_formula_with_idx[1],
1860
                    len(elimination_result),
1861
                )
1862

1863
            return [
1✔
1864
                result
1865
                for result in first_matched | elimination_result
1866
                if result != state
1867
            ]
1868

1869
        return (
1✔
1870
            result_to_maybe(
1871
                safe(
1872
                    lambda: next(
1873
                        (idx, conjunct)
1874
                        for idx, conjunct in enumerate(
1875
                            split_conjunction(state.constraint)
1876
                        )
1877
                        if isinstance(conjunct, language.ExistsFormula)
1878
                    )
1879
                )()
1880
            )
1881
            .map(do_eliminate)
1882
            .value_or(None)
1883
        )
1884

1885
    def match_all_universal_formulas(
1✔
1886
        self, state: SolutionState
1887
    ) -> Maybe[List[SolutionState]]:
1888
        universal_formulas = [
1✔
1889
            conjunct
1890
            for conjunct in split_conjunction(state.constraint)
1891
            if isinstance(conjunct, language.ForallFormula)
1892
        ]
1893

1894
        if not universal_formulas:
1✔
1895
            return Nothing
1✔
1896

1897
        result = self.match_universal_formulas(state)
1✔
1898
        if not result:
1✔
1899
            return Nothing
1✔
1900

1901
        self.logger.debug(
1✔
1902
            "Matched universal formulas [%s]", lazyjoin(", ", universal_formulas)
1903
        )
1904

1905
        return Some(result)
1✔
1906

1907
    def expand_tree(
1✔
1908
        self,
1909
        state: SolutionState,
1910
        only_universal: bool = True,
1911
        limit: Optional[int] = None,
1912
    ) -> List[SolutionState]:
1913
        """
1914
        Expands the given tree, but not at nonterminals that can be freely instantiated of those that directly
1915
        correspond to the assignment constant.
1916

1917
        :param state: The current state.
1918
        :param only_universal: If set to True, only nonterminals that might match universal quantifiers are
1919
        expanded. If set to false, also nonterminals matching to existential quantifiers are expanded.
1920
        :param limit: If set to a value, this will return only up to limit expansions.
1921
        :return: A (possibly empty) list of expanded trees.
1922
        """
1923

1924
        nonterminal_expansions: Dict[Path, List[List[DerivationTree]]] = {
1✔
1925
            leaf_path: [
1926
                [
1927
                    DerivationTree(child, None if is_nonterminal(child) else [])
1928
                    for child in expansion
1929
                ]
1930
                for expansion in self.canonical_grammar[leaf_node.value]
1931
            ]
1932
            for leaf_path, leaf_node in state.tree.open_leaves()
1933
            if any(
1934
                self.quantified_formula_might_match(formula, leaf_path, state.tree)
1935
                for formula in get_conjuncts(state.constraint)
1936
                if (only_universal and isinstance(formula, language.ForallFormula))
1937
                or (
1938
                    not only_universal
1939
                    and isinstance(formula, language.QuantifiedFormula)
1940
                )
1941
            )
1942
        }
1943

1944
        if not nonterminal_expansions:
1✔
1945
            return []
1✔
1946

1947
        possible_expansions: List[Dict[Path, List[DerivationTree]]] = []
1✔
1948
        if not limit:
1✔
1949
            possible_expansions = dict_of_lists_to_list_of_dicts(nonterminal_expansions)
1✔
1950
            assert len(possible_expansions) == math.prod(
1✔
1951
                len(values) for values in nonterminal_expansions.values()
1952
            )
1953
        else:
1954
            for _ in range(limit):
1✔
1955
                curr_expansion = {}
1✔
1956
                for path, expansions in nonterminal_expansions.items():
1✔
1957
                    if not expansions:
1✔
1958
                        continue
×
1959

1960
                    curr_expansion[path] = random.choice(expansions)
1✔
1961
                possible_expansions.append(curr_expansion)
1✔
1962

1963
        # This replaces a previous `if` statement with the negated condition as guard,
1964
        # which seems to be dead code (the guard can never hold true due to the check
1965
        # of emptiness of `nonterminal_expansions` above). We keep this assertion here
1966
        # to be sure.
1967
        assert (
1✔
1968
            len(possible_expansions) > 1
1969
            or len(possible_expansions) == 1
1970
            and possible_expansions[0]
1971
        )
1972

1973
        result: List[SolutionState] = []
1✔
1974
        for possible_expansion in possible_expansions:
1✔
1975
            expanded_tree = state.tree
1✔
1976
            for path, new_children in possible_expansion.items():
1✔
1977
                leaf_node = expanded_tree.get_subtree(path)
1✔
1978
                expanded_tree = expanded_tree.replace_path(
1✔
1979
                    path, DerivationTree(leaf_node.value, new_children, leaf_node.id)
1980
                )
1981

1982
                assert expanded_tree is not state.tree
1✔
1983
                assert expanded_tree != state.tree
1✔
1984
                assert expanded_tree.structural_hash() != state.tree.structural_hash()
1✔
1985

1986
            updated_constraint = state.constraint.substitute_expressions(
1✔
1987
                {
1988
                    state.tree.get_subtree(path[:idx]): expanded_tree.get_subtree(
1989
                        path[:idx]
1990
                    )
1991
                    for path in possible_expansion
1992
                    for idx in range(len(path) + 1)
1993
                }
1994
            )
1995

1996
            result.append(SolutionState(updated_constraint, expanded_tree))
1✔
1997

1998
        assert not limit or len(result) <= limit
1✔
1999
        return result
1✔
2000

2001
    def match_universal_formulas(self, state: SolutionState) -> List[SolutionState]:
1✔
2002
        instantiated_formulas: List[language.Formula] = []
1✔
2003
        conjuncts = split_conjunction(state.constraint)
1✔
2004

2005
        for idx, universal_formula in enumerate(conjuncts):
1✔
2006
            if not isinstance(universal_formula, language.ForallFormula):
1✔
2007
                continue
1✔
2008

2009
            matches: List[Dict[language.Variable, Tuple[Path, DerivationTree]]] = [
1✔
2010
                match
2011
                for match in matches_for_quantified_formula(
2012
                    universal_formula, self.grammar
2013
                )
2014
                if not universal_formula.is_already_matched(
2015
                    match[universal_formula.bound_variable][1]
2016
                )
2017
            ]
2018

2019
            universal_formula_with_matches = universal_formula.add_already_matched(
1✔
2020
                {match[universal_formula.bound_variable][1] for match in matches}
2021
            )
2022

2023
            for match in matches:
1✔
2024
                inst_formula = (
1✔
2025
                    universal_formula_with_matches.inner_formula.substitute_expressions(
2026
                        {
2027
                            variable: match_tree
2028
                            for variable, (_, match_tree) in match.items()
2029
                        }
2030
                    )
2031
                )
2032

2033
                instantiated_formulas.append(inst_formula)
1✔
2034
                conjuncts = list_set(conjuncts, idx, universal_formula_with_matches)
1✔
2035

2036
        if instantiated_formulas:
1✔
2037
            return [
1✔
2038
                SolutionState(
2039
                    sc.conjunction(*instantiated_formulas) & sc.conjunction(*conjuncts),
2040
                    state.tree,
2041
                )
2042
            ]
2043
        else:
2044
            return []
1✔
2045

2046
    def match_existential_formula(
1✔
2047
        self, existential_formula_idx: int, state: SolutionState
2048
    ) -> List[SolutionState]:
2049
        result: List[SolutionState] = []
1✔
2050

2051
        conjuncts: ImmutableList[language.Formula] = tuple(
1✔
2052
            split_conjunction(state.constraint)
2053
        )
2054
        existential_formula = cast(
1✔
2055
            language.ExistsFormula, conjuncts[existential_formula_idx]
2056
        )
2057

2058
        matches: List[Dict[language.Variable, Tuple[Path, DerivationTree]]] = (
1✔
2059
            matches_for_quantified_formula(existential_formula, self.grammar)
2060
        )
2061

2062
        for match in matches:
1✔
2063
            inst_formula = existential_formula.inner_formula.substitute_expressions(
1✔
2064
                {variable: match_tree for variable, (_, match_tree) in match.items()}
2065
            )
2066
            constraint = inst_formula & sc.conjunction(
1✔
2067
                *list_del(conjuncts, existential_formula_idx)
2068
            )
2069
            result.append(SolutionState(constraint, state.tree))
1✔
2070

2071
        return result
1✔
2072

2073
    def eliminate_existential_formula(
1✔
2074
        self, existential_formula_idx: int, state: SolutionState
2075
    ) -> List[SolutionState]:
2076
        conjuncts: ImmutableList[language.Formula] = tuple(
1✔
2077
            split_conjunction(state.constraint)
2078
        )
2079
        existential_formula = cast(
1✔
2080
            language.ExistsFormula, conjuncts[existential_formula_idx]
2081
        )
2082

2083
        inserted_trees_and_bind_paths = (
1✔
2084
            [(DerivationTree(existential_formula.bound_variable.n_type, None), {})]
2085
            if existential_formula.bind_expression is None
2086
            else (
2087
                existential_formula.bind_expression.to_tree_prefix(
2088
                    existential_formula.bound_variable.n_type, self.grammar
2089
                )
2090
            )
2091
        )
2092

2093
        result: List[SolutionState] = []
1✔
2094

2095
        inserted_tree: DerivationTree
2096
        bind_expr_paths: Dict[language.BoundVariable, Path]
2097
        for inserted_tree, bind_expr_paths in inserted_trees_and_bind_paths:
1✔
2098
            self.logger.debug(
1✔
2099
                "process_insertion_results(self.canonical_grammar, %s, %s, self.graph, %s)",
2100
                lazystr(lambda: repr(inserted_tree)),
2101
                lazystr(lambda: repr(existential_formula.in_variable)),
2102
                self.max_number_tree_insertion_results,
2103
            )
2104

2105
            insertion_results = insert_tree(
1✔
2106
                self.canonical_grammar,
2107
                inserted_tree,
2108
                existential_formula.in_variable,
2109
                graph=self.graph,
2110
                max_num_solutions=self.max_number_tree_insertion_results * 2,
2111
                methods=self.tree_insertion_methods,
2112
            )
2113

2114
            insertion_results = sorted(
1✔
2115
                insertion_results,
2116
                key=lambda t: compute_tree_closing_cost(t, self.graph),
2117
            )
2118
            insertion_results = insertion_results[
1✔
2119
                : self.max_number_tree_insertion_results
2120
            ]
2121

2122
            for insertion_result in insertion_results:
1✔
2123
                replaced_path = state.tree.find_node(existential_formula.in_variable)
1✔
2124
                resulting_tree = state.tree.replace_path(
1✔
2125
                    replaced_path, insertion_result
2126
                )
2127

2128
                tree_substitution: Dict[DerivationTree, DerivationTree] = {}
1✔
2129
                for idx in range(len(replaced_path) + 1):
1✔
2130
                    original_path = replaced_path[: idx - 1]
1✔
2131
                    original_tree = state.tree.get_subtree(original_path)
1✔
2132
                    if (
1✔
2133
                        resulting_tree.is_valid_path(original_path)
2134
                        and original_tree.value
2135
                        == resulting_tree.get_subtree(original_path).value
2136
                        and resulting_tree.get_subtree(original_path) != original_tree
2137
                    ):
2138
                        tree_substitution[original_tree] = resulting_tree.get_subtree(
1✔
2139
                            original_path
2140
                        )
2141

2142
                assert insertion_result.find_node(inserted_tree) is not None
1✔
2143
                variable_substitutions = {
1✔
2144
                    existential_formula.bound_variable: inserted_tree
2145
                }
2146

2147
                if bind_expr_paths:
1✔
2148
                    if assertions_activated():
1✔
2149
                        dangling_bind_expr_vars = [
1✔
2150
                            (var, path)
2151
                            for var, path in bind_expr_paths.items()
2152
                            if (
2153
                                var
2154
                                in existential_formula.bind_expression.bound_variables()
2155
                                and insertion_result.find_node(
2156
                                    inserted_tree.get_subtree(path)
2157
                                )
2158
                                is None
2159
                            )
2160
                        ]
2161
                        assert not dangling_bind_expr_vars, (
1✔
2162
                            f"Bound variables from match expression not found in tree: "
2163
                            f"[{' ,'.join(map(repr, dangling_bind_expr_vars))}]"
2164
                        )
2165

2166
                    variable_substitutions.update(
1✔
2167
                        {
2168
                            var: inserted_tree.get_subtree(path)
2169
                            for var, path in bind_expr_paths.items()
2170
                            if var
2171
                            in existential_formula.bind_expression.bound_variables()
2172
                        }
2173
                    )
2174

2175
                instantiated_formula = (
1✔
2176
                    existential_formula.inner_formula.substitute_expressions(
2177
                        variable_substitutions
2178
                    ).substitute_expressions(tree_substitution)
2179
                )
2180

2181
                instantiated_original_constraint = sc.conjunction(
1✔
2182
                    *list_del(conjuncts, existential_formula_idx)
2183
                ).substitute_expressions(tree_substitution)
2184

2185
                new_tree = resulting_tree.substitute(tree_substitution)
1✔
2186

2187
                new_formula = (
1✔
2188
                    instantiated_formula
2189
                    & self.formula.substitute_expressions(
2190
                        {self.top_constant.unwrap(): new_tree}
2191
                    )
2192
                    & instantiated_original_constraint
2193
                )
2194

2195
                new_state = SolutionState(new_formula, new_tree)
1✔
2196

2197
                assert all(
1✔
2198
                    new_state.tree.find_node(tree) is not None
2199
                    for quantified_formula in split_conjunction(new_state.constraint)
2200
                    if isinstance(quantified_formula, language.QuantifiedFormula)
2201
                    for _, tree in quantified_formula.in_variable.filter(lambda t: True)
2202
                )
2203

2204
                if assertions_activated() or self.debug:
1✔
2205
                    lost_tree_predicate_arguments: List[DerivationTree] = [
1✔
2206
                        arg
2207
                        for invstate in self.establish_invariant(new_state)
2208
                        for predicate_formula in get_conjuncts(invstate.constraint)
2209
                        if isinstance(
2210
                            predicate_formula, language.StructuralPredicateFormula
2211
                        )
2212
                        for arg in predicate_formula.args
2213
                        if isinstance(arg, DerivationTree)
2214
                        and invstate.tree.find_node(arg) is None
2215
                    ]
2216

2217
                    if lost_tree_predicate_arguments:
1✔
2218
                        previous_posititions = [
×
2219
                            state.tree.find_node(t)
2220
                            for t in lost_tree_predicate_arguments
2221
                        ]
2222
                        assert False, (
×
2223
                            f"Dangling subtrees [{', '.join(map(repr, lost_tree_predicate_arguments))}], "
2224
                            f"previously at positions [{', '.join(map(str, previous_posititions))}] "
2225
                            f"in tree {repr(state.tree)} (hash: {hash(state)})."
2226
                        )
2227

2228
                    lost_semantic_formula_arguments = [
1✔
2229
                        arg
2230
                        for invstate in self.establish_invariant(new_state)
2231
                        for semantic_formula in get_conjuncts(new_state.constraint)
2232
                        if isinstance(semantic_formula, language.SMTFormula)
2233
                        for arg in semantic_formula.substitutions.values()
2234
                        if invstate.tree.find_node(arg) is None
2235
                    ]
2236

2237
                    if lost_semantic_formula_arguments:
1✔
2238
                        previous_posititions = [
×
2239
                            state.tree.find_node(t)
2240
                            for t in lost_semantic_formula_arguments
2241
                        ]
2242
                        previous_posititions = [
×
2243
                            p for p in previous_posititions if p is not None
2244
                        ]
2245
                        assert False, (
×
2246
                            f"Dangling subtrees [{', '.join(map(repr, lost_semantic_formula_arguments))}], "
2247
                            f"previously at positions [{', '.join(map(str, previous_posititions))}] "
2248
                            f"in tree {repr(state.tree)} (hash: {hash(state)})."
2249
                        )
2250

2251
                result.append(new_state)
1✔
2252

2253
        return result
1✔
2254

2255
    def eliminate_semantic_formula(
1✔
2256
        self,
2257
        semantic_formula: language.Formula,
2258
        state: SolutionState,
2259
        max_instantiations: Optional[int] = None,
2260
    ) -> Optional[List[SolutionState]]:
2261
        """
2262
        Solves a semantic formula and, for each solution, substitutes the solution for
2263
        the respective constant in each assignment of the state. Also instantiates all
2264
        "free" constants in the given tree. The SMT solver is passed a regular
2265
        expression approximating the language of the nonterminal of each considered
2266
        constant. Returns an empty list for unsolvable constraints.
2267

2268
        :param semantic_formula: The semantic (i.e., only containing logical connectors and SMT Formulas)
2269
        formula to solve.
2270
        :param state: The original solution state.
2271
        :param max_instantiations: The maximum number of solutions to ask the SMT solver for.
2272
        :return: A list of instantiated SolutionStates.
2273
        """
2274

2275
        assert all(
1✔
2276
            isinstance(conjunct, language.SMTFormula)
2277
            for conjunct in get_conjuncts(semantic_formula)
2278
        )
2279

2280
        # NOTE: We need to cluster SMT formulas by tree substitutions. If there are two
2281
        # formulas with a variable $var which is instantiated to different trees, we
2282
        # need two separate solutions. If, however, $var is instantiated with the
2283
        # *same* tree, we need one solution to both formulas together.
2284

2285
        smt_formulas = self.rename_instantiated_variables_in_smt_formulas(
1✔
2286
            [
2287
                smt_formula
2288
                for smt_formula in get_conjuncts(semantic_formula)
2289
                if isinstance(smt_formula, language.SMTFormula)
2290
            ]
2291
        )
2292

2293
        # Now, we also cluster formulas by common variables (and instantiated subtrees:
2294
        # One formula might yield an instantiation of a subtree of the instantiation of
2295
        # another formula. They need to appear in the same cluster). The solver can
2296
        # better handle smaller constraints, and those which do not have variables in
2297
        # common can be handled independently.
2298

2299
        def cluster_keys(smt_formula: language.SMTFormula):
1✔
2300
            return (
1✔
2301
                smt_formula.free_variables()
2302
                | smt_formula.instantiated_variables
2303
                | set(
2304
                    [
2305
                        subtree
2306
                        for tree in smt_formula.substitutions.values()
2307
                        for _, subtree in tree.paths()
2308
                    ]
2309
                )
2310
            )
2311

2312
        formula_clusters: List[List[language.SMTFormula]] = cluster_by_common_elements(
1✔
2313
            smt_formulas, cluster_keys
2314
        )
2315

2316
        assert all(
1✔
2317
            not cluster_keys(smt_formula)
2318
            or any(smt_formula in cluster for cluster in formula_clusters)
2319
            for smt_formula in smt_formulas
2320
        )
2321

2322
        formula_clusters = [cluster for cluster in formula_clusters if cluster]
1✔
2323
        remaining_clusters = [
1✔
2324
            smt_formula for smt_formula in smt_formulas if not cluster_keys(smt_formula)
2325
        ]
2326
        if remaining_clusters:
1✔
2327
            formula_clusters.append(remaining_clusters)
1✔
2328

2329
        # Note: We cannot ask for `max_instantiations` solutions for *each cluster;*
2330
        #       this would imply that we get 10^4 solutions if `max_instantiations`
2331
        #       is 10 and we have 4 clusters (we combine all these solutions to a
2332
        #       product). Instead, we want 10 solutions; thus, we compute the
2333
        #       #numCluster'th root of `max_instantiations` and ceil.
2334
        #       For example, the ceil of the 4-root of 10 is 2, and 2^10 is 16. This
2335
        #       is still within an acceptable range.
2336

2337
        solutions_per_cluster = math.ceil(
1✔
2338
            (max_instantiations or self.max_number_smt_instantiations)
2339
            ** (1 / len(formula_clusters))
2340
        )
2341

2342
        all_solutions: List[
1✔
2343
            List[Dict[Union[language.Constant, DerivationTree], DerivationTree]]
2344
        ] = [
2345
            self.solve_quantifier_free_formula(
2346
                tuple(cluster),
2347
                solutions_per_cluster,
2348
            )
2349
            for cluster in formula_clusters
2350
        ]
2351

2352
        # These solutions are all independent, such that we can combine each solution
2353
        # with all others.
2354
        solutions: List[
1✔
2355
            Dict[Union[language.Constant, DerivationTree], DerivationTree]
2356
        ] = [
2357
            functools.reduce(operator.or_, dicts)
2358
            for dicts in itertools.product(*all_solutions)
2359
        ]
2360

2361
        results = []
1✔
2362
        # We also have to instantiate all subtrees of the substituted element.
2363
        for solution in map(subtree_solutions, solutions):
1✔
2364
            if solution:
1✔
2365
                new_state = SolutionState(
1✔
2366
                    state.constraint.substitute_expressions(solution),
2367
                    state.tree.substitute(solution),
2368
                )
2369
            else:
2370
                new_state = SolutionState(
×
2371
                    language.replace_formula(
2372
                        state.constraint, semantic_formula, sc.true()
2373
                    ),
2374
                    state.tree,
2375
                )
2376

2377
            results.append(new_state)
1✔
2378

2379
        return results
1✔
2380

2381
    @lru_cache(100)
1✔
2382
    def solve_quantifier_free_formula(
1✔
2383
        self,
2384
        smt_formulas: ImmutableList[language.SMTFormula],
2385
        max_instantiations: Optional[int] = None,
2386
    ) -> List[Dict[language.Constant | DerivationTree, DerivationTree]]:
2387
        """
2388
        Attempts to solve the given SMT-LIB formulas by calling Z3.
2389

2390
        Note that this function does not unify variables pointing to the same derivation
2391
        trees. E.g., a solution may be returned for the formula `var_1 = "a" and
2392
        var_2 = "b"`, though `var_1` and `var_2` point to the same `<var>` tree as
2393
        defined by their substitutions maps. Unification is performed in
2394
        `eliminate_all_semantic_formulas`.
2395

2396
        :param smt_formulas: The SMT-LIB formulas to solve.
2397
        :param max_instantiations: The maximum number of instantiations to produce.
2398
        :return: A (possibly empty) list of solutions.
2399
        """
2400

2401
        # If any SMT formula refers to *sub*trees in the instantiations of other SMT
2402
        # formulas, we have to instantiate those first.
2403
        priority_formulas = smt_formulas_referring_to_subtrees(smt_formulas)
1✔
2404

2405
        if priority_formulas:
1✔
2406
            smt_formulas = priority_formulas
1✔
2407
            assert not smt_formulas_referring_to_subtrees(smt_formulas)
1✔
2408

2409
        tree_substitutions = reduce(
1✔
2410
            lambda d1, d2: d1 | d2,
2411
            [smt_formula.substitutions for smt_formula in smt_formulas],
2412
            {},
2413
        )
2414

2415
        constants = reduce(
1✔
2416
            lambda d1, d2: d1 | d2,
2417
            [
2418
                smt_formula.free_variables() | smt_formula.instantiated_variables
2419
                for smt_formula in smt_formulas
2420
            ],
2421
            set(),
2422
        )
2423

2424
        solutions: List[
1✔
2425
            Dict[Union[language.Constant, DerivationTree], DerivationTree]
2426
        ] = []
2427
        internal_solutions: List[Dict[language.Constant, z3.StringVal]] = []
1✔
2428

2429
        num_instantiations = max_instantiations or self.max_number_smt_instantiations
1✔
2430
        for _ in range(num_instantiations):
1✔
2431
            (
1✔
2432
                solver_result,
2433
                maybe_model,
2434
            ) = self.solve_smt_formulas_with_language_constraints(
2435
                constants,
2436
                tuple([smt_formula.formula for smt_formula in smt_formulas]),
2437
                tree_substitutions,
2438
                internal_solutions,
2439
            )
2440

2441
            if solver_result != z3.sat:
1✔
2442
                if not solutions:
1✔
2443
                    return []
1✔
2444
                else:
2445
                    return solutions
1✔
2446

2447
            assert maybe_model is not None
1✔
2448

2449
            new_solution = {
1✔
2450
                tree_substitutions.get(constant, constant): maybe_model[constant]
2451
                for constant in constants
2452
            }
2453

2454
            new_internal_solution = {
1✔
2455
                constant: z3.StringVal(str(maybe_model[constant]))
2456
                for constant in constants
2457
            }
2458

2459
            if new_solution in solutions:
1✔
2460
                # This can happen for trivial solutions, i.e., if the formula is
2461
                # logically valid. Then, the assignment for that constant will
2462
                # always be {}
2463
                return solutions
×
2464
            else:
2465
                solutions.append(new_solution)
1✔
2466
                if new_internal_solution:
1✔
2467
                    internal_solutions.append(new_internal_solution)
1✔
2468
                else:
2469
                    # Again, for a trivial solution (e.g., True), the assignment
2470
                    # can be empty.
2471
                    break
×
2472

2473
        return solutions
1✔
2474

2475
    def solve_smt_formulas_with_language_constraints(
1✔
2476
        self,
2477
        variables: Set[language.Variable],
2478
        smt_formulas: ImmutableList[z3.BoolRef],
2479
        tree_substitutions: Dict[language.Variable, DerivationTree],
2480
        solutions_to_exclude: List[Dict[language.Variable, z3.StringVal]],
2481
    ) -> Tuple[z3.CheckSatResult, Dict[language.Variable, DerivationTree]]:
2482
        # We disable optimized Z3 queries if the SMT formulas contain "too concrete"
2483
        # substitutions, that is, substitutions with a tree that is not merely an
2484
        # open leaf. Example: we have a constrained `str.len(<chars>) < 10` and a
2485
        # tree `<char><char>`; only the concrete length "10" is possible then. In fact,
2486
        # we could simply finish of the tree and check the constraint, or restrict the
2487
        # custom tree generation to admissible lengths, but we stay general here. The
2488
        # SMT solution is more robust.
2489

2490
        if self.enable_optimized_z3_queries and not any(
1✔
2491
            substitution.children for substitution in tree_substitutions.values()
2492
        ):
2493
            vars_in_context = self.infer_variable_contexts(variables, smt_formulas)
1✔
2494
            length_vars = vars_in_context["length"]
1✔
2495
            int_vars = vars_in_context["int"]
1✔
2496
            flexible_vars = vars_in_context["flexible"]
1✔
2497
        else:
2498
            length_vars = set()
1✔
2499
            int_vars = set()
1✔
2500
            flexible_vars = set(variables)
1✔
2501

2502
        # Add language constraints for "flexible" variables
2503
        formulas: List[z3.BoolRef] = self.generate_language_constraints(
1✔
2504
            flexible_vars, tree_substitutions
2505
        )
2506

2507
        # Create fresh variables for `str.len` and `str.to.int` variables.
2508
        all_variables = set(variables)
1✔
2509
        fresh_var_map: Dict[language.Variable, z3.ExprRef] = {}
1✔
2510
        for var in length_vars | int_vars:
1✔
2511
            fresh = fresh_constant(
1✔
2512
                all_variables,
2513
                language.Constant(var.name, "NOT-NEEDED"),
2514
            )
2515
            fresh_var_map[var] = z3.Int(fresh.name)
1✔
2516

2517
        # In `smt_formulas`, we replace all `length(...)` terms for "length variables"
2518
        # with the corresponding fresh variable.
2519
        replacement_map: Dict[z3.ExprRef, z3.ExprRef] = {
1✔
2520
            expr: fresh_var_map[
2521
                get_elem_by_equivalence(
2522
                    expr.children()[0],
2523
                    length_vars | int_vars,
2524
                    lambda e1, e2: e1 == e2.to_smt(),
2525
                )
2526
            ]
2527
            for formula in smt_formulas
2528
            for expr in visit_z3_expr(formula)
2529
            if expr.decl().kind() in {z3.Z3_OP_SEQ_LENGTH, z3.Z3_OP_STR_TO_INT}
2530
            and expr.children()[0] in {var.to_smt() for var in length_vars | int_vars}
2531
        }
2532

2533
        # Perform substitution, add formulas
2534
        formulas.extend(
1✔
2535
            [
2536
                cast(z3.BoolRef, z3_subst(formula, replacement_map))
2537
                for formula in smt_formulas
2538
            ]
2539
        )
2540

2541
        # Lengths must be positive
2542
        formulas.extend(
1✔
2543
            [
2544
                cast(
2545
                    z3.BoolRef,
2546
                    replacement_map[z3.Length(length_var.to_smt())] >= z3.IntVal(0),
2547
                )
2548
                for length_var in length_vars
2549
            ]
2550
        )
2551

2552
        # Add custom intervals for int variables
2553
        for int_var in int_vars:
1✔
2554
            if int_var.n_type == language.Variable.NUMERIC_NTYPE:
1✔
2555
                # "NUM" variables range over the full int domain
2556
                continue
1✔
2557

2558
            regex = self.extract_regular_expression(int_var.n_type)
1✔
2559
            maybe_intervals = numeric_intervals_from_regex(regex)
1✔
2560
            repl_var = replacement_map[z3.StrToInt(int_var.to_smt())]
1✔
2561
            maybe_intervals.map(
1✔
2562
                tap(
2563
                    lambda intervals: formulas.append(
2564
                        z3_or(
2565
                            [
2566
                                z3.And(
2567
                                    (
2568
                                        repl_var >= z3.IntVal(interval[0])
2569
                                        if interval[0] > -sys.maxsize
2570
                                        else z3.BoolVal(True)
2571
                                    ),
2572
                                    (
2573
                                        repl_var <= z3.IntVal(interval[1])
2574
                                        if interval[1] < sys.maxsize
2575
                                        else z3.BoolVal(True)
2576
                                    ),
2577
                                )
2578
                                for interval in intervals
2579
                            ]
2580
                        )
2581
                    )
2582
                )
2583
            )
2584

2585
        for prev_solution in solutions_to_exclude:
1✔
2586
            prev_solution_formula = z3_and(
1✔
2587
                [
2588
                    self.previous_solution_formula(
2589
                        var, string_val, fresh_var_map, length_vars, int_vars
2590
                    )
2591
                    for var, string_val in prev_solution.items()
2592
                ]
2593
            )
2594

2595
            formulas.append(z3.Not(prev_solution_formula))
1✔
2596

2597
        sat_result, maybe_model = z3_solve(formulas)
1✔
2598

2599
        if sat_result != z3.sat:
1✔
2600
            return sat_result, {}
1✔
2601

2602
        assert maybe_model is not None
1✔
2603

2604
        return sat_result, {
1✔
2605
            var: self.extract_model_value(
2606
                var, maybe_model, fresh_var_map, length_vars, int_vars
2607
            )
2608
            for var in variables
2609
        }
2610

2611
    @staticmethod
1✔
2612
    def previous_solution_formula(
1✔
2613
        var: language.Variable,
2614
        string_val: z3.StringVal,
2615
        fresh_var_map: Dict[language.Variable, z3.ExprRef],
2616
        length_vars: Set[language.Variable],
2617
        int_vars: Set[language.Variable],
2618
    ) -> z3.BoolRef:
2619
        """
2620
        Computes a formula describing the previously found solution
2621
        :code:`var == string_val` for an :class:`~isla.language.SMTFormula`.
2622
        Considers the special cases that :code:`var` is a "length" or "int"
2623
        variable, i.e., occurred only in these contexts in the formula this
2624
        solution is about.
2625

2626
        >>> x = language.Variable("x", "<X>")
2627
        >>> ISLaSolver.previous_solution_formula(
2628
        ...     x, z3.StringVal("val"), {}, set(), set())
2629
        x == "val"
2630

2631
        >>> ISLaSolver.previous_solution_formula(
2632
        ...     x, z3.StringVal("val"), {x: z3.Int("x_0")}, {x}, set())
2633
        x_0 == 3
2634

2635
        >>> ISLaSolver.previous_solution_formula(
2636
        ...     x, z3.StringVal("10"), {x: z3.Int("x_0")}, set(), {x})
2637
        x_0 == 10
2638

2639
        >>> x = language.Variable("x", language.Variable.NUMERIC_NTYPE)
2640
        >>> ISLaSolver.previous_solution_formula(
2641
        ...     x, z3.StringVal("10"), {x: z3.Int("x_0")}, set(), {x})
2642
        x_0 == 10
2643

2644
        A "numeric" variable (of "NUM" type) is expected to always be an int variable,
2645
        which also needs to be reflected in its inclusion in :code:`fresh_var_map`.
2646

2647
        >>> x = language.Variable("x", language.Variable.NUMERIC_NTYPE)
2648
        >>> ISLaSolver.previous_solution_formula(
2649
        ...     x, z3.StringVal("10"), {}, set(), set())
2650
        Traceback (most recent call last):
2651
        ...
2652
        AssertionError
2653

2654
        :param var: The variable the solution is for.
2655
        :param string_val: The solution for :code:`var`.
2656
        :param fresh_var_map: A map from variables to fresh variables for "length" or
2657
                              "int" variables.
2658
        :param length_vars: The "length" variables.
2659
        :param int_vars: The "int" variables.
2660
        :return: An equation describing the previous solution.
2661
        """
2662

2663
        if var in int_vars:
1✔
2664
            return z3_eq(
1✔
2665
                fresh_var_map[var],
2666
                z3.IntVal(int(smt_string_val_to_string(string_val))),
2667
            )
2668
        elif var in length_vars:
1✔
2669
            return z3_eq(
1✔
2670
                fresh_var_map[var],
2671
                z3.IntVal(len(smt_string_val_to_string(string_val))),
2672
            )
2673
        else:
2674
            assert not var.is_numeric()
1✔
2675
            return z3_eq(var.to_smt(), string_val)
1✔
2676

2677
    def safe_create_fixed_length_tree(
1✔
2678
        self,
2679
        var: language.Variable,
2680
        model: z3.ModelRef,
2681
        fresh_var_map: Dict[language.Variable, z3.ExprRef],
2682
    ) -> DerivationTree:
2683
        """
2684
        Creates a :class:`~isla.derivation_tree.DerivationTree` for :code:`var` such
2685
        that the type of the tree fits to :code:`var` and the length of its string
2686
        representation fits to the length in :code:`model` for the fresh variable in
2687
        :code:`fresh_var_map`. For example:
2688

2689
        >>> grammar = {
2690
        ...     "<start>": ["<X>"],
2691
        ...     "<X>": ["x", "x<X>"],
2692
        ... }
2693
        >>> x = language.Variable("x", "<X>")
2694
        >>> x_0 = z3.Int("x_0")
2695
        >>> f = z3_eq(x_0, z3.IntVal(5))
2696
        >>> z3_solver = z3.Solver()
2697
        >>> z3_solver.add(f)
2698
        >>> z3_solver.check()
2699
        sat
2700
        >>> model = z3_solver.model()
2701
        >>> solver = ISLaSolver(grammar)
2702
        >>> tree = solver.safe_create_fixed_length_tree(x, model, {x: x_0})
2703
        >>> tree.value
2704
        '<X>'
2705
        >>> str(tree)
2706
        'xxxxx'
2707

2708
        :param var: The variable to create a
2709
                    :class:`~isla.derivation_tree.DerivationTree` object for.
2710
        :param model: The Z3 model to extract a solution to the length constraint.
2711
        :param fresh_var_map: A map including a mapping :code:`var` -> :code:`var_0`,
2712
                              where :code:`var_0` is an integer-valued variale included
2713
                              in :code:`model`.
2714
        :return: A tree of the type of :code:`var` and length as specified in
2715
                :code:`model`.
2716
        """
2717

2718
        assert var in fresh_var_map
1✔
2719
        assert fresh_var_map[var].decl() in model.decls()
1✔
2720

2721
        fixed_length_tree = create_fixed_length_tree(
1✔
2722
            start=var.n_type,
2723
            canonical_grammar=self.canonical_grammar,
2724
            target_length=model[fresh_var_map[var]].as_long(),
2725
        )
2726

2727
        if fixed_length_tree is None:
1✔
2728
            raise RuntimeError(
1✔
2729
                f"Could not create a tree with the start symbol '{var.n_type}' "
2730
                + f"of length {model[fresh_var_map[var]].as_long()}; try "
2731
                + "running the solver without optimized Z3 queries or make "
2732
                + "sure that lengths are restricted to syntactically valid "
2733
                + "ones (according to the grammar).",
2734
            )
2735

2736
        return fixed_length_tree
1✔
2737

2738
    def extract_model_value(
1✔
2739
        self,
2740
        var: language.Variable,
2741
        model: z3.ModelRef,
2742
        fresh_var_map: Dict[language.Variable, z3.ExprRef],
2743
        length_vars: Set[language.Variable],
2744
        int_vars: Set[language.Variable],
2745
    ) -> DerivationTree:
2746
        r"""
2747
        Extracts a value for :code:`var` from :code:`model`. Considers the following
2748
        special cases:
2749

2750
        Numeric Variables
2751
            Returns a closed derivation tree of one node with a string representation
2752
            of the numeric solution.
2753

2754
        "Length" Variables
2755
            Returns a string of the length corresponding to the model and
2756
            :code:`fresh_var_map`, see also
2757
            :meth:`~isla.solver.ISLaSolver.safe_create_fixed_length_tree()`.
2758

2759
        "Int" Variables
2760
            Tries to parse the numeric solution from the model (obtained via
2761
            :code:`fresh_var_map`) into the type of :code:`var` and returns the
2762
            corresponding derivation tree.
2763

2764
        >>> grammar = {
2765
        ...     "<start>": ["<A>"],
2766
        ...     "<A>": ["<X><Y>"],
2767
        ...     "<X>": ["x", "x<X>"],
2768
        ...     "<Y>": ["<digit>", "<digit><Y>"],
2769
        ...     "<digit>": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
2770
        ... }
2771
        >>> solver = ISLaSolver(grammar)
2772

2773
        **Numeric Variables:**
2774

2775
        >>> n = language.Variable("n", language.Variable.NUMERIC_NTYPE)
2776
        >>> f = z3_eq(z3.StrToInt(n.to_smt()), z3.IntVal(15))
2777
        >>> z3_solver = z3.Solver()
2778
        >>> z3_solver.add(f)
2779
        >>> z3_solver.check()
2780
        sat
2781
        >>> model = z3_solver.model()
2782
        >>> DerivationTree.next_id = 1
2783
        >>> solver.extract_model_value(n, model, {}, set(), set())
2784
        DerivationTree('15', (), id=1)
2785

2786
        For a trivially true solution on numeric variables, we return a random number:
2787

2788
        >>> f = z3_eq(n.to_smt(), n.to_smt())
2789
        >>> z3_solver = z3.Solver()
2790
        >>> z3_solver.add(f)
2791
        >>> z3_solver.check()
2792
        sat
2793

2794
        >>> model = z3_solver.model()
2795
        >>> DerivationTree.next_id = 1
2796
        >>> random.seed(0)
2797
        >>> solver.extract_model_value(n, model, {n: n.to_smt()}, set(), {n})
2798
        DerivationTree('-2116850434379610162', (), id=1)
2799

2800
        **"Length" Variables:**
2801

2802
        >>> x = language.Variable("x", "<X>")
2803
        >>> x_0 = z3.Int("x_0")
2804
        >>> f = z3_eq(x_0, z3.IntVal(3))
2805
        >>> z3_solver = z3.Solver()
2806
        >>> z3_solver.add(f)
2807
        >>> z3_solver.check()
2808
        sat
2809
        >>> model = z3_solver.model()
2810
        >>> result = solver.extract_model_value(x, model, {x: x_0}, {x}, set())
2811
        >>> result.value
2812
        '<X>'
2813
        >>> str(result)
2814
        'xxx'
2815

2816
        **"Int" Variables:**
2817

2818
        >>> y = language.Variable("y", "<Y>")
2819
        >>> y_0 = z3.Int("y_0")
2820
        >>> f = z3_eq(y_0, z3.IntVal(5))
2821
        >>> z3_solver = z3.Solver()
2822
        >>> z3_solver.add(f)
2823
        >>> z3_solver.check()
2824
        sat
2825
        >>> model = z3_solver.model()
2826
        >>> DerivationTree.next_id = 1
2827
        >>> solver.extract_model_value(y, model, {y: y_0}, set(), {y})
2828
        DerivationTree('<Y>', (DerivationTree('<digit>', (DerivationTree('5', (), id=1),), id=2),), id=3)
2829

2830
        **"Flexible" Variables:**
2831

2832
        >>> f = z3_eq(x.to_smt(), z3.StringVal("xxxxx"))
2833
        >>> z3_solver = z3.Solver()
2834
        >>> z3_solver.add(f)
2835
        >>> z3_solver.check()
2836
        sat
2837
        >>> model = z3_solver.model()
2838
        >>> result = solver.extract_model_value(x, model, {}, set(), set())
2839
        >>> result.value
2840
        '<X>'
2841
        >>> str(result)
2842
        'xxxxx'
2843

2844
        **Special Number Formats**
2845

2846
        It may happen that the solver returns, e.g., "1" as a solution for an int
2847
        variable, but the grammar only recognizes "+001". We support this case, i.e.,
2848
        an optional "+" and optional 0 padding; an optional 0 padding for negative
2849
        numbers is also supported.
2850

2851
        >>> grammar = {
2852
        ...     "<start>": ["<int>"],
2853
        ...     "<int>": ["<sign>00<leaddigit><digits>"],
2854
        ...     "<sign>": ["-", "+"],
2855
        ...     "<digits>": ["", "<digit><digits>"],
2856
        ...     "<digit>": list("0123456789"),
2857
        ...     "<leaddigit>": list("123456789"),
2858
        ... }
2859
        >>> solver = ISLaSolver(grammar)
2860

2861
        >>> i = language.Variable("i", "<int>")
2862
        >>> i_0 = z3.Int("i_0")
2863
        >>> f = z3_eq(i_0, z3.IntVal(5))
2864

2865
        >>> z3_solver = z3.Solver()
2866
        >>> z3_solver.add(f)
2867
        >>> z3_solver.check()
2868
        sat
2869

2870
        The following test works when run from the IDE, but unfortunately not when
2871
        started from CI/the `test_doctests.py` script. Thus, we only show it as code
2872
        block (we added a unit test as a replacement)::
2873

2874
            model = z3_solver.model()
2875
            print(solver.extract_model_value(i, model, {i: i_0}, set(), {i}))
2876
            # Prints: +005
2877

2878
        :param var: The variable for which to extract a solution from the model.
2879
        :param model: The model containing the solution.
2880
        :param fresh_var_map: A map from variables to fresh symbols for "length" and
2881
                              "int" variables.
2882
        :param length_vars: The set of "length" variables.
2883
        :param int_vars: The set of "int" variables.
2884
        :return: A :class:`~isla.derivation_tree.DerivationTree` object corresponding
2885
                 to the solution in :code:`model`.
2886
        """
2887

2888
        f_flex_vars = self.extract_model_value_flexible_var
1✔
2889
        f_int_vars = partial(self.extract_model_value_int_var, f_flex_vars)
1✔
2890
        f_length_vars = partial(self.extract_model_value_length_var, f_int_vars)
1✔
2891
        f_num_vars = partial(self.extract_model_value_numeric_var, f_length_vars)
1✔
2892

2893
        return f_num_vars(var, model, fresh_var_map, length_vars, int_vars)
1✔
2894

2895
    ExtractModelValueFallbackType = Callable[
1✔
2896
        [
2897
            language.Variable,
2898
            z3.ModelRef,
2899
            Dict[language.Variable, z3.ExprRef],
2900
            Set[language.Variable],
2901
            Set[language.Variable],
2902
        ],
2903
        DerivationTree,
2904
    ]
2905

2906
    def extract_model_value_numeric_var(
1✔
2907
        self,
2908
        fallback: ExtractModelValueFallbackType,
2909
        var: language.Variable,
2910
        model: z3.ModelRef,
2911
        fresh_var_map: Dict[language.Variable, z3.ExprRef],
2912
        length_vars: Set[language.Variable],
2913
        int_vars: Set[language.Variable],
2914
    ) -> DerivationTree:
2915
        """
2916
        Addresses the case of numeric variables from
2917
        :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2918

2919
        :param fallback: The function to call if this function is not responsible.
2920
        :param var: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2921
        :param model: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2922
        :param fresh_var_map: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2923
        :param length_vars: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2924
        :param int_vars: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2925
        :return: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2926
        """
2927
        if not var.is_numeric():
1✔
2928
            return fallback(var, model, fresh_var_map, length_vars, int_vars)
1✔
2929

2930
        z3_var = z3.String(var.name)
1✔
2931
        if z3_var.decl() in model.decls():
1✔
2932
            model_value = model[z3_var]
1✔
2933
        else:
2934
            assert var in int_vars
1✔
2935
            assert var in fresh_var_map
1✔
2936

2937
            model_value = model[fresh_var_map[var]]
1✔
2938

2939
            if model_value is None:
1✔
2940
                # This can happen for universally true formulas, e.g., `x = x`.
2941
                # In that case, we return a random integer.
2942
                model_value = z3.IntVal(random.randint(-sys.maxsize, sys.maxsize))
1✔
2943

2944
        assert (
1✔
2945
            model_value is not None
2946
        ), f"No solution for variable {var} found in model {model}"
2947

2948
        string_value = smt_string_val_to_string(model_value)
1✔
2949
        assert string_value
1✔
2950
        assert (
1✔
2951
            string_value.isnumeric()
2952
            or string_value[0] == "-"
2953
            and string_value[1:].isnumeric()
2954
        )
2955

2956
        return DerivationTree(string_value, ())
1✔
2957

2958
    def extract_model_value_length_var(
1✔
2959
        self,
2960
        fallback: ExtractModelValueFallbackType,
2961
        var: language.Variable,
2962
        model: z3.ModelRef,
2963
        fresh_var_map: Dict[language.Variable, z3.ExprRef],
2964
        length_vars: Set[language.Variable],
2965
        int_vars: Set[language.Variable],
2966
    ) -> DerivationTree:
2967
        """
2968
        Addresses the case of length variables from
2969
        :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2970

2971
        :param fallback: The function to call if this function is not responsible.
2972
        :param var: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2973
        :param model: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2974
        :param fresh_var_map: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2975
        :param length_vars: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2976
        :param int_vars: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2977
        :return: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2978
        """
2979
        if var not in length_vars:
1✔
2980
            return fallback(var, model, fresh_var_map, length_vars, int_vars)
1✔
2981

2982
        return self.safe_create_fixed_length_tree(var, model, fresh_var_map)
1✔
2983

2984
    def extract_model_value_int_var(
1✔
2985
        self,
2986
        fallback: ExtractModelValueFallbackType,
2987
        var: language.Variable,
2988
        model: z3.ModelRef,
2989
        fresh_var_map: Dict[language.Variable, z3.ExprRef],
2990
        length_vars: Set[language.Variable],
2991
        int_vars: Set[language.Variable],
2992
    ) -> DerivationTree:
2993
        """
2994
        Addresses the case of int variables from
2995
        :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2996

2997
        :param fallback: The function to call if this function is not responsible.
2998
        :param var: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2999
        :param model: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
3000
        :param fresh_var_map: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
3001
        :param length_vars: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
3002
        :param int_vars: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
3003
        :return: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
3004
        """
3005
        if var not in int_vars:
1✔
3006
            return fallback(var, model, fresh_var_map, length_vars, int_vars)
1✔
3007

3008
        str_model_value = model[fresh_var_map[var]].as_string()
1✔
3009

3010
        try:
1✔
3011
            int_model_value = int(str_model_value)
1✔
3012
        except ValueError:
×
3013
            raise RuntimeError(f"Value {str_model_value} for {var} is not a number")
×
3014

3015
        var_type = var.n_type
1✔
3016

3017
        try:
1✔
3018
            return self.parse(
1✔
3019
                str(int_model_value),
3020
                var_type,
3021
                silent=True,
3022
            )
3023
        except SyntaxError:
1✔
3024
            # This may happen, e.g, with padded values: Only "01" is a valid
3025
            # solution, but not "1". Similarly, a grammar may expect "+1", but
3026
            # "1" is returned by the solver. We support the number format
3027
            # `[+-]0*<digits>`. Whenever the grammar recognizes at least this
3028
            # set for the nonterminal in question, we return a derivation tree.
3029
            # Otherwise, a RuntimeError is raised.
3030

3031
            z3_solver = z3.Solver()
1✔
3032
            z3_solver.set("timeout", 300)
1✔
3033

3034
            maybe_plus_re = z3.Option(z3.Re("+"))
1✔
3035
            zeroes_padding_re = z3.Star(z3.Re("0"))
1✔
3036

3037
            # TODO: Ensure symbols are fresh
3038
            maybe_plus_var = z3.String("__plus")
1✔
3039
            zeroes_padding_var = z3.String("__padding")
1✔
3040

3041
            z3_solver.add(z3.InRe(maybe_plus_var, maybe_plus_re))
1✔
3042
            z3_solver.add(z3.InRe(zeroes_padding_var, zeroes_padding_re))
1✔
3043

3044
            z3_solver.add(
1✔
3045
                z3.InRe(
3046
                    z3.Concat(
3047
                        maybe_plus_var if int_model_value >= 0 else z3.StringVal("-"),
3048
                        zeroes_padding_var,
3049
                        z3.StringVal(
3050
                            str_model_value
3051
                            if int_model_value >= 0
3052
                            else str(-int_model_value)
3053
                        ),
3054
                    ),
3055
                    self.extract_regular_expression(var.n_type),
3056
                )
3057
            )
3058

3059
            if z3_solver.check() != z3.sat:
1✔
3060
                raise RuntimeError(
×
3061
                    "Could not parse a numeric solution "
3062
                    + f"({str_model_value}) for variable "
3063
                    + f"{var} of type '{var.n_type}'; try "
3064
                    + "running the solver without optimized Z3 queries or make "
3065
                    + "sure that ranges are restricted to syntactically valid "
3066
                    + "ones (according to the grammar).",
3067
                )
3068

3069
            return self.parse(
1✔
3070
                (
3071
                    z3_solver.model()[maybe_plus_var].as_string()
3072
                    if int_model_value >= 0
3073
                    else "-"
3074
                )
3075
                + z3_solver.model()[zeroes_padding_var].as_string()
3076
                + (str_model_value if int_model_value >= 0 else str(-int_model_value)),
3077
                var.n_type,
3078
            )
3079

3080
    def extract_model_value_flexible_var(
1✔
3081
        self,
3082
        var: language.Variable,
3083
        model: z3.ModelRef,
3084
        fresh_var_map: Dict[language.Variable, z3.ExprRef],
3085
        length_vars: Set[language.Variable],
3086
        int_vars: Set[language.Variable],
3087
    ) -> DerivationTree:
3088
        """
3089
        Addresses the case of "flexible" variables from
3090
        :meth:`~isla.solver.ISLaSolver.extract_model_value`.
3091

3092
        :param fallback: The function to call if this function is not responsible.
3093
        :param var: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
3094
        :param model: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
3095
        :param fresh_var_map: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
3096
        :param length_vars: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
3097
        :param int_vars: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
3098
        :return: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
3099
        """
3100

3101
        return self.parse(
1✔
3102
            smt_string_val_to_string(model[z3.String(var.name)]),
3103
            var.n_type,
3104
        )
3105

3106
    @staticmethod
1✔
3107
    def infer_variable_contexts(
1✔
3108
        variables: Set[language.Variable], smt_formulas: ImmutableList[z3.BoolRef]
3109
    ) -> Dict[str, Set[language.Variable]]:
3110
        """
3111
        Divides the given variables into
3112

3113
        1. those that occur only in :code:`length(...)` contexts,
3114
        2. those that occur only in :code:`str.to.int(...)` contexts, and
3115
        3. "flexible" constants occurring in other/various contexts.
3116

3117
        >>> x = language.Variable("x", "<X>")
3118
        >>> y = language.Variable("y", "<Y>")
3119

3120
        Two variables in an arbitrary context.
3121

3122
        >>> f = z3_eq(x.to_smt(), y.to_smt())
3123
        >>> contexts = ISLaSolver.infer_variable_contexts({x, y}, (f,))
3124
        >>> contexts["length"]
3125
        set()
3126
        >>> contexts["int"]
3127
        set()
3128
        >>> contexts["flexible"] == {language.Variable("x", "<X>"), language.Variable("y", "<Y>")}
3129
        True
3130

3131
        Variable x occurs in a length context, variable y in an arbitrary one.
3132

3133
        >>> f = z3.And(
3134
        ...     z3.Length(x.to_smt()) > z3.IntVal(10),
3135
        ...     z3_eq(y.to_smt(), z3.StringVal("y")))
3136
        >>> ISLaSolver.infer_variable_contexts({x, y}, (f,))
3137
        {'length': {Variable("x", "<X>")}, 'int': set(), 'flexible': {Variable("y", "<Y>")}}
3138

3139
        Variable x occurs in a length context, y does not occur.
3140

3141
        >>> f = z3.Length(x.to_smt()) > z3.IntVal(10)
3142
        >>> ISLaSolver.infer_variable_contexts({x, y}, (f,))
3143
        {'length': {Variable("x", "<X>")}, 'int': set(), 'flexible': {Variable("y", "<Y>")}}
3144

3145
        Variables x and y both occur in a length context.
3146

3147
        >>> f = z3.Length(x.to_smt()) > z3.Length(y.to_smt())
3148
        >>> contexts = ISLaSolver.infer_variable_contexts({x, y}, (f,))
3149
        >>> contexts["length"] == {language.Variable("x", "<X>"), language.Variable("y", "<Y>")}
3150
        True
3151
        >>> contexts["int"]
3152
        set()
3153
        >>> contexts["flexible"]
3154
        set()
3155

3156
        Variable x occurs in a :code:`str.to.int` context.
3157

3158
        >>> f = z3.StrToInt(x.to_smt()) > z3.IntVal(17)
3159
        >>> ISLaSolver.infer_variable_contexts({x}, (f,))
3160
        {'length': set(), 'int': {Variable("x", "<X>")}, 'flexible': set()}
3161

3162
        Now, x also occurs in a different context; it's "flexible" now.
3163

3164
        >>> f = z3.And(
3165
        ...     z3.StrToInt(x.to_smt()) > z3.IntVal(17),
3166
        ...     z3_eq(x.to_smt(), z3.StringVal("17")))
3167
        >>> ISLaSolver.infer_variable_contexts({x}, (f,))
3168
        {'length': set(), 'int': set(), 'flexible': {Variable("x", "<X>")}}
3169

3170
        :param variables: The constants to divide/filter from.
3171
        :param smt_formulas: The SMT formulas to consider in the filtering.
3172
        :return: A pair of constants occurring in `str.len` contexts, and the
3173
        remaining ones. The union of both sets equals `variables`, and both sets
3174
        are disjoint.
3175
        """
3176

3177
        parent_relationships = reduce(
1✔
3178
            merge_dict_of_sets,
3179
            [parent_relationships_in_z3_expr(formula) for formula in smt_formulas],
3180
            {},
3181
        )
3182

3183
        contexts: Dict[language.Variable, Set[int]] = {
1✔
3184
            var: {
3185
                expr.decl().kind()
3186
                for expr in parent_relationships.get(var.to_smt(), set())
3187
            }
3188
            or {-1}
3189
            for var in variables
3190
        }
3191

3192
        # The set `length_vars` consists of all variables that only occur in
3193
        # `str.len(...)` context.
3194
        length_vars: Set[language.Variable] = {
1✔
3195
            var
3196
            for var in variables
3197
            if all(context == z3.Z3_OP_SEQ_LENGTH for context in contexts[var])
3198
        }
3199

3200
        # The set `int_vars` consists of all variables that only occur in
3201
        # `str.to.int(...)` context.
3202
        int_vars: Set[language.Variable] = {
1✔
3203
            var
3204
            for var in variables
3205
            if all(context == z3.Z3_OP_STR_TO_INT for context in contexts[var])
3206
        }
3207

3208
        # "Flexible" variables are the remaining ones.
3209
        flexible_vars = variables.difference(length_vars).difference(int_vars)
1✔
3210

3211
        return {"length": length_vars, "int": int_vars, "flexible": flexible_vars}
1✔
3212

3213
    def generate_language_constraints(
1✔
3214
        self,
3215
        constants: Iterable[language.Variable],
3216
        tree_substitutions: Dict[language.Variable, DerivationTree],
3217
    ) -> List[z3.BoolRef]:
3218
        formulas: List[z3.BoolRef] = []
1✔
3219
        for constant in constants:
1✔
3220
            if constant.is_numeric():
1✔
3221
                regex = z3.Union(
×
3222
                    z3.Re("0"),
3223
                    z3.Concat(z3.Range("1", "9"), z3.Star(z3.Range("0", "9"))),
3224
                )
3225
            elif constant in tree_substitutions:
1✔
3226
                # We have a more concrete shape of the desired instantiation available
3227
                regexes = [
1✔
3228
                    (
3229
                        self.extract_regular_expression(t)
3230
                        if is_nonterminal(t)
3231
                        else z3.Re(t)
3232
                    )
3233
                    for t in split_str_with_nonterminals(
3234
                        str(tree_substitutions[constant])
3235
                    )
3236
                ]
3237
                assert regexes
1✔
3238
                regex = z3.Concat(*regexes) if len(regexes) > 1 else regexes[0]
1✔
3239
            else:
3240
                regex = self.extract_regular_expression(constant.n_type)
1✔
3241

3242
            formulas.append(z3.InRe(z3.String(constant.name), regex))
1✔
3243

3244
        return formulas
1✔
3245

3246
    def rename_instantiated_variables_in_smt_formulas(self, smt_formulas):
1✔
3247
        old_smt_formulas = smt_formulas
1✔
3248
        smt_formulas = []
1✔
3249
        for subformula in old_smt_formulas:
1✔
3250
            subst_var: language.Variable
3251
            subst_tree: DerivationTree
3252

3253
            new_smt_formula: z3.BoolRef = subformula.formula
1✔
3254
            new_substitutions = subformula.substitutions
1✔
3255
            new_instantiated_variables = subformula.instantiated_variables
1✔
3256

3257
            for subst_var, subst_tree in subformula.substitutions.items():
1✔
3258
                new_name = f"{subst_tree.value}_{subst_tree.id}"
1✔
3259
                new_var = language.BoundVariable(new_name, subst_var.n_type)
1✔
3260

3261
                new_smt_formula = cast(
1✔
3262
                    z3.BoolRef,
3263
                    z3_subst(new_smt_formula, {subst_var.to_smt(): new_var.to_smt()}),
3264
                )
3265
                new_substitutions = {
1✔
3266
                    new_var if k == subst_var else k: v
3267
                    for k, v in new_substitutions.items()
3268
                }
3269
                new_instantiated_variables = {
1✔
3270
                    new_var if v == subst_var else v for v in new_instantiated_variables
3271
                }
3272

3273
            smt_formulas.append(
1✔
3274
                language.SMTFormula(
3275
                    new_smt_formula,
3276
                    *subformula.free_variables_,
3277
                    instantiated_variables=new_instantiated_variables,
3278
                    substitutions=new_substitutions,
3279
                )
3280
            )
3281

3282
        return smt_formulas
1✔
3283

3284
    def process_new_states(
1✔
3285
        self, new_states: List[SolutionState]
3286
    ) -> List[DerivationTree]:
3287
        return [
1✔
3288
            tree
3289
            for new_state in new_states
3290
            for tree in self.process_new_state(new_state)
3291
        ]
3292

3293
    def process_new_state(self, new_state: SolutionState) -> List[DerivationTree]:
1✔
3294
        new_state = self.instantiate_structural_predicates(new_state)
1✔
3295
        new_states = self.establish_invariant(new_state)
1✔
3296
        new_states = [
1✔
3297
            self.remove_nonmatching_universal_quantifiers(new_state)
3298
            for new_state in new_states
3299
        ]
3300
        new_states = [
1✔
3301
            self.remove_infeasible_universal_quantifiers(new_state)
3302
            for new_state in new_states
3303
        ]
3304

3305
        if self.activate_unsat_support and not self.currently_unsat_checking:
1✔
3306
            self.currently_unsat_checking = True
1✔
3307

3308
            for new_state in list(new_states):
1✔
3309
                if new_state.constraint == sc.true():
1✔
3310
                    continue
×
3311

3312
                # Remove states with unsatisfiable SMT-LIB formulas.
3313
                if any(
1✔
3314
                    isinstance(f, language.SMTFormula)
3315
                    for f in split_conjunction(new_state.constraint)
3316
                ) and not is_successful(
3317
                    self.eliminate_all_semantic_formulas(
3318
                        new_state, max_instantiations=1
3319
                    ).bind(lambda a: Some(a) if a else Nothing)
3320
                ):
3321
                    new_states.remove(new_state)
1✔
3322
                    self.logger.debug(
1✔
3323
                        "Dropping state %s, unsatisfiable SMT formulas", new_state
3324
                    )
3325

3326
                # Remove states with unsatisfiable existential formulas.
3327
                existential_formulas = [
1✔
3328
                    f
3329
                    for f in split_conjunction(new_state.constraint)
3330
                    if isinstance(f, language.ExistsFormula)
3331
                ]
3332
                for existential_formula in existential_formulas:
1✔
3333
                    old_start_time = self.start_time
1✔
3334
                    old_timeout_seconds = self.timeout_seconds
1✔
3335
                    old_queue = list(self.queue)
1✔
3336
                    old_solutions = list(self.solutions)
1✔
3337

3338
                    self.queue = []
1✔
3339
                    self.solutions = []
1✔
3340
                    check_state = SolutionState(existential_formula, new_state.tree)
1✔
3341
                    heapq.heappush(self.queue, (0, check_state))
1✔
3342
                    self.start_time = int(time.time())
1✔
3343
                    self.timeout_seconds = 2
1✔
3344

3345
                    try:
1✔
3346
                        self.solve()
1✔
3347
                    except StopIteration:
1✔
3348
                        new_states.remove(new_state)
1✔
3349
                        self.logger.debug(
1✔
3350
                            "Dropping state %s, unsatisfiable existential formula %s",
3351
                            new_state,
3352
                            existential_formula,
3353
                        )
3354
                        break
1✔
3355
                    finally:
3356
                        self.start_time = old_start_time
1✔
3357
                        self.timeout_seconds = old_timeout_seconds
1✔
3358
                        self.queue = old_queue
1✔
3359
                        self.solutions = old_solutions
1✔
3360

3361
            self.currently_unsat_checking = False
1✔
3362

3363
        assert all(
1✔
3364
            state.tree.find_node(tree) is not None
3365
            for state in new_states
3366
            for quantified_formula in split_conjunction(state.constraint)
3367
            if isinstance(quantified_formula, language.QuantifiedFormula)
3368
            for _, tree in quantified_formula.in_variable.filter(lambda t: True)
3369
        )
3370

3371
        solution_trees = [
1✔
3372
            new_state.tree
3373
            for new_state in new_states
3374
            if self.state_is_valid_or_enqueue(new_state)
3375
        ]
3376

3377
        for tree in solution_trees:
1✔
3378
            self.cost_computer.signal_tree_output(tree)
1✔
3379

3380
        return solution_trees
1✔
3381

3382
    def state_is_valid_or_enqueue(self, state: SolutionState) -> bool:
1✔
3383
        """
3384
        Returns True if the given state is valid, such that it can be yielded. Returns False and enqueues the state
3385
        if the state is not yet complete, otherwise returns False and discards the state.
3386
        """
3387

3388
        if state.complete():
1✔
3389
            for _, subtree in state.tree.paths():
1✔
3390
                if subtree.children:
1✔
3391
                    self.seen_coverages.add(
1✔
3392
                        expansion_key(subtree.value, subtree.children)
3393
                    )
3394

3395
            assert state.formula_satisfied(self.grammar).is_true()
1✔
3396
            return True
1✔
3397

3398
        # Helps in debugging below assertion:
3399
        # [(predicate_formula, [
3400
        #     arg for arg in predicate_formula.args
3401
        #     if isinstance(arg, DerivationTree) and not state.tree.find_node(arg)])
3402
        #  for predicate_formula in get_conjuncts(state.constraint)
3403
        #  if isinstance(predicate_formula, language.StructuralPredicateFormula)]
3404

3405
        self.assert_no_dangling_predicate_argument_trees(state)
1✔
3406
        self.assert_no_dangling_smt_formula_argument_trees(state)
1✔
3407

3408
        if (
1✔
3409
            self.enforce_unique_trees_in_queue
3410
            and state.tree.structural_hash() in self.tree_hashes_in_queue
3411
        ):
3412
            # Some structures can arise as well from tree insertion (existential
3413
            # quantifier elimination) and expansion; also, tree insertion can yield
3414
            # different trees that have intersecting expansions. We drop those to output
3415
            # more diverse solutions (numbers for SMT solutions and free nonterminals
3416
            # are configurable, so you get more outputs by playing with those!).
3417
            self.logger.debug("Discarding state %s, tree already in queue", state)
1✔
3418
            return False
1✔
3419

3420
        if hash(state) in self.state_hashes_in_queue:
1✔
3421
            self.logger.debug("Discarding state %s, already in queue", state)
1✔
3422
            return False
1✔
3423

3424
        if self.propositionally_unsatisfiable(state.constraint):
1✔
3425
            self.logger.debug("Discarding state %s", state)
1✔
3426
            return False
1✔
3427

3428
        state = SolutionState(
1✔
3429
            state.constraint, state.tree, level=self.current_level + 1
3430
        )
3431

3432
        self.recompute_costs()
1✔
3433

3434
        cost = self.compute_cost(state)
1✔
3435
        heapq.heappush(self.queue, (cost, state))
1✔
3436
        self.tree_hashes_in_queue.add(state.tree.structural_hash())
1✔
3437
        self.state_hashes_in_queue.add(hash(state))
1✔
3438

3439
        if self.debug:
1✔
3440
            self.state_tree[self.current_state].append(state)
1✔
3441
            self.costs[state] = cost
1✔
3442

3443
        self.logger.debug(
1✔
3444
            "Pushing new state (%s, %s) (hash %d, cost %f)",
3445
            state.constraint,
3446
            state.tree.to_string(show_open_leaves=True, show_ids=True),
3447
            hash(state),
3448
            cost,
3449
        )
3450
        self.logger.debug("Queue length: %d", len(self.queue))
1✔
3451
        if len(self.queue) % 100 == 0:
1✔
3452
            self.logger.info("Queue length: %d", len(self.queue))
1✔
3453

3454
        return False
1✔
3455

3456
    def recompute_costs(self):
1✔
3457
        if self.step_cnt % 400 != 0 or self.step_cnt <= self.last_cost_recomputation:
1✔
3458
            return
1✔
3459

3460
        self.last_cost_recomputation = self.step_cnt
1✔
3461
        self.logger.info(
1✔
3462
            f"Recomputing costs in queue after {self.step_cnt} solver steps"
3463
        )
3464
        old_queue = list(self.queue)
1✔
3465
        self.queue = []
1✔
3466
        for _, state in old_queue:
1✔
3467
            cost = self.compute_cost(state)
1✔
3468
            heapq.heappush(self.queue, (cost, state))
1✔
3469

3470
    def assert_no_dangling_smt_formula_argument_trees(
1✔
3471
        self, state: SolutionState
3472
    ) -> None:
3473
        if not assertions_activated() and not self.debug:
1✔
3474
            return
1✔
3475

3476
        dangling_smt_formula_argument_trees = [
1✔
3477
            (smt_formula, arg)
3478
            for smt_formula in language.FilterVisitor(
3479
                lambda f: isinstance(f, language.SMTFormula)
3480
            ).collect(state.constraint)
3481
            for arg in cast(language.SMTFormula, smt_formula).substitutions.values()
3482
            if isinstance(arg, DerivationTree) and state.tree.find_node(arg) is None
3483
        ]
3484

3485
        if dangling_smt_formula_argument_trees:
1✔
3486
            message = "Dangling SMT formula arguments: ["
×
3487
            message += ", ".join(
×
3488
                [
3489
                    str(f) + ", " + repr(a)
3490
                    for f, a in dangling_smt_formula_argument_trees
3491
                ]
3492
            )
3493
            message += "]"
×
3494
            assert False, message
×
3495

3496
    def assert_no_dangling_predicate_argument_trees(self, state: SolutionState) -> None:
1✔
3497
        if not assertions_activated() and not self.debug:
1✔
3498
            return
1✔
3499

3500
        dangling_predicate_argument_trees = [
1✔
3501
            (predicate_formula, arg)
3502
            for predicate_formula in language.FilterVisitor(
3503
                lambda f: isinstance(f, language.StructuralPredicateFormula)
3504
            ).collect(state.constraint)
3505
            for arg in cast(language.StructuralPredicateFormula, predicate_formula).args
3506
            if isinstance(arg, DerivationTree) and state.tree.find_node(arg) is None
3507
        ]
3508

3509
        if dangling_predicate_argument_trees:
1✔
3510
            message = "Dangling predicate arguments: ["
×
3511
            message += ", ".join(
×
3512
                [str(f) + ", " + repr(a) for f, a in dangling_predicate_argument_trees]
3513
            )
3514
            message += "]"
×
3515
            assert False, message
×
3516

3517
    def propositionally_unsatisfiable(self, formula: language.Formula) -> bool:
1✔
3518
        return formula == sc.false()
1✔
3519

3520
        # NOTE: Deactivated propositional check for performance reasons
3521
        # z3_formula = language.isla_to_smt_formula(formula, replace_untranslatable_with_predicate=True)
3522
        # solver = z3.Solver()
3523
        # solver.add(z3_formula)
3524
        # return solver.check() == z3.unsat
3525

3526
    def establish_invariant(self, state: SolutionState) -> List[SolutionState]:
1✔
3527
        clauses = to_dnf_clauses(convert_to_nnf(state.constraint))
1✔
3528
        return [
1✔
3529
            SolutionState(reduce(lambda a, b: a & b, clause, sc.true()), state.tree)
3530
            for clause in clauses
3531
        ]
3532

3533
    def compute_cost(self, state: SolutionState) -> float:
1✔
3534
        if state.constraint == sc.true():
1✔
3535
            return 0
1✔
3536

3537
        return self.cost_computer.compute_cost(state)
1✔
3538

3539
    def remove_nonmatching_universal_quantifiers(
1✔
3540
        self, state: SolutionState
3541
    ) -> SolutionState:
3542
        conjuncts = [conjunct for conjunct in get_conjuncts(state.constraint)]
1✔
3543
        deleted = False
1✔
3544

3545
        for idx, universal_formula in reversed(list(enumerate(conjuncts))):
1✔
3546
            if not isinstance(universal_formula, language.ForallFormula):
1✔
3547
                continue
1✔
3548

3549
            if (
1✔
3550
                universal_formula.in_variable.is_complete()
3551
                and not matches_for_quantified_formula(universal_formula, self.grammar)
3552
            ):
3553
                deleted = True
1✔
3554
                del conjuncts[idx]
1✔
3555

3556
        if not deleted:
1✔
3557
            return state
1✔
3558

3559
        return SolutionState(sc.conjunction(*conjuncts), state.tree)
1✔
3560

3561
    def remove_infeasible_universal_quantifiers(
1✔
3562
        self, state: SolutionState
3563
    ) -> SolutionState:
3564
        conjuncts = get_conjuncts(state.constraint)
1✔
3565
        one_removed = False
1✔
3566

3567
        for idx, universal_formula in reversed(list(enumerate(conjuncts))):
1✔
3568
            if not isinstance(universal_formula, language.ForallFormula):
1✔
3569
                continue
1✔
3570

3571
            matches = matches_for_quantified_formula(universal_formula, self.grammar)
1✔
3572

3573
            all_matches_matched = all(
1✔
3574
                universal_formula.is_already_matched(
3575
                    match[universal_formula.bound_variable][1]
3576
                )
3577
                for match in matches
3578
            )
3579

3580
            def some_leaf_might_match() -> bool:
1✔
3581
                return any(
1✔
3582
                    self.quantified_formula_might_match(
3583
                        universal_formula, leaf_path, universal_formula.in_variable
3584
                    )
3585
                    for leaf_path, _ in universal_formula.in_variable.open_leaves()
3586
                )
3587

3588
            if all_matches_matched and not some_leaf_might_match():
1✔
3589
                one_removed = True
1✔
3590
                del conjuncts[idx]
1✔
3591

3592
        return (
1✔
3593
            state
3594
            if not one_removed
3595
            else SolutionState(
3596
                reduce(lambda a, b: a & b, conjuncts, sc.true()),
3597
                state.tree,
3598
            )
3599
        )
3600

3601
    def quantified_formula_might_match(
1✔
3602
        self,
3603
        qfd_formula: language.QuantifiedFormula,
3604
        path_to_nonterminal: Path,
3605
        tree: DerivationTree,
3606
    ) -> bool:
3607
        return quantified_formula_might_match(
1✔
3608
            qfd_formula,
3609
            path_to_nonterminal,
3610
            tree,
3611
            self.grammar,
3612
            self.graph.reachable,
3613
        )
3614

3615
    def extract_regular_expression(self, nonterminal: str) -> z3.ReRef:
1✔
3616
        if nonterminal in self.regex_cache:
1✔
3617
            return self.regex_cache[nonterminal]
1✔
3618

3619
        # For definitions like `<a> ::= <b>`, we only compute the regular expression
3620
        # for `<b>`. That way, we might save some calls if `<b>` is used multiple times
3621
        # (e.g., as in `<byte>`).
3622
        canonical_expansions = self.canonical_grammar[nonterminal]
1✔
3623

3624
        if (
1✔
3625
            len(canonical_expansions) == 1
3626
            and len(canonical_expansions[0]) == 1
3627
            and is_nonterminal(canonical_expansions[0][0])
3628
        ):
3629
            sub_nonterminal = canonical_expansions[0][0]
1✔
3630
            assert (
1✔
3631
                nonterminal != sub_nonterminal
3632
            ), f"Expansion {nonterminal} => {sub_nonterminal}: Infinite recursion!"
3633
            return self.regex_cache.setdefault(
1✔
3634
                nonterminal, self.extract_regular_expression(sub_nonterminal)
3635
            )
3636

3637
        # Similarly, for definitions like `<a> ::= <b> " x " <c>`, where `<b>` and `<c>`
3638
        # don't reach `<a>`, we only compute the regular expressions for `<b>` and `<c>`
3639
        # and return a concatenation. This also saves us expensive conversions (e.g.,
3640
        # for `<seq> ::= <byte> <byte>`).
3641
        if (
1✔
3642
            len(canonical_expansions) == 1
3643
            and any(is_nonterminal(elem) for elem in canonical_expansions[0])
3644
            and all(
3645
                not is_nonterminal(elem)
3646
                or elem != nonterminal
3647
                and not self.graph.reachable(elem, nonterminal)
3648
                for elem in canonical_expansions[0]
3649
            )
3650
        ):
3651
            result_elements: List[z3.ReRef] = [
1✔
3652
                (
3653
                    z3.Re(elem)
3654
                    if not is_nonterminal(elem)
3655
                    else self.extract_regular_expression(elem)
3656
                )
3657
                for elem in canonical_expansions[0]
3658
            ]
3659
            return self.regex_cache.setdefault(nonterminal, z3.Concat(*result_elements))
1✔
3660

3661
        regex_conv = RegexConverter(
1✔
3662
            self.grammar,
3663
            compress_unions=True,
3664
            max_num_expansions=self.grammar_unwinding_threshold,
3665
        )
3666
        regex = regex_conv.to_regex(nonterminal, convert_to_z3=False)
1✔
3667
        self.logger.debug(
1✔
3668
            f"Computed regular expression for nonterminal {nonterminal}:\n{regex}"
3669
        )
3670
        z3_regex = regex_to_z3(regex)
1✔
3671

3672
        if assertions_activated():
1✔
3673
            # Check correctness of regular expression
3674
            grammar = self.graph.subgraph(nonterminal).to_grammar()
1✔
3675

3676
            # L(regex) \subseteq L(grammar)
3677
            self.logger.debug(
1✔
3678
                "Checking L(regex) \\subseteq L(grammar) for "
3679
                + "nonterminal '%s' and regex '%s'",
3680
                nonterminal,
3681
                regex,
3682
            )
3683
            parser = EarleyParser(grammar)
1✔
3684
            c = z3.String("c")
1✔
3685
            prev: Set[str] = set()
1✔
3686
            for _ in range(100):
1✔
3687
                s = z3.Solver()
1✔
3688
                s.add(z3.InRe(c, z3_regex))
1✔
3689
                for inp in prev:
1✔
3690
                    s.add(z3.Not(c == z3.StringVal(inp)))
1✔
3691
                if s.check() != z3.sat:
1✔
3692
                    self.logger.debug(
×
3693
                        "Cannot find the %d-th solution for regex %s (timeout).\nThis is *not* a problem "
3694
                        "if there not that many solutions (for regexes with finite language), or if we "
3695
                        "are facing a meaningless timeout of the solver.",
3696
                        len(prev) + 1,
3697
                        regex,
3698
                    )
3699
                    break
×
3700
                new_inp = smt_string_val_to_string(s.model()[c])
1✔
3701
                try:
1✔
3702
                    next(parser.parse(new_inp))
1✔
3703
                except SyntaxError:
×
3704
                    assert (
×
3705
                        False
3706
                    ), f"Input '{new_inp}' from regex language is not in grammar language."
3707
                prev.add(new_inp)
1✔
3708

3709
        self.regex_cache[nonterminal] = z3_regex
1✔
3710

3711
        return z3_regex
1✔
3712

3713

3714
class CostComputer(ABC):
1✔
3715
    def compute_cost(self, state: SolutionState) -> float:
1✔
3716
        """
3717
        Computes a cost value for the given state. States with lower cost
3718
        will be preferred in the analysis.
3719

3720
        :param state: The state for which to compute a cost.
3721
        :return: The cost value.
3722
        """
3723
        raise NotImplementedError()
×
3724

3725
    def signal_tree_output(self, tree: DerivationTree) -> None:
1✔
3726
        """
3727
        Should be called when a tree is output as a solution. Used to
3728
        update internal information for cost computation.
3729

3730
        :param tree The tree that is output as a solution.
3731
        :return: Nothing.
3732
        """
3733
        raise NotImplementedError()
×
3734

3735

3736
class GrammarBasedBlackboxCostComputer(CostComputer):
1✔
3737
    def __init__(
1✔
3738
        self,
3739
        cost_settings: CostSettings,
3740
        graph: gg.GrammarGraph,
3741
        reset_coverage_after_n_round_with_no_coverage: int = 100,
3742
        symbol_costs: Optional[Dict[str, int]] = None,
3743
    ):
3744
        self.cost_settings = cost_settings
1✔
3745
        self.graph = graph
1✔
3746

3747
        self.covered_k_paths: Set[Tuple[gg.Node, ...]] = set()
1✔
3748
        self.rounds_with_no_new_coverage = 0
1✔
3749
        self.reset_coverage_after_n_round_with_no_coverage = (
1✔
3750
            reset_coverage_after_n_round_with_no_coverage
3751
        )
3752
        self.symbol_costs: Optional[Dict[str, int]] = symbol_costs
1✔
3753

3754
        self.logger = logging.getLogger(type(self).__name__)
1✔
3755

3756
    def __repr__(self):
1✔
3757
        return (
×
3758
            "GrammarBasedBlackboxCostComputer("
3759
            + f"{repr(self.cost_settings)}, "
3760
            + "graph, "
3761
            + f"{self.reset_coverage_after_n_round_with_no_coverage}, "
3762
            + f"{self.symbol_costs})"
3763
        )
3764

3765
    def compute_cost(self, state: SolutionState) -> float:
1✔
3766
        # How costly is it to finish the tree?
3767
        tree_closing_cost = self.compute_tree_closing_cost(state.tree)
1✔
3768

3769
        # Quantifiers are expensive (universal formulas have to be matched, tree insertion for existential
3770
        # formulas is even more costly). TODO: Penalize nested quantifiers more.
3771
        constraint_cost = sum(
1✔
3772
            [
3773
                idx * (2 if isinstance(f, language.ExistsFormula) else 1) + 1
3774
                for c in get_quantifier_chains(state.constraint)
3775
                for idx, f in enumerate(c)
3776
            ]
3777
        )
3778

3779
        # k-Path coverage: Fewer covered -> higher penalty
3780
        k_cov_cost = self._compute_k_coverage_cost(state)
1✔
3781

3782
        # Covered k-paths: Fewer contributed -> higher penalty
3783
        global_k_path_cost = self._compute_global_k_coverage_cost(state)
1✔
3784

3785
        costs = [
1✔
3786
            tree_closing_cost,
3787
            constraint_cost,
3788
            state.level,
3789
            k_cov_cost,
3790
            global_k_path_cost,
3791
        ]
3792
        assert tree_closing_cost >= 0, f"tree_closing_cost == {tree_closing_cost}!"
1✔
3793
        assert constraint_cost >= 0, f"constraint_cost == {constraint_cost}!"
1✔
3794
        assert state.level >= 0, f"state.level == {state.level}!"
1✔
3795
        assert k_cov_cost >= 0, f"k_cov_cost == {k_cov_cost}!"
1✔
3796
        assert global_k_path_cost >= 0, f"global_k_path_cost == {global_k_path_cost}!"
1✔
3797

3798
        # Compute geometric mean
3799
        result = weighted_geometric_mean(costs, list(self.cost_settings.weight_vector))
1✔
3800

3801
        self.logger.debug(
1✔
3802
            "Computed cost for state %s:\n%f, individual costs: %s, weights: %s",
3803
            lazystr(lambda: f"({(str(state.constraint)[:50] + '...')}, {state.tree})"),
3804
            result,
3805
            costs,
3806
            self.cost_settings.weight_vector,
3807
        )
3808

3809
        return result
1✔
3810

3811
    def signal_tree_output(self, tree: DerivationTree) -> None:
1✔
3812
        self._update_covered_k_paths(tree)
1✔
3813

3814
    def _symbol_costs(self):
1✔
3815
        if self.symbol_costs is None:
1✔
3816
            self.symbol_costs = compute_symbol_costs(self.graph)
1✔
3817
        return self.symbol_costs
1✔
3818

3819
    def _update_covered_k_paths(self, tree: DerivationTree):
1✔
3820
        if self.cost_settings.weight_vector.low_global_k_path_coverage_penalty > 0:
1✔
3821
            old_covered_k_paths = copy.copy(self.covered_k_paths)
1✔
3822

3823
            self.covered_k_paths.update(
1✔
3824
                tree.k_paths(
3825
                    self.graph, self.cost_settings.k, include_potential_paths=False
3826
                )
3827
            )
3828

3829
            if old_covered_k_paths == self.covered_k_paths:
1✔
3830
                self.rounds_with_no_new_coverage += 1
1✔
3831

3832
            graph_paths = self.graph.k_paths(
1✔
3833
                self.cost_settings.k, include_terminals=False
3834
            )
3835
            if (
1✔
3836
                self.rounds_with_no_new_coverage
3837
                >= self.reset_coverage_after_n_round_with_no_coverage
3838
                or self.covered_k_paths == graph_paths
3839
            ):
3840
                if self.covered_k_paths == graph_paths:
1✔
UNCOV
3841
                    self.logger.debug("ALL PATHS COVERED")
×
3842
                else:
3843
                    self.logger.debug(
1✔
3844
                        "COVERAGE RESET SINCE NO CHANGE IN COVERED PATHS SINCE %d "
3845
                        + "ROUNDS (%d path(s) uncovered)",
3846
                        self.reset_coverage_after_n_round_with_no_coverage,
3847
                        len(graph_paths) - len(self.covered_k_paths),
3848
                    )
3849

3850
                    # uncovered_paths = (
3851
                    #     self.graph.k_paths(
3852
                    #         self.cost_settings.k, include_terminals=False
3853
                    #     )
3854
                    #     - self.covered_k_paths
3855
                    # )
3856
                    # self.logger.debug(
3857
                    #     "\n".join(
3858
                    #         [
3859
                    #             ", ".join(f"'{n.symbol}'" for n in p)
3860
                    #             for p in uncovered_paths
3861
                    #         ]
3862
                    #     )
3863
                    # )
3864

3865
                self.covered_k_paths = set()
1✔
3866
            else:
3867
                pass
1✔
3868
                # uncovered_paths = (
3869
                #     self.graph.k_paths(self.cost_settings.k, include_terminals=False)
3870
                #     - self.covered_k_paths
3871
                # )
3872
                # self.logger.debug("%d uncovered paths", len(uncovered_paths))
3873
                # self.logger.debug(
3874
                #     "\n"
3875
                #     + "\n".join(
3876
                #         [", ".join(f"'{n.symbol}'" for n in p)
3877
                #         for p in uncovered_paths]
3878
                #     )
3879
                #     + "\n"
3880
                # )
3881

3882
            if (
1✔
3883
                self.rounds_with_no_new_coverage
3884
                >= self.reset_coverage_after_n_round_with_no_coverage
3885
            ):
3886
                self.rounds_with_no_new_coverage = 0
1✔
3887

3888
    def _compute_global_k_coverage_cost(self, state: SolutionState):
1✔
3889
        if self.cost_settings.weight_vector.low_global_k_path_coverage_penalty == 0:
1✔
3890
            return 0
×
3891

3892
        tree_k_paths = state.tree.k_paths(
1✔
3893
            self.graph, self.cost_settings.k, include_potential_paths=False
3894
        )
3895
        all_graph_k_paths = self.graph.k_paths(
1✔
3896
            self.cost_settings.k, include_terminals=False
3897
        )
3898

3899
        contributed_k_paths = {
1✔
3900
            path
3901
            for path in all_graph_k_paths
3902
            if path in tree_k_paths and path not in self.covered_k_paths
3903
        }
3904

3905
        num_contributed_k_paths = len(contributed_k_paths)
1✔
3906
        num_missing_k_paths = len(all_graph_k_paths) - len(self.covered_k_paths)
1✔
3907

3908
        # self.logger.debug(
3909
        #     'k-Paths contributed by input %s:\n%s',
3910
        #     state.tree,
3911
        #     '\n'.join(map(
3912
        #         lambda path: ' '.join(map(
3913
        #             lambda n: n.symbol,
3914
        #             filter(lambda n: not isinstance(n, gg.ChoiceNode), path))),
3915
        #         contributed_k_paths)))
3916
        # self.logger.debug('Missing k paths: %s', num_missing_k_paths)
3917

3918
        assert 0 <= num_contributed_k_paths <= num_missing_k_paths, (
1✔
3919
            f"num_contributed_k_paths == {num_contributed_k_paths}, "
3920
            f"num_missing_k_paths == {num_missing_k_paths}"
3921
        )
3922

3923
        # return 1 - (num_contributed_k_paths / num_missing_k_paths)
3924

3925
        potential_tree_k_paths = state.tree.k_paths(
1✔
3926
            self.graph, self.cost_settings.k, include_potential_paths=True
3927
        )
3928
        contributed_k_paths = {
1✔
3929
            path
3930
            for path in all_graph_k_paths
3931
            if path in potential_tree_k_paths and path not in self.covered_k_paths
3932
        }
3933

3934
        num_contributed_potential_k_paths = len(contributed_k_paths)
1✔
3935

3936
        if not num_missing_k_paths:
1✔
3937
            return 0
1✔
3938

3939
        return 1 - weighted_geometric_mean(
1✔
3940
            [
3941
                num_contributed_k_paths / num_missing_k_paths,
3942
                num_contributed_potential_k_paths / num_missing_k_paths,
3943
            ],
3944
            [0.2, 0.8],
3945
        )
3946

3947
    def _compute_k_coverage_cost(self, state: SolutionState) -> float:
1✔
3948
        if self.cost_settings.weight_vector.low_k_coverage_penalty == 0:
1✔
3949
            return 0
1✔
3950

3951
        coverages = []
1✔
3952
        for k in range(1, self.cost_settings.k + 1):
1✔
3953
            coverage = state.tree.k_coverage(
1✔
3954
                self.graph, k, include_potential_paths=False
3955
            )
3956
            assert 0 <= coverage <= 1, f"coverage == {coverage}"
1✔
3957

3958
            coverages.append(1 - coverage)
1✔
3959

3960
        return math.prod(coverages) ** (1 / float(self.cost_settings.k))
1✔
3961

3962
    def compute_tree_closing_cost(self, tree: DerivationTree) -> float:
1✔
3963
        nonterminals = [leaf.value for _, leaf in tree.open_leaves()]
1✔
3964
        return sum([self._symbol_costs()[nonterminal] for nonterminal in nonterminals])
1✔
3965

3966

3967
def smt_formulas_referring_to_subtrees(
1✔
3968
    smt_formulas: Sequence[language.SMTFormula],
3969
) -> List[language.SMTFormula]:
3970
    """
3971
    Returns a list of SMT formulas whose solutions address subtrees of other SMT
3972
    formulas, but whose own substitution subtrees are in turn *not* referred by
3973
    top-level substitution trees of other formulas. Those must be solved first to avoid
3974
    inconsistencies.
3975

3976
    :param smt_formulas: The formulas to search for references to subtrees.
3977
    :return: The list of conflicting formulas that must be solved first.
3978
    """
3979

3980
    def subtree_ids(formula: language.SMTFormula) -> Set[int]:
1✔
3981
        return {
1✔
3982
            subtree.id
3983
            for tree in formula.substitutions.values()
3984
            for _, subtree in tree.paths()
3985
            if subtree.id != tree.id
3986
        }
3987

3988
    def tree_ids(formula: language.SMTFormula) -> Set[int]:
1✔
3989
        return {tree.id for tree in formula.substitutions.values()}
1✔
3990

3991
    subtree_ids_for_formula: Dict[language.SMTFormula, Set[int]] = {
1✔
3992
        formula: subtree_ids(formula) for formula in smt_formulas
3993
    }
3994

3995
    tree_ids_for_formula: Dict[language.SMTFormula, Set[int]] = {
1✔
3996
        formula: tree_ids(formula) for formula in smt_formulas
3997
    }
3998

3999
    def independent_from_solutions_of_other_formula(
1✔
4000
        idx: int, formula: language.SMTFormula
4001
    ) -> bool:
4002
        return all(
1✔
4003
            not tree_ids_for_formula[other_formula].intersection(
4004
                subtree_ids_for_formula[formula]
4005
            )
4006
            for other_idx, other_formula in enumerate(smt_formulas)
4007
            if other_idx != idx
4008
        )
4009

4010
    def refers_to_subtree_of_other_formula(
1✔
4011
        idx: int, formula: language.SMTFormula
4012
    ) -> bool:
4013
        return any(
1✔
4014
            tree_ids_for_formula[formula].intersection(
4015
                subtree_ids_for_formula[other_formula]
4016
            )
4017
            for other_idx, other_formula in enumerate(smt_formulas)
4018
            if other_idx != idx
4019
        )
4020

4021
    return [
1✔
4022
        formula
4023
        for idx, formula in enumerate(smt_formulas)
4024
        if refers_to_subtree_of_other_formula(idx, formula)
4025
        and independent_from_solutions_of_other_formula(idx, formula)
4026
    ]
4027

4028

4029
def compute_tree_closing_cost(tree: DerivationTree, graph: GrammarGraph) -> float:
1✔
4030
    nonterminals = [leaf.value for _, leaf in tree.open_leaves()]
1✔
4031
    return sum(
1✔
4032
        [compute_symbol_costs(graph)[nonterminal] for nonterminal in nonterminals]
4033
    )
4034

4035

4036
def get_quantifier_chains(
1✔
4037
    formula: language.Formula,
4038
) -> List[Tuple[Union[language.QuantifiedFormula, language.ExistsIntFormula], ...]]:
4039
    univ_toplevel_formulas = get_toplevel_quantified_formulas(formula)
1✔
4040
    return [
1✔
4041
        (f,) + c
4042
        for f in univ_toplevel_formulas
4043
        for c in (get_quantifier_chains(f.inner_formula) or [()])
4044
    ]
4045

4046

4047
def shortest_derivations(graph: gg.GrammarGraph) -> Dict[str, int]:
1✔
4048
    def avg(it) -> int:
1✔
4049
        elems = [elem for elem in it if elem is not None]
1✔
4050
        return math.ceil(math.prod(elems) ** (1 / len(elems)))
1✔
4051

4052
    parent_relation = {node: set() for node in graph.all_nodes}
1✔
4053
    for parent, child in graph.all_edges:
1✔
4054
        parent_relation[child].add(parent)
1✔
4055

4056
    shortest_node_derivations: Dict[gg.Node, int] = {}
1✔
4057
    stack: List[gg.Node] = graph.filter(lambda node: isinstance(node, gg.TerminalNode))
1✔
4058
    while stack:
1✔
4059
        node = stack.pop()
1✔
4060

4061
        old_min = shortest_node_derivations.get(node, None)
1✔
4062

4063
        if isinstance(node, gg.TerminalNode):
1✔
4064
            shortest_node_derivations[node] = 0
1✔
4065
        elif isinstance(node, gg.ChoiceNode):
1✔
4066
            shortest_node_derivations[node] = max(
1✔
4067
                shortest_node_derivations.get(child, 0) for child in node.children
4068
            )
4069
        elif isinstance(node, gg.NonterminalNode):
1✔
4070
            assert not isinstance(node, gg.ChoiceNode)
1✔
4071

4072
            shortest_node_derivations[node] = (
1✔
4073
                avg(
4074
                    shortest_node_derivations.get(child, None)
4075
                    for child in node.children
4076
                )
4077
                + 1
4078
            )
4079

4080
        if (old_min or sys.maxsize) > shortest_node_derivations[node]:
1✔
4081
            stack.extend(parent_relation[node])
1✔
4082

4083
    return {
1✔
4084
        nonterminal: shortest_node_derivations[graph.get_node(nonterminal)]
4085
        for nonterminal in graph.grammar
4086
    }
4087

4088

4089
@lru_cache()
1✔
4090
def compute_symbol_costs(graph: GrammarGraph) -> Dict[str, int]:
1✔
4091
    grammar = graph.to_grammar()
1✔
4092
    canonical_grammar = canonical(grammar)
1✔
4093

4094
    result: Dict[str, int] = shortest_derivations(graph)
1✔
4095

4096
    nonterminal_parents = [
1✔
4097
        nonterminal
4098
        for nonterminal in canonical_grammar
4099
        if any(
4100
            is_nonterminal(symbol)
4101
            for expansion in canonical_grammar[nonterminal]
4102
            for symbol in expansion
4103
        )
4104
    ]
4105

4106
    # Sometimes this computation results in some nonterminals having lower cost values
4107
    # than nonterminals that are reachable from those (but not vice versa), which is
4108
    # undesired. We counteract this by assuring that on paths with at most one cycle
4109
    # from the root to any nonterminal parent, the costs are strictly monotonically
4110
    # decreasing.
4111
    for nonterminal_parent in nonterminal_parents:
1✔
4112
        # noinspection PyTypeChecker
4113
        for path in all_paths(graph, graph.root, graph.get_node(nonterminal_parent)):
1✔
4114
            for idx in reversed(range(1, len(path))):
1✔
4115
                source: gg.Node = path[idx - 1]
1✔
4116
                target: gg.Node = path[idx]
1✔
4117

4118
                if result[source.symbol] <= result[target.symbol]:
1✔
4119
                    result[source.symbol] = result[target.symbol] + 1
1✔
4120

4121
    return result
1✔
4122

4123

4124
def all_paths(
1✔
4125
    graph,
4126
    from_node: gg.NonterminalNode,
4127
    to_node: gg.NonterminalNode,
4128
    cycles_allowed: int = 2,
4129
) -> List[List[gg.NonterminalNode]]:
4130
    """Compute all paths between two nodes. Note: We allow to visit each nonterminal twice.
4131
    This is not really allowing up to `cycles_allowed` cycles (which was the original intention
4132
    of the parameter), since then we would have to check per path; yet, the number of paths would
4133
    explode then and the current implementation provides reasonably good results."""
4134
    result: List[List[gg.NonterminalNode]] = []
1✔
4135
    visited: Dict[gg.NonterminalNode, int] = {n: 0 for n in graph.all_nodes}
1✔
4136

4137
    queue: List[List[gg.NonterminalNode]] = [[from_node]]
1✔
4138
    while queue:
1✔
4139
        p = queue.pop(0)
1✔
4140
        if p[-1] == to_node:
1✔
4141
            result.append(p)
1✔
4142
            continue
1✔
4143

4144
        for child in p[-1].children:
1✔
4145
            if (
1✔
4146
                not isinstance(child, gg.NonterminalNode)
4147
                or visited[child] > cycles_allowed + 1
4148
            ):
4149
                continue
1✔
4150

4151
            visited[child] += 1
1✔
4152
            queue.append(p + [child])
1✔
4153

4154
    return [[n for n in p if not isinstance(n, gg.ChoiceNode)] for p in result]
1✔
4155

4156

4157
def implies(
1✔
4158
    f1: language.Formula, f2: language.Formula, grammar: Grammar, timeout_seconds=5
4159
) -> bool:
4160
    solver = ISLaSolver(
1✔
4161
        grammar, f1 & -f2, activate_unsat_support=True, timeout_seconds=timeout_seconds
4162
    )
4163

4164
    return (
1✔
4165
        safe(solver.solve, exceptions=(StopIteration,))()
4166
        .map(lambda _: False)
4167
        .lash(lambda _: Success(True))
4168
    ).unwrap()
4169

4170

4171
def equivalent(
1✔
4172
    f1: language.Formula, f2: language.Formula, grammar: Grammar, timeout_seconds=5
4173
) -> bool:
4174
    solver = ISLaSolver(
1✔
4175
        grammar,
4176
        -(f1 & f2 | -f1 & -f2),
4177
        activate_unsat_support=True,
4178
        timeout_seconds=timeout_seconds,
4179
    )
4180

4181
    return (
1✔
4182
        safe(solver.solve)()
4183
        .map(lambda _: False)
4184
        .lash(lambda e: Success(isinstance(e, StopIteration)))
4185
    ).unwrap()
4186

4187

4188
def generate_abstracted_trees(
1✔
4189
    inp: DerivationTree, participating_paths: Set[Path]
4190
) -> List[DerivationTree]:
4191
    """
4192
    Yields trees that are more and more "abstracted," i.e., pruned, at prefixes of the
4193
    paths specified in `participating_paths`.
4194

4195
    :param inp: The unabstracted input.
4196
    :param participating_paths: The paths to abstract.
4197
    :return: A generator of more and more abstract trees, beginning with the most
4198
    concrete and ending with the most abstract ones.
4199
    """
4200
    parent_paths: Set[ImmutableList[Path]] = {
1✔
4201
        tuple(
4202
            [tuple(path[:i]) for i in reversed(range(1, len(path) + 1))]
4203
            if path
4204
            else [()]
4205
        )
4206
        for path in participating_paths
4207
    }
4208

4209
    abstraction_candidate_combinations: Set[ImmutableList[Path]] = {
1✔
4210
        tuple(eliminate_suffixes(combination))
4211
        for k in range(1, len(participating_paths) + 1)
4212
        for paths in itertools.product(*parent_paths)
4213
        for combination in itertools.combinations(paths, k)
4214
    }
4215

4216
    result: Dict[int, DerivationTree] = {}
1✔
4217
    for paths_to_abstract in abstraction_candidate_combinations:
1✔
4218
        abstracted_tree = inp.substitute(
1✔
4219
            {
4220
                inp.get_subtree(path_to_abstract): DerivationTree(
4221
                    inp.get_subtree(path_to_abstract).value
4222
                )
4223
                for path_to_abstract in paths_to_abstract
4224
            }
4225
        )
4226
        result[abstracted_tree.structural_hash()] = abstracted_tree
1✔
4227

4228
    return sorted(result.values(), key=lambda tree: -len(tree))
1✔
4229

4230

4231
class EvaluatePredicateFormulasTransformer(NoopFormulaTransformer):
1✔
4232
    def __init__(self, inp: DerivationTree):
1✔
4233
        super().__init__()
1✔
4234
        self.inp = inp
1✔
4235

4236
    def transform_predicate_formula(
1✔
4237
        self, sub_formula: language.StructuralPredicateFormula
4238
    ) -> language.Formula:
4239
        return sc.true() if sub_formula.evaluate(self.inp) else sc.false()
1✔
4240

4241
    def transform_conjunctive_formula(
1✔
4242
        self, sub_formula: language.ConjunctiveFormula
4243
    ) -> language.Formula:
4244
        return reduce(language.Formula.__and__, sub_formula.args)
1✔
4245

4246
    def transform_disjunctive_formula(
1✔
4247
        self, sub_formula: language.DisjunctiveFormula
4248
    ) -> language.Formula:
4249
        return reduce(language.Formula.__or__, sub_formula.args)
1✔
4250

4251
    def transform_smt_formula(
1✔
4252
        self, sub_formula: language.SMTFormula
4253
    ) -> language.Formula:
4254
        # We instantiate the formula and check whether it evaluates to
4255
        # True (or False in a negation scope); in that case, we replace
4256
        # it by "true." Otherwise, we keep it for later analysis.
4257

4258
        instantiated_formula = copy.deepcopy(sub_formula)
1✔
4259
        set_smt_auto_subst(instantiated_formula, True)
1✔
4260
        set_smt_auto_eval(instantiated_formula, True)
1✔
4261
        instantiated_formula = instantiated_formula.substitute_expressions(
1✔
4262
            sub_formula.substitutions, force=True
4263
        )
4264

4265
        assert instantiated_formula in {sc.true(), sc.false()}
1✔
4266

4267
        return (
1✔
4268
            sc.true()
4269
            if (instantiated_formula == sc.true()) ^ self.in_negation_scope
4270
            else sub_formula
4271
        )
4272

4273

4274
def create_fixed_length_tree(
1✔
4275
    start: DerivationTree | str,
4276
    canonical_grammar: CanonicalGrammar,
4277
    target_length: int,
4278
) -> Optional[DerivationTree]:
4279
    nullable = compute_nullable_nonterminals(canonical_grammar)
1✔
4280
    start = DerivationTree(start) if isinstance(start, str) else start
1✔
4281
    stack: List[
1✔
4282
        Tuple[DerivationTree, int, ImmutableList[Tuple[Path, DerivationTree]]]
4283
    ] = [
4284
        (start, int(start.value not in nullable), (((), start),)),
4285
    ]
4286

4287
    while stack:
1✔
4288
        tree, curr_len, open_leaves = stack.pop()
1✔
4289

4290
        if not open_leaves:
1✔
4291
            if curr_len == target_length:
1✔
4292
                return tree
1✔
4293
            else:
4294
                continue
1✔
4295

4296
        if curr_len > target_length:
1✔
4297
            continue
1✔
4298

4299
        idx: int
4300
        path: Path
4301
        leaf: DerivationTree
4302
        for idx, (path, leaf) in reversed(list(enumerate(open_leaves))):
1✔
4303
            terminal_expansions, expansions = get_expansions(
1✔
4304
                leaf.value, canonical_grammar
4305
            )
4306

4307
            if terminal_expansions:
1✔
4308
                expansions.append(random.choice(terminal_expansions))
1✔
4309

4310
            # Only choose one random terminal expansion; keep all nonterminal expansions
4311
            expansions = sorted(
1✔
4312
                expansions,
4313
                key=lambda expansion: len(
4314
                    [elem for elem in expansion if is_nonterminal(elem)]
4315
                ),
4316
            )
4317

4318
            for expansion in reversed(expansions):
1✔
4319
                new_children = tuple(
1✔
4320
                    [
4321
                        DerivationTree(elem, None if is_nonterminal(elem) else ())
4322
                        for elem in expansion
4323
                    ]
4324
                )
4325

4326
                expanded_tree = tree.replace_path(
1✔
4327
                    path,
4328
                    DerivationTree(
4329
                        leaf.value,
4330
                        new_children,
4331
                    ),
4332
                )
4333

4334
                stack.append(
1✔
4335
                    (
4336
                        expanded_tree,
4337
                        curr_len
4338
                        + sum(
4339
                            [
4340
                                (
4341
                                    len(child.value)
4342
                                    if child.children == ()
4343
                                    else (1 if child.value not in nullable else 0)
4344
                                )
4345
                                for child in new_children
4346
                            ]
4347
                        )
4348
                        - int(leaf.value not in nullable),
4349
                        open_leaves[:idx]
4350
                        + tuple(
4351
                            [
4352
                                (path + (child_idx,), new_child)
4353
                                for child_idx, new_child in enumerate(new_children)
4354
                                if is_nonterminal(new_child.value)
4355
                            ]
4356
                        )
4357
                        + open_leaves[idx + 1 :],
4358
                    )
4359
                )
4360

4361
    return None
1✔
4362

4363

4364
def subtree_solutions(
1✔
4365
    solution: Dict[language.Constant | DerivationTree, DerivationTree]
4366
) -> Dict[language.Variable | DerivationTree, DerivationTree]:
4367
    solution_with_subtrees: Dict[language.Variable | DerivationTree, DerivationTree] = (
1✔
4368
        {}
4369
    )
4370
    for orig, subst in solution.items():
1✔
4371
        if isinstance(orig, language.Variable):
1✔
4372
            solution_with_subtrees[orig] = subst
1✔
4373
            continue
1✔
4374

4375
        assert isinstance(
1✔
4376
            orig, DerivationTree
4377
        ), f"Expected a DerivationTree, given: {type(orig).__name__}"
4378

4379
        # Note: It can happen that a path in the original tree is not valid in the
4380
        #       substitution, e.g., if we happen to replace a larger with a smaller
4381
        #       tree.
4382
        for path, tree in [
1✔
4383
            (p, t)
4384
            for p, t in orig.paths()
4385
            if t not in solution_with_subtrees and subst.is_valid_path(p)
4386
        ]:
4387
            solution_with_subtrees[tree] = subst.get_subtree(path)
1✔
4388

4389
    return solution_with_subtrees
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc