• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

rindPHI / isla / 7758526352

02 Feb 2024 04:08PM UTC coverage: 93.406% (-0.3%) from 93.737%
7758526352

Pull #90

github

web-flow
Merge d2edf2f3d into 3de029959
Pull Request #90: RepairSolver

690 of 774 new or added lines in 7 files covered. (89.15%)

2 existing lines in 2 files now uncovered.

6813 of 7294 relevant lines covered (93.41%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.54
/src/isla/solver.py
1
# Copyright © 2022 CISPA Helmholtz Center for Information Security.
2
# Author: Dominic Steinhöfel.
3
#
4
# This file is part of ISLa.
5
#
6
# ISLa is free software: you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation, either version 3 of the License, or
9
# (at your option) any later version.
10
#
11
# ISLa is distributed in the hope that it will be useful,
12
# but WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14
# GNU General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with ISLa.  If not, see <http://www.gnu.org/licenses/>.
18

19
import copy
1✔
20
import functools
1✔
21
import heapq
1✔
22
import itertools
1✔
23
import logging
1✔
24
import math
1✔
25
import operator
1✔
26
import random
1✔
27
import sys
1✔
28
import time
1✔
29
from abc import ABC
1✔
30
from dataclasses import dataclass
1✔
31
from functools import reduce, lru_cache, partial
1✔
32
from typing import (
1✔
33
    Dict,
34
    List,
35
    Set,
36
    Optional,
37
    Tuple,
38
    Union,
39
    cast,
40
    Callable,
41
    Iterable,
42
    Sequence,
43
)
44

45
import pkg_resources
1✔
46
import z3
1✔
47
from grammar_graph import gg
1✔
48
from grammar_graph.gg import GrammarGraph
1✔
49
from grammar_to_regex.cfg2regex import RegexConverter
1✔
50
from grammar_to_regex.regex import regex_to_z3
1✔
51
from orderedset import OrderedSet
1✔
52
from packaging import version
1✔
53
from returns.converters import result_to_maybe
1✔
54
from returns.functions import compose, tap
1✔
55
from returns.maybe import Nothing, Some
1✔
56
from returns.pipeline import flow, is_successful
1✔
57
from returns.pointfree import lash
1✔
58
from returns.result import safe, Success
1✔
59

60
import isla.isla_shortcuts as sc
1✔
61
import isla.three_valued_truth
1✔
62
from isla import language
1✔
63
from isla.derivation_tree import DerivationTree
1✔
64
from isla.evaluator import (
1✔
65
    evaluate,
66
    quantified_formula_might_match,
67
    get_toplevel_quantified_formulas,
68
    eliminate_quantifiers,
69
)
70
from isla.evaluator import matches_for_quantified_formula
1✔
71
from isla.existential_helpers import (
1✔
72
    insert_tree,
73
    DIRECT_EMBEDDING,
74
    SELF_EMBEDDING,
75
    CONTEXT_ADDITION,
76
)
77
from isla.fuzzer import GrammarFuzzer, GrammarCoverageFuzzer, expansion_key
1✔
78
from isla.helpers import (
1✔
79
    delete_unreachable,
80
    shuffle,
81
    dict_of_lists_to_list_of_dicts,
82
    weighted_geometric_mean,
83
    assertions_activated,
84
    split_str_with_nonterminals,
85
    cluster_by_common_elements,
86
    is_nonterminal,
87
    canonical,
88
    lazyjoin,
89
    lazystr,
90
    Maybe,
91
    eliminate_suffixes,
92
    get_elem_by_equivalence,
93
    get_expansions,
94
    list_del,
95
    compute_nullable_nonterminals,
96
    eassert,
97
    merge_dict_of_sets,
98
    list_set,
99
)
100
from isla.isla_predicates import (
1✔
101
    STANDARD_STRUCTURAL_PREDICATES,
102
    STANDARD_SEMANTIC_PREDICATES,
103
    COUNT_PREDICATE,
104
)
105
from isla.language import (
1✔
106
    VariablesCollector,
107
    split_conjunction,
108
    split_disjunction,
109
    convert_to_nnf,
110
    ensure_unique_bound_variables,
111
    parse_isla,
112
    get_conjuncts,
113
    parse_bnf,
114
    ForallIntFormula,
115
    set_smt_auto_eval,
116
    NoopFormulaTransformer,
117
    set_smt_auto_subst,
118
    fresh_constant,
119
    to_dnf_clauses,
120
)
121
from isla.mutator import Mutator
1✔
122
from isla.parser import EarleyParser
1✔
123
from isla.type_defs import Grammar, Path, ImmutableList, CanonicalGrammar
1✔
124
from isla.z3_helpers import (
1✔
125
    z3_solve,
126
    z3_subst,
127
    z3_eq,
128
    z3_and,
129
    visit_z3_expr,
130
    smt_string_val_to_string,
131
    parent_relationships_in_z3_expr,
132
    numeric_intervals_from_regex,
133
    z3_or,
134
)
135

136

137
@dataclass(frozen=True)
1✔
138
class SolutionState:
1✔
139
    constraint: language.Formula
1✔
140
    tree: DerivationTree
1✔
141
    level: int = 0
1✔
142
    __hash: Optional[int] = None
1✔
143

144
    def formula_satisfied(
1✔
145
        self, grammar: Grammar
146
    ) -> isla.three_valued_truth.ThreeValuedTruth:
147
        if self.tree.is_open():
1✔
148
            # Have to instantiate variables first
149
            return isla.three_valued_truth.ThreeValuedTruth.unknown()
×
150

151
        return evaluate(self.constraint, self.tree, grammar)
1✔
152

153
    def complete(self) -> bool:
1✔
154
        if not self.tree.is_complete():
1✔
155
            return False
1✔
156

157
        # We assume that any universal quantifier has already been instantiated, if it
158
        # matches, and is thus satisfied, or another unsatisfied constraint resulted
159
        # from the instantiation. Existential, predicate, and SMT formulas have to be
160
        # eliminated first.
161

162
        return self.constraint == sc.true()
1✔
163

164
    # Less-than comparisons are needed for usage in the binary heap queue
165
    def __lt__(self, other: "SolutionState"):
1✔
166
        return hash(self) < hash(other)
1✔
167

168
    def __hash__(self):
1✔
169
        if self.__hash is None:
1✔
170
            result = hash((self.constraint, self.tree))
1✔
171
            object.__setattr__(self, "__hash", result)
1✔
172
            return result
1✔
173

174
        return self.__hash
×
175

176
    def __eq__(self, other):
1✔
177
        return (
1✔
178
            isinstance(other, SolutionState)
179
            and self.constraint == other.constraint
180
            and self.tree.structurally_equal(other.tree)
181
        )
182

183

184
@dataclass(frozen=True)
1✔
185
class CostWeightVector:
1✔
186
    """
1✔
187
    Collection of weights for the
188
    :class:`~isla.solver.GrammarBasedBlackboxCostComputer`.
189
    """
190

191
    tree_closing_cost: float = 0
1✔
192
    constraint_cost: float = 0
1✔
193
    derivation_depth_penalty: float = 0
1✔
194
    low_k_coverage_penalty: float = 0
1✔
195
    low_global_k_path_coverage_penalty: float = 0
1✔
196

197
    def __iter__(self):
1✔
198
        """
199
        Use tuple assignment for objects of this type:
200

201
        >>> v = CostWeightVector(1, 2, 3, 4, 5)
202
        >>> a, b, c, d, e = v
203
        >>> (a, b, c, d, e)
204
        (1, 2, 3, 4, 5)
205

206
        :return: An iterator of the fixed-size list of elements of the weight vector.
207
        """
208
        return iter(
1✔
209
            [
210
                self.tree_closing_cost,
211
                self.constraint_cost,
212
                self.derivation_depth_penalty,
213
                self.low_k_coverage_penalty,
214
                self.low_global_k_path_coverage_penalty,
215
            ]
216
        )
217

218
    def __getitem__(self, item: int) -> float:
1✔
219
        """
220
        Tuple-like access of elements of the vector.
221

222
        >>> v = CostWeightVector(1, 2, 3, 4, 5)
223
        >>> v[3]
224
        4
225

226
        :param item: A numeric index.
227
        :return: The element at index :code:`item`.
228
        """
229
        assert isinstance(item, int)
1✔
230
        return [
1✔
231
            self.tree_closing_cost,
232
            self.constraint_cost,
233
            self.derivation_depth_penalty,
234
            self.low_k_coverage_penalty,
235
            self.low_global_k_path_coverage_penalty,
236
        ][item]
237

238

239
@dataclass(frozen=True)
1✔
240
class CostSettings:
1✔
241
    weight_vector: CostWeightVector
1✔
242
    k: int = 3
1✔
243

244
    def __init__(self, weight_vector: CostWeightVector, k: int = 3):
1✔
245
        assert isinstance(weight_vector, CostWeightVector)
1✔
246
        assert isinstance(k, int)
1✔
247
        object.__setattr__(self, "weight_vector", weight_vector)
1✔
248
        object.__setattr__(self, "k", k)
1✔
249

250

251
STD_COST_SETTINGS = CostSettings(
1✔
252
    CostWeightVector(
253
        tree_closing_cost=6.5,
254
        constraint_cost=1,
255
        derivation_depth_penalty=4,
256
        low_k_coverage_penalty=2,
257
        low_global_k_path_coverage_penalty=19,
258
    ),
259
    k=3,
260
)
261

262

263
@dataclass(frozen=True)
1✔
264
class UnknownResultError(Exception):
1✔
265
    pass
1✔
266

267

268
@dataclass(frozen=True)
1✔
269
class SemanticError(Exception):
1✔
270
    pass
1✔
271

272

273
@dataclass(frozen=True)
1✔
274
class SolverDefaults:
1✔
275
    formula: Optional[language.Formula | str] = None
1✔
276
    structural_predicates: frozenset[
1✔
277
        language.StructuralPredicate
278
    ] = STANDARD_STRUCTURAL_PREDICATES
279
    semantic_predicates: frozenset[
1✔
280
        language.SemanticPredicate
281
    ] = STANDARD_SEMANTIC_PREDICATES
282
    max_number_free_instantiations: int = 10
1✔
283
    max_number_smt_instantiations: int = 10
1✔
284
    max_number_tree_insertion_results: int = 5
1✔
285
    enforce_unique_trees_in_queue: bool = False
1✔
286
    debug: bool = False
1✔
287
    cost_computer: Optional["CostComputer"] = None
1✔
288
    timeout_seconds: Optional[int] = None
1✔
289
    global_fuzzer: bool = False
1✔
290
    predicates_unique_in_int_arg: Tuple[language.SemanticPredicate, ...] = (
1✔
291
        COUNT_PREDICATE,
292
    )
293
    fuzzer_factory: Callable[
1✔
294
        [Grammar], GrammarFuzzer
295
    ] = lambda grammar: GrammarCoverageFuzzer(grammar)
296
    tree_insertion_methods: Optional[int] = None
1✔
297
    activate_unsat_support: bool = False
1✔
298
    grammar_unwinding_threshold: int = 4
1✔
299
    initial_tree: Maybe[DerivationTree] = Nothing
1✔
300
    enable_optimized_z3_queries: bool = True
1✔
301
    start_symbol: Optional[str] = None
1✔
302

303

304
_DEFAULTS = SolverDefaults()
1✔
305

306

307
class ISLaSolver:
1✔
308
    """
1✔
309
    The solver class for ISLa formulas/constraints. Its top-level methods are
310

311
    :meth:`~isla.solver.ISLaSolver.solve`
312
      Use to generate solutions for an ISLa constraint.
313

314
    :meth:`~isla.solver.ISLaSolver.check`
315
      Use to check if an ISLa constraint is satisfied for a given input.
316

317
    :meth:`~isla.solver.ISLaSolver.parse`
318
      Use to parse and validate an input.
319

320
    :meth:`~isla.solver.ISLaSolver.repair`
321
      Use to repair an input such that it satisfies a constraint.
322

323
    :meth:`~isla.solver.ISLaSolver.mutate`
324
      Use to mutate an input such that the result satisfies a constraint.
325
    """
326

327
    def __init__(
1✔
328
        self,
329
        grammar: Grammar | str,
330
        formula: Optional[language.Formula | str] = _DEFAULTS.formula,
331
        structural_predicates: Set[
332
            language.StructuralPredicate
333
        ] = _DEFAULTS.structural_predicates,
334
        semantic_predicates: Set[
335
            language.SemanticPredicate
336
        ] = _DEFAULTS.semantic_predicates,
337
        max_number_free_instantiations: int = _DEFAULTS.max_number_free_instantiations,
338
        max_number_smt_instantiations: int = _DEFAULTS.max_number_smt_instantiations,
339
        max_number_tree_insertion_results: int = _DEFAULTS.max_number_tree_insertion_results,
340
        enforce_unique_trees_in_queue: bool = _DEFAULTS.enforce_unique_trees_in_queue,
341
        debug: bool = _DEFAULTS.debug,
342
        cost_computer: Optional["CostComputer"] = _DEFAULTS.cost_computer,
343
        timeout_seconds: Optional[int] = _DEFAULTS.timeout_seconds,
344
        global_fuzzer: bool = _DEFAULTS.global_fuzzer,
345
        predicates_unique_in_int_arg: Tuple[
346
            language.SemanticPredicate, ...
347
        ] = _DEFAULTS.predicates_unique_in_int_arg,
348
        fuzzer_factory: Callable[[Grammar], GrammarFuzzer] = _DEFAULTS.fuzzer_factory,
349
        tree_insertion_methods: Optional[int] = _DEFAULTS.tree_insertion_methods,
350
        activate_unsat_support: bool = _DEFAULTS.activate_unsat_support,
351
        grammar_unwinding_threshold: int = _DEFAULTS.grammar_unwinding_threshold,
352
        initial_tree: Maybe[DerivationTree] = _DEFAULTS.initial_tree,
353
        enable_optimized_z3_queries: bool = _DEFAULTS.enable_optimized_z3_queries,
354
        start_symbol: Optional[str] = _DEFAULTS.start_symbol,
355
    ):
356
        """
357
        The constructor of :class:`~isla.solver.ISLaSolver` accepts a large number of
358
        parameters. However, all but the first one, :code:`grammar`, are *optional.*
359

360
        The simplest way to construct an ISLa solver is by only providing it with a
361
        grammar only; it then works like a grammar fuzzer.
362

363
        >>> import random
364
        >>> random.seed(1)
365

366
        >>> import string
367
        >>> LANG_GRAMMAR = {
368
        ...     "<start>":
369
        ...         ["<stmt>"],
370
        ...     "<stmt>":
371
        ...         ["<assgn> ; <stmt>", "<assgn>"],
372
        ...     "<assgn>":
373
        ...         ["<var> := <rhs>"],
374
        ...     "<rhs>":
375
        ...         ["<var>", "<digit>"],
376
        ...     "<var>": list(string.ascii_lowercase),
377
        ...     "<digit>": list(string.digits)
378
        ... }
379
        >>>
380
        >>> from isla.solver import ISLaSolver
381
        >>> solver = ISLaSolver(LANG_GRAMMAR)
382
        >>>
383
        >>> str(solver.solve())
384
        'd := 9'
385
        >>> str(solver.solve())
386
        'v := n ; s := r'
387

388
        :param grammar: The underlying grammar; either, as a "Fuzzing Book" dictionary
389
          or in BNF syntax.
390
        :param formula: The formula to solve; either a string or a readily parsed
391
          formula. If no formula is given, a default `true` constraint is assumed, and
392
          the solver falls back to a grammar fuzzer. The number of produced solutions
393
          will then be bound by `max_number_free_instantiations`.
394
        :param structural_predicates: Structural predicates to use when parsing a
395
          formula.
396
        :param semantic_predicates: Semantic predicates to use when parsing a formula.
397
        :param max_number_free_instantiations: Number of times that nonterminals that
398
          are not bound by any formula should be expanded by a coverage-based fuzzer.
399
        :param max_number_smt_instantiations: Number of solutions of SMT formulas that
400
          should be produced.
401
        :param max_number_tree_insertion_results: The maximum number of results when
402
          solving existential quantifiers by tree insertion.
403
        :param enforce_unique_trees_in_queue: If true, states with the same tree as an
404
          already existing tree in the queue are discarded, irrespectively of the
405
          constraint.
406
        :param debug: If true, debug information about the evolution of states is
407
          collected, notably in the field state_tree. The root of the tree is in the
408
          field state_tree_root. The field costs stores the computed cost values for
409
          all new nodes.
410
        :param cost_computer: The `CostComputer` class for computing the cost relevant
411
          to placing states in ISLa's queue.
412
        :param timeout_seconds: Number of seconds after which the solver will terminate.
413
        :param global_fuzzer: If set to True, only one coverage-guided grammar fuzzer
414
          object is used to finish off unconstrained open derivation trees throughout
415
          the whole generation time. This may be beneficial for some targets; e.g., we
416
          experienced that CSV works significantly faster. However, the achieved k-path
417
          coverage can be lower with that setting.
418
        :param predicates_unique_in_int_arg: This is needed in certain cases for
419
          instantiating universal integer quantifiers. The supplied predicates should
420
          have exactly one integer argument, and hold for exactly one integer value
421
          once all other parameters are fixed.
422
        :param fuzzer_factory: Constructor of the fuzzer to use for instantiating
423
          "free" nonterminals.
424
        :param tree_insertion_methods: Combination of methods to use for existential
425
          quantifier elimination by tree insertion. Full selection: `DIRECT_EMBEDDING &
426
          SELF_EMBEDDING & CONTEXT_ADDITION`.
427
        :param activate_unsat_support: Set to True if you assume that a formula might
428
          be unsatisfiable. This triggers additional tests for unsatisfiability that
429
          reduce input generation performance, but might ensure termination (with a
430
          negative solver result) for unsatisfiable problems for which the solver could
431
          otherwise diverge.
432
        :param grammar_unwinding_threshold: When querying the SMT solver, ISLa passes a
433
          regular expression for the syntax of the involved nonterminals. If this
434
          syntax is not regular, we unwind the respective part in the reference grammar
435
          up to a depth of `grammar_unwinding_threshold`. If this is too shallow, it can
436
          happen that an equation etc. cannot be solved; if it is too deep, it can
437
          negatively impact performance (and quite tremendously so).
438
        :param initial_tree: An initial input tree for the queue, if the solver shall
439
          not start from the tree `(<start>, None)`.
440
        :param enable_optimized_z3_queries: Enables preprocessing of Z3 queries (mainly
441
          numeric problems concerning things like length). This can improve performance
442
          significantly; however, it might happen that certain problems cannot be solved
443
          anymore. In that case, this option can/should be deactivated.
444
        :param start_symbol: This is an alternative to `initial_tree` for starting with
445
          a start symbol different form `<start>`. If `start_symbol` is provided, a tree
446
          consisting of a single root node with the value of `start_symbol` is chosen as
447
          initial tree.
448
        """
449
        self.logger = logging.getLogger(type(self).__name__)
1✔
450

451
        # We require at least z3 4.8.13.0. ISLa might work for some older versions, but
452
        # at least for 4.8.8.0, we have witnessed that certain rather easy constraints,
453
        # like equations, don't work since z3 cannot handle the restrictions to more
454
        # complicated regular expressions. This happened in the XML case study with
455
        # constraints of the kind <id> == <id_no_prefix>. At least with 4.8.13.0, this
456
        # works flawlessly, but times out for 4.8.8.0.
457
        #
458
        # This should be solved using Python requirements, which however cannot be done
459
        # currently since the fuzzingbook library inflexibly binds z3 to 4.8.8.0. Thus,
460
        # one has to manually install a newer version and ignore the warning.
461

462
        z3_version = pkg_resources.get_distribution("z3-solver").version
1✔
463
        assert version.parse(z3_version) >= version.parse("4.8.13.0"), (
1✔
464
            f"ISLa requires at least z3 4.8.13.0, present: {z3_version}. "
465
            "Please install a newer z3 version, e.g., using 'pip install z3-solver==4.8.14.0'."
466
        )
467

468
        if isinstance(grammar, str):
1✔
469
            self.grammar = parse_bnf(grammar)
1✔
470
        else:
471
            self.grammar = copy.deepcopy(grammar)
1✔
472

473
        assert start_symbol is None or not is_successful(
1✔
474
            initial_tree
475
        ), "You cannot supply a start symbol *and* an initial tree."
476

477
        if start_symbol is not None:
1✔
478
            self.grammar |= {"<start>": [start_symbol]}
1✔
479
            self.grammar = delete_unreachable(self.grammar)
1✔
480

481
        self.graph = GrammarGraph.from_grammar(self.grammar)
1✔
482
        self.canonical_grammar = canonical(self.grammar)
1✔
483
        self.timeout_seconds = timeout_seconds
1✔
484
        self.start_time: Optional[int] = None
1✔
485
        self.global_fuzzer = global_fuzzer
1✔
486
        self.fuzzer = fuzzer_factory(self.grammar)
1✔
487
        self.fuzzer_factory = fuzzer_factory
1✔
488
        self.predicates_unique_in_int_arg: Set[language.SemanticPredicate] = set(
1✔
489
            predicates_unique_in_int_arg
490
        )
491
        self.grammar_unwinding_threshold = grammar_unwinding_threshold
1✔
492
        self.enable_optimized_z3_queries = enable_optimized_z3_queries
1✔
493

494
        if activate_unsat_support and tree_insertion_methods is None:
1✔
495
            self.tree_insertion_methods = 0
1✔
496
        else:
497
            if activate_unsat_support:
1✔
498
                assert tree_insertion_methods is not None
×
499
                print(
×
500
                    "With activate_unsat_support set, a 0 value for tree_insertion_methods is recommended, "
501
                    f"the current value is: {tree_insertion_methods}",
502
                    file=sys.stderr,
503
                )
504

505
            self.tree_insertion_methods = (
1✔
506
                DIRECT_EMBEDDING + SELF_EMBEDDING + CONTEXT_ADDITION
507
            )
508
            if tree_insertion_methods is not None:
1✔
509
                self.tree_insertion_methods = tree_insertion_methods
1✔
510

511
        self.activate_unsat_support = activate_unsat_support
1✔
512
        self.currently_unsat_checking: bool = False
1✔
513

514
        self.cost_computer = (
1✔
515
            cost_computer
516
            if cost_computer is not None
517
            else GrammarBasedBlackboxCostComputer(STD_COST_SETTINGS, self.graph)
518
        )
519

520
        formula = (
1✔
521
            sc.true()
522
            if formula is None
523
            else (
524
                parse_isla(
525
                    formula, self.grammar, structural_predicates, semantic_predicates
526
                )
527
                if isinstance(formula, str)
528
                else formula
529
            )
530
        )
531

532
        self.formula = ensure_unique_bound_variables(formula)
1✔
533

534
        top_constants: Set[language.Constant] = set(
1✔
535
            [
536
                c
537
                for c in VariablesCollector.collect(self.formula)
538
                if isinstance(c, language.Constant) and not c.is_numeric()
539
            ]
540
        )
541

542
        assert len(top_constants) <= 1, (
1✔
543
            "ISLa only accepts up to one constant (free variable), "
544
            + f'found {len(top_constants)}: {", ".join(map(str, top_constants))}'
545
        )
546

547
        only_top_constant = Maybe.from_optional(next(iter(top_constants), None))
1✔
548
        self.top_constant = only_top_constant.map(
1✔
549
            lambda c: language.Constant(
550
                c.name, Maybe.from_optional(start_symbol).value_or(c.n_type)
551
            )
552
        )
553

554
        if only_top_constant != self.top_constant:
1✔
555
            assert is_successful(only_top_constant)
1✔
556
            assert is_successful(self.top_constant)
1✔
557
            self.formula = self.formula.substitute_variables(
1✔
558
                {only_top_constant.unwrap(): self.top_constant.unwrap()}
559
            )
560

561
        quantifier_chains: List[Tuple[language.ForallFormula, ...]] = [
1✔
562
            tuple([f for f in c if isinstance(f, language.ForallFormula)])
563
            for c in get_quantifier_chains(self.formula)
564
        ]
565
        # TODO: Remove?
566
        self.quantifier_chains: List[Tuple[language.ForallFormula, ...]] = [
1✔
567
            c for c in quantifier_chains if c
568
        ]
569

570
        self.max_number_free_instantiations: int = max_number_free_instantiations
1✔
571
        self.max_number_smt_instantiations: int = max_number_smt_instantiations
1✔
572
        self.max_number_tree_insertion_results = max_number_tree_insertion_results
1✔
573
        self.enforce_unique_trees_in_queue = enforce_unique_trees_in_queue
1✔
574

575
        # Initialize Queue
576
        self.initial_tree = (
1✔
577
            initial_tree.lash(
578
                lambda _: Maybe.from_optional(start_symbol)
579
                .map(lambda s: eassert(s, s in self.grammar))
580
                .map(lambda s: DerivationTree(s, None))
581
            ).lash(
582
                lambda _: Some(
583
                    DerivationTree(
584
                        self.top_constant.map(lambda c: c.n_type).value_or("<start>"),
585
                        None,
586
                    )
587
                )
588
            )
589
        ).unwrap()
590

591
        initial_formula = self.top_constant.map(
1✔
592
            lambda c: self.formula.substitute_expressions({c: self.initial_tree})
593
        ).value_or(self.formula)
594
        initial_state = SolutionState(initial_formula, self.initial_tree)
1✔
595
        initial_states = self.establish_invariant(initial_state)
1✔
596

597
        self.queue: List[Tuple[float, SolutionState]] = []
1✔
598
        self.tree_hashes_in_queue: Set[int] = {self.initial_tree.structural_hash()}
1✔
599
        self.state_hashes_in_queue: Set[int] = {hash(state) for state in initial_states}
1✔
600
        for state in initial_states:
1✔
601
            heapq.heappush(self.queue, (self.compute_cost(state), state))
1✔
602

603
        self.seen_coverages: Set[str] = set()
1✔
604
        self.current_level: int = 0
1✔
605
        self.step_cnt: int = 0
1✔
606
        self.last_cost_recomputation: int = 0
1✔
607

608
        self.regex_cache = {}
1✔
609

610
        self.solutions: List[DerivationTree] = []
1✔
611

612
        # Debugging stuff
613
        self.debug = debug
1✔
614
        self.state_tree: Dict[
1✔
615
            SolutionState, List[SolutionState]
616
        ] = {}  # is only filled if self.debug
617
        self.state_tree_root = None
1✔
618
        self.current_state = None
1✔
619
        self.costs: Dict[SolutionState, float] = {}
1✔
620

621
        if self.debug:
1✔
622
            self.state_tree[initial_state] = initial_states
1✔
623
            self.state_tree_root = initial_state
1✔
624
            self.costs[initial_state] = 0
1✔
625
            for state in initial_states:
1✔
626
                self.costs[state] = self.compute_cost(state)
1✔
627

628
    def solve(self) -> DerivationTree:
1✔
629
        """
630
        Attempts to compute a solution to the given ISLa formula. Returns that solution,
631
        if any. This function can be called repeatedly to obtain more solutions until
632
        one of two exception types is raised: A :class:`StopIteration` indicates that
633
        no more solution can be found; a :class:`TimeoutError` is raised if a timeout
634
        occurred. After that, an exception will be raised every time.
635

636
        The timeout can be controlled by the :code:`timeout_seconds`
637
        :meth:`constructor <isla.solver.ISLaSolver.__init__>` parameter.
638

639
        :return: A solution for the ISLa formula passed to the
640
          :class:`isla.solver.ISLaSolver`.
641
        """
642
        if self.timeout_seconds is not None and self.start_time is None:
1✔
643
            self.start_time = int(time.time())
1✔
644

645
        while self.queue:
1✔
646
            self.step_cnt += 1
1✔
647

648
            # import dill as pickle
649
            # state_hash = 9107154106757938105
650
            # out_file = "/tmp/saved_debug_state"
651
            # if hash(self.queue[0][1]) == state_hash:
652
            #     with open(out_file, 'wb') as debug_state_file:
653
            #         pickle.dump(self, debug_state_file)
654
            #     print(f"Dumping state to {out_file}")
655
            #     exit()
656

657
            if self.timeout_seconds is not None:
1✔
658
                if int(time.time()) - self.start_time > self.timeout_seconds:
1✔
659
                    self.logger.debug("TIMEOUT")
1✔
660
                    raise TimeoutError(self.timeout_seconds)
1✔
661

662
            if self.solutions:
1✔
663
                solution = self.solutions.pop(0)
1✔
664
                self.logger.debug('Found solution "%s"', solution)
1✔
665
                return solution
1✔
666

667
            cost: int
668
            state: SolutionState
669
            cost, state = heapq.heappop(self.queue)
1✔
670

671
            self.current_level = state.level
1✔
672
            self.tree_hashes_in_queue.discard(state.tree.structural_hash())
1✔
673
            self.state_hashes_in_queue.discard(hash(state))
1✔
674

675
            if self.debug:
1✔
676
                self.current_state = state
1✔
677
                self.state_tree.setdefault(state, [])
1✔
678
            self.logger.debug(
1✔
679
                "Polling new state (%s, %s) (hash %d, cost %f)",
680
                state.constraint,
681
                state.tree.to_string(show_open_leaves=True, show_ids=True),
682
                hash(state),
683
                cost,
684
            )
685
            self.logger.debug("Queue length: %s", len(self.queue))
1✔
686

687
            assert not isinstance(state.constraint, language.DisjunctiveFormula)
1✔
688

689
            # Instantiate all top-level structural predicate formulas.
690
            state = self.instantiate_structural_predicates(state)
1✔
691

692
            # Apply the first elimination function that is applicable.
693
            # The later ones are ignored.
694
            def process_and_extend_solutions(
1✔
695
                result_states: List[SolutionState],
696
            ) -> Nothing:
697
                assert result_states is not None
1✔
698
                self.solutions.extend(self.process_new_states(result_states))
1✔
699
                return Nothing
1✔
700

701
            flow(
1✔
702
                Nothing,
703
                *map(
704
                    compose(lambda f: (lambda _: f(state)), lash),
705
                    [
706
                        self.noop_on_false_constraint,
707
                        self.eliminate_existential_integer_quantifiers,
708
                        self.instantiate_universal_integer_quantifiers,
709
                        self.match_all_universal_formulas,
710
                        self.expand_to_match_quantifiers,
711
                        self.eliminate_all_semantic_formulas,
712
                        self.eliminate_all_ready_semantic_predicate_formulas,
713
                        self.eliminate_and_match_first_existential_formula_and_expand,
714
                        self.assert_remaining_formulas_are_lazy_binding_semantic,
715
                        self.finish_unconstrained_trees,
716
                        self.expand,
717
                    ],
718
                ),
719
            ).bind(process_and_extend_solutions)
720

721
        if self.solutions:
1✔
722
            solution = self.solutions.pop(0)
1✔
723
            self.logger.debug('Found solution "%s"', solution)
1✔
724
            return solution
1✔
725
        else:
726
            self.logger.debug("UNSAT")
1✔
727
            raise StopIteration()
1✔
728

729
    def check(self, inp: DerivationTree | str) -> bool:
1✔
730
        """
731
        Evaluates whether the given derivation tree satisfies the constraint passed to
732
        the solver. Raises an `UnknownResultError` if this could not be evaluated
733
        (e.g., because of a solver timeout or a semantic predicate that cannot be
734
        evaluated).
735

736
        :param inp: The input to evaluate, either readily parsed or as a string.
737
        :return: A truth value.
738
        """
739
        if isinstance(inp, str):
1✔
740
            try:
1✔
741
                self.parse(inp)
1✔
742
                return True
1✔
743
            except (SyntaxError, SemanticError):
1✔
744
                return False
1✔
745

746
        assert isinstance(inp, DerivationTree)
1✔
747

748
        result = evaluate(self.formula, inp, self.grammar)
1✔
749

750
        if result.is_unknown():
1✔
751
            raise UnknownResultError()
1✔
752
        else:
753
            return bool(result)
1✔
754

755
    def parse(
1✔
756
        self,
757
        inp: str,
758
        nonterminal: str = "<start>",
759
        skip_check: bool = False,
760
        silent: bool = False,
761
    ) -> DerivationTree:
762
        """
763
        Parses the given input `inp`. Raises a `SyntaxError` if the input does not
764
        satisfy the grammar, a `SemanticError` if it does not satisfy the constraint
765
        (this is only checked if `nonterminal` is "<start>"), and returns the parsed
766
        `DerivationTree` otherwise.
767

768
        :param inp: The input to parse.
769
        :param nonterminal: The nonterminal to start parsing with, if a string
770
          corresponding to a sub-grammar shall be parsed. We don't check semantic
771
          correctness in that case.
772
        :param skip_check: If True, the semantic check is left out.
773
        :param silent: If True, no error is sent to the log stream in case of a
774
            failed parse.
775
        :return: A parsed `DerivationTree`.
776
        """
777
        grammar = copy.deepcopy(self.grammar)
1✔
778
        if nonterminal != "<start>":
1✔
779
            grammar |= {"<start>": [nonterminal]}
1✔
780
            grammar = delete_unreachable(grammar)
1✔
781

782
        parser = EarleyParser(grammar)
1✔
783
        try:
1✔
784
            parse_tree = next(parser.parse(inp))
1✔
785
            if nonterminal != "<start>":
1✔
786
                parse_tree = parse_tree[1][0]
1✔
787
            tree = DerivationTree.from_parse_tree(parse_tree)
1✔
788
        except SyntaxError as err:
1✔
789
            if not silent:
1✔
790
                self.logger.error(
1✔
791
                    f'Error parsing "{inp}" starting with "{nonterminal}"'
792
                )
793
            raise err
1✔
794

795
        if not skip_check and nonterminal == "<start>" and not self.check(tree):
1✔
796
            raise SemanticError()
1✔
797

798
        return tree
1✔
799

800
    def repair(
1✔
801
        self, inp: DerivationTree | str, fix_timeout_seconds: float = 3
802
    ) -> Maybe[DerivationTree]:
803
        """
804
        Attempts to repair the given input. The input must not violate syntactic
805
        (grammar) constraints. If semantic constraints are violated, this method
806
        gradually abstracts the input and tries to turn it into a valid one.
807
        Note that intensive structural manipulations are not performed; we merely
808
        try to satisfy SMT-LIB and semantic predicate constraints.
809

810
        :param inp: The input to fix.
811
        :param fix_timeout_seconds: A timeout used when calling the solver for an
812
          abstracted input. Usually, a low timeout suffices.
813
        :return: A fixed input (or the original, if it was not broken) or nothing.
814
        """
815

816
        inp = self.parse(inp, skip_check=True) if isinstance(inp, str) else inp
1✔
817

818
        try:
1✔
819
            if self.check(inp) or not is_successful(self.top_constant):
1✔
820
                return Some(inp)
1✔
821
        except UnknownResultError:
1✔
822
            pass
1✔
823

824
        formula = self.top_constant.map(
1✔
825
            lambda c: self.formula.substitute_expressions({c: inp})
826
        ).unwrap()
827

828
        set_smt_auto_eval(formula, False)
1✔
829
        set_smt_auto_subst(formula, False)
1✔
830

831
        qfr_free = eliminate_quantifiers(
1✔
832
            formula,
833
            grammar=self.grammar,
834
            numeric_constants={
835
                c
836
                for c in VariablesCollector.collect(formula)
837
                if isinstance(c, language.Constant) and c.is_numeric()
838
            },
839
        )
840

841
        # We first evaluate all structural predicates; for now, we do not interfere
842
        # with structure.
843
        semantic_only = qfr_free.transform(EvaluatePredicateFormulasTransformer(inp))
1✔
844

845
        if semantic_only == sc.false():
1✔
846
            # This cannot be repaired while preserving structure; for existential
847
            # problems, we could try tree insertion. We leave this for future work.
848
            return Nothing
1✔
849

850
        # We try to satisfy any of the remaining disjunctive elements, in random order
851
        for formula_to_satisfy in shuffle(split_disjunction(semantic_only)):
1✔
852
            # Now, we consider all combinations of 1, 2, ... of the derivation trees
853
            # participating in the formula. We successively prune deeper and deeper
854
            # subtrees until the resulting input evaluates to "unknown" for the given
855
            # formula.
856

857
            participating_paths = {
1✔
858
                inp.find_node(arg) for arg in formula_to_satisfy.tree_arguments()
859
            }
860

861
            def do_complete(tree: DerivationTree) -> Maybe[DerivationTree]:
1✔
862
                return result_to_maybe(
1✔
863
                    safe(
864
                        self.copy_without_queue(
865
                            initial_tree=Some(tree),
866
                            timeout_seconds=Some(fix_timeout_seconds),
867
                        ).solve,
868
                        (UnknownResultError, TimeoutError, StopIteration),
869
                    )()
870
                )
871

872
            # If p1, p2 are in participating_paths, then we consider the following
873
            # path combinations (roughly) in the listed order:
874
            # {p1}, {p2}, {p1, p2}, {p1[:-1]}, {p2[:-1]}, {p1[:-1], p2}, {p1, p2[:-1]},
875
            # {p1[:-1], p2[:-1]}, ...
876
            for abstracted_tree in generate_abstracted_trees(inp, participating_paths):
1✔
877
                match (
1✔
878
                    safe(lambda: self.check(abstracted_tree))()
879
                    .bind(lambda _: Nothing)
880
                    .lash(
881
                        lambda exc: Some(abstracted_tree)
882
                        if isinstance(exc, UnknownResultError)
883
                        else Nothing
884
                    )
885
                    .bind(do_complete)
886
                ):
887
                    case Some(completed):
1✔
888
                        return Some(completed)
1✔
889
                    case _:
1✔
890
                        pass
1✔
891

892
        return Nothing
×
893

894
    def mutate(
1✔
895
        self,
896
        inp: DerivationTree | str,
897
        min_mutations: int = 2,
898
        max_mutations: int = 5,
899
        fix_timeout_seconds: float = 1,
900
    ) -> DerivationTree:
901
        """
902
        Mutates `inp` such that the result satisfies the constraint.
903

904
        :param inp: The input to mutate.
905
        :param min_mutations: The minimum number of mutation steps to perform.
906
        :param max_mutations: The maximum number of mutation steps to perform.
907
        :param fix_timeout_seconds: A timeout used when calling the solver for fixing
908
          an abstracted input. Usually, a low timeout suffices.
909
        :return: A mutated input.
910
        """
911

912
        inp = self.parse(inp, skip_check=True) if isinstance(inp, str) else inp
1✔
913
        mutator = Mutator(
1✔
914
            self.grammar,
915
            min_mutations=min_mutations,
916
            max_mutations=max_mutations,
917
            graph=self.graph,
918
        )
919

920
        while True:
1✔
921
            mutated = mutator.mutate(inp)
1✔
922
            if mutated.structurally_equal(inp):
1✔
923
                continue
1✔
924
            maybe_fixed = self.repair(mutated, fix_timeout_seconds)
1✔
925
            if is_successful(maybe_fixed):
1✔
926
                return maybe_fixed.unwrap()
1✔
927

928
    def copy_without_queue(
1✔
929
        self,
930
        grammar: Maybe[Grammar | str] = Nothing,
931
        formula: Maybe[language.Formula | str] = Nothing,
932
        max_number_free_instantiations: Maybe[int] = Nothing,
933
        max_number_smt_instantiations: Maybe[int] = Nothing,
934
        max_number_tree_insertion_results: Maybe[int] = Nothing,
935
        enforce_unique_trees_in_queue: Maybe[bool] = Nothing,
936
        debug: Maybe[bool] = Nothing,
937
        cost_computer: Maybe["CostComputer"] = Nothing,
938
        timeout_seconds: Maybe[int] = Nothing,
939
        global_fuzzer: Maybe[bool] = Nothing,
940
        predicates_unique_in_int_arg: Maybe[
941
            Tuple[language.SemanticPredicate, ...]
942
        ] = Nothing,
943
        fuzzer_factory: Maybe[Callable[[Grammar], GrammarFuzzer]] = Nothing,
944
        tree_insertion_methods: Maybe[int] = Nothing,
945
        activate_unsat_support: Maybe[bool] = Nothing,
946
        grammar_unwinding_threshold: Maybe[int] = Nothing,
947
        initial_tree: Maybe[DerivationTree] = Nothing,
948
        enable_optimized_z3_queries: Maybe[bool] = Nothing,
949
        start_symbol: Optional[str] = None,
950
    ):
951
        result = ISLaSolver(
1✔
952
            grammar=grammar.value_or(self.grammar),
953
            formula=formula.value_or(self.formula),
954
            max_number_free_instantiations=max_number_free_instantiations.value_or(
955
                self.max_number_free_instantiations
956
            ),
957
            max_number_smt_instantiations=max_number_smt_instantiations.value_or(
958
                self.max_number_smt_instantiations
959
            ),
960
            max_number_tree_insertion_results=max_number_tree_insertion_results.value_or(
961
                self.max_number_tree_insertion_results
962
            ),
963
            enforce_unique_trees_in_queue=enforce_unique_trees_in_queue.value_or(
964
                self.enforce_unique_trees_in_queue
965
            ),
966
            debug=debug.value_or(self.debug),
967
            cost_computer=cost_computer.value_or(self.cost_computer),
968
            timeout_seconds=timeout_seconds.value_or(self.timeout_seconds),
969
            global_fuzzer=global_fuzzer.value_or(self.global_fuzzer),
970
            predicates_unique_in_int_arg=predicates_unique_in_int_arg.value_or(
971
                self.predicates_unique_in_int_arg
972
            ),
973
            fuzzer_factory=fuzzer_factory.value_or(self.fuzzer_factory),
974
            tree_insertion_methods=tree_insertion_methods.value_or(
975
                self.tree_insertion_methods
976
            ),
977
            activate_unsat_support=activate_unsat_support.value_or(
978
                self.activate_unsat_support
979
            ),
980
            grammar_unwinding_threshold=grammar_unwinding_threshold.value_or(
981
                self.grammar_unwinding_threshold
982
            ),
983
            initial_tree=initial_tree,
984
            enable_optimized_z3_queries=enable_optimized_z3_queries.value_or(
985
                self.enable_optimized_z3_queries
986
            ),
987
            start_symbol=start_symbol,
988
        )
989

990
        result.regex_cache = self.regex_cache
1✔
991

992
        return result
1✔
993

994
    @staticmethod
1✔
995
    def noop_on_false_constraint(
1✔
996
        state: SolutionState,
997
    ) -> Maybe[List[SolutionState]]:
998
        if state.constraint == sc.false():
1✔
999
            # This state can be silently discarded.
1000
            return Some([state])
1✔
1001

1002
        return Nothing
1✔
1003

1004
    def expand_to_match_quantifiers(
1✔
1005
        self,
1006
        state: SolutionState,
1007
    ) -> Maybe[List[SolutionState]]:
1008
        if all(
1✔
1009
            not isinstance(conjunct, language.ForallFormula)
1010
            for conjunct in get_conjuncts(state.constraint)
1011
        ):
1012
            return Nothing
1✔
1013

1014
        expansion_result = self.expand_tree(state)
1✔
1015

1016
        assert len(expansion_result) > 0, f"State {state} will never leave the queue."
1✔
1017
        self.logger.debug(
1✔
1018
            "Expanding state %s (%d successors)", state, len(expansion_result)
1019
        )
1020

1021
        return Some(expansion_result)
1✔
1022

1023
    def eliminate_and_match_first_existential_formula_and_expand(
1✔
1024
        self,
1025
        state: SolutionState,
1026
    ) -> Maybe[List[SolutionState]]:
1027
        elim_result = self.eliminate_and_match_first_existential_formula(state)
1✔
1028
        if elim_result is None:
1✔
1029
            return Nothing
1✔
1030

1031
        # Also add some expansions of the original state, to create a larger
1032
        # solution stream (otherwise, it might be possible that only a small
1033
        # finite number of solutions are generated for existential formulas).
1034
        return Some(
1✔
1035
            elim_result + self.expand_tree(state, limit=2, only_universal=False)
1036
        )
1037

1038
    def assert_remaining_formulas_are_lazy_binding_semantic(
1✔
1039
        self,
1040
        state: SolutionState,
1041
    ) -> Maybe[List[SolutionState]]:
1042
        # SEMANTIC PREDICATE FORMULAS can remain if they bind lazily. In that case, we can choose a random
1043
        # instantiation and let the predicate "fix" the resulting tree.
1044
        assert state.constraint == sc.true() or all(
1✔
1045
            isinstance(conjunct, language.SemanticPredicateFormula)
1046
            or (
1047
                isinstance(conjunct, language.NegatedFormula)
1048
                and isinstance(conjunct.args[0], language.SemanticPredicateFormula)
1049
            )
1050
            for conjunct in get_conjuncts(state.constraint)
1051
        ), (
1052
            "Constraint is not true and contains formulas "
1053
            f"other than semantic predicate formulas: {state.constraint}"
1054
        )
1055

1056
        assert (
1✔
1057
            state.constraint == sc.true()
1058
            or all(
1059
                not pred_formula.binds_tree(leaf)
1060
                for pred_formula in get_conjuncts(state.constraint)
1061
                if isinstance(pred_formula, language.SemanticPredicateFormula)
1062
                for _, leaf in state.tree.open_leaves()
1063
            )
1064
            or all(
1065
                not cast(
1066
                    language.SemanticPredicateFormula, pred_formula.args[0]
1067
                ).binds_tree(leaf)
1068
                for pred_formula in get_conjuncts(state.constraint)
1069
                if isinstance(pred_formula, language.NegatedFormula)
1070
                and isinstance(pred_formula.args[0], language.SemanticPredicateFormula)
1071
            )
1072
            for _, leaf in state.tree.open_leaves()
1073
        ), (
1074
            "Constraint is not true and contains semantic predicate formulas binding open tree leaves: "
1075
            f"{state.constraint}, leaves: "
1076
            + ", ".join(
1077
                [str(leaf) for _, leaf in state.tree.open_leaves()],
1078
            )
1079
        )
1080

1081
        return Nothing
1✔
1082

1083
    def finish_unconstrained_trees(
1✔
1084
        self,
1085
        state: SolutionState,
1086
    ) -> Maybe[List[SolutionState]]:
1087
        fuzzer = (
1✔
1088
            self.fuzzer if self.global_fuzzer else self.fuzzer_factory(self.grammar)
1089
        )
1090

1091
        if isinstance(fuzzer, GrammarCoverageFuzzer):
1✔
1092
            fuzzer.covered_expansions.update(self.seen_coverages)
1✔
1093

1094
        if state.constraint != sc.true():
1✔
1095
            return Nothing
1✔
1096

1097
        closed_results: List[SolutionState] = []
1✔
1098
        for _ in range(self.max_number_free_instantiations):
1✔
1099
            result = state.tree
1✔
1100
            for path, leaf in state.tree.open_leaves():
1✔
1101
                leaf_inst = fuzzer.expand_tree(DerivationTree(leaf.value, None))
1✔
1102
                result = result.replace_path(path, leaf_inst)
1✔
1103

1104
            closed_results.append(SolutionState(state.constraint, result))
1✔
1105

1106
        return Some(closed_results)
1✔
1107

1108
    def expand(
1✔
1109
        self,
1110
        state: SolutionState,
1111
    ) -> Maybe[List[SolutionState]]:
1112
        fuzzer = (
1✔
1113
            self.fuzzer if self.global_fuzzer else self.fuzzer_factory(self.grammar)
1114
        )
1115

1116
        if isinstance(fuzzer, GrammarCoverageFuzzer):
1✔
1117
            fuzzer.covered_expansions.update(self.seen_coverages)
1✔
1118

1119
        result: List[SolutionState] = []
1✔
1120
        for _ in range(self.max_number_free_instantiations):
1✔
1121
            substitutions: Dict[DerivationTree, DerivationTree] = {
1✔
1122
                subtree: fuzzer.expand_tree(DerivationTree(subtree.value, None))
1123
                for path, subtree in state.tree.open_leaves()
1124
            }
1125

1126
            if substitutions:
1✔
1127
                result.append(
1✔
1128
                    SolutionState(
1129
                        state.constraint.substitute_expressions(substitutions),
1130
                        state.tree.substitute(substitutions),
1131
                    )
1132
                )
1133

1134
        return Some(result)
1✔
1135

1136
    def instantiate_structural_predicates(self, state: SolutionState) -> SolutionState:
1✔
1137
        predicate_formulas = [
1✔
1138
            pred_formula
1139
            for pred_formula in language.FilterVisitor(
1140
                lambda f: isinstance(f, language.StructuralPredicateFormula)
1141
            ).collect(state.constraint)
1142
            if (
1143
                isinstance(pred_formula, language.StructuralPredicateFormula)
1144
                and all(
1145
                    not isinstance(arg, language.Variable) for arg in pred_formula.args
1146
                )
1147
            )
1148
        ]
1149

1150
        formula = state.constraint
1✔
1151
        for predicate_formula in predicate_formulas:
1✔
1152
            instantiation = language.SMTFormula(
1✔
1153
                z3.BoolVal(predicate_formula.evaluate(state.tree))
1154
            )
1155
            self.logger.debug(
1✔
1156
                "Eliminating (-> %s) structural predicate formula %s",
1157
                instantiation,
1158
                predicate_formula,
1159
            )
1160
            formula = language.replace_formula(
1✔
1161
                formula, predicate_formula, instantiation
1162
            )
1163

1164
        return SolutionState(formula, state.tree)
1✔
1165

1166
    def eliminate_existential_integer_quantifiers(
1✔
1167
        self, state: SolutionState
1168
    ) -> Maybe[List[SolutionState]]:
1169
        existential_int_formulas = [
1✔
1170
            conjunct
1171
            for conjunct in get_conjuncts(state.constraint)
1172
            if isinstance(conjunct, language.ExistsIntFormula)
1173
        ]
1174

1175
        if not existential_int_formulas:
1✔
1176
            return Nothing
1✔
1177

1178
        formula = state.constraint
1✔
1179
        for existential_int_formula in existential_int_formulas:
1✔
1180
            # The following check for validity is not only a performance measure, but required
1181
            # when existential integer formulas are re-inserted. Otherwise, new constants get
1182
            # introduced, and the solver won't terminate.
1183
            if evaluate(
1✔
1184
                existential_int_formula,
1185
                state.tree,
1186
                self.grammar,
1187
                assumptions={
1188
                    f
1189
                    for f in split_conjunction(state.constraint)
1190
                    if f != existential_int_formula
1191
                },
1192
            ).is_true():
1193
                self.logger.debug(
1✔
1194
                    "Removing existential integer quantifier '%.30s', already implied "
1195
                    "by tree and existing constraints",
1196
                    existential_int_formula,
1197
                )
1198
                # This should simplify the process after quantifier re-insertion.
1199
                return Some(
1✔
1200
                    [
1201
                        SolutionState(
1202
                            language.replace_formula(
1203
                                state.constraint, existential_int_formula, sc.true()
1204
                            ),
1205
                            state.tree,
1206
                        )
1207
                    ]
1208
                )
1209

1210
            self.logger.debug(
1✔
1211
                "Eliminating existential integer quantifier %s", existential_int_formula
1212
            )
1213
            used_vars = set(VariablesCollector.collect(formula))
1✔
1214
            fresh = language.fresh_constant(
1✔
1215
                used_vars,
1216
                language.Constant(
1217
                    existential_int_formula.bound_variable.name,
1218
                    existential_int_formula.bound_variable.n_type,
1219
                ),
1220
            )
1221
            instantiation = existential_int_formula.inner_formula.substitute_variables(
1✔
1222
                {existential_int_formula.bound_variable: fresh}
1223
            )
1224
            formula = language.replace_formula(
1✔
1225
                formula, existential_int_formula, instantiation
1226
            )
1227

1228
        return Some([SolutionState(formula, state.tree)])
1✔
1229

1230
    def instantiate_universal_integer_quantifiers(
1✔
1231
        self, state: SolutionState
1232
    ) -> Maybe[List[SolutionState]]:
1233
        universal_int_formulas = [
1✔
1234
            conjunct
1235
            for conjunct in get_conjuncts(state.constraint)
1236
            if isinstance(conjunct, language.ForallIntFormula)
1237
        ]
1238

1239
        if not universal_int_formulas:
1✔
1240
            return Nothing
1✔
1241

1242
        results: List[SolutionState] = [state]
1✔
1243
        for universal_int_formula in universal_int_formulas:
1✔
1244
            results = [
1✔
1245
                result
1246
                for formula_list in [
1247
                    self.instantiate_universal_integer_quantifier(
1248
                        previous_result, universal_int_formula
1249
                    )
1250
                    for previous_result in results
1251
                ]
1252
                for result in formula_list
1253
            ]
1254

1255
        return Some(results)
1✔
1256

1257
    def instantiate_universal_integer_quantifier(
1✔
1258
        self, state: SolutionState, universal_int_formula: language.ForallIntFormula
1259
    ) -> List[SolutionState]:
1260
        results = self.instantiate_universal_integer_quantifier_by_enumeration(
1✔
1261
            state, universal_int_formula
1262
        )
1263

1264
        if results:
1✔
1265
            return results
1✔
1266

1267
        return self.instantiate_universal_integer_quantifier_by_transformation(
1✔
1268
            state, universal_int_formula
1269
        )
1270

1271
    def instantiate_universal_integer_quantifier_by_transformation(
1✔
1272
        self, state: SolutionState, universal_int_formula: language.ForallIntFormula
1273
    ) -> List[SolutionState]:
1274
        # If the enumeration approach was not successful, we con transform the universal int
1275
        # quantifier to an existential one in a particular situation:
1276
        #
1277
        # Let phi(elem, i) be such that phi(elem) (for fixed first argument) is a unary
1278
        # relation that holds for exactly one argument:
1279
        #
1280
        # forall <A> elem:
1281
        #   exists int i:
1282
        #     phi(elem, i) and
1283
        #     forall int i':
1284
        #       phi(elem, i) <==> i = i'
1285
        #
1286
        # Then, the following transformations are equivalence-preserving:
1287
        #
1288
        # forall int i:
1289
        #   exists <A> elem:
1290
        #     not phi(elem, i)
1291
        #
1292
        # <==> (*)
1293
        #
1294
        # exists int i:
1295
        #   exists <A> elem':
1296
        #     phi(elem', i) &
1297
        #   exists <A> elem:
1298
        #     not phi(elem, i) &
1299
        #   forall int i':
1300
        #     i != i' ->
1301
        #     exists <A> elem'':
1302
        #       not phi(elem'', i')
1303
        #
1304
        # <==> (+)
1305
        #
1306
        # exists int i:
1307
        #   exists <A> elem':
1308
        #     phi(elem', i) &
1309
        #   exists <A> elem:
1310
        #     not phi(elem, i)
1311
        #
1312
        # (*)
1313
        # Problematic is only the first inner conjunct. However, for every elem, there
1314
        # has to be an i such that phi(elem, i) holds. If there is no no in the first
1315
        # place, also the original formula would be unsatisfiable. Without this conjunct,
1316
        # the transformation is a simple "quantifier unwinding."
1317
        #
1318
        # (+)
1319
        # Let i' != i. Choose elem'' := elem': Since phi(elem', i) holds and i != i',
1320
        # "not phi(elem', i')" has to hold.
1321

1322
        if (
1✔
1323
            isinstance(universal_int_formula.inner_formula, language.ExistsFormula)
1324
            and isinstance(
1325
                universal_int_formula.inner_formula.inner_formula,
1326
                language.NegatedFormula,
1327
            )
1328
            and isinstance(
1329
                universal_int_formula.inner_formula.inner_formula.args[0],
1330
                language.SemanticPredicateFormula,
1331
            )
1332
            and cast(
1333
                language.SemanticPredicateFormula,
1334
                universal_int_formula.inner_formula.inner_formula.args[0],
1335
            ).predicate
1336
            in self.predicates_unique_in_int_arg
1337
        ):
1338
            inner_formula: language.ExistsFormula = universal_int_formula.inner_formula
1✔
1339
            predicate_formula: language.SemanticPredicateFormula = cast(
1✔
1340
                language.SemanticPredicateFormula,
1341
                cast(language.NegatedFormula, inner_formula.inner_formula).args[0],
1342
            )
1343

1344
            fresh_var = language.fresh_bound_variable(
1✔
1345
                language.VariablesCollector().collect(state.constraint),
1346
                inner_formula.bound_variable,
1347
                add=False,
1348
            )
1349

1350
            new_formula = language.ExistsIntFormula(
1✔
1351
                universal_int_formula.bound_variable,
1352
                language.ExistsFormula(
1353
                    fresh_var,
1354
                    inner_formula.in_variable,
1355
                    predicate_formula.substitute_variables(
1356
                        {inner_formula.bound_variable: fresh_var}
1357
                    ),
1358
                )
1359
                & inner_formula,
1360
            )
1361

1362
            self.logger.debug(
1✔
1363
                "Transforming universal integer quantifier "
1364
                "(special case, see code comments for explanation):\n%s ==> %s",
1365
                universal_int_formula,
1366
                new_formula,
1367
            )
1368

1369
            return [
1✔
1370
                SolutionState(
1371
                    language.replace_formula(
1372
                        state.constraint, universal_int_formula, new_formula
1373
                    ),
1374
                    state.tree,
1375
                )
1376
            ]
1377

1378
        self.logger.warning(
×
1379
            "Did not find a way to instantiate formula %s!\n"
1380
            + "Discarding this state. Please report this to your nearest ISLa developer.",
1381
            universal_int_formula,
1382
        )
1383

1384
        return []
×
1385

1386
    def instantiate_universal_integer_quantifier_by_enumeration(
1✔
1387
        self, state: SolutionState, universal_int_formula: ForallIntFormula
1388
    ) -> Optional[List[SolutionState]]:
1389
        constant = language.Constant(
1✔
1390
            universal_int_formula.bound_variable.name,
1391
            universal_int_formula.bound_variable.n_type,
1392
        )
1393

1394
        # noinspection PyTypeChecker
1395
        inner_formula = universal_int_formula.inner_formula.substitute_variables(
1✔
1396
            {universal_int_formula.bound_variable: constant}
1397
        )
1398

1399
        instantiations: List[
1✔
1400
            Dict[
1401
                language.Constant | DerivationTree,
1402
                int | language.Constant | DerivationTree,
1403
            ]
1404
        ] = []
1405

1406
        if isinstance(universal_int_formula.inner_formula, language.DisjunctiveFormula):
1✔
1407
            # In the disjunctive case, we attempt to falsify all SMT formulas in the inner formula
1408
            # (on top level) that contain the bound variable as argument.
1409
            smt_disjuncts = [
1✔
1410
                formula
1411
                for formula in language.split_disjunction(inner_formula)
1412
                if isinstance(formula, language.SMTFormula)
1413
                and constant in formula.free_variables()
1414
            ]
1415

1416
            if smt_disjuncts and len(smt_disjuncts) < len(
1✔
1417
                language.split_disjunction(inner_formula)
1418
            ):
1419
                instantiation_values = (
1✔
1420
                    self.infer_satisfying_assignments_for_smt_formula(
1421
                        -reduce(language.SMTFormula.disjunction, smt_disjuncts),
1422
                        constant,
1423
                    )
1424
                )
1425

1426
                # We also try to falsify (negated) semantic predicate formulas, if present,
1427
                # if there exist any remaining disjuncts.
1428
                semantic_predicate_formulas: List[
1✔
1429
                    Tuple[language.SemanticPredicateFormula, bool]
1430
                ] = [
1431
                    (pred_formula, False)
1432
                    if isinstance(pred_formula, language.SemanticPredicateFormula)
1433
                    else (cast(language.NegatedFormula, pred_formula).args[0], True)
1434
                    for pred_formula in language.FilterVisitor(
1435
                        lambda f: (
1436
                            constant in f.free_variables()
1437
                            and (
1438
                                isinstance(f, language.SemanticPredicateFormula)
1439
                                or isinstance(f, language.NegatedFormula)
1440
                                and isinstance(
1441
                                    f.args[0], language.SemanticPredicateFormula
1442
                                )
1443
                            )
1444
                        ),
1445
                        do_continue=lambda f: isinstance(
1446
                            f, language.DisjunctiveFormula
1447
                        ),
1448
                    ).collect(inner_formula)
1449
                    if all(
1450
                        not isinstance(var, language.BoundVariable)
1451
                        for var in pred_formula.free_variables()
1452
                    )
1453
                ]
1454

1455
                if semantic_predicate_formulas and len(
1✔
1456
                    semantic_predicate_formulas
1457
                ) + len(smt_disjuncts) < len(language.split_disjunction(inner_formula)):
1458
                    for value in instantiation_values:
1✔
1459
                        instantiation: Dict[
1✔
1460
                            language.Constant | DerivationTree,
1461
                            int | language.Constant | DerivationTree,
1462
                        ] = {constant: value}
1463
                        for (
1✔
1464
                            semantic_predicate_formula,
1465
                            negated,
1466
                        ) in semantic_predicate_formulas:
1467
                            eval_result = cast(
1✔
1468
                                language.SemanticPredicateFormula,
1469
                                language.substitute(
1470
                                    semantic_predicate_formula, {constant: value}
1471
                                ),
1472
                            ).evaluate(self.graph, negate=not negated)
1473
                            if eval_result.ready() and not eval_result.is_boolean():
1✔
1474
                                instantiation.update(eval_result.result)
1✔
1475
                        instantiations.append(instantiation)
1✔
1476
                else:
1477
                    instantiations.extend(
×
1478
                        [{constant: value} for value in instantiation_values]
1479
                    )
1480

1481
        results: List[SolutionState] = []
1✔
1482
        for instantiation in instantiations:
1✔
1483
            self.logger.debug(
1✔
1484
                "Instantiating universal integer quantifier (%s -> %s) %s",
1485
                universal_int_formula.bound_variable,
1486
                instantiation[constant],
1487
                universal_int_formula,
1488
            )
1489

1490
            formula = language.replace_formula(
1✔
1491
                state.constraint,
1492
                universal_int_formula,
1493
                language.substitute(inner_formula, instantiation),
1494
            )
1495
            formula = language.substitute(formula, instantiation)
1✔
1496

1497
            tree = state.tree.substitute(
1✔
1498
                {
1499
                    tree: subst
1500
                    for tree, subst in instantiation.items()
1501
                    if isinstance(tree, DerivationTree)
1502
                }
1503
            )
1504

1505
            results.append(SolutionState(formula, tree))
1✔
1506

1507
        return results
1✔
1508

1509
    def infer_satisfying_assignments_for_smt_formula(
1✔
1510
        self, smt_formula: language.SMTFormula, constant: language.Constant
1511
    ) -> Set[int | language.Constant]:
1512
        """
1513
        This method returns `self.max_number_free_instantiations` many solutions for
1514
        the given :class:`~isla.language.SMTFormula` if `constant` is the only free
1515
        variable in `smt_formula`. The given formula must be a numeric formula, i.e.,
1516
        all free variables must be numeric. If more than one free variables are
1517
        present, at most one solution is returned (see example below).
1518

1519
        :param smt_formula: The :class:`~isla.language.SMTFormula` to solve. Must
1520
          only contain numeric free variables.
1521
        :param constant: One free variable in `smt_formula`.
1522
        :return: A set of solutions (see explanation above & comment below).
1523

1524
        We create a solver with a dummy grammar (it's not needed for this example),
1525
        choosing a value of 5 for `max_number_free_instantiations`.
1526

1527
        >>> solver = ISLaSolver(
1528
        ...     '<start> ::= "x"',  # dummy grammar
1529
        ...     max_number_free_instantiations=5,
1530
        ... )
1531

1532
        The formula we're considering is `x > 10`.
1533

1534
        >>> from isla.language import Constant, SMTFormula, Variable, unparse_isla
1535
        >>> x = Constant("x", Variable.NUMERIC_NTYPE)
1536

1537
        >>> formula = SMTFormula(z3.StrToInt(x.to_smt()) > z3.IntVal(10), x)
1538
        >>> unparse_isla(formula)
1539
        '(< 10 (str.to.int x))'
1540

1541
        We obtain five results (due to our choice of `max_number_free_instantiations`).
1542

1543
        >>> results = solver.infer_satisfying_assignments_for_smt_formula(formula, x)
1544
        >>> len(results)
1545
        5
1546

1547
        All results are `int`s...
1548

1549
        >>> all(isinstance(result, int) for result in results)
1550
        True
1551

1552
        ...and all are strictly greater than 10.
1553

1554
        >>> all(result > 10 for result in results)
1555
        True
1556

1557
        Now, lets consider `x == y`. This formula contains *two* free variables, `x`
1558
        and `y`. It is the only type of formula with more than one variable for which
1559
        this method will return a solution in the current state.
1560

1561
        >>> y = Constant("y", Variable.NUMERIC_NTYPE)
1562
        >>> formula = SMTFormula(
1563
        ...     z3_eq(z3.StrToInt(x.to_smt()), z3.StrToInt(y.to_smt())), x, y)
1564
        >>> unparse_isla(formula)
1565
        '(= (str.to.int x) (str.to.int y))'
1566

1567
        The solution is the singleton set with the variable `y`, which is an
1568
        instantiation of the constant `x` solving the equation.
1569

1570
        >>> solver.infer_satisfying_assignments_for_smt_formula(formula, x)
1571
        {Constant("y", "NUM")}
1572

1573
        If we choose a different type of formula (a greater-than relation), we obtain
1574
        an empty solution set.
1575

1576
        >>> formula = SMTFormula(
1577
        ...     z3.StrToInt(x.to_smt()) > z3.StrToInt(y.to_smt()), x, y)
1578
        >>> unparse_isla(formula)
1579
        '(> (str.to.int x) (str.to.int y))'
1580
        >>> solver.infer_satisfying_assignments_for_smt_formula(formula, x)
1581
        set()
1582

1583
        With a non-numeric formula, we obtain an AssertionError (if assertions are
1584
        enabled). This method expects to be called only internally, so this should
1585
        not happen (with or without activated assertions).
1586

1587
        >>> z = Constant("x", "<start>")
1588
        >>> formula = SMTFormula(z3_eq(z.to_smt(), z3.StringVal("x")), z)
1589
        >>> print(unparse_isla(formula))
1590
        const x: <start>;
1591
        <BLANKLINE>
1592
        (= x "x")
1593
        >>> results = solver.infer_satisfying_assignments_for_smt_formula(formula, z)
1594
        Traceback (most recent call last):
1595
        ...
1596
        AssertionError: Expected numeric solution.
1597
        """
1598

1599
        free_variables = smt_formula.free_variables()
1✔
1600
        max_instantiations = (
1✔
1601
            self.max_number_free_instantiations if len(free_variables) == 1 else 1
1602
        )
1603

1604
        try:
1✔
1605
            solver_result = self.solve_quantifier_free_formula(
1✔
1606
                (smt_formula,), max_instantiations=max_instantiations
1607
            )
1608

1609
            solutions: Dict[language.Constant, Set[int]] = {
1✔
1610
                c: {
1611
                    int(solution[cast(language.Constant, c)].value)
1612
                    for solution in solver_result
1613
                }
1614
                for c in free_variables
1615
            }
1616
        except ValueError:
1✔
1617
            assert False, "Expected numeric solution."
1✔
1618

1619
        if solutions:
1✔
1620
            if len(free_variables) == 1:
1✔
1621
                return solutions[constant]
1✔
1622
            else:
1623
                assert all(len(solution) == 1 for solution in solutions.values())
1✔
1624
                # In situations with multiple variables, we might have to abstract from
1625
                # concrete values. Currently, we only support simple equality inference
1626
                # (based on one sample...). Note that for supporting *more complex*
1627
                # terms (e.g., additions), we would have to extend the whole
1628
                # infrastructure: Substitutions with complex terms, and complex terms
1629
                # in semantic predicate arguments, are unsupported as of now.
1630
                candidates = {
1✔
1631
                    c
1632
                    for c in solutions
1633
                    if c != constant
1634
                    and next(iter(solutions[c])) == next(iter(solutions[constant]))
1635
                }
1636

1637
                # Filter working candidates
1638
                return {
1✔
1639
                    c
1640
                    for c in candidates
1641
                    if self.solve_quantifier_free_formula(
1642
                        (
1643
                            cast(
1644
                                language.SMTFormula,
1645
                                smt_formula.substitute_variables({constant: c}),
1646
                            ),
1647
                        ),
1648
                        max_instantiations=1,
1649
                    )
1650
                }
1651

1652
    def eliminate_all_semantic_formulas(
1✔
1653
        self, state: SolutionState, max_instantiations: Optional[int] = None
1654
    ) -> Maybe[List[SolutionState]]:
1655
        """
1656
        Eliminates all SMT-LIB formulas that appear in `state`'s constraint as conjunctive elements.
1657
        If, e.g., an SMT-LIB formula occurs as a disjunction, no solution is computed.
1658

1659
        :param state: The state in which to solve all SMT-LIB formulas.
1660
        :param max_instantiations: The number of solutions the SMT solver should be asked for.
1661
        :return: The discovered solutions.
1662
        """
1663

1664
        conjuncts = split_conjunction(state.constraint)
1✔
1665
        semantic_formulas = [
1✔
1666
            conjunct
1667
            for conjunct in conjuncts
1668
            if isinstance(conjunct, language.SMTFormula)
1669
            and not z3.is_true(conjunct.formula)
1670
        ]
1671

1672
        if not semantic_formulas:
1✔
1673
            return Nothing
1✔
1674

1675
        self.logger.debug(
1✔
1676
            "Eliminating semantic formulas [%s]", lazyjoin(", ", semantic_formulas)
1677
        )
1678

1679
        prefix_conjunction = reduce(lambda a, b: a & b, semantic_formulas, sc.true())
1✔
1680
        new_disjunct = prefix_conjunction & reduce(
1✔
1681
            lambda a, b: a & b,
1682
            [conjunct for conjunct in conjuncts if conjunct not in semantic_formulas],
1683
            sc.true(),
1684
        )
1685

1686
        return Some(
1✔
1687
            self.eliminate_semantic_formula(
1688
                prefix_conjunction,
1689
                SolutionState(new_disjunct, state.tree),
1690
                max_instantiations,
1691
            )
1692
        )
1693

1694
    def eliminate_all_ready_semantic_predicate_formulas(
1✔
1695
        self, state: SolutionState
1696
    ) -> Maybe[List[SolutionState]]:
1697
        semantic_predicate_formulas: List[
1✔
1698
            language.NegatedFormula | language.SemanticPredicateFormula
1699
        ] = [
1700
            cast(
1701
                language.NegatedFormula | language.SemanticPredicateFormula,
1702
                pred_formula,
1703
            )
1704
            for pred_formula in language.FilterVisitor(
1705
                lambda f: (
1706
                    isinstance(f, language.SemanticPredicateFormula)
1707
                    or isinstance(f, language.NegatedFormula)
1708
                    and isinstance(f.args[0], language.SemanticPredicateFormula)
1709
                ),
1710
                do_continue=lambda f: (
1711
                    not isinstance(f, language.NegatedFormula)
1712
                    or not isinstance(f.args[0], language.SemanticPredicateFormula)
1713
                ),
1714
            ).collect(state.constraint)
1715
            if all(
1716
                not isinstance(var, language.BoundVariable)
1717
                for var in pred_formula.free_variables()
1718
            )
1719
        ]
1720

1721
        semantic_predicate_formulas = sorted(
1✔
1722
            semantic_predicate_formulas,
1723
            key=lambda f: (
1724
                2 * cast(language.SemanticPredicateFormula, f.args[0]).predicate.order
1725
                + 100
1726
                if isinstance(f, language.NegatedFormula)
1727
                else f.predicate.order
1728
            ),
1729
        )
1730

1731
        if not semantic_predicate_formulas:
1✔
1732
            return Nothing
1✔
1733

1734
        result = state
1✔
1735

1736
        changed = False
1✔
1737
        for idx, possibly_negated_semantic_predicate_formula in enumerate(
1✔
1738
            semantic_predicate_formulas
1739
        ):
1740
            negated = isinstance(
1✔
1741
                possibly_negated_semantic_predicate_formula, language.NegatedFormula
1742
            )
1743
            semantic_predicate_formula: language.SemanticPredicateFormula = (
1✔
1744
                cast(
1745
                    language.NegatedFormula, possibly_negated_semantic_predicate_formula
1746
                ).args[0]
1747
                if negated
1748
                else possibly_negated_semantic_predicate_formula
1749
            )
1750

1751
            evaluation_result = semantic_predicate_formula.evaluate(
1✔
1752
                self.graph, negate=negated
1753
            )
1754
            if not evaluation_result.ready():
1✔
1755
                continue
1✔
1756

1757
            self.logger.debug(
1✔
1758
                "Eliminating semantic predicate formula %s", semantic_predicate_formula
1759
            )
1760
            changed = True
1✔
1761

1762
            if evaluation_result.is_boolean():
1✔
1763
                result = SolutionState(
1✔
1764
                    language.replace_formula(
1765
                        result.constraint,
1766
                        semantic_predicate_formula,
1767
                        language.smt_atom(evaluation_result.true()),
1768
                    ),
1769
                    result.tree,
1770
                )
1771
                continue
1✔
1772

1773
            substitution = subtree_solutions(evaluation_result.result)
1✔
1774

1775
            new_constraint = language.replace_formula(
1✔
1776
                result.constraint,
1777
                semantic_predicate_formula,
1778
                sc.false() if negated else sc.true(),
1779
            ).substitute_expressions(substitution)
1780

1781
            for k in range(idx + 1, len(semantic_predicate_formulas)):
1✔
1782
                semantic_predicate_formulas[k] = cast(
1✔
1783
                    language.SemanticPredicateFormula,
1784
                    semantic_predicate_formulas[k].substitute_expressions(substitution),
1785
                )
1786

1787
            result = SolutionState(new_constraint, result.tree.substitute(substitution))
1✔
1788
            assert self.graph.tree_is_valid(result.tree)
1✔
1789

1790
        return Maybe.from_optional([result] if changed else None)
1✔
1791

1792
    def eliminate_and_match_first_existential_formula(
1✔
1793
        self, state: SolutionState
1794
    ) -> Optional[List[SolutionState]]:
1795
        # We produce up to two groups of output states: One where the first existential
1796
        # formula, if it can be matched, is matched, and one where the first existential
1797
        # formula is eliminated by tree insertion.
1798

1799
        def do_eliminate(
1✔
1800
            first_existential_formula_with_idx: Tuple[int, language.ExistsFormula]
1801
        ) -> List[SolutionState]:
1802
            first_matched = OrderedSet(
1✔
1803
                self.match_existential_formula(
1804
                    first_existential_formula_with_idx[0], state
1805
                )
1806
            )
1807

1808
            # Tree insertion can be deactivated by setting `self.tree_insertion_methods`
1809
            # to 0.
1810
            if not self.tree_insertion_methods:
1✔
1811
                return list(first_matched)
1✔
1812

1813
            if first_matched:
1✔
1814
                self.logger.debug(
1✔
1815
                    "Matched first existential formulas, result: [%s]",
1816
                    lazyjoin(
1817
                        ", ",
1818
                        [
1819
                            lazystr(lambda: f"{s} (hash={hash(s)})")
1820
                            for s in first_matched
1821
                        ],
1822
                    ),
1823
                )
1824

1825
            # 3. Eliminate first existential formula by tree insertion.
1826
            elimination_result = OrderedSet(
1✔
1827
                self.eliminate_existential_formula(
1828
                    first_existential_formula_with_idx[0], state
1829
                )
1830
            )
1831
            elimination_result = OrderedSet(
1✔
1832
                [
1833
                    result
1834
                    for result in elimination_result
1835
                    if not any(
1836
                        other_result.tree == result.tree
1837
                        and self.propositionally_unsatisfiable(
1838
                            result.constraint & -other_result.constraint
1839
                        )
1840
                        for other_result in first_matched
1841
                    )
1842
                ]
1843
            )
1844

1845
            if not elimination_result and not first_matched:
1✔
1846
                self.logger.warning(
×
1847
                    "Existential qfr elimination: Could not eliminate existential formula %s "
1848
                    "by matching or tree insertion",
1849
                    first_existential_formula_with_idx[1],
1850
                )
1851

1852
            if elimination_result:
1✔
1853
                self.logger.debug(
1✔
1854
                    "Eliminated existential formula %s by tree insertion, %d successors",
1855
                    first_existential_formula_with_idx[1],
1856
                    len(elimination_result),
1857
                )
1858

1859
            return [
1✔
1860
                result
1861
                for result in first_matched | elimination_result
1862
                if result != state
1863
            ]
1864

1865
        return (
1✔
1866
            result_to_maybe(
1867
                safe(
1868
                    lambda: next(
1869
                        (idx, conjunct)
1870
                        for idx, conjunct in enumerate(
1871
                            split_conjunction(state.constraint)
1872
                        )
1873
                        if isinstance(conjunct, language.ExistsFormula)
1874
                    )
1875
                )()
1876
            )
1877
            .map(do_eliminate)
1878
            .value_or(None)
1879
        )
1880

1881
    def match_all_universal_formulas(
1✔
1882
        self, state: SolutionState
1883
    ) -> Maybe[List[SolutionState]]:
1884
        universal_formulas = [
1✔
1885
            conjunct
1886
            for conjunct in split_conjunction(state.constraint)
1887
            if isinstance(conjunct, language.ForallFormula)
1888
        ]
1889

1890
        if not universal_formulas:
1✔
1891
            return Nothing
1✔
1892

1893
        result = self.match_universal_formulas(state)
1✔
1894
        if not result:
1✔
1895
            return Nothing
1✔
1896

1897
        self.logger.debug(
1✔
1898
            "Matched universal formulas [%s]", lazyjoin(", ", universal_formulas)
1899
        )
1900

1901
        return Some(result)
1✔
1902

1903
    def expand_tree(
1✔
1904
        self,
1905
        state: SolutionState,
1906
        only_universal: bool = True,
1907
        limit: Optional[int] = None,
1908
    ) -> List[SolutionState]:
1909
        """
1910
        Expands the given tree, but not at nonterminals that can be freely instantiated of those that directly
1911
        correspond to the assignment constant.
1912

1913
        :param state: The current state.
1914
        :param only_universal: If set to True, only nonterminals that might match universal quantifiers are
1915
        expanded. If set to false, also nonterminals matching to existential quantifiers are expanded.
1916
        :param limit: If set to a value, this will return only up to limit expansions.
1917
        :return: A (possibly empty) list of expanded trees.
1918
        """
1919

1920
        nonterminal_expansions: Dict[Path, List[List[DerivationTree]]] = {
1✔
1921
            leaf_path: [
1922
                [
1923
                    DerivationTree(child, None if is_nonterminal(child) else [])
1924
                    for child in expansion
1925
                ]
1926
                for expansion in self.canonical_grammar[leaf_node.value]
1927
            ]
1928
            for leaf_path, leaf_node in state.tree.open_leaves()
1929
            if any(
1930
                self.quantified_formula_might_match(formula, leaf_path, state.tree)
1931
                for formula in get_conjuncts(state.constraint)
1932
                if (only_universal and isinstance(formula, language.ForallFormula))
1933
                or (
1934
                    not only_universal
1935
                    and isinstance(formula, language.QuantifiedFormula)
1936
                )
1937
            )
1938
        }
1939

1940
        if not nonterminal_expansions:
1✔
1941
            return []
1✔
1942

1943
        possible_expansions: List[Dict[Path, List[DerivationTree]]] = []
1✔
1944
        if not limit:
1✔
1945
            possible_expansions = dict_of_lists_to_list_of_dicts(nonterminal_expansions)
1✔
1946
            assert len(possible_expansions) == math.prod(
1✔
1947
                len(values) for values in nonterminal_expansions.values()
1948
            )
1949
        else:
1950
            for _ in range(limit):
1✔
1951
                curr_expansion = {}
1✔
1952
                for path, expansions in nonterminal_expansions.items():
1✔
1953
                    if not expansions:
1✔
1954
                        continue
×
1955

1956
                    curr_expansion[path] = random.choice(expansions)
1✔
1957
                possible_expansions.append(curr_expansion)
1✔
1958

1959
        # This replaces a previous `if` statement with the negated condition as guard,
1960
        # which seems to be dead code (the guard can never hold true due to the check
1961
        # of emptiness of `nonterminal_expansions` above). We keep this assertion here
1962
        # to be sure.
1963
        assert (
1✔
1964
            len(possible_expansions) > 1
1965
            or len(possible_expansions) == 1
1966
            and possible_expansions[0]
1967
        )
1968

1969
        result: List[SolutionState] = []
1✔
1970
        for possible_expansion in possible_expansions:
1✔
1971
            expanded_tree = state.tree
1✔
1972
            for path, new_children in possible_expansion.items():
1✔
1973
                leaf_node = expanded_tree.get_subtree(path)
1✔
1974
                expanded_tree = expanded_tree.replace_path(
1✔
1975
                    path, DerivationTree(leaf_node.value, new_children, leaf_node.id)
1976
                )
1977

1978
                assert expanded_tree is not state.tree
1✔
1979
                assert expanded_tree != state.tree
1✔
1980
                assert expanded_tree.structural_hash() != state.tree.structural_hash()
1✔
1981

1982
            updated_constraint = state.constraint.substitute_expressions(
1✔
1983
                {
1984
                    state.tree.get_subtree(path[:idx]): expanded_tree.get_subtree(
1985
                        path[:idx]
1986
                    )
1987
                    for path in possible_expansion
1988
                    for idx in range(len(path) + 1)
1989
                }
1990
            )
1991

1992
            result.append(SolutionState(updated_constraint, expanded_tree))
1✔
1993

1994
        assert not limit or len(result) <= limit
1✔
1995
        return result
1✔
1996

1997
    def match_universal_formulas(self, state: SolutionState) -> List[SolutionState]:
1✔
1998
        instantiated_formulas: List[language.Formula] = []
1✔
1999
        conjuncts = split_conjunction(state.constraint)
1✔
2000

2001
        for idx, universal_formula in enumerate(conjuncts):
1✔
2002
            if not isinstance(universal_formula, language.ForallFormula):
1✔
2003
                continue
1✔
2004

2005
            matches: List[Dict[language.Variable, Tuple[Path, DerivationTree]]] = [
1✔
2006
                match
2007
                for match in matches_for_quantified_formula(
2008
                    universal_formula, self.grammar
2009
                )
2010
                if not universal_formula.is_already_matched(
2011
                    match[universal_formula.bound_variable][1]
2012
                )
2013
            ]
2014

2015
            universal_formula_with_matches = universal_formula.add_already_matched(
1✔
2016
                {match[universal_formula.bound_variable][1] for match in matches}
2017
            )
2018

2019
            for match in matches:
1✔
2020
                inst_formula = (
1✔
2021
                    universal_formula_with_matches.inner_formula.substitute_expressions(
2022
                        {
2023
                            variable: match_tree
2024
                            for variable, (_, match_tree) in match.items()
2025
                        }
2026
                    )
2027
                )
2028

2029
                instantiated_formulas.append(inst_formula)
1✔
2030
                conjuncts = list_set(conjuncts, idx, universal_formula_with_matches)
1✔
2031

2032
        if instantiated_formulas:
1✔
2033
            return [
1✔
2034
                SolutionState(
2035
                    sc.conjunction(*instantiated_formulas) & sc.conjunction(*conjuncts),
2036
                    state.tree,
2037
                )
2038
            ]
2039
        else:
2040
            return []
1✔
2041

2042
    def match_existential_formula(
1✔
2043
        self, existential_formula_idx: int, state: SolutionState
2044
    ) -> List[SolutionState]:
2045
        result: List[SolutionState] = []
1✔
2046

2047
        conjuncts: ImmutableList[language.Formula] = tuple(
1✔
2048
            split_conjunction(state.constraint)
2049
        )
2050
        existential_formula = cast(
1✔
2051
            language.ExistsFormula, conjuncts[existential_formula_idx]
2052
        )
2053

2054
        matches: List[
1✔
2055
            Dict[language.Variable, Tuple[Path, DerivationTree]]
2056
        ] = matches_for_quantified_formula(existential_formula, self.grammar)
2057

2058
        for match in matches:
1✔
2059
            inst_formula = existential_formula.inner_formula.substitute_expressions(
1✔
2060
                {variable: match_tree for variable, (_, match_tree) in match.items()}
2061
            )
2062
            constraint = inst_formula & sc.conjunction(
1✔
2063
                *list_del(conjuncts, existential_formula_idx)
2064
            )
2065
            result.append(SolutionState(constraint, state.tree))
1✔
2066

2067
        return result
1✔
2068

2069
    def eliminate_existential_formula(
1✔
2070
        self, existential_formula_idx: int, state: SolutionState
2071
    ) -> List[SolutionState]:
2072
        conjuncts: ImmutableList[language.Formula] = tuple(
1✔
2073
            split_conjunction(state.constraint)
2074
        )
2075
        existential_formula = cast(
1✔
2076
            language.ExistsFormula, conjuncts[existential_formula_idx]
2077
        )
2078

2079
        inserted_trees_and_bind_paths = (
1✔
2080
            [(DerivationTree(existential_formula.bound_variable.n_type, None), {})]
2081
            if existential_formula.bind_expression is None
2082
            else (
2083
                existential_formula.bind_expression.to_tree_prefix(
2084
                    existential_formula.bound_variable.n_type, self.grammar
2085
                )
2086
            )
2087
        )
2088

2089
        result: List[SolutionState] = []
1✔
2090

2091
        inserted_tree: DerivationTree
2092
        bind_expr_paths: Dict[language.BoundVariable, Path]
2093
        for inserted_tree, bind_expr_paths in inserted_trees_and_bind_paths:
1✔
2094
            self.logger.debug(
1✔
2095
                "process_insertion_results(self.canonical_grammar, %s, %s, self.graph, %s)",
2096
                lazystr(lambda: repr(inserted_tree)),
2097
                lazystr(lambda: repr(existential_formula.in_variable)),
2098
                self.max_number_tree_insertion_results,
2099
            )
2100

2101
            insertion_results = insert_tree(
1✔
2102
                self.canonical_grammar,
2103
                inserted_tree,
2104
                existential_formula.in_variable,
2105
                graph=self.graph,
2106
                max_num_solutions=self.max_number_tree_insertion_results * 2,
2107
                methods=self.tree_insertion_methods,
2108
            )
2109

2110
            insertion_results = sorted(
1✔
2111
                insertion_results,
2112
                key=lambda t: compute_tree_closing_cost(t, self.graph),
2113
            )
2114
            insertion_results = insertion_results[
1✔
2115
                : self.max_number_tree_insertion_results
2116
            ]
2117

2118
            for insertion_result in insertion_results:
1✔
2119
                replaced_path = state.tree.find_node(existential_formula.in_variable)
1✔
2120
                resulting_tree = state.tree.replace_path(
1✔
2121
                    replaced_path, insertion_result
2122
                )
2123

2124
                tree_substitution: Dict[DerivationTree, DerivationTree] = {}
1✔
2125
                for idx in range(len(replaced_path) + 1):
1✔
2126
                    original_path = replaced_path[: idx - 1]
1✔
2127
                    original_tree = state.tree.get_subtree(original_path)
1✔
2128
                    if (
1✔
2129
                        resulting_tree.is_valid_path(original_path)
2130
                        and original_tree.value
2131
                        == resulting_tree.get_subtree(original_path).value
2132
                        and resulting_tree.get_subtree(original_path) != original_tree
2133
                    ):
2134
                        tree_substitution[original_tree] = resulting_tree.get_subtree(
1✔
2135
                            original_path
2136
                        )
2137

2138
                assert insertion_result.find_node(inserted_tree) is not None
1✔
2139
                variable_substitutions = {
1✔
2140
                    existential_formula.bound_variable: inserted_tree
2141
                }
2142

2143
                if bind_expr_paths:
1✔
2144
                    if assertions_activated():
1✔
2145
                        dangling_bind_expr_vars = [
1✔
2146
                            (var, path)
2147
                            for var, path in bind_expr_paths.items()
2148
                            if (
2149
                                var
2150
                                in existential_formula.bind_expression.bound_variables()
2151
                                and insertion_result.find_node(
2152
                                    inserted_tree.get_subtree(path)
2153
                                )
2154
                                is None
2155
                            )
2156
                        ]
2157
                        assert not dangling_bind_expr_vars, (
1✔
2158
                            f"Bound variables from match expression not found in tree: "
2159
                            f"[{' ,'.join(map(repr, dangling_bind_expr_vars))}]"
2160
                        )
2161

2162
                    variable_substitutions.update(
1✔
2163
                        {
2164
                            var: inserted_tree.get_subtree(path)
2165
                            for var, path in bind_expr_paths.items()
2166
                            if var
2167
                            in existential_formula.bind_expression.bound_variables()
2168
                        }
2169
                    )
2170

2171
                instantiated_formula = (
1✔
2172
                    existential_formula.inner_formula.substitute_expressions(
2173
                        variable_substitutions
2174
                    ).substitute_expressions(tree_substitution)
2175
                )
2176

2177
                instantiated_original_constraint = sc.conjunction(
1✔
2178
                    *list_del(conjuncts, existential_formula_idx)
2179
                ).substitute_expressions(tree_substitution)
2180

2181
                new_tree = resulting_tree.substitute(tree_substitution)
1✔
2182

2183
                new_formula = (
1✔
2184
                    instantiated_formula
2185
                    & self.formula.substitute_expressions(
2186
                        {self.top_constant.unwrap(): new_tree}
2187
                    )
2188
                    & instantiated_original_constraint
2189
                )
2190

2191
                new_state = SolutionState(new_formula, new_tree)
1✔
2192

2193
                assert all(
1✔
2194
                    new_state.tree.find_node(tree) is not None
2195
                    for quantified_formula in split_conjunction(new_state.constraint)
2196
                    if isinstance(quantified_formula, language.QuantifiedFormula)
2197
                    for _, tree in quantified_formula.in_variable.filter(lambda t: True)
2198
                )
2199

2200
                if assertions_activated() or self.debug:
1✔
2201
                    lost_tree_predicate_arguments: List[DerivationTree] = [
1✔
2202
                        arg
2203
                        for invstate in self.establish_invariant(new_state)
2204
                        for predicate_formula in get_conjuncts(invstate.constraint)
2205
                        if isinstance(
2206
                            predicate_formula, language.StructuralPredicateFormula
2207
                        )
2208
                        for arg in predicate_formula.args
2209
                        if isinstance(arg, DerivationTree)
2210
                        and invstate.tree.find_node(arg) is None
2211
                    ]
2212

2213
                    if lost_tree_predicate_arguments:
1✔
2214
                        previous_posititions = [
×
2215
                            state.tree.find_node(t)
2216
                            for t in lost_tree_predicate_arguments
2217
                        ]
2218
                        assert False, (
×
2219
                            f"Dangling subtrees [{', '.join(map(repr, lost_tree_predicate_arguments))}], "
2220
                            f"previously at positions [{', '.join(map(str, previous_posititions))}] "
2221
                            f"in tree {repr(state.tree)} (hash: {hash(state)})."
2222
                        )
2223

2224
                    lost_semantic_formula_arguments = [
1✔
2225
                        arg
2226
                        for invstate in self.establish_invariant(new_state)
2227
                        for semantic_formula in get_conjuncts(new_state.constraint)
2228
                        if isinstance(semantic_formula, language.SMTFormula)
2229
                        for arg in semantic_formula.substitutions.values()
2230
                        if invstate.tree.find_node(arg) is None
2231
                    ]
2232

2233
                    if lost_semantic_formula_arguments:
1✔
2234
                        previous_posititions = [
×
2235
                            state.tree.find_node(t)
2236
                            for t in lost_semantic_formula_arguments
2237
                        ]
2238
                        previous_posititions = [
×
2239
                            p for p in previous_posititions if p is not None
2240
                        ]
2241
                        assert False, (
×
2242
                            f"Dangling subtrees [{', '.join(map(repr, lost_semantic_formula_arguments))}], "
2243
                            f"previously at positions [{', '.join(map(str, previous_posititions))}] "
2244
                            f"in tree {repr(state.tree)} (hash: {hash(state)})."
2245
                        )
2246

2247
                result.append(new_state)
1✔
2248

2249
        return result
1✔
2250

2251
    def eliminate_semantic_formula(
1✔
2252
        self,
2253
        semantic_formula: language.Formula,
2254
        state: SolutionState,
2255
        max_instantiations: Optional[int] = None,
2256
    ) -> Optional[List[SolutionState]]:
2257
        """
2258
        Solves a semantic formula and, for each solution, substitutes the solution for
2259
        the respective constant in each assignment of the state. Also instantiates all
2260
        "free" constants in the given tree. The SMT solver is passed a regular
2261
        expression approximating the language of the nonterminal of each considered
2262
        constant. Returns an empty list for unsolvable constraints.
2263

2264
        :param semantic_formula: The semantic (i.e., only containing logical connectors and SMT Formulas)
2265
        formula to solve.
2266
        :param state: The original solution state.
2267
        :param max_instantiations: The maximum number of solutions to ask the SMT solver for.
2268
        :return: A list of instantiated SolutionStates.
2269
        """
2270

2271
        assert all(
1✔
2272
            isinstance(conjunct, language.SMTFormula)
2273
            for conjunct in get_conjuncts(semantic_formula)
2274
        )
2275

2276
        # NOTE: We need to cluster SMT formulas by tree substitutions. If there are two
2277
        # formulas with a variable $var which is instantiated to different trees, we
2278
        # need two separate solutions. If, however, $var is instantiated with the
2279
        # *same* tree, we need one solution to both formulas together.
2280

2281
        smt_formulas = self.rename_instantiated_variables_in_smt_formulas(
1✔
2282
            [
2283
                smt_formula
2284
                for smt_formula in get_conjuncts(semantic_formula)
2285
                if isinstance(smt_formula, language.SMTFormula)
2286
            ]
2287
        )
2288

2289
        # Now, we also cluster formulas by common variables (and instantiated subtrees:
2290
        # One formula might yield an instantiation of a subtree of the instantiation of
2291
        # another formula. They need to appear in the same cluster). The solver can
2292
        # better handle smaller constraints, and those which do not have variables in
2293
        # common can be handled independently.
2294

2295
        def cluster_keys(smt_formula: language.SMTFormula):
1✔
2296
            return (
1✔
2297
                smt_formula.free_variables()
2298
                | smt_formula.instantiated_variables
2299
                | set(
2300
                    [
2301
                        subtree
2302
                        for tree in smt_formula.substitutions.values()
2303
                        for _, subtree in tree.paths()
2304
                    ]
2305
                )
2306
            )
2307

2308
        formula_clusters: List[List[language.SMTFormula]] = cluster_by_common_elements(
1✔
2309
            smt_formulas, cluster_keys
2310
        )
2311

2312
        assert all(
1✔
2313
            not cluster_keys(smt_formula)
2314
            or any(smt_formula in cluster for cluster in formula_clusters)
2315
            for smt_formula in smt_formulas
2316
        )
2317

2318
        formula_clusters = [cluster for cluster in formula_clusters if cluster]
1✔
2319
        remaining_clusters = [
1✔
2320
            smt_formula for smt_formula in smt_formulas if not cluster_keys(smt_formula)
2321
        ]
2322
        if remaining_clusters:
1✔
2323
            formula_clusters.append(remaining_clusters)
1✔
2324

2325
        # Note: We cannot ask for `max_instantiations` solutions for *each cluster;*
2326
        #       this would imply that we get 10^4 solutions if `max_instantiations`
2327
        #       is 10 and we have 4 clusters (we combine all these solutions to a
2328
        #       product). Instead, we want 10 solutions; thus, we compute the
2329
        #       #numCluster'th root of `max_instantiations` and ceil.
2330
        #       For example, the ceil of the 4-root of 10 is 2, and 2^10 is 16. This
2331
        #       is still within an acceptable range.
2332

2333
        solutions_per_cluster = math.ceil(
1✔
2334
            (max_instantiations or self.max_number_smt_instantiations)
2335
            ** (1 / len(formula_clusters))
2336
        )
2337

2338
        all_solutions: List[
1✔
2339
            List[Dict[Union[language.Constant, DerivationTree], DerivationTree]]
2340
        ] = [
2341
            self.solve_quantifier_free_formula(
2342
                tuple(cluster),
2343
                solutions_per_cluster,
2344
            )
2345
            for cluster in formula_clusters
2346
        ]
2347

2348
        # These solutions are all independent, such that we can combine each solution
2349
        # with all others.
2350
        solutions: List[
1✔
2351
            Dict[Union[language.Constant, DerivationTree], DerivationTree]
2352
        ] = [
2353
            functools.reduce(operator.or_, dicts)
2354
            for dicts in itertools.product(*all_solutions)
2355
        ]
2356

2357
        results = []
1✔
2358
        # We also have to instantiate all subtrees of the substituted element.
2359
        for solution in map(subtree_solutions, solutions):
1✔
2360
            if solution:
1✔
2361
                new_state = SolutionState(
1✔
2362
                    state.constraint.substitute_expressions(solution),
2363
                    state.tree.substitute(solution),
2364
                )
2365
            else:
2366
                new_state = SolutionState(
×
2367
                    language.replace_formula(
2368
                        state.constraint, semantic_formula, sc.true()
2369
                    ),
2370
                    state.tree,
2371
                )
2372

2373
            results.append(new_state)
1✔
2374

2375
        return results
1✔
2376

2377
    @lru_cache(100)
1✔
2378
    def solve_quantifier_free_formula(
1✔
2379
        self,
2380
        smt_formulas: ImmutableList[language.SMTFormula],
2381
        max_instantiations: Optional[int] = None,
2382
    ) -> List[Dict[language.Constant | DerivationTree, DerivationTree]]:
2383
        """
2384
        Attempts to solve the given SMT-LIB formulas by calling Z3.
2385

2386
        Note that this function does not unify variables pointing to the same derivation
2387
        trees. E.g., a solution may be returned for the formula `var_1 = "a" and
2388
        var_2 = "b"`, though `var_1` and `var_2` point to the same `<var>` tree as
2389
        defined by their substitutions maps. Unification is performed in
2390
        `eliminate_all_semantic_formulas`.
2391

2392
        :param smt_formulas: The SMT-LIB formulas to solve.
2393
        :param max_instantiations: The maximum number of instantiations to produce.
2394
        :return: A (possibly empty) list of solutions.
2395
        """
2396

2397
        # If any SMT formula refers to *sub*trees in the instantiations of other SMT
2398
        # formulas, we have to instantiate those first.
2399
        priority_formulas = smt_formulas_referring_to_subtrees(smt_formulas)
1✔
2400

2401
        if priority_formulas:
1✔
2402
            smt_formulas = priority_formulas
1✔
2403
            assert not smt_formulas_referring_to_subtrees(smt_formulas)
1✔
2404

2405
        tree_substitutions = reduce(
1✔
2406
            lambda d1, d2: d1 | d2,
2407
            [smt_formula.substitutions for smt_formula in smt_formulas],
2408
            {},
2409
        )
2410

2411
        constants = reduce(
1✔
2412
            lambda d1, d2: d1 | d2,
2413
            [
2414
                smt_formula.free_variables() | smt_formula.instantiated_variables
2415
                for smt_formula in smt_formulas
2416
            ],
2417
            set(),
2418
        )
2419

2420
        solutions: List[
1✔
2421
            Dict[Union[language.Constant, DerivationTree], DerivationTree]
2422
        ] = []
2423
        internal_solutions: List[Dict[language.Constant, z3.StringVal]] = []
1✔
2424

2425
        num_instantiations = max_instantiations or self.max_number_smt_instantiations
1✔
2426
        for _ in range(num_instantiations):
1✔
2427
            (
1✔
2428
                solver_result,
2429
                maybe_model,
2430
            ) = self.solve_smt_formulas_with_language_constraints(
2431
                constants,
2432
                tuple([smt_formula.formula for smt_formula in smt_formulas]),
2433
                tree_substitutions,
2434
                internal_solutions,
2435
            )
2436

2437
            if solver_result != z3.sat:
1✔
2438
                if not solutions:
1✔
2439
                    return []
1✔
2440
                else:
2441
                    return solutions
1✔
2442

2443
            assert maybe_model is not None
1✔
2444

2445
            new_solution = {
1✔
2446
                tree_substitutions.get(constant, constant): maybe_model[constant]
2447
                for constant in constants
2448
            }
2449

2450
            new_internal_solution = {
1✔
2451
                constant: z3.StringVal(str(maybe_model[constant]))
2452
                for constant in constants
2453
            }
2454

2455
            if new_solution in solutions:
1✔
2456
                # This can happen for trivial solutions, i.e., if the formula is
2457
                # logically valid. Then, the assignment for that constant will
2458
                # always be {}
2459
                return solutions
×
2460
            else:
2461
                solutions.append(new_solution)
1✔
2462
                if new_internal_solution:
1✔
2463
                    internal_solutions.append(new_internal_solution)
1✔
2464
                else:
2465
                    # Again, for a trivial solution (e.g., True), the assignment
2466
                    # can be empty.
2467
                    break
×
2468

2469
        return solutions
1✔
2470

2471
    def solve_smt_formulas_with_language_constraints(
1✔
2472
        self,
2473
        variables: Set[language.Variable],
2474
        smt_formulas: ImmutableList[z3.BoolRef],
2475
        tree_substitutions: Dict[language.Variable, DerivationTree],
2476
        solutions_to_exclude: List[Dict[language.Variable, z3.StringVal]],
2477
    ) -> Tuple[z3.CheckSatResult, Dict[language.Variable, DerivationTree]]:
2478
        # We disable optimized Z3 queries if the SMT formulas contain "too concrete"
2479
        # substitutions, that is, substitutions with a tree that is not merely an
2480
        # open leaf. Example: we have a constrained `str.len(<chars>) < 10` and a
2481
        # tree `<char><char>`; only the concrete length "10" is possible then. In fact,
2482
        # we could simply finish of the tree and check the constraint, or restrict the
2483
        # custom tree generation to admissible lengths, but we stay general here. The
2484
        # SMT solution is more robust.
2485

2486
        if self.enable_optimized_z3_queries and not any(
1✔
2487
            substitution.children for substitution in tree_substitutions.values()
2488
        ):
2489
            vars_in_context = self.infer_variable_contexts(variables, smt_formulas)
1✔
2490
            length_vars = vars_in_context["length"]
1✔
2491
            int_vars = vars_in_context["int"]
1✔
2492
            flexible_vars = vars_in_context["flexible"]
1✔
2493
        else:
2494
            length_vars = set()
1✔
2495
            int_vars = set()
1✔
2496
            flexible_vars = set(variables)
1✔
2497

2498
        # Add language constraints for "flexible" variables
2499
        formulas: List[z3.BoolRef] = self.generate_language_constraints(
1✔
2500
            flexible_vars, tree_substitutions
2501
        )
2502

2503
        # Create fresh variables for `str.len` and `str.to.int` variables.
2504
        all_variables = set(variables)
1✔
2505
        fresh_var_map: Dict[language.Variable, z3.ExprRef] = {}
1✔
2506
        for var in length_vars | int_vars:
1✔
2507
            fresh = fresh_constant(
1✔
2508
                all_variables,
2509
                language.Constant(var.name, "NOT-NEEDED"),
2510
            )
2511
            fresh_var_map[var] = z3.Int(fresh.name)
1✔
2512

2513
        # In `smt_formulas`, we replace all `length(...)` terms for "length variables"
2514
        # with the corresponding fresh variable.
2515
        replacement_map: Dict[z3.ExprRef, z3.ExprRef] = {
1✔
2516
            expr: fresh_var_map[
2517
                get_elem_by_equivalence(
2518
                    expr.children()[0],
2519
                    length_vars | int_vars,
2520
                    lambda e1, e2: e1 == e2.to_smt(),
2521
                )
2522
            ]
2523
            for formula in smt_formulas
2524
            for expr in visit_z3_expr(formula)
2525
            if expr.decl().kind() in {z3.Z3_OP_SEQ_LENGTH, z3.Z3_OP_STR_TO_INT}
2526
            and expr.children()[0] in {var.to_smt() for var in length_vars | int_vars}
2527
        }
2528

2529
        # Perform substitution, add formulas
2530
        formulas.extend(
1✔
2531
            [
2532
                cast(z3.BoolRef, z3_subst(formula, replacement_map))
2533
                for formula in smt_formulas
2534
            ]
2535
        )
2536

2537
        # Lengths must be positive
2538
        formulas.extend(
1✔
2539
            [
2540
                cast(
2541
                    z3.BoolRef,
2542
                    replacement_map[z3.Length(length_var.to_smt())] >= z3.IntVal(0),
2543
                )
2544
                for length_var in length_vars
2545
            ]
2546
        )
2547

2548
        # Add custom intervals for int variables
2549
        for int_var in int_vars:
1✔
2550
            if int_var.n_type == language.Variable.NUMERIC_NTYPE:
1✔
2551
                # "NUM" variables range over the full int domain
2552
                continue
1✔
2553

2554
            regex = self.extract_regular_expression(int_var.n_type)
1✔
2555
            maybe_intervals = numeric_intervals_from_regex(regex)
1✔
2556
            repl_var = replacement_map[z3.StrToInt(int_var.to_smt())]
1✔
2557
            maybe_intervals.map(
1✔
2558
                tap(
2559
                    lambda intervals: formulas.append(
2560
                        z3_or(
2561
                            [
2562
                                z3.And(
2563
                                    repl_var >= z3.IntVal(interval[0])
2564
                                    if interval[0] > -sys.maxsize
2565
                                    else z3.BoolVal(True),
2566
                                    repl_var <= z3.IntVal(interval[1])
2567
                                    if interval[1] < sys.maxsize
2568
                                    else z3.BoolVal(True),
2569
                                )
2570
                                for interval in intervals
2571
                            ]
2572
                        )
2573
                    )
2574
                )
2575
            )
2576

2577
        for prev_solution in solutions_to_exclude:
1✔
2578
            prev_solution_formula = z3_and(
1✔
2579
                [
2580
                    self.previous_solution_formula(
2581
                        var, string_val, fresh_var_map, length_vars, int_vars
2582
                    )
2583
                    for var, string_val in prev_solution.items()
2584
                ]
2585
            )
2586

2587
            formulas.append(z3.Not(prev_solution_formula))
1✔
2588

2589
        sat_result, maybe_model = z3_solve(formulas)
1✔
2590

2591
        if sat_result != z3.sat:
1✔
2592
            return sat_result, {}
1✔
2593

2594
        assert maybe_model is not None
1✔
2595

2596
        return sat_result, {
1✔
2597
            var: self.extract_model_value(
2598
                var, maybe_model, fresh_var_map, length_vars, int_vars
2599
            )
2600
            for var in variables
2601
        }
2602

2603
    @staticmethod
1✔
2604
    def previous_solution_formula(
1✔
2605
        var: language.Variable,
2606
        string_val: z3.StringVal,
2607
        fresh_var_map: Dict[language.Variable, z3.ExprRef],
2608
        length_vars: Set[language.Variable],
2609
        int_vars: Set[language.Variable],
2610
    ) -> z3.BoolRef:
2611
        """
2612
        Computes a formula describing the previously found solution
2613
        :code:`var == string_val` for an :class:`~isla.language.SMTFormula`.
2614
        Considers the special cases that :code:`var` is a "length" or "int"
2615
        variable, i.e., occurred only in these contexts in the formula this
2616
        solution is about.
2617

2618
        >>> x = language.Variable("x", "<X>")
2619
        >>> ISLaSolver.previous_solution_formula(
2620
        ...     x, z3.StringVal("val"), {}, set(), set())
2621
        x == "val"
2622

2623
        >>> ISLaSolver.previous_solution_formula(
2624
        ...     x, z3.StringVal("val"), {x: z3.Int("x_0")}, {x}, set())
2625
        x_0 == 3
2626

2627
        >>> ISLaSolver.previous_solution_formula(
2628
        ...     x, z3.StringVal("10"), {x: z3.Int("x_0")}, set(), {x})
2629
        x_0 == 10
2630

2631
        >>> x = language.Variable("x", language.Variable.NUMERIC_NTYPE)
2632
        >>> ISLaSolver.previous_solution_formula(
2633
        ...     x, z3.StringVal("10"), {x: z3.Int("x_0")}, set(), {x})
2634
        x_0 == 10
2635

2636
        A "numeric" variable (of "NUM" type) is expected to always be an int variable,
2637
        which also needs to be reflected in its inclusion in :code:`fresh_var_map`.
2638

2639
        >>> x = language.Variable("x", language.Variable.NUMERIC_NTYPE)
2640
        >>> ISLaSolver.previous_solution_formula(
2641
        ...     x, z3.StringVal("10"), {}, set(), set())
2642
        Traceback (most recent call last):
2643
        ...
2644
        AssertionError
2645

2646
        :param var: The variable the solution is for.
2647
        :param string_val: The solution for :code:`var`.
2648
        :param fresh_var_map: A map from variables to fresh variables for "length" or
2649
                              "int" variables.
2650
        :param length_vars: The "length" variables.
2651
        :param int_vars: The "int" variables.
2652
        :return: An equation describing the previous solution.
2653
        """
2654

2655
        if var in int_vars:
1✔
2656
            return z3_eq(
1✔
2657
                fresh_var_map[var],
2658
                z3.IntVal(int(smt_string_val_to_string(string_val))),
2659
            )
2660
        elif var in length_vars:
1✔
2661
            return z3_eq(
1✔
2662
                fresh_var_map[var],
2663
                z3.IntVal(len(smt_string_val_to_string(string_val))),
2664
            )
2665
        else:
2666
            assert not var.is_numeric()
1✔
2667
            return z3_eq(var.to_smt(), string_val)
1✔
2668

2669
    def safe_create_fixed_length_tree(
1✔
2670
        self,
2671
        var: language.Variable,
2672
        model: z3.ModelRef,
2673
        fresh_var_map: Dict[language.Variable, z3.ExprRef],
2674
    ) -> DerivationTree:
2675
        """
2676
        Creates a :class:`~isla.derivation_tree.DerivationTree` for :code:`var` such
2677
        that the type of the tree fits to :code:`var` and the length of its string
2678
        representation fits to the length in :code:`model` for the fresh variable in
2679
        :code:`fresh_var_map`. For example:
2680

2681
        >>> grammar = {
2682
        ...     "<start>": ["<X>"],
2683
        ...     "<X>": ["x", "x<X>"],
2684
        ... }
2685
        >>> x = language.Variable("x", "<X>")
2686
        >>> x_0 = z3.Int("x_0")
2687
        >>> f = z3_eq(x_0, z3.IntVal(5))
2688
        >>> z3_solver = z3.Solver()
2689
        >>> z3_solver.add(f)
2690
        >>> z3_solver.check()
2691
        sat
2692
        >>> model = z3_solver.model()
2693
        >>> solver = ISLaSolver(grammar)
2694
        >>> tree = solver.safe_create_fixed_length_tree(x, model, {x: x_0})
2695
        >>> tree.value
2696
        '<X>'
2697
        >>> str(tree)
2698
        'xxxxx'
2699

2700
        :param var: The variable to create a
2701
                    :class:`~isla.derivation_tree.DerivationTree` object for.
2702
        :param model: The Z3 model to extract a solution to the length constraint.
2703
        :param fresh_var_map: A map including a mapping :code:`var` -> :code:`var_0`,
2704
                              where :code:`var_0` is an integer-valued variale included
2705
                              in :code:`model`.
2706
        :return: A tree of the type of :code:`var` and length as specified in
2707
                :code:`model`.
2708
        """
2709

2710
        assert var in fresh_var_map
1✔
2711
        assert fresh_var_map[var].decl() in model.decls()
1✔
2712

2713
        fixed_length_tree = create_fixed_length_tree(
1✔
2714
            start=var.n_type,
2715
            canonical_grammar=self.canonical_grammar,
2716
            target_length=model[fresh_var_map[var]].as_long(),
2717
        )
2718

2719
        if fixed_length_tree is None:
1✔
2720
            raise RuntimeError(
1✔
2721
                f"Could not create a tree with the start symbol '{var.n_type}' "
2722
                + f"of length {model[fresh_var_map[var]].as_long()}; try "
2723
                + "running the solver without optimized Z3 queries or make "
2724
                + "sure that lengths are restricted to syntactically valid "
2725
                + "ones (according to the grammar).",
2726
            )
2727

2728
        return fixed_length_tree
1✔
2729

2730
    def extract_model_value(
1✔
2731
        self,
2732
        var: language.Variable,
2733
        model: z3.ModelRef,
2734
        fresh_var_map: Dict[language.Variable, z3.ExprRef],
2735
        length_vars: Set[language.Variable],
2736
        int_vars: Set[language.Variable],
2737
    ) -> DerivationTree:
2738
        r"""
2739
        Extracts a value for :code:`var` from :code:`model`. Considers the following
2740
        special cases:
2741

2742
        Numeric Variables
2743
            Returns a closed derivation tree of one node with a string representation
2744
            of the numeric solution.
2745

2746
        "Length" Variables
2747
            Returns a string of the length corresponding to the model and
2748
            :code:`fresh_var_map`, see also
2749
            :meth:`~isla.solver.ISLaSolver.safe_create_fixed_length_tree()`.
2750

2751
        "Int" Variables
2752
            Tries to parse the numeric solution from the model (obtained via
2753
            :code:`fresh_var_map`) into the type of :code:`var` and returns the
2754
            corresponding derivation tree.
2755

2756
        >>> grammar = {
2757
        ...     "<start>": ["<A>"],
2758
        ...     "<A>": ["<X><Y>"],
2759
        ...     "<X>": ["x", "x<X>"],
2760
        ...     "<Y>": ["<digit>", "<digit><Y>"],
2761
        ...     "<digit>": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
2762
        ... }
2763
        >>> solver = ISLaSolver(grammar)
2764

2765
        **Numeric Variables:**
2766

2767
        >>> n = language.Variable("n", language.Variable.NUMERIC_NTYPE)
2768
        >>> f = z3_eq(z3.StrToInt(n.to_smt()), z3.IntVal(15))
2769
        >>> z3_solver = z3.Solver()
2770
        >>> z3_solver.add(f)
2771
        >>> z3_solver.check()
2772
        sat
2773
        >>> model = z3_solver.model()
2774
        >>> DerivationTree.next_id = 1
2775
        >>> solver.extract_model_value(n, model, {}, set(), set())
2776
        DerivationTree('15', (), id=1)
2777

2778
        For a trivially true solution on numeric variables, we return a random number:
2779

2780
        >>> f = z3_eq(n.to_smt(), n.to_smt())
2781
        >>> z3_solver = z3.Solver()
2782
        >>> z3_solver.add(f)
2783
        >>> z3_solver.check()
2784
        sat
2785

2786
        >>> model = z3_solver.model()
2787
        >>> DerivationTree.next_id = 1
2788
        >>> random.seed(0)
2789
        >>> solver.extract_model_value(n, model, {n: n.to_smt()}, set(), {n})
2790
        DerivationTree('-2116850434379610162', (), id=1)
2791

2792
        **"Length" Variables:**
2793

2794
        >>> x = language.Variable("x", "<X>")
2795
        >>> x_0 = z3.Int("x_0")
2796
        >>> f = z3_eq(x_0, z3.IntVal(3))
2797
        >>> z3_solver = z3.Solver()
2798
        >>> z3_solver.add(f)
2799
        >>> z3_solver.check()
2800
        sat
2801
        >>> model = z3_solver.model()
2802
        >>> result = solver.extract_model_value(x, model, {x: x_0}, {x}, set())
2803
        >>> result.value
2804
        '<X>'
2805
        >>> str(result)
2806
        'xxx'
2807

2808
        **"Int" Variables:**
2809

2810
        >>> y = language.Variable("y", "<Y>")
2811
        >>> y_0 = z3.Int("y_0")
2812
        >>> f = z3_eq(y_0, z3.IntVal(5))
2813
        >>> z3_solver = z3.Solver()
2814
        >>> z3_solver.add(f)
2815
        >>> z3_solver.check()
2816
        sat
2817
        >>> model = z3_solver.model()
2818
        >>> DerivationTree.next_id = 1
2819
        >>> solver.extract_model_value(y, model, {y: y_0}, set(), {y})
2820
        DerivationTree('<Y>', (DerivationTree('<digit>', (DerivationTree('5', (), id=1),), id=2),), id=3)
2821

2822
        **"Flexible" Variables:**
2823

2824
        >>> f = z3_eq(x.to_smt(), z3.StringVal("xxxxx"))
2825
        >>> z3_solver = z3.Solver()
2826
        >>> z3_solver.add(f)
2827
        >>> z3_solver.check()
2828
        sat
2829
        >>> model = z3_solver.model()
2830
        >>> result = solver.extract_model_value(x, model, {}, set(), set())
2831
        >>> result.value
2832
        '<X>'
2833
        >>> str(result)
2834
        'xxxxx'
2835

2836
        **Special Number Formats**
2837

2838
        It may happen that the solver returns, e.g., "1" as a solution for an int
2839
        variable, but the grammar only recognizes "+001". We support this case, i.e.,
2840
        an optional "+" and optional 0 padding; an optional 0 padding for negative
2841
        numbers is also supported.
2842

2843
        >>> grammar = {
2844
        ...     "<start>": ["<int>"],
2845
        ...     "<int>": ["<sign>00<leaddigit><digits>"],
2846
        ...     "<sign>": ["-", "+"],
2847
        ...     "<digits>": ["", "<digit><digits>"],
2848
        ...     "<digit>": list("0123456789"),
2849
        ...     "<leaddigit>": list("123456789"),
2850
        ... }
2851
        >>> solver = ISLaSolver(grammar)
2852

2853
        >>> i = language.Variable("i", "<int>")
2854
        >>> i_0 = z3.Int("i_0")
2855
        >>> f = z3_eq(i_0, z3.IntVal(5))
2856

2857
        >>> z3_solver = z3.Solver()
2858
        >>> z3_solver.add(f)
2859
        >>> z3_solver.check()
2860
        sat
2861

2862
        The following test works when run from the IDE, but unfortunately not when
2863
        started from CI/the `test_doctests.py` script. Thus, we only show it as code
2864
        block (we added a unit test as a replacement)::
2865

2866
            model = z3_solver.model()
2867
            print(solver.extract_model_value(i, model, {i: i_0}, set(), {i}))
2868
            # Prints: +005
2869

2870
        :param var: The variable for which to extract a solution from the model.
2871
        :param model: The model containing the solution.
2872
        :param fresh_var_map: A map from variables to fresh symbols for "length" and
2873
                              "int" variables.
2874
        :param length_vars: The set of "length" variables.
2875
        :param int_vars: The set of "int" variables.
2876
        :return: A :class:`~isla.derivation_tree.DerivationTree` object corresponding
2877
                 to the solution in :code:`model`.
2878
        """
2879

2880
        f_flex_vars = self.extract_model_value_flexible_var
1✔
2881
        f_int_vars = partial(self.extract_model_value_int_var, f_flex_vars)
1✔
2882
        f_length_vars = partial(self.extract_model_value_length_var, f_int_vars)
1✔
2883
        f_num_vars = partial(self.extract_model_value_numeric_var, f_length_vars)
1✔
2884

2885
        return f_num_vars(var, model, fresh_var_map, length_vars, int_vars)
1✔
2886

2887
    ExtractModelValueFallbackType = Callable[
1✔
2888
        [
2889
            language.Variable,
2890
            z3.ModelRef,
2891
            Dict[language.Variable, z3.ExprRef],
2892
            Set[language.Variable],
2893
            Set[language.Variable],
2894
        ],
2895
        DerivationTree,
2896
    ]
2897

2898
    def extract_model_value_numeric_var(
1✔
2899
        self,
2900
        fallback: ExtractModelValueFallbackType,
2901
        var: language.Variable,
2902
        model: z3.ModelRef,
2903
        fresh_var_map: Dict[language.Variable, z3.ExprRef],
2904
        length_vars: Set[language.Variable],
2905
        int_vars: Set[language.Variable],
2906
    ) -> DerivationTree:
2907
        """
2908
        Addresses the case of numeric variables from
2909
        :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2910

2911
        :param fallback: The function to call if this function is not responsible.
2912
        :param var: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2913
        :param model: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2914
        :param fresh_var_map: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2915
        :param length_vars: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2916
        :param int_vars: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2917
        :return: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2918
        """
2919
        if not var.is_numeric():
1✔
2920
            return fallback(var, model, fresh_var_map, length_vars, int_vars)
1✔
2921

2922
        z3_var = z3.String(var.name)
1✔
2923
        if z3_var.decl() in model.decls():
1✔
2924
            model_value = model[z3_var]
1✔
2925
        else:
2926
            assert var in int_vars
1✔
2927
            assert var in fresh_var_map
1✔
2928

2929
            model_value = model[fresh_var_map[var]]
1✔
2930

2931
            if model_value is None:
1✔
2932
                # This can happen for universally true formulas, e.g., `x = x`.
2933
                # In that case, we return a random integer.
2934
                model_value = z3.IntVal(random.randint(-sys.maxsize, sys.maxsize))
1✔
2935

2936
        assert (
1✔
2937
            model_value is not None
2938
        ), f"No solution for variable {var} found in model {model}"
2939

2940
        string_value = smt_string_val_to_string(model_value)
1✔
2941
        assert string_value
1✔
2942
        assert (
1✔
2943
            string_value.isnumeric()
2944
            or string_value[0] == "-"
2945
            and string_value[1:].isnumeric()
2946
        )
2947

2948
        return DerivationTree(string_value, ())
1✔
2949

2950
    def extract_model_value_length_var(
1✔
2951
        self,
2952
        fallback: ExtractModelValueFallbackType,
2953
        var: language.Variable,
2954
        model: z3.ModelRef,
2955
        fresh_var_map: Dict[language.Variable, z3.ExprRef],
2956
        length_vars: Set[language.Variable],
2957
        int_vars: Set[language.Variable],
2958
    ) -> DerivationTree:
2959
        """
2960
        Addresses the case of length variables from
2961
        :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2962

2963
        :param fallback: The function to call if this function is not responsible.
2964
        :param var: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2965
        :param model: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2966
        :param fresh_var_map: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2967
        :param length_vars: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2968
        :param int_vars: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2969
        :return: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2970
        """
2971
        if var not in length_vars:
1✔
2972
            return fallback(var, model, fresh_var_map, length_vars, int_vars)
1✔
2973

2974
        return self.safe_create_fixed_length_tree(var, model, fresh_var_map)
1✔
2975

2976
    def extract_model_value_int_var(
1✔
2977
        self,
2978
        fallback: ExtractModelValueFallbackType,
2979
        var: language.Variable,
2980
        model: z3.ModelRef,
2981
        fresh_var_map: Dict[language.Variable, z3.ExprRef],
2982
        length_vars: Set[language.Variable],
2983
        int_vars: Set[language.Variable],
2984
    ) -> DerivationTree:
2985
        """
2986
        Addresses the case of int variables from
2987
        :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2988

2989
        :param fallback: The function to call if this function is not responsible.
2990
        :param var: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2991
        :param model: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2992
        :param fresh_var_map: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2993
        :param length_vars: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2994
        :param int_vars: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2995
        :return: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
2996
        """
2997
        if var not in int_vars:
1✔
2998
            return fallback(var, model, fresh_var_map, length_vars, int_vars)
1✔
2999

3000
        str_model_value = model[fresh_var_map[var]].as_string()
1✔
3001

3002
        try:
1✔
3003
            int_model_value = int(str_model_value)
1✔
3004
        except ValueError:
×
3005
            raise RuntimeError(f"Value {str_model_value} for {var} is not a number")
×
3006

3007
        var_type = var.n_type
1✔
3008

3009
        try:
1✔
3010
            return self.parse(
1✔
3011
                str(int_model_value),
3012
                var_type,
3013
                silent=True,
3014
            )
3015
        except SyntaxError:
1✔
3016
            # This may happen, e.g, with padded values: Only "01" is a valid
3017
            # solution, but not "1". Similarly, a grammar may expect "+1", but
3018
            # "1" is returned by the solver. We support the number format
3019
            # `[+-]0*<digits>`. Whenever the grammar recognizes at least this
3020
            # set for the nonterminal in question, we return a derivation tree.
3021
            # Otherwise, a RuntimeError is raised.
3022

3023
            z3_solver = z3.Solver()
1✔
3024
            z3_solver.set("timeout", 300)
1✔
3025

3026
            maybe_plus_re = z3.Option(z3.Re("+"))
1✔
3027
            zeroes_padding_re = z3.Star(z3.Re("0"))
1✔
3028

3029
            # TODO: Ensure symbols are fresh
3030
            maybe_plus_var = z3.String("__plus")
1✔
3031
            zeroes_padding_var = z3.String("__padding")
1✔
3032

3033
            z3_solver.add(z3.InRe(maybe_plus_var, maybe_plus_re))
1✔
3034
            z3_solver.add(z3.InRe(zeroes_padding_var, zeroes_padding_re))
1✔
3035

3036
            z3_solver.add(
1✔
3037
                z3.InRe(
3038
                    z3.Concat(
3039
                        maybe_plus_var if int_model_value >= 0 else z3.StringVal("-"),
3040
                        zeroes_padding_var,
3041
                        z3.StringVal(
3042
                            str_model_value
3043
                            if int_model_value >= 0
3044
                            else str(-int_model_value)
3045
                        ),
3046
                    ),
3047
                    self.extract_regular_expression(var.n_type),
3048
                )
3049
            )
3050

3051
            if z3_solver.check() != z3.sat:
1✔
3052
                raise RuntimeError(
×
3053
                    "Could not parse a numeric solution "
3054
                    + f"({str_model_value}) for variable "
3055
                    + f"{var} of type '{var.n_type}'; try "
3056
                    + "running the solver without optimized Z3 queries or make "
3057
                    + "sure that ranges are restricted to syntactically valid "
3058
                    + "ones (according to the grammar).",
3059
                )
3060

3061
            return self.parse(
1✔
3062
                (
3063
                    z3_solver.model()[maybe_plus_var].as_string()
3064
                    if int_model_value >= 0
3065
                    else "-"
3066
                )
3067
                + z3_solver.model()[zeroes_padding_var].as_string()
3068
                + (str_model_value if int_model_value >= 0 else str(-int_model_value)),
3069
                var.n_type,
3070
            )
3071

3072
    def extract_model_value_flexible_var(
1✔
3073
        self,
3074
        var: language.Variable,
3075
        model: z3.ModelRef,
3076
        fresh_var_map: Dict[language.Variable, z3.ExprRef],
3077
        length_vars: Set[language.Variable],
3078
        int_vars: Set[language.Variable],
3079
    ) -> DerivationTree:
3080
        """
3081
        Addresses the case of "flexible" variables from
3082
        :meth:`~isla.solver.ISLaSolver.extract_model_value`.
3083

3084
        :param fallback: The function to call if this function is not responsible.
3085
        :param var: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
3086
        :param model: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
3087
        :param fresh_var_map: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
3088
        :param length_vars: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
3089
        :param int_vars: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
3090
        :return: See :meth:`~isla.solver.ISLaSolver.extract_model_value`.
3091
        """
3092

3093
        return self.parse(
1✔
3094
            smt_string_val_to_string(model[z3.String(var.name)]),
3095
            var.n_type,
3096
        )
3097

3098
    @staticmethod
1✔
3099
    def infer_variable_contexts(
1✔
3100
        variables: Set[language.Variable], smt_formulas: ImmutableList[z3.BoolRef]
3101
    ) -> Dict[str, Set[language.Variable]]:
3102
        """
3103
        Divides the given variables into
3104

3105
        1. those that occur only in :code:`length(...)` contexts,
3106
        2. those that occur only in :code:`str.to.int(...)` contexts, and
3107
        3. "flexible" constants occurring in other/various contexts.
3108

3109
        >>> x = language.Variable("x", "<X>")
3110
        >>> y = language.Variable("y", "<Y>")
3111

3112
        Two variables in an arbitrary context.
3113

3114
        >>> f = z3_eq(x.to_smt(), y.to_smt())
3115
        >>> contexts = ISLaSolver.infer_variable_contexts({x, y}, (f,))
3116
        >>> contexts["length"]
3117
        set()
3118
        >>> contexts["int"]
3119
        set()
3120
        >>> contexts["flexible"] == {language.Variable("x", "<X>"), language.Variable("y", "<Y>")}
3121
        True
3122

3123
        Variable x occurs in a length context, variable y in an arbitrary one.
3124

3125
        >>> f = z3.And(
3126
        ...     z3.Length(x.to_smt()) > z3.IntVal(10),
3127
        ...     z3_eq(y.to_smt(), z3.StringVal("y")))
3128
        >>> ISLaSolver.infer_variable_contexts({x, y}, (f,))
3129
        {'length': {Variable("x", "<X>")}, 'int': set(), 'flexible': {Variable("y", "<Y>")}}
3130

3131
        Variable x occurs in a length context, y does not occur.
3132

3133
        >>> f = z3.Length(x.to_smt()) > z3.IntVal(10)
3134
        >>> ISLaSolver.infer_variable_contexts({x, y}, (f,))
3135
        {'length': {Variable("x", "<X>")}, 'int': set(), 'flexible': {Variable("y", "<Y>")}}
3136

3137
        Variables x and y both occur in a length context.
3138

3139
        >>> f = z3.Length(x.to_smt()) > z3.Length(y.to_smt())
3140
        >>> contexts = ISLaSolver.infer_variable_contexts({x, y}, (f,))
3141
        >>> contexts["length"] == {language.Variable("x", "<X>"), language.Variable("y", "<Y>")}
3142
        True
3143
        >>> contexts["int"]
3144
        set()
3145
        >>> contexts["flexible"]
3146
        set()
3147

3148
        Variable x occurs in a :code:`str.to.int` context.
3149

3150
        >>> f = z3.StrToInt(x.to_smt()) > z3.IntVal(17)
3151
        >>> ISLaSolver.infer_variable_contexts({x}, (f,))
3152
        {'length': set(), 'int': {Variable("x", "<X>")}, 'flexible': set()}
3153

3154
        Now, x also occurs in a different context; it's "flexible" now.
3155

3156
        >>> f = z3.And(
3157
        ...     z3.StrToInt(x.to_smt()) > z3.IntVal(17),
3158
        ...     z3_eq(x.to_smt(), z3.StringVal("17")))
3159
        >>> ISLaSolver.infer_variable_contexts({x}, (f,))
3160
        {'length': set(), 'int': set(), 'flexible': {Variable("x", "<X>")}}
3161

3162
        :param variables: The constants to divide/filter from.
3163
        :param smt_formulas: The SMT formulas to consider in the filtering.
3164
        :return: A pair of constants occurring in `str.len` contexts, and the
3165
        remaining ones. The union of both sets equals `variables`, and both sets
3166
        are disjoint.
3167
        """
3168

3169
        parent_relationships = reduce(
1✔
3170
            merge_dict_of_sets,
3171
            [parent_relationships_in_z3_expr(formula) for formula in smt_formulas],
3172
            {},
3173
        )
3174

3175
        contexts: Dict[language.Variable, Set[int]] = {
1✔
3176
            var: {
3177
                expr.decl().kind()
3178
                for expr in parent_relationships.get(var.to_smt(), set())
3179
            }
3180
            or {-1}
3181
            for var in variables
3182
        }
3183

3184
        # The set `length_vars` consists of all variables that only occur in
3185
        # `str.len(...)` context.
3186
        length_vars: Set[language.Variable] = {
1✔
3187
            var
3188
            for var in variables
3189
            if all(context == z3.Z3_OP_SEQ_LENGTH for context in contexts[var])
3190
        }
3191

3192
        # The set `int_vars` consists of all variables that only occur in
3193
        # `str.to.int(...)` context.
3194
        int_vars: Set[language.Variable] = {
1✔
3195
            var
3196
            for var in variables
3197
            if all(context == z3.Z3_OP_STR_TO_INT for context in contexts[var])
3198
        }
3199

3200
        # "Flexible" variables are the remaining ones.
3201
        flexible_vars = variables.difference(length_vars).difference(int_vars)
1✔
3202

3203
        return {"length": length_vars, "int": int_vars, "flexible": flexible_vars}
1✔
3204

3205
    def generate_language_constraints(
1✔
3206
        self,
3207
        constants: Iterable[language.Variable],
3208
        tree_substitutions: Dict[language.Variable, DerivationTree],
3209
    ) -> List[z3.BoolRef]:
3210
        formulas: List[z3.BoolRef] = []
1✔
3211
        for constant in constants:
1✔
3212
            if constant.is_numeric():
1✔
3213
                regex = z3.Union(
×
3214
                    z3.Re("0"),
3215
                    z3.Concat(z3.Range("1", "9"), z3.Star(z3.Range("0", "9"))),
3216
                )
3217
            elif constant in tree_substitutions:
1✔
3218
                # We have a more concrete shape of the desired instantiation available
3219
                regexes = [
1✔
3220
                    self.extract_regular_expression(t)
3221
                    if is_nonterminal(t)
3222
                    else z3.Re(t)
3223
                    for t in split_str_with_nonterminals(
3224
                        str(tree_substitutions[constant])
3225
                    )
3226
                ]
3227
                assert regexes
1✔
3228
                regex = z3.Concat(*regexes) if len(regexes) > 1 else regexes[0]
1✔
3229
            else:
3230
                regex = self.extract_regular_expression(constant.n_type)
1✔
3231

3232
            formulas.append(z3.InRe(z3.String(constant.name), regex))
1✔
3233

3234
        return formulas
1✔
3235

3236
    def rename_instantiated_variables_in_smt_formulas(self, smt_formulas):
1✔
3237
        old_smt_formulas = smt_formulas
1✔
3238
        smt_formulas = []
1✔
3239
        for subformula in old_smt_formulas:
1✔
3240
            subst_var: language.Variable
3241
            subst_tree: DerivationTree
3242

3243
            new_smt_formula: z3.BoolRef = subformula.formula
1✔
3244
            new_substitutions = subformula.substitutions
1✔
3245
            new_instantiated_variables = subformula.instantiated_variables
1✔
3246

3247
            for subst_var, subst_tree in subformula.substitutions.items():
1✔
3248
                new_name = f"{subst_tree.value}_{subst_tree.id}"
1✔
3249
                new_var = language.BoundVariable(new_name, subst_var.n_type)
1✔
3250

3251
                new_smt_formula = cast(
1✔
3252
                    z3.BoolRef,
3253
                    z3_subst(new_smt_formula, {subst_var.to_smt(): new_var.to_smt()}),
3254
                )
3255
                new_substitutions = {
1✔
3256
                    new_var if k == subst_var else k: v
3257
                    for k, v in new_substitutions.items()
3258
                }
3259
                new_instantiated_variables = {
1✔
3260
                    new_var if v == subst_var else v for v in new_instantiated_variables
3261
                }
3262

3263
            smt_formulas.append(
1✔
3264
                language.SMTFormula(
3265
                    new_smt_formula,
3266
                    *subformula.free_variables_,
3267
                    instantiated_variables=new_instantiated_variables,
3268
                    substitutions=new_substitutions,
3269
                )
3270
            )
3271

3272
        return smt_formulas
1✔
3273

3274
    def process_new_states(
1✔
3275
        self, new_states: List[SolutionState]
3276
    ) -> List[DerivationTree]:
3277
        return [
1✔
3278
            tree
3279
            for new_state in new_states
3280
            for tree in self.process_new_state(new_state)
3281
        ]
3282

3283
    def process_new_state(self, new_state: SolutionState) -> List[DerivationTree]:
1✔
3284
        new_state = self.instantiate_structural_predicates(new_state)
1✔
3285
        new_states = self.establish_invariant(new_state)
1✔
3286
        new_states = [
1✔
3287
            self.remove_nonmatching_universal_quantifiers(new_state)
3288
            for new_state in new_states
3289
        ]
3290
        new_states = [
1✔
3291
            self.remove_infeasible_universal_quantifiers(new_state)
3292
            for new_state in new_states
3293
        ]
3294

3295
        if self.activate_unsat_support and not self.currently_unsat_checking:
1✔
3296
            self.currently_unsat_checking = True
1✔
3297

3298
            for new_state in list(new_states):
1✔
3299
                if new_state.constraint == sc.true():
1✔
3300
                    continue
×
3301

3302
                # Remove states with unsatisfiable SMT-LIB formulas.
3303
                if any(
1✔
3304
                    isinstance(f, language.SMTFormula)
3305
                    for f in split_conjunction(new_state.constraint)
3306
                ) and not is_successful(
3307
                    self.eliminate_all_semantic_formulas(
3308
                        new_state, max_instantiations=1
3309
                    ).bind(lambda a: Some(a) if a else Nothing)
3310
                ):
3311
                    new_states.remove(new_state)
1✔
3312
                    self.logger.debug(
1✔
3313
                        "Dropping state %s, unsatisfiable SMT formulas", new_state
3314
                    )
3315

3316
                # Remove states with unsatisfiable existential formulas.
3317
                existential_formulas = [
1✔
3318
                    f
3319
                    for f in split_conjunction(new_state.constraint)
3320
                    if isinstance(f, language.ExistsFormula)
3321
                ]
3322
                for existential_formula in existential_formulas:
1✔
3323
                    old_start_time = self.start_time
1✔
3324
                    old_timeout_seconds = self.timeout_seconds
1✔
3325
                    old_queue = list(self.queue)
1✔
3326
                    old_solutions = list(self.solutions)
1✔
3327

3328
                    self.queue = []
1✔
3329
                    self.solutions = []
1✔
3330
                    check_state = SolutionState(existential_formula, new_state.tree)
1✔
3331
                    heapq.heappush(self.queue, (0, check_state))
1✔
3332
                    self.start_time = int(time.time())
1✔
3333
                    self.timeout_seconds = 2
1✔
3334

3335
                    try:
1✔
3336
                        self.solve()
1✔
3337
                    except StopIteration:
1✔
3338
                        new_states.remove(new_state)
1✔
3339
                        self.logger.debug(
1✔
3340
                            "Dropping state %s, unsatisfiable existential formula %s",
3341
                            new_state,
3342
                            existential_formula,
3343
                        )
3344
                        break
1✔
3345
                    finally:
3346
                        self.start_time = old_start_time
1✔
3347
                        self.timeout_seconds = old_timeout_seconds
1✔
3348
                        self.queue = old_queue
1✔
3349
                        self.solutions = old_solutions
1✔
3350

3351
            self.currently_unsat_checking = False
1✔
3352

3353
        assert all(
1✔
3354
            state.tree.find_node(tree) is not None
3355
            for state in new_states
3356
            for quantified_formula in split_conjunction(state.constraint)
3357
            if isinstance(quantified_formula, language.QuantifiedFormula)
3358
            for _, tree in quantified_formula.in_variable.filter(lambda t: True)
3359
        )
3360

3361
        solution_trees = [
1✔
3362
            new_state.tree
3363
            for new_state in new_states
3364
            if self.state_is_valid_or_enqueue(new_state)
3365
        ]
3366

3367
        for tree in solution_trees:
1✔
3368
            self.cost_computer.signal_tree_output(tree)
1✔
3369

3370
        return solution_trees
1✔
3371

3372
    def state_is_valid_or_enqueue(self, state: SolutionState) -> bool:
1✔
3373
        """
3374
        Returns True if the given state is valid, such that it can be yielded. Returns False and enqueues the state
3375
        if the state is not yet complete, otherwise returns False and discards the state.
3376
        """
3377

3378
        if state.complete():
1✔
3379
            for _, subtree in state.tree.paths():
1✔
3380
                if subtree.children:
1✔
3381
                    self.seen_coverages.add(
1✔
3382
                        expansion_key(subtree.value, subtree.children)
3383
                    )
3384

3385
            assert state.formula_satisfied(self.grammar).is_true()
1✔
3386
            return True
1✔
3387

3388
        # Helps in debugging below assertion:
3389
        # [(predicate_formula, [
3390
        #     arg for arg in predicate_formula.args
3391
        #     if isinstance(arg, DerivationTree) and not state.tree.find_node(arg)])
3392
        #  for predicate_formula in get_conjuncts(state.constraint)
3393
        #  if isinstance(predicate_formula, language.StructuralPredicateFormula)]
3394

3395
        self.assert_no_dangling_predicate_argument_trees(state)
1✔
3396
        self.assert_no_dangling_smt_formula_argument_trees(state)
1✔
3397

3398
        if (
1✔
3399
            self.enforce_unique_trees_in_queue
3400
            and state.tree.structural_hash() in self.tree_hashes_in_queue
3401
        ):
3402
            # Some structures can arise as well from tree insertion (existential
3403
            # quantifier elimination) and expansion; also, tree insertion can yield
3404
            # different trees that have intersecting expansions. We drop those to output
3405
            # more diverse solutions (numbers for SMT solutions and free nonterminals
3406
            # are configurable, so you get more outputs by playing with those!).
3407
            self.logger.debug("Discarding state %s, tree already in queue", state)
1✔
3408
            return False
1✔
3409

3410
        if hash(state) in self.state_hashes_in_queue:
1✔
3411
            self.logger.debug("Discarding state %s, already in queue", state)
1✔
3412
            return False
1✔
3413

3414
        if self.propositionally_unsatisfiable(state.constraint):
1✔
3415
            self.logger.debug("Discarding state %s", state)
1✔
3416
            return False
1✔
3417

3418
        state = SolutionState(
1✔
3419
            state.constraint, state.tree, level=self.current_level + 1
3420
        )
3421

3422
        self.recompute_costs()
1✔
3423

3424
        cost = self.compute_cost(state)
1✔
3425
        heapq.heappush(self.queue, (cost, state))
1✔
3426
        self.tree_hashes_in_queue.add(state.tree.structural_hash())
1✔
3427
        self.state_hashes_in_queue.add(hash(state))
1✔
3428

3429
        if self.debug:
1✔
3430
            self.state_tree[self.current_state].append(state)
1✔
3431
            self.costs[state] = cost
1✔
3432

3433
        self.logger.debug(
1✔
3434
            "Pushing new state (%s, %s) (hash %d, cost %f)",
3435
            state.constraint,
3436
            state.tree.to_string(show_open_leaves=True, show_ids=True),
3437
            hash(state),
3438
            cost,
3439
        )
3440
        self.logger.debug("Queue length: %d", len(self.queue))
1✔
3441
        if len(self.queue) % 100 == 0:
1✔
3442
            self.logger.info("Queue length: %d", len(self.queue))
1✔
3443

3444
        return False
1✔
3445

3446
    def recompute_costs(self):
1✔
3447
        if self.step_cnt % 400 != 0 or self.step_cnt <= self.last_cost_recomputation:
1✔
3448
            return
1✔
3449

3450
        self.last_cost_recomputation = self.step_cnt
1✔
3451
        self.logger.info(
1✔
3452
            f"Recomputing costs in queue after {self.step_cnt} solver steps"
3453
        )
3454
        old_queue = list(self.queue)
1✔
3455
        self.queue = []
1✔
3456
        for _, state in old_queue:
1✔
3457
            cost = self.compute_cost(state)
1✔
3458
            heapq.heappush(self.queue, (cost, state))
1✔
3459

3460
    def assert_no_dangling_smt_formula_argument_trees(
1✔
3461
        self, state: SolutionState
3462
    ) -> None:
3463
        if not assertions_activated() and not self.debug:
1✔
3464
            return
1✔
3465

3466
        dangling_smt_formula_argument_trees = [
1✔
3467
            (smt_formula, arg)
3468
            for smt_formula in language.FilterVisitor(
3469
                lambda f: isinstance(f, language.SMTFormula)
3470
            ).collect(state.constraint)
3471
            for arg in cast(language.SMTFormula, smt_formula).substitutions.values()
3472
            if isinstance(arg, DerivationTree) and state.tree.find_node(arg) is None
3473
        ]
3474

3475
        if dangling_smt_formula_argument_trees:
1✔
3476
            message = "Dangling SMT formula arguments: ["
×
3477
            message += ", ".join(
×
3478
                [
3479
                    str(f) + ", " + repr(a)
3480
                    for f, a in dangling_smt_formula_argument_trees
3481
                ]
3482
            )
3483
            message += "]"
×
3484
            assert False, message
×
3485

3486
    def assert_no_dangling_predicate_argument_trees(self, state: SolutionState) -> None:
1✔
3487
        if not assertions_activated() and not self.debug:
1✔
3488
            return
1✔
3489

3490
        dangling_predicate_argument_trees = [
1✔
3491
            (predicate_formula, arg)
3492
            for predicate_formula in language.FilterVisitor(
3493
                lambda f: isinstance(f, language.StructuralPredicateFormula)
3494
            ).collect(state.constraint)
3495
            for arg in cast(language.StructuralPredicateFormula, predicate_formula).args
3496
            if isinstance(arg, DerivationTree) and state.tree.find_node(arg) is None
3497
        ]
3498

3499
        if dangling_predicate_argument_trees:
1✔
3500
            message = "Dangling predicate arguments: ["
×
3501
            message += ", ".join(
×
3502
                [str(f) + ", " + repr(a) for f, a in dangling_predicate_argument_trees]
3503
            )
3504
            message += "]"
×
3505
            assert False, message
×
3506

3507
    def propositionally_unsatisfiable(self, formula: language.Formula) -> bool:
1✔
3508
        return formula == sc.false()
1✔
3509

3510
        # NOTE: Deactivated propositional check for performance reasons
3511
        # z3_formula = language.isla_to_smt_formula(formula, replace_untranslatable_with_predicate=True)
3512
        # solver = z3.Solver()
3513
        # solver.add(z3_formula)
3514
        # return solver.check() == z3.unsat
3515

3516
    def establish_invariant(self, state: SolutionState) -> List[SolutionState]:
1✔
3517
        clauses = to_dnf_clauses(convert_to_nnf(state.constraint))
1✔
3518
        return [
1✔
3519
            SolutionState(reduce(lambda a, b: a & b, clause, sc.true()), state.tree)
3520
            for clause in clauses
3521
        ]
3522

3523
    def compute_cost(self, state: SolutionState) -> float:
1✔
3524
        if state.constraint == sc.true():
1✔
3525
            return 0
1✔
3526

3527
        return self.cost_computer.compute_cost(state)
1✔
3528

3529
    def remove_nonmatching_universal_quantifiers(
1✔
3530
        self, state: SolutionState
3531
    ) -> SolutionState:
3532
        conjuncts = [conjunct for conjunct in get_conjuncts(state.constraint)]
1✔
3533
        deleted = False
1✔
3534

3535
        for idx, universal_formula in reversed(list(enumerate(conjuncts))):
1✔
3536
            if not isinstance(universal_formula, language.ForallFormula):
1✔
3537
                continue
1✔
3538

3539
            if (
1✔
3540
                universal_formula.in_variable.is_complete()
3541
                and not matches_for_quantified_formula(universal_formula, self.grammar)
3542
            ):
3543
                deleted = True
1✔
3544
                del conjuncts[idx]
1✔
3545

3546
        if not deleted:
1✔
3547
            return state
1✔
3548

3549
        return SolutionState(sc.conjunction(*conjuncts), state.tree)
1✔
3550

3551
    def remove_infeasible_universal_quantifiers(
1✔
3552
        self, state: SolutionState
3553
    ) -> SolutionState:
3554
        conjuncts = get_conjuncts(state.constraint)
1✔
3555
        one_removed = False
1✔
3556

3557
        for idx, universal_formula in reversed(list(enumerate(conjuncts))):
1✔
3558
            if not isinstance(universal_formula, language.ForallFormula):
1✔
3559
                continue
1✔
3560

3561
            matches = matches_for_quantified_formula(universal_formula, self.grammar)
1✔
3562

3563
            all_matches_matched = all(
1✔
3564
                universal_formula.is_already_matched(
3565
                    match[universal_formula.bound_variable][1]
3566
                )
3567
                for match in matches
3568
            )
3569

3570
            def some_leaf_might_match() -> bool:
1✔
3571
                return any(
1✔
3572
                    self.quantified_formula_might_match(
3573
                        universal_formula, leaf_path, universal_formula.in_variable
3574
                    )
3575
                    for leaf_path, _ in universal_formula.in_variable.open_leaves()
3576
                )
3577

3578
            if all_matches_matched and not some_leaf_might_match():
1✔
3579
                one_removed = True
1✔
3580
                del conjuncts[idx]
1✔
3581

3582
        return (
1✔
3583
            state
3584
            if not one_removed
3585
            else SolutionState(
3586
                reduce(lambda a, b: a & b, conjuncts, sc.true()),
3587
                state.tree,
3588
            )
3589
        )
3590

3591
    def quantified_formula_might_match(
1✔
3592
        self,
3593
        qfd_formula: language.QuantifiedFormula,
3594
        path_to_nonterminal: Path,
3595
        tree: DerivationTree,
3596
    ) -> bool:
3597
        return quantified_formula_might_match(
1✔
3598
            qfd_formula,
3599
            path_to_nonterminal,
3600
            tree,
3601
            self.grammar,
3602
            self.graph.reachable,
3603
        )
3604

3605
    def extract_regular_expression(self, nonterminal: str) -> z3.ReRef:
1✔
3606
        if nonterminal in self.regex_cache:
1✔
3607
            return self.regex_cache[nonterminal]
1✔
3608

3609
        # For definitions like `<a> ::= <b>`, we only compute the regular expression
3610
        # for `<b>`. That way, we might save some calls if `<b>` is used multiple times
3611
        # (e.g., as in `<byte>`).
3612
        canonical_expansions = self.canonical_grammar[nonterminal]
1✔
3613

3614
        if (
1✔
3615
            len(canonical_expansions) == 1
3616
            and len(canonical_expansions[0]) == 1
3617
            and is_nonterminal(canonical_expansions[0][0])
3618
        ):
3619
            sub_nonterminal = canonical_expansions[0][0]
1✔
3620
            assert (
1✔
3621
                nonterminal != sub_nonterminal
3622
            ), f"Expansion {nonterminal} => {sub_nonterminal}: Infinite recursion!"
3623
            return self.regex_cache.setdefault(
1✔
3624
                nonterminal, self.extract_regular_expression(sub_nonterminal)
3625
            )
3626

3627
        # Similarly, for definitions like `<a> ::= <b> " x " <c>`, where `<b>` and `<c>`
3628
        # don't reach `<a>`, we only compute the regular expressions for `<b>` and `<c>`
3629
        # and return a concatenation. This also saves us expensive conversions (e.g.,
3630
        # for `<seq> ::= <byte> <byte>`).
3631
        if (
1✔
3632
            len(canonical_expansions) == 1
3633
            and any(is_nonterminal(elem) for elem in canonical_expansions[0])
3634
            and all(
3635
                not is_nonterminal(elem)
3636
                or elem != nonterminal
3637
                and not self.graph.reachable(elem, nonterminal)
3638
                for elem in canonical_expansions[0]
3639
            )
3640
        ):
3641
            result_elements: List[z3.ReRef] = [
1✔
3642
                z3.Re(elem)
3643
                if not is_nonterminal(elem)
3644
                else self.extract_regular_expression(elem)
3645
                for elem in canonical_expansions[0]
3646
            ]
3647
            return self.regex_cache.setdefault(nonterminal, z3.Concat(*result_elements))
1✔
3648

3649
        regex_conv = RegexConverter(
1✔
3650
            self.grammar,
3651
            compress_unions=True,
3652
            max_num_expansions=self.grammar_unwinding_threshold,
3653
        )
3654
        regex = regex_conv.to_regex(nonterminal, convert_to_z3=False)
1✔
3655
        self.logger.debug(
1✔
3656
            f"Computed regular expression for nonterminal {nonterminal}:\n{regex}"
3657
        )
3658
        z3_regex = regex_to_z3(regex)
1✔
3659

3660
        if assertions_activated():
1✔
3661
            # Check correctness of regular expression
3662
            grammar = self.graph.subgraph(nonterminal).to_grammar()
1✔
3663

3664
            # L(regex) \subseteq L(grammar)
3665
            self.logger.debug(
1✔
3666
                "Checking L(regex) \\subseteq L(grammar) for "
3667
                + "nonterminal '%s' and regex '%s'",
3668
                nonterminal,
3669
                regex,
3670
            )
3671
            parser = EarleyParser(grammar)
1✔
3672
            c = z3.String("c")
1✔
3673
            prev: Set[str] = set()
1✔
3674
            for _ in range(100):
1✔
3675
                s = z3.Solver()
1✔
3676
                s.add(z3.InRe(c, z3_regex))
1✔
3677
                for inp in prev:
1✔
3678
                    s.add(z3.Not(c == z3.StringVal(inp)))
1✔
3679
                if s.check() != z3.sat:
1✔
3680
                    self.logger.debug(
×
3681
                        "Cannot find the %d-th solution for regex %s (timeout).\nThis is *not* a problem "
3682
                        "if there not that many solutions (for regexes with finite language), or if we "
3683
                        "are facing a meaningless timeout of the solver.",
3684
                        len(prev) + 1,
3685
                        regex,
3686
                    )
3687
                    break
×
3688
                new_inp = smt_string_val_to_string(s.model()[c])
1✔
3689
                try:
1✔
3690
                    next(parser.parse(new_inp))
1✔
3691
                except SyntaxError:
×
3692
                    assert (
×
3693
                        False
3694
                    ), f"Input '{new_inp}' from regex language is not in grammar language."
3695
                prev.add(new_inp)
1✔
3696

3697
        self.regex_cache[nonterminal] = z3_regex
1✔
3698

3699
        return z3_regex
1✔
3700

3701

3702
class CostComputer(ABC):
1✔
3703
    def compute_cost(self, state: SolutionState) -> float:
1✔
3704
        """
3705
        Computes a cost value for the given state. States with lower cost
3706
        will be preferred in the analysis.
3707

3708
        :param state: The state for which to compute a cost.
3709
        :return: The cost value.
3710
        """
3711
        raise NotImplementedError()
×
3712

3713
    def signal_tree_output(self, tree: DerivationTree) -> None:
1✔
3714
        """
3715
        Should be called when a tree is output as a solution. Used to
3716
        update internal information for cost computation.
3717

3718
        :param tree The tree that is output as a solution.
3719
        :return: Nothing.
3720
        """
3721
        raise NotImplementedError()
×
3722

3723

3724
class GrammarBasedBlackboxCostComputer(CostComputer):
1✔
3725
    def __init__(
1✔
3726
        self,
3727
        cost_settings: CostSettings,
3728
        graph: gg.GrammarGraph,
3729
        reset_coverage_after_n_round_with_no_coverage: int = 100,
3730
        symbol_costs: Optional[Dict[str, int]] = None,
3731
    ):
3732
        self.cost_settings = cost_settings
1✔
3733
        self.graph = graph
1✔
3734

3735
        self.covered_k_paths: Set[Tuple[gg.Node, ...]] = set()
1✔
3736
        self.rounds_with_no_new_coverage = 0
1✔
3737
        self.reset_coverage_after_n_round_with_no_coverage = (
1✔
3738
            reset_coverage_after_n_round_with_no_coverage
3739
        )
3740
        self.symbol_costs: Optional[Dict[str, int]] = symbol_costs
1✔
3741

3742
        self.logger = logging.getLogger(type(self).__name__)
1✔
3743

3744
    def __repr__(self):
1✔
3745
        return (
×
3746
            "GrammarBasedBlackboxCostComputer("
3747
            + f"{repr(self.cost_settings)}, "
3748
            + "graph, "
3749
            + f"{self.reset_coverage_after_n_round_with_no_coverage}, "
3750
            + f"{self.symbol_costs})"
3751
        )
3752

3753
    def compute_cost(self, state: SolutionState) -> float:
1✔
3754
        # How costly is it to finish the tree?
3755
        tree_closing_cost = self.compute_tree_closing_cost(state.tree)
1✔
3756

3757
        # Quantifiers are expensive (universal formulas have to be matched, tree insertion for existential
3758
        # formulas is even more costly). TODO: Penalize nested quantifiers more.
3759
        constraint_cost = sum(
1✔
3760
            [
3761
                idx * (2 if isinstance(f, language.ExistsFormula) else 1) + 1
3762
                for c in get_quantifier_chains(state.constraint)
3763
                for idx, f in enumerate(c)
3764
            ]
3765
        )
3766

3767
        # k-Path coverage: Fewer covered -> higher penalty
3768
        k_cov_cost = self._compute_k_coverage_cost(state)
1✔
3769

3770
        # Covered k-paths: Fewer contributed -> higher penalty
3771
        global_k_path_cost = self._compute_global_k_coverage_cost(state)
1✔
3772

3773
        costs = [
1✔
3774
            tree_closing_cost,
3775
            constraint_cost,
3776
            state.level,
3777
            k_cov_cost,
3778
            global_k_path_cost,
3779
        ]
3780
        assert tree_closing_cost >= 0, f"tree_closing_cost == {tree_closing_cost}!"
1✔
3781
        assert constraint_cost >= 0, f"constraint_cost == {constraint_cost}!"
1✔
3782
        assert state.level >= 0, f"state.level == {state.level}!"
1✔
3783
        assert k_cov_cost >= 0, f"k_cov_cost == {k_cov_cost}!"
1✔
3784
        assert global_k_path_cost >= 0, f"global_k_path_cost == {global_k_path_cost}!"
1✔
3785

3786
        # Compute geometric mean
3787
        result = weighted_geometric_mean(costs, list(self.cost_settings.weight_vector))
1✔
3788

3789
        self.logger.debug(
1✔
3790
            "Computed cost for state %s:\n%f, individual costs: %s, weights: %s",
3791
            lazystr(lambda: f"({(str(state.constraint)[:50] + '...')}, {state.tree})"),
3792
            result,
3793
            costs,
3794
            self.cost_settings.weight_vector,
3795
        )
3796

3797
        return result
1✔
3798

3799
    def signal_tree_output(self, tree: DerivationTree) -> None:
1✔
3800
        self._update_covered_k_paths(tree)
1✔
3801

3802
    def _symbol_costs(self):
1✔
3803
        if self.symbol_costs is None:
1✔
3804
            self.symbol_costs = compute_symbol_costs(self.graph)
1✔
3805
        return self.symbol_costs
1✔
3806

3807
    def _update_covered_k_paths(self, tree: DerivationTree):
1✔
3808
        if self.cost_settings.weight_vector.low_global_k_path_coverage_penalty > 0:
1✔
3809
            old_covered_k_paths = copy.copy(self.covered_k_paths)
1✔
3810

3811
            self.covered_k_paths.update(
1✔
3812
                tree.k_paths(
3813
                    self.graph, self.cost_settings.k, include_potential_paths=False
3814
                )
3815
            )
3816

3817
            if old_covered_k_paths == self.covered_k_paths:
1✔
3818
                self.rounds_with_no_new_coverage += 1
1✔
3819

3820
            graph_paths = self.graph.k_paths(
1✔
3821
                self.cost_settings.k, include_terminals=False
3822
            )
3823
            if (
1✔
3824
                self.rounds_with_no_new_coverage
3825
                >= self.reset_coverage_after_n_round_with_no_coverage
3826
                or self.covered_k_paths == graph_paths
3827
            ):
3828
                if self.covered_k_paths == graph_paths:
1✔
UNCOV
3829
                    self.logger.debug("ALL PATHS COVERED")
×
3830
                else:
3831
                    self.logger.debug(
1✔
3832
                        "COVERAGE RESET SINCE NO CHANGE IN COVERED PATHS SINCE %d "
3833
                        + "ROUNDS (%d path(s) uncovered)",
3834
                        self.reset_coverage_after_n_round_with_no_coverage,
3835
                        len(graph_paths) - len(self.covered_k_paths),
3836
                    )
3837

3838
                    # uncovered_paths = (
3839
                    #     self.graph.k_paths(
3840
                    #         self.cost_settings.k, include_terminals=False
3841
                    #     )
3842
                    #     - self.covered_k_paths
3843
                    # )
3844
                    # self.logger.debug(
3845
                    #     "\n".join(
3846
                    #         [
3847
                    #             ", ".join(f"'{n.symbol}'" for n in p)
3848
                    #             for p in uncovered_paths
3849
                    #         ]
3850
                    #     )
3851
                    # )
3852

3853
                self.covered_k_paths = set()
1✔
3854
            else:
3855
                pass
1✔
3856
                # uncovered_paths = (
3857
                #     self.graph.k_paths(self.cost_settings.k, include_terminals=False)
3858
                #     - self.covered_k_paths
3859
                # )
3860
                # self.logger.debug("%d uncovered paths", len(uncovered_paths))
3861
                # self.logger.debug(
3862
                #     "\n"
3863
                #     + "\n".join(
3864
                #         [", ".join(f"'{n.symbol}'" for n in p)
3865
                #         for p in uncovered_paths]
3866
                #     )
3867
                #     + "\n"
3868
                # )
3869

3870
            if (
1✔
3871
                self.rounds_with_no_new_coverage
3872
                >= self.reset_coverage_after_n_round_with_no_coverage
3873
            ):
3874
                self.rounds_with_no_new_coverage = 0
1✔
3875

3876
    def _compute_global_k_coverage_cost(self, state: SolutionState):
1✔
3877
        if self.cost_settings.weight_vector.low_global_k_path_coverage_penalty == 0:
1✔
3878
            return 0
×
3879

3880
        tree_k_paths = state.tree.k_paths(
1✔
3881
            self.graph, self.cost_settings.k, include_potential_paths=False
3882
        )
3883
        all_graph_k_paths = self.graph.k_paths(
1✔
3884
            self.cost_settings.k, include_terminals=False
3885
        )
3886

3887
        contributed_k_paths = {
1✔
3888
            path
3889
            for path in all_graph_k_paths
3890
            if path in tree_k_paths and path not in self.covered_k_paths
3891
        }
3892

3893
        num_contributed_k_paths = len(contributed_k_paths)
1✔
3894
        num_missing_k_paths = len(all_graph_k_paths) - len(self.covered_k_paths)
1✔
3895

3896
        # self.logger.debug(
3897
        #     'k-Paths contributed by input %s:\n%s',
3898
        #     state.tree,
3899
        #     '\n'.join(map(
3900
        #         lambda path: ' '.join(map(
3901
        #             lambda n: n.symbol,
3902
        #             filter(lambda n: not isinstance(n, gg.ChoiceNode), path))),
3903
        #         contributed_k_paths)))
3904
        # self.logger.debug('Missing k paths: %s', num_missing_k_paths)
3905

3906
        assert 0 <= num_contributed_k_paths <= num_missing_k_paths, (
1✔
3907
            f"num_contributed_k_paths == {num_contributed_k_paths}, "
3908
            f"num_missing_k_paths == {num_missing_k_paths}"
3909
        )
3910

3911
        # return 1 - (num_contributed_k_paths / num_missing_k_paths)
3912

3913
        potential_tree_k_paths = state.tree.k_paths(
1✔
3914
            self.graph, self.cost_settings.k, include_potential_paths=True
3915
        )
3916
        contributed_k_paths = {
1✔
3917
            path
3918
            for path in all_graph_k_paths
3919
            if path in potential_tree_k_paths and path not in self.covered_k_paths
3920
        }
3921

3922
        num_contributed_potential_k_paths = len(contributed_k_paths)
1✔
3923

3924
        if not num_missing_k_paths:
1✔
3925
            return 0
1✔
3926

3927
        return 1 - weighted_geometric_mean(
1✔
3928
            [
3929
                num_contributed_k_paths / num_missing_k_paths,
3930
                num_contributed_potential_k_paths / num_missing_k_paths,
3931
            ],
3932
            [0.2, 0.8],
3933
        )
3934

3935
    def _compute_k_coverage_cost(self, state: SolutionState) -> float:
1✔
3936
        if self.cost_settings.weight_vector.low_k_coverage_penalty == 0:
1✔
3937
            return 0
1✔
3938

3939
        coverages = []
1✔
3940
        for k in range(1, self.cost_settings.k + 1):
1✔
3941
            coverage = state.tree.k_coverage(
1✔
3942
                self.graph, k, include_potential_paths=False
3943
            )
3944
            assert 0 <= coverage <= 1, f"coverage == {coverage}"
1✔
3945

3946
            coverages.append(1 - coverage)
1✔
3947

3948
        return math.prod(coverages) ** (1 / float(self.cost_settings.k))
1✔
3949

3950
    def compute_tree_closing_cost(self, tree: DerivationTree) -> float:
1✔
3951
        nonterminals = [leaf.value for _, leaf in tree.open_leaves()]
1✔
3952
        return sum([self._symbol_costs()[nonterminal] for nonterminal in nonterminals])
1✔
3953

3954

3955
def smt_formulas_referring_to_subtrees(
1✔
3956
    smt_formulas: Sequence[language.SMTFormula],
3957
) -> List[language.SMTFormula]:
3958
    """
3959
    Returns a list of SMT formulas whose solutions address subtrees of other SMT
3960
    formulas, but whose own substitution subtrees are in turn *not* referred by
3961
    top-level substitution trees of other formulas. Those must be solved first to avoid
3962
    inconsistencies.
3963

3964
    :param smt_formulas: The formulas to search for references to subtrees.
3965
    :return: The list of conflicting formulas that must be solved first.
3966
    """
3967

3968
    def subtree_ids(formula: language.SMTFormula) -> Set[int]:
1✔
3969
        return {
1✔
3970
            subtree.id
3971
            for tree in formula.substitutions.values()
3972
            for _, subtree in tree.paths()
3973
            if subtree.id != tree.id
3974
        }
3975

3976
    def tree_ids(formula: language.SMTFormula) -> Set[int]:
1✔
3977
        return {tree.id for tree in formula.substitutions.values()}
1✔
3978

3979
    subtree_ids_for_formula: Dict[language.SMTFormula, Set[int]] = {
1✔
3980
        formula: subtree_ids(formula) for formula in smt_formulas
3981
    }
3982

3983
    tree_ids_for_formula: Dict[language.SMTFormula, Set[int]] = {
1✔
3984
        formula: tree_ids(formula) for formula in smt_formulas
3985
    }
3986

3987
    def independent_from_solutions_of_other_formula(
1✔
3988
        idx: int, formula: language.SMTFormula
3989
    ) -> bool:
3990
        return all(
1✔
3991
            not tree_ids_for_formula[other_formula].intersection(
3992
                subtree_ids_for_formula[formula]
3993
            )
3994
            for other_idx, other_formula in enumerate(smt_formulas)
3995
            if other_idx != idx
3996
        )
3997

3998
    def refers_to_subtree_of_other_formula(
1✔
3999
        idx: int, formula: language.SMTFormula
4000
    ) -> bool:
4001
        return any(
1✔
4002
            tree_ids_for_formula[formula].intersection(
4003
                subtree_ids_for_formula[other_formula]
4004
            )
4005
            for other_idx, other_formula in enumerate(smt_formulas)
4006
            if other_idx != idx
4007
        )
4008

4009
    return [
1✔
4010
        formula
4011
        for idx, formula in enumerate(smt_formulas)
4012
        if refers_to_subtree_of_other_formula(idx, formula)
4013
        and independent_from_solutions_of_other_formula(idx, formula)
4014
    ]
4015

4016

4017
def compute_tree_closing_cost(tree: DerivationTree, graph: GrammarGraph) -> float:
1✔
4018
    nonterminals = [leaf.value for _, leaf in tree.open_leaves()]
1✔
4019
    return sum(
1✔
4020
        [compute_symbol_costs(graph)[nonterminal] for nonterminal in nonterminals]
4021
    )
4022

4023

4024
def get_quantifier_chains(
1✔
4025
    formula: language.Formula,
4026
) -> List[Tuple[Union[language.QuantifiedFormula, language.ExistsIntFormula], ...]]:
4027
    univ_toplevel_formulas = get_toplevel_quantified_formulas(formula)
1✔
4028
    return [
1✔
4029
        (f,) + c
4030
        for f in univ_toplevel_formulas
4031
        for c in (get_quantifier_chains(f.inner_formula) or [()])
4032
    ]
4033

4034

4035
def shortest_derivations(graph: gg.GrammarGraph) -> Dict[str, int]:
1✔
4036
    def avg(it) -> int:
1✔
4037
        elems = [elem for elem in it if elem is not None]
1✔
4038
        return math.ceil(math.prod(elems) ** (1 / len(elems)))
1✔
4039

4040
    parent_relation = {node: set() for node in graph.all_nodes}
1✔
4041
    for parent, child in graph.all_edges:
1✔
4042
        parent_relation[child].add(parent)
1✔
4043

4044
    shortest_node_derivations: Dict[gg.Node, int] = {}
1✔
4045
    stack: List[gg.Node] = graph.filter(lambda node: isinstance(node, gg.TerminalNode))
1✔
4046
    while stack:
1✔
4047
        node = stack.pop()
1✔
4048

4049
        old_min = shortest_node_derivations.get(node, None)
1✔
4050

4051
        if isinstance(node, gg.TerminalNode):
1✔
4052
            shortest_node_derivations[node] = 0
1✔
4053
        elif isinstance(node, gg.ChoiceNode):
1✔
4054
            shortest_node_derivations[node] = max(
1✔
4055
                shortest_node_derivations.get(child, 0) for child in node.children
4056
            )
4057
        elif isinstance(node, gg.NonterminalNode):
1✔
4058
            assert not isinstance(node, gg.ChoiceNode)
1✔
4059

4060
            shortest_node_derivations[node] = (
1✔
4061
                avg(
4062
                    shortest_node_derivations.get(child, None)
4063
                    for child in node.children
4064
                )
4065
                + 1
4066
            )
4067

4068
        if (old_min or sys.maxsize) > shortest_node_derivations[node]:
1✔
4069
            stack.extend(parent_relation[node])
1✔
4070

4071
    return {
1✔
4072
        nonterminal: shortest_node_derivations[graph.get_node(nonterminal)]
4073
        for nonterminal in graph.grammar
4074
    }
4075

4076

4077
@lru_cache()
1✔
4078
def compute_symbol_costs(graph: GrammarGraph) -> Dict[str, int]:
1✔
4079
    grammar = graph.to_grammar()
1✔
4080
    canonical_grammar = canonical(grammar)
1✔
4081

4082
    result: Dict[str, int] = shortest_derivations(graph)
1✔
4083

4084
    nonterminal_parents = [
1✔
4085
        nonterminal
4086
        for nonterminal in canonical_grammar
4087
        if any(
4088
            is_nonterminal(symbol)
4089
            for expansion in canonical_grammar[nonterminal]
4090
            for symbol in expansion
4091
        )
4092
    ]
4093

4094
    # Sometimes this computation results in some nonterminals having lower cost values
4095
    # than nonterminals that are reachable from those (but not vice versa), which is
4096
    # undesired. We counteract this by assuring that on paths with at most one cycle
4097
    # from the root to any nonterminal parent, the costs are strictly monotonically
4098
    # decreasing.
4099
    for nonterminal_parent in nonterminal_parents:
1✔
4100
        # noinspection PyTypeChecker
4101
        for path in all_paths(graph, graph.root, graph.get_node(nonterminal_parent)):
1✔
4102
            for idx in reversed(range(1, len(path))):
1✔
4103
                source: gg.Node = path[idx - 1]
1✔
4104
                target: gg.Node = path[idx]
1✔
4105

4106
                if result[source.symbol] <= result[target.symbol]:
1✔
4107
                    result[source.symbol] = result[target.symbol] + 1
1✔
4108

4109
    return result
1✔
4110

4111

4112
def all_paths(
1✔
4113
    graph,
4114
    from_node: gg.NonterminalNode,
4115
    to_node: gg.NonterminalNode,
4116
    cycles_allowed: int = 2,
4117
) -> List[List[gg.NonterminalNode]]:
4118
    """Compute all paths between two nodes. Note: We allow to visit each nonterminal twice.
4119
    This is not really allowing up to `cycles_allowed` cycles (which was the original intention
4120
    of the parameter), since then we would have to check per path; yet, the number of paths would
4121
    explode then and the current implementation provides reasonably good results."""
4122
    result: List[List[gg.NonterminalNode]] = []
1✔
4123
    visited: Dict[gg.NonterminalNode, int] = {n: 0 for n in graph.all_nodes}
1✔
4124

4125
    queue: List[List[gg.NonterminalNode]] = [[from_node]]
1✔
4126
    while queue:
1✔
4127
        p = queue.pop(0)
1✔
4128
        if p[-1] == to_node:
1✔
4129
            result.append(p)
1✔
4130
            continue
1✔
4131

4132
        for child in p[-1].children:
1✔
4133
            if (
1✔
4134
                not isinstance(child, gg.NonterminalNode)
4135
                or visited[child] > cycles_allowed + 1
4136
            ):
4137
                continue
1✔
4138

4139
            visited[child] += 1
1✔
4140
            queue.append(p + [child])
1✔
4141

4142
    return [[n for n in p if not isinstance(n, gg.ChoiceNode)] for p in result]
1✔
4143

4144

4145
def implies(
1✔
4146
    f1: language.Formula, f2: language.Formula, grammar: Grammar, timeout_seconds=5
4147
) -> bool:
4148
    solver = ISLaSolver(
1✔
4149
        grammar, f1 & -f2, activate_unsat_support=True, timeout_seconds=timeout_seconds
4150
    )
4151

4152
    return (
1✔
4153
        safe(solver.solve, exceptions=(StopIteration,))()
4154
        .map(lambda _: False)
4155
        .lash(lambda _: Success(True))
4156
    ).unwrap()
4157

4158

4159
def equivalent(
1✔
4160
    f1: language.Formula, f2: language.Formula, grammar: Grammar, timeout_seconds=5
4161
) -> bool:
4162
    solver = ISLaSolver(
1✔
4163
        grammar,
4164
        -(f1 & f2 | -f1 & -f2),
4165
        activate_unsat_support=True,
4166
        timeout_seconds=timeout_seconds,
4167
    )
4168

4169
    return (
1✔
4170
        safe(solver.solve)()
4171
        .map(lambda _: False)
4172
        .lash(lambda e: Success(isinstance(e, StopIteration)))
4173
    ).unwrap()
4174

4175

4176
def generate_abstracted_trees(
1✔
4177
    inp: DerivationTree, participating_paths: Set[Path]
4178
) -> List[DerivationTree]:
4179
    """
4180
    Yields trees that are more and more "abstracted," i.e., pruned, at prefixes of the
4181
    paths specified in `participating_paths`.
4182

4183
    :param inp: The unabstracted input.
4184
    :param participating_paths: The paths to abstract.
4185
    :return: A generator of more and more abstract trees, beginning with the most
4186
    concrete and ending with the most abstract ones.
4187
    """
4188
    parent_paths: Set[ImmutableList[Path]] = {
1✔
4189
        tuple(
4190
            [tuple(path[:i]) for i in reversed(range(1, len(path) + 1))]
4191
            if path
4192
            else [()]
4193
        )
4194
        for path in participating_paths
4195
    }
4196

4197
    abstraction_candidate_combinations: Set[ImmutableList[Path]] = {
1✔
4198
        tuple(eliminate_suffixes(combination))
4199
        for k in range(1, len(participating_paths) + 1)
4200
        for paths in itertools.product(*parent_paths)
4201
        for combination in itertools.combinations(paths, k)
4202
    }
4203

4204
    result: Dict[int, DerivationTree] = {}
1✔
4205
    for paths_to_abstract in abstraction_candidate_combinations:
1✔
4206
        abstracted_tree = inp.substitute(
1✔
4207
            {
4208
                inp.get_subtree(path_to_abstract): DerivationTree(
4209
                    inp.get_subtree(path_to_abstract).value
4210
                )
4211
                for path_to_abstract in paths_to_abstract
4212
            }
4213
        )
4214
        result[abstracted_tree.structural_hash()] = abstracted_tree
1✔
4215

4216
    return sorted(result.values(), key=lambda tree: -len(tree))
1✔
4217

4218

4219
class EvaluatePredicateFormulasTransformer(NoopFormulaTransformer):
1✔
4220
    def __init__(self, inp: DerivationTree):
1✔
4221
        super().__init__()
1✔
4222
        self.inp = inp
1✔
4223

4224
    def transform_predicate_formula(
1✔
4225
        self, sub_formula: language.StructuralPredicateFormula
4226
    ) -> language.Formula:
4227
        return sc.true() if sub_formula.evaluate(self.inp) else sc.false()
1✔
4228

4229
    def transform_conjunctive_formula(
1✔
4230
        self, sub_formula: language.ConjunctiveFormula
4231
    ) -> language.Formula:
4232
        return reduce(language.Formula.__and__, sub_formula.args)
1✔
4233

4234
    def transform_disjunctive_formula(
1✔
4235
        self, sub_formula: language.DisjunctiveFormula
4236
    ) -> language.Formula:
4237
        return reduce(language.Formula.__or__, sub_formula.args)
1✔
4238

4239
    def transform_smt_formula(
1✔
4240
        self, sub_formula: language.SMTFormula
4241
    ) -> language.Formula:
4242
        # We instantiate the formula and check whether it evaluates to
4243
        # True (or False in a negation scope); in that case, we replace
4244
        # it by "true." Otherwise, we keep it for later analysis.
4245

4246
        instantiated_formula = copy.deepcopy(sub_formula)
1✔
4247
        set_smt_auto_subst(instantiated_formula, True)
1✔
4248
        set_smt_auto_eval(instantiated_formula, True)
1✔
4249
        instantiated_formula = instantiated_formula.substitute_expressions(
1✔
4250
            sub_formula.substitutions, force=True
4251
        )
4252

4253
        assert instantiated_formula in {sc.true(), sc.false()}
1✔
4254

4255
        return (
1✔
4256
            sc.true()
4257
            if (instantiated_formula == sc.true()) ^ self.in_negation_scope
4258
            else sub_formula
4259
        )
4260

4261

4262
def create_fixed_length_tree(
1✔
4263
    start: DerivationTree | str,
4264
    canonical_grammar: CanonicalGrammar,
4265
    target_length: int,
4266
) -> Optional[DerivationTree]:
4267
    nullable = compute_nullable_nonterminals(canonical_grammar)
1✔
4268
    start = DerivationTree(start) if isinstance(start, str) else start
1✔
4269
    stack: List[
1✔
4270
        Tuple[DerivationTree, int, ImmutableList[Tuple[Path, DerivationTree]]]
4271
    ] = [
4272
        (start, int(start.value not in nullable), (((), start),)),
4273
    ]
4274

4275
    while stack:
1✔
4276
        tree, curr_len, open_leaves = stack.pop()
1✔
4277

4278
        if not open_leaves:
1✔
4279
            if curr_len == target_length:
1✔
4280
                return tree
1✔
4281
            else:
4282
                continue
1✔
4283

4284
        if curr_len > target_length:
1✔
4285
            continue
1✔
4286

4287
        idx: int
4288
        path: Path
4289
        leaf: DerivationTree
4290
        for idx, (path, leaf) in reversed(list(enumerate(open_leaves))):
1✔
4291
            terminal_expansions, expansions = get_expansions(
1✔
4292
                leaf.value, canonical_grammar
4293
            )
4294

4295
            if terminal_expansions:
1✔
4296
                expansions.append(random.choice(terminal_expansions))
1✔
4297

4298
            # Only choose one random terminal expansion; keep all nonterminal expansions
4299
            expansions = sorted(
1✔
4300
                expansions,
4301
                key=lambda expansion: len(
4302
                    [elem for elem in expansion if is_nonterminal(elem)]
4303
                ),
4304
            )
4305

4306
            for expansion in reversed(expansions):
1✔
4307
                new_children = tuple(
1✔
4308
                    [
4309
                        DerivationTree(elem, None if is_nonterminal(elem) else ())
4310
                        for elem in expansion
4311
                    ]
4312
                )
4313

4314
                expanded_tree = tree.replace_path(
1✔
4315
                    path,
4316
                    DerivationTree(
4317
                        leaf.value,
4318
                        new_children,
4319
                    ),
4320
                )
4321

4322
                stack.append(
1✔
4323
                    (
4324
                        expanded_tree,
4325
                        curr_len
4326
                        + sum(
4327
                            [
4328
                                len(child.value)
4329
                                if child.children == ()
4330
                                else (1 if child.value not in nullable else 0)
4331
                                for child in new_children
4332
                            ]
4333
                        )
4334
                        - int(leaf.value not in nullable),
4335
                        open_leaves[:idx]
4336
                        + tuple(
4337
                            [
4338
                                (path + (child_idx,), new_child)
4339
                                for child_idx, new_child in enumerate(new_children)
4340
                                if is_nonterminal(new_child.value)
4341
                            ]
4342
                        )
4343
                        + open_leaves[idx + 1 :],
4344
                    )
4345
                )
4346

4347
    return None
1✔
4348

4349

4350
def subtree_solutions(
1✔
4351
    solution: Dict[language.Constant | DerivationTree, DerivationTree]
4352
) -> Dict[language.Variable | DerivationTree, DerivationTree]:
4353
    solution_with_subtrees: Dict[
1✔
4354
        language.Variable | DerivationTree, DerivationTree
4355
    ] = {}
4356
    for orig, subst in solution.items():
1✔
4357
        if isinstance(orig, language.Variable):
1✔
4358
            solution_with_subtrees[orig] = subst
1✔
4359
            continue
1✔
4360

4361
        assert isinstance(
1✔
4362
            orig, DerivationTree
4363
        ), f"Expected a DerivationTree, given: {type(orig).__name__}"
4364

4365
        # Note: It can happen that a path in the original tree is not valid in the
4366
        #       substitution, e.g., if we happen to replace a larger with a smaller
4367
        #       tree.
4368
        for path, tree in [
1✔
4369
            (p, t)
4370
            for p, t in orig.paths()
4371
            if t not in solution_with_subtrees and subst.is_valid_path(p)
4372
        ]:
4373
            solution_with_subtrees[tree] = subst.get_subtree(path)
1✔
4374

4375
    return solution_with_subtrees
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc