• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IntelPython / numba-dpex / 8303526931

15 Mar 2024 11:38PM UTC coverage: 82.781% (-0.02%) from 82.805%
8303526931

push

github

web-flow
Merge pull request #1385 from IntelPython/feature/inline_threashold

Feature/inline threshold

1574 of 2167 branches covered (72.63%)

Branch coverage included in aggregate %.

4 of 5 new or added lines in 2 files covered. (80.0%)

1 existing line in 1 file now uncovered.

6493 of 7578 relevant lines covered (85.68%)

0.86 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.0
/numba_dpex/core/kernel_interface/spirv_kernel.py
1
# SPDX-FileCopyrightText: 2022 - 2024 Intel Corporation
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import logging
1✔
6
from types import FunctionType
1✔
7

8
from numba.core import ir
1✔
9

10
from numba_dpex.core import config
1✔
11
from numba_dpex.core.compiler import compile_with_dpex
1✔
12
from numba_dpex.core.exceptions import UncompiledKernelError, UnreachableError
1✔
13
from numba_dpex.kernel_api_impl.spirv import spirv_generator
1✔
14
from numba_dpex.kernel_api_impl.spirv.target import SPIRVTargetContext
1✔
15

16
from .kernel_base import KernelInterface
1✔
17

18

19
class SpirvKernel(KernelInterface):
1✔
20
    def __init__(self, func, func_name) -> None:
1✔
21
        """Represents a SPIR-V module compiled for a Python function.
22

23
        Args:
24
            func: The function to be compiled. Can be a Python function or a
25
            Numba IR object representing a function.
26
            func_name (str): Name of the function being compiled
27

28
        Raises:
29
            UnreachableError: An internal error indicating an unexpected code
30
            path was executed.
31
        """
32
        self._llvm_module = None
1✔
33
        self._device_driver_ir_module = None
1✔
34
        self._module_name = None
1✔
35
        self._pyfunc_name = func_name
1✔
36
        self._func = func
1✔
37
        if isinstance(func, FunctionType):
1✔
38
            self._func_ty = FunctionType
1✔
39
        elif isinstance(func, ir.FunctionIR):
1!
40
            self._func_ty = ir.FunctionIR
1✔
41
        else:
42
            raise UnreachableError()
×
43
        self._target_context = None
1✔
44

45
    @property
1✔
46
    def llvm_module(self):
1✔
47
        """The LLVM IR Module corresponding to the Kernel instance."""
48
        if self._llvm_module:
1!
49
            return self._llvm_module
1✔
50
        else:
51
            raise UncompiledKernelError(self._pyfunc_name)
×
52

53
    @property
1✔
54
    def device_driver_ir_module(self):
1✔
55
        """The module in a device IR (such as SPIR-V or PTX) format."""
56
        if self._device_driver_ir_module:
1!
57
            return self._device_driver_ir_module
1✔
58
        else:
59
            raise UncompiledKernelError(self._pyfunc_name)
×
60

61
    @property
1✔
62
    def pyfunc_name(self):
1✔
63
        """The Python function name corresponding to the kernel."""
64
        return self._pyfunc_name
×
65

66
    @property
1✔
67
    def module_name(self):
1✔
68
        """The name of the compiled LLVM module for the kernel."""
69
        if self._module_name:
1!
70
            return self._module_name
1✔
71
        else:
72
            raise UncompiledKernelError(self._pyfunc_name)
×
73

74
    @property
1✔
75
    def target_context(self):
1✔
76
        """Returns the target context that was used to compile the kernel.
77

78
        Raises:
79
            UncompiledKernelError: If the kernel was not yet compiled.
80

81
        Returns:
82
            target context used to compile the kernel
83
        """
84
        if self._target_context:
1!
85
            return self._target_context
1✔
86
        else:
87
            raise UncompiledKernelError(self._pyfunc_name)
×
88

89
    @property
1✔
90
    def typing_context(self):
1✔
91
        """Returns the typing context that was used to compile the kernel.
92

93
        Raises:
94
            UncompiledKernelError: If the kernel was not yet compiled.
95

96
        Returns:
97
            typing context used to compile the kernel
98
        """
99
        if self._typing_context:
×
100
            return self._typing_context
×
101
        else:
102
            raise UncompiledKernelError(self._pyfunc_name)
×
103

104
    def compile(
1✔
105
        self,
106
        target_ctx,
107
        typing_ctx,
108
        args,
109
        debug,
110
        compile_flags,
111
    ):
112
        """Compiles a kernel using numba_dpex.core.compiler.Compiler.
113

114
        Args:
115
            args (_type_): _description_
116
            debug (_type_): _description_
117
            compile_flags (_type_): _description_
118
        """
119

120
        logging.debug("compiling SpirvKernel with arg types", args)
1✔
121

122
        self._target_context = target_ctx
1✔
123
        self._typing_context = typing_ctx
1✔
124

125
        cres = compile_with_dpex(
1✔
126
            self._func,
127
            self._pyfunc_name,
128
            args=args,
129
            return_type=None,
130
            debug=debug,
131
            is_kernel=True,
132
            typing_context=typing_ctx,
133
            target_context=target_ctx,
134
            extra_compile_flags=compile_flags,
135
        )
136

137
        func = cres.library.get_function(cres.fndesc.llvm_func_name)
1✔
138
        kernel_targetctx: SPIRVTargetContext = cres.target_context
1✔
139
        kernel = kernel_targetctx.prepare_spir_kernel(func, cres.signature.args)
1✔
140

141
        # XXX: Setting the inline_threshold in the following way is a temporary
142
        # workaround till the JitKernel dispatcher is replaced by
143
        # experimental.dispatcher.KernelDispatcher.
144
        if config.INLINE_THRESHOLD is not None:
1!
145
            cres.library.inline_threshold = config.INLINE_THRESHOLD
1✔
146
        else:
UNCOV
147
            cres.library.inline_threshold = 0
×
148

149
        cres.library._optimize_final_module()
1✔
150
        self._llvm_module = kernel.module.__str__()
1✔
151
        self._module_name = kernel.name
1✔
152

153
        # Dump LLVM IR if DEBUG flag is set.
154
        if config.DUMP_KERNEL_LLVM:
1✔
155
            import hashlib
1✔
156

157
            # Embed hash of module name in the output file name
158
            # so that different kernels are written to separate files
159
            with open(
1✔
160
                "llvm_kernel_"
161
                + hashlib.sha256(self._module_name.encode()).hexdigest()
162
                + ".ll",
163
                "w",
164
            ) as f:
165
                f.write(self._llvm_module)
1✔
166

167
        # FIXME: There is no need to serialize the bitcode. It can be passed to
168
        # llvm-spirv directly via stdin.
169

170
        # FIXME: There is no need for spirv-dis. We cause use --to-text
171
        # (or --spirv-text) to convert SPIRV to text
172
        self._device_driver_ir_module = spirv_generator.llvm_to_spirv(
1✔
173
            self._target_context, self._llvm_module, kernel.module.as_bitcode()
174
        )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc