• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MITLibraries / transmogrifier / 12319654873

13 Dec 2024 04:48PM UTC coverage: 98.751% (+2.3%) from 96.404%
12319654873

Pull #219

github

ghukill
Update --input-file CLI docstring
Pull Request #219: TIMX 405 - support output to TIMDEX parquet dataset

37 of 45 new or added lines in 2 files covered. (82.22%)

1 existing line in 1 file now uncovered.

1739 of 1761 relevant lines covered (98.75%)

0.99 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.51
/transmogrifier/sources/transformer.py
1
"""Transformer module."""
2

3
# ruff: noqa: D417
4

5
from __future__ import annotations
1✔
6

7
import json
1✔
8
import logging
1✔
9
import os
1✔
10
import re
1✔
11
import uuid
1✔
12
from abc import ABC, abstractmethod
1✔
13
from importlib import import_module
1✔
14
from typing import TYPE_CHECKING, final
1✔
15

16
import smart_open  # type: ignore[import-untyped]
1✔
17
from attrs import asdict
1✔
18
from bs4 import Tag  # type: ignore[import-untyped]
1✔
19
from timdex_dataset_api import (  # type: ignore[import-untyped, import-not-found]
1✔
20
    DatasetRecord,
21
    TIMDEXDataset,
22
)
23

24
import transmogrifier.models as timdex
1✔
25
from transmogrifier.config import SOURCES, get_etl_version
1✔
26
from transmogrifier.exceptions import DeletedRecordEvent, SkippedRecordEvent
1✔
27
from transmogrifier.helpers import generate_citation, validate_date
1✔
28

29
if TYPE_CHECKING:
1✔
30
    from collections.abc import Callable, Iterator
×
31

32
logger = logging.getLogger(__name__)
1✔
33

34
type JSON = dict[str, "JSON"] | list["JSON"] | str | int | float | bool | None
1✔
35

36
PARQUET_DATASET_BATCH_SIZE = 1_000
1✔
37

38

39
class Transformer(ABC):
1✔
40
    """Base transformer class."""
41

42
    @final
1✔
43
    def __init__(
1✔
44
        self,
45
        source: str,
46
        source_records: Iterator[dict[str, JSON] | Tag],
47
        source_file: str | None = None,
48
        run_id: str | None = None,
49
    ) -> None:
50
        """
51
        Initialize Transformer instance.
52

53
        Args:
54
            source: Source repository label. Must match a source key from config.SOURCES.
55
            source_records: A set of source records to be processed.
56
            source_file: Filepath of the input source file.
57
            run_id: A unique identifier for this invocation of Transmogrifier.
58
        """
59
        self.source: str = source
1✔
60
        self.source_base_url: str = SOURCES[source]["base-url"]
1✔
61
        self.source_name = SOURCES[source]["name"]
1✔
62
        self.source_records: Iterator[JSON | Tag] = source_records
1✔
63
        self.processed_record_count: int = 0
1✔
64
        self.transformed_record_count: int = 0
1✔
65
        self.skipped_record_count: int = 0
1✔
66
        self.error_record_count: int = 0
1✔
67
        self.deleted_records: list[str] = []
1✔
68
        self.source_file = source_file
1✔
69

70
        # NOTE: FEATURE FLAG: branching logic will be removed after v2 work is complete
71
        etl_version = get_etl_version()
1✔
72
        if etl_version == 2:  # noqa: PLR2004
1✔
73
            self.run_data = self.get_run_data(source_file, run_id)
1✔
74

75
    @final
1✔
76
    def __iter__(self) -> Iterator[timdex.TimdexRecord | DatasetRecord]:
1✔
77
        """Iterate over transformed records."""
78
        return self
1✔
79

80
    @final
1✔
81
    def __next__(self) -> timdex.TimdexRecord | DatasetRecord:
1✔
82
        """Return next transformed record."""
83
        # NOTE: FEATURE FLAG: branching logic will be removed after v2 work is complete
84
        etl_version = get_etl_version()
1✔
85
        match etl_version:
1✔
86
            case 1:
1✔
87
                return self._etl_v1_next_iter_method()
1✔
88
            case 2:
1✔
89
                return self._etl_v2_next_iter_method()
1✔
90

91
    # NOTE: FEATURE FLAG: branching logic + method removed after v2 work is complete
92
    def _etl_v1_next_iter_method(self) -> timdex.TimdexRecord:
1✔
93
        """Transformer.__next__ behavior for ETL version 1."""
94
        while True:
1✔
95
            source_record = next(self.source_records)
1✔
96
            self.processed_record_count += 1
1✔
97
            try:
1✔
98
                record = self.transform(source_record)
1✔
99
            except DeletedRecordEvent as error:
1✔
100
                self.deleted_records.append(error.timdex_record_id)
1✔
101
                continue
1✔
102
            except SkippedRecordEvent:
1✔
103
                self.skipped_record_count += 1
1✔
104
                continue
1✔
105
            self.transformed_record_count += 1
1✔
106
            return record
1✔
107

108
    # NOTE: FEATURE FLAG: method logic will move directly to __next__ definition
109
    def _etl_v2_next_iter_method(self) -> DatasetRecord:
1✔
110
        """Transformer.__next__ behavior for ETL version 2."""
111
        while True:
1✔
112
            transformed_record = None
1✔
113
            timdex_record_id = None
1✔
114

115
            source_record = next(self.source_records)
1✔
116
            self.processed_record_count += 1
1✔
117

118
            try:
1✔
119
                transformed_record = self.transform(source_record)
1✔
120
                timdex_record_id = transformed_record.timdex_record_id
1✔
121
                self.transformed_record_count += 1
1✔
122
                action = "index"
1✔
123

124
            except DeletedRecordEvent as error:
1✔
125
                self.deleted_records.append(error.timdex_record_id)
1✔
126
                timdex_record_id = error.timdex_record_id
1✔
127
                action = "delete"
1✔
128

129
            except SkippedRecordEvent:
1✔
130
                self.skipped_record_count += 1
1✔
131
                action = "skip"
1✔
132

133
            except Exception as exception:  # noqa: BLE001
1✔
134
                self.error_record_count += 1
1✔
135
                message = f"Unhandled exception during record transformation: {exception}"
1✔
136
                logger.warning(message)
1✔
137
                action = "error"
1✔
138

139
            return DatasetRecord(
1✔
140
                timdex_record_id=timdex_record_id,
141
                source_record=self.serialize_source_record(source_record),
142
                transformed_record=(
143
                    json.dumps(transformed_record.asdict()).encode()
144
                    if transformed_record
145
                    else None
146
                ),
147
                action=action,
148
                **self.run_data,
149
            )
150

151
    @final
1✔
152
    @classmethod
1✔
153
    def get_transformer(cls, source: str) -> type[Transformer]:
1✔
154
        """
155
        Return configured transformer class for a source.
156

157
        Source must be configured with a valid transform class path.
158

159
        Args:
160
            source: Source repository label. Must match a source key from config.SOURCES.
161

162
        """
163
        module_name, class_name = SOURCES[source]["transform-class"].rsplit(".", 1)
1✔
164
        source_module = import_module(module_name)
1✔
165
        return getattr(source_module, class_name)
1✔
166

167
    @final
1✔
168
    @classmethod
1✔
169
    def load(
1✔
170
        cls, source: str, source_file: str, run_id: str | None = None
171
    ) -> Transformer:
172
        """
173
        Instantiate specified transformer class and populate with source records.
174

175
        Args:
176
            source: Source repository label. Must match a source key from config.SOURCES.
177
            source_file: A file containing source records to be transformed.
178
            run_id: A unique identifier for this invocation of Transmogrifier.
179
        """
180
        transformer_class = cls.get_transformer(source)
1✔
181
        source_records = transformer_class.parse_source_file(source_file)
1✔
182
        return transformer_class(
1✔
183
            source,
184
            source_records,
185
            source_file=source_file,
186
            run_id=run_id,
187
        )
188

189
    @staticmethod
1✔
190
    def get_run_data(source_file: str | None, run_id: str | None) -> dict:
1✔
191
        """Prepare dictionary of run data based on input source filename and CLI args.
192

193
        Args:
194
            - source_file: str
195
                - example: "libguides-2024-06-03-full-extracted-records-to-index.xml"
196
            - run_id: str
197
                - example: "run-abc-123"
198
                - provided as CLI argument or minted if absent
199

200
        Example output:
201
            {
202
                'source': 'libguides',
203
                'run_date': '2024-06-03',
204
                'run_type': 'full',
205
                'run_id': 'run-abc-123'
206
            }
207
        """
208
        if not source_file:
1✔
NEW
209
            message = "source file not set, cannot parse run data"
×
NEW
210
            raise ValueError(message)
×
211

212
        filename = source_file.split("/")[-1]
1✔
213

214
        match_result = re.match(
1✔
215
            r"^([\w\-]+?)-(\d{4}-\d{2}-\d{2})-(\w+)-(\w+)-records-to-(.+?)(?:_(\d+))?\.(\w+)$",
216
            filename,
217
        )
218
        if not match_result:
1✔
NEW
219
            message = f"Provided S3 URI or filename is invalid: {filename}."
×
220
            raise ValueError(message)
×
221

222
        match_keys = [
1✔
223
            "source",
224
            "run_date",
225
            "run_type",
226
            "stage",
227
            "action",
228
            "index",
229
            "file_type",
230
        ]
231
        output_keys = ["source", "run_date", "run_type"]
1✔
232
        try:
1✔
233
            filename_parts = dict(zip(match_keys, match_result.groups(), strict=True))
1✔
234
            run_data = {k: v for k, v in filename_parts.items() if k in output_keys}
1✔
235
        except ValueError as exception:
×
NEW
236
            message = (
×
237
                f"Input S3 URI or filename '{filename}' does not contain required "
238
                f"dataset data: {exception}."
239
            )
UNCOV
240
            raise ValueError(message) from exception
×
241

242
        if not run_id:
1✔
243
            logger.info("explicit run_id not passed, minting new UUID")
1✔
244
            run_id = str(uuid.uuid4())
1✔
245
        message = f"run_id set: '{run_id}'"
1✔
246
        logger.info(message)
1✔
247
        run_data["run_id"] = run_id
1✔
248

249
        return run_data
1✔
250

251
    def serialize_source_record(self, source_record: Tag | dict) -> bytes | None:
1✔
252
        if isinstance(source_record, Tag):
1✔
253
            return source_record.encode()
1✔
254
        if isinstance(source_record, dict):
×
255
            return json.dumps(source_record).encode()
×
256
        return None
×
257

258
    @final
1✔
259
    def transform(self, source_record: dict[str, JSON] | Tag) -> timdex.TimdexRecord:
1✔
260
        """
261
        Transform source record into TimdexRecord instance.
262

263
        Instantiates a TimdexRecord instance with required fields and runs fields methods
264
        for optional fields. The optional field methods return values or exceptions that
265
        prompt the __next__ method to skip the entire record.
266

267
        After optional fields are set, derived fields are generated from the required
268
        optional field values set by the source transformer.
269

270
        May not be overridden.
271

272
        Args:
273
            source_record: A single source record.
274
        """
275
        if self.record_is_deleted(source_record):
1✔
276
            timdex_record_id = self.get_timdex_record_id(source_record)
1✔
277
            raise DeletedRecordEvent(timdex_record_id)
1✔
278

279
        timdex_record = timdex.TimdexRecord(
1✔
280
            source=self.source_name,
281
            source_link=self.get_source_link(source_record),
282
            timdex_record_id=self.get_timdex_record_id(source_record),
283
            title=self.get_valid_title(source_record),
284
        )
285

286
        for field_name, field_method in self.get_optional_field_methods():
1✔
287
            setattr(timdex_record, field_name, field_method(source_record))
1✔
288

289
        self.generate_derived_fields(timdex_record)
1✔
290

291
        return timdex_record
1✔
292

293
    # NOTE: FEATURE FLAG: method will be removed after v2 work is complete
294
    @final
1✔
295
    def transform_and_write_output_files(self, output_file: str) -> None:
1✔
296
        """Iterates through source records to transform and write to output files.
297

298
        Args:
299
            output_file: The name of the output files.
300
        """
301
        self._write_timdex_records_to_json_file(output_file)
1✔
302
        if self.processed_record_count == 0:
1✔
303
            message = "No records processed from input file, needs investigation"
1✔
304
            raise ValueError(message)
1✔
305
        if deleted_records := self.deleted_records:
1✔
306
            deleted_output_file = output_file.replace("index", "delete").replace(
1✔
307
                "json", "txt"
308
            )
309
            self._write_deleted_records_to_txt_file(deleted_records, deleted_output_file)
1✔
310

311
    # NOTE: FEATURE FLAG: method will be removed after v2 work is complete
312
    @final
1✔
313
    def _write_timdex_records_to_json_file(self, output_file: str) -> int:
1✔
314
        """
315
        Write TIMDEX records to JSON file.
316

317
        Args:
318
            output_file: The JSON file used for writing TIMDEX records.
319
        """
320
        count = 0
1✔
321
        try:
1✔
322
            record: timdex.TimdexRecord = next(self)  # type: ignore[assignment]
1✔
323
        except StopIteration:
1✔
324
            return count
1✔
325
        with smart_open.open(output_file, "w") as file:
1✔
326
            file.write("[\n")
1✔
327
            while record:
1✔
328
                file.write(
1✔
329
                    json.dumps(
330
                        asdict(
331
                            record,
332
                            filter=lambda _, value: value is not None,
333
                        ),
334
                        indent=2,
335
                    )
336
                )
337
                count += 1
1✔
338
                if count % int(os.getenv("STATUS_UPDATE_INTERVAL", "1000")) == 0:
1✔
339
                    logger.info(
1✔
340
                        "Status update: %s records written to output file so far!",
341
                        count,
342
                    )
343
                try:
1✔
344
                    record: timdex.TimdexRecord = next(self)  # type: ignore[no-redef]
1✔
345
                except StopIteration:
1✔
346
                    break
1✔
347
                file.write(",\n")
1✔
348
            file.write("\n]")
1✔
349
        return count
1✔
350

351
    # NOTE: FEATURE FLAG: method will be removed after v2 work is complete
352
    @final
1✔
353
    @staticmethod
1✔
354
    def _write_deleted_records_to_txt_file(
1✔
355
        deleted_records: list[str], output_file: str
356
    ) -> None:
357
        """Write deleted records to the specified text file.
358

359
        Args:
360
            deleted_records: The deleted records to write to file.
361
            output_file: The text file used for writing deleted records.
362
        """
363
        with smart_open.open(output_file, "w") as file:
1✔
364
            for record_id in deleted_records:
1✔
365
                file.write(f"{record_id}\n")
1✔
366

367
    def write_to_parquet_dataset(self, dataset_location: str) -> list:
1✔
368
        """Write output to TIMDEX dataset."""
369
        timdex_dataset = TIMDEXDataset(location=dataset_location)
1✔
370
        return timdex_dataset.write(records_iter=self)
1✔
371

372
    @final
1✔
373
    def get_valid_title(self, source_record: dict[str, JSON] | Tag) -> str:
1✔
374
        """
375
        Retrieves main title(s) from a source record and returns a valid title string.
376

377
        May not be overridden.
378

379
        If the list of main titles retrieved from the source record is empty or the
380
        title element has no string value, inserts standard language to represent a
381
        missing title field.
382

383
        Args:
384
            source_record: A single source record.
385
        """
386
        all_titles = self.get_main_titles(source_record)
1✔
387
        title_count = len(all_titles)
1✔
388
        if title_count > 1:
1✔
389
            logger.warning(
1✔
390
                "Record %s has multiple titles. Using the first title from the "
391
                "following titles found: %s",
392
                self.get_source_record_id(source_record),
393
                all_titles,
394
            )
395
        if title_count >= 1:
1✔
396
            title = all_titles[0]
1✔
397
        else:
398
            logger.warning(
1✔
399
                "Record %s was missing a title, source record should be investigated.",
400
                self.get_source_record_id(source_record),
401
            )
402
            title = "Title not provided"
1✔
403
        return title
1✔
404

405
    @classmethod
1✔
406
    @abstractmethod
1✔
407
    def parse_source_file(cls, source_file: str) -> Iterator[dict[str, JSON] | Tag]:
1✔
408
        """
409
        Parse source file and return source records via an iterator.
410

411
        Must be overridden by format subclasses.
412

413
        Args:
414
            source_file: A file containing source records to be transformed.
415
        """
416

417
    @classmethod
1✔
418
    @abstractmethod
1✔
419
    def get_main_titles(cls, source_record: dict[str, JSON] | Tag) -> list[str]:
1✔
420
        """
421
        Retrieve main title(s) from an source record.
422

423
        Must be overridden by source subclasses.
424

425
        Args:
426
            source_record: A single source record.
427
        """
428

429
    @abstractmethod
1✔
430
    def get_source_link(
1✔
431
        self,
432
        source_record: dict[str, JSON] | Tag,
433
    ) -> str:
434
        """
435
        Class method to set the source link for the item.
436

437
        Must be overridden by source subclasses.
438

439
        Args:
440
            source_record: A single source record.
441
        """
442

443
    @abstractmethod
1✔
444
    def get_timdex_record_id(self, source_record: dict[str, JSON] | Tag) -> str:
1✔
445
        """
446
        Class method to set the TIMDEX record id.
447

448
        Must be overridden by source subclasses.
449

450
        Args:
451
            source_record: A single source record.
452
        """
453

454
    @classmethod
1✔
455
    @abstractmethod
1✔
456
    def get_source_record_id(cls, source_record: dict[str, JSON] | Tag) -> str:
1✔
457
        """
458
        Get or generate a source record ID from a source record.
459

460
        Must be overridden by source subclasses.
461

462
        Args:
463
            source_record: A single source record.
464
        """
465

466
    @classmethod
1✔
467
    @abstractmethod
1✔
468
    def record_is_deleted(cls, source_record: dict[str, JSON] | Tag) -> bool:
1✔
469
        """
470
        Determine whether record has a status of deleted.
471

472
        Must be overridden by source subclasses.
473

474
        Args:
475
            source_record: A single source record.
476
        """
477

478
    @final
1✔
479
    def get_optional_field_methods(self) -> Iterator[tuple[str, Callable]]:
1✔
480
        """
481
        Return optional TIMDEX field names and corresponding methods.
482

483
        May not be overridden.
484
        """
485
        for field_name in timdex.TimdexRecord.get_optional_field_names():
1✔
486
            if field_method := getattr(self, f"get_{field_name}", None):
1✔
487
                yield field_name, field_method
1✔
488

489
    @final
1✔
490
    def generate_derived_fields(
1✔
491
        self, timdex_record: timdex.TimdexRecord
492
    ) -> timdex.TimdexRecord:
493
        """
494
        Generate field values based on existing values in TIMDEX record.
495

496
        This method sets or extends the following fields:
497
            - dates: list[Date]
498
            - locations: list[Location]
499
            - citation: str
500
            - content_type: str
501

502
        May not be overridden.
503
        """
504
        # dates
505
        derived_dates = timdex_record.dates or []
1✔
506
        derived_dates.extend(self.create_dates_from_publishers(timdex_record))
1✔
507
        timdex_record.dates = derived_dates or None
1✔
508

509
        # locations
510
        derived_locations = timdex_record.locations or []
1✔
511
        derived_locations.extend(self.create_locations_from_publishers(timdex_record))
1✔
512
        derived_locations.extend(
1✔
513
            self.create_locations_from_spatial_subjects(timdex_record)
514
        )
515
        timdex_record.locations = derived_locations or None
1✔
516

517
        # citation
518
        timdex_record.citation = timdex_record.citation or generate_citation(
1✔
519
            timdex_record
520
        )
521

522
        # content type
523
        timdex_record.content_type = timdex_record.content_type or ["Not specified"]
1✔
524

525
        return timdex_record
1✔
526

527
    @final
1✔
528
    @staticmethod
1✔
529
    def create_dates_from_publishers(
1✔
530
        timdex_record: timdex.TimdexRecord,
531
    ) -> Iterator[timdex.Date]:
532
        """Derive Date objects based on data in publishers field.
533

534
        Args:
535
            timdex_record: A TimdexRecord class instance.
536
        """
537
        if timdex_record.publishers:
1✔
538
            for publisher in timdex_record.publishers:
1✔
539
                if publisher.date and validate_date(
1✔
540
                    publisher.date, timdex_record.timdex_record_id
541
                ):
542
                    yield timdex.Date(kind="Publication date", value=publisher.date)
1✔
543

544
    @final
1✔
545
    @staticmethod
1✔
546
    def create_locations_from_publishers(
1✔
547
        timdex_record: timdex.TimdexRecord,
548
    ) -> Iterator[timdex.Location]:
549
        """Derive Location objects based on data in publishers field.
550

551
        Args:
552
            timdex_record: A TimdexRecord class instance.
553
        """
554
        if timdex_record.publishers:
1✔
555
            for publisher in timdex_record.publishers:
1✔
556
                if publisher.location:
1✔
557
                    yield timdex.Location(
1✔
558
                        kind="Place of Publication", value=publisher.location
559
                    )
560

561
    @final
1✔
562
    @staticmethod
1✔
563
    def create_locations_from_spatial_subjects(
1✔
564
        timdex_record: timdex.TimdexRecord,
565
    ) -> Iterator[timdex.Location]:
566
        """Derive Location objects from a TimdexRecord's spatial subjects.
567

568
        Args:
569
           timdex_record: A TimdexRecord class instance.
570
        """
571
        if timdex_record.subjects:
1✔
572
            spatial_subjects = [
1✔
573
                subject
574
                for subject in timdex_record.subjects
575
                if subject.kind == "Dublin Core; Spatial" and subject.value is not None
576
            ]
577

578
            for subject in spatial_subjects:
1✔
579
                for place_name in subject.value:
1✔
580
                    yield timdex.Location(value=place_name, kind="Place Name")
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc