• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 14039186228

24 Mar 2025 03:26PM UTC coverage: 90.097% (+0.08%) from 90.016%
14039186228

Pull #9099

github

web-flow
Merge 40da60053 into dae8c7bab
Pull Request #9099: refactor: `LLMMetadataExtractor` - adopt `ChatGenerator` protocol

9835 of 10916 relevant lines covered (90.1%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

67.7
haystack/components/extractors/llm_metadata_extractor.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import copy
1✔
6
import json
1✔
7
import warnings
1✔
8
from concurrent.futures import ThreadPoolExecutor
1✔
9
from enum import Enum
1✔
10
from typing import Any, Dict, List, Optional, Union
1✔
11

12
from jinja2 import meta
1✔
13
from jinja2.sandbox import SandboxedEnvironment
1✔
14

15
from haystack import Document, component, default_from_dict, default_to_dict, logging
1✔
16
from haystack.components.builders import PromptBuilder
1✔
17
from haystack.components.generators.chat import AzureOpenAIChatGenerator, OpenAIChatGenerator
1✔
18
from haystack.components.generators.chat.types import ChatGenerator
1✔
19
from haystack.components.preprocessors import DocumentSplitter
1✔
20
from haystack.core.serialization import import_class_by_name
1✔
21
from haystack.dataclasses import ChatMessage
1✔
22
from haystack.lazy_imports import LazyImport
1✔
23
from haystack.utils import expand_page_range
1✔
24

25
with LazyImport(message="Run 'pip install \"amazon-bedrock-haystack>=1.0.2\"'") as amazon_bedrock_generator:
1✔
26
    from haystack_integrations.components.generators.amazon_bedrock import (  #  pylint: disable=import-error
1✔
27
        AmazonBedrockChatGenerator,
28
    )
29

30
with LazyImport(message="Run 'pip install \"google-vertex-haystack>=2.0.0\"'") as vertex_ai_gemini_generator:
1✔
31
    from haystack_integrations.components.generators.google_vertex.chat.gemini import (  # pylint: disable=import-error
1✔
32
        VertexAIGeminiChatGenerator,
33
    )
34

35

36
logger = logging.getLogger(__name__)
1✔
37

38

39
class LLMProvider(Enum):
1✔
40
    """
41
    Currently LLM providers supported by `LLMMetadataExtractor`.
42
    """
43

44
    OPENAI = "openai"
1✔
45
    OPENAI_AZURE = "openai_azure"
1✔
46
    AWS_BEDROCK = "aws_bedrock"
1✔
47
    GOOGLE_VERTEX = "google_vertex"
1✔
48

49
    @staticmethod
1✔
50
    def from_str(string: str) -> "LLMProvider":
1✔
51
        """
52
        Convert a string to a LLMProvider enum.
53
        """
54
        provider_map = {e.value: e for e in LLMProvider}
×
55
        provider = provider_map.get(string)
×
56
        if provider is None:
×
57
            msg = f"Invalid LLMProvider '{string}'Supported LLMProviders are: {list(provider_map.keys())}"
×
58
            raise ValueError(msg)
×
59
        return provider
×
60

61

62
@component
1✔
63
class LLMMetadataExtractor:
1✔
64
    """
65
    Extracts metadata from documents using a Large Language Model (LLM).
66

67
    The metadata is extracted by providing a prompt to an LLM that generates the metadata.
68

69
    This component expects as input a list of documents and a prompt. The prompt should have a variable called
70
    `document` that will point to a single document in the list of documents. So to access the content of the document,
71
    you can use `{{ document.content }}` in the prompt.
72

73
    The component will run the LLM on each document in the list and extract metadata from the document. The metadata
74
    will be added to the document's metadata field. If the LLM fails to extract metadata from a document, the document
75
    will be added to the `failed_documents` list. The failed documents will have the keys `metadata_extraction_error` and
76
    `metadata_extraction_response` in their metadata. These documents can be re-run with another extractor to
77
    extract metadata by using the `metadata_extraction_response` and `metadata_extraction_error` in the prompt.
78

79
    ```python
80
    from haystack import Document
81
    from haystack.components.extractors.llm_metadata_extractor import LLMMetadataExtractor
82
    from haystack.components.generators.chat import OpenAIChatGenerator
83

84
    NER_PROMPT = '''
85
    -Goal-
86
    Given text and a list of entity types, identify all entities of those types from the text.
87

88
    -Steps-
89
    1. Identify all entities. For each identified entity, extract the following information:
90
    - entity_name: Name of the entity, capitalized
91
    - entity_type: One of the following types: [organization, product, service, industry]
92
    Format each entity as a JSON like: {"entity": <entity_name>, "entity_type": <entity_type>}
93

94
    2. Return output in a single list with all the entities identified in steps 1.
95

96
    -Examples-
97
    ######################
98
    Example 1:
99
    entity_types: [organization, person, partnership, financial metric, product, service, industry, investment strategy, market trend]
100
    text: Another area of strength is our co-brand issuance. Visa is the primary network partner for eight of the top
101
    10 co-brand partnerships in the US today and we are pleased that Visa has finalized a multi-year extension of
102
    our successful credit co-branded partnership with Alaska Airlines, a portfolio that benefits from a loyal customer
103
    base and high cross-border usage.
104
    We have also had significant co-brand momentum in CEMEA. First, we launched a new co-brand card in partnership
105
    with Qatar Airways, British Airways and the National Bank of Kuwait. Second, we expanded our strong global
106
    Marriott relationship to launch Qatar's first hospitality co-branded card with Qatar Islamic Bank. Across the
107
    United Arab Emirates, we now have exclusive agreements with all the leading airlines marked by a recent
108
    agreement with Emirates Skywards.
109
    And we also signed an inaugural Airline co-brand agreement in Morocco with Royal Air Maroc. Now newer digital
110
    issuers are equally
111
    ------------------------
112
    output:
113
    {"entities": [{"entity": "Visa", "entity_type": "company"}, {"entity": "Alaska Airlines", "entity_type": "company"}, {"entity": "Qatar Airways", "entity_type": "company"}, {"entity": "British Airways", "entity_type": "company"}, {"entity": "National Bank of Kuwait", "entity_type": "company"}, {"entity": "Marriott", "entity_type": "company"}, {"entity": "Qatar Islamic Bank", "entity_type": "company"}, {"entity": "Emirates Skywards", "entity_type": "company"}, {"entity": "Royal Air Maroc", "entity_type": "company"}]}
114
    #############################
115
    -Real Data-
116
    ######################
117
    entity_types: [company, organization, person, country, product, service]
118
    text: {{ document.content }}
119
    ######################
120
    output:
121
    '''
122

123
    docs = [
124
        Document(content="deepset was founded in 2018 in Berlin, and is known for its Haystack framework"),
125
        Document(content="Hugging Face is a company that was founded in New York, USA and is known for its Transformers library")
126
    ]
127

128
    chat_generator = OpenAIChatGenerator(
129
        generation_kwargs={
130
            "max_tokens": 500,
131
            "temperature": 0.0,
132
            "seed": 0,
133
            "response_format": {"type": "json_object"},
134
        },
135
        max_retries=1,
136
        timeout=60.0,
137
    )
138

139
    extractor = LLMMetadataExtractor(
140
        prompt=NER_PROMPT,
141
        chat_generator=generator,
142
        expected_keys=["entities"],
143
        raise_on_failure=False,
144
    )
145

146
    extractor.warm_up()
147
    extractor.run(documents=docs)
148
    >> {'documents': [
149
        Document(id=.., content: 'deepset was founded in 2018 in Berlin, and is known for its Haystack framework',
150
        meta: {'entities': [{'entity': 'deepset', 'entity_type': 'company'}, {'entity': 'Berlin', 'entity_type': 'city'},
151
              {'entity': 'Haystack', 'entity_type': 'product'}]}),
152
        Document(id=.., content: 'Hugging Face is a company that was founded in New York, USA and is known for its Transformers library',
153
        meta: {'entities': [
154
                {'entity': 'Hugging Face', 'entity_type': 'company'}, {'entity': 'New York', 'entity_type': 'city'},
155
                {'entity': 'USA', 'entity_type': 'country'}, {'entity': 'Transformers', 'entity_type': 'product'}
156
                ]})
157
           ]
158
        'failed_documents': []
159
       }
160
    >>
161
    ```
162
    """  # noqa: E501
163

164
    def __init__(  # pylint: disable=R0917
1✔
165
        self,
166
        prompt: str,
167
        generator_api: Optional[Union[str, LLMProvider]] = None,
168
        generator_api_params: Optional[Dict[str, Any]] = None,
169
        chat_generator: Optional[ChatGenerator] = None,
170
        expected_keys: Optional[List[str]] = None,
171
        page_range: Optional[List[Union[str, int]]] = None,
172
        raise_on_failure: bool = False,
173
        max_workers: int = 3,
174
    ):
175
        """
176
        Initializes the LLMMetadataExtractor.
177

178
        :param prompt: The prompt to be used for the LLM.
179
        :param generator_api: The API provider for the LLM. Deprecated. Use chat_generator to configure the LLM.
180
            Currently supported providers are: "openai", "openai_azure", "aws_bedrock", "google_vertex".
181
        :param generator_api_params: The parameters for the LLM generator. Deprecated. Use chat_generator to configure
182
            the LLM.
183
        :param chat_generator: a ChatGenerator instance which represents the LLM. If provided, this will override
184
            settings in generator_api and generator_api_params.
185
        :param expected_keys: The keys expected in the JSON output from the LLM.
186
        :param page_range: A range of pages to extract metadata from. For example, page_range=['1', '3'] will extract
187
            metadata from the first and third pages of each document. It also accepts printable range strings, e.g.:
188
            ['1-3', '5', '8', '10-12'] will extract metadata from pages 1, 2, 3, 5, 8, 10,11, 12.
189
            If None, metadata will be extracted from the entire document for each document in the documents list.
190
            This parameter is optional and can be overridden in the `run` method.
191
        :param raise_on_failure: Whether to raise an error on failure during the execution of the Generator or
192
            validation of the JSON output.
193
        :param max_workers: The maximum number of workers to use in the thread pool executor.
194
        """
195
        self.prompt = prompt
1✔
196
        ast = SandboxedEnvironment().parse(prompt)
1✔
197
        template_variables = meta.find_undeclared_variables(ast)
1✔
198
        variables = list(template_variables)
1✔
199
        if len(variables) > 1 or variables[0] != "document":
1✔
200
            raise ValueError(
1✔
201
                f"Prompt must have exactly one variable called 'document'. Found {','.join(variables)} in the prompt."
202
            )
203
        self.builder = PromptBuilder(prompt, required_variables=variables)
1✔
204
        self.raise_on_failure = raise_on_failure
1✔
205
        self.expected_keys = expected_keys or []
1✔
206
        generator_api_params = generator_api_params or {}
1✔
207

208
        if generator_api is None and chat_generator is None:
1✔
209
            raise ValueError("Either generator_api or chat_generator must be provided.")
1✔
210

211
        if chat_generator is not None:
1✔
212
            self._chat_generator = chat_generator
1✔
213
            if generator_api is not None:
1✔
214
                logger.warning(
1✔
215
                    "Both chat_generator and generator_api are provided. "
216
                    "chat_generator will be used. generator_api/generator_api_params are deprecated and "
217
                    "will be removed in Haystack 2.13.0."
218
                )
219
        else:
220
            warnings.warn(
1✔
221
                "generator_api and generator_api_params are deprecated and will be removed in Haystack "
222
                "2.13.0. Use chat_generator instead. For example, change `generator_api=LLMProvider.OPENAI` to "
223
                "`chat_generator=OpenAIChatGenerator()`.",
224
                DeprecationWarning,
225
            )
226
            assert generator_api is not None  # verified by the checks above
1✔
227
            generator_api = (
1✔
228
                generator_api if isinstance(generator_api, LLMProvider) else LLMProvider.from_str(generator_api)
229
            )
230
            self._chat_generator = self._init_generator(generator_api, generator_api_params)
1✔
231

232
        self.splitter = DocumentSplitter(split_by="page", split_length=1)
1✔
233
        self.expanded_range = expand_page_range(page_range) if page_range else None
1✔
234
        self.max_workers = max_workers
1✔
235

236
    @staticmethod
1✔
237
    def _init_generator(
1✔
238
        generator_api: LLMProvider, generator_api_params: Optional[Dict[str, Any]]
239
    ) -> Union[
240
        OpenAIChatGenerator, AzureOpenAIChatGenerator, "AmazonBedrockChatGenerator", "VertexAIGeminiChatGenerator"
241
    ]:
242
        """
243
        Initialize the chat generator based on the specified API provider and parameters.
244
        """
245

246
        generator_api_params = generator_api_params or {}
1✔
247

248
        if generator_api == LLMProvider.OPENAI:
1✔
249
            return OpenAIChatGenerator(**generator_api_params)
1✔
250
        elif generator_api == LLMProvider.OPENAI_AZURE:
×
251
            return AzureOpenAIChatGenerator(**generator_api_params)
×
252
        elif generator_api == LLMProvider.AWS_BEDROCK:
×
253
            amazon_bedrock_generator.check()
×
254
            return AmazonBedrockChatGenerator(**generator_api_params)
×
255
        elif generator_api == LLMProvider.GOOGLE_VERTEX:
×
256
            vertex_ai_gemini_generator.check()
×
257
            return VertexAIGeminiChatGenerator(**generator_api_params)
×
258
        else:
259
            raise ValueError(f"Unsupported generator API: {generator_api}")
×
260

261
    def warm_up(self):
1✔
262
        """
263
        Warm up the LLM provider component.
264
        """
265
        if hasattr(self._chat_generator, "warm_up"):
1✔
266
            self._chat_generator.warm_up()
1✔
267

268
    def to_dict(self) -> Dict[str, Any]:
1✔
269
        """
270
        Serializes the component to a dictionary.
271

272
        :returns:
273
            Dictionary with serialized data.
274
        """
275

276
        return default_to_dict(
1✔
277
            self,
278
            prompt=self.prompt,
279
            chat_generator=self._chat_generator.to_dict(),
280
            expected_keys=self.expected_keys,
281
            page_range=self.expanded_range,
282
            raise_on_failure=self.raise_on_failure,
283
            max_workers=self.max_workers,
284
        )
285

286
    @classmethod
1✔
287
    def from_dict(cls, data: Dict[str, Any]) -> "LLMMetadataExtractor":
1✔
288
        """
289
        Deserializes the component from a dictionary.
290

291
        :param data:
292
            Dictionary with serialized data.
293
        :returns:
294
            An instance of the component.
295
        """
296

297
        chat_generator_class = import_class_by_name(data["init_parameters"]["chat_generator"]["type"])
1✔
298
        assert hasattr(chat_generator_class, "from_dict")  # we know but mypy doesn't
1✔
299
        chat_generator_instance = chat_generator_class.from_dict(data["init_parameters"]["chat_generator"])
1✔
300
        data["init_parameters"]["chat_generator"] = chat_generator_instance
1✔
301
        return default_from_dict(cls, data)
1✔
302

303
    def _extract_metadata(self, llm_answer: str) -> Dict[str, Any]:
1✔
304
        try:
1✔
305
            parsed_metadata = json.loads(llm_answer)
1✔
306
        except json.JSONDecodeError as e:
1✔
307
            logger.warning(
1✔
308
                "Response from the LLM is not valid JSON. Skipping metadata extraction. Received output: {response}",
309
                response=llm_answer,
310
            )
311
            if self.raise_on_failure:
1✔
312
                raise e
1✔
313
            return {"error": "Response is not valid JSON. Received JSONDecodeError: " + str(e)}
1✔
314

315
        if not all(key in parsed_metadata for key in self.expected_keys):
1✔
316
            logger.warning(
1✔
317
                "Expected response from LLM to be a JSON with keys {expected_keys}, got {parsed_json}. "
318
                "Continuing extraction with received output.",
319
                expected_keys=self.expected_keys,
320
                parsed_json=parsed_metadata,
321
            )
322

323
        return parsed_metadata
1✔
324

325
    def _prepare_prompts(
1✔
326
        self, documents: List[Document], expanded_range: Optional[List[int]] = None
327
    ) -> List[Union[ChatMessage, None]]:
328
        all_prompts: List[Union[ChatMessage, None]] = []
1✔
329
        for document in documents:
1✔
330
            if not document.content:
1✔
331
                logger.warning("Document {doc_id} has no content. Skipping metadata extraction.", doc_id=document.id)
1✔
332
                all_prompts.append(None)
1✔
333
                continue
1✔
334

335
            if expanded_range:
1✔
336
                doc_copy = copy.deepcopy(document)
1✔
337
                pages = self.splitter.run(documents=[doc_copy])
1✔
338
                content = ""
1✔
339
                for idx, page in enumerate(pages["documents"]):
1✔
340
                    if idx + 1 in expanded_range:
1✔
341
                        content += page.content
1✔
342
                doc_copy.content = content
1✔
343
            else:
344
                doc_copy = document
1✔
345

346
            prompt_with_doc = self.builder.run(template=self.prompt, template_variables={"document": doc_copy})
1✔
347

348
            # build a ChatMessage with the prompt
349
            message = ChatMessage.from_user(prompt_with_doc["prompt"])
1✔
350
            all_prompts.append(message)
1✔
351

352
        return all_prompts
1✔
353

354
    def _run_on_thread(self, prompt: Optional[ChatMessage]) -> Dict[str, Any]:
1✔
355
        # If prompt is None, return an empty dictionary
356
        if prompt is None:
×
357
            return {"replies": ["{}"]}
×
358

359
        llm = self.llm_provider if hasattr(self, "llm_provider") else self._chat_generator
×
360

361
        try:
×
362
            result = llm.run(messages=[prompt])
×
363
        except Exception as e:
×
364
            if self.raise_on_failure:
×
365
                raise e
×
366
            logger.error(
×
367
                "LLM {class_name} execution failed. Skipping metadata extraction. Failed with exception '{error}'.",
368
                class_name=llm.__class__.__name__,
369
                error=e,
370
            )
371
            result = {"error": "LLM failed with exception: " + str(e)}
×
372
        return result
×
373

374
    @component.output_types(documents=List[Document], failed_documents=List[Document])
1✔
375
    def run(self, documents: List[Document], page_range: Optional[List[Union[str, int]]] = None):
1✔
376
        """
377
        Extract metadata from documents using a Large Language Model.
378

379
        If `page_range` is provided, the metadata will be extracted from the specified range of pages. This component
380
        will split the documents into pages and extract metadata from the specified range of pages. The metadata will be
381
        extracted from the entire document if `page_range` is not provided.
382

383
        The original documents will be returned  updated with the extracted metadata.
384

385
        :param documents: List of documents to extract metadata from.
386
        :param page_range: A range of pages to extract metadata from. For example, page_range=['1', '3'] will extract
387
                           metadata from the first and third pages of each document. It also accepts printable range
388
                           strings, e.g.: ['1-3', '5', '8', '10-12'] will extract metadata from pages 1, 2, 3, 5, 8, 10,
389
                           11, 12.
390
                           If None, metadata will be extracted from the entire document for each document in the
391
                           documents list.
392
        :returns:
393
            A dictionary with the keys:
394
            - "documents": A list of documents that were successfully updated with the extracted metadata.
395
            - "failed_documents": A list of documents that failed to extract metadata. These documents will have
396
            "metadata_extraction_error" and "metadata_extraction_response" in their metadata. These documents can be
397
            re-run with the extractor to extract metadata.
398
        """
399
        if len(documents) == 0:
1✔
400
            logger.warning("No documents provided. Skipping metadata extraction.")
1✔
401
            return {"documents": [], "failed_documents": []}
1✔
402

403
        expanded_range = self.expanded_range
×
404
        if page_range:
×
405
            expanded_range = expand_page_range(page_range)
×
406

407
        # Create ChatMessage prompts for each document
408
        all_prompts = self._prepare_prompts(documents=documents, expanded_range=expanded_range)
×
409

410
        # Run the LLM on each prompt
411
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
×
412
            results = executor.map(self._run_on_thread, all_prompts)
×
413

414
        successful_documents = []
×
415
        failed_documents = []
×
416
        for document, result in zip(documents, results):
×
417
            if "error" in result:
×
418
                document.meta["metadata_extraction_error"] = result["error"]
×
419
                document.meta["metadata_extraction_response"] = None
×
420
                failed_documents.append(document)
×
421
                continue
×
422

423
            parsed_metadata = self._extract_metadata(result["replies"][0].text)
×
424
            if "error" in parsed_metadata:
×
425
                document.meta["metadata_extraction_error"] = parsed_metadata["error"]
×
426
                document.meta["metadata_extraction_response"] = result["replies"][0]
×
427
                failed_documents.append(document)
×
428
                continue
×
429

430
            for key in parsed_metadata:
×
431
                document.meta[key] = parsed_metadata[key]
×
432
                # Remove metadata_extraction_error and metadata_extraction_response if present from previous runs
433
                document.meta.pop("metadata_extraction_error", None)
×
434
                document.meta.pop("metadata_extraction_response", None)
×
435
            successful_documents.append(document)
×
436

437
        return {"documents": successful_documents, "failed_documents": failed_documents}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc