• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 10920861378

18 Sep 2024 11:05AM UTC coverage: 90.391% (+0.003%) from 90.388%
10920861378

push

github

web-flow
fix: Prevent `set_output_types` from being called when the `output_types` decorator is used (#8376)

7337 of 8117 relevant lines covered (90.39%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.07
haystack/components/classifiers/zero_shot_document_classifier.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from typing import Any, Dict, List, Optional
1✔
6

7
from haystack import Document, component, default_from_dict, default_to_dict, logging
1✔
8
from haystack.lazy_imports import LazyImport
1✔
9
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
1✔
10
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_pipeline_kwargs, serialize_hf_model_kwargs
1✔
11

12
logger = logging.getLogger(__name__)
1✔
13

14

15
with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
1✔
16
    from transformers import pipeline
1✔
17

18

19
@component
1✔
20
class TransformersZeroShotDocumentClassifier:
1✔
21
    """
22
    Performs zero-shot classification of documents based on given labels and adds the predicted label to their metadata.
23

24
    The component uses a Hugging Face pipeline for zero-shot classification.
25
    Provide the model and the set of labels to be used for categorization during initialization.
26
    Additionally, you can configure the component to allow multiple labels to be true.
27

28
    Classification is run on the document's content field by default. If you want it to run on another field, set the
29
    `classification_field` to one of the document's metadata fields.
30

31
    Available models for the task of zero-shot-classification include:
32
        - `valhalla/distilbart-mnli-12-3`
33
        - `cross-encoder/nli-distilroberta-base`
34
        - `cross-encoder/nli-deberta-v3-xsmall`
35

36
    ### Usage example
37

38
    The following is a pipeline that classifies documents based on predefined classification labels
39
    retrieved from a search pipeline:
40

41
    ```python
42
    from haystack import Document
43
    from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
44
    from haystack.document_stores.in_memory import InMemoryDocumentStore
45
    from haystack.core.pipeline import Pipeline
46
    from haystack.components.classifiers import TransformersZeroShotDocumentClassifier
47

48
    documents = [Document(id="0", content="Today was a nice day!"),
49
                 Document(id="1", content="Yesterday was a bad day!")]
50

51
    document_store = InMemoryDocumentStore()
52
    retriever = InMemoryBM25Retriever(document_store=document_store)
53
    document_classifier = TransformersZeroShotDocumentClassifier(
54
        model="cross-encoder/nli-deberta-v3-xsmall",
55
        labels=["positive", "negative"],
56
    )
57

58
    document_store.write_documents(documents)
59

60
    pipeline = Pipeline()
61
    pipeline.add_component(instance=retriever, name="retriever")
62
    pipeline.add_component(instance=document_classifier, name="document_classifier")
63
    pipeline.connect("retriever", "document_classifier")
64

65
    queries = ["How was your day today?", "How was your day yesterday?"]
66
    expected_predictions = ["positive", "negative"]
67

68
    for idx, query in enumerate(queries):
69
        result = pipeline.run({"retriever": {"query": query, "top_k": 1}})
70
        assert result["document_classifier"]["documents"][0].to_dict()["id"] == str(idx)
71
        assert (result["document_classifier"]["documents"][0].to_dict()["classification"]["label"]
72
                == expected_predictions[idx])
73
    ```
74
    """
75

76
    def __init__(
1✔
77
        self,
78
        model: str,
79
        labels: List[str],
80
        multi_label: bool = False,
81
        classification_field: Optional[str] = None,
82
        device: Optional[ComponentDevice] = None,
83
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
84
        huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
85
    ):
86
        """
87
        Initializes the TransformersZeroShotDocumentClassifier.
88

89
        See the Hugging Face [website](https://huggingface.co/models?pipeline_tag=zero-shot-classification&sort=downloads&search=nli)
90
        for the full list of zero-shot classification models (NLI) models.
91

92
        :param model:
93
            The name or path of a Hugging Face model for zero shot document classification.
94
        :param labels:
95
            The set of possible class labels to classify each document into, for example,
96
            ["positive", "negative"]. The labels depend on the selected model.
97
        :param multi_label:
98
            Whether or not multiple candidate labels can be true.
99
            If `False`, the scores are normalized such that
100
            the sum of the label likelihoods for each sequence is 1. If `True`, the labels are considered
101
            independent and probabilities are normalized for each candidate by doing a softmax of the entailment
102
            score vs. the contradiction score.
103
        :param classification_field:
104
            Name of document's meta field to be used for classification.
105
            If not set, `Document.content` is used by default.
106
        :param device:
107
            The device on which the model is loaded. If `None`, the default device is automatically
108
            selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
109
        :param token:
110
            The Hugging Face token to use as HTTP bearer authorization.
111
            Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
112
        :param huggingface_pipeline_kwargs:
113
            Dictionary containing keyword arguments used to initialize the
114
            Hugging Face pipeline for text classification.
115
        """
116

117
        torch_and_transformers_import.check()
1✔
118

119
        self.classification_field = classification_field
1✔
120

121
        self.token = token
1✔
122
        self.labels = labels
1✔
123
        self.multi_label = multi_label
1✔
124

125
        huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs(
1✔
126
            huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {},
127
            model=model,
128
            task="zero-shot-classification",
129
            supported_tasks=["zero-shot-classification"],
130
            device=device,
131
            token=token,
132
        )
133

134
        self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
1✔
135
        self.pipeline = None
1✔
136

137
    def _get_telemetry_data(self) -> Dict[str, Any]:
1✔
138
        """
139
        Data that is sent to Posthog for usage analytics.
140
        """
141
        if isinstance(self.huggingface_pipeline_kwargs["model"], str):
×
142
            return {"model": self.huggingface_pipeline_kwargs["model"]}
×
143
        return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
×
144

145
    def warm_up(self):
1✔
146
        """
147
        Initializes the component.
148
        """
149
        if self.pipeline is None:
1✔
150
            self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
1✔
151

152
    def to_dict(self) -> Dict[str, Any]:
1✔
153
        """
154
        Serializes the component to a dictionary.
155

156
        :returns:
157
            Dictionary with serialized data.
158
        """
159
        serialization_dict = default_to_dict(
1✔
160
            self,
161
            labels=self.labels,
162
            model=self.huggingface_pipeline_kwargs["model"],
163
            huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
164
            token=self.token.to_dict() if self.token else None,
165
        )
166

167
        huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
1✔
168
        huggingface_pipeline_kwargs.pop("token", None)
1✔
169

170
        serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
1✔
171
        return serialization_dict
1✔
172

173
    @classmethod
1✔
174
    def from_dict(cls, data: Dict[str, Any]) -> "TransformersZeroShotDocumentClassifier":
1✔
175
        """
176
        Deserializes the component from a dictionary.
177

178
        :param data:
179
            Dictionary to deserialize from.
180
        :returns:
181
            Deserialized component.
182
        """
183
        deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
1✔
184
        if data["init_parameters"].get("huggingface_pipeline_kwargs") is not None:
1✔
185
            deserialize_hf_model_kwargs(data["init_parameters"]["huggingface_pipeline_kwargs"])
1✔
186
        return default_from_dict(cls, data)
1✔
187

188
    @component.output_types(documents=List[Document])
1✔
189
    def run(self, documents: List[Document], batch_size: int = 1):
1✔
190
        """
191
        Classifies the documents based on the provided labels and adds them to their metadata.
192

193
        The classification results are stored in the `classification` dict within
194
        each document's metadata. If `multi_label` is set to `True`, the scores for each label are available under
195
        the `details` key within the `classification` dictionary.
196

197
        :param documents:
198
            Documents to process.
199
        :param batch_size:
200
            Batch size used for processing the content in each document.
201
        :returns:
202
            A dictionary with the following key:
203
            - `documents`: A list of documents with an added metadata field called `classification`.
204
        """
205

206
        if self.pipeline is None:
1✔
207
            raise RuntimeError(
1✔
208
                "The component TransformerZeroShotDocumentClassifier wasn't warmed up. "
209
                "Run 'warm_up()' before calling 'run()'."
210
            )
211

212
        if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
1✔
213
            raise TypeError(
1✔
214
                "DocumentLanguageClassifier expects a list of documents as input. "
215
                "In case you want to classify a text, please use the TextLanguageClassifier."
216
            )
217

218
        invalid_doc_ids = []
1✔
219

220
        for doc in documents:
1✔
221
            if self.classification_field is not None and self.classification_field not in doc.meta:
1✔
222
                invalid_doc_ids.append(doc.id)
×
223

224
        if invalid_doc_ids:
1✔
225
            raise ValueError(
×
226
                f"The following documents do not have the classification field '{self.classification_field}': "
227
                f"{', '.join(invalid_doc_ids)}"
228
            )
229

230
        texts = [
1✔
231
            (doc.content if self.classification_field is None else doc.meta[self.classification_field])
232
            for doc in documents
233
        ]
234

235
        predictions = self.pipeline(texts, self.labels, multi_label=self.multi_label, batch_size=batch_size)
1✔
236

237
        for prediction, document in zip(predictions, documents):
1✔
238
            formatted_prediction = {
1✔
239
                "label": prediction["labels"][0],
240
                "score": prediction["scores"][0],
241
                "details": dict(zip(prediction["labels"], prediction["scores"])),
242
            }
243
            document.meta["classification"] = formatted_prediction
1✔
244

245
        return {"documents": documents}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc