• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 14064199728

25 Mar 2025 03:52PM UTC coverage: 90.154% (+0.08%) from 90.07%
14064199728

Pull #9055

github

web-flow
Merge eaafb5e56 into e64db6197
Pull Request #9055: Added retries parameters to pipeline.draw()

9898 of 10979 relevant lines covered (90.15%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.81
haystack/utils/device.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import os
1✔
6
from dataclasses import dataclass, field
1✔
7
from enum import Enum
1✔
8
from typing import Any, Dict, Optional, Tuple, Union
1✔
9

10
from haystack.lazy_imports import LazyImport
1✔
11

12
with LazyImport(
1✔
13
    message="PyTorch must be installed to use torch.device or use GPU support in HuggingFace transformers. "
14
    "Run 'pip install \"transformers[torch]\"'"
15
) as torch_import:
16
    import torch
1✔
17

18

19
class DeviceType(Enum):
1✔
20
    """
21
    Represents device types supported by Haystack.
22

23
    This also includes devices that are not directly used by models - for example, the disk device is exclusively used
24
    in device maps for frameworks that support offloading model weights to disk.
25
    """
26

27
    CPU = "cpu"
1✔
28
    GPU = "cuda"
1✔
29
    DISK = "disk"
1✔
30
    MPS = "mps"
1✔
31

32
    def __str__(self):
1✔
33
        return self.value
1✔
34

35
    @staticmethod
1✔
36
    def from_str(string: str) -> "DeviceType":
1✔
37
        """
38
        Create a device type from a string.
39

40
        :param string:
41
            The string to convert.
42
        :returns:
43
            The device type.
44
        """
45
        mapping = {e.value: e for e in DeviceType}
1✔
46
        _type = mapping.get(string)
1✔
47
        if _type is None:
1✔
48
            raise ValueError(f"Unknown device type string '{string}'")
1✔
49
        return _type
1✔
50

51

52
@dataclass
1✔
53
class Device:
1✔
54
    """
55
    A generic representation of a device.
56

57
    :param type:
58
        The device type.
59
    :param id:
60
        The optional device id.
61
    """
62

63
    type: DeviceType
1✔
64
    id: Optional[int] = field(default=None)
1✔
65

66
    def __init__(self, type: DeviceType, id: Optional[int] = None):  # noqa:A002
1✔
67
        """
68
        Create a generic device.
69

70
        :param type:
71
            The device type.
72
        :param id:
73
            The device id.
74
        """
75
        if id is not None and id < 0:
1✔
76
            raise ValueError(f"Device id must be >= 0, got {id}")
1✔
77

78
        self.type = type
1✔
79
        self.id = id
1✔
80

81
    def __str__(self):
1✔
82
        if self.id is None:
1✔
83
            return str(self.type)
1✔
84
        else:
85
            return f"{self.type}:{self.id}"
1✔
86

87
    @staticmethod
1✔
88
    def cpu() -> "Device":
1✔
89
        """
90
        Create a generic CPU device.
91

92
        :returns:
93
            The CPU device.
94
        """
95
        return Device(DeviceType.CPU)
1✔
96

97
    @staticmethod
1✔
98
    def gpu(id: int = 0) -> "Device":  # noqa:A002
1✔
99
        """
100
        Create a generic GPU device.
101

102
        :param id:
103
            The GPU id.
104
        :returns:
105
            The GPU device.
106
        """
107
        return Device(DeviceType.GPU, id)
1✔
108

109
    @staticmethod
1✔
110
    def disk() -> "Device":
1✔
111
        """
112
        Create a generic disk device.
113

114
        :returns:
115
            The disk device.
116
        """
117
        return Device(DeviceType.DISK)
1✔
118

119
    @staticmethod
1✔
120
    def mps() -> "Device":
1✔
121
        """
122
        Create a generic Apple Metal Performance Shader device.
123

124
        :returns:
125
            The MPS device.
126
        """
127
        return Device(DeviceType.MPS)
1✔
128

129
    @staticmethod
1✔
130
    def from_str(string: str) -> "Device":
1✔
131
        """
132
        Create a generic device from a string.
133

134
        :returns:
135
            The device.
136

137
        """
138
        device_type_str, device_id = _split_device_string(string)
1✔
139
        return Device(DeviceType.from_str(device_type_str), device_id)
1✔
140

141

142
@dataclass
1✔
143
class DeviceMap:
1✔
144
    """
145
    A generic mapping from strings to devices.
146

147
    The semantics of the strings are dependent on target framework. Primarily used to deploy HuggingFace models to
148
    multiple devices.
149

150
    :param mapping:
151
        Dictionary mapping strings to devices.
152
    """
153

154
    mapping: Dict[str, Device] = field(default_factory=dict, hash=False)
1✔
155

156
    def __getitem__(self, key: str) -> Device:
1✔
157
        return self.mapping[key]
1✔
158

159
    def __setitem__(self, key: str, value: Device):
1✔
160
        self.mapping[key] = value
1✔
161

162
    def __contains__(self, key: str) -> bool:
1✔
163
        return key in self.mapping
1✔
164

165
    def __len__(self) -> int:
1✔
166
        return len(self.mapping)
1✔
167

168
    def __iter__(self):
1✔
169
        return iter(self.mapping.items())
1✔
170

171
    def to_dict(self) -> Dict[str, str]:
1✔
172
        """
173
        Serialize the mapping to a JSON-serializable dictionary.
174

175
        :returns:
176
            The serialized mapping.
177
        """
178
        return {key: str(device) for key, device in self.mapping.items()}
×
179

180
    @property
1✔
181
    def first_device(self) -> Optional[Device]:
1✔
182
        """
183
        Return the first device in the mapping, if any.
184

185
        :returns:
186
            The first device.
187
        """
188
        if not self.mapping:
1✔
189
            return None
1✔
190
        else:
191
            return next(iter(self.mapping.values()))
1✔
192

193
    @staticmethod
1✔
194
    def from_dict(dict: Dict[str, str]) -> "DeviceMap":  # noqa:A002
1✔
195
        """
196
        Create a generic device map from a JSON-serialized dictionary.
197

198
        :param dict:
199
            The serialized mapping.
200
        :returns:
201
            The generic device map.
202
        """
203
        mapping = {}
×
204
        for key, device_str in dict.items():
×
205
            mapping[key] = Device.from_str(device_str)
×
206
        return DeviceMap(mapping)
×
207

208
    @staticmethod
1✔
209
    def from_hf(hf_device_map: Dict[str, Union[int, str, "torch.device"]]) -> "DeviceMap":
1✔
210
        """
211
        Create a generic device map from a HuggingFace device map.
212

213
        :param hf_device_map:
214
            The HuggingFace device map.
215
        :returns:
216
            The deserialized device map.
217
        """
218
        mapping = {}
1✔
219
        for key, device in hf_device_map.items():
1✔
220
            if isinstance(device, int):
1✔
221
                mapping[key] = Device(DeviceType.GPU, device)
1✔
222
            elif isinstance(device, str):
1✔
223
                device_type, device_id = _split_device_string(device)
1✔
224
                mapping[key] = Device(DeviceType.from_str(device_type), device_id)
1✔
225
            elif isinstance(device, torch.device):
1✔
226
                device_type = device.type
×
227
                device_id = device.index
×
228
                mapping[key] = Device(DeviceType.from_str(device_type), device_id)
×
229
            else:
230
                raise ValueError(
1✔
231
                    f"Couldn't convert HuggingFace device map - unexpected device '{str(device)}' for '{key}'"
232
                )
233
        return DeviceMap(mapping)
1✔
234

235

236
@dataclass(frozen=True)
1✔
237
class ComponentDevice:
1✔
238
    """
239
    A representation of a device for a component.
240

241
    This can be either a single device or a device map.
242
    """
243

244
    _single_device: Optional[Device] = field(default=None)
1✔
245
    _multiple_devices: Optional[DeviceMap] = field(default=None)
1✔
246

247
    @classmethod
1✔
248
    def from_str(cls, device_str: str) -> "ComponentDevice":
1✔
249
        """
250
        Create a component device representation from a device string.
251

252
        The device string can only represent a single device.
253

254
        :param device_str:
255
            The device string.
256
        :returns:
257
            The component device representation.
258
        """
259
        device = Device.from_str(device_str)
1✔
260
        return cls.from_single(device)
1✔
261

262
    @classmethod
1✔
263
    def from_single(cls, device: Device) -> "ComponentDevice":
1✔
264
        """
265
        Create a component device representation from a single device.
266

267
        Disks cannot be used as single devices.
268

269
        :param device:
270
            The device.
271
        :returns:
272
            The component device representation.
273
        """
274
        if device.type == DeviceType.DISK:
1✔
275
            raise ValueError("The disk device can only be used as a part of device maps")
1✔
276

277
        return cls(_single_device=device)
1✔
278

279
    @classmethod
1✔
280
    def from_multiple(cls, device_map: DeviceMap) -> "ComponentDevice":
1✔
281
        """
282
        Create a component device representation from a device map.
283

284
        :param device_map:
285
            The device map.
286
        :returns:
287
            The component device representation.
288
        """
289
        return cls(_multiple_devices=device_map)
1✔
290

291
    def _validate(self):
1✔
292
        """
293
        Validate the component device representation.
294
        """
295
        if not (self._single_device is not None) ^ (self._multiple_devices is not None):
1✔
296
            raise ValueError(
1✔
297
                "The component device can neither be empty nor contain both a single device and a device map"
298
            )
299

300
    def to_torch(self) -> "torch.device":
1✔
301
        """
302
        Convert the component device representation to PyTorch format.
303

304
        Device maps are not supported.
305

306
        :returns:
307
            The PyTorch device representation.
308
        """
309
        self._validate()
1✔
310

311
        if self._single_device is None:
1✔
312
            raise ValueError("Only single devices can be converted to PyTorch format")
1✔
313

314
        torch_import.check()
1✔
315
        assert self._single_device is not None
1✔
316
        return torch.device(str(self._single_device))
1✔
317

318
    def to_torch_str(self) -> str:
1✔
319
        """
320
        Convert the component device representation to PyTorch string format.
321

322
        Device maps are not supported.
323

324
        :returns:
325
            The PyTorch device string representation.
326
        """
327
        self._validate()
1✔
328

329
        if self._single_device is None:
1✔
330
            raise ValueError("Only single devices can be converted to PyTorch format")
1✔
331

332
        assert self._single_device is not None
1✔
333
        return str(self._single_device)
1✔
334

335
    def to_spacy(self) -> int:
1✔
336
        """
337
        Convert the component device representation to spaCy format.
338

339
        Device maps are not supported.
340

341
        :returns:
342
            The spaCy device representation.
343
        """
344
        self._validate()
1✔
345

346
        if self._single_device is None:
1✔
347
            raise ValueError("Only single devices can be converted to spaCy format")
1✔
348

349
        assert self._single_device is not None
1✔
350
        if self._single_device.type == DeviceType.GPU:
1✔
351
            assert self._single_device.id is not None
1✔
352
            return self._single_device.id
1✔
353
        else:
354
            return -1
×
355

356
    def to_hf(self) -> Union[Union[int, str], Dict[str, Union[int, str]]]:
1✔
357
        """
358
        Convert the component device representation to HuggingFace format.
359

360
        :returns:
361
            The HuggingFace device representation.
362
        """
363
        self._validate()
1✔
364

365
        def convert_device(device: Device, *, gpu_id_only: bool = False) -> Union[int, str]:
1✔
366
            if gpu_id_only and device.type == DeviceType.GPU:
1✔
367
                assert device.id is not None
1✔
368
                return device.id
1✔
369
            else:
370
                return str(device)
1✔
371

372
        if self._single_device is not None:
1✔
373
            return convert_device(self._single_device)
1✔
374

375
        assert self._multiple_devices is not None
1✔
376
        return {key: convert_device(device, gpu_id_only=True) for key, device in self._multiple_devices.mapping.items()}
1✔
377

378
    def update_hf_kwargs(self, hf_kwargs: Dict[str, Any], *, overwrite: bool) -> Dict[str, Any]:
1✔
379
        """
380
        Convert the component device representation to HuggingFace format.
381

382
        Add them as canonical keyword arguments to the keyword arguments dictionary.
383

384
        :param hf_kwargs:
385
            The HuggingFace keyword arguments dictionary.
386
        :param overwrite:
387
            Whether to overwrite existing device arguments.
388
        :returns:
389
            The HuggingFace keyword arguments dictionary.
390
        """
391
        self._validate()
1✔
392

393
        if not overwrite and any(x in hf_kwargs for x in ("device", "device_map")):
1✔
394
            return hf_kwargs
1✔
395

396
        converted = self.to_hf()
1✔
397
        key = "device_map" if self.has_multiple_devices else "device"
1✔
398
        hf_kwargs[key] = converted
1✔
399
        return hf_kwargs
1✔
400

401
    @property
1✔
402
    def has_multiple_devices(self) -> bool:
1✔
403
        """
404
        Whether this component device representation contains multiple devices.
405
        """
406
        self._validate()
1✔
407

408
        return self._multiple_devices is not None
1✔
409

410
    @property
1✔
411
    def first_device(self) -> Optional["ComponentDevice"]:
1✔
412
        """
413
        Return either the single device or the first device in the device map, if any.
414

415
        :returns:
416
            The first device.
417
        """
418
        self._validate()
1✔
419

420
        if self._single_device is not None:
1✔
421
            return self.from_single(self._single_device)
1✔
422

423
        assert self._multiple_devices is not None
1✔
424
        assert self._multiple_devices.first_device is not None
1✔
425
        return self.from_single(self._multiple_devices.first_device)
1✔
426

427
    @staticmethod
1✔
428
    def resolve_device(device: Optional["ComponentDevice"] = None) -> "ComponentDevice":
1✔
429
        """
430
        Select a device for a component. If a device is specified, it's used. Otherwise, the default device is used.
431

432
        :param device:
433
            The provided device, if any.
434
        :returns:
435
            The resolved device.
436
        """
437
        if not isinstance(device, ComponentDevice) and device is not None:
1✔
438
            raise ValueError(
×
439
                f"Invalid component device type '{type(device).__name__}'. Must either be None or ComponentDevice."
440
            )
441

442
        if device is None:
1✔
443
            device = ComponentDevice.from_single(_get_default_device())
1✔
444

445
        return device
1✔
446

447
    def to_dict(self) -> Dict[str, Any]:
1✔
448
        """
449
        Convert the component device representation to a JSON-serializable dictionary.
450

451
        :returns:
452
            The dictionary representation.
453
        """
454
        if self._single_device is not None:
1✔
455
            return {"type": "single", "device": str(self._single_device)}
1✔
456
        elif self._multiple_devices is not None:
×
457
            return {"type": "multiple", "device_map": self._multiple_devices.to_dict()}
×
458
        else:
459
            # Unreachable
460
            assert False
×
461

462
    @classmethod
1✔
463
    def from_dict(cls, dict: Dict[str, Any]) -> "ComponentDevice":  # noqa:A002
1✔
464
        """
465
        Create a component device representation from a JSON-serialized dictionary.
466

467
        :param dict:
468
            The serialized representation.
469
        :returns:
470
            The deserialized component device.
471
        """
472
        if dict["type"] == "single":
1✔
473
            return cls.from_str(dict["device"])
1✔
474
        elif dict["type"] == "multiple":
×
475
            return cls.from_multiple(DeviceMap.from_dict(dict["device_map"]))
×
476
        else:
477
            raise ValueError(f"Unknown component device type '{dict['type']}' in serialized data")
×
478

479

480
def _get_default_device() -> Device:
1✔
481
    """
482
    Return the default device for Haystack.
483

484
    Precedence:
485
        GPU > MPS > CPU. If PyTorch is not installed, only CPU is available.
486

487
    :returns:
488
        The default device.
489
    """
490
    try:
1✔
491
        torch_import.check()
1✔
492

493
        has_mps = (
1✔
494
            hasattr(torch.backends, "mps")
495
            and torch.backends.mps.is_available()
496
            and os.getenv("HAYSTACK_MPS_ENABLED", "true") != "false"
497
        )
498
        has_cuda = torch.cuda.is_available()
1✔
499
    except ImportError:
×
500
        has_mps = False
×
501
        has_cuda = False
×
502

503
    if has_cuda:
1✔
504
        return Device.gpu()
1✔
505
    elif has_mps:
1✔
506
        return Device.mps()
1✔
507
    else:
508
        return Device.cpu()
1✔
509

510

511
def _split_device_string(string: str) -> Tuple[str, Optional[int]]:
1✔
512
    """
513
    Split a device string into device type and device id.
514

515
    :param string:
516
        The device string to split.
517
    :returns:
518
        The device type and device id, if any.
519
    """
520
    if ":" in string:
1✔
521
        device_type, device_id_str = string.split(":")
1✔
522
        try:
1✔
523
            device_id = int(device_id_str)
1✔
524
        except ValueError:
×
525
            raise ValueError(f"Device id must be an integer, got {device_id_str}")
×
526
    else:
527
        device_type = string
1✔
528
        device_id = None
1✔
529
    return device_type, device_id
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc