• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

CityOfZion / neo3-boa / 5c8eab45-f99f-4697-a4ff-d2c04d4d17ac

01 Feb 2024 09:59PM UTC coverage: 92.107% (+0.5%) from 91.625%
5c8eab45-f99f-4697-a4ff-d2c04d4d17ac

push

circleci

Mirella de Medeiros
Bump version: 1.1.0 → 1.1.1

1 of 1 new or added line in 1 file covered. (100.0%)

302 existing lines in 22 files now uncovered.

20784 of 22565 relevant lines covered (92.11%)

2.76 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.79
/boa3/internal/analyser/astoptimizer.py
1
import ast
3✔
2
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
3✔
3

4
from boa3.internal import constants
3✔
5
from boa3.internal.analyser.astanalyser import IAstAnalyser
3✔
6
from boa3.internal.analyser.model.optimizer import ScopeValue, Undefined
3✔
7
from boa3.internal.analyser.model.optimizer.Operation import Operation
3✔
8
from boa3.internal.exception import CompilerWarning
3✔
9
from boa3.internal.model.builtin.method.builtinmethod import IBuiltinMethod
3✔
10
from boa3.internal.model.method import Method
3✔
11
from boa3.internal.model.module import Module
3✔
12
from boa3.internal.model.operation.binary.binaryoperation import BinaryOperation
3✔
13
from boa3.internal.model.operation.operator import Operator
3✔
14
from boa3.internal.model.operation.unary.unaryoperation import UnaryOperation
3✔
15
from boa3.internal.model.property import Property
3✔
16
from boa3.internal.model.symbol import ISymbol
3✔
17
from boa3.internal.model.type.classes.userclass import UserClass
3✔
18
from boa3.internal.model.type.primitive.primitivetype import PrimitiveType
3✔
19

20

21
class AstOptimizer(IAstAnalyser, ast.NodeTransformer):
3✔
22
    """
23
    This class is responsible for reducing the generated ast.
24

25
    The methods with the name starting with 'visit_' are implementations of methods from the :class:`NodeVisitor` class.
26
    These methods are used to walk through the Python abstract syntax tree.
27

28
    :ivar modules: a list with the analysed modules. Empty by default.
29
    :ivar symbols: a dictionary that maps the global symbols.
30
    """
31

32
    def __init__(self, analyser, log: bool = False, fail_fast: bool = True):
3✔
33
        super().__init__(analyser.ast_tree, filename=analyser.filename, root_folder=analyser.root,
3✔
34
                         log=log, fail_fast=fail_fast)
35
        self.modules: Dict[str, Module] = {}
3✔
36
        self.symbols: Dict[str, ISymbol] = analyser.symbol_table
3✔
37

38
        self._is_optimizing: bool = False
3✔
39
        self.has_changes: bool = False
3✔
40
        self.current_scope: ScopeValue = ScopeValue()
3✔
41

42
        self._current_class: UserClass = None
3✔
43

44
        self.analyse_visit(self._tree)
3✔
45

46
    @property
3✔
47
    def tree(self) -> ast.AST:
3✔
48
        """
49
        Gets the analysed abstract syntax tree
50

51
        :return: the analysed ast
52
        """
53
        return self._tree
×
54

55
    def literal_eval(self, node: ast.AST) -> Any:
3✔
56
        """
57
        Evaluates an expression node containing a Python expression.
58

59
        :param node: the node that will be evaluated
60
        :return: the evaluated expression if the node is valid. Otherwise, returns Undefined.
61
        """
62
        try:
3✔
63
            return ast.literal_eval(node)
3✔
64
        except BaseException:
3✔
65
            return Undefined
3✔
66

67
    def parse_to_node(self, expression: str, origin: ast.AST = None, is_origin_str: bool = False) -> Union[ast.AST, Sequence[ast.AST]]:
3✔
68
        """
69
        Parses an expression to an ast.
70

71
        :param expression: string expression to be parsed
72
        :param origin: an existing ast. If not None, the parsed node will have the same location of origin.
73
        :return: the parsed node
74
        :rtype: ast.AST or Sequence[ast.AST]
75
        """
76
        if is_origin_str:
3✔
77
            expression = "'{0}'".format(expression)
3✔
78

79
        new_node = self.visit(super().parse_to_node(expression, origin))
3✔
80
        if hasattr(new_node, 'op'):
3✔
81
            new_node.op = Operator.get_operation(new_node.op)
3✔
82

83
        return new_node
3✔
84

85
    def reset_state(self):
3✔
86
        self.current_scope.reset()
3✔
87

88
    def get_symbol_id(self, node: ast.AST) -> Optional[str]:
3✔
89
        parts = []
3✔
90
        cur_node = node
3✔
91
        while isinstance(cur_node, ast.Attribute):
3✔
92
            parts.insert(0, cur_node.attr)
3✔
93
            cur_node = cur_node.value
3✔
94

95
        if isinstance(cur_node, ast.Name):
3✔
96
            parts.insert(0, cur_node.id)
3✔
97

98
        return constants.ATTRIBUTE_NAME_SEPARATOR.join(parts)
3✔
99

100
    def visit_ClassDef(self, node: ast.ClassDef) -> Any:
3✔
101
        if node.name in self.symbols:
3✔
102
            class_symbol = self.symbols[node.name]
3✔
103
            if isinstance(class_symbol, UserClass):
3✔
104
                self._current_class = class_symbol
3✔
105

106
        self.generic_visit(node)
3✔
107
        self._current_class = None
3✔
108
        return node
3✔
109

110
    def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
3✔
111
        symbols = self.symbols if self._current_class is None else self._current_class.symbols
3✔
112
        method = symbols[node.name]
3✔
113

114
        if isinstance(method, Property):
3✔
115
            method = method.getter
3✔
116

117
        if isinstance(method, Method):
3✔
118
            self._is_optimizing = True
3✔
119
            self.has_changes = True
3✔
120

121
            while self.has_changes:
3✔
122
                self.reset_state()
3✔
123
                self.has_changes = False
3✔
124

125
                super().generic_visit(node)
3✔
126

127
        self.end_function_optimization()
3✔
128
        return node
3✔
129

130
    def end_function_optimization(self):
3✔
131
        self.reset_state()
3✔
132
        self.has_changes = False
3✔
133
        self._is_optimizing = False
3✔
134

135
    def visit_Assign(self, node: ast.Assign) -> ast.AST:
3✔
136
        super().generic_visit(node)
3✔
137
        self.set_variables_value(node.targets, node.value)
3✔
138
        return node
3✔
139

140
    def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AST:
3✔
141
        super().generic_visit(node)
3✔
142
        self.set_variables_value([node.target], node.value)
3✔
143
        return node
3✔
144

145
    def visit_AugAssign(self, node: ast.AugAssign) -> ast.AST:
3✔
146
        super().generic_visit(node)
3✔
147

148
        value = self.parse_to_node("x+y", node)
3✔
149
        value.left = node.target
3✔
150
        value.op = node.op
3✔
151
        value.right = node.value
3✔
152

153
        self.set_variables_value([node.target], value)
3✔
154
        return node
3✔
155

156
    def set_variables_value(self, targets: List[ast.AST], value: ast.AST):
3✔
157
        new_value = self.literal_eval(value)
3✔
158
        for target in targets:
3✔
159
            if isinstance(target, ast.Name) and isinstance(target.ctx, ast.Store):
3✔
160
                self.current_scope.assign(target.id)
3✔
161

162
                if new_value is not Undefined:
3✔
163
                    self.current_scope[target.id] = new_value
3✔
164
                elif target.id in self.current_scope:
3✔
165
                    self.current_scope.remove(target.id)
3✔
166

167
    def visit_BinOp(self, bin_op: ast.BinOp) -> ast.AST:
3✔
168
        """
169
        Visitor of a binary operation node
170

171
        :param bin_op: the python ast binary operation node
172
        """
173
        try:
3✔
174
            super().generic_visit(bin_op)
3✔
175

176
            left_value = self.literal_eval(bin_op.left)
3✔
177
            right_value = self.literal_eval(bin_op.right)
3✔
178

179
            if (left_value is Undefined and isinstance(bin_op.left, ast.BinOp)
3✔
180
                    and self.is_symmetric_operation(bin_op.op, bin_op.left.op)):
181
                left_value, right_value = self.reorder_operations(bin_op, bin_op.left)
3✔
182
            elif (right_value is Undefined and isinstance(bin_op.right, ast.BinOp)
3✔
183
                  and self.is_symmetric_operation(bin_op.op, bin_op.right.op)):
184
                left_value, right_value = self.reorder_operations(bin_op, bin_op.right)
3✔
185

186
            value = self._evaluate_binary_operation(left_value, right_value, bin_op.op)
3✔
187
            if value is not None:
3✔
188
                self.has_changes = True
3✔
189
                return self.parse_to_node(str(value), bin_op, isinstance(value, str))
3✔
190
            return bin_op
3✔
191
        except ValueError:
×
192
            return bin_op
×
193

194
    def is_symmetric_operation(self, first_op: BinaryOperation, second_op: BinaryOperation) -> bool:
3✔
195
        if not isinstance(first_op, BinaryOperation) or not isinstance(second_op, BinaryOperation):
3✔
196
            return False
×
197

198
        operation = type(first_op)
3✔
199
        second_operation = type(second_op)
3✔
200

201
        if operation != second_operation:
3✔
202
            return False
3✔
203

204
        return first_op.is_symmetric
3✔
205

206
    def reorder_operations(self, outer_bin_op: ast.BinOp, inner_bin_op: ast.BinOp) -> Tuple[Any, Any]:
3✔
207
        inner_first_value = self.literal_eval(inner_bin_op.left)
3✔
208
        inner_second_value = self.literal_eval(inner_bin_op.right)
3✔
209

210
        if (not (isinstance(outer_bin_op.op, BinaryOperation) and outer_bin_op.op.is_symmetric)
3✔
211
                or not (isinstance(outer_bin_op.op, BinaryOperation) and outer_bin_op.op.is_symmetric)):
212
            return inner_first_value, inner_second_value
×
213

214
        is_left_operand: bool = inner_bin_op is outer_bin_op.left
3✔
215
        other_value = self.literal_eval(outer_bin_op.right if is_left_operand else outer_bin_op.left)
3✔
216

217
        if inner_first_value is not Undefined or inner_second_value is not Undefined:
3✔
218
            if inner_first_value is Undefined:
3✔
219
                if other_value is Undefined:
3✔
220
                    if is_left_operand:
3✔
221
                        # (x + 1) + y -> (x + y) + 1
222
                        inner_bin_op.right, outer_bin_op.right = outer_bin_op.right, inner_bin_op.right
3✔
223
                    else:
224
                        # y + (x + 1) -> 1 + (x + y)
225
                        inner_bin_op.right, outer_bin_op.left = outer_bin_op.left, inner_bin_op.right
×
226
                else:
227
                    if is_left_operand:
3✔
228
                        # (x + 2) + 1 -> (1 + 2) + x
229
                        inner_bin_op.left, outer_bin_op.right = outer_bin_op.right, inner_bin_op.left
3✔
230
                    else:
231
                        # 1 + (x + 2) -> x + (1 + 2)
232
                        inner_bin_op.left, outer_bin_op.left = outer_bin_op.left, inner_bin_op.left
×
233
            else:
234
                if other_value is Undefined:
3✔
235
                    if is_left_operand:
3✔
236
                        # (1 + x) + y -> (y + x) + 1
237
                        inner_bin_op.left, outer_bin_op.right = outer_bin_op.right, inner_bin_op.left
3✔
238
                    else:
239
                        # y + (1 + x) ->  1 + (y + x)
240
                        inner_bin_op.left, outer_bin_op.left = outer_bin_op.left, inner_bin_op.left
×
241
                else:
242
                    if is_left_operand:
3✔
243
                        # (2 + x) + 1 -> (2 + 1) + x
244
                        inner_bin_op.right, outer_bin_op.right = outer_bin_op.right, inner_bin_op.right
3✔
245
                    else:
246
                        # 1 + (2 + x) -> x + (2 + 1)
247
                        inner_bin_op.right, outer_bin_op.left = outer_bin_op.left, inner_bin_op.right
3✔
248

249
        super().generic_visit(outer_bin_op)
3✔
250

251
        return self.literal_eval(outer_bin_op), self.literal_eval(outer_bin_op)
3✔
252

253
    def _evaluate_binary_operation(self, left: Any, right: Any,
3✔
254
                                   op: Union[ast.operator, BinaryOperation]) -> Optional[Any]:
255
        operator = Operation.get_operation(op)
3✔
256
        try:
3✔
257
            if operator is Operation.Add:
3✔
258
                return left + right
3✔
259
            if operator is Operation.Sub:
3✔
260
                return left - right
3✔
261
            if operator is Operation.Mult:
3✔
262
                return left * right
3✔
263
            if operator is Operation.FloorDiv:
3✔
264
                return left // right
3✔
265
            if operator is Operation.Mod:
3✔
266
                return left % right
3✔
267
            return None
3✔
268
        except BaseException:
3✔
269
            return None
3✔
270

271
    def visit_UnaryOp(self, un_op: ast.UnaryOp) -> ast.AST:
3✔
272
        """
273
        Visitor of a binary operation node
274

275
        :param un_op: the python ast binary operation node
276
        """
277
        try:
3✔
278
            self.visit(un_op.operand)
3✔
279

280
            operand_value = ast.literal_eval(un_op.operand)
3✔
281

282
            value = self._evaluate_unary_operation(operand_value, un_op.op)
3✔
283
            if value is not None:
3✔
284
                self.has_changes = True
3✔
285
                if hasattr(un_op.operand, 'n'):
3✔
286
                    un_op.operand.n = value
3✔
287
                    self.update_line_and_col(un_op.operand, un_op)
3✔
288
                    return un_op.operand
3✔
289
                return self.parse_to_node(str(value), un_op, isinstance(value, str))
×
290
            return un_op
×
291
        except ValueError:
3✔
292
            return un_op
3✔
293

294
    def _evaluate_unary_operation(self, operand: Any, op: Union[ast.operator, UnaryOperation]) -> Optional[Any]:
3✔
295
        operator = Operation.get_operation(op)
3✔
296
        try:
3✔
297
            if operator is Operation.Add:
3✔
298
                return +operand
×
299
            elif operator is Operation.Sub:
3✔
300
                return -operand
3✔
301
            return None
×
302
        except BaseException:
×
303
            return None
×
304

305
    def visit_Match(self, match_node: ast.Match) -> ast.AST:
3✔
306
        self.visit(match_node.subject)
3✔
307

308
        case_scopes = []
3✔
309

310
        match_scope = self.current_scope
3✔
311

312
        for case in match_node.cases:
3✔
313
            case_scopes.append(match_scope.new_scope())
3✔
314

315
            self.current_scope = case_scopes[-1]
3✔
316
            for stmt in case.body:
3✔
317
                self.visit(stmt)
3✔
318

319
        self.current_scope = match_scope
3✔
320
        self.current_scope.update_values(*case_scopes)
3✔
321

322
        return match_node
3✔
323

324
    def visit_If(self, node: ast.If) -> ast.AST:
3✔
325
        self.visit(node.test)
3✔
326

327
        if_scope: ScopeValue = self.current_scope.new_scope()
3✔
328
        else_scope: ScopeValue = self.current_scope.new_scope()
3✔
329

330
        self.current_scope = if_scope
3✔
331
        for stmt in node.body:
3✔
332
            self.visit(stmt)
3✔
333

334
        if len(node.orelse) > 0:
3✔
335
            self.current_scope = else_scope
3✔
336
            for stmt in node.orelse:
3✔
337
                self.visit(stmt)
3✔
338

339
        self.current_scope = self.current_scope.previous_scope()
3✔
340
        self.current_scope.update_values(if_scope, else_scope)
3✔
341

342
        return node
3✔
343

344
    def visit_loop_body(self, node: ast.AST):
3✔
345
        if not hasattr(node, 'body') or not hasattr(node, 'orelse'):
3✔
UNCOV
346
            self.generic_visit(node)
×
347

348
        loop_scope: ScopeValue = self.current_scope.new_scope()
3✔
349
        # TODO: substitute the variables only if they're not reassigned inside the loop #2kq0wk3
350
        loop_scope.reset()
3✔
351

352
        self.current_scope = loop_scope
3✔
353
        for stmt in node.body:
3✔
354
            self.visit(stmt)
3✔
355

356
        if len(node.orelse) > 0:
3✔
357
            else_scope: ScopeValue = self.current_scope.new_scope()
3✔
358
            self.current_scope = else_scope
3✔
359

360
            for stmt in node.orelse:
3✔
361
                self.visit(stmt)
3✔
362

363
            self.current_scope = else_scope.previous_scope()
3✔
364
            loop_scope.update_values(else_scope, is_loop_scope=True)
3✔
365

366
        outer_scope = self.current_scope.previous_scope()
3✔
367
        outer_scope.update_values(loop_scope)
3✔
368

369
        return loop_scope
3✔
370

371
    def visit_For(self, node: ast.For) -> ast.AST:
3✔
372
        self.visit(node.iter)
3✔
373

374
        for_scope = self.visit_loop_body(node)
3✔
375
        self.current_scope = for_scope.previous_scope()
3✔
376

377
        return node
3✔
378

379
    def visit_While(self, node: ast.While) -> ast.AST:
3✔
380
        while_scope = self.visit_loop_body(node)
3✔
381
        self.visit(node.test)
3✔
382

383
        self.current_scope = while_scope.previous_scope()
3✔
384

385
        return node
3✔
386

387
    def visit_Try(self, node: ast.Try) -> ast.AST:
3✔
388
        outer_scope = self.current_scope
3✔
389
        try_scope: ScopeValue = self.current_scope.new_scope()
3✔
390
        except_scopes: List[ScopeValue] = []
3✔
391

392
        self.current_scope = try_scope
3✔
393
        for stmt in node.body:
3✔
394
            self.visit(stmt)
3✔
395

396
        if len(node.handlers) > 0:
3✔
397
            for handler in node.handlers:
3✔
398
                except_scope = outer_scope.new_scope()
3✔
399
                self.current_scope = except_scope
3✔
400

401
                for stmt in handler.body:
3✔
402
                    self.visit(stmt)
3✔
403

404
                except_scopes.append(except_scope)
3✔
405

406
        if len(node.orelse) > 0:
3✔
407
            else_scope = outer_scope.new_scope()
3✔
408
            self.current_scope = else_scope
3✔
409

410
            for stmt in node.orelse:
3✔
411
                self.visit(stmt)
3✔
412

413
            except_scopes.append(else_scope)
3✔
414

415
        self.current_scope = self.current_scope.previous_scope()
3✔
416
        self.current_scope.update_values(try_scope, *except_scopes)
3✔
417

418
        for stmt in node.finalbody:
3✔
419
            self.visit(stmt)
3✔
420

421
        return node
3✔
422

423
    def visit_Name(self, node: ast.Name) -> ast.AST:
3✔
424
        if (isinstance(node.ctx, ast.Load)
3✔
425
                and node.id in self.current_scope
426
                and isinstance(self.get_type(self.current_scope[node.id]), PrimitiveType)):
427
            # only values from int, bool, str and bytes types are going to replace the variable
428
            # TODO: check if it's worth to replace other types #2kq0zhe
429
            value = self.current_scope[node.id]
3✔
430
            if isinstance(value, str):
3✔
431
                value = "'{0}'".format(value)
3✔
432
            return self.parse_to_node(str(value), node)
3✔
433
        return node
3✔
434

435
    def visit_Call(self, node: ast.Call) -> ast.AST:
3✔
436
        # check if the call can be evaluated during compile time
437
        # TODO: right now only UInt160 and UInt256 constructors are evaluated #2kq12zd
438
        literal_args = []
3✔
439
        args_are_literal = True
3✔
440

441
        for index, arg in enumerate(node.args.copy()):
3✔
442
            updated_arg = self.visit(arg)  # first try to optimize the arguments
3✔
443
            if updated_arg != arg:
3✔
444
                node.args[index] = updated_arg
3✔
445

446
            if args_are_literal:
3✔
447
                value = self.literal_eval(updated_arg)
3✔
448
                if value is Undefined:
3✔
449
                    # don't break if one argument is not literal to make sure that all arguments were checked
450
                    # if they can be optimized
451
                    args_are_literal = False
3✔
452

453
                literal_args.append(value)
3✔
454

455
        if args_are_literal:
3✔
456
            # try to get the result
457
            try:
3✔
458
                func_id = self.get_symbol_id(node.func)
3✔
UNCOV
459
            except BaseException:
×
UNCOV
460
                return node
×
461
            func = self.get_symbol(func_id)
3✔
462

463
            if isinstance(func, IBuiltinMethod):
3✔
464
                try:
3✔
465
                    result = func.evaluate_literal(*literal_args)
3✔
466
                    if result is not Undefined:
3✔
467
                        return self.parse_to_node(str(result), node, is_origin_str=isinstance(result, str))
3✔
UNCOV
468
                except BaseException:
×
UNCOV
469
                    self._log_warning(CompilerWarning.InvalidArgument(
×
470
                        node.lineno, node.col_offset
471
                    ))
472

473
        return node
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc