• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 21203210931

21 Jan 2026 08:55AM UTC coverage: 92.211% (-0.007%) from 92.218%
21203210931

push

github

web-flow
chore: Simplify `to_dict`, `from_dict` with default (de-)serialization of Secrets (#10411)

* drop to_dict and deserialize_secrets_inplace

* add release note

---------

Co-authored-by: David S. Batista <dsbatista@gmail.com>

14396 of 15612 relevant lines covered (92.21%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.38
haystack/components/classifiers/zero_shot_document_classifier.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from dataclasses import replace
1✔
6
from typing import Any
1✔
7

8
from haystack import Document, component, default_from_dict, default_to_dict
1✔
9
from haystack.lazy_imports import LazyImport
1✔
10
from haystack.utils import ComponentDevice, Secret
1✔
11
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_pipeline_kwargs, serialize_hf_model_kwargs
1✔
12

13
with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
1✔
14
    from transformers import Pipeline as HfPipeline
1✔
15
    from transformers import pipeline
1✔
16

17

18
@component
1✔
19
class TransformersZeroShotDocumentClassifier:
1✔
20
    """
21
    Performs zero-shot classification of documents based on given labels and adds the predicted label to their metadata.
22

23
    The component uses a Hugging Face pipeline for zero-shot classification.
24
    Provide the model and the set of labels to be used for categorization during initialization.
25
    Additionally, you can configure the component to allow multiple labels to be true.
26

27
    Classification is run on the document's content field by default. If you want it to run on another field, set the
28
    `classification_field` to one of the document's metadata fields.
29

30
    Available models for the task of zero-shot-classification include:
31
        - `valhalla/distilbart-mnli-12-3`
32
        - `cross-encoder/nli-distilroberta-base`
33
        - `cross-encoder/nli-deberta-v3-xsmall`
34

35
    ### Usage example
36

37
    The following is a pipeline that classifies documents based on predefined classification labels
38
    retrieved from a search pipeline:
39

40
    ```python
41
    from haystack import Document
42
    from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
43
    from haystack.document_stores.in_memory import InMemoryDocumentStore
44
    from haystack.core.pipeline import Pipeline
45
    from haystack.components.classifiers import TransformersZeroShotDocumentClassifier
46

47
    documents = [Document(id="0", content="Today was a nice day!"),
48
                 Document(id="1", content="Yesterday was a bad day!")]
49

50
    document_store = InMemoryDocumentStore()
51
    retriever = InMemoryBM25Retriever(document_store=document_store)
52
    document_classifier = TransformersZeroShotDocumentClassifier(
53
        model="cross-encoder/nli-deberta-v3-xsmall",
54
        labels=["positive", "negative"],
55
    )
56

57
    document_store.write_documents(documents)
58

59
    pipeline = Pipeline()
60
    pipeline.add_component(instance=retriever, name="retriever")
61
    pipeline.add_component(instance=document_classifier, name="document_classifier")
62
    pipeline.connect("retriever", "document_classifier")
63

64
    queries = ["How was your day today?", "How was your day yesterday?"]
65
    expected_predictions = ["positive", "negative"]
66

67
    for idx, query in enumerate(queries):
68
        result = pipeline.run({"retriever": {"query": query, "top_k": 1}})
69
        assert result["document_classifier"]["documents"][0].to_dict()["id"] == str(idx)
70
        assert (result["document_classifier"]["documents"][0].to_dict()["classification"]["label"]
71
                == expected_predictions[idx])
72
    ```
73
    """
74

75
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
76
        self,
77
        model: str,
78
        labels: list[str],
79
        multi_label: bool = False,
80
        classification_field: str | None = None,
81
        device: ComponentDevice | None = None,
82
        token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
83
        huggingface_pipeline_kwargs: dict[str, Any] | None = None,
84
    ):
85
        """
86
        Initializes the TransformersZeroShotDocumentClassifier.
87

88
        See the Hugging Face [website](https://huggingface.co/models?pipeline_tag=zero-shot-classification&sort=downloads&search=nli)
89
        for the full list of zero-shot classification models (NLI) models.
90

91
        :param model:
92
            The name or path of a Hugging Face model for zero shot document classification.
93
        :param labels:
94
            The set of possible class labels to classify each document into, for example,
95
            ["positive", "negative"]. The labels depend on the selected model.
96
        :param multi_label:
97
            Whether or not multiple candidate labels can be true.
98
            If `False`, the scores are normalized such that
99
            the sum of the label likelihoods for each sequence is 1. If `True`, the labels are considered
100
            independent and probabilities are normalized for each candidate by doing a softmax of the entailment
101
            score vs. the contradiction score.
102
        :param classification_field:
103
            Name of document's meta field to be used for classification.
104
            If not set, `Document.content` is used by default.
105
        :param device:
106
            The device on which the model is loaded. If `None`, the default device is automatically
107
            selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
108
        :param token:
109
            The Hugging Face token to use as HTTP bearer authorization.
110
            Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
111
        :param huggingface_pipeline_kwargs:
112
            Dictionary containing keyword arguments used to initialize the
113
            Hugging Face pipeline for text classification.
114
        """
115

116
        torch_and_transformers_import.check()
1✔
117

118
        self.classification_field = classification_field
1✔
119

120
        self.token = token
1✔
121
        self.labels = labels
1✔
122
        self.multi_label = multi_label
1✔
123

124
        huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs(
1✔
125
            huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {},
126
            model=model,
127
            task="zero-shot-classification",
128
            supported_tasks=["zero-shot-classification"],
129
            device=device,
130
            token=token,
131
        )
132

133
        self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
1✔
134
        self.pipeline: HfPipeline | None = None
1✔
135

136
    def _get_telemetry_data(self) -> dict[str, Any]:
1✔
137
        """
138
        Data that is sent to Posthog for usage analytics.
139
        """
140
        if isinstance(self.huggingface_pipeline_kwargs["model"], str):
×
141
            return {"model": self.huggingface_pipeline_kwargs["model"]}
×
142
        return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
×
143

144
    def warm_up(self):
1✔
145
        """
146
        Initializes the component.
147
        """
148
        if self.pipeline is None:
1✔
149
            self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
1✔
150

151
    def to_dict(self) -> dict[str, Any]:
1✔
152
        """
153
        Serializes the component to a dictionary.
154

155
        :returns:
156
            Dictionary with serialized data.
157
        """
158
        serialization_dict = default_to_dict(
1✔
159
            self,
160
            labels=self.labels,
161
            model=self.huggingface_pipeline_kwargs["model"],
162
            huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
163
            token=self.token,
164
        )
165

166
        huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
1✔
167
        huggingface_pipeline_kwargs.pop("token", None)
1✔
168

169
        serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
1✔
170
        return serialization_dict
1✔
171

172
    @classmethod
1✔
173
    def from_dict(cls, data: dict[str, Any]) -> "TransformersZeroShotDocumentClassifier":
1✔
174
        """
175
        Deserializes the component from a dictionary.
176

177
        :param data:
178
            Dictionary to deserialize from.
179
        :returns:
180
            Deserialized component.
181
        """
182
        if data["init_parameters"].get("huggingface_pipeline_kwargs") is not None:
1✔
183
            deserialize_hf_model_kwargs(data["init_parameters"]["huggingface_pipeline_kwargs"])
1✔
184
        return default_from_dict(cls, data)
1✔
185

186
    @component.output_types(documents=list[Document])
1✔
187
    def run(self, documents: list[Document], batch_size: int = 1):
1✔
188
        """
189
        Classifies the documents based on the provided labels and adds them to their metadata.
190

191
        The classification results are stored in the `classification` dict within
192
        each document's metadata. If `multi_label` is set to `True`, the scores for each label are available under
193
        the `details` key within the `classification` dictionary.
194

195
        :param documents:
196
            Documents to process.
197
        :param batch_size:
198
            Batch size used for processing the content in each document.
199
        :returns:
200
            A dictionary with the following key:
201
            - `documents`: A list of documents with an added metadata field called `classification`.
202
        """
203

204
        if self.pipeline is None:
1✔
205
            self.warm_up()
1✔
206

207
        if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
1✔
208
            raise TypeError(
1✔
209
                "TransformerZeroShotDocumentClassifier expects a list of documents as input. "
210
                "In case you want to classify and route a text, please use the TransformersZeroShotTextRouter."
211
            )
212

213
        invalid_doc_ids = []
1✔
214

215
        for doc in documents:
1✔
216
            if self.classification_field is not None and self.classification_field not in doc.meta:
1✔
217
                invalid_doc_ids.append(doc.id)
×
218

219
        if invalid_doc_ids:
1✔
220
            raise ValueError(
×
221
                f"The following documents do not have the classification field '{self.classification_field}': "
222
                f"{', '.join(invalid_doc_ids)}"
223
            )
224

225
        texts = [
1✔
226
            (doc.content if self.classification_field is None else doc.meta[self.classification_field])
227
            for doc in documents
228
        ]
229

230
        # mypy doesn't know this is set in warm_up
231
        predictions = self.pipeline(  # type: ignore[misc]
1✔
232
            texts, self.labels, multi_label=self.multi_label, batch_size=batch_size
233
        )
234

235
        new_documents = []
1✔
236
        for prediction, document in zip(predictions, documents):
1✔
237
            formatted_prediction = {
1✔
238
                "label": prediction["labels"][0],
239
                "score": prediction["scores"][0],
240
                "details": dict(zip(prediction["labels"], prediction["scores"])),
241
            }
242
            new_meta = {**document.meta, "classification": formatted_prediction}
1✔
243
            new_documents.append(replace(document, meta=new_meta))
1✔
244

245
        return {"documents": new_documents}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc