• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 16470706949

23 Jul 2025 12:31PM UTC coverage: 81.122% (-0.1%) from 81.222%
16470706949

Pull #1861

github

web-flow
Merge c48d10af5 into 83063f920
Pull Request #1861: Fix compatibility with datasets 4.0

1585 of 1965 branches covered (80.66%)

Branch coverage included in aggregate %.

10735 of 13222 relevant lines covered (81.19%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

77.54
src/unitxt/struct_data_operators.py
1
"""This section describes unitxt operators for structured data.
2

3
These operators are specialized in handling structured data like tables.
4
For tables, expected input format is:
5

6
.. code-block:: text
7

8
    {
9
        "header": ["col1", "col2"],
10
        "rows": [["row11", "row12"], ["row21", "row22"], ["row31", "row32"]]
11
    }
12

13
For triples, expected input format is:
14

15
.. code-block:: text
16

17
    [[ "subject1", "relation1", "object1" ], [ "subject1", "relation2", "object2"]]
18

19
For key-value pairs, expected input format is:
20

21
.. code-block:: text
22

23
    {"key1": "value1", "key2": value2, "key3": "value3"}
24
"""
25

26
import ast
1✔
27
import json
1✔
28
import random
1✔
29
from abc import ABC, abstractmethod
1✔
30
from typing import (
1✔
31
    Any,
32
    Dict,
33
    List,
34
    Optional,
35
    Tuple,
36
)
37

38
import pandas as pd
1✔
39

40
from .augmentors import TypeDependentAugmentor
1✔
41
from .dict_utils import dict_get
1✔
42
from .error_utils import UnitxtWarning
1✔
43
from .operators import FieldOperator, InstanceOperator
1✔
44
from .random_utils import new_random_generator
1✔
45
from .serializers import ImageSerializer, TableSerializer
1✔
46
from .type_utils import isoftype
1✔
47
from .types import Table, ToolCall
1✔
48
from .utils import recursive_copy
1✔
49

50

51
def shuffle_columns(table: Table, seed=0) -> Table:
1✔
52
    # extract header & rows from the dictionary
53
    header = table.get("header", [])
1✔
54
    rows = table.get("rows", [])
1✔
55
    # shuffle the indices first
56
    indices = list(range(len(header)))
1✔
57
    random_generator = new_random_generator({"table": table, "seed": seed})
1✔
58
    random_generator.shuffle(indices)
1✔
59

60
    # shuffle the header & rows based on that indices
61
    shuffled_header = [header[i] for i in indices]
1✔
62
    shuffled_rows = [[row[i] for i in indices] for row in rows]
1✔
63

64
    table["header"] = shuffled_header
1✔
65
    table["rows"] = shuffled_rows
1✔
66

67
    return table
1✔
68

69

70
def shuffle_rows(table: Table, seed=0) -> Table:
1✔
71
    # extract header & rows from the dictionary
72
    rows = table.get("rows", [])
1✔
73
    # shuffle rows
74
    random_generator = new_random_generator({"table": table, "seed": seed})
1✔
75
    random_generator.shuffle(rows)
1✔
76
    table["rows"] = rows
1✔
77

78
    return table
1✔
79

80

81
class SerializeTable(ABC, TableSerializer):
1✔
82
    """TableSerializer converts a given table into a flat sequence with special symbols.
83

84
    Output format varies depending on the chosen serializer. This abstract class defines structure of a typical table serializer that any concrete implementation should follow.
85
    """
86

87
    seed: int = 0
1✔
88
    shuffle_rows: bool = False
1✔
89
    shuffle_columns: bool = False
1✔
90

91
    def serialize(self, value: Table, instance: Dict[str, Any]) -> str:
1✔
92
        value = recursive_copy(value)
1✔
93
        if self.shuffle_columns:
1✔
94
            value = shuffle_columns(table=value, seed=self.seed)
1✔
95

96
        if self.shuffle_rows:
1✔
97
            value = shuffle_rows(table=value, seed=self.seed)
1✔
98

99
        return self.serialize_table(value)
1✔
100

101
    # main method to serialize a table
102
    @abstractmethod
1✔
103
    def serialize_table(self, table_content: Dict) -> str:
1✔
104
        pass
105

106
    # method to process table header
107
    def process_header(self, header: List):
1✔
108
        pass
109

110
    # method to process a table row
111
    def process_row(self, row: List, row_index: int):
1✔
112
        pass
113

114

115
# Concrete classes implementing table serializers
116
class SerializeTableAsIndexedRowMajor(SerializeTable):
1✔
117
    """Indexed Row Major Table Serializer.
118

119
    Commonly used row major serialization format.
120
    Format:  col : col1 | col2 | col 3 row 1 : val1 | val2 | val3 | val4 row 2 : val1 | ...
121
    """
122

123
    # main method that processes a table
124
    # table_content must be in the presribed input format
125
    def serialize_table(self, table_content: Dict) -> str:
1✔
126
        # Extract headers and rows from the dictionary
127
        header = table_content.get("header", [])
1✔
128
        rows = table_content.get("rows", [])
1✔
129

130
        assert header and rows, "Incorrect input table format"
1✔
131

132
        # Process table header first
133
        serialized_tbl_str = self.process_header(header) + " "
1✔
134

135
        # Process rows sequentially starting from row 1
136
        for i, row in enumerate(rows, start=1):
1✔
137
            serialized_tbl_str += self.process_row(row, row_index=i) + " "
1✔
138

139
        # return serialized table as a string
140
        return serialized_tbl_str.strip()
1✔
141

142
    # serialize header into a string containing the list of column names separated by '|' symbol
143
    def process_header(self, header: List):
1✔
144
        return "col : " + " | ".join(header)
1✔
145

146
    # serialize a table row into a string containing the list of cell values separated by '|'
147
    def process_row(self, row: List, row_index: int):
1✔
148
        serialized_row_str = ""
1✔
149
        row_cell_values = [
1✔
150
            str(value) if isinstance(value, (int, float)) else value for value in row
151
        ]
152
        serialized_row_str += " | ".join([str(value) for value in row_cell_values])
1✔
153

154
        return f"row {row_index} : {serialized_row_str}"
1✔
155

156

157
class SerializeTableAsMarkdown(SerializeTable):
1✔
158
    """Markdown Table Serializer.
159

160
    Markdown table format is used in GitHub code primarily.
161
    Format:
162

163
    .. code-block:: text
164

165
        |col1|col2|col3|
166
        |---|---|---|
167
        |A|4|1|
168
        |I|2|1|
169
        ...
170

171
    """
172

173
    # main method that serializes a table.
174
    # table_content must be in the presribed input format.
175
    def serialize_table(self, table_content: Dict) -> str:
1✔
176
        # Extract headers and rows from the dictionary
177
        header = table_content.get("header", [])
1✔
178
        rows = table_content.get("rows", [])
1✔
179

180
        assert header and rows, "Incorrect input table format"
1✔
181

182
        # Process table header first
183
        serialized_tbl_str = self.process_header(header)
1✔
184

185
        # Process rows sequentially starting from row 1
186
        for i, row in enumerate(rows, start=1):
1✔
187
            serialized_tbl_str += self.process_row(row, row_index=i)
1✔
188

189
        # return serialized table as a string
190
        return serialized_tbl_str.strip()
1✔
191

192
    # serialize header into a string containing the list of column names
193
    def process_header(self, header: List):
1✔
194
        header_str = "|{}|\n".format("|".join(header))
1✔
195
        header_str += "|{}|\n".format("|".join(["---"] * len(header)))
1✔
196
        return header_str
1✔
197

198
    # serialize a table row into a string containing the list of cell values
199
    def process_row(self, row: List, row_index: int):
1✔
200
        row_str = ""
1✔
201
        row_str += "|{}|\n".format("|".join(str(cell) for cell in row))
1✔
202
        return row_str
1✔
203

204

205
class SerializeTableAsDFLoader(SerializeTable):
1✔
206
    """DFLoader Table Serializer.
207

208
    Pandas dataframe based code snippet format serializer.
209
    Format(Sample):
210

211
    .. code-block:: python
212

213
        pd.DataFrame({
214
            "name" : ["Alex", "Diana", "Donald"],
215
            "age" : [26, 34, 39]
216
        },
217
        index=[0,1,2])
218
    """
219

220
    # main method that serializes a table.
221
    # table_content must be in the presribed input format.
222
    def serialize_table(self, table_content: Dict) -> str:
1✔
223
        # Extract headers and rows from the dictionary
224
        header = table_content.get("header", [])
1✔
225
        rows = table_content.get("rows", [])
1✔
226

227
        assert header and rows, "Incorrect input table format"
1✔
228

229
        # Fix duplicate columns, ensuring the first occurrence has no suffix
230
        header = [
1✔
231
            f"{col}_{header[:i].count(col)}" if header[:i].count(col) > 0 else col
232
            for i, col in enumerate(header)
233
        ]
234

235
        # Create a pandas DataFrame
236
        df = pd.DataFrame(rows, columns=header)
1✔
237

238
        # Generate output string in the desired format
239
        data_dict = df.to_dict(orient="list")
1✔
240

241
        return (
1✔
242
            "pd.DataFrame({\n"
243
            + json.dumps(data_dict)[1:-1]
244
            + "},\nindex="
245
            + str(list(range(len(rows))))
246
            + ")"
247
        )
248

249

250
class SerializeTableAsJson(SerializeTable):
1✔
251
    """JSON Table Serializer.
252

253
    Json format based serializer.
254
    Format(Sample):
255

256
    .. code-block:: json
257

258
        {
259
            "0":{"name":"Alex","age":26},
260
            "1":{"name":"Diana","age":34},
261
            "2":{"name":"Donald","age":39}
262
        }
263
    """
264

265
    # main method that serializes a table.
266
    # table_content must be in the presribed input format.
267
    def serialize_table(self, table_content: Dict) -> str:
1✔
268
        # Extract headers and rows from the dictionary
269
        header = table_content.get("header", [])
1✔
270
        rows = table_content.get("rows", [])
1✔
271

272
        assert header and rows, "Incorrect input table format"
1✔
273

274
        # Generate output dictionary
275
        output_dict = {}
1✔
276
        for i, row in enumerate(rows):
1✔
277
            output_dict[i] = {header[j]: value for j, value in enumerate(row)}
1✔
278

279
        # Convert dictionary to JSON string
280
        return json.dumps(output_dict)
1✔
281

282

283
class SerializeTableAsHTML(SerializeTable):
1✔
284
    """HTML Table Serializer.
285

286
    HTML table format used for rendering tables in web pages.
287
    Format(Sample):
288

289
    .. code-block:: html
290

291
        <table>
292
            <thead>
293
                <tr><th>name</th><th>age</th><th>sex</th></tr>
294
            </thead>
295
            <tbody>
296
                <tr><td>Alice</td><td>26</td><td>F</td></tr>
297
                <tr><td>Raj</td><td>34</td><td>M</td></tr>
298
            </tbody>
299
        </table>
300
    """
301

302
    # main method that serializes a table.
303
    # table_content must be in the prescribed input format.
304
    def serialize_table(self, table_content: Dict) -> str:
1✔
305
        # Extract headers and rows from the dictionary
306
        header = table_content.get("header", [])
1✔
307
        rows = table_content.get("rows", [])
1✔
308

309
        assert header and rows, "Incorrect input table format"
1✔
310

311
        # Build the HTML table structure
312
        serialized_tbl_str = "<table>\n"
1✔
313
        serialized_tbl_str += self.process_header(header) + "\n"
1✔
314
        serialized_tbl_str += self.process_rows(rows) + "\n"
1✔
315
        serialized_tbl_str += "</table>"
1✔
316

317
        return serialized_tbl_str.strip()
1✔
318

319
    # serialize the header into an HTML <thead> section
320
    def process_header(self, header: List) -> str:
1✔
321
        header_html = "  <thead>\n    <tr>"
1✔
322
        for col in header:
1✔
323
            header_html += f"<th>{col}</th>"
1✔
324
        header_html += "</tr>\n  </thead>"
1✔
325
        return header_html
1✔
326

327
    # serialize the rows into an HTML <tbody> section
328
    def process_rows(self, rows: List[List]) -> str:
1✔
329
        rows_html = "  <tbody>"
1✔
330
        for row in rows:
1✔
331
            rows_html += "\n    <tr>"
1✔
332
            for cell in row:
1✔
333
                rows_html += f"<td>{cell}</td>"
1✔
334
            rows_html += "</tr>"
1✔
335
        rows_html += "\n  </tbody>"
1✔
336
        return rows_html
1✔
337

338

339
class SerializeTableAsConcatenation(SerializeTable):
1✔
340
    """Concat Serializer.
341

342
    Concat all table content to one string of header and rows.
343
    Format(Sample):
344
    name age Alex 26 Diana 34
345
    """
346

347
    def serialize_table(self, table_content: Dict) -> str:
1✔
348
        # Extract headers and rows from the dictionary
349
        header = table_content["header"]
×
350
        rows = table_content["rows"]
×
351

352
        assert header and rows, "Incorrect input table format"
×
353

354
        # Process table header first
355
        serialized_tbl_str = " ".join([str(i) for i in header])
×
356

357
        # Process rows sequentially starting from row 1
358
        for row in rows:
×
359
            serialized_tbl_str += " " + " ".join([str(i) for i in row])
×
360

361
        # return serialized table as a string
362
        return serialized_tbl_str.strip()
×
363

364

365
class SerializeTableAsImage(SerializeTable):
1✔
366
    _requirements_list = ["matplotlib", "pillow"]
1✔
367

368
    def serialize_table(self, table_content: Dict) -> str:
1✔
369
        raise NotImplementedError()
×
370

371
    def serialize(self, value: Table, instance: Dict[str, Any]) -> str:
1✔
372
        table_content = recursive_copy(value)
×
373
        if self.shuffle_columns:
×
374
            table_content = shuffle_columns(table=table_content, seed=self.seed)
×
375

376
        if self.shuffle_rows:
×
377
            table_content = shuffle_rows(table=table_content, seed=self.seed)
×
378

379
        import io
×
380

381
        import matplotlib.pyplot as plt
×
382
        import pandas as pd
×
383
        from PIL import Image
×
384

385
        # Extract headers and rows from the dictionary
386
        header = table_content.get("header", [])
×
387
        rows = table_content.get("rows", [])
×
388

389
        assert header and rows, "Incorrect input table format"
×
390

391
        # Fix duplicate columns, ensuring the first occurrence has no suffix
392
        header = [
×
393
            f"{col}_{header[:i].count(col)}" if header[:i].count(col) > 0 else col
394
            for i, col in enumerate(header)
395
        ]
396

397
        # Create a pandas DataFrame
398
        df = pd.DataFrame(rows, columns=header)
×
399

400
        # Fix duplicate columns, ensuring the first occurrence has no suffix
401
        df.columns = [
×
402
            f"{col}_{i}" if df.columns.duplicated()[i] else col
403
            for i, col in enumerate(df.columns)
404
        ]
405

406
        # Create a matplotlib table
407
        plt.rcParams["font.family"] = "Serif"
×
408
        fig, ax = plt.subplots(figsize=(len(header) * 1.5, len(rows) * 0.5))
×
409
        ax.axis("off")  # Turn off the axes
×
410

411
        table = pd.plotting.table(ax, df, loc="center", cellLoc="center")
×
412
        table.auto_set_column_width(col=range(len(df.columns)))
×
413
        table.scale(1.5, 1.5)
×
414

415
        # Save the plot to a BytesIO buffer
416
        buf = io.BytesIO()
×
417
        plt.savefig(buf, format="png", bbox_inches="tight", dpi=150)
×
418
        plt.close(fig)  # Close the figure to free up memory
×
419
        buf.seek(0)
×
420

421
        # Load the image from the buffer using PIL
422
        image = Image.open(buf)
×
423
        return ImageSerializer().serialize({"image": image, "format": "png"}, instance)
×
424

425

426
# truncate cell value to maximum allowed length
427
def truncate_cell(cell_value, max_len):
1✔
428
    if cell_value is None:
1✔
429
        return None
×
430

431
    if isinstance(cell_value, int) or isinstance(cell_value, float):
1✔
432
        return None
×
433

434
    if cell_value.strip() == "":
1✔
435
        return None
×
436

437
    if len(cell_value) > max_len:
1✔
438
        return cell_value[:max_len]
1✔
439

440
    return None
1✔
441

442

443
class TruncateTableCells(InstanceOperator):
1✔
444
    """Limit the maximum length of cell values in a table to reduce the overall length.
445

446
    Args:
447
        max_length (int) - maximum allowed length of cell values
448
        For tasks that produce a cell value as answer, truncating a cell value should be replicated
449
        with truncating the corresponding answer as well. This has been addressed in the implementation.
450

451
    """
452

453
    max_length: int = 15
1✔
454
    table: str = None
1✔
455
    text_output: Optional[str] = None
1✔
456

457
    def process(
1✔
458
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
459
    ) -> Dict[str, Any]:
460
        table = dict_get(instance, self.table)
1✔
461

462
        answers = []
1✔
463
        if self.text_output is not None:
1✔
464
            answers = dict_get(instance, self.text_output)
×
465

466
        self.truncate_table(table_content=table, answers=answers)
1✔
467

468
        return instance
1✔
469

470
    # truncate table cells
471
    def truncate_table(self, table_content: Dict, answers: Optional[List]):
1✔
472
        cell_mapping = {}
1✔
473

474
        # One row at a time
475
        for row in table_content.get("rows", []):
1✔
476
            for i, cell in enumerate(row):
1✔
477
                truncated_cell = truncate_cell(cell, self.max_length)
1✔
478
                if truncated_cell is not None:
1✔
479
                    cell_mapping[cell] = truncated_cell
1✔
480
                    row[i] = truncated_cell
1✔
481

482
        # Update values in answer list to truncated values
483
        if answers is not None:
1✔
484
            for i, case in enumerate(answers):
1✔
485
                answers[i] = cell_mapping.get(case, case)
×
486

487

488
class TruncateTableRows(FieldOperator):
1✔
489
    """Limits table rows to specified limit by removing excess rows via random selection.
490

491
    Args:
492
        rows_to_keep (int): number of rows to keep.
493
    """
494

495
    rows_to_keep: int = 10
1✔
496

497
    def process_value(self, table: Any) -> Any:
1✔
498
        return self.truncate_table_rows(table_content=table)
1✔
499

500
    def truncate_table_rows(self, table_content: Dict):
1✔
501
        # Get rows from table
502
        rows = table_content.get("rows", [])
1✔
503

504
        num_rows = len(rows)
1✔
505

506
        # if number of rows are anyway lesser, return.
507
        if num_rows <= self.rows_to_keep:
1✔
508
            return table_content
×
509

510
        # calculate number of rows to delete, delete them
511
        rows_to_delete = num_rows - self.rows_to_keep
1✔
512

513
        # Randomly select rows to be deleted
514
        deleted_rows_indices = random.sample(range(len(rows)), rows_to_delete)
1✔
515

516
        remaining_rows = [
1✔
517
            row for i, row in enumerate(rows) if i not in deleted_rows_indices
518
        ]
519
        table_content["rows"] = remaining_rows
1✔
520

521
        return table_content
1✔
522

523

524
class GetNumOfTableCells(FieldOperator):
1✔
525
    """Get the number of cells in the given table."""
526

527
    def process_value(self, table: Any) -> Any:
1✔
528
        num_of_rows = len(table.get("rows"))
×
529
        num_of_cols = len(table.get("header"))
×
530
        return num_of_rows * num_of_cols
×
531

532

533
class SerializeTableRowAsText(InstanceOperator):
1✔
534
    """Serializes a table row as text.
535

536
    Args:
537
        fields (str) - list of fields to be included in serialization.
538
        to_field (str) - serialized text field name.
539
        max_cell_length (int) - limits cell length to be considered, optional.
540
    """
541

542
    fields: str
1✔
543
    to_field: str
1✔
544
    max_cell_length: Optional[int] = None
1✔
545

546
    def process(
1✔
547
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
548
    ) -> Dict[str, Any]:
549
        linearized_str = ""
1✔
550
        for field in self.fields:
1✔
551
            value = dict_get(instance, field)
1✔
552
            if self.max_cell_length is not None:
1✔
553
                truncated_value = truncate_cell(value, self.max_cell_length)
1✔
554
                if truncated_value is not None:
1✔
555
                    value = truncated_value
×
556

557
            linearized_str = linearized_str + field + " is " + str(value) + ", "
1✔
558

559
        instance[self.to_field] = linearized_str
1✔
560
        return instance
1✔
561

562

563
class SerializeTableRowAsList(InstanceOperator):
1✔
564
    """Serializes a table row as list.
565

566
    Args:
567
        fields (str) - list of fields to be included in serialization.
568
        to_field (str) - serialized text field name.
569
        max_cell_length (int) - limits cell length to be considered, optional.
570
    """
571

572
    fields: str
1✔
573
    to_field: str
1✔
574
    max_cell_length: Optional[int] = None
1✔
575

576
    def process(
1✔
577
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
578
    ) -> Dict[str, Any]:
579
        linearized_str = ""
1✔
580
        for field in self.fields:
1✔
581
            value = dict_get(instance, field)
1✔
582
            if self.max_cell_length is not None:
1✔
583
                truncated_value = truncate_cell(value, self.max_cell_length)
1✔
584
                if truncated_value is not None:
1✔
585
                    value = truncated_value
×
586

587
            linearized_str = linearized_str + field + ": " + str(value) + ", "
1✔
588

589
        instance[self.to_field] = linearized_str
1✔
590
        return instance
1✔
591

592

593
class SerializeTriples(FieldOperator):
1✔
594
    """Serializes triples into a flat sequence.
595

596
    Sample input in expected format:
597
    [[ "First Clearing", "LOCATION", "On NYS 52 1 Mi. Youngsville" ], [ "On NYS 52 1 Mi. Youngsville", "CITY_OR_TOWN", "Callicoon, New York"]]
598

599
    Sample output:
600
    First Clearing : LOCATION : On NYS 52 1 Mi. Youngsville | On NYS 52 1 Mi. Youngsville : CITY_OR_TOWN : Callicoon, New York
601

602
    """
603

604
    def process_value(self, tripleset: Any) -> Any:
1✔
605
        return self.serialize_triples(tripleset)
1✔
606

607
    def serialize_triples(self, tripleset) -> str:
1✔
608
        return " | ".join(
1✔
609
            f"{subj} : {rel.lower()} : {obj}" for subj, rel, obj in tripleset
610
        )
611

612

613
class SerializeKeyValPairs(FieldOperator):
1✔
614
    """Serializes key, value pairs into a flat sequence.
615

616
    Sample input in expected format: {"name": "Alex", "age": 31, "sex": "M"}
617
    Sample output: name is Alex, age is 31, sex is M
618
    """
619

620
    def process_value(self, kvpairs: Any) -> Any:
1✔
621
        return self.serialize_kvpairs(kvpairs)
1✔
622

623
    def serialize_kvpairs(self, kvpairs) -> str:
1✔
624
        serialized_str = ""
1✔
625
        for key, value in kvpairs.items():
1✔
626
            serialized_str += f"{key} is {value}, "
1✔
627

628
        # Remove the trailing comma and space then return
629
        return serialized_str[:-2]
1✔
630

631

632
class ListToKeyValPairs(InstanceOperator):
1✔
633
    """Maps list of keys and values into key:value pairs.
634

635
    Sample input in expected format: {"keys": ["name", "age", "sex"], "values": ["Alex", 31, "M"]}
636
    Sample output: {"name": "Alex", "age": 31, "sex": "M"}
637
    """
638

639
    fields: List[str]
1✔
640
    to_field: str
1✔
641

642
    def process(
1✔
643
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
644
    ) -> Dict[str, Any]:
645
        keylist = dict_get(instance, self.fields[0])
1✔
646
        valuelist = dict_get(instance, self.fields[1])
1✔
647

648
        output_dict = {}
1✔
649
        for key, value in zip(keylist, valuelist):
1✔
650
            output_dict[key] = value
1✔
651

652
        instance[self.to_field] = output_dict
1✔
653

654
        return instance
1✔
655

656

657
class ConvertTableColNamesToSequential(FieldOperator):
1✔
658
    """Replaces actual table column names with static sequential names like col_0, col_1,...
659

660
    .. code-block:: text
661

662
        Sample input:
663
        {
664
            "header": ["name", "age"],
665
            "rows": [["Alex", 21], ["Donald", 34]]
666
        }
667

668
        Sample output:
669
        {
670
            "header": ["col_0", "col_1"],
671
            "rows": [["Alex", 21], ["Donald", 34]]
672
        }
673
    """
674

675
    def process_value(self, table: Any) -> Any:
1✔
676
        table_input = recursive_copy(table)
1✔
677
        return self.replace_header(table_content=table_input)
1✔
678

679
    # replaces header with sequential column names
680
    def replace_header(self, table_content: Dict) -> str:
1✔
681
        # Extract header from the dictionary
682
        header = table_content.get("header", [])
1✔
683

684
        assert header, "Input table missing header"
1✔
685

686
        new_header = ["col_" + str(i) for i in range(len(header))]
1✔
687
        table_content["header"] = new_header
1✔
688

689
        return table_content
1✔
690

691

692
class ShuffleTableRows(TypeDependentAugmentor):
1✔
693
    """Shuffles the input table rows randomly.
694

695
    .. code-block:: text
696

697
        Sample Input:
698
        {
699
            "header": ["name", "age"],
700
            "rows": [["Alex", 26], ["Raj", 34], ["Donald", 39]],
701
        }
702

703
        Sample Output:
704
        {
705
            "header": ["name", "age"],
706
            "rows": [["Donald", 39], ["Raj", 34], ["Alex", 26]],
707
        }
708
    """
709

710
    augmented_type = Table
1✔
711
    seed = 0
1✔
712

713
    def process_value(self, table: Any) -> Any:
1✔
714
        table_input = recursive_copy(table)
1✔
715
        return shuffle_rows(table_input, self.seed)
1✔
716

717

718
class ShuffleTableColumns(TypeDependentAugmentor):
1✔
719
    """Shuffles the table columns randomly.
720

721
    .. code-block:: text
722

723
        Sample Input:
724
            {
725
                "header": ["name", "age"],
726
                "rows": [["Alex", 26], ["Raj", 34], ["Donald", 39]],
727
            }
728

729
        Sample Output:
730
            {
731
                "header": ["age", "name"],
732
                "rows": [[26, "Alex"], [34, "Raj"], [39, "Donald"]],
733
            }
734
    """
735

736
    augmented_type = Table
1✔
737
    seed = 0
1✔
738

739
    def process_value(self, table: Any) -> Any:
1✔
740
        table_input = recursive_copy(table)
1✔
741
        return shuffle_columns(table_input, self.seed)
1✔
742

743

744
class LoadJson(FieldOperator):
1✔
745
    failure_value: Any = None
1✔
746
    allow_failure: bool = False
1✔
747

748
    def process_value(self, value: str) -> Any:
1✔
749
        if self.allow_failure:
1✔
750
            try:
1✔
751
                return json.loads(value)
1✔
752
            except json.JSONDecodeError:
1✔
753
                return self.failure_value
1✔
754
        else:
755
            return json.loads(value, strict=False)
1✔
756

757

758
class PythonCallProcessor(FieldOperator):
1✔
759
    def process_value(self, value: str) -> ToolCall:
1✔
760
        expr = ast.parse(value, mode="eval").body
×
761
        function = expr.func.id
×
762
        args = {}
×
763
        for kw in expr.keywords:
×
764
            args[kw.arg] = ast.literal_eval(kw.value)
×
765
        # Handle positional args, if any
766
        if expr.args:
×
767
            args["_args"] = [ast.literal_eval(arg) for arg in expr.args]
×
768
        return {"name": function, "arguments": args}
×
769

770

771
def extract_possible_json_str(text):
1✔
772
    """Extract potential JSON string from text by finding outermost braces/brackets."""
773
    # Find first opening delimiter
774
    start_positions = [pos for pos in [text.find("{"), text.find("[")] if pos != -1]
×
775
    start = min(start_positions) if start_positions else 0
×
776

777
    # Find last closing delimiter
778
    end_positions = [pos for pos in [text.rfind("}"), text.rfind("]")] if pos != -1]
×
779
    end = max(end_positions) if end_positions else len(text) - 1
×
780

781
    return text[start : end + 1]
×
782

783

784
class ToolCallPostProcessor(FieldOperator):
1✔
785
    failure_value: Any = None
1✔
786
    allow_failure: bool = False
1✔
787

788
    def process_value(self, value: str) -> ToolCall:
1✔
789
        value = extract_possible_json_str(
×
790
            value
791
        )  # clear tokens such as <tool_call> focusing on the call json itself
792
        if self.allow_failure:
×
793
            try:
×
794
                result = json.loads(value)
×
795
            except json.JSONDecodeError:
×
796
                return self.failure_value
×
797
        else:
798
            result = json.loads(value, strict=False)
×
799
        if isoftype(result, List[ToolCall]):
×
800
            if len(result) > 1:
×
801
                UnitxtWarning(f"More than one tool call returned from model: {result}")
×
802
                return self.failure_value
×
803
            if len(result) == 0:
×
804
                return self.failure_value
×
805
            return result[0]
×
806
        if not isoftype(result, ToolCall):
×
807
            return self.failure_value
×
808
        return result
×
809

810

811
class MultipleToolCallPostProcessor(FieldOperator):
1✔
812
    failure_value: Any = None
1✔
813
    allow_failure: bool = False
1✔
814

815
    def process_value(self, value: str) -> List[ToolCall]:
1✔
816
        if self.allow_failure:
×
817
            try:
×
818
                result = json.loads(value)
×
819
            except json.JSONDecodeError:
×
820
                return self.failure_value
×
821
        else:
822
            result = json.loads(value, strict=False)
×
823
        if isoftype(result, List[ToolCall]):
×
824
            return result
×
825
        if not isoftype(result, ToolCall):
×
826
            return self.failure_value
×
827
        return [result]
×
828

829

830
class DumpJson(FieldOperator):
1✔
831
    def process_value(self, value: str) -> str:
1✔
832
        return json.dumps(value)
1✔
833

834

835
class MapHTMLTableToJSON(FieldOperator):
1✔
836
    """Converts HTML table format to the basic one (JSON).
837

838
    JSON format:
839

840
    .. code-block:: json
841

842
        {
843
            "header": ["col1", "col2"],
844
            "rows": [["row11", "row12"], ["row21", "row22"], ["row31", "row32"]]
845
        }
846
    """
847

848
    _requirements_list = ["bs4"]
1✔
849

850
    def process_value(self, table: Any) -> Any:
1✔
851
        return self.convert_to_json(table_content=table)
1✔
852

853
    def convert_to_json(self, table_content: str) -> Dict:
1✔
854
        from bs4 import BeautifulSoup
1✔
855

856
        soup = BeautifulSoup(table_content, "html.parser")
1✔
857

858
        # Extract header
859
        header = []
1✔
860
        header_cells = soup.find("thead").find_all("th")
1✔
861
        for cell in header_cells:
1✔
862
            header.append(cell.get_text())
1✔
863

864
        # Extract rows
865
        rows = []
1✔
866
        for row in soup.find("tbody").find_all("tr"):
1✔
867
            row_data = []
1✔
868
            for cell in row.find_all("td"):
1✔
869
                row_data.append(cell.get_text())
1✔
870
            rows.append(row_data)
1✔
871

872
        # return dictionary
873

874
        return {"header": header, "rows": rows}
1✔
875

876

877
class MapTableListsToStdTableJSON(FieldOperator):
1✔
878
    """Converts lists table format to the basic one (JSON).
879

880
    JSON format:
881

882
    .. code-block:: json
883

884
        {
885
            "header": ["col1", "col2"],
886
            "rows": [["row11", "row12"], ["row21", "row22"], ["row31", "row32"]]
887
        }
888
    """
889

890
    def process_value(self, table: Any) -> Any:
1✔
891
        return self.map_tablelists_to_stdtablejson_util(table_content=table)
×
892

893
    def map_tablelists_to_stdtablejson_util(self, table_content: str) -> Dict:
1✔
894
        return {"header": table_content[0], "rows": table_content[1:]}
×
895

896

897
class ConstructTableFromRowsCols(InstanceOperator):
1✔
898
    """Maps column and row field into single table field encompassing both header and rows.
899

900
    field[0] = header string as List
901
    field[1] = rows string as List[List]
902
    field[2] = table caption string(optional)
903
    """
904

905
    fields: List[str]
1✔
906
    to_field: str
1✔
907

908
    def process(
1✔
909
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
910
    ) -> Dict[str, Any]:
911
        header = dict_get(instance, self.fields[0])
×
912
        rows = dict_get(instance, self.fields[1])
×
913

914
        if len(self.fields) >= 3:
×
915
            caption = instance[self.fields[2]]
×
916
        else:
917
            caption = None
×
918

919
        import ast
×
920

921
        header_processed = ast.literal_eval(header)
×
922
        rows_processed = ast.literal_eval(rows)
×
923

924
        output_dict = {"header": header_processed, "rows": rows_processed}
×
925

926
        if caption is not None:
×
927
            output_dict["caption"] = caption
×
928

929
        instance[self.to_field] = output_dict
×
930

931
        return instance
×
932

933

934
class TransposeTable(TypeDependentAugmentor):
1✔
935
    """Transpose a table.
936

937
    .. code-block:: text
938

939
        Sample Input:
940
            {
941
                "header": ["name", "age", "sex"],
942
                "rows": [["Alice", 26, "F"], ["Raj", 34, "M"], ["Donald", 39, "M"]],
943
            }
944

945
        Sample Output:
946
            {
947
                "header": [" ", "0", "1", "2"],
948
                "rows": [["name", "Alice", "Raj", "Donald"], ["age", 26, 34, 39], ["sex", "F", "M", "M"]],
949
            }
950

951
    """
952

953
    augmented_type = Table
1✔
954

955
    def process_value(self, table: Any) -> Any:
1✔
956
        return self.transpose_table(table)
1✔
957

958
    def transpose_table(self, table: Dict) -> Dict:
1✔
959
        # Extract the header and rows from the table object
960
        header = table["header"]
1✔
961
        rows = table["rows"]
1✔
962

963
        # Transpose the table by converting rows as columns and vice versa
964
        transposed_header = [" "] + [str(i) for i in range(len(rows))]
1✔
965
        transposed_rows = [
1✔
966
            [header[i]] + [row[i] for row in rows] for i in range(len(header))
967
        ]
968

969
        return {"header": transposed_header, "rows": transposed_rows}
1✔
970

971

972
class DuplicateTableRows(TypeDependentAugmentor):
1✔
973
    """Duplicates specific rows of a table for the given number of times.
974

975
    Args:
976
        row_indices (List[int]): rows to be duplicated
977

978
        times(int): each row to be duplicated is to show that many times
979
    """
980

981
    augmented_type = Table
1✔
982

983
    row_indices: List[int] = []
1✔
984
    times: int = 1
1✔
985

986
    def process_value(self, table: Any) -> Any:
1✔
987
        # Extract the header and rows from the table
988
        header = table["header"]
1✔
989
        rows = table["rows"]
1✔
990

991
        # Duplicate only the specified rows
992
        duplicated_rows = []
1✔
993
        for i, row in enumerate(rows):
1✔
994
            if i in self.row_indices:
1✔
995
                duplicated_rows.extend(
1✔
996
                    [row] * self.times
997
                )  # Duplicate the selected rows
998
            else:
999
                duplicated_rows.append(row)  # Leave other rows unchanged
1✔
1000

1001
        # Return the new table with selectively duplicated rows
1002
        return {"header": header, "rows": duplicated_rows}
1✔
1003

1004

1005
class DuplicateTableColumns(TypeDependentAugmentor):
1✔
1006
    """Duplicates specific columns of a table for the given number of times.
1007

1008
    Args:
1009
        column_indices (List[int]): columns to be duplicated
1010

1011
        times(int): each column to be duplicated is to show that many times
1012
    """
1013

1014
    augmented_type = Table
1✔
1015

1016
    column_indices: List[int] = []
1✔
1017
    times: int = 1
1✔
1018

1019
    def process_value(self, table: Any) -> Any:
1✔
1020
        # Extract the header and rows from the table
1021
        header = table["header"]
1✔
1022
        rows = table["rows"]
1✔
1023

1024
        # Duplicate the specified columns in the header
1025
        duplicated_header = []
1✔
1026
        for i, col in enumerate(header):
1✔
1027
            if i in self.column_indices:
1✔
1028
                duplicated_header.extend([col] * self.times)
1✔
1029
            else:
1030
                duplicated_header.append(col)
1✔
1031

1032
        # Duplicate the specified columns in each row
1033
        duplicated_rows = []
1✔
1034
        for row in rows:
1✔
1035
            new_row = []
1✔
1036
            for i, value in enumerate(row):
1✔
1037
                if i in self.column_indices:
1✔
1038
                    new_row.extend([value] * self.times)
1✔
1039
                else:
1040
                    new_row.append(value)
1✔
1041
            duplicated_rows.append(new_row)
1✔
1042

1043
        # Return the new table with selectively duplicated columns
1044
        return {"header": duplicated_header, "rows": duplicated_rows}
1✔
1045

1046

1047
class InsertEmptyTableRows(TypeDependentAugmentor):
1✔
1048
    """Inserts empty rows in a table randomly for the given number of times.
1049

1050
    Args:
1051
        times(int) - how many times to insert
1052
    """
1053

1054
    augmented_type = Table
1✔
1055

1056
    times: int = 0
1✔
1057

1058
    def process_value(self, table: Any) -> Any:
1✔
1059
        # Extract the header and rows from the table
1060
        header = table["header"]
1✔
1061
        rows = table["rows"]
1✔
1062

1063
        # Insert empty rows at random positions
1064
        for _ in range(self.times):
1✔
1065
            empty_row = [""] * len(
1✔
1066
                header
1067
            )  # Create an empty row with the same number of columns
1068
            insert_pos = random.randint(
1✔
1069
                0, len(rows)
1070
            )  # Get a random position to insert the empty row created
1071
            rows.insert(insert_pos, empty_row)
1✔
1072

1073
        # Return the modified table
1074
        return {"header": header, "rows": rows}
1✔
1075

1076

1077
class MaskColumnsNames(TypeDependentAugmentor):
1✔
1078
    """Mask the names of tables columns with dummies "Col1", "Col2" etc."""
1079

1080
    augmented_type = Table
1✔
1081

1082
    def process_value(self, table: Any) -> Any:
1✔
1083
        masked_header = ["Col" + str(ind + 1) for ind in range(len(table["header"]))]
×
1084

1085
        return {"header": masked_header, "rows": table["rows"]}
×
1086

1087

1088
class ShuffleColumnsNames(TypeDependentAugmentor):
1✔
1089
    """Shuffle table columns names to be displayed in random order."""
1090

1091
    augmented_type = Table
1✔
1092

1093
    def process_value(self, table: Any) -> Any:
1✔
1094
        shuffled_header = table["header"]
×
1095
        random.shuffle(shuffled_header)
×
1096

1097
        return {"header": shuffled_header, "rows": table["rows"]}
×
1098

1099

1100
class JsonStrToDict(FieldOperator):
1✔
1101
    """Convert a Json string of representing key value as dictionary.
1102

1103
    Ensure keys and values are strings, and there are no None values.
1104

1105
    """
1106

1107
    def process_value(self, text: str) -> List[Tuple[str, str]]:
1✔
1108
        try:
1✔
1109
            dict_value = json.loads(text)
1✔
1110
        except Exception as e:
1111
            UnitxtWarning(
1112
                f"Unable to convert input text to json format in JsonStrToDict due to {e}. Text: {text}"
1113
            )
1114
            dict_value = {}
1115
        if not isoftype(dict_value, Dict[str, Any]):
1✔
1116
            UnitxtWarning(
1✔
1117
                f"Unable to convert input text to dictionary in JsonStrToDict. Text: {text}"
1118
            )
1119
            dict_value = {}
1✔
1120
        return {str(k): str(v) for k, v in dict_value.items() if v is not None}
1✔
1121

1122

1123
class ParseCSV(FieldOperator):
1✔
1124
    r"""Parse CSV/TSV text content into table format.
1125

1126
    This operator converts CSV or TSV text content into the standard table format
1127
    used by Unitxt with header and rows fields.
1128

1129
    Args:
1130
        separator (str): Field separator character. Defaults to ','.
1131
        has_header (bool): Whether the first row contains column headers. Defaults to True.
1132
        skip_header (bool): Whether to skip the first row entirely. Defaults to False.
1133

1134
    Example:
1135
        Parsing CSV content
1136

1137
        .. code-block:: python
1138

1139
            ParseCSV(field="csv_content", to_field="table", separator=",")
1140

1141
        Parsing TSV content
1142

1143
        .. code-block:: python
1144

1145
            ParseCSV(field="tsv_content", to_field="table", separator="\t")
1146
    """
1147

1148
    separator: str = ","
1✔
1149
    has_header: bool = True
1✔
1150
    skip_header: bool = False
1✔
1151

1152
    def process_value(self, value: str) -> Dict[str, Any]:
1✔
1153
        """Parse CSV/TSV text content into table format."""
1154
        import csv
1✔
1155
        import io
1✔
1156

1157
        # Create a StringIO object to read the CSV content
1158
        csv_reader = csv.reader(io.StringIO(value), delimiter=self.separator)
1✔
1159
        rows = list(csv_reader)
1✔
1160

1161
        if not rows:
1✔
1162
            return {"header": [], "rows": []}
×
1163

1164
        if self.skip_header:
1✔
1165
            # Skip the first row entirely
1166
            rows = rows[1:]
1✔
1167
            if not rows:
1✔
1168
                return {"header": [], "rows": []}
×
1169
            # Generate generic column names
1170
            header = [f"col_{i}" for i in range(len(rows[0]))]
1✔
1171
            table_rows = rows
1✔
1172
        elif self.has_header:
1✔
1173
            # First row is header
1174
            header = rows[0]
1✔
1175
            table_rows = rows[1:]
1✔
1176
        else:
1177
            # No header, generate generic column names
1178
            header = [f"col_{i}" for i in range(len(rows[0]))]
1✔
1179
            table_rows = rows
1✔
1180

1181
        return {"header": header, "rows": table_rows}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc