• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 12240140835

09 Dec 2024 04:39PM UTC coverage: 90.335% (+0.001%) from 90.334%
12240140835

Pull #8610

github

web-flow
Merge 3ff0aa0e9 into 6f983a22c
Pull Request #8610: chore: fixing `pylint` issues

8038 of 8898 relevant lines covered (90.33%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

66.67
haystack/components/extractors/named_entity_extractor.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from abc import ABC, abstractmethod
1✔
6
from contextlib import contextmanager
1✔
7
from dataclasses import dataclass
1✔
8
from enum import Enum
1✔
9
from typing import Any, Dict, List, Optional, Union
1✔
10

11
from haystack import ComponentError, DeserializationError, Document, component, default_from_dict, default_to_dict
1✔
12
from haystack.lazy_imports import LazyImport
1✔
13
from haystack.utils.device import ComponentDevice
1✔
14

15
with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transformers_import:
1✔
16
    from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
1✔
17
    from transformers import Pipeline as HfPipeline
1✔
18

19
with LazyImport(message="Run 'pip install spacy'") as spacy_import:
1✔
20
    import spacy
1✔
21
    from spacy import Language as SpacyPipeline
1✔
22

23

24
class NamedEntityExtractorBackend(Enum):
1✔
25
    """
26
    NLP backend to use for Named Entity Recognition.
27
    """
28

29
    #: Uses an Hugging Face model and pipeline.
30
    HUGGING_FACE = "hugging_face"
1✔
31

32
    #: Uses a spaCy model and pipeline.
33
    SPACY = "spacy"
1✔
34

35
    def __str__(self):
1✔
36
        return self.value
×
37

38
    @staticmethod
1✔
39
    def from_str(string: str) -> "NamedEntityExtractorBackend":
1✔
40
        """
41
        Convert a string to a NamedEntityExtractorBackend enum.
42
        """
43
        enum_map = {e.value: e for e in NamedEntityExtractorBackend}
1✔
44
        mode = enum_map.get(string)
1✔
45
        if mode is None:
1✔
46
            msg = (
1✔
47
                f"Invalid backend '{string}' for named entity extractor. "
48
                f"Supported backends are: {list(enum_map.keys())}"
49
            )
50
            raise ComponentError(msg)
1✔
51
        return mode
1✔
52

53

54
@dataclass
1✔
55
class NamedEntityAnnotation:
1✔
56
    """
57
    Describes a single NER annotation.
58

59
    :param entity:
60
        Entity label.
61
    :param start:
62
        Start index of the entity in the document.
63
    :param end:
64
        End index of the entity in the document.
65
    :param score:
66
        Score calculated by the model.
67
    """
68

69
    entity: str
1✔
70
    start: int
1✔
71
    end: int
1✔
72
    score: Optional[float] = None
1✔
73

74

75
@component
1✔
76
class NamedEntityExtractor:
1✔
77
    """
78
    Annotates named entities in a collection of documents.
79

80
    The component supports two backends: Hugging Face and spaCy. The
81
    former can be used with any sequence classification model from the
82
    [Hugging Face model hub](https://huggingface.co/models), while the
83
    latter can be used with any [spaCy model](https://spacy.io/models)
84
    that contains an NER component. Annotations are stored as metadata
85
    in the documents.
86

87
    Usage example:
88
    ```python
89
    from haystack import Document
90
    from haystack.components.extractors.named_entity_extractor import NamedEntityExtractor
91

92
    documents = [
93
        Document(content="I'm Merlin, the happy pig!"),
94
        Document(content="My name is Clara and I live in Berkeley, California."),
95
    ]
96
    extractor = NamedEntityExtractor(backend="hugging_face", model="dslim/bert-base-NER")
97
    extractor.warm_up()
98
    results = extractor.run(documents=documents)["documents"]
99
    annotations = [NamedEntityExtractor.get_stored_annotations(doc) for doc in results]
100
    print(annotations)
101
    ```
102
    """
103

104
    _METADATA_KEY = "named_entities"
1✔
105

106
    def __init__(
1✔
107
        self,
108
        *,
109
        backend: Union[str, NamedEntityExtractorBackend],
110
        model: str,
111
        pipeline_kwargs: Optional[Dict[str, Any]] = None,
112
        device: Optional[ComponentDevice] = None,
113
    ) -> None:
114
        """
115
        Create a Named Entity extractor component.
116

117
        :param backend:
118
            Backend to use for NER.
119
        :param model:
120
            Name of the model or a path to the model on
121
            the local disk. Dependent on the backend.
122
        :param pipeline_kwargs:
123
            Keyword arguments passed to the pipeline. The
124
            pipeline can override these arguments. Dependent on the backend.
125
        :param device:
126
            The device on which the model is loaded. If `None`,
127
            the default device is automatically selected. If a
128
            device/device map is specified in `pipeline_kwargs`,
129
            it overrides this parameter (only applicable to the
130
            HuggingFace backend).
131
        """
132

133
        if isinstance(backend, str):
1✔
134
            backend = NamedEntityExtractorBackend.from_str(backend)
1✔
135

136
        self._backend: _NerBackend
1✔
137
        self._warmed_up: bool = False
1✔
138
        device = ComponentDevice.resolve_device(device)
1✔
139

140
        if backend == NamedEntityExtractorBackend.HUGGING_FACE:
1✔
141
            self._backend = _HfBackend(model_name_or_path=model, device=device, pipeline_kwargs=pipeline_kwargs)
1✔
142
        elif backend == NamedEntityExtractorBackend.SPACY:
1✔
143
            self._backend = _SpacyBackend(model_name_or_path=model, device=device, pipeline_kwargs=pipeline_kwargs)
1✔
144
        else:
145
            raise ComponentError(f"Unknown NER backend '{type(backend).__name__}' for extractor")
×
146

147
    def warm_up(self):
1✔
148
        """
149
        Initialize the component.
150

151
        :raises ComponentError:
152
            If the backend fails to initialize successfully.
153
        """
154
        if self._warmed_up:
×
155
            return
×
156

157
        try:
×
158
            self._backend.initialize()
×
159
            self._warmed_up = True
×
160
        except Exception as e:
×
161
            raise ComponentError(
×
162
                f"Named entity extractor with backend '{self._backend.type} failed to initialize."
163
            ) from e
164

165
    @component.output_types(documents=List[Document])
1✔
166
    def run(self, documents: List[Document], batch_size: int = 1) -> Dict[str, Any]:
1✔
167
        """
168
        Annotate named entities in each document and store the annotations in the document's metadata.
169

170
        :param documents:
171
            Documents to process.
172
        :param batch_size:
173
            Batch size used for processing the documents.
174
        :returns:
175
            Processed documents.
176
        :raises ComponentError:
177
            If the backend fails to process a document.
178
        """
179
        if not self._warmed_up:
×
180
            msg = "The component NamedEntityExtractor was not warmed up. Call warm_up() before running the component."
×
181
            raise RuntimeError(msg)
×
182

183
        texts = [doc.content if doc.content is not None else "" for doc in documents]
×
184
        annotations = self._backend.annotate(texts, batch_size=batch_size)
×
185

186
        if len(annotations) != len(documents):
×
187
            raise ComponentError(
×
188
                "NER backend did not return the correct number of annotations; "
189
                f"got {len(annotations)} but expected {len(documents)}"
190
            )
191

192
        for doc, doc_annotations in zip(documents, annotations):
×
193
            doc.meta[self._METADATA_KEY] = doc_annotations
×
194

195
        return {"documents": documents}
×
196

197
    def to_dict(self) -> Dict[str, Any]:
1✔
198
        """
199
        Serializes the component to a dictionary.
200

201
        :returns:
202
            Dictionary with serialized data.
203
        """
204
        return default_to_dict(
1✔
205
            self,
206
            backend=self._backend.type.name,
207
            model=self._backend.model_name,
208
            device=self._backend.device.to_dict(),
209
            pipeline_kwargs=self._backend._pipeline_kwargs,
210
        )
211

212
    @classmethod
1✔
213
    def from_dict(cls, data: Dict[str, Any]) -> "NamedEntityExtractor":
1✔
214
        """
215
        Deserializes the component from a dictionary.
216

217
        :param data:
218
            Dictionary to deserialize from.
219
        :returns:
220
            Deserialized component.
221
        """
222
        try:
1✔
223
            init_params = data["init_parameters"]
1✔
224
            if init_params.get("device") is not None:
1✔
225
                init_params["device"] = ComponentDevice.from_dict(init_params["device"])
1✔
226
            init_params["backend"] = NamedEntityExtractorBackend[init_params["backend"]]
1✔
227
            return default_from_dict(cls, data)
1✔
228
        except Exception as e:
1✔
229
            raise DeserializationError(f"Couldn't deserialize {cls.__name__} instance") from e
1✔
230

231
    @property
1✔
232
    def initialized(self) -> bool:
1✔
233
        """
234
        Returns if the extractor is ready to annotate text.
235
        """
236
        return self._backend.initialized
×
237

238
    @classmethod
1✔
239
    def get_stored_annotations(cls, document: Document) -> Optional[List[NamedEntityAnnotation]]:
1✔
240
        """
241
        Returns the document's named entity annotations stored in its metadata, if any.
242

243
        :param document:
244
            Document whose annotations are to be fetched.
245
        :returns:
246
            The stored annotations.
247
        """
248

249
        return document.meta.get(cls._METADATA_KEY)
×
250

251

252
class _NerBackend(ABC):
1✔
253
    """
254
    Base class for NER backends.
255
    """
256

257
    def __init__(
1✔
258
        self,
259
        _type: NamedEntityExtractorBackend,
260
        device: ComponentDevice,
261
        pipeline_kwargs: Optional[Dict[str, Any]] = None,
262
    ) -> None:
263
        super().__init__()
1✔
264

265
        self._type = _type
1✔
266
        self._device = device
1✔
267
        self._pipeline_kwargs = pipeline_kwargs if pipeline_kwargs is not None else {}
1✔
268

269
    @abstractmethod
1✔
270
    def initialize(self):
1✔
271
        """
272
        Initializes the backend. This would usually entail loading models, pipelines, and so on.
273
        """
274

275
    @property
1✔
276
    @abstractmethod
1✔
277
    def initialized(self) -> bool:
1✔
278
        """
279
        Returns if the backend has been initialized, for example, ready to annotate text.
280
        """
281

282
    @abstractmethod
1✔
283
    def annotate(self, texts: List[str], *, batch_size: int = 1) -> List[List[NamedEntityAnnotation]]:
1✔
284
        """
285
        Predict annotations for a collection of documents.
286

287
        :param texts:
288
            Raw texts to be annotated.
289
        :param batch_size:
290
            Size of text batches that are
291
            passed to the model.
292
        :returns:
293
            NER annotations.
294
        """
295

296
    @property
1✔
297
    @abstractmethod
1✔
298
    def model_name(self) -> str:
1✔
299
        """
300
        Returns the model name or path on the local disk.
301
        """
302

303
    @property
1✔
304
    def device(self) -> ComponentDevice:
1✔
305
        """
306
        The device on which the backend's model is loaded.
307

308
        :returns:
309
            The device on which the backend's model is loaded.
310
        """
311
        return self._device
1✔
312

313
    @property
1✔
314
    def type(self) -> NamedEntityExtractorBackend:
1✔
315
        """
316
        Returns the type of the backend.
317
        """
318
        return self._type
1✔
319

320

321
class _HfBackend(_NerBackend):
1✔
322
    """
323
    Hugging Face backend for NER.
324
    """
325

326
    def __init__(
1✔
327
        self, *, model_name_or_path: str, device: ComponentDevice, pipeline_kwargs: Optional[Dict[str, Any]] = None
328
    ) -> None:
329
        """
330
        Construct a Hugging Face NER backend.
331

332
        :param model_name_or_path:
333
            Name of the model or a path to the Hugging Face
334
            model on the local disk.
335
        :param device:
336
            The device on which the model is loaded. If `None`,
337
            the default device is automatically selected.
338

339
            If a device/device map is specified in `pipeline_kwargs`,
340
            it overrides this parameter.
341
        :param pipeline_kwargs:
342
            Keyword arguments passed to the pipeline. The
343
            pipeline can override these arguments.
344
        """
345
        super().__init__(NamedEntityExtractorBackend.HUGGING_FACE, device, pipeline_kwargs)
1✔
346

347
        transformers_import.check()
1✔
348

349
        self._model_name_or_path = model_name_or_path
1✔
350
        self.tokenizer: Optional[AutoTokenizer] = None
1✔
351
        self.model: Optional[AutoModelForTokenClassification] = None
1✔
352
        self.pipeline: Optional[HfPipeline] = None
1✔
353

354
    def initialize(self):
1✔
355
        self.tokenizer = AutoTokenizer.from_pretrained(self._model_name_or_path)
×
356
        self.model = AutoModelForTokenClassification.from_pretrained(self._model_name_or_path)
×
357

358
        pipeline_params = {
×
359
            "task": "ner",
360
            "model": self.model,
361
            "tokenizer": self.tokenizer,
362
            "aggregation_strategy": "simple",
363
        }
364
        pipeline_params.update({k: v for k, v in self._pipeline_kwargs.items() if k not in pipeline_params})
×
365
        self.device.update_hf_kwargs(pipeline_params, overwrite=False)
×
366
        self.pipeline = pipeline(**pipeline_params)
×
367

368
    def annotate(self, texts: List[str], *, batch_size: int = 1) -> List[List[NamedEntityAnnotation]]:
1✔
369
        if not self.initialized:
×
370
            raise ComponentError("Hugging Face NER backend was not initialized - Did you call `warm_up()`?")
×
371

372
        assert self.pipeline is not None
×
373
        outputs = self.pipeline(texts, batch_size=batch_size)
×
374
        return [
×
375
            [
376
                NamedEntityAnnotation(
377
                    entity=annotation["entity"] if "entity" in annotation else annotation["entity_group"],
378
                    start=annotation["start"],
379
                    end=annotation["end"],
380
                    score=annotation["score"],
381
                )
382
                for annotation in annotations
383
            ]
384
            for annotations in outputs
385
        ]
386

387
    @property
1✔
388
    def initialized(self) -> bool:
1✔
389
        return self.tokenizer is not None and self.model is not None or self.pipeline is not None
×
390

391
    @property
1✔
392
    def model_name(self) -> str:
1✔
393
        return self._model_name_or_path
1✔
394

395

396
class _SpacyBackend(_NerBackend):
1✔
397
    """
398
    spaCy backend for NER.
399
    """
400

401
    def __init__(
1✔
402
        self, *, model_name_or_path: str, device: ComponentDevice, pipeline_kwargs: Optional[Dict[str, Any]] = None
403
    ) -> None:
404
        """
405
        Construct a spaCy NER backend.
406

407
        :param model_name_or_path:
408
            Name of the model or a path to the spaCy
409
            model on the local disk.
410
        :param device:
411
            The device on which the model is loaded. If `None`,
412
            the default device is automatically selected.
413
        :param pipeline_kwargs:
414
            Keyword arguments passed to the pipeline. The
415
            pipeline can override these arguments.
416
        """
417
        super().__init__(NamedEntityExtractorBackend.SPACY, device, pipeline_kwargs)
1✔
418

419
        spacy_import.check()
1✔
420

421
        self._model_name_or_path = model_name_or_path
1✔
422
        self.pipeline: Optional[SpacyPipeline] = None
1✔
423

424
        if self.device.has_multiple_devices:
1✔
425
            raise ValueError("spaCy backend for named entity extractor only supports inference on single devices")
×
426

427
    def initialize(self):
1✔
428
        # We need to initialize the model on the GPU if needed.
429
        with self._select_device():
×
430
            self.pipeline = spacy.load(self._model_name_or_path)
×
431

432
        if not self.pipeline.has_pipe("ner"):
×
433
            raise ComponentError(f"spaCy pipeline '{self._model_name_or_path}' does not contain an NER component")
×
434

435
        # Disable unnecessary pipes.
436
        pipes_to_keep = ("ner", "tok2vec", "transformer", "curated_transformer")
×
437
        for name in self.pipeline.pipe_names:
×
438
            if name not in pipes_to_keep:
×
439
                self.pipeline.disable_pipe(name)
×
440

441
        self._pipeline_kwargs = {k: v for k, v in self._pipeline_kwargs.items() if k not in ("texts", "batch_size")}
×
442

443
    def annotate(self, texts: List[str], *, batch_size: int = 1) -> List[List[NamedEntityAnnotation]]:
1✔
444
        if not self.initialized:
×
445
            raise ComponentError("spaCy NER backend was not initialized - Did you call `warm_up()`?")
×
446

447
        assert self.pipeline is not None
×
448
        with self._select_device():
×
449
            outputs = list(self.pipeline.pipe(texts=texts, batch_size=batch_size, **self._pipeline_kwargs))
×
450

451
        return [
×
452
            [
453
                NamedEntityAnnotation(entity=entity.label_, start=entity.start_char, end=entity.end_char)
454
                for entity in doc.ents
455
            ]
456
            for doc in outputs
457
        ]
458

459
    @property
1✔
460
    def initialized(self) -> bool:
1✔
461
        return self.pipeline is not None
×
462

463
    @property
1✔
464
    def model_name(self) -> str:
1✔
465
        return self._model_name_or_path
×
466

467
    @contextmanager
1✔
468
    def _select_device(self):
1✔
469
        """
470
        Context manager used to run spaCy models on a specific GPU in a scoped manner.
471
        """
472

473
        # TODO: This won't restore the active device.
474
        # Since there are no opaque API functions to determine
475
        # the active device in spaCy/Thinc, we can't do much
476
        # about it as a consumer unless we start poking into their
477
        # internals.
478
        device_id = self._device.to_spacy()
×
479
        try:
×
480
            if device_id >= 0:
×
481
                spacy.require_gpu(device_id)
×
482
            yield
×
483
        finally:
484
            if device_id >= 0:
×
485
                spacy.require_cpu()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc