• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IntelPython / numba-dpex / 8368790851

21 Mar 2024 02:26AM UTC coverage: 82.668% (-0.4%) from 83.041%
8368790851

push

github

web-flow
Merge pull request #1398 from IntelPython/feature/specialization_device_func

Feature/specialization device func

1633 of 2249 branches covered (72.61%)

Branch coverage included in aggregate %.

11 of 11 new or added lines in 1 file covered. (100.0%)

29 existing lines in 2 files now uncovered.

6628 of 7744 relevant lines covered (85.59%)

0.86 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

62.37
/numba_dpex/core/kernel_interface/func.py
1
# SPDX-FileCopyrightText: 2022 - 2024 Intel Corporation
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from numba.core import sigutils, types
1✔
6
from numba.core.typing.templates import AbstractTemplate, ConcreteTemplate
1✔
7

8
from numba_dpex.core import config
1✔
9
from numba_dpex.core.caching import LRUCache, NullCache
1✔
10
from numba_dpex.core.compiler import compile_with_dpex
1✔
11
from numba_dpex.core.descriptor import dpex_kernel_target
1✔
12
from numba_dpex.core.utils import (
1✔
13
    build_key,
14
    create_func_hash,
15
    strip_usm_metadata,
16
)
17

18

19
class DpexFunction(object):
1✔
20
    """Class to materialize dpex function
21

22
    Helper class to eager compile a specialized `numba_dpex.func`
23
    decorated Python function into a LLVM function with `spir_func`
24
    calling convention.
25

26
    A specialized `numba_dpex.func` decorated Python function is one
27
    where the user has specified a signature or a list of signatures
28
    for the function. The function gets compiled as soon as the Python
29
    program is loaded, i.e., eagerly, instead of JIT compilation once
30
    the function is invoked.
31
    """
32

33
    def __init__(self, pyfunc, debug=False):
1✔
34
        """Constructor for `DpexFunction`
35

36
        Args:
37
            pyfunc (`function`): A python function to be compiled.
38
            debug (`bool`, optional): Debug option for compilation.
39
                Defaults to `False`.
40
        """
UNCOV
41
        self._pyfunc = pyfunc
×
UNCOV
42
        self._debug = debug
×
43

44
    def compile(self, arg_types, return_types):
1✔
45
        """The actual compilation function.
46

47
        Args:
48
            arg_types (`tuple`): Function argument types in a tuple.
49
            return_types (`numba.core.types.scalars.Integer`):
50
                An integer value to specify the return type.
51

52
        Returns:
53
            `numba.core.compiler.CompileResult`: The compiled result
54
        """
55

UNCOV
56
        cres = compile_with_dpex(
×
57
            pyfunc=self._pyfunc,
58
            pyfunc_name=self._pyfunc.__name__,
59
            return_type=return_types,
60
            target_context=dpex_kernel_target.target_context,
61
            typing_context=dpex_kernel_target.typing_context,
62
            args=arg_types,
63
            is_kernel=False,
64
            debug=self._debug,
65
        )
UNCOV
66
        func = cres.library.get_function(cres.fndesc.llvm_func_name)
×
UNCOV
67
        cres.target_context.set_spir_func_calling_conv(func)
×
68

UNCOV
69
        return cres
×
70

71

72
class DpexFunctionTemplate(object):
1✔
73
    """Helper class to compile an unspecialized `numba_dpex.func`
74

75
    A helper class to JIT compile an unspecialized `numba_dpex.func`
76
    decorated Python function into an LLVM function with `spir_func`
77
    calling convention.
78
    """
79

80
    def __init__(self, pyfunc, debug=False, enable_cache=True):
1✔
81
        """Constructor for `DpexFunctionTemplate`
82

83
        Args:
84
            pyfunc (function): A python function to be compiled.
85
            debug (bool, optional): Debug option for compilation.
86
                Defaults to `False`.
87
            enable_cache (bool, optional): Flag to turn on/off caching.
88
                Defaults to `True`.
89
        """
90
        self._pyfunc = pyfunc
1✔
91
        self._debug = debug
1✔
92
        self._enable_cache = enable_cache
1✔
93

94
        self._func_hash = create_func_hash(pyfunc)
1✔
95

96
        if not config.ENABLE_CACHE:
1!
97
            self._cache = NullCache()
×
98
        elif self._enable_cache:
1!
99
            self._cache = LRUCache(
1✔
100
                name="DpexFunctionTemplateCache",
101
                capacity=config.CACHE_SIZE,
102
                pyfunc=self._pyfunc,
103
            )
104
        else:
105
            self._cache = NullCache()
×
106
        self._cache_hits = 0
1✔
107

108
    @property
1✔
109
    def cache(self):
1✔
110
        """Cache accessor"""
111
        return self._cache
×
112

113
    @property
1✔
114
    def cache_hits(self):
1✔
115
        """Cache hit count accessor"""
116
        return self._cache_hits
×
117

118
    def compile(self, args):
1✔
119
        """Compile a `numba_dpex.func` decorated function
120

121
        Compile a `numba_dpex.func` decorated Python function with the
122
        given argument types. Each signature is compiled once by caching
123
        the compiled function inside this object.
124

125
        Args:
126
            args (`tuple`): Function argument types in a tuple.
127

128
        Returns:
129
            `numba.core.typing.templates.Signature`: Signature of the
130
                compiled result.
131
        """
132

133
        argtypes = [
1✔
134
            dpex_kernel_target.typing_context.resolve_argument_type(arg)
135
            for arg in args
136
        ]
137

138
        # Generate key used for cache lookup
139
        stripped_argtypes = strip_usm_metadata(argtypes)
1✔
140
        codegen_magic_tuple = (
1✔
141
            dpex_kernel_target.target_context.codegen().magic_tuple()
142
        )
143
        key = build_key(stripped_argtypes, codegen_magic_tuple, self._func_hash)
1✔
144

145
        cres = self._cache.get(key)
1✔
146
        if cres is None:
1✔
147
            self._cache_hits += 1
1✔
148
            cres = compile_with_dpex(
1✔
149
                pyfunc=self._pyfunc,
150
                pyfunc_name=self._pyfunc.__name__,
151
                return_type=None,
152
                target_context=dpex_kernel_target.target_context,
153
                typing_context=dpex_kernel_target.typing_context,
154
                args=args,
155
                is_kernel=False,
156
                debug=self._debug,
157
            )
158
            func = cres.library.get_function(cres.fndesc.llvm_func_name)
1✔
159
            cres.target_context.set_spir_func_calling_conv(func)
1✔
160
            libs = [cres.library]
1✔
161

162
            cres.target_context.insert_user_function(self, cres.fndesc, libs)
1✔
163
            self._cache.put(key, cres)
1✔
164
        return cres.signature
1✔
165

166

167
def compile_func(pyfunc, signature, debug=False):
1✔
168
    """Compiles a specialized `numba_dpex.func`
169

170
    Compiles a specialized `numba_dpex.func` decorated function to native binary
171
    library function and returns the library wrapped inside a
172
    `numba_dpex.core.kernel_interface.func.DpexFunction` object.
173

174
    Args:
175
        pyfunc (`function`): A python function to be compiled.
176
        signature (`list`): A list of `numba.core.typing.templates.Signature`'s
177
        debug (`bool`, optional): Debug options. Defaults to `False`.
178

179
    Returns:
180
        `numba_dpex.core.kernel_interface.func.DpexFunction`: A `DpexFunction`
181
         object
182
    """
183

UNCOV
184
    devfn = DpexFunction(pyfunc, debug=debug)
×
185

UNCOV
186
    cres = []
×
UNCOV
187
    for sig in signature:
×
UNCOV
188
        arg_types, return_types = sigutils.normalize_signature(sig)
×
UNCOV
189
        c = devfn.compile(arg_types, return_types)
×
UNCOV
190
        cres.append(c)
×
191

UNCOV
192
    class _function_template(ConcreteTemplate):
×
UNCOV
193
        unsafe_casting = False
×
UNCOV
194
        exact_match_required = True
×
UNCOV
195
        key = devfn
×
UNCOV
196
        cases = [c.signature for c in cres]
×
197

UNCOV
198
    cres[0].typing_context.insert_user_function(devfn, _function_template)
×
199

UNCOV
200
    for c in cres:
×
UNCOV
201
        c.target_context.insert_user_function(devfn, c.fndesc, [c.library])
×
202

UNCOV
203
    return devfn
×
204

205

206
def compile_func_template(pyfunc, debug=False, enable_cache=True):
1✔
207
    """Converts a `numba_dpex.func` function to an `AbstractTemplate`
208

209
    Converts a `numba_dpex.func` decorated function to a Numba
210
    `AbstractTemplate` and returns the object wrapped inside a
211
    `numba_dpex.core.kernel_interface.func.DpexFunctionTemplate`
212
    object.
213

214
    A `DpexFunctionTemplate` object is an abstract representation for
215
    a native function with `spir_func` calling convention that is to be
216
    JIT compiled once the argument types are resolved.
217

218
    Args:
219
        pyfunc (`function`): A python function to be compiled.
220
        debug (`bool`, optional): Debug options. Defaults to `False`.
221

222
    Raises:
223
        `AssertionError`: Raised if keyword arguments are supplied in
224
            the inner generic function.
225

226
    Returns:
227
        `numba_dpex.core.kernel_interface.func.DpexFunctionTemplate`:
228
            A `DpexFunctionTemplate` object.
229
    """
230

231
    dft = DpexFunctionTemplate(pyfunc, debug=debug, enable_cache=enable_cache)
1✔
232

233
    class _function_template(AbstractTemplate):
1✔
234
        unsafe_casting = False
1✔
235
        exact_match_required = True
1✔
236
        key = dft
1✔
237

238
        def generic(self, args, kws):
1✔
239
            if kws:
1!
240
                raise AssertionError("No keyword arguments allowed.")
×
241
            return dft.compile(args)
1✔
242

243
    dpex_kernel_target.typing_context.insert_user_function(
1✔
244
        dft, _function_template
245
    )
246

247
    return dft
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc