• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 13007270296

28 Jan 2025 09:20AM UTC coverage: 79.211% (-0.3%) from 79.556%
13007270296

Pull #1546

github

web-flow
Merge 477ed98f2 into 49cd166e7
Pull Request #1546: [Draft] Add CollateInstanceByField operator to group data by specific field

1437 of 1809 branches covered (79.44%)

Branch coverage included in aggregate %.

9102 of 11496 relevant lines covered (79.18%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.7
src/unitxt/struct_data_operators.py
1
"""This section describes unitxt operators for structured data.
2

3
These operators are specialized in handling structured data like tables.
4
For tables, expected input format is:
5

6
.. code-block:: text
7

8
    {
9
        "header": ["col1", "col2"],
10
        "rows": [["row11", "row12"], ["row21", "row22"], ["row31", "row32"]]
11
    }
12

13
For triples, expected input format is:
14

15
.. code-block:: text
16

17
    [[ "subject1", "relation1", "object1" ], [ "subject1", "relation2", "object2"]]
18

19
For key-value pairs, expected input format is:
20

21
.. code-block:: text
22

23
    {"key1": "value1", "key2": value2, "key3": "value3"}
24
"""
25

26
import json
1✔
27
import random
1✔
28
from abc import ABC, abstractmethod
1✔
29
from typing import (
1✔
30
    Any,
31
    Dict,
32
    List,
33
    Optional,
34
)
35

36
import pandas as pd
1✔
37

38
from .augmentors import TypeDependentAugmentor
1✔
39
from .dict_utils import dict_get
1✔
40
from .operators import FieldOperator, InstanceOperator
1✔
41
from .random_utils import new_random_generator
1✔
42
from .serializers import ImageSerializer, TableSerializer
1✔
43
from .types import Table
1✔
44
from .utils import recursive_copy
1✔
45

46

47
def shuffle_columns(table: Table, seed=0) -> Table:
1✔
48
    # extract header & rows from the dictionary
49
    header = table.get("header", [])
1✔
50
    rows = table.get("rows", [])
1✔
51
    # shuffle the indices first
52
    indices = list(range(len(header)))
1✔
53
    random_generator = new_random_generator({"table": table, "seed": seed})
1✔
54
    random_generator.shuffle(indices)
1✔
55

56
    # shuffle the header & rows based on that indices
57
    shuffled_header = [header[i] for i in indices]
1✔
58
    shuffled_rows = [[row[i] for i in indices] for row in rows]
1✔
59

60
    table["header"] = shuffled_header
1✔
61
    table["rows"] = shuffled_rows
1✔
62

63
    return table
1✔
64

65

66
def shuffle_rows(table: Table, seed=0) -> Table:
1✔
67
    # extract header & rows from the dictionary
68
    rows = table.get("rows", [])
1✔
69
    # shuffle rows
70
    random_generator = new_random_generator({"table": table, "seed": seed})
1✔
71
    random_generator.shuffle(rows)
1✔
72
    table["rows"] = rows
1✔
73

74
    return table
1✔
75

76

77
class SerializeTable(ABC, TableSerializer):
1✔
78
    """TableSerializer converts a given table into a flat sequence with special symbols.
79

80
    Output format varies depending on the chosen serializer. This abstract class defines structure of a typical table serializer that any concrete implementation should follow.
81
    """
82

83
    seed: int = 0
1✔
84
    shuffle_rows: bool = False
1✔
85
    shuffle_columns: bool = False
1✔
86

87
    def serialize(self, value: Table, instance: Dict[str, Any]) -> str:
1✔
88
        value = recursive_copy(value)
1✔
89
        if self.shuffle_columns:
1✔
90
            value = shuffle_columns(table=value, seed=self.seed)
1✔
91

92
        if self.shuffle_rows:
1✔
93
            value = shuffle_rows(table=value, seed=self.seed)
1✔
94

95
        return self.serialize_table(value)
1✔
96

97
    # main method to serialize a table
98
    @abstractmethod
1✔
99
    def serialize_table(self, table_content: Dict) -> str:
1✔
100
        pass
×
101

102
    # method to process table header
103
    def process_header(self, header: List):
1✔
104
        pass
×
105

106
    # method to process a table row
107
    def process_row(self, row: List, row_index: int):
1✔
108
        pass
×
109

110

111
# Concrete classes implementing table serializers
112
class SerializeTableAsIndexedRowMajor(SerializeTable):
1✔
113
    """Indexed Row Major Table Serializer.
114

115
    Commonly used row major serialization format.
116
    Format:  col : col1 | col2 | col 3 row 1 : val1 | val2 | val3 | val4 row 2 : val1 | ...
117
    """
118

119
    # main method that processes a table
120
    # table_content must be in the presribed input format
121
    def serialize_table(self, table_content: Dict) -> str:
1✔
122
        # Extract headers and rows from the dictionary
123
        header = table_content.get("header", [])
1✔
124
        rows = table_content.get("rows", [])
1✔
125

126
        assert header and rows, "Incorrect input table format"
1✔
127

128
        # Process table header first
129
        serialized_tbl_str = self.process_header(header) + " "
1✔
130

131
        # Process rows sequentially starting from row 1
132
        for i, row in enumerate(rows, start=1):
1✔
133
            serialized_tbl_str += self.process_row(row, row_index=i) + " "
1✔
134

135
        # return serialized table as a string
136
        return serialized_tbl_str.strip()
1✔
137

138
    # serialize header into a string containing the list of column names separated by '|' symbol
139
    def process_header(self, header: List):
1✔
140
        return "col : " + " | ".join(header)
1✔
141

142
    # serialize a table row into a string containing the list of cell values separated by '|'
143
    def process_row(self, row: List, row_index: int):
1✔
144
        serialized_row_str = ""
1✔
145
        row_cell_values = [
1✔
146
            str(value) if isinstance(value, (int, float)) else value for value in row
147
        ]
148
        serialized_row_str += " | ".join([str(value) for value in row_cell_values])
1✔
149

150
        return f"row {row_index} : {serialized_row_str}"
1✔
151

152

153
class SerializeTableAsMarkdown(SerializeTable):
1✔
154
    """Markdown Table Serializer.
155

156
    Markdown table format is used in GitHub code primarily.
157
    Format:
158

159
    .. code-block:: text
160

161
        |col1|col2|col3|
162
        |---|---|---|
163
        |A|4|1|
164
        |I|2|1|
165
        ...
166

167
    """
168

169
    # main method that serializes a table.
170
    # table_content must be in the presribed input format.
171
    def serialize_table(self, table_content: Dict) -> str:
1✔
172
        # Extract headers and rows from the dictionary
173
        header = table_content.get("header", [])
1✔
174
        rows = table_content.get("rows", [])
1✔
175

176
        assert header and rows, "Incorrect input table format"
1✔
177

178
        # Process table header first
179
        serialized_tbl_str = self.process_header(header)
1✔
180

181
        # Process rows sequentially starting from row 1
182
        for i, row in enumerate(rows, start=1):
1✔
183
            serialized_tbl_str += self.process_row(row, row_index=i)
1✔
184

185
        # return serialized table as a string
186
        return serialized_tbl_str.strip()
1✔
187

188
    # serialize header into a string containing the list of column names
189
    def process_header(self, header: List):
1✔
190
        header_str = "|{}|\n".format("|".join(header))
1✔
191
        header_str += "|{}|\n".format("|".join(["---"] * len(header)))
1✔
192
        return header_str
1✔
193

194
    # serialize a table row into a string containing the list of cell values
195
    def process_row(self, row: List, row_index: int):
1✔
196
        row_str = ""
1✔
197
        row_str += "|{}|\n".format("|".join(str(cell) for cell in row))
1✔
198
        return row_str
1✔
199

200

201
class SerializeTableAsDFLoader(SerializeTable):
1✔
202
    """DFLoader Table Serializer.
203

204
    Pandas dataframe based code snippet format serializer.
205
    Format(Sample):
206

207
    .. code-block:: python
208

209
        pd.DataFrame({
210
            "name" : ["Alex", "Diana", "Donald"],
211
            "age" : [26, 34, 39]
212
        },
213
        index=[0,1,2])
214
    """
215

216
    # main method that serializes a table.
217
    # table_content must be in the presribed input format.
218
    def serialize_table(self, table_content: Dict) -> str:
1✔
219
        # Extract headers and rows from the dictionary
220
        header = table_content.get("header", [])
1✔
221
        rows = table_content.get("rows", [])
1✔
222

223
        assert header and rows, "Incorrect input table format"
1✔
224

225
        # Fix duplicate columns, ensuring the first occurrence has no suffix
226
        header = [
1✔
227
            f"{col}_{header[:i].count(col)}" if header[:i].count(col) > 0 else col
228
            for i, col in enumerate(header)
229
        ]
230

231
        # Create a pandas DataFrame
232
        df = pd.DataFrame(rows, columns=header)
1✔
233

234
        # Generate output string in the desired format
235
        data_dict = df.to_dict(orient="list")
1✔
236

237
        return (
1✔
238
            "pd.DataFrame({\n"
239
            + json.dumps(data_dict)[1:-1]
240
            + "},\nindex="
241
            + str(list(range(len(rows))))
242
            + ")"
243
        )
244

245

246
class SerializeTableAsJson(SerializeTable):
1✔
247
    """JSON Table Serializer.
248

249
    Json format based serializer.
250
    Format(Sample):
251

252
    .. code-block:: json
253

254
        {
255
            "0":{"name":"Alex","age":26},
256
            "1":{"name":"Diana","age":34},
257
            "2":{"name":"Donald","age":39}
258
        }
259
    """
260

261
    # main method that serializes a table.
262
    # table_content must be in the presribed input format.
263
    def serialize_table(self, table_content: Dict) -> str:
1✔
264
        # Extract headers and rows from the dictionary
265
        header = table_content.get("header", [])
1✔
266
        rows = table_content.get("rows", [])
1✔
267

268
        assert header and rows, "Incorrect input table format"
1✔
269

270
        # Generate output dictionary
271
        output_dict = {}
1✔
272
        for i, row in enumerate(rows):
1✔
273
            output_dict[i] = {header[j]: value for j, value in enumerate(row)}
1✔
274

275
        # Convert dictionary to JSON string
276
        return json.dumps(output_dict)
1✔
277

278

279
class SerializeTableAsHTML(SerializeTable):
1✔
280
    """HTML Table Serializer.
281

282
    HTML table format used for rendering tables in web pages.
283
    Format(Sample):
284

285
    .. code-block:: html
286

287
        <table>
288
            <thead>
289
                <tr><th>name</th><th>age</th><th>sex</th></tr>
290
            </thead>
291
            <tbody>
292
                <tr><td>Alice</td><td>26</td><td>F</td></tr>
293
                <tr><td>Raj</td><td>34</td><td>M</td></tr>
294
            </tbody>
295
        </table>
296
    """
297

298
    # main method that serializes a table.
299
    # table_content must be in the prescribed input format.
300
    def serialize_table(self, table_content: Dict) -> str:
1✔
301
        # Extract headers and rows from the dictionary
302
        header = table_content.get("header", [])
1✔
303
        rows = table_content.get("rows", [])
1✔
304

305
        assert header and rows, "Incorrect input table format"
1✔
306

307
        # Build the HTML table structure
308
        serialized_tbl_str = "<table>\n"
1✔
309
        serialized_tbl_str += self.process_header(header) + "\n"
1✔
310
        serialized_tbl_str += self.process_rows(rows) + "\n"
1✔
311
        serialized_tbl_str += "</table>"
1✔
312

313
        return serialized_tbl_str.strip()
1✔
314

315
    # serialize the header into an HTML <thead> section
316
    def process_header(self, header: List) -> str:
1✔
317
        header_html = "  <thead>\n    <tr>"
1✔
318
        for col in header:
1✔
319
            header_html += f"<th>{col}</th>"
1✔
320
        header_html += "</tr>\n  </thead>"
1✔
321
        return header_html
1✔
322

323
    # serialize the rows into an HTML <tbody> section
324
    def process_rows(self, rows: List[List]) -> str:
1✔
325
        rows_html = "  <tbody>"
1✔
326
        for row in rows:
1✔
327
            rows_html += "\n    <tr>"
1✔
328
            for cell in row:
1✔
329
                rows_html += f"<td>{cell}</td>"
1✔
330
            rows_html += "</tr>"
1✔
331
        rows_html += "\n  </tbody>"
1✔
332
        return rows_html
1✔
333

334

335
class SerializeTableAsConcatenation(SerializeTable):
1✔
336
    """Concat Serializer.
337

338
    Concat all table content to one string of header and rows.
339
    Format(Sample):
340
    name age Alex 26 Diana 34
341
    """
342

343
    def serialize_table(self, table_content: Dict) -> str:
1✔
344
        # Extract headers and rows from the dictionary
345
        header = table_content["header"]
×
346
        rows = table_content["rows"]
×
347

348
        assert header and rows, "Incorrect input table format"
×
349

350
        # Process table header first
351
        serialized_tbl_str = " ".join([str(i) for i in header])
×
352

353
        # Process rows sequentially starting from row 1
354
        for row in rows:
×
355
            serialized_tbl_str += " " + " ".join([str(i) for i in row])
×
356

357
        # return serialized table as a string
358
        return serialized_tbl_str.strip()
×
359

360

361
class SerializeTableAsImage(SerializeTable):
1✔
362
    _requirements_list = ["matplotlib", "pillow"]
1✔
363

364
    def serialize_table(self, table_content: Dict) -> str:
1✔
365
        raise NotImplementedError()
×
366

367
    def serialize(self, value: Table, instance: Dict[str, Any]) -> str:
1✔
368
        table_content = recursive_copy(value)
×
369
        if self.shuffle_columns:
×
370
            table_content = shuffle_columns(table=table_content, seed=self.seed)
×
371

372
        if self.shuffle_rows:
×
373
            table_content = shuffle_rows(table=table_content, seed=self.seed)
×
374

375
        import io
×
376

377
        import matplotlib.pyplot as plt
×
378
        import pandas as pd
×
379
        from PIL import Image
×
380

381
        # Extract headers and rows from the dictionary
382
        header = table_content.get("header", [])
×
383
        rows = table_content.get("rows", [])
×
384

385
        assert header and rows, "Incorrect input table format"
×
386

387
        # Fix duplicate columns, ensuring the first occurrence has no suffix
388
        header = [
×
389
            f"{col}_{header[:i].count(col)}" if header[:i].count(col) > 0 else col
390
            for i, col in enumerate(header)
391
        ]
392

393
        # Create a pandas DataFrame
394
        df = pd.DataFrame(rows, columns=header)
×
395

396
        # Fix duplicate columns, ensuring the first occurrence has no suffix
397
        df.columns = [
×
398
            f"{col}_{i}" if df.columns.duplicated()[i] else col
399
            for i, col in enumerate(df.columns)
400
        ]
401

402
        # Create a matplotlib table
403
        plt.rcParams["font.family"] = "Serif"
×
404
        fig, ax = plt.subplots(figsize=(len(header) * 1.5, len(rows) * 0.5))
×
405
        ax.axis("off")  # Turn off the axes
×
406

407
        table = pd.plotting.table(ax, df, loc="center", cellLoc="center")
×
408
        table.auto_set_column_width(col=range(len(df.columns)))
×
409
        table.scale(1.5, 1.5)
×
410

411
        # Save the plot to a BytesIO buffer
412
        buf = io.BytesIO()
×
413
        plt.savefig(buf, format="png", bbox_inches="tight", dpi=150)
×
414
        plt.close(fig)  # Close the figure to free up memory
×
415
        buf.seek(0)
×
416

417
        # Load the image from the buffer using PIL
418
        image = Image.open(buf)
×
419
        return ImageSerializer().serialize({"image": image, "format": "png"}, instance)
×
420

421

422
# truncate cell value to maximum allowed length
423
def truncate_cell(cell_value, max_len):
1✔
424
    if cell_value is None:
1✔
425
        return None
×
426

427
    if isinstance(cell_value, int) or isinstance(cell_value, float):
1✔
428
        return None
×
429

430
    if cell_value.strip() == "":
1✔
431
        return None
×
432

433
    if len(cell_value) > max_len:
1✔
434
        return cell_value[:max_len]
1✔
435

436
    return None
1✔
437

438

439
class TruncateTableCells(InstanceOperator):
1✔
440
    """Limit the maximum length of cell values in a table to reduce the overall length.
441

442
    Args:
443
        max_length (int) - maximum allowed length of cell values
444
        For tasks that produce a cell value as answer, truncating a cell value should be replicated
445
        with truncating the corresponding answer as well. This has been addressed in the implementation.
446

447
    """
448

449
    max_length: int = 15
1✔
450
    table: str = None
1✔
451
    text_output: Optional[str] = None
1✔
452

453
    def process(
1✔
454
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
455
    ) -> Dict[str, Any]:
456
        table = dict_get(instance, self.table)
1✔
457

458
        answers = []
1✔
459
        if self.text_output is not None:
1✔
460
            answers = dict_get(instance, self.text_output)
×
461

462
        self.truncate_table(table_content=table, answers=answers)
1✔
463

464
        return instance
1✔
465

466
    # truncate table cells
467
    def truncate_table(self, table_content: Dict, answers: Optional[List]):
1✔
468
        cell_mapping = {}
1✔
469

470
        # One row at a time
471
        for row in table_content.get("rows", []):
1✔
472
            for i, cell in enumerate(row):
1✔
473
                truncated_cell = truncate_cell(cell, self.max_length)
1✔
474
                if truncated_cell is not None:
1✔
475
                    cell_mapping[cell] = truncated_cell
1✔
476
                    row[i] = truncated_cell
1✔
477

478
        # Update values in answer list to truncated values
479
        if answers is not None:
1✔
480
            for i, case in enumerate(answers):
1✔
481
                answers[i] = cell_mapping.get(case, case)
×
482

483

484
class TruncateTableRows(FieldOperator):
1✔
485
    """Limits table rows to specified limit by removing excess rows via random selection.
486

487
    Args:
488
        rows_to_keep (int): number of rows to keep.
489
    """
490

491
    rows_to_keep: int = 10
1✔
492

493
    def process_value(self, table: Any) -> Any:
1✔
494
        return self.truncate_table_rows(table_content=table)
1✔
495

496
    def truncate_table_rows(self, table_content: Dict):
1✔
497
        # Get rows from table
498
        rows = table_content.get("rows", [])
1✔
499

500
        num_rows = len(rows)
1✔
501

502
        # if number of rows are anyway lesser, return.
503
        if num_rows <= self.rows_to_keep:
1✔
504
            return table_content
×
505

506
        # calculate number of rows to delete, delete them
507
        rows_to_delete = num_rows - self.rows_to_keep
1✔
508

509
        # Randomly select rows to be deleted
510
        deleted_rows_indices = random.sample(range(len(rows)), rows_to_delete)
1✔
511

512
        remaining_rows = [
1✔
513
            row for i, row in enumerate(rows) if i not in deleted_rows_indices
514
        ]
515
        table_content["rows"] = remaining_rows
1✔
516

517
        return table_content
1✔
518

519

520
class GetNumOfTableCells(FieldOperator):
1✔
521
    """Get the number of cells in the given table."""
522

523
    def process_value(self, table: Any) -> Any:
1✔
524
        num_of_rows = len(table.get("rows"))
×
525
        num_of_cols = len(table.get("header"))
×
526
        return num_of_rows * num_of_cols
×
527

528

529
class SerializeTableRowAsText(InstanceOperator):
1✔
530
    """Serializes a table row as text.
531

532
    Args:
533
        fields (str) - list of fields to be included in serialization.
534
        to_field (str) - serialized text field name.
535
        max_cell_length (int) - limits cell length to be considered, optional.
536
    """
537

538
    fields: str
1✔
539
    to_field: str
1✔
540
    max_cell_length: Optional[int] = None
1✔
541

542
    def process(
1✔
543
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
544
    ) -> Dict[str, Any]:
545
        linearized_str = ""
1✔
546
        for field in self.fields:
1✔
547
            value = dict_get(instance, field)
1✔
548
            if self.max_cell_length is not None:
1✔
549
                truncated_value = truncate_cell(value, self.max_cell_length)
1✔
550
                if truncated_value is not None:
1✔
551
                    value = truncated_value
×
552

553
            linearized_str = linearized_str + field + " is " + str(value) + ", "
1✔
554

555
        instance[self.to_field] = linearized_str
1✔
556
        return instance
1✔
557

558

559
class SerializeTableRowAsList(InstanceOperator):
1✔
560
    """Serializes a table row as list.
561

562
    Args:
563
        fields (str) - list of fields to be included in serialization.
564
        to_field (str) - serialized text field name.
565
        max_cell_length (int) - limits cell length to be considered, optional.
566
    """
567

568
    fields: str
1✔
569
    to_field: str
1✔
570
    max_cell_length: Optional[int] = None
1✔
571

572
    def process(
1✔
573
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
574
    ) -> Dict[str, Any]:
575
        linearized_str = ""
1✔
576
        for field in self.fields:
1✔
577
            value = dict_get(instance, field)
1✔
578
            if self.max_cell_length is not None:
1✔
579
                truncated_value = truncate_cell(value, self.max_cell_length)
1✔
580
                if truncated_value is not None:
1✔
581
                    value = truncated_value
×
582

583
            linearized_str = linearized_str + field + ": " + str(value) + ", "
1✔
584

585
        instance[self.to_field] = linearized_str
1✔
586
        return instance
1✔
587

588

589
class SerializeTriples(FieldOperator):
1✔
590
    """Serializes triples into a flat sequence.
591

592
    Sample input in expected format:
593
    [[ "First Clearing", "LOCATION", "On NYS 52 1 Mi. Youngsville" ], [ "On NYS 52 1 Mi. Youngsville", "CITY_OR_TOWN", "Callicoon, New York"]]
594

595
    Sample output:
596
    First Clearing : LOCATION : On NYS 52 1 Mi. Youngsville | On NYS 52 1 Mi. Youngsville : CITY_OR_TOWN : Callicoon, New York
597

598
    """
599

600
    def process_value(self, tripleset: Any) -> Any:
1✔
601
        return self.serialize_triples(tripleset)
1✔
602

603
    def serialize_triples(self, tripleset) -> str:
1✔
604
        return " | ".join(
1✔
605
            f"{subj} : {rel.lower()} : {obj}" for subj, rel, obj in tripleset
606
        )
607

608

609
class SerializeKeyValPairs(FieldOperator):
1✔
610
    """Serializes key, value pairs into a flat sequence.
611

612
    Sample input in expected format: {"name": "Alex", "age": 31, "sex": "M"}
613
    Sample output: name is Alex, age is 31, sex is M
614
    """
615

616
    def process_value(self, kvpairs: Any) -> Any:
1✔
617
        return self.serialize_kvpairs(kvpairs)
1✔
618

619
    def serialize_kvpairs(self, kvpairs) -> str:
1✔
620
        serialized_str = ""
1✔
621
        for key, value in kvpairs.items():
1✔
622
            serialized_str += f"{key} is {value}, "
1✔
623

624
        # Remove the trailing comma and space then return
625
        return serialized_str[:-2]
1✔
626

627

628
class ListToKeyValPairs(InstanceOperator):
1✔
629
    """Maps list of keys and values into key:value pairs.
630

631
    Sample input in expected format: {"keys": ["name", "age", "sex"], "values": ["Alex", 31, "M"]}
632
    Sample output: {"name": "Alex", "age": 31, "sex": "M"}
633
    """
634

635
    fields: List[str]
1✔
636
    to_field: str
1✔
637

638
    def process(
1✔
639
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
640
    ) -> Dict[str, Any]:
641
        keylist = dict_get(instance, self.fields[0])
1✔
642
        valuelist = dict_get(instance, self.fields[1])
1✔
643

644
        output_dict = {}
1✔
645
        for key, value in zip(keylist, valuelist):
1✔
646
            output_dict[key] = value
1✔
647

648
        instance[self.to_field] = output_dict
1✔
649

650
        return instance
1✔
651

652

653
class ConvertTableColNamesToSequential(FieldOperator):
1✔
654
    """Replaces actual table column names with static sequential names like col_0, col_1,...
655

656
    .. code-block:: text
657

658
        Sample input:
659
        {
660
            "header": ["name", "age"],
661
            "rows": [["Alex", 21], ["Donald", 34]]
662
        }
663

664
        Sample output:
665
        {
666
            "header": ["col_0", "col_1"],
667
            "rows": [["Alex", 21], ["Donald", 34]]
668
        }
669
    """
670

671
    def process_value(self, table: Any) -> Any:
1✔
672
        table_input = recursive_copy(table)
1✔
673
        return self.replace_header(table_content=table_input)
1✔
674

675
    # replaces header with sequential column names
676
    def replace_header(self, table_content: Dict) -> str:
1✔
677
        # Extract header from the dictionary
678
        header = table_content.get("header", [])
1✔
679

680
        assert header, "Input table missing header"
1✔
681

682
        new_header = ["col_" + str(i) for i in range(len(header))]
1✔
683
        table_content["header"] = new_header
1✔
684

685
        return table_content
1✔
686

687

688
class ShuffleTableRows(TypeDependentAugmentor):
1✔
689
    """Shuffles the input table rows randomly.
690

691
    .. code-block:: text
692

693
        Sample Input:
694
        {
695
            "header": ["name", "age"],
696
            "rows": [["Alex", 26], ["Raj", 34], ["Donald", 39]],
697
        }
698

699
        Sample Output:
700
        {
701
            "header": ["name", "age"],
702
            "rows": [["Donald", 39], ["Raj", 34], ["Alex", 26]],
703
        }
704
    """
705

706
    augmented_type = Table
1✔
707
    seed = 0
1✔
708

709
    def process_value(self, table: Any) -> Any:
1✔
710
        table_input = recursive_copy(table)
1✔
711
        return shuffle_rows(table_input, self.seed)
1✔
712

713

714
class ShuffleTableColumns(TypeDependentAugmentor):
1✔
715
    """Shuffles the table columns randomly.
716

717
    .. code-block:: text
718

719
        Sample Input:
720
            {
721
                "header": ["name", "age"],
722
                "rows": [["Alex", 26], ["Raj", 34], ["Donald", 39]],
723
            }
724

725
        Sample Output:
726
            {
727
                "header": ["age", "name"],
728
                "rows": [[26, "Alex"], [34, "Raj"], [39, "Donald"]],
729
            }
730
    """
731

732
    augmented_type = Table
1✔
733
    seed = 0
1✔
734

735
    def process_value(self, table: Any) -> Any:
1✔
736
        table_input = recursive_copy(table)
1✔
737
        return shuffle_columns(table_input, self.seed)
1✔
738

739

740
class LoadJson(FieldOperator):
1✔
741
    failure_value: Any = None
1✔
742
    allow_failure: bool = False
1✔
743

744
    def process_value(self, value: str) -> Any:
1✔
745
        if self.allow_failure:
1✔
746
            try:
1✔
747
                return json.loads(value)
1✔
748
            except json.JSONDecodeError:
1✔
749
                return self.failure_value
1✔
750
        else:
751
            return json.loads(value, strict=False)
1✔
752

753

754
class DumpJson(FieldOperator):
1✔
755
    def process_value(self, value: str) -> str:
1✔
756
        return json.dumps(value)
1✔
757

758

759
class MapHTMLTableToJSON(FieldOperator):
1✔
760
    """Converts HTML table format to the basic one (JSON).
761

762
    JSON format:
763

764
    .. code-block:: json
765

766
        {
767
            "header": ["col1", "col2"],
768
            "rows": [["row11", "row12"], ["row21", "row22"], ["row31", "row32"]]
769
        }
770
    """
771

772
    _requirements_list = ["bs4"]
1✔
773

774
    def process_value(self, table: Any) -> Any:
1✔
775
        return self.convert_to_json(table_content=table)
1✔
776

777
    def convert_to_json(self, table_content: str) -> Dict:
1✔
778
        from bs4 import BeautifulSoup
1✔
779

780
        soup = BeautifulSoup(table_content, "html.parser")
1✔
781

782
        # Extract header
783
        header = []
1✔
784
        header_cells = soup.find("thead").find_all("th")
1✔
785
        for cell in header_cells:
1✔
786
            header.append(cell.get_text())
1✔
787

788
        # Extract rows
789
        rows = []
1✔
790
        for row in soup.find("tbody").find_all("tr"):
1✔
791
            row_data = []
1✔
792
            for cell in row.find_all("td"):
1✔
793
                row_data.append(cell.get_text())
1✔
794
            rows.append(row_data)
1✔
795

796
        # return dictionary
797

798
        return {"header": header, "rows": rows}
1✔
799

800

801
class MapTableListsToStdTableJSON(FieldOperator):
1✔
802
    """Converts lists table format to the basic one (JSON).
803

804
    JSON format:
805

806
    .. code-block:: json
807

808
        {
809
            "header": ["col1", "col2"],
810
            "rows": [["row11", "row12"], ["row21", "row22"], ["row31", "row32"]]
811
        }
812
    """
813

814
    def process_value(self, table: Any) -> Any:
1✔
815
        return self.map_tablelists_to_stdtablejson_util(table_content=table)
×
816

817
    def map_tablelists_to_stdtablejson_util(self, table_content: str) -> Dict:
1✔
818
        return {"header": table_content[0], "rows": table_content[1:]}
×
819

820

821
class ConstructTableFromRowsCols(InstanceOperator):
1✔
822
    """Maps column and row field into single table field encompassing both header and rows.
823

824
    field[0] = header string as List
825
    field[1] = rows string as List[List]
826
    field[2] = table caption string(optional)
827
    """
828

829
    fields: List[str]
1✔
830
    to_field: str
1✔
831

832
    def process(
1✔
833
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
834
    ) -> Dict[str, Any]:
835
        header = dict_get(instance, self.fields[0])
×
836
        rows = dict_get(instance, self.fields[1])
×
837

838
        if len(self.fields) >= 3:
×
839
            caption = instance[self.fields[2]]
×
840
        else:
841
            caption = None
×
842

843
        import ast
×
844

845
        header_processed = ast.literal_eval(header)
×
846
        rows_processed = ast.literal_eval(rows)
×
847

848
        output_dict = {"header": header_processed, "rows": rows_processed}
×
849

850
        if caption is not None:
×
851
            output_dict["caption"] = caption
×
852

853
        instance[self.to_field] = output_dict
×
854

855
        return instance
×
856

857

858
class TransposeTable(TypeDependentAugmentor):
1✔
859
    """Transpose a table.
860

861
    .. code-block:: text
862

863
        Sample Input:
864
            {
865
                "header": ["name", "age", "sex"],
866
                "rows": [["Alice", 26, "F"], ["Raj", 34, "M"], ["Donald", 39, "M"]],
867
            }
868

869
        Sample Output:
870
            {
871
                "header": [" ", "0", "1", "2"],
872
                "rows": [["name", "Alice", "Raj", "Donald"], ["age", 26, 34, 39], ["sex", "F", "M", "M"]],
873
            }
874

875
    """
876

877
    augmented_type = Table
1✔
878

879
    def process_value(self, table: Any) -> Any:
1✔
880
        return self.transpose_table(table)
1✔
881

882
    def transpose_table(self, table: Dict) -> Dict:
1✔
883
        # Extract the header and rows from the table object
884
        header = table["header"]
1✔
885
        rows = table["rows"]
1✔
886

887
        # Transpose the table by converting rows as columns and vice versa
888
        transposed_header = [" "] + [str(i) for i in range(len(rows))]
1✔
889
        transposed_rows = [
1✔
890
            [header[i]] + [row[i] for row in rows] for i in range(len(header))
891
        ]
892

893
        return {"header": transposed_header, "rows": transposed_rows}
1✔
894

895

896
class DuplicateTableRows(TypeDependentAugmentor):
1✔
897
    """Duplicates specific rows of a table for the given number of times.
898

899
    Args:
900
        row_indices (List[int]): rows to be duplicated
901

902
        times(int): each row to be duplicated is to show that many times
903
    """
904

905
    augmented_type = Table
1✔
906

907
    row_indices: List[int] = []
1✔
908
    times: int = 1
1✔
909

910
    def process_value(self, table: Any) -> Any:
1✔
911
        # Extract the header and rows from the table
912
        header = table["header"]
1✔
913
        rows = table["rows"]
1✔
914

915
        # Duplicate only the specified rows
916
        duplicated_rows = []
1✔
917
        for i, row in enumerate(rows):
1✔
918
            if i in self.row_indices:
1✔
919
                duplicated_rows.extend(
1✔
920
                    [row] * self.times
921
                )  # Duplicate the selected rows
922
            else:
923
                duplicated_rows.append(row)  # Leave other rows unchanged
1✔
924

925
        # Return the new table with selectively duplicated rows
926
        return {"header": header, "rows": duplicated_rows}
1✔
927

928

929
class DuplicateTableColumns(TypeDependentAugmentor):
1✔
930
    """Duplicates specific columns of a table for the given number of times.
931

932
    Args:
933
        column_indices (List[int]): columns to be duplicated
934

935
        times(int): each column to be duplicated is to show that many times
936
    """
937

938
    augmented_type = Table
1✔
939

940
    column_indices: List[int] = []
1✔
941
    times: int = 1
1✔
942

943
    def process_value(self, table: Any) -> Any:
1✔
944
        # Extract the header and rows from the table
945
        header = table["header"]
1✔
946
        rows = table["rows"]
1✔
947

948
        # Duplicate the specified columns in the header
949
        duplicated_header = []
1✔
950
        for i, col in enumerate(header):
1✔
951
            if i in self.column_indices:
1✔
952
                duplicated_header.extend([col] * self.times)
1✔
953
            else:
954
                duplicated_header.append(col)
1✔
955

956
        # Duplicate the specified columns in each row
957
        duplicated_rows = []
1✔
958
        for row in rows:
1✔
959
            new_row = []
1✔
960
            for i, value in enumerate(row):
1✔
961
                if i in self.column_indices:
1✔
962
                    new_row.extend([value] * self.times)
1✔
963
                else:
964
                    new_row.append(value)
1✔
965
            duplicated_rows.append(new_row)
1✔
966

967
        # Return the new table with selectively duplicated columns
968
        return {"header": duplicated_header, "rows": duplicated_rows}
1✔
969

970

971
class InsertEmptyTableRows(TypeDependentAugmentor):
1✔
972
    """Inserts empty rows in a table randomly for the given number of times.
973

974
    Args:
975
        times(int) - how many times to insert
976
    """
977

978
    augmented_type = Table
1✔
979

980
    times: int = 0
1✔
981

982
    def process_value(self, table: Any) -> Any:
1✔
983
        # Extract the header and rows from the table
984
        header = table["header"]
1✔
985
        rows = table["rows"]
1✔
986

987
        # Insert empty rows at random positions
988
        for _ in range(self.times):
1✔
989
            empty_row = [""] * len(
1✔
990
                header
991
            )  # Create an empty row with the same number of columns
992
            insert_pos = random.randint(
1✔
993
                0, len(rows)
994
            )  # Get a random position to insert the empty row created
995
            rows.insert(insert_pos, empty_row)
1✔
996

997
        # Return the modified table
998
        return {"header": header, "rows": rows}
1✔
999

1000

1001
class MaskColumnsNames(TypeDependentAugmentor):
1✔
1002
    """Mask the names of tables columns with dummies "Col1", "Col2" etc."""
1003

1004
    augmented_type = Table
1✔
1005

1006
    def process_value(self, table: Any) -> Any:
1✔
1007
        masked_header = ["Col" + str(ind + 1) for ind in range(len(table["header"]))]
×
1008

1009
        return {"header": masked_header, "rows": table["rows"]}
×
1010

1011

1012
class ShuffleColumnsNames(TypeDependentAugmentor):
1✔
1013
    """Shuffle table columns names to be displayed in random order."""
1014

1015
    augmented_type = Table
1✔
1016

1017
    def process_value(self, table: Any) -> Any:
1✔
1018
        shuffled_header = table["header"]
×
1019
        random.shuffle(shuffled_header)
×
1020

1021
        return {"header": shuffled_header, "rows": table["rows"]}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc