• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 10211089327

02 Aug 2024 06:29AM UTC coverage: 90.091%. Remained the same
10211089327

Pull #8149

github

web-flow
Merge bc2d6e10c into c670f0fbe
Pull Request #8149: Docs: Update AzureOpenAIGenerator docstrings

6864 of 7619 relevant lines covered (90.09%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.5
haystack/components/generators/azure.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import os
1✔
6
from typing import Any, Callable, Dict, Optional
1✔
7

8
# pylint: disable=import-error
9
from openai.lib.azure import AzureOpenAI
1✔
10

11
from haystack import component, default_from_dict, default_to_dict, logging
1✔
12
from haystack.components.generators import OpenAIGenerator
1✔
13
from haystack.dataclasses import StreamingChunk
1✔
14
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1✔
15

16
logger = logging.getLogger(__name__)
1✔
17

18

19
@component
1✔
20
class AzureOpenAIGenerator(OpenAIGenerator):
1✔
21
    """
22
    Generates text using OpenAI's large language models (LLMs).
23

24
    It works with the gpt-4 and gpt-3.5-turbo models.
25
    You can customize how the text is generated by passing parameters to the
26
    OpenAI API. Use the `**generation_kwargs` argument when you initialize
27
    the component or when you run it. Any parameter that works with
28
    `openai.ChatCompletion.create` will work here too.
29

30

31
    For details on OpenAI API parameters, see
32
    [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat).
33

34

35
    ### Usage example
36

37
    ```python
38
    from haystack.components.generators import AzureOpenAIGenerator
39
    from haystack.utils import Secret
40
    client = AzureOpenAIGenerator(
41
        azure_endpoint="<Your Azure endpoint e.g. `https://your-company.azure.openai.com/>",
42
        api_key=Secret.from_token("<your-api-key>"),
43
        azure_deployment="<this a model name, e.g. gpt-35-turbo>")
44
    response = client.run("What's Natural Language Processing? Be brief.")
45
    print(response)
46
    ```
47

48
    ```
49
    >> {'replies': ['Natural Language Processing (NLP) is a branch of artificial intelligence that focuses on
50
    >> the interaction between computers and human language. It involves enabling computers to understand, interpret,
51
    >> and respond to natural human language in a way that is both meaningful and useful.'], 'meta': [{'model':
52
    >> 'gpt-3.5-turbo-0613', 'index': 0, 'finish_reason': 'stop', 'usage': {'prompt_tokens': 16,
53
    >> 'completion_tokens': 49, 'total_tokens': 65}}]}
54
    ```
55
    """
56

57
    # pylint: disable=super-init-not-called
58
    def __init__(
1✔
59
        self,
60
        azure_endpoint: Optional[str] = None,
61
        api_version: Optional[str] = "2023-05-15",
62
        azure_deployment: Optional[str] = "gpt-35-turbo",
63
        api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
64
        azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
65
        organization: Optional[str] = None,
66
        streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
67
        system_prompt: Optional[str] = None,
68
        timeout: Optional[float] = None,
69
        max_retries: Optional[int] = None,
70
        generation_kwargs: Optional[Dict[str, Any]] = None,
71
    ):
72
        """
73
        Initialize the Azure OpenAI Generator.
74

75
        :param azure_endpoint: The endpoint of the deployed model, for example `https://example-resource.azure.openai.com/`.
76
        :param api_version: The version of the API to use. Defaults to 2023-05-15.
77
        :param azure_deployment: The deployment of the model, usually the model name.
78
        :param api_key: The API key to use for authentication.
79
        :param azure_ad_token: [Azure Active Directory token](https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id).
80
        :param organization: The Organization ID, defaults to `None`. For help, see
81
        [Setting up your organization](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
82
        :param streaming_callback: A callback function called when a new token is received from the stream.
83
            It accepts StreamingChunk as an argument.
84
        :param system_prompt: The system prompt to use for text generation. If not provided, the system prompt is
85
        omitted, and the default system prompt of the model is used.
86
        :param timeout: Timeout for AzureOpenAI client. If not set, it is inferred from the
87
            `OPENAI_TIMEOUT` environment variable or set to 30.
88
        :param max_retries: Maximum retries to establish contact with AzureOpenAI if it returns an internal error.
89
            If not set, it is inferred from the `OPENAI_MAX_RETRIES` environment variable or set to 5.
90
        :param generation_kwargs: Other parameters to use for the model, sent directly to
91
            the OpenAI endpoint. See [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat) for
92
            more details.
93
            Some of the supported parameters:
94
            - `max_tokens`: The maximum number of tokens the output text can have.
95
            - `temperature`: The sampling temperature to use. Higher values mean the model takes more risks.
96
                Try 0.9 for more creative applications and 0 (argmax sampling) for ones with a well-defined answer.
97
            - `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model
98
                considers the results of the tokens with top_p probability mass. For example, 0.1 means only the tokens
99
                comprising the top 10% probability mass are considered.
100
            - `n`: The number of completions to generate for each prompt. For example, with 3 prompts and n=2,
101
                the LLM will generate two completions per prompt, resulting in 6 completions total.
102
            - `stop`: One or more sequences after which the LLM should stop generating tokens.
103
            - `presence_penalty`: The penalty applied if a token is already present.
104
                Higher values make the model less likely to repeat the token.
105
            - `frequency_penalty`: Penalty applied if a token has already been generated.
106
                Higher values make the model less likely to repeat the token.
107
            - `logit_bias`: Adds a logit bias to specific tokens. The keys of the dictionary are tokens, and the
108
                values are the bias to add to that token.
109
        """
110
        # We intentionally do not call super().__init__ here because we only need to instantiate the client to interact
111
        # with the API.
112

113
        # Why is this here?
114
        # AzureOpenAI init is forcing us to use an init method that takes either base_url or azure_endpoint as not
115
        # None init parameters. This way we accommodate the use case where env var AZURE_OPENAI_ENDPOINT is set instead
116
        # of passing it as a parameter.
117
        azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
1✔
118
        if not azure_endpoint:
1✔
119
            raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")
×
120

121
        if api_key is None and azure_ad_token is None:
1✔
122
            raise ValueError("Please provide an API key or an Azure Active Directory token.")
×
123

124
        # The check above makes mypy incorrectly infer that api_key is never None,
125
        # which propagates the incorrect type.
126
        self.api_key = api_key  # type: ignore
1✔
127
        self.azure_ad_token = azure_ad_token
1✔
128
        self.generation_kwargs = generation_kwargs or {}
1✔
129
        self.system_prompt = system_prompt
1✔
130
        self.streaming_callback = streaming_callback
1✔
131
        self.api_version = api_version
1✔
132
        self.azure_endpoint = azure_endpoint
1✔
133
        self.azure_deployment = azure_deployment
1✔
134
        self.organization = organization
1✔
135
        self.model: str = azure_deployment or "gpt-35-turbo"
1✔
136
        self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
1✔
137
        self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
1✔
138

139
        self.client = AzureOpenAI(
1✔
140
            api_version=api_version,
141
            azure_endpoint=azure_endpoint,
142
            azure_deployment=azure_deployment,
143
            api_key=api_key.resolve_value() if api_key is not None else None,
144
            azure_ad_token=azure_ad_token.resolve_value() if azure_ad_token is not None else None,
145
            organization=organization,
146
            timeout=self.timeout,
147
            max_retries=self.max_retries,
148
        )
149

150
    def to_dict(self) -> Dict[str, Any]:
1✔
151
        """
152
        Serialize this component to a dictionary.
153

154
        :returns:
155
            The serialized component as a dictionary.
156
        """
157
        callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
1✔
158
        return default_to_dict(
1✔
159
            self,
160
            azure_endpoint=self.azure_endpoint,
161
            azure_deployment=self.azure_deployment,
162
            organization=self.organization,
163
            api_version=self.api_version,
164
            streaming_callback=callback_name,
165
            generation_kwargs=self.generation_kwargs,
166
            system_prompt=self.system_prompt,
167
            api_key=self.api_key.to_dict() if self.api_key is not None else None,
168
            azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
169
            timeout=self.timeout,
170
            max_retries=self.max_retries,
171
        )
172

173
    @classmethod
1✔
174
    def from_dict(cls, data: Dict[str, Any]) -> "AzureOpenAIGenerator":
1✔
175
        """
176
        Deserialize this component from a dictionary.
177

178
        :param data:
179
            The dictionary representation of this component.
180
        :returns:
181
            The deserialized component instance.
182
        """
183
        deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_ad_token"])
1✔
184
        init_params = data.get("init_parameters", {})
1✔
185
        serialized_callback_handler = init_params.get("streaming_callback")
1✔
186
        if serialized_callback_handler:
1✔
187
            data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
×
188
        return default_from_dict(cls, data)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc