• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 9568249476

18 Jun 2024 03:52PM UTC coverage: 89.872% (-0.1%) from 89.995%
9568249476

push

github

web-flow
ci: Add code formatting checks  (#7882)

* ruff settings

enable ruff format and re-format outdated files

feat: `EvaluationRunResult` add parameter to specify columns to keep in the comparative `Dataframe`  (#7879)

* adding param to explictily state which cols to keep

* adding param to explictily state which cols to keep

* adding param to explictily state which cols to keep

* updating tests

* adding release notes

* Update haystack/evaluation/eval_run_result.py

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>

* Update releasenotes/notes/add-keep-columns-to-EvalRunResult-comparative-be3e15ce45de3e0b.yaml

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>

* updating docstring

---------

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>

add format-check

fail on format and linting failures

fix string formatting

reformat long lines

fix tests

fix typing

linter

pull from main

* reformat

* lint -> check

* lint -> check

6957 of 7741 relevant lines covered (89.87%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.1
haystack/components/generators/hugging_face_api.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from dataclasses import asdict
1✔
6
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
1✔
7

8
from haystack import component, default_from_dict, default_to_dict, logging
1✔
9
from haystack.dataclasses import StreamingChunk
1✔
10
from haystack.lazy_imports import LazyImport
1✔
11
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1✔
12
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model
1✔
13
from haystack.utils.url_validation import is_valid_http_url
1✔
14

15
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
1✔
16
    from huggingface_hub import (
1✔
17
        InferenceClient,
18
        TextGenerationOutput,
19
        TextGenerationOutputToken,
20
        TextGenerationStreamOutput,
21
    )
22

23

24
logger = logging.getLogger(__name__)
1✔
25

26

27
@component
1✔
28
class HuggingFaceAPIGenerator:
1✔
29
    """
30
    A Generator component that uses Hugging Face APIs to generate text.
31

32
    This component can be used to generate text using different Hugging Face APIs:
33
    - [Free Serverless Inference API]((https://huggingface.co/inference-api)
34
    - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
35
    - [Self-hosted Text Generation Inference](https://github.com/huggingface/text-generation-inference)
36

37

38
    Example usage with the free Serverless Inference API:
39
    ```python
40
    from haystack.components.generators import HuggingFaceAPIGenerator
41
    from haystack.utils import Secret
42

43
    generator = HuggingFaceAPIGenerator(api_type="serverless_inference_api",
44
                                        api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
45
                                        token=Secret.from_token("<your-api-key>"))
46

47
    result = generator.run(prompt="What's Natural Language Processing?")
48
    print(result)
49
    ```
50

51
    Example usage with paid Inference Endpoints:
52
    ```python
53
    from haystack.components.generators import HuggingFaceAPIGenerator
54
    from haystack.utils import Secret
55

56
    generator = HuggingFaceAPIGenerator(api_type="inference_endpoints",
57
                                        api_params={"url": "<your-inference-endpoint-url>"},
58
                                        token=Secret.from_token("<your-api-key>"))
59

60
    result = generator.run(prompt="What's Natural Language Processing?")
61
    print(result)
62

63
    Example usage with self-hosted Text Generation Inference:
64
    ```python
65
    from haystack.components.generators import HuggingFaceAPIGenerator
66

67
    generator = HuggingFaceAPIGenerator(api_type="text_generation_inference",
68
                                        api_params={"url": "http://localhost:8080"})
69

70
    result = generator.run(prompt="What's Natural Language Processing?")
71
    print(result)
72
    ```
73
    """
74

75
    def __init__(
1✔
76
        self,
77
        api_type: Union[HFGenerationAPIType, str],
78
        api_params: Dict[str, str],
79
        token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
80
        generation_kwargs: Optional[Dict[str, Any]] = None,
81
        stop_words: Optional[List[str]] = None,
82
        streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
83
    ):
84
        """
85
        Initialize the HuggingFaceAPIGenerator instance.
86

87
        :param api_type:
88
            The type of Hugging Face API to use.
89
        :param api_params:
90
            A dictionary containing the following keys:
91
            - `model`: model ID on the Hugging Face Hub. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
92
            - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
93
            `TEXT_GENERATION_INFERENCE`.
94
        :param token: The HuggingFace token to use as HTTP bearer authorization.
95
            You can find your HF token in your [account settings](https://huggingface.co/settings/tokens).
96
        :param generation_kwargs:
97
            A dictionary containing keyword arguments to customize text generation. Some examples: `max_new_tokens`,
98
            `temperature`, `top_k`, `top_p`,...
99
            See Hugging Face's [documentation](https://huggingface.co/docs/huggingface_hub/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation)
100
            for more information.
101
        :param stop_words: An optional list of strings representing the stop words.
102
        :param streaming_callback: An optional callable for handling streaming responses.
103
        """
104

105
        huggingface_hub_import.check()
1✔
106

107
        if isinstance(api_type, str):
1✔
108
            api_type = HFGenerationAPIType.from_str(api_type)
1✔
109

110
        if api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API:
1✔
111
            model = api_params.get("model")
1✔
112
            if model is None:
1✔
113
                raise ValueError(
1✔
114
                    "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
115
                )
116
            check_valid_model(model, HFModelType.GENERATION, token)
1✔
117
            model_or_url = model
1✔
118
        elif api_type in [HFGenerationAPIType.INFERENCE_ENDPOINTS, HFGenerationAPIType.TEXT_GENERATION_INFERENCE]:
1✔
119
            url = api_params.get("url")
1✔
120
            if url is None:
1✔
121
                msg = (
1✔
122
                    "To use Text Generation Inference or Inference Endpoints, you need to specify the `url` "
123
                    "parameter in `api_params`."
124
                )
125
                raise ValueError(msg)
1✔
126
            if not is_valid_http_url(url):
1✔
127
                raise ValueError(f"Invalid URL: {url}")
1✔
128
            model_or_url = url
1✔
129
        else:
130
            msg = f"Unknown api_type {api_type}"
×
131
            raise ValueError(api_type)
×
132

133
        # handle generation kwargs setup
134
        generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
1✔
135
        generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
1✔
136
        generation_kwargs["stop_sequences"].extend(stop_words or [])
1✔
137
        generation_kwargs.setdefault("max_new_tokens", 512)
1✔
138

139
        self.api_type = api_type
1✔
140
        self.api_params = api_params
1✔
141
        self.token = token
1✔
142
        self.generation_kwargs = generation_kwargs
1✔
143
        self.streaming_callback = streaming_callback
1✔
144
        self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
1✔
145

146
    def to_dict(self) -> Dict[str, Any]:
1✔
147
        """
148
        Serialize this component to a dictionary.
149

150
        :returns:
151
            A dictionary containing the serialized component.
152
        """
153
        callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
1✔
154
        return default_to_dict(
1✔
155
            self,
156
            api_type=str(self.api_type),
157
            api_params=self.api_params,
158
            token=self.token.to_dict() if self.token else None,
159
            generation_kwargs=self.generation_kwargs,
160
            streaming_callback=callback_name,
161
        )
162

163
    @classmethod
1✔
164
    def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIGenerator":
1✔
165
        """
166
        Deserialize this component from a dictionary.
167
        """
168
        deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
1✔
169
        init_params = data["init_parameters"]
1✔
170
        serialized_callback_handler = init_params.get("streaming_callback")
1✔
171
        if serialized_callback_handler:
1✔
172
            init_params["streaming_callback"] = deserialize_callable(serialized_callback_handler)
1✔
173
        return default_from_dict(cls, data)
1✔
174

175
    @component.output_types(replies=List[str], meta=List[Dict[str, Any]])
1✔
176
    def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
1✔
177
        """
178
        Invoke the text generation inference for the given prompt and generation parameters.
179

180
        :param prompt:
181
            A string representing the prompt.
182
        :param generation_kwargs:
183
            Additional keyword arguments for text generation.
184
        :returns:
185
            A dictionary containing the generated replies and metadata. Both are lists of length n.
186
            - replies: A list of strings representing the generated replies.
187
        """
188
        # update generation kwargs by merging with the default ones
189
        generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
1✔
190

191
        if self.streaming_callback:
1✔
192
            return self._run_streaming(prompt, generation_kwargs)
1✔
193

194
        return self._run_non_streaming(prompt, generation_kwargs)
1✔
195

196
    def _run_streaming(self, prompt: str, generation_kwargs: Dict[str, Any]):
1✔
197
        res_chunk: Iterable[TextGenerationStreamOutput] = self._client.text_generation(
1✔
198
            prompt, details=True, stream=True, **generation_kwargs
199
        )
200
        chunks: List[StreamingChunk] = []
1✔
201
        # pylint: disable=not-an-iterable
202
        for chunk in res_chunk:
1✔
203
            token: TextGenerationOutputToken = chunk.token
1✔
204
            if token.special:
1✔
205
                continue
×
206
            chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})}
1✔
207
            stream_chunk = StreamingChunk(token.text, chunk_metadata)
1✔
208
            chunks.append(stream_chunk)
1✔
209
            self.streaming_callback(stream_chunk)  # type: ignore # streaming_callback is not None (verified in the run method)
1✔
210
        metadata = {
1✔
211
            "finish_reason": chunks[-1].meta.get("finish_reason", None),
212
            "model": self._client.model,
213
            "usage": {"completion_tokens": chunks[-1].meta.get("generated_tokens", 0)},
214
        }
215
        return {"replies": ["".join([chunk.content for chunk in chunks])], "meta": [metadata]}
1✔
216

217
    def _run_non_streaming(self, prompt: str, generation_kwargs: Dict[str, Any]):
1✔
218
        tgr: TextGenerationOutput = self._client.text_generation(prompt, details=True, **generation_kwargs)
1✔
219
        meta = [
1✔
220
            {
221
                "model": self._client.model,
222
                "finish_reason": tgr.details.finish_reason if tgr.details else None,
223
                "usage": {"completion_tokens": len(tgr.details.tokens) if tgr.details else 0},
224
            }
225
        ]
226
        return {"replies": [tgr.generated_text], "meta": meta}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc