• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 14402617714

11 Apr 2025 11:58AM UTC coverage: 90.327% (-0.05%) from 90.373%
14402617714

Pull #9215

github

web-flow
Merge dc5acd967 into 8bf41a851
Pull Request #9215: feat: Allow OpenAI client config in `OpenAIChatGenerator` and `AzureOpenAIChatGenerator`

10674 of 11817 relevant lines covered (90.33%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

71.43
haystack/components/extractors/llm_metadata_extractor.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import copy
1✔
6
import json
1✔
7
import warnings
1✔
8
from concurrent.futures import ThreadPoolExecutor
1✔
9
from enum import Enum
1✔
10
from typing import Any, Dict, List, Optional, Union
1✔
11

12
from jinja2 import meta
1✔
13
from jinja2.sandbox import SandboxedEnvironment
1✔
14

15
from haystack import Document, component, default_from_dict, default_to_dict, logging
1✔
16
from haystack.components.builders import PromptBuilder
1✔
17
from haystack.components.generators.chat import AzureOpenAIChatGenerator, OpenAIChatGenerator
1✔
18
from haystack.components.generators.chat.types import ChatGenerator
1✔
19
from haystack.components.preprocessors import DocumentSplitter
1✔
20
from haystack.core.serialization import component_to_dict
1✔
21
from haystack.dataclasses import ChatMessage
1✔
22
from haystack.lazy_imports import LazyImport
1✔
23
from haystack.utils import (
1✔
24
    deserialize_callable,
25
    deserialize_chatgenerator_inplace,
26
    deserialize_secrets_inplace,
27
    expand_page_range,
28
)
29

30
with LazyImport(message="Run 'pip install \"amazon-bedrock-haystack>=1.0.2\"'") as amazon_bedrock_generator:
1✔
31
    from haystack_integrations.components.generators.amazon_bedrock import (  #  pylint: disable=import-error
1✔
32
        AmazonBedrockChatGenerator,
33
    )
34

35
with LazyImport(message="Run 'pip install \"google-vertex-haystack>=2.0.0\"'") as vertex_ai_gemini_generator:
1✔
36
    from haystack_integrations.components.generators.google_vertex.chat.gemini import (  # pylint: disable=import-error
1✔
37
        VertexAIGeminiChatGenerator,
38
    )
39
    from vertexai.generative_models import GenerationConfig  # pylint: disable=import-error
1✔
40

41

42
logger = logging.getLogger(__name__)
1✔
43

44

45
class LLMProvider(Enum):
1✔
46
    """
47
    Currently LLM providers supported by `LLMMetadataExtractor`.
48
    """
49

50
    OPENAI = "openai"
1✔
51
    OPENAI_AZURE = "openai_azure"
1✔
52
    AWS_BEDROCK = "aws_bedrock"
1✔
53
    GOOGLE_VERTEX = "google_vertex"
1✔
54

55
    @staticmethod
1✔
56
    def from_str(string: str) -> "LLMProvider":
1✔
57
        """
58
        Convert a string to a LLMProvider enum.
59
        """
60
        provider_map = {e.value: e for e in LLMProvider}
1✔
61
        provider = provider_map.get(string)
1✔
62
        if provider is None:
1✔
63
            msg = f"Invalid LLMProvider '{string}'Supported LLMProviders are: {list(provider_map.keys())}"
×
64
            raise ValueError(msg)
×
65
        return provider
1✔
66

67

68
@component
1✔
69
class LLMMetadataExtractor:
1✔
70
    """
71
    Extracts metadata from documents using a Large Language Model (LLM).
72

73
    The metadata is extracted by providing a prompt to an LLM that generates the metadata.
74

75
    This component expects as input a list of documents and a prompt. The prompt should have a variable called
76
    `document` that will point to a single document in the list of documents. So to access the content of the document,
77
    you can use `{{ document.content }}` in the prompt.
78

79
    The component will run the LLM on each document in the list and extract metadata from the document. The metadata
80
    will be added to the document's metadata field. If the LLM fails to extract metadata from a document, the document
81
    will be added to the `failed_documents` list. The failed documents will have the keys `metadata_extraction_error` and
82
    `metadata_extraction_response` in their metadata. These documents can be re-run with another extractor to
83
    extract metadata by using the `metadata_extraction_response` and `metadata_extraction_error` in the prompt.
84

85
    ```python
86
    from haystack import Document
87
    from haystack.components.extractors.llm_metadata_extractor import LLMMetadataExtractor
88
    from haystack.components.generators.chat import OpenAIChatGenerator
89

90
    NER_PROMPT = '''
91
    -Goal-
92
    Given text and a list of entity types, identify all entities of those types from the text.
93

94
    -Steps-
95
    1. Identify all entities. For each identified entity, extract the following information:
96
    - entity_name: Name of the entity, capitalized
97
    - entity_type: One of the following types: [organization, product, service, industry]
98
    Format each entity as a JSON like: {"entity": <entity_name>, "entity_type": <entity_type>}
99

100
    2. Return output in a single list with all the entities identified in steps 1.
101

102
    -Examples-
103
    ######################
104
    Example 1:
105
    entity_types: [organization, person, partnership, financial metric, product, service, industry, investment strategy, market trend]
106
    text: Another area of strength is our co-brand issuance. Visa is the primary network partner for eight of the top
107
    10 co-brand partnerships in the US today and we are pleased that Visa has finalized a multi-year extension of
108
    our successful credit co-branded partnership with Alaska Airlines, a portfolio that benefits from a loyal customer
109
    base and high cross-border usage.
110
    We have also had significant co-brand momentum in CEMEA. First, we launched a new co-brand card in partnership
111
    with Qatar Airways, British Airways and the National Bank of Kuwait. Second, we expanded our strong global
112
    Marriott relationship to launch Qatar's first hospitality co-branded card with Qatar Islamic Bank. Across the
113
    United Arab Emirates, we now have exclusive agreements with all the leading airlines marked by a recent
114
    agreement with Emirates Skywards.
115
    And we also signed an inaugural Airline co-brand agreement in Morocco with Royal Air Maroc. Now newer digital
116
    issuers are equally
117
    ------------------------
118
    output:
119
    {"entities": [{"entity": "Visa", "entity_type": "company"}, {"entity": "Alaska Airlines", "entity_type": "company"}, {"entity": "Qatar Airways", "entity_type": "company"}, {"entity": "British Airways", "entity_type": "company"}, {"entity": "National Bank of Kuwait", "entity_type": "company"}, {"entity": "Marriott", "entity_type": "company"}, {"entity": "Qatar Islamic Bank", "entity_type": "company"}, {"entity": "Emirates Skywards", "entity_type": "company"}, {"entity": "Royal Air Maroc", "entity_type": "company"}]}
120
    #############################
121
    -Real Data-
122
    ######################
123
    entity_types: [company, organization, person, country, product, service]
124
    text: {{ document.content }}
125
    ######################
126
    output:
127
    '''
128

129
    docs = [
130
        Document(content="deepset was founded in 2018 in Berlin, and is known for its Haystack framework"),
131
        Document(content="Hugging Face is a company that was founded in New York, USA and is known for its Transformers library")
132
    ]
133

134
    chat_generator = OpenAIChatGenerator(
135
        generation_kwargs={
136
            "max_tokens": 500,
137
            "temperature": 0.0,
138
            "seed": 0,
139
            "response_format": {"type": "json_object"},
140
        },
141
        max_retries=1,
142
        timeout=60.0,
143
    )
144

145
    extractor = LLMMetadataExtractor(
146
        prompt=NER_PROMPT,
147
        chat_generator=generator,
148
        expected_keys=["entities"],
149
        raise_on_failure=False,
150
    )
151

152
    extractor.warm_up()
153
    extractor.run(documents=docs)
154
    >> {'documents': [
155
        Document(id=.., content: 'deepset was founded in 2018 in Berlin, and is known for its Haystack framework',
156
        meta: {'entities': [{'entity': 'deepset', 'entity_type': 'company'}, {'entity': 'Berlin', 'entity_type': 'city'},
157
              {'entity': 'Haystack', 'entity_type': 'product'}]}),
158
        Document(id=.., content: 'Hugging Face is a company that was founded in New York, USA and is known for its Transformers library',
159
        meta: {'entities': [
160
                {'entity': 'Hugging Face', 'entity_type': 'company'}, {'entity': 'New York', 'entity_type': 'city'},
161
                {'entity': 'USA', 'entity_type': 'country'}, {'entity': 'Transformers', 'entity_type': 'product'}
162
                ]})
163
           ]
164
        'failed_documents': []
165
       }
166
    >>
167
    ```
168
    """  # noqa: E501
169

170
    def __init__(  # pylint: disable=R0917
1✔
171
        self,
172
        prompt: str,
173
        generator_api: Optional[Union[str, LLMProvider]] = None,
174
        generator_api_params: Optional[Dict[str, Any]] = None,
175
        chat_generator: Optional[ChatGenerator] = None,
176
        expected_keys: Optional[List[str]] = None,
177
        page_range: Optional[List[Union[str, int]]] = None,
178
        raise_on_failure: bool = False,
179
        max_workers: int = 3,
180
    ):
181
        """
182
        Initializes the LLMMetadataExtractor.
183

184
        :param prompt: The prompt to be used for the LLM.
185
        :param generator_api: The API provider for the LLM. Deprecated. Use chat_generator to configure the LLM.
186
            Currently supported providers are: "openai", "openai_azure", "aws_bedrock", "google_vertex".
187
        :param generator_api_params: The parameters for the LLM generator. Deprecated. Use chat_generator to configure
188
            the LLM.
189
        :param chat_generator: a ChatGenerator instance which represents the LLM. If provided, this will override
190
            settings in generator_api and generator_api_params.
191
        :param expected_keys: The keys expected in the JSON output from the LLM.
192
        :param page_range: A range of pages to extract metadata from. For example, page_range=['1', '3'] will extract
193
            metadata from the first and third pages of each document. It also accepts printable range strings, e.g.:
194
            ['1-3', '5', '8', '10-12'] will extract metadata from pages 1, 2, 3, 5, 8, 10,11, 12.
195
            If None, metadata will be extracted from the entire document for each document in the documents list.
196
            This parameter is optional and can be overridden in the `run` method.
197
        :param raise_on_failure: Whether to raise an error on failure during the execution of the Generator or
198
            validation of the JSON output.
199
        :param max_workers: The maximum number of workers to use in the thread pool executor.
200
        """
201
        self.prompt = prompt
1✔
202
        ast = SandboxedEnvironment().parse(prompt)
1✔
203
        template_variables = meta.find_undeclared_variables(ast)
1✔
204
        variables = list(template_variables)
1✔
205
        if len(variables) > 1 or variables[0] != "document":
1✔
206
            raise ValueError(
1✔
207
                f"Prompt must have exactly one variable called 'document'. Found {','.join(variables)} in the prompt."
208
            )
209
        self.builder = PromptBuilder(prompt, required_variables=variables)
1✔
210
        self.raise_on_failure = raise_on_failure
1✔
211
        self.expected_keys = expected_keys or []
1✔
212
        generator_api_params = generator_api_params or {}
1✔
213

214
        if generator_api is None and chat_generator is None:
1✔
215
            raise ValueError("Either generator_api or chat_generator must be provided.")
1✔
216

217
        if chat_generator is not None:
1✔
218
            self._chat_generator = chat_generator
1✔
219
            if generator_api is not None:
1✔
220
                logger.warning(
1✔
221
                    "Both chat_generator and generator_api are provided. "
222
                    "chat_generator will be used. generator_api/generator_api_params/LLMProvider are deprecated and "
223
                    "will be removed in Haystack 2.13.0."
224
                )
225
        else:
226
            warnings.warn(
1✔
227
                "generator_api, generator_api_params, and LLMProvider are deprecated and will be removed in Haystack "
228
                "2.13.0. Use chat_generator instead. For example, change `generator_api=LLMProvider.OPENAI` to "
229
                "`chat_generator=OpenAIChatGenerator()`.",
230
                DeprecationWarning,
231
            )
232
            assert generator_api is not None  # verified by the checks above
1✔
233
            generator_api = (
1✔
234
                generator_api if isinstance(generator_api, LLMProvider) else LLMProvider.from_str(generator_api)
235
            )
236
            self._chat_generator = self._init_generator(generator_api, generator_api_params)
1✔
237

238
        self.splitter = DocumentSplitter(split_by="page", split_length=1)
1✔
239
        self.expanded_range = expand_page_range(page_range) if page_range else None
1✔
240
        self.max_workers = max_workers
1✔
241

242
    @staticmethod
1✔
243
    def _init_generator(
1✔
244
        generator_api: LLMProvider, generator_api_params: Optional[Dict[str, Any]]
245
    ) -> Union[
246
        OpenAIChatGenerator, AzureOpenAIChatGenerator, "AmazonBedrockChatGenerator", "VertexAIGeminiChatGenerator"
247
    ]:
248
        """
249
        Initialize the chat generator based on the specified API provider and parameters.
250
        """
251

252
        generator_api_params = generator_api_params or {}
1✔
253

254
        if generator_api == LLMProvider.OPENAI:
1✔
255
            return OpenAIChatGenerator(**generator_api_params)
1✔
256
        elif generator_api == LLMProvider.OPENAI_AZURE:
×
257
            return AzureOpenAIChatGenerator(**generator_api_params)
×
258
        elif generator_api == LLMProvider.AWS_BEDROCK:
×
259
            amazon_bedrock_generator.check()
×
260
            return AmazonBedrockChatGenerator(**generator_api_params)
×
261
        elif generator_api == LLMProvider.GOOGLE_VERTEX:
×
262
            vertex_ai_gemini_generator.check()
×
263
            return VertexAIGeminiChatGenerator(**generator_api_params)
×
264
        else:
265
            raise ValueError(f"Unsupported generator API: {generator_api}")
×
266

267
    def warm_up(self):
1✔
268
        """
269
        Warm up the LLM provider component.
270
        """
271
        if hasattr(self._chat_generator, "warm_up"):
1✔
272
            self._chat_generator.warm_up()
1✔
273

274
    def to_dict(self) -> Dict[str, Any]:
1✔
275
        """
276
        Serializes the component to a dictionary.
277

278
        :returns:
279
            Dictionary with serialized data.
280
        """
281

282
        return default_to_dict(
1✔
283
            self,
284
            prompt=self.prompt,
285
            chat_generator=component_to_dict(obj=self._chat_generator, name="chat_generator"),
286
            expected_keys=self.expected_keys,
287
            page_range=self.expanded_range,
288
            raise_on_failure=self.raise_on_failure,
289
            max_workers=self.max_workers,
290
        )
291

292
    @classmethod
1✔
293
    def from_dict(cls, data: Dict[str, Any]) -> "LLMMetadataExtractor":
1✔
294
        """
295
        Deserializes the component from a dictionary.
296

297
        :param data:
298
            Dictionary with serialized data.
299
        :returns:
300
            An instance of the component.
301
        """
302

303
        init_parameters = data.get("init_parameters", {})
1✔
304

305
        # new deserialization with chat_generator
306
        if init_parameters.get("chat_generator") is not None:
1✔
307
            deserialize_chatgenerator_inplace(data["init_parameters"], key="chat_generator")
1✔
308
            return default_from_dict(cls, data)
1✔
309

310
        # legacy deserialization
311
        if "generator_api" in init_parameters:
1✔
312
            data["init_parameters"]["generator_api"] = LLMProvider.from_str(data["init_parameters"]["generator_api"])
1✔
313

314
        if "generator_api_params" in init_parameters:
1✔
315
            # Check all the keys that need to be deserialized
316
            azure_openai_keys = ["azure_ad_token"]
1✔
317
            aws_bedrock_keys = [
1✔
318
                "aws_access_key_id",
319
                "aws_secret_access_key",
320
                "aws_session_token",
321
                "aws_region_name",
322
                "aws_profile_name",
323
            ]
324
            deserialize_secrets_inplace(
1✔
325
                data["init_parameters"]["generator_api_params"], keys=["api_key"] + azure_openai_keys + aws_bedrock_keys
326
            )
327

328
            # For VertexAI
329
            if "generation_config" in init_parameters["generator_api_params"]:
1✔
330
                data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(
×
331
                    init_parameters["generator_api_params"]["generation_config"]
332
                )
333

334
            # For AzureOpenAI
335
            serialized_azure_ad_token_provider = init_parameters["generator_api_params"].get("azure_ad_token_provider")
1✔
336
            if serialized_azure_ad_token_provider:
1✔
337
                data["init_parameters"]["azure_ad_token_provider"] = deserialize_callable(
×
338
                    serialized_azure_ad_token_provider
339
                )
340

341
            # For all
342
            serialized_callback_handler = init_parameters["generator_api_params"].get("streaming_callback")
1✔
343
            if serialized_callback_handler:
1✔
344
                data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
×
345

346
        return default_from_dict(cls, data)
1✔
347

348
    def _extract_metadata(self, llm_answer: str) -> Dict[str, Any]:
1✔
349
        try:
1✔
350
            parsed_metadata = json.loads(llm_answer)
1✔
351
        except json.JSONDecodeError as e:
1✔
352
            logger.warning(
1✔
353
                "Response from the LLM is not valid JSON. Skipping metadata extraction. Received output: {response}",
354
                response=llm_answer,
355
            )
356
            if self.raise_on_failure:
1✔
357
                raise e
1✔
358
            return {"error": "Response is not valid JSON. Received JSONDecodeError: " + str(e)}
1✔
359

360
        if not all(key in parsed_metadata for key in self.expected_keys):
1✔
361
            logger.warning(
1✔
362
                "Expected response from LLM to be a JSON with keys {expected_keys}, got {parsed_json}. "
363
                "Continuing extraction with received output.",
364
                expected_keys=self.expected_keys,
365
                parsed_json=parsed_metadata,
366
            )
367

368
        return parsed_metadata
1✔
369

370
    def _prepare_prompts(
1✔
371
        self, documents: List[Document], expanded_range: Optional[List[int]] = None
372
    ) -> List[Union[ChatMessage, None]]:
373
        all_prompts: List[Union[ChatMessage, None]] = []
1✔
374
        for document in documents:
1✔
375
            if not document.content:
1✔
376
                logger.warning("Document {doc_id} has no content. Skipping metadata extraction.", doc_id=document.id)
1✔
377
                all_prompts.append(None)
1✔
378
                continue
1✔
379

380
            if expanded_range:
1✔
381
                doc_copy = copy.deepcopy(document)
1✔
382
                pages = self.splitter.run(documents=[doc_copy])
1✔
383
                content = ""
1✔
384
                for idx, page in enumerate(pages["documents"]):
1✔
385
                    if idx + 1 in expanded_range:
1✔
386
                        content += page.content
1✔
387
                doc_copy.content = content
1✔
388
            else:
389
                doc_copy = document
1✔
390

391
            prompt_with_doc = self.builder.run(template=self.prompt, template_variables={"document": doc_copy})
1✔
392

393
            # build a ChatMessage with the prompt
394
            message = ChatMessage.from_user(prompt_with_doc["prompt"])
1✔
395
            all_prompts.append(message)
1✔
396

397
        return all_prompts
1✔
398

399
    def _run_on_thread(self, prompt: Optional[ChatMessage]) -> Dict[str, Any]:
1✔
400
        # If prompt is None, return an empty dictionary
401
        if prompt is None:
×
402
            return {"replies": ["{}"]}
×
403

404
        try:
×
405
            result = self._chat_generator.run(messages=[prompt])
×
406
        except Exception as e:
×
407
            if self.raise_on_failure:
×
408
                raise e
×
409
            logger.error(
×
410
                "LLM {class_name} execution failed. Skipping metadata extraction. Failed with exception '{error}'.",
411
                class_name=self._chat_generator.__class__.__name__,
412
                error=e,
413
            )
414
            result = {"error": "LLM failed with exception: " + str(e)}
×
415
        return result
×
416

417
    @component.output_types(documents=List[Document], failed_documents=List[Document])
1✔
418
    def run(self, documents: List[Document], page_range: Optional[List[Union[str, int]]] = None):
1✔
419
        """
420
        Extract metadata from documents using a Large Language Model.
421

422
        If `page_range` is provided, the metadata will be extracted from the specified range of pages. This component
423
        will split the documents into pages and extract metadata from the specified range of pages. The metadata will be
424
        extracted from the entire document if `page_range` is not provided.
425

426
        The original documents will be returned  updated with the extracted metadata.
427

428
        :param documents: List of documents to extract metadata from.
429
        :param page_range: A range of pages to extract metadata from. For example, page_range=['1', '3'] will extract
430
                           metadata from the first and third pages of each document. It also accepts printable range
431
                           strings, e.g.: ['1-3', '5', '8', '10-12'] will extract metadata from pages 1, 2, 3, 5, 8, 10,
432
                           11, 12.
433
                           If None, metadata will be extracted from the entire document for each document in the
434
                           documents list.
435
        :returns:
436
            A dictionary with the keys:
437
            - "documents": A list of documents that were successfully updated with the extracted metadata.
438
            - "failed_documents": A list of documents that failed to extract metadata. These documents will have
439
            "metadata_extraction_error" and "metadata_extraction_response" in their metadata. These documents can be
440
            re-run with the extractor to extract metadata.
441
        """
442
        if len(documents) == 0:
1✔
443
            logger.warning("No documents provided. Skipping metadata extraction.")
1✔
444
            return {"documents": [], "failed_documents": []}
1✔
445

446
        expanded_range = self.expanded_range
×
447
        if page_range:
×
448
            expanded_range = expand_page_range(page_range)
×
449

450
        # Create ChatMessage prompts for each document
451
        all_prompts = self._prepare_prompts(documents=documents, expanded_range=expanded_range)
×
452

453
        # Run the LLM on each prompt
454
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
×
455
            results = executor.map(self._run_on_thread, all_prompts)
×
456

457
        successful_documents = []
×
458
        failed_documents = []
×
459
        for document, result in zip(documents, results):
×
460
            if "error" in result:
×
461
                document.meta["metadata_extraction_error"] = result["error"]
×
462
                document.meta["metadata_extraction_response"] = None
×
463
                failed_documents.append(document)
×
464
                continue
×
465

466
            parsed_metadata = self._extract_metadata(result["replies"][0].text)
×
467
            if "error" in parsed_metadata:
×
468
                document.meta["metadata_extraction_error"] = parsed_metadata["error"]
×
469
                document.meta["metadata_extraction_response"] = result["replies"][0]
×
470
                failed_documents.append(document)
×
471
                continue
×
472

473
            for key in parsed_metadata:
×
474
                document.meta[key] = parsed_metadata[key]
×
475
                # Remove metadata_extraction_error and metadata_extraction_response if present from previous runs
476
                document.meta.pop("metadata_extraction_error", None)
×
477
                document.meta.pop("metadata_extraction_response", None)
×
478
            successful_documents.append(document)
×
479

480
        return {"documents": successful_documents, "failed_documents": failed_documents}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc