• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 13972131258

20 Mar 2025 02:43PM UTC coverage: 90.021% (-0.03%) from 90.054%
13972131258

Pull #9069

github

web-flow
Merge 8371761b0 into 67ab3788e
Pull Request #9069: refactor!: `ChatMessage` serialization-deserialization updates

9833 of 10923 relevant lines covered (90.02%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.91
haystack/components/classifiers/zero_shot_document_classifier.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from typing import Any, Dict, List, Optional
1✔
6

7
from haystack import Document, component, default_from_dict, default_to_dict
1✔
8
from haystack.lazy_imports import LazyImport
1✔
9
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
1✔
10
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_pipeline_kwargs, serialize_hf_model_kwargs
1✔
11

12
with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
1✔
13
    from transformers import pipeline
1✔
14

15

16
@component
1✔
17
class TransformersZeroShotDocumentClassifier:
1✔
18
    """
19
    Performs zero-shot classification of documents based on given labels and adds the predicted label to their metadata.
20

21
    The component uses a Hugging Face pipeline for zero-shot classification.
22
    Provide the model and the set of labels to be used for categorization during initialization.
23
    Additionally, you can configure the component to allow multiple labels to be true.
24

25
    Classification is run on the document's content field by default. If you want it to run on another field, set the
26
    `classification_field` to one of the document's metadata fields.
27

28
    Available models for the task of zero-shot-classification include:
29
        - `valhalla/distilbart-mnli-12-3`
30
        - `cross-encoder/nli-distilroberta-base`
31
        - `cross-encoder/nli-deberta-v3-xsmall`
32

33
    ### Usage example
34

35
    The following is a pipeline that classifies documents based on predefined classification labels
36
    retrieved from a search pipeline:
37

38
    ```python
39
    from haystack import Document
40
    from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
41
    from haystack.document_stores.in_memory import InMemoryDocumentStore
42
    from haystack.core.pipeline import Pipeline
43
    from haystack.components.classifiers import TransformersZeroShotDocumentClassifier
44

45
    documents = [Document(id="0", content="Today was a nice day!"),
46
                 Document(id="1", content="Yesterday was a bad day!")]
47

48
    document_store = InMemoryDocumentStore()
49
    retriever = InMemoryBM25Retriever(document_store=document_store)
50
    document_classifier = TransformersZeroShotDocumentClassifier(
51
        model="cross-encoder/nli-deberta-v3-xsmall",
52
        labels=["positive", "negative"],
53
    )
54

55
    document_store.write_documents(documents)
56

57
    pipeline = Pipeline()
58
    pipeline.add_component(instance=retriever, name="retriever")
59
    pipeline.add_component(instance=document_classifier, name="document_classifier")
60
    pipeline.connect("retriever", "document_classifier")
61

62
    queries = ["How was your day today?", "How was your day yesterday?"]
63
    expected_predictions = ["positive", "negative"]
64

65
    for idx, query in enumerate(queries):
66
        result = pipeline.run({"retriever": {"query": query, "top_k": 1}})
67
        assert result["document_classifier"]["documents"][0].to_dict()["id"] == str(idx)
68
        assert (result["document_classifier"]["documents"][0].to_dict()["classification"]["label"]
69
                == expected_predictions[idx])
70
    ```
71
    """
72

73
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
74
        self,
75
        model: str,
76
        labels: List[str],
77
        multi_label: bool = False,
78
        classification_field: Optional[str] = None,
79
        device: Optional[ComponentDevice] = None,
80
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
81
        huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
82
    ):
83
        """
84
        Initializes the TransformersZeroShotDocumentClassifier.
85

86
        See the Hugging Face [website](https://huggingface.co/models?pipeline_tag=zero-shot-classification&sort=downloads&search=nli)
87
        for the full list of zero-shot classification models (NLI) models.
88

89
        :param model:
90
            The name or path of a Hugging Face model for zero shot document classification.
91
        :param labels:
92
            The set of possible class labels to classify each document into, for example,
93
            ["positive", "negative"]. The labels depend on the selected model.
94
        :param multi_label:
95
            Whether or not multiple candidate labels can be true.
96
            If `False`, the scores are normalized such that
97
            the sum of the label likelihoods for each sequence is 1. If `True`, the labels are considered
98
            independent and probabilities are normalized for each candidate by doing a softmax of the entailment
99
            score vs. the contradiction score.
100
        :param classification_field:
101
            Name of document's meta field to be used for classification.
102
            If not set, `Document.content` is used by default.
103
        :param device:
104
            The device on which the model is loaded. If `None`, the default device is automatically
105
            selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
106
        :param token:
107
            The Hugging Face token to use as HTTP bearer authorization.
108
            Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
109
        :param huggingface_pipeline_kwargs:
110
            Dictionary containing keyword arguments used to initialize the
111
            Hugging Face pipeline for text classification.
112
        """
113

114
        torch_and_transformers_import.check()
1✔
115

116
        self.classification_field = classification_field
1✔
117

118
        self.token = token
1✔
119
        self.labels = labels
1✔
120
        self.multi_label = multi_label
1✔
121

122
        huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs(
1✔
123
            huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {},
124
            model=model,
125
            task="zero-shot-classification",
126
            supported_tasks=["zero-shot-classification"],
127
            device=device,
128
            token=token,
129
        )
130

131
        self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
1✔
132
        self.pipeline = None
1✔
133

134
    def _get_telemetry_data(self) -> Dict[str, Any]:
1✔
135
        """
136
        Data that is sent to Posthog for usage analytics.
137
        """
138
        if isinstance(self.huggingface_pipeline_kwargs["model"], str):
×
139
            return {"model": self.huggingface_pipeline_kwargs["model"]}
×
140
        return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
×
141

142
    def warm_up(self):
1✔
143
        """
144
        Initializes the component.
145
        """
146
        if self.pipeline is None:
1✔
147
            self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
1✔
148

149
    def to_dict(self) -> Dict[str, Any]:
1✔
150
        """
151
        Serializes the component to a dictionary.
152

153
        :returns:
154
            Dictionary with serialized data.
155
        """
156
        serialization_dict = default_to_dict(
1✔
157
            self,
158
            labels=self.labels,
159
            model=self.huggingface_pipeline_kwargs["model"],
160
            huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
161
            token=self.token.to_dict() if self.token else None,
162
        )
163

164
        huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
1✔
165
        huggingface_pipeline_kwargs.pop("token", None)
1✔
166

167
        serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
1✔
168
        return serialization_dict
1✔
169

170
    @classmethod
1✔
171
    def from_dict(cls, data: Dict[str, Any]) -> "TransformersZeroShotDocumentClassifier":
1✔
172
        """
173
        Deserializes the component from a dictionary.
174

175
        :param data:
176
            Dictionary to deserialize from.
177
        :returns:
178
            Deserialized component.
179
        """
180
        deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
1✔
181
        if data["init_parameters"].get("huggingface_pipeline_kwargs") is not None:
1✔
182
            deserialize_hf_model_kwargs(data["init_parameters"]["huggingface_pipeline_kwargs"])
1✔
183
        return default_from_dict(cls, data)
1✔
184

185
    @component.output_types(documents=List[Document])
1✔
186
    def run(self, documents: List[Document], batch_size: int = 1):
1✔
187
        """
188
        Classifies the documents based on the provided labels and adds them to their metadata.
189

190
        The classification results are stored in the `classification` dict within
191
        each document's metadata. If `multi_label` is set to `True`, the scores for each label are available under
192
        the `details` key within the `classification` dictionary.
193

194
        :param documents:
195
            Documents to process.
196
        :param batch_size:
197
            Batch size used for processing the content in each document.
198
        :returns:
199
            A dictionary with the following key:
200
            - `documents`: A list of documents with an added metadata field called `classification`.
201
        """
202

203
        if self.pipeline is None:
1✔
204
            raise RuntimeError(
1✔
205
                "The component TransformerZeroShotDocumentClassifier wasn't warmed up. "
206
                "Run 'warm_up()' before calling 'run()'."
207
            )
208

209
        if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
1✔
210
            raise TypeError(
1✔
211
                "TransformerZeroShotDocumentClassifier expects a list of documents as input. "
212
                "In case you want to classify and route a text, please use the TransformersZeroShotTextRouter."
213
            )
214

215
        invalid_doc_ids = []
1✔
216

217
        for doc in documents:
1✔
218
            if self.classification_field is not None and self.classification_field not in doc.meta:
1✔
219
                invalid_doc_ids.append(doc.id)
×
220

221
        if invalid_doc_ids:
1✔
222
            raise ValueError(
×
223
                f"The following documents do not have the classification field '{self.classification_field}': "
224
                f"{', '.join(invalid_doc_ids)}"
225
            )
226

227
        texts = [
1✔
228
            (doc.content if self.classification_field is None else doc.meta[self.classification_field])
229
            for doc in documents
230
        ]
231

232
        predictions = self.pipeline(texts, self.labels, multi_label=self.multi_label, batch_size=batch_size)
1✔
233

234
        for prediction, document in zip(predictions, documents):
1✔
235
            formatted_prediction = {
1✔
236
                "label": prediction["labels"][0],
237
                "score": prediction["scores"][0],
238
                "details": dict(zip(prediction["labels"], prediction["scores"])),
239
            }
240
            document.meta["classification"] = formatted_prediction
1✔
241

242
        return {"documents": documents}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc