• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In
Build has been canceled!

FEniCS / ffcx / 11642291501

02 Nov 2024 11:16AM UTC coverage: 81.168% (+0.5%) from 80.657%
11642291501

push

github

web-flow
Upload to coveralls and docs from CI job running against python 3.12 (#726)

3474 of 4280 relevant lines covered (81.17%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.37
/ffcx/codegeneration/optimizer.py
1
"""Optimizer."""
2

3
from collections import defaultdict
1✔
4
from typing import Union
1✔
5

6
import ffcx.codegeneration.lnodes as L
1✔
7
from ffcx.ir.representationutils import QuadratureRule
1✔
8

9

10
def optimize(code: list[L.LNode], quadrature_rule: QuadratureRule) -> list[L.LNode]:
1✔
11
    """Optimize code.
12

13
    Args:
14
        code: List of LNodes to optimize.
15
        quadrature_rule: TODO.
16

17
    Returns:
18
        Optimized list of LNodes.
19
    """
20
    # Fuse sections with the same name and same annotations
21
    code = fuse_sections(code, "Coefficient")
1✔
22
    code = fuse_sections(code, "Jacobian")
1✔
23
    for i, section in enumerate(code):
1✔
24
        if isinstance(section, L.Section):
1✔
25
            if L.Annotation.fuse in section.annotations:
1✔
26
                section = fuse_loops(section)
1✔
27
            if L.Annotation.licm in section.annotations:
1✔
28
                section = licm(section, quadrature_rule)
1✔
29
            code[i] = section
1✔
30

31
    return code
1✔
32

33

34
def fuse_sections(code: list[L.LNode], name: str) -> list[L.LNode]:
1✔
35
    """Fuse sections with the same name.
36

37
    Args:
38
        code: List of LNodes to fuse.
39
        name: Common name used by the sections that should be fused
40

41
    Returns:
42
        Fused list of LNodes.
43
    """
44
    statements: list[L.LNode] = []
1✔
45
    indices: list[int] = []
1✔
46
    input: list[L.Symbol] = []
1✔
47
    output: list[L.Symbol] = []
1✔
48
    declarations: list[L.Declaration] = []
1✔
49
    annotations: list[L.Annotation] = []
1✔
50

51
    for i, section in enumerate(code):
1✔
52
        if isinstance(section, L.Section):
1✔
53
            if section.name == name:
1✔
54
                declarations.extend(section.declarations)
1✔
55
                statements.extend(section.statements)
1✔
56
                indices.append(i)
1✔
57
                input.extend(section.input)
1✔
58
                output.extend(section.output)
1✔
59
                annotations = section.annotations
1✔
60

61
    # Remove duplicated inputs
62
    input = list(set(input))
1✔
63
    # Remove duplicated outputs
64
    output = list(set(output))
1✔
65

66
    section = L.Section(name, statements, declarations, input, output, annotations)
1✔
67

68
    # Replace the first section with the fused section
69
    code = code.copy()
1✔
70
    if indices:
1✔
71
        code[indices[0]] = section
1✔
72
        # Remove the other sections
73
        code = [c for i, c in enumerate(code) if i not in indices[1:]]
1✔
74

75
    return code
1✔
76

77

78
def fuse_loops(code: L.Section) -> L.Section:
1✔
79
    """Fuse loops with the same range and same annotations.
80

81
    Args:
82
        code: List of LNodes to fuse.
83

84
    Returns:
85
        Fused list of LNodes.
86
    """
87
    loops = defaultdict(list)
1✔
88
    output_code = []
1✔
89
    for statement in code.statements:
1✔
90
        if isinstance(statement, L.ForRange):
1✔
91
            id = (statement.index, statement.begin, statement.end)
1✔
92
            loops[id].append(statement.body)
1✔
93
        else:
94
            output_code.append(statement)
×
95

96
    for range, body in loops.items():
1✔
97
        output_code.append(L.ForRange(*range, body))
1✔
98

99
    return L.Section(code.name, output_code, code.declarations, code.input, code.output)
1✔
100

101

102
def get_statements(statement: Union[L.Statement, L.StatementList]) -> list[L.LNode]:
1✔
103
    """Get statements from a statement list.
104

105
    Args:
106
        statement: Statement list.
107

108
    Returns:
109
        List of statements.
110
    """
111
    if isinstance(statement, L.StatementList):
1✔
112
        return [statement.expr for statement in statement.statements]
1✔
113
    else:
114
        return [statement.expr]
1✔
115

116

117
def check_dependency(statement: L.Statement, index: L.Symbol) -> bool:
1✔
118
    """Check if a statement depends on a given index.
119

120
    Args:
121
        statement: Statement to check.
122
        index: Index to check.
123

124
    Returns:
125
        True if statement depends on index, False otherwise.
126
    """
127
    if isinstance(statement, L.ArrayAccess):
1✔
128
        if index in statement.indices:
1✔
129
            return True
×
130
        else:
131
            for i in statement.indices:
1✔
132
                if isinstance(i, L.Sum) or isinstance(i, L.Product):
1✔
133
                    if index in i.args:
1✔
134
                        return True
1✔
135
    elif isinstance(statement, L.Symbol):
1✔
136
        return False
1✔
137
    elif isinstance(statement, L.LiteralFloat) or isinstance(statement, L.LiteralInt):
×
138
        return False
×
139
    else:
140
        raise NotImplementedError(f"Statement {statement} not supported.")
×
141

142
    return False
1✔
143

144

145
def licm(section: L.Section, quadrature_rule: QuadratureRule) -> L.Section:
1✔
146
    """Perform loop invariant code motion.
147

148
    Args:
149
        section: List of LNodes to optimize.
150
        quadrature_rule: TODO.
151

152
    Returns:
153
        Optimized list of LNodes.
154
    """
155
    assert L.Annotation.licm in section.annotations
1✔
156

157
    counter = 0
1✔
158

159
    # Check depth of loops
160
    depth = L.depth(section.statements[0])
1✔
161
    if depth != 2:
1✔
162
        return section
1✔
163

164
    # Get statements in the inner loop
165
    outer_loop = section.statements[0]
1✔
166
    inner_loop = outer_loop.body.statements[0]
1✔
167

168
    # Collect all expressions in the inner loop by corresponding RHS
169
    expressions = defaultdict(list)
1✔
170
    for body in inner_loop.body.statements:
1✔
171
        statements = get_statements(body)
1✔
172
        assert isinstance(statements, list)
1✔
173
        for statement in statements:
1✔
174
            assert isinstance(statement, L.AssignAdd)  # Expecting AssignAdd
1✔
175
            rhs = statement.rhs
1✔
176
            assert isinstance(rhs, L.Product)  # Expecting Sum
1✔
177
            lhs = statement.lhs
1✔
178
            assert isinstance(lhs, L.ArrayAccess)  # Expecting ArrayAccess
1✔
179
            expressions[lhs].append(rhs)
1✔
180

181
    pre_loop: list[L.LNode] = []
1✔
182
    for lhs, rhs in expressions.items():
1✔
183
        for r in rhs:
1✔
184
            hoist_candidates = []
1✔
185
            for arg in r.args:
1✔
186
                dependency = check_dependency(arg, inner_loop.index)
1✔
187
                if not dependency:
1✔
188
                    hoist_candidates.append(arg)
1✔
189
            if len(hoist_candidates) > 1:
1✔
190
                # create new temp
191
                name = f"temp_{counter}"
1✔
192
                counter += 1
1✔
193
                temp = L.Symbol(name, L.DataType.SCALAR)
1✔
194
                for h in hoist_candidates:
1✔
195
                    r.args.remove(h)
1✔
196
                # update expression with new temp
197
                r.args.append(L.ArrayAccess(temp, [outer_loop.index]))
1✔
198
                # create code for hoisted term
199
                size = outer_loop.end.value - outer_loop.begin.value
1✔
200
                pre_loop.append(L.ArrayDecl(temp, size, [0]))
1✔
201
                body = L.Assign(
1✔
202
                    L.ArrayAccess(temp, [outer_loop.index]), L.Product(hoist_candidates)
203
                )
204
                pre_loop.append(
1✔
205
                    L.ForRange(outer_loop.index, outer_loop.begin, outer_loop.end, [body])
206
                )
207

208
    section.statements = pre_loop + section.statements
1✔
209

210
    return section
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc