• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 15703017100

17 Jun 2025 09:07AM UTC coverage: 90.143% (-0.001%) from 90.144%
15703017100

Pull #9470

github

web-flow
Merge bc17a8c2b into 7dbac5b3c
Pull Request #9470: feat: adding support for torch xpu device

11550 of 12813 relevant lines covered (90.14%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.72
haystack/utils/device.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import os
1✔
6
from dataclasses import dataclass, field
1✔
7
from enum import Enum
1✔
8
from typing import Any, Dict, Optional, Tuple, Union
1✔
9

10
from haystack.lazy_imports import LazyImport
1✔
11

12
with LazyImport(
1✔
13
    message="PyTorch must be installed to use torch.device or use GPU support in HuggingFace transformers. "
14
    "Run 'pip install \"transformers[torch]\"'"
15
) as torch_import:
16
    import torch
1✔
17

18

19
class DeviceType(Enum):
1✔
20
    """
21
    Represents device types supported by Haystack.
22

23
    This also includes devices that are not directly used by models - for example, the disk device is exclusively used
24
    in device maps for frameworks that support offloading model weights to disk.
25
    """
26

27
    CPU = "cpu"
1✔
28
    GPU = "cuda"
1✔
29
    DISK = "disk"
1✔
30
    MPS = "mps"
1✔
31
    XPU = "xpu"
1✔
32

33
    def __str__(self):
1✔
34
        return self.value
1✔
35

36
    @staticmethod
1✔
37
    def from_str(string: str) -> "DeviceType":
1✔
38
        """
39
        Create a device type from a string.
40

41
        :param string:
42
            The string to convert.
43
        :returns:
44
            The device type.
45
        """
46
        mapping = {e.value: e for e in DeviceType}
1✔
47
        _type = mapping.get(string)
1✔
48
        if _type is None:
1✔
49
            raise ValueError(f"Unknown device type string '{string}'")
1✔
50
        return _type
1✔
51

52

53
@dataclass
1✔
54
class Device:
1✔
55
    """
56
    A generic representation of a device.
57

58
    :param type:
59
        The device type.
60
    :param id:
61
        The optional device id.
62
    """
63

64
    type: DeviceType
1✔
65
    id: Optional[int] = field(default=None)
1✔
66

67
    def __init__(self, type: DeviceType, id: Optional[int] = None):  # noqa:A002
1✔
68
        """
69
        Create a generic device.
70

71
        :param type:
72
            The device type.
73
        :param id:
74
            The device id.
75
        """
76
        if id is not None and id < 0:
1✔
77
            raise ValueError(f"Device id must be >= 0, got {id}")
1✔
78

79
        self.type = type
1✔
80
        self.id = id
1✔
81

82
    def __str__(self):
1✔
83
        if self.id is None:
1✔
84
            return str(self.type)
1✔
85
        else:
86
            return f"{self.type}:{self.id}"
1✔
87

88
    @staticmethod
1✔
89
    def cpu() -> "Device":
1✔
90
        """
91
        Create a generic CPU device.
92

93
        :returns:
94
            The CPU device.
95
        """
96
        return Device(DeviceType.CPU)
1✔
97

98
    @staticmethod
1✔
99
    def gpu(id: int = 0) -> "Device":  # noqa:A002
1✔
100
        """
101
        Create a generic GPU device.
102

103
        :param id:
104
            The GPU id.
105
        :returns:
106
            The GPU device.
107
        """
108
        return Device(DeviceType.GPU, id)
1✔
109

110
    @staticmethod
1✔
111
    def disk() -> "Device":
1✔
112
        """
113
        Create a generic disk device.
114

115
        :returns:
116
            The disk device.
117
        """
118
        return Device(DeviceType.DISK)
1✔
119

120
    @staticmethod
1✔
121
    def mps() -> "Device":
1✔
122
        """
123
        Create a generic Apple Metal Performance Shader device.
124

125
        :returns:
126
            The MPS device.
127
        """
128
        return Device(DeviceType.MPS)
1✔
129

130
    @staticmethod
1✔
131
    def xpu() -> "Device":
1✔
132
        """
133
        Create a generic Intel GPU Optimization device.
134

135
        :returns:
136
            The XPU device.
137
        """
138
        return Device(DeviceType.XPU)
1✔
139

140
    @staticmethod
1✔
141
    def from_str(string: str) -> "Device":
1✔
142
        """
143
        Create a generic device from a string.
144

145
        :returns:
146
            The device.
147

148
        """
149
        device_type_str, device_id = _split_device_string(string)
1✔
150
        return Device(DeviceType.from_str(device_type_str), device_id)
1✔
151

152

153
@dataclass
1✔
154
class DeviceMap:
1✔
155
    """
156
    A generic mapping from strings to devices.
157

158
    The semantics of the strings are dependent on target framework. Primarily used to deploy HuggingFace models to
159
    multiple devices.
160

161
    :param mapping:
162
        Dictionary mapping strings to devices.
163
    """
164

165
    mapping: Dict[str, Device] = field(default_factory=dict, hash=False)
1✔
166

167
    def __getitem__(self, key: str) -> Device:
1✔
168
        return self.mapping[key]
1✔
169

170
    def __setitem__(self, key: str, value: Device) -> None:
1✔
171
        self.mapping[key] = value
1✔
172

173
    def __contains__(self, key: str) -> bool:
1✔
174
        return key in self.mapping
1✔
175

176
    def __len__(self) -> int:
1✔
177
        return len(self.mapping)
1✔
178

179
    def __iter__(self):
1✔
180
        return iter(self.mapping.items())
1✔
181

182
    def to_dict(self) -> Dict[str, str]:
1✔
183
        """
184
        Serialize the mapping to a JSON-serializable dictionary.
185

186
        :returns:
187
            The serialized mapping.
188
        """
189
        return {key: str(device) for key, device in self.mapping.items()}
×
190

191
    @property
1✔
192
    def first_device(self) -> Optional[Device]:
1✔
193
        """
194
        Return the first device in the mapping, if any.
195

196
        :returns:
197
            The first device.
198
        """
199
        if not self.mapping:
1✔
200
            return None
1✔
201
        else:
202
            return next(iter(self.mapping.values()))
1✔
203

204
    @staticmethod
1✔
205
    def from_dict(dict: Dict[str, str]) -> "DeviceMap":  # noqa:A002
1✔
206
        """
207
        Create a generic device map from a JSON-serialized dictionary.
208

209
        :param dict:
210
            The serialized mapping.
211
        :returns:
212
            The generic device map.
213
        """
214
        mapping = {}
×
215
        for key, device_str in dict.items():
×
216
            mapping[key] = Device.from_str(device_str)
×
217
        return DeviceMap(mapping)
×
218

219
    @staticmethod
1✔
220
    def from_hf(hf_device_map: Dict[str, Union[int, str, "torch.device"]]) -> "DeviceMap":
1✔
221
        """
222
        Create a generic device map from a HuggingFace device map.
223

224
        :param hf_device_map:
225
            The HuggingFace device map.
226
        :returns:
227
            The deserialized device map.
228
        """
229
        mapping = {}
1✔
230
        for key, device in hf_device_map.items():
1✔
231
            if isinstance(device, int):
1✔
232
                mapping[key] = Device(DeviceType.GPU, device)
1✔
233
            elif isinstance(device, str):
1✔
234
                device_type, device_id = _split_device_string(device)
1✔
235
                mapping[key] = Device(DeviceType.from_str(device_type), device_id)
1✔
236
            elif isinstance(device, torch.device):
1✔
237
                device_type = device.type
×
238
                device_id = device.index
×
239
                mapping[key] = Device(DeviceType.from_str(device_type), device_id)
×
240
            else:
241
                raise ValueError(
1✔
242
                    f"Couldn't convert HuggingFace device map - unexpected device '{str(device)}' for '{key}'"
243
                )
244
        return DeviceMap(mapping)
1✔
245

246

247
@dataclass(frozen=True)
1✔
248
class ComponentDevice:
1✔
249
    """
250
    A representation of a device for a component.
251

252
    This can be either a single device or a device map.
253
    """
254

255
    _single_device: Optional[Device] = field(default=None)
1✔
256
    _multiple_devices: Optional[DeviceMap] = field(default=None)
1✔
257

258
    @classmethod
1✔
259
    def from_str(cls, device_str: str) -> "ComponentDevice":
1✔
260
        """
261
        Create a component device representation from a device string.
262

263
        The device string can only represent a single device.
264

265
        :param device_str:
266
            The device string.
267
        :returns:
268
            The component device representation.
269
        """
270
        device = Device.from_str(device_str)
1✔
271
        return cls.from_single(device)
1✔
272

273
    @classmethod
1✔
274
    def from_single(cls, device: Device) -> "ComponentDevice":
1✔
275
        """
276
        Create a component device representation from a single device.
277

278
        Disks cannot be used as single devices.
279

280
        :param device:
281
            The device.
282
        :returns:
283
            The component device representation.
284
        """
285
        if device.type == DeviceType.DISK:
1✔
286
            raise ValueError("The disk device can only be used as a part of device maps")
1✔
287

288
        return cls(_single_device=device)
1✔
289

290
    @classmethod
1✔
291
    def from_multiple(cls, device_map: DeviceMap) -> "ComponentDevice":
1✔
292
        """
293
        Create a component device representation from a device map.
294

295
        :param device_map:
296
            The device map.
297
        :returns:
298
            The component device representation.
299
        """
300
        return cls(_multiple_devices=device_map)
1✔
301

302
    def _validate(self):
1✔
303
        """
304
        Validate the component device representation.
305
        """
306
        if not (self._single_device is not None) ^ (self._multiple_devices is not None):
1✔
307
            raise ValueError(
1✔
308
                "The component device can neither be empty nor contain both a single device and a device map"
309
            )
310

311
    def to_torch(self) -> "torch.device":
1✔
312
        """
313
        Convert the component device representation to PyTorch format.
314

315
        Device maps are not supported.
316

317
        :returns:
318
            The PyTorch device representation.
319
        """
320
        self._validate()
1✔
321

322
        if self._single_device is None:
1✔
323
            raise ValueError("Only single devices can be converted to PyTorch format")
1✔
324

325
        torch_import.check()
1✔
326
        assert self._single_device is not None
1✔
327
        return torch.device(str(self._single_device))
1✔
328

329
    def to_torch_str(self) -> str:
1✔
330
        """
331
        Convert the component device representation to PyTorch string format.
332

333
        Device maps are not supported.
334

335
        :returns:
336
            The PyTorch device string representation.
337
        """
338
        self._validate()
1✔
339

340
        if self._single_device is None:
1✔
341
            raise ValueError("Only single devices can be converted to PyTorch format")
1✔
342

343
        assert self._single_device is not None
1✔
344
        return str(self._single_device)
1✔
345

346
    def to_spacy(self) -> int:
1✔
347
        """
348
        Convert the component device representation to spaCy format.
349

350
        Device maps are not supported.
351

352
        :returns:
353
            The spaCy device representation.
354
        """
355
        self._validate()
1✔
356

357
        if self._single_device is None:
1✔
358
            raise ValueError("Only single devices can be converted to spaCy format")
1✔
359

360
        assert self._single_device is not None
1✔
361
        if self._single_device.type == DeviceType.GPU:
1✔
362
            assert self._single_device.id is not None
1✔
363
            return self._single_device.id
1✔
364
        else:
365
            return -1
×
366

367
    def to_hf(self) -> Union[Union[int, str], Dict[str, Union[int, str]]]:
1✔
368
        """
369
        Convert the component device representation to HuggingFace format.
370

371
        :returns:
372
            The HuggingFace device representation.
373
        """
374
        self._validate()
1✔
375

376
        def convert_device(device: Device, *, gpu_id_only: bool = False) -> Union[int, str]:
1✔
377
            if gpu_id_only and device.type == DeviceType.GPU:
1✔
378
                assert device.id is not None
1✔
379
                return device.id
1✔
380
            else:
381
                return str(device)
1✔
382

383
        if self._single_device is not None:
1✔
384
            return convert_device(self._single_device)
1✔
385

386
        assert self._multiple_devices is not None
1✔
387
        return {key: convert_device(device, gpu_id_only=True) for key, device in self._multiple_devices.mapping.items()}
1✔
388

389
    def update_hf_kwargs(self, hf_kwargs: Dict[str, Any], *, overwrite: bool) -> Dict[str, Any]:
1✔
390
        """
391
        Convert the component device representation to HuggingFace format.
392

393
        Add them as canonical keyword arguments to the keyword arguments dictionary.
394

395
        :param hf_kwargs:
396
            The HuggingFace keyword arguments dictionary.
397
        :param overwrite:
398
            Whether to overwrite existing device arguments.
399
        :returns:
400
            The HuggingFace keyword arguments dictionary.
401
        """
402
        self._validate()
1✔
403

404
        if not overwrite and any(x in hf_kwargs for x in ("device", "device_map")):
1✔
405
            return hf_kwargs
1✔
406

407
        converted = self.to_hf()
1✔
408
        key = "device_map" if self.has_multiple_devices else "device"
1✔
409
        hf_kwargs[key] = converted
1✔
410
        return hf_kwargs
1✔
411

412
    @property
1✔
413
    def has_multiple_devices(self) -> bool:
1✔
414
        """
415
        Whether this component device representation contains multiple devices.
416
        """
417
        self._validate()
1✔
418

419
        return self._multiple_devices is not None
1✔
420

421
    @property
1✔
422
    def first_device(self) -> Optional["ComponentDevice"]:
1✔
423
        """
424
        Return either the single device or the first device in the device map, if any.
425

426
        :returns:
427
            The first device.
428
        """
429
        self._validate()
1✔
430

431
        if self._single_device is not None:
1✔
432
            return self.from_single(self._single_device)
1✔
433

434
        assert self._multiple_devices is not None
1✔
435
        assert self._multiple_devices.first_device is not None
1✔
436
        return self.from_single(self._multiple_devices.first_device)
1✔
437

438
    @staticmethod
1✔
439
    def resolve_device(device: Optional["ComponentDevice"] = None) -> "ComponentDevice":
1✔
440
        """
441
        Select a device for a component. If a device is specified, it's used. Otherwise, the default device is used.
442

443
        :param device:
444
            The provided device, if any.
445
        :returns:
446
            The resolved device.
447
        """
448
        if not isinstance(device, ComponentDevice) and device is not None:
1✔
449
            raise ValueError(
×
450
                f"Invalid component device type '{type(device).__name__}'. Must either be None or ComponentDevice."
451
            )
452

453
        if device is None:
1✔
454
            device = ComponentDevice.from_single(_get_default_device())
1✔
455

456
        return device
1✔
457

458
    def to_dict(self) -> Dict[str, Any]:
1✔
459
        """
460
        Convert the component device representation to a JSON-serializable dictionary.
461

462
        :returns:
463
            The dictionary representation.
464
        """
465
        if self._single_device is not None:
1✔
466
            return {"type": "single", "device": str(self._single_device)}
1✔
467
        elif self._multiple_devices is not None:
×
468
            return {"type": "multiple", "device_map": self._multiple_devices.to_dict()}
×
469
        else:
470
            # Unreachable
471
            assert False
×
472

473
    @classmethod
1✔
474
    def from_dict(cls, dict: Dict[str, Any]) -> "ComponentDevice":  # noqa:A002
1✔
475
        """
476
        Create a component device representation from a JSON-serialized dictionary.
477

478
        :param dict:
479
            The serialized representation.
480
        :returns:
481
            The deserialized component device.
482
        """
483
        if dict["type"] == "single":
1✔
484
            return cls.from_str(dict["device"])
1✔
485
        elif dict["type"] == "multiple":
×
486
            return cls.from_multiple(DeviceMap.from_dict(dict["device_map"]))
×
487
        else:
488
            raise ValueError(f"Unknown component device type '{dict['type']}' in serialized data")
×
489

490

491
def _get_default_device() -> Device:
1✔
492
    """
493
    Return the default device for Haystack.
494

495
    Precedence:
496
        GPU > XPU > MPS > CPU. If PyTorch is not installed, only CPU is available.
497

498
    :returns:
499
        The default device.
500
    """
501
    try:
1✔
502
        torch_import.check()
1✔
503

504
        has_mps = (
1✔
505
            hasattr(torch.backends, "mps")
506
            and torch.backends.mps.is_available()
507
            and os.getenv("HAYSTACK_MPS_ENABLED", "true") != "false"
508
        )
509
        has_cuda = torch.cuda.is_available()
1✔
510
        has_xpu = (
1✔
511
            hasattr(torch, "xpu")
512
            and hasattr(torch.xpu, "is_available")
513
            and torch.xpu.is_available()
514
            and os.getenv("HAYSTACK_XPU_ENABLED", "true") != "false"
515
        )
516
    except ImportError:
×
517
        has_mps = False
×
518
        has_cuda = False
×
519
        has_xpu = False
×
520

521
    if has_cuda:
1✔
522
        return Device.gpu()
1✔
523
    elif has_xpu:
1✔
524
        return Device.xpu()
1✔
525
    elif has_mps:
1✔
526
        return Device.mps()
1✔
527
    else:
528
        return Device.cpu()
1✔
529

530

531
def _split_device_string(string: str) -> Tuple[str, Optional[int]]:
1✔
532
    """
533
    Split a device string into device type and device id.
534

535
    :param string:
536
        The device string to split.
537
    :returns:
538
        The device type and device id, if any.
539
    """
540
    if ":" in string:
1✔
541
        device_type, device_id_str = string.split(":")
1✔
542
        try:
1✔
543
            device_id = int(device_id_str)
1✔
544
        except ValueError:
×
545
            raise ValueError(f"Device id must be an integer, got {device_id_str}")
×
546
    else:
547
        device_type = string
1✔
548
        device_id = None
1✔
549
    return device_type, device_id
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc