• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 16933015230

13 Aug 2025 09:18AM UTC coverage: 92.184% (+0.2%) from 91.969%
16933015230

Pull #9699

github

web-flow
Merge cfbd602e7 into 8160ea8bf
Pull Request #9699: feat: Update `source_id_meta_field` in `SentenceWindowRetriever` to also accept a list of values

12891 of 13984 relevant lines covered (92.18%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.53
haystack/components/classifiers/zero_shot_document_classifier.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from dataclasses import replace
1✔
6
from typing import Any, Optional
1✔
7

8
from haystack import Document, component, default_from_dict, default_to_dict
1✔
9
from haystack.lazy_imports import LazyImport
1✔
10
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
1✔
11
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_pipeline_kwargs, serialize_hf_model_kwargs
1✔
12

13
with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
1✔
14
    from transformers import Pipeline as HfPipeline
1✔
15
    from transformers import pipeline
1✔
16

17

18
@component
1✔
19
class TransformersZeroShotDocumentClassifier:
1✔
20
    """
21
    Performs zero-shot classification of documents based on given labels and adds the predicted label to their metadata.
22

23
    The component uses a Hugging Face pipeline for zero-shot classification.
24
    Provide the model and the set of labels to be used for categorization during initialization.
25
    Additionally, you can configure the component to allow multiple labels to be true.
26

27
    Classification is run on the document's content field by default. If you want it to run on another field, set the
28
    `classification_field` to one of the document's metadata fields.
29

30
    Available models for the task of zero-shot-classification include:
31
        - `valhalla/distilbart-mnli-12-3`
32
        - `cross-encoder/nli-distilroberta-base`
33
        - `cross-encoder/nli-deberta-v3-xsmall`
34

35
    ### Usage example
36

37
    The following is a pipeline that classifies documents based on predefined classification labels
38
    retrieved from a search pipeline:
39

40
    ```python
41
    from haystack import Document
42
    from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
43
    from haystack.document_stores.in_memory import InMemoryDocumentStore
44
    from haystack.core.pipeline import Pipeline
45
    from haystack.components.classifiers import TransformersZeroShotDocumentClassifier
46

47
    documents = [Document(id="0", content="Today was a nice day!"),
48
                 Document(id="1", content="Yesterday was a bad day!")]
49

50
    document_store = InMemoryDocumentStore()
51
    retriever = InMemoryBM25Retriever(document_store=document_store)
52
    document_classifier = TransformersZeroShotDocumentClassifier(
53
        model="cross-encoder/nli-deberta-v3-xsmall",
54
        labels=["positive", "negative"],
55
    )
56

57
    document_store.write_documents(documents)
58

59
    pipeline = Pipeline()
60
    pipeline.add_component(instance=retriever, name="retriever")
61
    pipeline.add_component(instance=document_classifier, name="document_classifier")
62
    pipeline.connect("retriever", "document_classifier")
63

64
    queries = ["How was your day today?", "How was your day yesterday?"]
65
    expected_predictions = ["positive", "negative"]
66

67
    for idx, query in enumerate(queries):
68
        result = pipeline.run({"retriever": {"query": query, "top_k": 1}})
69
        assert result["document_classifier"]["documents"][0].to_dict()["id"] == str(idx)
70
        assert (result["document_classifier"]["documents"][0].to_dict()["classification"]["label"]
71
                == expected_predictions[idx])
72
    ```
73
    """
74

75
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
76
        self,
77
        model: str,
78
        labels: list[str],
79
        multi_label: bool = False,
80
        classification_field: Optional[str] = None,
81
        device: Optional[ComponentDevice] = None,
82
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
83
        huggingface_pipeline_kwargs: Optional[dict[str, Any]] = None,
84
    ):
85
        """
86
        Initializes the TransformersZeroShotDocumentClassifier.
87

88
        See the Hugging Face [website](https://huggingface.co/models?pipeline_tag=zero-shot-classification&sort=downloads&search=nli)
89
        for the full list of zero-shot classification models (NLI) models.
90

91
        :param model:
92
            The name or path of a Hugging Face model for zero shot document classification.
93
        :param labels:
94
            The set of possible class labels to classify each document into, for example,
95
            ["positive", "negative"]. The labels depend on the selected model.
96
        :param multi_label:
97
            Whether or not multiple candidate labels can be true.
98
            If `False`, the scores are normalized such that
99
            the sum of the label likelihoods for each sequence is 1. If `True`, the labels are considered
100
            independent and probabilities are normalized for each candidate by doing a softmax of the entailment
101
            score vs. the contradiction score.
102
        :param classification_field:
103
            Name of document's meta field to be used for classification.
104
            If not set, `Document.content` is used by default.
105
        :param device:
106
            The device on which the model is loaded. If `None`, the default device is automatically
107
            selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
108
        :param token:
109
            The Hugging Face token to use as HTTP bearer authorization.
110
            Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
111
        :param huggingface_pipeline_kwargs:
112
            Dictionary containing keyword arguments used to initialize the
113
            Hugging Face pipeline for text classification.
114
        """
115

116
        torch_and_transformers_import.check()
1✔
117

118
        self.classification_field = classification_field
1✔
119

120
        self.token = token
1✔
121
        self.labels = labels
1✔
122
        self.multi_label = multi_label
1✔
123

124
        huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs(
1✔
125
            huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {},
126
            model=model,
127
            task="zero-shot-classification",
128
            supported_tasks=["zero-shot-classification"],
129
            device=device,
130
            token=token,
131
        )
132

133
        self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
1✔
134
        self.pipeline: Optional[HfPipeline] = None
1✔
135

136
    def _get_telemetry_data(self) -> dict[str, Any]:
1✔
137
        """
138
        Data that is sent to Posthog for usage analytics.
139
        """
140
        if isinstance(self.huggingface_pipeline_kwargs["model"], str):
×
141
            return {"model": self.huggingface_pipeline_kwargs["model"]}
×
142
        return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
×
143

144
    def warm_up(self):
1✔
145
        """
146
        Initializes the component.
147
        """
148
        if self.pipeline is None:
1✔
149
            self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
1✔
150

151
    def to_dict(self) -> dict[str, Any]:
1✔
152
        """
153
        Serializes the component to a dictionary.
154

155
        :returns:
156
            Dictionary with serialized data.
157
        """
158
        serialization_dict = default_to_dict(
1✔
159
            self,
160
            labels=self.labels,
161
            model=self.huggingface_pipeline_kwargs["model"],
162
            huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
163
            token=self.token.to_dict() if self.token else None,
164
        )
165

166
        huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
1✔
167
        huggingface_pipeline_kwargs.pop("token", None)
1✔
168

169
        serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
1✔
170
        return serialization_dict
1✔
171

172
    @classmethod
1✔
173
    def from_dict(cls, data: dict[str, Any]) -> "TransformersZeroShotDocumentClassifier":
1✔
174
        """
175
        Deserializes the component from a dictionary.
176

177
        :param data:
178
            Dictionary to deserialize from.
179
        :returns:
180
            Deserialized component.
181
        """
182
        deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
1✔
183
        if data["init_parameters"].get("huggingface_pipeline_kwargs") is not None:
1✔
184
            deserialize_hf_model_kwargs(data["init_parameters"]["huggingface_pipeline_kwargs"])
1✔
185
        return default_from_dict(cls, data)
1✔
186

187
    @component.output_types(documents=list[Document])
1✔
188
    def run(self, documents: list[Document], batch_size: int = 1):
1✔
189
        """
190
        Classifies the documents based on the provided labels and adds them to their metadata.
191

192
        The classification results are stored in the `classification` dict within
193
        each document's metadata. If `multi_label` is set to `True`, the scores for each label are available under
194
        the `details` key within the `classification` dictionary.
195

196
        :param documents:
197
            Documents to process.
198
        :param batch_size:
199
            Batch size used for processing the content in each document.
200
        :returns:
201
            A dictionary with the following key:
202
            - `documents`: A list of documents with an added metadata field called `classification`.
203
        """
204

205
        if self.pipeline is None:
1✔
206
            raise RuntimeError(
1✔
207
                "The component TransformerZeroShotDocumentClassifier wasn't warmed up. "
208
                "Run 'warm_up()' before calling 'run()'."
209
            )
210

211
        if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
1✔
212
            raise TypeError(
1✔
213
                "TransformerZeroShotDocumentClassifier expects a list of documents as input. "
214
                "In case you want to classify and route a text, please use the TransformersZeroShotTextRouter."
215
            )
216

217
        invalid_doc_ids = []
1✔
218

219
        for doc in documents:
1✔
220
            if self.classification_field is not None and self.classification_field not in doc.meta:
1✔
221
                invalid_doc_ids.append(doc.id)
×
222

223
        if invalid_doc_ids:
1✔
224
            raise ValueError(
×
225
                f"The following documents do not have the classification field '{self.classification_field}': "
226
                f"{', '.join(invalid_doc_ids)}"
227
            )
228

229
        texts = [
1✔
230
            (doc.content if self.classification_field is None else doc.meta[self.classification_field])
231
            for doc in documents
232
        ]
233

234
        predictions = self.pipeline(texts, self.labels, multi_label=self.multi_label, batch_size=batch_size)
1✔
235

236
        new_documents = []
1✔
237
        for prediction, document in zip(predictions, documents):
1✔
238
            formatted_prediction = {
1✔
239
                "label": prediction["labels"][0],
240
                "score": prediction["scores"][0],
241
                "details": dict(zip(prediction["labels"], prediction["scores"])),
242
            }
243
            new_meta = {**document.meta, "classification": formatted_prediction}
1✔
244
            new_documents.append(replace(document, meta=new_meta))
1✔
245

246
        return {"documents": new_documents}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc