• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 17399860316

02 Sep 2025 09:52AM UTC coverage: 92.078% (+0.001%) from 92.077%
17399860316

Pull #9740

github

web-flow
Merge 0e44518df into 0fe2f8e45
Pull Request #9740: chore: Bump transformers

12937 of 14050 relevant lines covered (92.08%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.38
haystack/components/routers/transformers_text_router.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from typing import Any, Optional
1✔
6

7
from haystack import component, default_from_dict, default_to_dict
1✔
8
from haystack.lazy_imports import LazyImport
1✔
9
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
1✔
10

11
with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
1✔
12
    from transformers import AutoConfig, Pipeline, pipeline
1✔
13

14
    from haystack.utils.hf import (  # pylint: disable=ungrouped-imports
1✔
15
        deserialize_hf_model_kwargs,
16
        resolve_hf_pipeline_kwargs,
17
        serialize_hf_model_kwargs,
18
    )
19

20

21
@component
1✔
22
class TransformersTextRouter:
1✔
23
    """
24
    Routes the text strings to different connections based on a category label.
25

26
    The labels are specific to each model and can be found it its description on Hugging Face.
27

28
    ### Usage example
29

30
    ```python
31
    from haystack.core.pipeline import Pipeline
32
    from haystack.components.routers import TransformersTextRouter
33
    from haystack.components.builders import PromptBuilder
34
    from haystack.components.generators import HuggingFaceLocalGenerator
35

36
    p = Pipeline()
37
    p.add_component(
38
        instance=TransformersTextRouter(model="papluca/xlm-roberta-base-language-detection"),
39
        name="text_router"
40
    )
41
    p.add_component(
42
        instance=PromptBuilder(template="Answer the question: {{query}}\\nAnswer:"),
43
        name="english_prompt_builder"
44
    )
45
    p.add_component(
46
        instance=PromptBuilder(template="Beantworte die Frage: {{query}}\\nAntwort:"),
47
        name="german_prompt_builder"
48
    )
49

50
    p.add_component(
51
        instance=HuggingFaceLocalGenerator(model="DiscoResearch/Llama3-DiscoLeo-Instruct-8B-v0.1"),
52
        name="german_llm"
53
    )
54
    p.add_component(
55
        instance=HuggingFaceLocalGenerator(model="microsoft/Phi-3-mini-4k-instruct"),
56
        name="english_llm"
57
    )
58

59
    p.connect("text_router.en", "english_prompt_builder.query")
60
    p.connect("text_router.de", "german_prompt_builder.query")
61
    p.connect("english_prompt_builder.prompt", "english_llm.prompt")
62
    p.connect("german_prompt_builder.prompt", "german_llm.prompt")
63

64
    # English Example
65
    print(p.run({"text_router": {"text": "What is the capital of Germany?"}}))
66

67
    # German Example
68
    print(p.run({"text_router": {"text": "Was ist die Hauptstadt von Deutschland?"}}))
69
    ```
70
    """
71

72
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
73
        self,
74
        model: str,
75
        labels: Optional[list[str]] = None,
76
        device: Optional[ComponentDevice] = None,
77
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
78
        huggingface_pipeline_kwargs: Optional[dict[str, Any]] = None,
79
    ):
80
        """
81
        Initializes the TransformersTextRouter component.
82

83
        :param model: The name or path of a Hugging Face model for text classification.
84
        :param labels: The list of labels. If not provided, the component fetches the labels
85
            from the model configuration file hosted on the Hugging Face Hub using
86
            `transformers.AutoConfig.from_pretrained`.
87
        :param device: The device for loading the model. If `None`, automatically selects the default device.
88
            If a device or device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
89
        :param token: The API token used to download private models from Hugging Face.
90
            If `True`, uses either `HF_API_TOKEN` or `HF_TOKEN` environment variables.
91
            To generate these tokens, run `transformers-cli login`.
92
        :param huggingface_pipeline_kwargs: A dictionary of keyword arguments for initializing the Hugging Face
93
            text classification pipeline.
94
        """
95
        torch_and_transformers_import.check()
1✔
96

97
        self.token = token
1✔
98

99
        huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs(
1✔
100
            huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {},
101
            model=model,
102
            task="text-classification",
103
            supported_tasks=["text-classification"],
104
            device=device,
105
            token=token,
106
        )
107
        self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
1✔
108

109
        if labels is None:
1✔
110
            config = AutoConfig.from_pretrained(
1✔
111
                huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"]
112
            )
113
            self.labels = list(config.label2id.keys())
1✔
114
        else:
115
            self.labels = labels
×
116
        component.set_output_types(self, **dict.fromkeys(self.labels, str))
1✔
117

118
        self.pipeline: Optional["Pipeline"] = None
1✔
119

120
    def _get_telemetry_data(self) -> dict[str, Any]:
1✔
121
        """
122
        Data that is sent to Posthog for usage analytics.
123
        """
124
        if isinstance(self.huggingface_pipeline_kwargs["model"], str):
×
125
            return {"model": self.huggingface_pipeline_kwargs["model"]}
×
126
        return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
×
127

128
    def warm_up(self):
1✔
129
        """
130
        Initializes the component.
131
        """
132
        if self.pipeline is None:
1✔
133
            self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
1✔
134

135
        # Verify labels from the model configuration file match provided labels
136
        label2id = self.pipeline.model.config.label2id
1✔
137
        if label2id is not None:
1✔
138
            labels = set(label2id.keys())
1✔
139
            if set(self.labels) != labels:
1✔
140
                raise ValueError(
×
141
                    f"The provided labels do not match the labels in the model configuration file. "
142
                    f"Provided labels: {self.labels}. Model labels: {labels}"
143
                )
144

145
    def to_dict(self) -> dict[str, Any]:
1✔
146
        """
147
        Serializes the component to a dictionary.
148

149
        :returns:
150
            Dictionary with serialized data.
151
        """
152
        serialization_dict = default_to_dict(
1✔
153
            self,
154
            labels=self.labels,
155
            model=self.huggingface_pipeline_kwargs["model"],
156
            huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
157
            token=self.token.to_dict() if self.token else None,
158
        )
159

160
        huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
1✔
161
        huggingface_pipeline_kwargs.pop("token", None)
1✔
162

163
        serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
1✔
164
        return serialization_dict
1✔
165

166
    @classmethod
1✔
167
    def from_dict(cls, data: dict[str, Any]) -> "TransformersTextRouter":
1✔
168
        """
169
        Deserializes the component from a dictionary.
170

171
        :param data:
172
            Dictionary to deserialize from.
173
        :returns:
174
            Deserialized component.
175
        """
176
        deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
1✔
177
        if data["init_parameters"].get("huggingface_pipeline_kwargs") is not None:
1✔
178
            deserialize_hf_model_kwargs(data["init_parameters"]["huggingface_pipeline_kwargs"])
1✔
179
        return default_from_dict(cls, data)
1✔
180

181
    def run(self, text: str) -> dict[str, str]:
1✔
182
        """
183
        Routes the text strings to different connections based on a category label.
184

185
        :param text: A string of text to route.
186
        :returns:
187
            A dictionary with the label as key and the text as value.
188

189
        :raises TypeError:
190
            If the input is not a str.
191
        :raises RuntimeError:
192
            If the pipeline has not been loaded because warm_up() was not called before.
193
        """
194
        if self.pipeline is None:
1✔
195
            raise RuntimeError(
1✔
196
                "The component TextTransformersRouter wasn't warmed up. Run 'warm_up()' before calling 'run()'."
197
            )
198

199
        if not isinstance(text, str):
1✔
200
            raise TypeError("TransformersTextRouter expects a str as input.")
1✔
201

202
        prediction = self.pipeline([text], return_all_scores=False, function_to_apply="none")
1✔
203
        label = prediction[0]["label"]
1✔
204
        return {label: text}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc