• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 15276913207

27 May 2025 01:41PM UTC coverage: 90.41% (+0.02%) from 90.388%
15276913207

Pull #9449

github

web-flow
Merge f55ee47c3 into 3deaa20cb
Pull Request #9449: refactor: Refactor hf api chat generator

11464 of 12680 relevant lines covered (90.41%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.19
haystack/components/generators/chat/hugging_face_local.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import asyncio
1✔
6
import json
1✔
7
import re
1✔
8
import sys
1✔
9
from concurrent.futures import ThreadPoolExecutor
1✔
10
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast
1✔
11

12
from haystack import component, default_from_dict, default_to_dict, logging
1✔
13
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, ToolCall, select_streaming_callback
1✔
14
from haystack.lazy_imports import LazyImport
1✔
15
from haystack.tools import (
1✔
16
    Tool,
17
    Toolset,
18
    _check_duplicate_tool_names,
19
    deserialize_tools_or_toolset_inplace,
20
    serialize_tools_or_toolset,
21
)
22
from haystack.utils import (
1✔
23
    ComponentDevice,
24
    Secret,
25
    deserialize_callable,
26
    deserialize_secrets_inplace,
27
    serialize_callable,
28
)
29

30
logger = logging.getLogger(__name__)
1✔
31

32
with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as torch_and_transformers_import:
1✔
33
    from huggingface_hub import model_info
1✔
34
    from transformers import Pipeline as HfPipeline
1✔
35
    from transformers import StoppingCriteriaList, pipeline
1✔
36
    from transformers.tokenization_utils import PreTrainedTokenizer
1✔
37
    from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
1✔
38

39
    from haystack.utils.hf import (  # pylint: disable=ungrouped-imports
1✔
40
        HFTokenStreamingHandler,
41
        StopWordsCriteria,
42
        convert_message_to_hf_format,
43
        deserialize_hf_model_kwargs,
44
        serialize_hf_model_kwargs,
45
    )
46

47

48
PIPELINE_SUPPORTED_TASKS = ["text-generation", "text2text-generation"]
1✔
49

50
DEFAULT_TOOL_PATTERN = (
1✔
51
    r"(?:<tool_call>)?"
52
    r'(?:\s*\{.*?"name"\s*:\s*"([^"]+)".*?"arguments"\s*:\s*(\{[^}]+\}).*?\}'
53
    r'|\{.*?"function"\s*:\s*\{.*?"name"\s*:\s*"([^"]+)".*?"arguments"\s*:\s*(\{[^}]+\}).*?\})'
54
)
55

56

57
def default_tool_parser(text: str) -> Optional[List[ToolCall]]:
1✔
58
    """
59
    Default implementation for parsing tool calls from model output text.
60

61
    Uses DEFAULT_TOOL_PATTERN to extract tool calls.
62

63
    :param text: The text to parse for tool calls.
64
    :returns: A list containing a single ToolCall if a valid tool call is found, None otherwise.
65
    """
66
    try:
1✔
67
        match = re.search(DEFAULT_TOOL_PATTERN, text, re.DOTALL)
1✔
68
    except re.error:
×
69
        logger.warning("Invalid regex pattern for tool parsing: {pattern}", pattern=DEFAULT_TOOL_PATTERN)
×
70
        return None
×
71

72
    if not match:
1✔
73
        return None
×
74

75
    name = match.group(1) or match.group(3)
1✔
76
    args_str = match.group(2) or match.group(4)
1✔
77

78
    try:
1✔
79
        arguments = json.loads(args_str)
1✔
80
        return [ToolCall(tool_name=name, arguments=arguments)]
1✔
81
    except json.JSONDecodeError:
×
82
        logger.warning("Failed to parse tool call arguments: {args_str}", args_str=args_str)
×
83
        return None
×
84

85

86
@component
1✔
87
class HuggingFaceLocalChatGenerator:
1✔
88
    """
89
    Generates chat responses using models from Hugging Face that run locally.
90

91
    Use this component with chat-based models,
92
    such as `HuggingFaceH4/zephyr-7b-beta` or `meta-llama/Llama-2-7b-chat-hf`.
93
    LLMs running locally may need powerful hardware.
94

95
    ### Usage example
96

97
    ```python
98
    from haystack.components.generators.chat import HuggingFaceLocalChatGenerator
99
    from haystack.dataclasses import ChatMessage
100

101
    generator = HuggingFaceLocalChatGenerator(model="HuggingFaceH4/zephyr-7b-beta")
102
    generator.warm_up()
103
    messages = [ChatMessage.from_user("What's Natural Language Processing? Be brief.")]
104
    print(generator.run(messages))
105
    ```
106

107
    ```
108
    {'replies':
109
        [ChatMessage(_role=<ChatRole.ASSISTANT: 'assistant'>, _content=[TextContent(text=
110
        "Natural Language Processing (NLP) is a subfield of artificial intelligence that deals
111
        with the interaction between computers and human language. It enables computers to understand, interpret, and
112
        generate human language in a valuable way. NLP involves various techniques such as speech recognition, text
113
        analysis, sentiment analysis, and machine translation. The ultimate goal is to make it easier for computers to
114
        process and derive meaning from human language, improving communication between humans and machines.")],
115
        _name=None,
116
        _meta={'finish_reason': 'stop', 'index': 0, 'model':
117
              'mistralai/Mistral-7B-Instruct-v0.2',
118
              'usage': {'completion_tokens': 90, 'prompt_tokens': 19, 'total_tokens': 109}})
119
              ]
120
    }
121
    ```
122
    """
123

124
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
125
        self,
126
        model: str = "HuggingFaceH4/zephyr-7b-beta",
127
        task: Optional[Literal["text-generation", "text2text-generation"]] = None,
128
        device: Optional[ComponentDevice] = None,
129
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
130
        chat_template: Optional[str] = None,
131
        generation_kwargs: Optional[Dict[str, Any]] = None,
132
        huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
133
        stop_words: Optional[List[str]] = None,
134
        streaming_callback: Optional[StreamingCallbackT] = None,
135
        tools: Optional[Union[List[Tool], Toolset]] = None,
136
        tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None,
137
        async_executor: Optional[ThreadPoolExecutor] = None,
138
    ):
139
        """
140
        Initializes the HuggingFaceLocalChatGenerator component.
141

142
        :param model: The Hugging Face text generation model name or path,
143
            for example, `mistralai/Mistral-7B-Instruct-v0.2` or `TheBloke/OpenHermes-2.5-Mistral-7B-16k-AWQ`.
144
            The model must be a chat model supporting the ChatML messaging
145
            format.
146
            If the model is specified in `huggingface_pipeline_kwargs`, this parameter is ignored.
147
        :param task: The task for the Hugging Face pipeline. Possible options:
148
            - `text-generation`: Supported by decoder models, like GPT.
149
            - `text2text-generation`: Supported by encoder-decoder models, like T5.
150
            If the task is specified in `huggingface_pipeline_kwargs`, this parameter is ignored.
151
            If not specified, the component calls the Hugging Face API to infer the task from the model name.
152
        :param device: The device for loading the model. If `None`, automatically selects the default device.
153
            If a device or device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
154
        :param token: The token to use as HTTP bearer authorization for remote files.
155
            If the token is specified in `huggingface_pipeline_kwargs`, this parameter is ignored.
156
        :param chat_template: Specifies an optional Jinja template for formatting chat
157
            messages. Most high-quality chat models have their own templates, but for models without this
158
            feature or if you prefer a custom template, use this parameter.
159
        :param generation_kwargs: A dictionary with keyword arguments to customize text generation.
160
            Some examples: `max_length`, `max_new_tokens`, `temperature`, `top_k`, `top_p`.
161
            See Hugging Face's documentation for more information:
162
            - - [customize-text-generation](https://huggingface.co/docs/transformers/main/en/generation_strategies#customize-text-generation)
163
            - - [GenerationConfig](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig)
164
            The only `generation_kwargs` set by default is `max_new_tokens`, which is set to 512 tokens.
165
        :param huggingface_pipeline_kwargs: Dictionary with keyword arguments to initialize the
166
            Hugging Face pipeline for text generation.
167
            These keyword arguments provide fine-grained control over the Hugging Face pipeline.
168
            In case of duplication, these kwargs override `model`, `task`, `device`, and `token` init parameters.
169
            For kwargs, see [Hugging Face documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline.task).
170
            In this dictionary, you can also include `model_kwargs` to specify the kwargs for [model initialization](https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.from_pretrained)
171
        :param stop_words: A list of stop words. If the model generates a stop word, the generation stops.
172
            If you provide this parameter, don't specify the `stopping_criteria` in `generation_kwargs`.
173
            For some chat models, the output includes both the new text and the original prompt.
174
            In these cases, make sure your prompt has no stop words.
175
        :param streaming_callback: An optional callable for handling streaming responses.
176
        :param tools: A list of tools or a Toolset for which the model can prepare calls.
177
            This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
178
        :param tool_parsing_function:
179
            A callable that takes a string and returns a list of ToolCall objects or None.
180
            If None, the default_tool_parser will be used which extracts tool calls using a predefined pattern.
181
        :param async_executor:
182
            Optional ThreadPoolExecutor to use for async calls. If not provided, a single-threaded executor will be
183
            initialized and used
184
        """
185
        torch_and_transformers_import.check()
1✔
186

187
        if tools and streaming_callback is not None:
1✔
188
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
1✔
189
        _check_duplicate_tool_names(list(tools or []))
1✔
190

191
        huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {}
1✔
192
        generation_kwargs = generation_kwargs or {}
1✔
193

194
        self.token = token
1✔
195
        token = token.resolve_value() if token else None
1✔
196

197
        # check if the huggingface_pipeline_kwargs contain the essential parameters
198
        # otherwise, populate them with values from other init parameters
199
        huggingface_pipeline_kwargs.setdefault("model", model)
1✔
200
        huggingface_pipeline_kwargs.setdefault("token", token)
1✔
201

202
        device = ComponentDevice.resolve_device(device)
1✔
203
        device.update_hf_kwargs(huggingface_pipeline_kwargs, overwrite=False)
1✔
204

205
        # task identification and validation
206
        if task is None:
1✔
207
            if "task" in huggingface_pipeline_kwargs:
1✔
208
                task = huggingface_pipeline_kwargs["task"]
1✔
209
            elif isinstance(huggingface_pipeline_kwargs["model"], str):
1✔
210
                task = model_info(
1✔
211
                    huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"]
212
                ).pipeline_tag  # type: ignore[assignment]  # we'll check below if task is in supported tasks
213

214
        if task not in PIPELINE_SUPPORTED_TASKS:
1✔
215
            raise ValueError(
1✔
216
                f"Task '{task}' is not supported. The supported tasks are: {', '.join(PIPELINE_SUPPORTED_TASKS)}."
217
            )
218
        huggingface_pipeline_kwargs["task"] = task
1✔
219

220
        # if not specified, set return_full_text to False for text-generation
221
        # only generated text is returned (excluding prompt)
222
        if task == "text-generation":
1✔
223
            generation_kwargs.setdefault("return_full_text", False)
×
224

225
        if stop_words and "stopping_criteria" in generation_kwargs:
1✔
226
            raise ValueError(
×
227
                "Found both the `stop_words` init parameter and the `stopping_criteria` key in `generation_kwargs`. "
228
                "Please specify only one of them."
229
            )
230
        generation_kwargs.setdefault("max_new_tokens", 512)
1✔
231
        generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
1✔
232
        generation_kwargs["stop_sequences"].extend(stop_words or [])
1✔
233

234
        self.tool_parsing_function = tool_parsing_function or default_tool_parser
1✔
235
        self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
1✔
236
        self.generation_kwargs = generation_kwargs
1✔
237
        self.chat_template = chat_template
1✔
238
        self.streaming_callback = streaming_callback
1✔
239
        self.pipeline: Optional[HfPipeline] = None
1✔
240
        self.tools = tools
1✔
241

242
        self._owns_executor = async_executor is None
1✔
243
        self.executor = (
1✔
244
            ThreadPoolExecutor(thread_name_prefix=f"async-HFLocalChatGenerator-executor-{id(self)}", max_workers=1)
245
            if async_executor is None
246
            else async_executor
247
        )
248

249
    def __del__(self):
1✔
250
        """
251
        Cleanup when the instance is being destroyed.
252
        """
253
        if hasattr(self, "_owns_executor") and self._owns_executor and hasattr(self, "executor"):
1✔
254
            self.executor.shutdown(wait=True)
1✔
255

256
    def shutdown(self):
1✔
257
        """
258
        Explicitly shutdown the executor if we own it.
259
        """
260
        if self._owns_executor:
×
261
            self.executor.shutdown(wait=True)
×
262

263
    def _get_telemetry_data(self) -> Dict[str, Any]:
1✔
264
        """
265
        Data that is sent to Posthog for usage analytics.
266
        """
267
        if isinstance(self.huggingface_pipeline_kwargs["model"], str):
×
268
            return {"model": self.huggingface_pipeline_kwargs["model"]}
×
269
        return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
×
270

271
    def warm_up(self):
1✔
272
        """
273
        Initializes the component.
274
        """
275
        if self.pipeline is None:
1✔
276
            self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
1✔
277

278
    def to_dict(self) -> Dict[str, Any]:
1✔
279
        """
280
        Serializes the component to a dictionary.
281

282
        :returns:
283
            Dictionary with serialized data.
284
        """
285
        callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
1✔
286
        serialization_dict = default_to_dict(
1✔
287
            self,
288
            huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
289
            generation_kwargs=self.generation_kwargs,
290
            streaming_callback=callback_name,
291
            token=self.token.to_dict() if self.token else None,
292
            chat_template=self.chat_template,
293
            tools=serialize_tools_or_toolset(self.tools),
294
            tool_parsing_function=serialize_callable(self.tool_parsing_function),
295
        )
296

297
        huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
1✔
298
        huggingface_pipeline_kwargs.pop("token", None)
1✔
299

300
        serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
1✔
301
        return serialization_dict
1✔
302

303
    @classmethod
1✔
304
    def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceLocalChatGenerator":
1✔
305
        """
306
        Deserializes the component from a dictionary.
307

308
        :param data:
309
            The dictionary to deserialize from.
310
        :returns:
311
            The deserialized component.
312
        """
313
        torch_and_transformers_import.check()  # leave this, cls method
1✔
314
        deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
1✔
315
        deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools")
1✔
316
        init_params = data.get("init_parameters", {})
1✔
317
        serialized_callback_handler = init_params.get("streaming_callback")
1✔
318
        if serialized_callback_handler:
1✔
319
            data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
×
320

321
        tool_parsing_function = init_params.get("tool_parsing_function")
1✔
322
        if tool_parsing_function:
1✔
323
            init_params["tool_parsing_function"] = deserialize_callable(tool_parsing_function)
1✔
324

325
        huggingface_pipeline_kwargs = init_params.get("huggingface_pipeline_kwargs", {})
1✔
326
        deserialize_hf_model_kwargs(huggingface_pipeline_kwargs)
1✔
327
        return default_from_dict(cls, data)
1✔
328

329
    @component.output_types(replies=List[ChatMessage])
1✔
330
    def run(
1✔
331
        self,
332
        messages: List[ChatMessage],
333
        generation_kwargs: Optional[Dict[str, Any]] = None,
334
        streaming_callback: Optional[StreamingCallbackT] = None,
335
        tools: Optional[Union[List[Tool], Toolset]] = None,
336
    ):
337
        """
338
        Invoke text generation inference based on the provided messages and generation parameters.
339

340
        :param messages: A list of ChatMessage objects representing the input messages.
341
        :param generation_kwargs: Additional keyword arguments for text generation.
342
        :param streaming_callback: An optional callable for handling streaming responses.
343
        :param tools:
344
            A list of tools or a Toolset for which the model can prepare calls. If set, it will override
345
            the `tools` parameter provided during initialization. This parameter can accept either a list
346
            of `Tool` objects or a `Toolset` instance.
347
        :returns:
348
            A list containing the generated responses as ChatMessage instances.
349
        """
350
        if self.pipeline is None:
1✔
351
            raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")
×
352

353
        tools = tools or self.tools
1✔
354
        if tools and streaming_callback is not None:
1✔
355
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
×
356
        _check_duplicate_tool_names(list(tools or []))
1✔
357

358
        tokenizer = self.pipeline.tokenizer
1✔
359
        # initialized text-generation/text2text-generation pipelines always have a non-None tokenizer
360
        assert tokenizer is not None
1✔
361

362
        # Check and update generation parameters
363
        generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
1✔
364

365
        stop_words = generation_kwargs.pop("stop_words", []) + generation_kwargs.pop("stop_sequences", [])
1✔
366
        # pipeline call doesn't support stop_sequences, so we need to pop it
367
        stop_words = self._validate_stop_words(stop_words)
1✔
368

369
        # Set up stop words criteria if stop words exist
370
        stop_words_criteria = StopWordsCriteria(tokenizer, stop_words, self.pipeline.device) if stop_words else None
1✔
371
        if stop_words_criteria:
1✔
372
            generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])
×
373

374
        streaming_callback = select_streaming_callback(
1✔
375
            self.streaming_callback, streaming_callback, requires_async=False
376
        )
377
        if streaming_callback:
1✔
378
            num_responses = generation_kwargs.get("num_return_sequences", 1)
1✔
379
            if num_responses > 1:
1✔
380
                msg = (
×
381
                    "Streaming is enabled, but the number of responses is set to {num_responses}. "
382
                    "Streaming is only supported for single response generation. "
383
                    "Setting the number of responses to 1."
384
                )
385
                logger.warning(msg, num_responses=num_responses)
×
386
                generation_kwargs["num_return_sequences"] = 1
×
387

388
            # Get component name and type
389
            component_info = ComponentInfo.from_component(self)
1✔
390
            # streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
391
            generation_kwargs["streamer"] = HFTokenStreamingHandler(
1✔
392
                tokenizer=tokenizer,
393
                stream_handler=streaming_callback,
394
                stop_words=stop_words,
395
                component_info=component_info,
396
            )
397

398
        # convert messages to HF format
399
        hf_messages = [convert_message_to_hf_format(message) for message in messages]
1✔
400

401
        if isinstance(tools, Toolset):
1✔
402
            tools = list(tools)
×
403

404
        prepared_prompt = tokenizer.apply_chat_template(
1✔
405
            hf_messages,
406
            tokenize=False,
407
            chat_template=self.chat_template,
408
            add_generation_prompt=True,
409
            tools=[tc.tool_spec for tc in tools] if tools else None,
410
        )
411

412
        # prepared_prompt is a string since we set tokenize=False https://hf.co/docs/transformers/main/chat_templating
413
        assert isinstance(prepared_prompt, str)
1✔
414

415
        # Avoid some unnecessary warnings in the generation pipeline call
416
        generation_kwargs["pad_token_id"] = (
1✔
417
            generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
418
        )
419

420
        # Generate responses
421
        output = self.pipeline(prepared_prompt, **generation_kwargs)
1✔
422
        replies = [o.get("generated_text", "") for o in output]
1✔
423

424
        # Remove stop words from replies if present
425
        for stop_word in stop_words:
1✔
426
            replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
×
427

428
        chat_messages = [
1✔
429
            self.create_message(
430
                reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=bool(tools)
431
            )
432
            for r_index, reply in enumerate(replies)
433
        ]
434

435
        return {"replies": chat_messages}
1✔
436

437
    def create_message(  # pylint: disable=too-many-positional-arguments
1✔
438
        self,
439
        text: str,
440
        index: int,
441
        tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
442
        prompt: str,
443
        generation_kwargs: Dict[str, Any],
444
        parse_tool_calls: bool = False,
445
    ) -> ChatMessage:
446
        """
447
        Create a ChatMessage instance from the provided text, populated with metadata.
448

449
        :param text: The generated text.
450
        :param index: The index of the generated text.
451
        :param tokenizer: The tokenizer used for generation.
452
        :param prompt: The prompt used for generation.
453
        :param generation_kwargs: The generation parameters.
454
        :param parse_tool_calls: Whether to attempt parsing tool calls from the text.
455
        :returns: A ChatMessage instance.
456
        """
457

458
        completion_tokens = len(tokenizer.encode(text, add_special_tokens=False))
1✔
459
        prompt_token_count = len(tokenizer.encode(prompt, add_special_tokens=False))
1✔
460
        total_tokens = prompt_token_count + completion_tokens
1✔
461

462
        tool_calls = self.tool_parsing_function(text) if parse_tool_calls else None
1✔
463

464
        # Determine finish reason based on context
465
        if completion_tokens >= generation_kwargs.get("max_new_tokens", sys.maxsize):
1✔
466
            finish_reason = "length"
×
467
        elif tool_calls:
1✔
468
            finish_reason = "tool_calls"
1✔
469
        else:
470
            finish_reason = "stop"
1✔
471

472
        meta = {
1✔
473
            "finish_reason": finish_reason,
474
            "index": index,
475
            "model": self.huggingface_pipeline_kwargs["model"],
476
            "usage": {
477
                "completion_tokens": completion_tokens,
478
                "prompt_tokens": prompt_token_count,
479
                "total_tokens": total_tokens,
480
            },
481
        }
482

483
        # If tool calls are detected, don't include the text content since it contains the raw tool call format
484
        return ChatMessage.from_assistant(tool_calls=tool_calls, text=None if tool_calls else text, meta=meta)
1✔
485

486
    @staticmethod
1✔
487
    def _validate_stop_words(stop_words: Optional[List[str]]) -> Optional[List[str]]:
1✔
488
        """
489
        Validates the provided stop words.
490

491
        :param stop_words: A list of stop words to validate.
492
        :return: A sanitized list of stop words or None if validation fails.
493
        """
494
        if stop_words and not all(isinstance(word, str) for word in stop_words):
1✔
495
            logger.warning(
×
496
                "Invalid stop words provided. Stop words must be specified as a list of strings. "
497
                "Ignoring stop words: {stop_words}",
498
                stop_words=stop_words,
499
            )
500
            return None
×
501

502
        return list(set(stop_words or []))
1✔
503

504
    @component.output_types(replies=List[ChatMessage])
1✔
505
    async def run_async(
1✔
506
        self,
507
        messages: List[ChatMessage],
508
        generation_kwargs: Optional[Dict[str, Any]] = None,
509
        streaming_callback: Optional[StreamingCallbackT] = None,
510
        tools: Optional[Union[List[Tool], Toolset]] = None,
511
    ):
512
        """
513
        Asynchronously invokes text generation inference based on the provided messages and generation parameters.
514

515
        This is the asynchronous version of the `run` method. It has the same parameters
516
        and return values but can be used with `await` in an async code.
517

518
        :param messages: A list of ChatMessage objects representing the input messages.
519
        :param generation_kwargs: Additional keyword arguments for text generation.
520
        :param streaming_callback: An optional callable for handling streaming responses.
521
        :param tools: A list of tools or a Toolset for which the model can prepare calls.
522
            This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
523
        :returns: A dictionary with the following keys:
524
            - `replies`: A list containing the generated responses as ChatMessage instances.
525
        """
526
        if self.pipeline is None:
1✔
527
            raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")
1✔
528

529
        tools = tools or self.tools
1✔
530
        if tools and streaming_callback is not None:
1✔
531
            raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
1✔
532
        _check_duplicate_tool_names(list(tools or []))
1✔
533

534
        tokenizer = self.pipeline.tokenizer
1✔
535
        # initialized text-generation/text2text-generation pipelines always have a non-None tokenizer
536
        assert tokenizer is not None
1✔
537

538
        # Check and update generation parameters
539
        generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
1✔
540

541
        stop_words = generation_kwargs.pop("stop_words", []) + generation_kwargs.pop("stop_sequences", [])
1✔
542
        stop_words = self._validate_stop_words(stop_words)
1✔
543

544
        # Set up stop words criteria if stop words exist
545
        stop_words_criteria = StopWordsCriteria(tokenizer, stop_words, self.pipeline.device) if stop_words else None
1✔
546
        if stop_words_criteria:
1✔
547
            generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])
×
548

549
        # validate and select the streaming callback
550
        streaming_callback = select_streaming_callback(self.streaming_callback, streaming_callback, requires_async=True)
1✔
551

552
        if streaming_callback:
1✔
553
            return await self._run_streaming_async(
×
554
                messages, tokenizer, generation_kwargs, stop_words, streaming_callback
555
            )
556

557
        return await self._run_non_streaming_async(messages, tokenizer, generation_kwargs, stop_words, tools)
1✔
558

559
    async def _run_streaming_async(  # pylint: disable=too-many-positional-arguments
1✔
560
        self,
561
        messages: List[ChatMessage],
562
        tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
563
        generation_kwargs: Dict[str, Any],
564
        stop_words: Optional[List[str]],
565
        streaming_callback: StreamingCallbackT,
566
    ):
567
        """
568
        Handles async streaming generation of responses.
569
        """
570
        # convert messages to HF format
571
        hf_messages = [convert_message_to_hf_format(message) for message in messages]
×
572
        prepared_prompt = tokenizer.apply_chat_template(
×
573
            hf_messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True
574
        )
575

576
        # prepared_prompt is a string since we set tokenize=False https://hf.co/docs/transformers/main/chat_templating
577
        assert isinstance(prepared_prompt, str)
×
578

579
        # Avoid some unnecessary warnings in the generation pipeline call
580
        generation_kwargs["pad_token_id"] = (
×
581
            generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
582
        )
583

584
        # get the component name and type
585
        component_info = ComponentInfo.from_component(self)
×
586
        generation_kwargs["streamer"] = HFTokenStreamingHandler(
×
587
            tokenizer, streaming_callback, stop_words, component_info
588
        )
589

590
        # Generate responses asynchronously
591
        output = await asyncio.get_running_loop().run_in_executor(
×
592
            self.executor,
593
            lambda: self.pipeline(prepared_prompt, **generation_kwargs),  # type: ignore # if self.executor was not passed it was initialized with max_workers=1 in init
594
        )
595

596
        replies = [o.get("generated_text", "") for o in output]
×
597

598
        # Remove stop words from replies if present
599
        for stop_word in stop_words or []:
×
600
            replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
×
601

602
        chat_messages = [
×
603
            self.create_message(reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=False)
604
            for r_index, reply in enumerate(replies)
605
        ]
606

607
        return {"replies": chat_messages}
×
608

609
    async def _run_non_streaming_async(  # pylint: disable=too-many-positional-arguments
1✔
610
        self,
611
        messages: List[ChatMessage],
612
        tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
613
        generation_kwargs: Dict[str, Any],
614
        stop_words: Optional[List[str]],
615
        tools: Optional[Union[List[Tool], Toolset]] = None,
616
    ):
617
        """
618
        Handles async non-streaming generation of responses.
619
        """
620
        # convert messages to HF format
621
        hf_messages = [convert_message_to_hf_format(message) for message in messages]
1✔
622

623
        if isinstance(tools, Toolset):
1✔
624
            tools = list(tools)
×
625

626
        prepared_prompt = tokenizer.apply_chat_template(
1✔
627
            hf_messages,
628
            tokenize=False,
629
            chat_template=self.chat_template,
630
            add_generation_prompt=True,
631
            tools=[tc.tool_spec for tc in tools] if tools else None,
632
        )
633

634
        # prepared_prompt is a string, but transformers has some type issues
635
        prepared_prompt = cast(str, prepared_prompt)
1✔
636

637
        # Avoid some unnecessary warnings in the generation pipeline call
638
        generation_kwargs["pad_token_id"] = (
1✔
639
            generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
640
        )
641

642
        # Generate responses asynchronously
643
        output = await asyncio.get_running_loop().run_in_executor(
1✔
644
            self.executor,
645
            lambda: self.pipeline(prepared_prompt, **generation_kwargs),  # type: ignore # if self.executor was not passed it was initialized with max_workers=1 in init
646
        )
647

648
        replies = [o.get("generated_text", "") for o in output]
1✔
649

650
        # Remove stop words from replies if present
651
        for stop_word in stop_words or []:
1✔
652
            replies = [reply.replace(stop_word, "").rstrip() for reply in replies]
×
653

654
        chat_messages = [
1✔
655
            self.create_message(
656
                reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=bool(tools)
657
            )
658
            for r_index, reply in enumerate(replies)
659
        ]
660

661
        return {"replies": chat_messages}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc