• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 15269350331

27 May 2025 07:35AM UTC coverage: 90.181% (+0.006%) from 90.175%
15269350331

push

github

web-flow
chore: make mypy run with `--check-untyped-defs`; fix some errors (#9447)

* chore: make mypy run with --check-untyped-defs; fix some errors

* small fixes

* use HfPipeline

* fix license error

11388 of 12628 relevant lines covered (90.18%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.07
haystack/components/classifiers/zero_shot_document_classifier.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from typing import Any, Dict, List, Optional
1✔
6

7
from haystack import Document, component, default_from_dict, default_to_dict
1✔
8
from haystack.lazy_imports import LazyImport
1✔
9
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
1✔
10
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_pipeline_kwargs, serialize_hf_model_kwargs
1✔
11

12
with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
1✔
13
    from transformers import Pipeline as HfPipeline
1✔
14
    from transformers import pipeline
1✔
15

16

17
@component
1✔
18
class TransformersZeroShotDocumentClassifier:
1✔
19
    """
20
    Performs zero-shot classification of documents based on given labels and adds the predicted label to their metadata.
21

22
    The component uses a Hugging Face pipeline for zero-shot classification.
23
    Provide the model and the set of labels to be used for categorization during initialization.
24
    Additionally, you can configure the component to allow multiple labels to be true.
25

26
    Classification is run on the document's content field by default. If you want it to run on another field, set the
27
    `classification_field` to one of the document's metadata fields.
28

29
    Available models for the task of zero-shot-classification include:
30
        - `valhalla/distilbart-mnli-12-3`
31
        - `cross-encoder/nli-distilroberta-base`
32
        - `cross-encoder/nli-deberta-v3-xsmall`
33

34
    ### Usage example
35

36
    The following is a pipeline that classifies documents based on predefined classification labels
37
    retrieved from a search pipeline:
38

39
    ```python
40
    from haystack import Document
41
    from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
42
    from haystack.document_stores.in_memory import InMemoryDocumentStore
43
    from haystack.core.pipeline import Pipeline
44
    from haystack.components.classifiers import TransformersZeroShotDocumentClassifier
45

46
    documents = [Document(id="0", content="Today was a nice day!"),
47
                 Document(id="1", content="Yesterday was a bad day!")]
48

49
    document_store = InMemoryDocumentStore()
50
    retriever = InMemoryBM25Retriever(document_store=document_store)
51
    document_classifier = TransformersZeroShotDocumentClassifier(
52
        model="cross-encoder/nli-deberta-v3-xsmall",
53
        labels=["positive", "negative"],
54
    )
55

56
    document_store.write_documents(documents)
57

58
    pipeline = Pipeline()
59
    pipeline.add_component(instance=retriever, name="retriever")
60
    pipeline.add_component(instance=document_classifier, name="document_classifier")
61
    pipeline.connect("retriever", "document_classifier")
62

63
    queries = ["How was your day today?", "How was your day yesterday?"]
64
    expected_predictions = ["positive", "negative"]
65

66
    for idx, query in enumerate(queries):
67
        result = pipeline.run({"retriever": {"query": query, "top_k": 1}})
68
        assert result["document_classifier"]["documents"][0].to_dict()["id"] == str(idx)
69
        assert (result["document_classifier"]["documents"][0].to_dict()["classification"]["label"]
70
                == expected_predictions[idx])
71
    ```
72
    """
73

74
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
75
        self,
76
        model: str,
77
        labels: List[str],
78
        multi_label: bool = False,
79
        classification_field: Optional[str] = None,
80
        device: Optional[ComponentDevice] = None,
81
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
82
        huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
83
    ):
84
        """
85
        Initializes the TransformersZeroShotDocumentClassifier.
86

87
        See the Hugging Face [website](https://huggingface.co/models?pipeline_tag=zero-shot-classification&sort=downloads&search=nli)
88
        for the full list of zero-shot classification models (NLI) models.
89

90
        :param model:
91
            The name or path of a Hugging Face model for zero shot document classification.
92
        :param labels:
93
            The set of possible class labels to classify each document into, for example,
94
            ["positive", "negative"]. The labels depend on the selected model.
95
        :param multi_label:
96
            Whether or not multiple candidate labels can be true.
97
            If `False`, the scores are normalized such that
98
            the sum of the label likelihoods for each sequence is 1. If `True`, the labels are considered
99
            independent and probabilities are normalized for each candidate by doing a softmax of the entailment
100
            score vs. the contradiction score.
101
        :param classification_field:
102
            Name of document's meta field to be used for classification.
103
            If not set, `Document.content` is used by default.
104
        :param device:
105
            The device on which the model is loaded. If `None`, the default device is automatically
106
            selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
107
        :param token:
108
            The Hugging Face token to use as HTTP bearer authorization.
109
            Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
110
        :param huggingface_pipeline_kwargs:
111
            Dictionary containing keyword arguments used to initialize the
112
            Hugging Face pipeline for text classification.
113
        """
114

115
        torch_and_transformers_import.check()
1✔
116

117
        self.classification_field = classification_field
1✔
118

119
        self.token = token
1✔
120
        self.labels = labels
1✔
121
        self.multi_label = multi_label
1✔
122

123
        huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs(
1✔
124
            huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {},
125
            model=model,
126
            task="zero-shot-classification",
127
            supported_tasks=["zero-shot-classification"],
128
            device=device,
129
            token=token,
130
        )
131

132
        self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
1✔
133
        self.pipeline: Optional[HfPipeline] = None
1✔
134

135
    def _get_telemetry_data(self) -> Dict[str, Any]:
1✔
136
        """
137
        Data that is sent to Posthog for usage analytics.
138
        """
139
        if isinstance(self.huggingface_pipeline_kwargs["model"], str):
×
140
            return {"model": self.huggingface_pipeline_kwargs["model"]}
×
141
        return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
×
142

143
    def warm_up(self):
1✔
144
        """
145
        Initializes the component.
146
        """
147
        if self.pipeline is None:
1✔
148
            self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
1✔
149

150
    def to_dict(self) -> Dict[str, Any]:
1✔
151
        """
152
        Serializes the component to a dictionary.
153

154
        :returns:
155
            Dictionary with serialized data.
156
        """
157
        serialization_dict = default_to_dict(
1✔
158
            self,
159
            labels=self.labels,
160
            model=self.huggingface_pipeline_kwargs["model"],
161
            huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
162
            token=self.token.to_dict() if self.token else None,
163
        )
164

165
        huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
1✔
166
        huggingface_pipeline_kwargs.pop("token", None)
1✔
167

168
        serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
1✔
169
        return serialization_dict
1✔
170

171
    @classmethod
1✔
172
    def from_dict(cls, data: Dict[str, Any]) -> "TransformersZeroShotDocumentClassifier":
1✔
173
        """
174
        Deserializes the component from a dictionary.
175

176
        :param data:
177
            Dictionary to deserialize from.
178
        :returns:
179
            Deserialized component.
180
        """
181
        deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
1✔
182
        if data["init_parameters"].get("huggingface_pipeline_kwargs") is not None:
1✔
183
            deserialize_hf_model_kwargs(data["init_parameters"]["huggingface_pipeline_kwargs"])
1✔
184
        return default_from_dict(cls, data)
1✔
185

186
    @component.output_types(documents=List[Document])
1✔
187
    def run(self, documents: List[Document], batch_size: int = 1):
1✔
188
        """
189
        Classifies the documents based on the provided labels and adds them to their metadata.
190

191
        The classification results are stored in the `classification` dict within
192
        each document's metadata. If `multi_label` is set to `True`, the scores for each label are available under
193
        the `details` key within the `classification` dictionary.
194

195
        :param documents:
196
            Documents to process.
197
        :param batch_size:
198
            Batch size used for processing the content in each document.
199
        :returns:
200
            A dictionary with the following key:
201
            - `documents`: A list of documents with an added metadata field called `classification`.
202
        """
203

204
        if self.pipeline is None:
1✔
205
            raise RuntimeError(
1✔
206
                "The component TransformerZeroShotDocumentClassifier wasn't warmed up. "
207
                "Run 'warm_up()' before calling 'run()'."
208
            )
209

210
        if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
1✔
211
            raise TypeError(
1✔
212
                "TransformerZeroShotDocumentClassifier expects a list of documents as input. "
213
                "In case you want to classify and route a text, please use the TransformersZeroShotTextRouter."
214
            )
215

216
        invalid_doc_ids = []
1✔
217

218
        for doc in documents:
1✔
219
            if self.classification_field is not None and self.classification_field not in doc.meta:
1✔
220
                invalid_doc_ids.append(doc.id)
×
221

222
        if invalid_doc_ids:
1✔
223
            raise ValueError(
×
224
                f"The following documents do not have the classification field '{self.classification_field}': "
225
                f"{', '.join(invalid_doc_ids)}"
226
            )
227

228
        texts = [
1✔
229
            (doc.content if self.classification_field is None else doc.meta[self.classification_field])
230
            for doc in documents
231
        ]
232

233
        predictions = self.pipeline(texts, self.labels, multi_label=self.multi_label, batch_size=batch_size)
1✔
234

235
        for prediction, document in zip(predictions, documents):
1✔
236
            formatted_prediction = {
1✔
237
                "label": prediction["labels"][0],
238
                "score": prediction["scores"][0],
239
                "details": dict(zip(prediction["labels"], prediction["scores"])),
240
            }
241
            document.meta["classification"] = formatted_prediction
1✔
242

243
        return {"documents": documents}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc