• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 20241601392

15 Dec 2025 05:34PM UTC coverage: 92.121% (-0.01%) from 92.133%
20241601392

Pull #10244

github

web-flow
Merge 5f2f7fd60 into fd989fecc
Pull Request #10244: feat!: drop Python 3.9 support due to EOL

14123 of 15331 relevant lines covered (92.12%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

68.45
haystack/components/extractors/named_entity_extractor.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from abc import ABC, abstractmethod
1✔
6
from contextlib import contextmanager
1✔
7
from dataclasses import dataclass, replace
1✔
8
from enum import Enum
1✔
9
from typing import Any, Optional, Union
1✔
10

11
from haystack import ComponentError, DeserializationError, Document, component, default_from_dict, default_to_dict
1✔
12
from haystack.lazy_imports import LazyImport
1✔
13
from haystack.utils.auth import Secret, deserialize_secrets_inplace
1✔
14
from haystack.utils.device import ComponentDevice
1✔
15
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_pipeline_kwargs, serialize_hf_model_kwargs
1✔
16

17
with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transformers_import:
1✔
18
    from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
1✔
19
    from transformers import Pipeline as HfPipeline
1✔
20

21
with LazyImport(message="Run 'pip install spacy'") as spacy_import:
1✔
22
    import spacy  # pylint: disable=import-error
1✔
23
    from spacy import Language as SpacyPipeline  # pylint: disable=import-error
×
24

25

26
class NamedEntityExtractorBackend(Enum):
1✔
27
    """
28
    NLP backend to use for Named Entity Recognition.
29
    """
30

31
    #: Uses an Hugging Face model and pipeline.
32
    HUGGING_FACE = "hugging_face"
1✔
33

34
    #: Uses a spaCy model and pipeline.
35
    SPACY = "spacy"
1✔
36

37
    def __str__(self):
1✔
38
        return self.value
×
39

40
    @staticmethod
1✔
41
    def from_str(string: str) -> "NamedEntityExtractorBackend":
1✔
42
        """
43
        Convert a string to a NamedEntityExtractorBackend enum.
44
        """
45
        enum_map = {e.value: e for e in NamedEntityExtractorBackend}
1✔
46
        mode = enum_map.get(string)
1✔
47
        if mode is None:
1✔
48
            msg = (
1✔
49
                f"Invalid backend '{string}' for named entity extractor. "
50
                f"Supported backends are: {list(enum_map.keys())}"
51
            )
52
            raise ComponentError(msg)
1✔
53
        return mode
1✔
54

55

56
@dataclass
1✔
57
class NamedEntityAnnotation:
1✔
58
    """
59
    Describes a single NER annotation.
60

61
    :param entity:
62
        Entity label.
63
    :param start:
64
        Start index of the entity in the document.
65
    :param end:
66
        End index of the entity in the document.
67
    :param score:
68
        Score calculated by the model.
69
    """
70

71
    entity: str
1✔
72
    start: int
1✔
73
    end: int
1✔
74
    score: Optional[float] = None
1✔
75

76

77
@component
1✔
78
class NamedEntityExtractor:
1✔
79
    """
80
    Annotates named entities in a collection of documents.
81

82
    The component supports two backends: Hugging Face and spaCy. The
83
    former can be used with any sequence classification model from the
84
    [Hugging Face model hub](https://huggingface.co/models), while the
85
    latter can be used with any [spaCy model](https://spacy.io/models)
86
    that contains an NER component. Annotations are stored as metadata
87
    in the documents.
88

89
    Usage example:
90
    ```python
91
    from haystack import Document
92
    from haystack.components.extractors.named_entity_extractor import NamedEntityExtractor
93

94
    documents = [
95
        Document(content="I'm Merlin, the happy pig!"),
96
        Document(content="My name is Clara and I live in Berkeley, California."),
97
    ]
98
    extractor = NamedEntityExtractor(backend="hugging_face", model="dslim/bert-base-NER")
99
    extractor.warm_up()
100
    results = extractor.run(documents=documents)["documents"]
101
    annotations = [NamedEntityExtractor.get_stored_annotations(doc) for doc in results]
102
    print(annotations)
103
    ```
104
    """
105

106
    _METADATA_KEY = "named_entities"
1✔
107

108
    def __init__(
1✔
109
        self,
110
        *,
111
        backend: Union[str, NamedEntityExtractorBackend],
112
        model: str,
113
        pipeline_kwargs: Optional[dict[str, Any]] = None,
114
        device: Optional[ComponentDevice] = None,
115
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
116
    ) -> None:
117
        """
118
        Create a Named Entity extractor component.
119

120
        :param backend:
121
            Backend to use for NER.
122
        :param model:
123
            Name of the model or a path to the model on
124
            the local disk. Dependent on the backend.
125
        :param pipeline_kwargs:
126
            Keyword arguments passed to the pipeline. The
127
            pipeline can override these arguments. Dependent on the backend.
128
        :param device:
129
            The device on which the model is loaded. If `None`,
130
            the default device is automatically selected. If a
131
            device/device map is specified in `pipeline_kwargs`,
132
            it overrides this parameter (only applicable to the
133
            HuggingFace backend).
134
        :param token:
135
            The API token to download private models from Hugging Face.
136
        """
137

138
        if isinstance(backend, str):
1✔
139
            backend = NamedEntityExtractorBackend.from_str(backend)
1✔
140

141
        self._backend: _NerBackend
1✔
142
        self._warmed_up: bool = False
1✔
143
        self.token = token
1✔
144
        device = ComponentDevice.resolve_device(device)
1✔
145

146
        if backend == NamedEntityExtractorBackend.HUGGING_FACE:
1✔
147
            pipeline_kwargs = resolve_hf_pipeline_kwargs(
1✔
148
                huggingface_pipeline_kwargs=pipeline_kwargs or {},
149
                model=model,
150
                task="ner",
151
                supported_tasks=["ner"],
152
                device=device,
153
                token=token,
154
            )
155

156
            self._backend = _HfBackend(model_name_or_path=model, device=device, pipeline_kwargs=pipeline_kwargs)
1✔
157
        elif backend == NamedEntityExtractorBackend.SPACY:
×
158
            self._backend = _SpacyBackend(model_name_or_path=model, device=device, pipeline_kwargs=pipeline_kwargs)
×
159
        else:
160
            raise ComponentError(f"Unknown NER backend '{type(backend).__name__}' for extractor")
×
161

162
    def warm_up(self):
1✔
163
        """
164
        Initialize the component.
165

166
        :raises ComponentError:
167
            If the backend fails to initialize successfully.
168
        """
169
        if self._warmed_up:
×
170
            return
×
171

172
        try:
×
173
            self._backend.initialize()
×
174
            self._warmed_up = True
×
175
        except Exception as e:
×
176
            raise ComponentError(
×
177
                f"Named entity extractor with backend '{self._backend.type}' failed to initialize."
178
            ) from e
179

180
    @component.output_types(documents=list[Document])
1✔
181
    def run(self, documents: list[Document], batch_size: int = 1) -> dict[str, Any]:
1✔
182
        """
183
        Annotate named entities in each document and store the annotations in the document's metadata.
184

185
        :param documents:
186
            Documents to process.
187
        :param batch_size:
188
            Batch size used for processing the documents.
189
        :returns:
190
            Processed documents.
191
        :raises ComponentError:
192
            If the backend fails to process a document.
193
        """
194
        if not self._warmed_up:
1✔
195
            self.warm_up()
×
196

197
        texts = [doc.content if doc.content is not None else "" for doc in documents]
1✔
198
        annotations = self._backend.annotate(texts, batch_size=batch_size)
1✔
199

200
        if len(annotations) != len(documents):
1✔
201
            raise ComponentError(
×
202
                "NER backend did not return the correct number of annotations; "
203
                f"got {len(annotations)} but expected {len(documents)}"
204
            )
205

206
        new_documents = []
1✔
207
        for doc, doc_annotations in zip(documents, annotations):
1✔
208
            new_meta = {**doc.meta, self._METADATA_KEY: doc_annotations}
1✔
209
            new_documents.append(replace(doc, meta=new_meta))
1✔
210

211
        return {"documents": new_documents}
1✔
212

213
    def to_dict(self) -> dict[str, Any]:
1✔
214
        """
215
        Serializes the component to a dictionary.
216

217
        :returns:
218
            Dictionary with serialized data.
219
        """
220
        serialization_dict = default_to_dict(
1✔
221
            self,
222
            backend=self._backend.type.name,
223
            model=self._backend.model_name,
224
            device=self._backend.device.to_dict(),
225
            pipeline_kwargs=self._backend._pipeline_kwargs,
226
            token=self.token.to_dict() if self.token else None,
227
        )
228

229
        hf_pipeline_kwargs = serialization_dict["init_parameters"]["pipeline_kwargs"]
1✔
230
        hf_pipeline_kwargs.pop("token", None)
1✔
231

232
        serialize_hf_model_kwargs(hf_pipeline_kwargs)
1✔
233
        return serialization_dict
1✔
234

235
    @classmethod
1✔
236
    def from_dict(cls, data: dict[str, Any]) -> "NamedEntityExtractor":
1✔
237
        """
238
        Deserializes the component from a dictionary.
239

240
        :param data:
241
            Dictionary to deserialize from.
242
        :returns:
243
            Deserialized component.
244
        """
245
        try:
1✔
246
            deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
1✔
247
            init_params = data.get("init_parameters", {})
1✔
248
            if init_params.get("device") is not None:
1✔
249
                init_params["device"] = ComponentDevice.from_dict(init_params["device"])
1✔
250
            init_params["backend"] = NamedEntityExtractorBackend[init_params["backend"]]
1✔
251

252
            hf_pipeline_kwargs = init_params.get("pipeline_kwargs", {})
1✔
253
            deserialize_hf_model_kwargs(hf_pipeline_kwargs)
1✔
254
            return default_from_dict(cls, data)
1✔
255
        except Exception as e:
1✔
256
            raise DeserializationError(f"Couldn't deserialize {cls.__name__} instance") from e
1✔
257

258
    @property
1✔
259
    def initialized(self) -> bool:
1✔
260
        """
261
        Returns if the extractor is ready to annotate text.
262
        """
263
        return self._backend.initialized
×
264

265
    @classmethod
1✔
266
    def get_stored_annotations(cls, document: Document) -> Optional[list[NamedEntityAnnotation]]:
1✔
267
        """
268
        Returns the document's named entity annotations stored in its metadata, if any.
269

270
        :param document:
271
            Document whose annotations are to be fetched.
272
        :returns:
273
            The stored annotations.
274
        """
275

276
        return document.meta.get(cls._METADATA_KEY)
×
277

278

279
class _NerBackend(ABC):
1✔
280
    """
281
    Base class for NER backends.
282
    """
283

284
    def __init__(
1✔
285
        self,
286
        _type: NamedEntityExtractorBackend,
287
        device: ComponentDevice,
288
        pipeline_kwargs: Optional[dict[str, Any]] = None,
289
    ) -> None:
290
        super().__init__()
1✔
291

292
        self._type = _type
1✔
293
        self._device = device
1✔
294
        self._pipeline_kwargs = pipeline_kwargs if pipeline_kwargs is not None else {}
1✔
295

296
    @abstractmethod
1✔
297
    def initialize(self):
1✔
298
        """
299
        Initializes the backend. This would usually entail loading models, pipelines, and so on.
300
        """
301

302
    @property
1✔
303
    @abstractmethod
1✔
304
    def initialized(self) -> bool:
1✔
305
        """
306
        Returns if the backend has been initialized, for example, ready to annotate text.
307
        """
308

309
    @abstractmethod
1✔
310
    def annotate(self, texts: list[str], *, batch_size: int = 1) -> list[list[NamedEntityAnnotation]]:
1✔
311
        """
312
        Predict annotations for a collection of documents.
313

314
        :param texts:
315
            Raw texts to be annotated.
316
        :param batch_size:
317
            Size of text batches that are
318
            passed to the model.
319
        :returns:
320
            NER annotations.
321
        """
322

323
    @property
1✔
324
    @abstractmethod
1✔
325
    def model_name(self) -> str:
1✔
326
        """
327
        Returns the model name or path on the local disk.
328
        """
329

330
    @property
1✔
331
    def device(self) -> ComponentDevice:
1✔
332
        """
333
        The device on which the backend's model is loaded.
334

335
        :returns:
336
            The device on which the backend's model is loaded.
337
        """
338
        return self._device
1✔
339

340
    @property
1✔
341
    def type(self) -> NamedEntityExtractorBackend:
1✔
342
        """
343
        Returns the type of the backend.
344
        """
345
        return self._type
1✔
346

347

348
class _HfBackend(_NerBackend):
1✔
349
    """
350
    Hugging Face backend for NER.
351
    """
352

353
    def __init__(
1✔
354
        self, *, model_name_or_path: str, device: ComponentDevice, pipeline_kwargs: Optional[dict[str, Any]] = None
355
    ) -> None:
356
        """
357
        Construct a Hugging Face NER backend.
358

359
        :param model_name_or_path:
360
            Name of the model or a path to the Hugging Face
361
            model on the local disk.
362
        :param device:
363
            The device on which the model is loaded. If `None`,
364
            the default device is automatically selected.
365

366
            If a device/device map is specified in `pipeline_kwargs`,
367
            it overrides this parameter.
368
        :param pipeline_kwargs:
369
            Keyword arguments passed to the pipeline. The
370
            pipeline can override these arguments.
371
        """
372
        super().__init__(NamedEntityExtractorBackend.HUGGING_FACE, device, pipeline_kwargs)
1✔
373

374
        transformers_import.check()
1✔
375

376
        self._model_name_or_path = model_name_or_path
1✔
377
        self.tokenizer: Optional[AutoTokenizer] = None
1✔
378
        self.model: Optional[AutoModelForTokenClassification] = None
1✔
379
        self.pipeline: Optional[HfPipeline] = None
1✔
380

381
    def initialize(self) -> None:
1✔
382
        token = self._pipeline_kwargs.get("token", None)
×
383
        self.tokenizer = AutoTokenizer.from_pretrained(self._model_name_or_path, token=token)
×
384
        self.model = AutoModelForTokenClassification.from_pretrained(self._model_name_or_path, token=token)
×
385

386
        pipeline_params: dict[str, Any] = {
×
387
            "task": "ner",
388
            "model": self.model,
389
            "tokenizer": self.tokenizer,
390
            "aggregation_strategy": "simple",
391
        }
392
        pipeline_params.update({k: v for k, v in self._pipeline_kwargs.items() if k not in pipeline_params})
×
393
        self.device.update_hf_kwargs(pipeline_params, overwrite=False)
×
394
        self.pipeline = pipeline(**pipeline_params)
×
395

396
    def annotate(self, texts: list[str], *, batch_size: int = 1) -> list[list[NamedEntityAnnotation]]:
1✔
397
        if not self.initialized:
×
398
            raise ComponentError("Hugging Face NER backend was not initialized - Did you call `warm_up()`?")
×
399

400
        assert self.pipeline is not None
×
401
        outputs = self.pipeline(texts, batch_size=batch_size)
×
402
        return [
×
403
            [
404
                NamedEntityAnnotation(
405
                    entity=annotation["entity"] if "entity" in annotation else annotation["entity_group"],
406
                    start=annotation["start"],
407
                    end=annotation["end"],
408
                    score=annotation["score"],
409
                )
410
                for annotation in annotations
411
            ]
412
            for annotations in outputs
413
        ]
414

415
    @property
1✔
416
    def initialized(self) -> bool:
1✔
417
        return self.tokenizer is not None and self.model is not None or self.pipeline is not None
×
418

419
    @property
1✔
420
    def model_name(self) -> str:
1✔
421
        return self._model_name_or_path
1✔
422

423

424
class _SpacyBackend(_NerBackend):
1✔
425
    """
426
    spaCy backend for NER.
427
    """
428

429
    def __init__(
1✔
430
        self, *, model_name_or_path: str, device: ComponentDevice, pipeline_kwargs: Optional[dict[str, Any]] = None
431
    ) -> None:
432
        """
433
        Construct a spaCy NER backend.
434

435
        :param model_name_or_path:
436
            Name of the model or a path to the spaCy
437
            model on the local disk.
438
        :param device:
439
            The device on which the model is loaded. If `None`,
440
            the default device is automatically selected.
441
        :param pipeline_kwargs:
442
            Keyword arguments passed to the pipeline. The
443
            pipeline can override these arguments.
444
        """
445
        super().__init__(NamedEntityExtractorBackend.SPACY, device, pipeline_kwargs)
×
446

447
        spacy_import.check()
×
448

449
        self._model_name_or_path = model_name_or_path
×
450
        self.pipeline: Optional[SpacyPipeline] = None
×
451

452
        if self.device.has_multiple_devices:
×
453
            raise ValueError("spaCy backend for named entity extractor only supports inference on single devices")
×
454

455
    def initialize(self):
1✔
456
        # We need to initialize the model on the GPU if needed.
457
        with self._select_device():
×
458
            self.pipeline = spacy.load(self._model_name_or_path)
×
459

460
        if not self.pipeline.has_pipe("ner"):
×
461
            raise ComponentError(f"spaCy pipeline '{self._model_name_or_path}' does not contain an NER component")
×
462

463
        # Disable unnecessary pipes.
464
        pipes_to_keep = ("ner", "tok2vec", "transformer", "curated_transformer")
×
465
        for name in self.pipeline.pipe_names:
×
466
            if name not in pipes_to_keep:
×
467
                self.pipeline.disable_pipe(name)
×
468

469
        self._pipeline_kwargs = {k: v for k, v in self._pipeline_kwargs.items() if k not in ("texts", "batch_size")}
×
470

471
    def annotate(self, texts: list[str], *, batch_size: int = 1) -> list[list[NamedEntityAnnotation]]:
1✔
472
        if not self.initialized:
×
473
            raise ComponentError("spaCy NER backend was not initialized - Did you call `warm_up()`?")
×
474

475
        assert self.pipeline is not None
×
476
        with self._select_device():
×
477
            outputs = list(self.pipeline.pipe(texts=texts, batch_size=batch_size, **self._pipeline_kwargs))
×
478

479
        return [
×
480
            [
481
                NamedEntityAnnotation(entity=entity.label_, start=entity.start_char, end=entity.end_char)
482
                for entity in doc.ents
483
            ]
484
            for doc in outputs
485
        ]
486

487
    @property
1✔
488
    def initialized(self) -> bool:
1✔
489
        return self.pipeline is not None
×
490

491
    @property
1✔
492
    def model_name(self) -> str:
1✔
493
        return self._model_name_or_path
×
494

495
    @contextmanager
1✔
496
    def _select_device(self):
1✔
497
        """
498
        Context manager used to run spaCy models on a specific GPU in a scoped manner.
499
        """
500

501
        # TODO: This won't restore the active device.
502
        # Since there are no opaque API functions to determine
503
        # the active device in spaCy/Thinc, we can't do much
504
        # about it as a consumer unless we start poking into their
505
        # internals.
506
        device_id = self._device.to_spacy()
×
507
        try:
×
508
            if device_id >= 0:
×
509
                spacy.require_gpu(device_id)
×
510
            yield
×
511
        finally:
512
            if device_id >= 0:
×
513
                spacy.require_cpu()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc