• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 18944482945

30 Oct 2025 02:40PM UTC coverage: 80.795% (-0.1%) from 80.893%
18944482945

Pull #1942

github

web-flow
Merge 0cdf61df9 into c90534f87
Pull Request #1942: Enable summarization by subsets and groups

1607 of 2006 branches covered (80.11%)

Branch coverage included in aggregate %.

10947 of 13532 relevant lines covered (80.9%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

56.19
src/unitxt/evaluate_cli.py
1
# evaluate_cli.py
2
import argparse
1✔
3
import importlib.metadata
1✔
4
import json
1✔
5
import logging
1✔
6
import os
1✔
7
import platform
1✔
8
import subprocess
1✔
9
import sys
1✔
10
from datetime import datetime, timezone
1✔
11
from functools import partial
1✔
12
from typing import Any, Dict, List, Optional, Tuple, Union
1✔
13

14
from datasets import Dataset as HFDataset
1✔
15

16
from .api import _source_to_dataset, evaluate, load_recipe
1✔
17
from .artifact import UnitxtArtifactNotFoundError
1✔
18
from .benchmark import Benchmark
1✔
19

20
# Use HFAutoModelInferenceEngine for local models
21
from .inference import (
1✔
22
    CrossProviderInferenceEngine,
23
    HFAutoModelInferenceEngine,
24
    InferenceEngine,
25
)
26
from .logging_utils import get_logger
1✔
27
from .metric_utils import EvaluationResults
1✔
28
from .parsing_utils import parse_key_equals_value_string_to_dict
1✔
29
from .settings_utils import settings
1✔
30

31
# Define logger early so it can be used in initial error handling
32
# Basic config for initial messages, will be reconfigured in main()
33
logger = get_logger()
34

35

36
def try_parse_json(value: str) -> Union[str, dict, None]:
1✔
37
    """Attempts to parse a string as JSON or key=value pairs.
38

39
    Returns the original string if parsing fails
40
    and the string doesn't look like JSON/kv pairs.
41
    Raises ArgumentTypeError if it looks like JSON but is invalid.
42
    """
43
    if value is None:
1✔
44
        return None
1✔
45
    try:
1✔
46
        # Handle simple key-value pairs like "key=value,key2=value2"
47
        if "=" in value and "{" not in value:
1✔
48
            parsed_dict = parse_key_equals_value_string_to_dict(value)
1✔
49
            if parsed_dict:
1✔
50
                return parsed_dict
1✔
51

52
        # Attempt standard JSON parsing
53
        return json.loads(value)
1✔
54

55
    except json.JSONDecodeError as e:
1✔
56
        if value.strip().startswith("{") or value.strip().startswith("["):
1✔
57
            raise argparse.ArgumentTypeError(
1✔
58
                f"Invalid JSON: '{value}'. Hint: Use double quotes for JSON strings and check syntax."
59
            ) from e
60
        return value  # Return as string if not JSON-like
1✔
61
    except Exception as e:
62
        logger.error(f"Error parsing argument '{value}': {e}")
63
        raise argparse.ArgumentTypeError(f"Could not parse argument: '{value}'") from e
64

65

66
def setup_parser() -> argparse.ArgumentParser:
1✔
67
    """Sets up the argument parser."""
68
    parser = argparse.ArgumentParser(
1✔
69
        formatter_class=argparse.RawTextHelpFormatter,
70
        description="CLI utility for running evaluations with unitxt.",
71
    )
72

73
    # --- Task/Dataset Arguments ---
74
    parser.add_argument(
1✔
75
        "--tasks",  # Changed to plural to better reflect it holds a list
76
        "-t",
77
        dest="tasks",  # Explicitly set the attribute name to 'tasks'
78
        type=partial(str.split, sep="+"),  # Use the custom function for type conversion
79
        required=True,
80
        help="Plus-separated (+) list of Unitxt task/dataset identifier strings.\n"
81
        "Each task format: 'card=<card_ref>,template=<template_ref>,...'\n"
82
        "Example: 'card=cards.mmlu,t=t.mmlu.all+card=cards.hellaswag,t=t.hellaswag.no'",
83
    )
84

85
    parser.add_argument(
1✔
86
        "--split",
87
        type=str,
88
        default="test",
89
        help="Dataset split to use (e.g., 'train', 'validation', 'test'). Default: 'test'.",
90
    )
91
    parser.add_argument(
1✔
92
        "--num_fewshots",
93
        type=int,
94
        default=None,
95
        help="number of fewshots to use",
96
    )
97
    parser.add_argument(
1✔
98
        "--limit",
99
        "-L",
100
        type=int,
101
        default=None,
102
        metavar="N",
103
        help="Limit the number of examples per task/dataset.",
104
    )
105

106
    parser.add_argument(
1✔
107
        "--batch_size",
108
        "-b",
109
        type=int,
110
        default=1,
111
        help="Batch size for use in inference when selected model is hf. Default 1",
112
    )
113

114
    # --- Model Arguments (Explicit Types) ---
115
    parser.add_argument(
1✔
116
        "--model",
117
        "-m",
118
        type=str,
119
        default="hf",
120
        choices=["hf", "cross_provider"],
121
        help="Specifies the model type/engine.\n"
122
        "- 'hf': Local Hugging Face model via HFAutoModel (default). Requires 'pretrained=...' in --model_args.\n"
123
        "- 'cross_provider': Remote model via CrossProviderInferenceEngine. Requires 'model_name=...' in --model_args.",
124
    )
125
    parser.add_argument(
126
        "--model_args",
127
        "-a",
128
        type=try_parse_json,
129
        default={},
130
        help="Comma separated string or JSON formatted arguments for the model/inference engine.\n"
131
        "Examples:\n"
132
        "- For --model hf (default): 'pretrained=meta-llama/Llama-3.1-8B-Instruct,torch_dtype=bfloat16,device=cuda'\n"
133
        "  (Note: 'pretrained' key is REQUIRED. Other args like 'torch_dtype', 'device', generation params are passed too)\n"
134
        "- For --model generic_remote: 'model_name=llama-3-3-70b-instruct,max_tokens=256,temperature=0.7'\n"
135
        "  (Note: 'model_name' key is REQUIRED)\n"
136
        '- JSON format: \'{"pretrained": "my_model", "torch_dtype": "float32"}\' or \'{"model_name": "openai/gpt-4o"}\'',
137
    )
138

139
    parser.add_argument(
1✔
140
        "--gen_kwargs",
141
        type=try_parse_json,
142
        default=None,
143
        help=(
144
            "Comma delimited string for model generation on greedy_until tasks,"
145
            """ e.g. temperature=0,top_p=0.1."""
146
        ),
147
    )
148

149
    parser.add_argument(
1✔
150
        "--chat_template_kwargs",
151
        type=try_parse_json,
152
        default=None,
153
        help=(
154
            "Comma delimited string for tokenizer kwargs"
155
            "e.g. thinking=True (https://github.com/huggingface/transformers/blob/9a1c1fe7edaefdb25ab37116a979832df298d6ea/src/transformers/tokenization_utils_base.py#L1542)"
156
        ),
157
    )
158

159
    # --- Output and Logging Arguments ---
160
    parser.add_argument(
1✔
161
        "--output_path",
162
        "-o",
163
        type=str,
164
        default=".",
165
        help="Directory to save evaluation results and logs. Default: current directory.",
166
    )
167
    parser.add_argument(
1✔
168
        "--output_file_prefix",
169
        type=str,
170
        default="evaluation_results",
171
        help="Prefix for the output JSON file names. Default: 'evaluation_results'.",
172
    )
173
    parser.add_argument(
1✔
174
        "--log_samples",
175
        "-s",
176
        action="store_true",
177
        default=False,
178
        help="If True, save individual predictions and scores to a separate JSON file.",
179
    )
180
    parser.add_argument(
1✔
181
        "--verbosity",
182
        "-v",
183
        type=str.upper,
184
        default="INFO",
185
        choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
186
        help="Controls logging verbosity level. Default: INFO.",
187
    )
188

189
    parser.add_argument(
1✔
190
        "--apply_chat_template",
191
        action="store_true",
192
        default=False,
193
    )
194

195
    # --- Unitxt Settings ---
196
    parser.add_argument(
1✔
197
        "--trust_remote_code",
198
        action="store_true",
199
        default=False,
200
        help="Allow execution of unverified code from the HuggingFace Hub (used by datasets/unitxt).",
201
    )
202
    parser.add_argument(
1✔
203
        "--disable_hf_cache",
204
        action="store_true",
205
        default=False,
206
        help="Disable HuggingFace datasets caching.",
207
    )
208
    parser.add_argument(
1✔
209
        "--cache_dir",
210
        type=str,
211
        default=None,
212
        help="Directory for HuggingFace datasets cache (overrides default).",
213
    )
214

215
    return parser
1✔
216

217

218
def setup_logging(verbosity: str) -> None:
1✔
219
    """Configures logging based on verbosity level."""
220
    logging.basicConfig(
×
221
        level=verbosity,
222
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
223
        force=True,  # Ensures reconfiguration works if basicConfig was called before
224
    )
225
    # Re-get the logger instance after basicConfig is set
226
    global logger
227
    logger = get_logger()
228
    logger.setLevel(verbosity)
229

230

231
def prepare_output_paths(output_path: str, prefix: str) -> Tuple[str, str]:
1✔
232
    """Creates output directory and defines file paths.
233

234
    Args:
235
        output_path (str): The directory where output files will be saved.
236
        prefix (str): The prefix for the output file names.
237

238
    Returns:
239
        Tuple[str, str]: A tuple containing the path for the results summary file
240
                         and the path for the detailed samples file.
241
    """
242
    os.makedirs(output_path, exist_ok=True)
1✔
243
    results_file_path = os.path.join(output_path, f"{prefix}.json")
1✔
244
    samples_file_path = os.path.join(output_path, f"{prefix}_samples.json")
1✔
245
    return results_file_path, samples_file_path
1✔
246

247

248
def configure_unitxt_settings(args: argparse.Namespace):
1✔
249
    """Configures unitxt settings and returns a context manager.
250

251
    Args:
252
        args (argparse.Namespace): Parsed command-line arguments.
253

254
    Returns:
255
        ContextManager: A context manager for applying unitxt settings.
256
    """
257
    unitxt_settings_dict = {
1✔
258
        "disable_hf_datasets_cache": args.disable_hf_cache,
259
        "allow_unverified_code": args.trust_remote_code,
260
    }
261
    if args.cache_dir:
1✔
262
        unitxt_settings_dict["hf_cache_dir"] = args.cache_dir
1✔
263
        # Also set environment variable as some HF parts might read it directly
264
        os.environ["HF_DATASETS_CACHE"] = args.cache_dir
1✔
265
        os.environ["HF_HOME"] = args.cache_dir
1✔
266
        logger.info(f"Set HF_DATASETS_CACHE to: {args.cache_dir}")
267

268
    if args.disable_hf_cache:
1✔
269
        os.environ["UNITXT_DISABLE_HF_DATASETS_CACHE"] = "True"
1✔
270

271
    logger.info(f"Applying unitxt settings: {unitxt_settings_dict}")
272
    return settings.context(**unitxt_settings_dict)
1✔
273

274

275
def cli_load_dataset(args: argparse.Namespace) -> HFDataset:
1✔
276
    """Loads the dataset based on command line arguments.
277

278
    Args:
279
        args (argparse.Namespace): Parsed command-line arguments.
280

281
    Returns:
282
        HFDataset: The loaded dataset.
283

284
    Raises:
285
        UnitxtArtifactNotFoundError: If the specified card or template artifact is not found.
286
        FileNotFoundError: If a specified file (e.g., in a local card path) is not found.
287
        AttributeError: If there's an issue accessing attributes during loading.
288
        ValueError: If there's a value-related error during loading (e.g., parsing).
289
    """
290
    logger.info(
291
        f"Loading task/dataset using identifier: '{args.tasks}' with split '{args.split}'"
292
    )
293

294
    benchmark_subsets = {}
×
295
    for task_str in args.tasks:
×
296
        overwrite_args = extract_overwrite_args(args)
×
297
        benchmark_subsets[task_str] = load_recipe(
×
298
            dataset_query=task_str, **overwrite_args
299
        )
300

301
    # this hack circumvents an issue with multi-level benchmarks (such Bluebench's translation subset) that fail when wrapped with an additional Benchmark() object.
302
    if len(benchmark_subsets) == 1 and isinstance(
×
303
        next(iter(benchmark_subsets.values())), Benchmark
304
    ):
305
        source = next(iter(benchmark_subsets.values()))
×
306
    else:
307
        source = Benchmark(subsets=benchmark_subsets)
×
308

309
    test_dataset = _source_to_dataset(source, split=args.split)
×
310
    logger.info(
311
        f"Dataset loaded successfully. Number of instances: {len(test_dataset)}"
312
    )
313
    return test_dataset
×
314

315

316
def extract_overwrite_args(args):
1✔
317
    dataset_args = {}
×
318

319
    if args.limit is not None:
×
320
        assert (
×
321
            f"max_{args.split}_instances" not in dataset_args
322
        ), "limit was inputted both as an arg and as a task parameter"
323
        # Check if limit or loader_limit is already present
324
        # dataset_args[f"max_{args.split}_instances"] = args.limit
325
        dataset_args[f"max_{args.split}_instances"] = args.limit
×
326
        # Use loader_limit for unitxt compatibility
327
        logger.info(
328
            f"Applying limit from --limit argument: max_{args.split}_instances={args.limit}"
329
        )
330

331
    if args.num_fewshots:
×
332
        assert (
×
333
            "num_demos" not in dataset_args
334
        ), "num_demos was inputted both as an arg and as a task parameter"
335
        dataset_args["num_demos"] = args.num_fewshots
×
336
        dataset_args.update(
×
337
            {
338
                "demos_taken_from": "train",
339
                "demos_pool_size": -1,
340
                "demos_removed_from_data": True,
341
            }
342
        )  # Use loader_limit for unitxt compatibility
343
        logger.info(
344
            f"Applying limit from --limit argument: num_demos={args.num_fewshots}"
345
        )
346

347
    if args.apply_chat_template:
×
348
        assert (
×
349
            "format" not in dataset_args
350
        ), "format was inputted as a task parameter, but chat_api was requested"
351
        dataset_args["format"] = "formats.chat_api"
×
352
        logger.info(
353
            "Applying chat template from --apply_chat_template argument: format=formats.chat_api"
354
        )
355

356
    return dataset_args
×
357

358

359
def prepare_kwargs(kwargs: dict) -> Dict[str, Any]:
1✔
360
    """Prepares the model arguments dictionary.
361

362
    Args:
363
        kwargs (dict): Parsed command-line arguments.
364

365
    Returns:
366
        Dict[str, Any]: The processed model arguments dictionary.
367
    """
368
    # Ensure model_args is a dictionary, handling potential string return from try_parse_json
369
    kwargs_dict = kwargs if isinstance(kwargs, dict) else {}
1✔
370
    if not isinstance(kwargs, dict) and kwargs is not None:
1✔
371
        logger.warning(
372
            f"Could not parse kwargs '{kwargs}' as JSON or key-value pairs. Treating as empty."
373
        )
374

375
    logger.info(f"Using kwargs: {kwargs_dict}")
376
    return kwargs_dict
1✔
377

378

379
def initialize_inference_engine(
1✔
380
    args: argparse.Namespace,
381
    model_args_dict: Dict[str, Any],
382
    chat_kwargs_dict: Dict[str, Any],
383
) -> InferenceEngine:
384
    """Initializes the appropriate inference engine based on arguments.
385

386
    Args:
387
        args (argparse.Namespace): Parsed command-line arguments.
388
        model_args_dict (Dict[str, Any]): Processed model arguments.
389
        chat_kwargs_dict (Dict[str, Any]): Processed chat arguments.
390

391
    Returns:
392
        InferenceEngine: The initialized inference engine instance.
393

394
    Raises:
395
        SystemExit: If required dependencies are missing for the selected model type.
396
        ValueError: If required keys are missing in model_args for the selected model type.
397
    """
398
    inference_model = None
1✔
399
    # --- Local Hugging Face Model (using HFAutoModelInferenceEngine) ---
400
    if args.model.lower() == "hf":
1✔
401
        if "pretrained" not in model_args_dict:
1✔
402
            logger.error(
403
                "Missing 'pretrained=<model_id_or_path>' in --model_args for '--model hf'."
404
            )
405
            raise ValueError(
406
                "Argument 'pretrained' is required in --model_args when --model is 'hf'"
407
            )
408

409
        local_model_name = model_args_dict.pop("pretrained")
1✔
410
        logger.info(
411
            f"Initializing HFAutoModelInferenceEngine for model: {local_model_name}"
412
        )
413

414
        model_args_dict.update({"batch_size": args.batch_size})
1✔
415
        logger.info(f"HFAutoModelInferenceEngine args: {model_args_dict}")
416

417
        inference_model = HFAutoModelInferenceEngine(
1✔
418
            model_name=local_model_name,
419
            **model_args_dict,
420
            chat_kwargs_dict=chat_kwargs_dict,
421
        )
422

423
        # Keep the actual model name for the results
424
        args.model = inference_model.model_name
1✔
425
    # --- Remote Model (CrossProviderInferenceEngine) ---
426
    elif args.model.lower() == "cross_provider":
1✔
427
        if "model_name" not in model_args_dict:
1✔
428
            logger.error(
429
                "Missing 'model_name=<provider/model_id>' in --model_args for '--model cross_provider'."
430
            )
431
            raise ValueError(
432
                "Argument 'model_name' is required in --model_args when --model is 'cross_provider'"
433
            )
434

435
        remote_model_name = model_args_dict.pop("model_name")
1✔
436
        logger.info(
437
            f"Initializing CrossProviderInferenceEngine for model: {remote_model_name}"
438
        )
439

440
        if (
1✔
441
            "max_tokens" not in model_args_dict
442
            and "max_new_tokens" not in model_args_dict
443
        ):
444
            logger.warning(
445
                f"'max_tokens' or 'max_new_tokens' not found in --model_args, {remote_model_name} might require it."
446
            )
447

448
        logger.info(f"CrossProviderInferenceEngine args: {model_args_dict}")
449

450
        # Note: CrossProviderInferenceEngine expects 'model' parameter, not 'model_name'
451
        inference_model = CrossProviderInferenceEngine(
1✔
452
            model=remote_model_name,
453
            **model_args_dict,
454
        )
455

456
        # Keep the actual model name for the results
457
        args.model = inference_model.get_engine_id()
1✔
458
    else:
459
        # This case should not be reached due to argparse choices
460
        logger.error(
461
            f"Invalid --model type specified: {args.model}. Use 'hf' or 'cross_provider'."
462
        )
463
        sys.exit(1)  # Exit here as it's an invalid configuration
×
464

465
    return inference_model
1✔
466

467

468
def run_inference(engine: InferenceEngine, dataset: HFDataset) -> List[Any]:
1✔
469
    """Runs inference using the initialized engine.
470

471
    Args:
472
        engine (InferenceEngine): The inference engine instance.
473
        dataset (HFDataset): The dataset to run inference on.
474

475
    Returns:
476
        List[Any]: A list of predictions.
477

478
    Raises:
479
        Exception: If an error occurs during inference.
480
    """
481
    logger.info("Starting inference...")
482
    try:
1✔
483
        predictions = engine.infer(dataset)
1✔
484
        logger.info("Inference completed.")
485
        if not predictions:
1✔
486
            logger.warning("Inference returned no predictions.")
487
            return []  # Return empty list if no predictions
×
488
        if len(predictions) != len(dataset):
1✔
489
            logger.error(
490
                f"Inference returned an unexpected number of predictions ({len(predictions)}). Expected {len(dataset)}."
491
            )
492
            # Don't exit, but log error. Evaluation might still work partially or fail later.
493
        return predictions
1✔
494
    except Exception:
495
        logger.exception("An error occurred during inference")  # Use logger.exception
496
        raise  # Re-raise after logging
497

498

499
def run_evaluation(predictions: List[Any], dataset: HFDataset) -> EvaluationResults:
1✔
500
    """Runs evaluation on the predictions.
501

502
    Args:
503
        predictions (List[Any]): The list of predictions from the model.
504
        dataset (HFDataset): The dataset containing references and other data.
505

506
    Returns:
507
        EvaluationResults: The evaluated dataset (list of instances with scores).
508

509
    Raises:
510
        RuntimeError: If evaluation returns no results or an unexpected type.
511
        Exception: If any other error occurs during evaluation.
512
    """
513
    logger.info("Starting evaluation...")
514
    if not predictions:
1✔
515
        logger.warning("Skipping evaluation as there are no predictions.")
516
        return []  # Return empty list if no predictions to evaluate
1✔
517

518
    try:
1✔
519
        evaluation_results = evaluate(predictions=predictions, data=dataset)
1✔
520
        logger.info("Evaluation completed.")
521
        if not evaluation_results:
1✔
522
            logger.error("Evaluation returned no results (empty list/None).")
523
            # Raise an error as this indicates a problem in the evaluation process
524
            raise RuntimeError("Evaluation returned no results.")
1✔
525
        if not isinstance(evaluation_results, EvaluationResults):
1✔
526
            logger.error(
527
                f"Evaluation returned unexpected type: {type(evaluation_results)}. Expected list."
528
            )
529
            raise RuntimeError(
1✔
530
                f"Evaluation returned unexpected type: {type(evaluation_results)}"
531
            )
532

533
        return evaluation_results
1✔
534
    except Exception:
535
        logger.exception("An error occurred during evaluation")  # Use logger.exception
536
        raise  # Re-raise after logging
537

538

539
def _get_unitxt_commit_hash() -> Optional[str]:
1✔
540
    """Tries to get the git commit hash of the installed unitxt package."""
541
    try:
×
542
        # Find the directory of the unitxt package
543
        # Use inspect to be more robust finding the package path
544

545
        current_script_path = os.path.abspath(__file__)
×
546
        package_dir = os.path.dirname(current_script_path)
×
547

548
        # Check if it's a git repository and get the commit hash
549
        # Use absolute path for git command
550
        git_command = ["git", "-C", os.path.abspath(package_dir), "rev-parse", "HEAD"]
×
551
        logger.debug(f"Running git command: {' '.join(git_command)}")
552
        result = subprocess.run(
×
553
            git_command,
554
            capture_output=True,
555
            text=True,
556
            check=False,  # Don't raise error if git command fails
557
            encoding="utf-8",
558
            errors="ignore",  # Ignore potential decoding errors
559
        )
560
        if result.returncode == 0:
×
561
            commit_hash = result.stdout.strip()
×
562
            logger.info(f"Found unitxt git commit hash: {commit_hash}")
563
            # Verify it looks like a hash (e.g., 40 hex chars)
564
            if len(commit_hash) == 40 and all(
×
565
                c in "0123456789abcdef" for c in commit_hash
566
            ):
567
                return commit_hash
×
568
            logger.warning(
569
                f"Git command output '{commit_hash}' doesn't look like a valid commit hash."
570
            )
571
            return None
×
572
        stderr_msg = result.stderr.strip() if result.stderr else "No stderr"
×
573
        logger.warning(
574
            f"Could not get unitxt git commit hash (git command failed with code {result.returncode}): {stderr_msg}"
575
        )
576
        return None
×
577
    except ImportError:
×
578
        logger.warning("unitxt package not found, cannot determine commit hash.")
579
        return None
×
580
    except FileNotFoundError:
×
581
        logger.warning(
582
            "'git' command not found in PATH. Cannot determine unitxt commit hash."
583
        )
584
        return None
×
585
    except Exception as e:
586
        logger.warning(
587
            f"Error getting unitxt commit hash: {e}", exc_info=True
588
        )  # Log traceback
589
        return None
590

591

592
def _get_installed_packages() -> Dict[str, str]:
1✔
593
    """Gets a dictionary of installed packages and their versions."""
594
    packages = {}
×
595
    try:
×
596
        for dist in importlib.metadata.distributions():
×
597
            # Handle potential missing metadata gracefully
598
            name = dist.metadata.get("Name")
×
599
            version = dist.metadata.get("Version")
×
600
            if name and version:
×
601
                packages[name] = version
×
602
            elif name:
×
603
                packages[name] = "N/A"  # Record package even if version is missing
×
604
                logger.debug(f"Could not find version for package: {name}")
605

606
        logger.info(f"Collected versions for {len(packages)} installed packages.")
607
    except Exception as e:
608
        logger.warning(f"Could not retrieve installed package list: {e}", exc_info=True)
609
    return packages
×
610

611

612
def _get_unitxt_version() -> str:
1✔
613
    """Gets the installed unitxt version using importlib.metadata."""
614
    try:
×
615
        version = importlib.metadata.version("unitxt")
×
616
        logger.info(f"Found unitxt version using importlib.metadata: {version}")
617
        return version
×
618
    except importlib.metadata.PackageNotFoundError:
×
619
        logger.warning(
620
            "Could not find 'unitxt' package version using importlib.metadata. Is it installed correctly?"
621
        )
622
        return "N/A"
×
623
    except Exception as e:
624
        logger.warning(
625
            f"Error getting unitxt version using importlib.metadata: {e}", exc_info=True
626
        )
627
        return "N/A"
628

629

630
def prepend_timestamp_to_path(original_path, timestamp):
1✔
631
    """Takes a path string and a timestamp string, prepends the timestamp to the filename part of the path, and returns the new path string."""
632
    directory, filename = os.path.split(original_path)
1✔
633
    # Use an f-string to create the new filename with the timestamp prepended
634
    new_filename = f"{timestamp}_{filename}"
1✔
635
    # Join the directory and the new filename back together
636
    return os.path.join(directory, new_filename)
1✔
637

638

639
def _save_results_to_disk(
1✔
640
    args: argparse.Namespace,
641
    global_scores: Dict[str, Any],
642
    all_samples_data: Dict[str, List[Dict[str, Any]]],
643
    results_path: str,
644
    samples_path: str,
645
) -> None:
646
    """Saves the configuration, environment info, global scores, and samples to JSON files.
647

648
    Args:
649
        args (argparse.Namespace): Parsed command-line arguments.
650
        global_scores (Dict[str, Any]): Dictionary of global scores.
651
        all_samples_data (Dict[str, List[Dict[str, Any]]]): List of processed sample data.
652
        results_path (str): Path to save the summary results JSON file.
653
        samples_path (str): Path to save the detailed samples JSON file.
654
    """
655
    # --- Gather Configuration ---
656
    config_to_save = {}
1✔
657
    for k, v in vars(args).items():
1✔
658
        # Ensure complex objects are represented as strings
659
        if isinstance(v, (str, int, float, bool, list, dict, type(None))):
1✔
660
            config_to_save[k] = v
1✔
661
        else:
662
            try:
×
663
                # Try standard repr first
664
                config_to_save[k] = repr(v)
×
665
            except Exception:
666
                # Fallback if repr fails
667
                config_to_save[
668
                    k
669
                ] = f"<Object of type {type(v).__name__} could not be represented>"
670

671
    # --- Gather Environment Info ---
672
    unitxt_commit = _get_unitxt_commit_hash()
1✔
673
    # Get version using the dedicated function
674
    unitxt_pkg_version = _get_unitxt_version()
1✔
675

676
    environment_info = {
1✔
677
        "timestamp_utc": datetime.utcnow().isoformat() + "Z",
678
        "command_line_invocation": sys.argv,
679
        "parsed_arguments": config_to_save,  # Include parsed args here as well
680
        "unitxt_version": unitxt_pkg_version,  # Use version from importlib.metadata
681
        "unitxt_commit_hash": unitxt_commit if unitxt_commit else "N/A",
682
        "python_version": platform.python_version(),
683
        "system": platform.system(),
684
        "system_version": platform.version(),
685
        "installed_packages": _get_installed_packages(),
686
    }
687

688
    # --- Prepare Final Results Structure ---
689
    results_summary = {
1✔
690
        "environment_info": environment_info,
691
        "results": global_scores,
692
    }
693

694
    # prepend the timestamp in UTC (e.g., 2025-01-18T11-37-32) to the file names
695
    timestamp = datetime.now().astimezone(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S")
1✔
696

697
    results_path = prepend_timestamp_to_path(results_path, timestamp)
1✔
698
    samples_path = prepend_timestamp_to_path(samples_path, timestamp)
1✔
699

700
    # --- Save Summary ---
701
    logger.info(f"Saving global results summary to: {results_path}")
702
    try:
1✔
703
        with open(results_path, "w", encoding="utf-8") as f:
1✔
704
            json.dump(results_summary, f, indent=4, ensure_ascii=False)
1✔
705
    except OSError as e:
×
706
        logger.error(f"Failed to write results summary file {results_path}: {e}")
707
    except TypeError as e:
×
708
        logger.error(
709
            f"Failed to serialize results summary to JSON: {e}. Check data types."
710
        )
711
        # Log the problematic structure if possible (might be large)
712
        # logger.debug(f"Problematic results_summary structure: {results_summary}")
713

714
    # --- Save Samples (if requested) ---
715
    if args.log_samples:
1✔
716
        logger.info(f"Saving detailed samples to: {samples_path}")
717
        # Structure samples file with environment info as well for self-containment
718
        samples_output = {
×
719
            "environment_info": environment_info,  # Repeat env info here
720
            "samples": all_samples_data,
721
        }
722
        try:
×
723
            with open(samples_path, "w", encoding="utf-8") as f:
×
724
                json.dump(samples_output, f, indent=4, ensure_ascii=False)
×
725
        except OSError as e:
×
726
            logger.error(f"Failed to write samples file {samples_path}: {e}")
727
        except TypeError as e:
×
728
            logger.error(f"Failed to serialize samples to JSON: {e}. Check data types.")
729

730

731
def process_and_save_results(
1✔
732
    args: argparse.Namespace,
733
    evaluation_results: EvaluationResults,
734
    results_path: str,
735
    samples_path: str,
736
) -> None:
737
    """Processes, prints, and saves the evaluation results.
738

739
    Args:
740
        args (argparse.Namespace): Parsed command-line arguments.
741
        evaluation_results (EvaluationResults): The list of evaluated instances.
742
        results_path (str): Path to save the summary results JSON file.
743
        samples_path (str): Path to save the detailed samples JSON file.
744

745
    Raises:
746
        Exception: If an error occurs during result processing or saving (re-raised).
747
    """
748
    try:
1✔
749
        # global_scores, all_samples_data = _extract_scores_and_samples(evaluated_dataset)
750

751
        subsets_scores = evaluation_results.subsets_scores
1✔
752
        instances_results = evaluation_results.instance_scores
1✔
753

754
        subset_instances = {}
1✔
755
        for instance in instances_results:
1✔
756
            if instance["subset"][0] not in subset_instances:
1✔
757
                subset_instances[instance["subset"][0]] = []
1✔
758
            del instance["postprocessors"]
1✔
759
            subset_instances[instance["subset"][0]].append(instance)
1✔
760

761
        logger.info(f"\n{subsets_scores.summary}")
762

763
        # --- Save Results ---
764
        # Pass all necessary data to the saving function
765
        _save_results_to_disk(
1✔
766
            args, subsets_scores, subset_instances, results_path, samples_path
767
        )
768

769
    except Exception:
770
        logger.exception(
771
            "An error occurred during result processing or saving"
772
        )  # Use logger.exception
773
        raise  # Re-raise after logging
774

775

776
def main():
1✔
777
    """Main function to parse arguments and run evaluation."""
778
    parser = setup_parser()
1✔
779
    args = parser.parse_args()
1✔
780

781
    # Setup logging ASAP
782
    setup_logging(args.verbosity)
1✔
783

784
    logger.info("Starting Unitxt Evaluation CLI")
785
    # Log raw and parsed args at DEBUG level
786
    logger.debug(f"Raw command line arguments: {sys.argv}")
787
    logger.debug(f"Parsed arguments: {vars(args)}")  # Log the vars(args) dict
788
    logger.debug(
789
        f"Parsed model_args type: {type(args.model_args)}, value: {args.model_args}"
790
    )
791

792
    try:
1✔
793
        results_path, samples_path = prepare_output_paths(
1✔
794
            args.output_path, args.output_file_prefix
795
        )
796

797
        # Apply unitxt settings within a context manager
798
        with configure_unitxt_settings(args):
1✔
799
            test_dataset = cli_load_dataset(args)
1✔
800
            model_args_dict = prepare_kwargs(args.model_args)
1✔
801
            gen_kwargs_dict = prepare_kwargs(args.gen_kwargs)
1✔
802
            chat_kwargs_dict = prepare_kwargs(args.chat_template_kwargs)
1✔
803

804
            model_args_dict.update(gen_kwargs_dict)
1✔
805
            inference_model = initialize_inference_engine(
1✔
806
                args, model_args_dict, chat_kwargs_dict
807
            )
808
            predictions = run_inference(inference_model, test_dataset)
1✔
809
            evaluation_results = run_evaluation(predictions, test_dataset)
1✔
810
            process_and_save_results(
1✔
811
                args, evaluation_results, results_path, samples_path
812
            )
813

814
    # --- More Specific Error Handling ---
815
    except (UnitxtArtifactNotFoundError, FileNotFoundError) as e:
×
816
        logger.exception(f"Error loading artifact or file: {e}")
817
        sys.exit(1)
×
818
    except (AttributeError, ValueError) as e:
×
819
        # Catch issues like missing keys in args, parsing errors, etc.
820
        logger.exception(f"Configuration or value error: {e}")
821
        sys.exit(1)
×
822
    except ImportError as e:
×
823
        # Catch missing optional dependencies
824
        logger.exception(f"Missing dependency: {e}")
825
        sys.exit(1)
×
826
    except RuntimeError as e:
×
827
        # Catch errors explicitly raised during execution (e.g., evaluation failure)
828
        logger.exception(f"Runtime error during processing: {e}")
829
        sys.exit(1)
×
830
    except Exception as e:
831
        # Catch any other unexpected errors
832
        logger.exception(f"An unexpected error occurred: {e}")
833
        sys.exit(1)
834

835
    logger.info("Unitxt Evaluation CLI finished successfully.")
836

837

838
def extract_scores(folder: str, subset: str, group: str):  # pragma: no cover
1✔
839
    import pandas as pd
×
840

841
    def safe_score(d: dict, key="score"):
×
842
        na = "N/A"
×
843
        return d.get(key, na) if isinstance(d, dict) else na
×
844

845
    def extract_subset(results: dict, subset: str, group: str):
×
846
        subset_results = results.get(subset, {})
×
847
        row = {subset: safe_score(subset_results)}
×
848

849
        groups = subset_results.get("groups", {})
×
850

851
        if not groups:
×
852
            return row
×
853

854
        group_results = groups.get(group) if group else next(iter(groups.values()), {})
×
855

856
        if not isinstance(group_results, dict):
×
857
            return row
×
858

859
        row.update(
×
860
            {k: safe_score(v) for k, v in group_results.items() if isinstance(v, dict)}
861
        )
862
        return row
×
863

864
    def extract_all(results: dict):
×
865
        row = {"Average": safe_score(results)}
×
866
        row.update(
×
867
            {k: safe_score(v) for k, v in results.items() if isinstance(v, dict)}
868
        )
869
        return row
×
870

871
    data = []
×
872

873
    for filename in sorted(os.listdir(folder)):
×
874
        if not filename.endswith("evaluation_results.json"):
×
875
            continue
×
876

877
        file_path = os.path.join(folder, filename)
×
878
        try:
×
879
            with open(file_path, encoding="utf-8") as f:
×
880
                content = json.load(f)
×
881

882
                env_info = content.get("environment_info", {})
×
883
                row = {
×
884
                    "Model": safe_score(env_info.get("parsed_arguments", {}), "model"),
885
                    "Timestamp": safe_score(env_info, "timestamp_utc"),
886
                }
887

888
                results = content.get("results", {})
×
889

890
                extra = (
×
891
                    extract_subset(results, subset, group)
892
                    if subset
893
                    else extract_all(results)
894
                )
895
                row.update(extra)
×
896
                data.append(row)
×
897
        except Exception as e:
898
            logger.error(f"Error parsing results file {filename}: {e}.")
899

900
    return pd.DataFrame(data).sort_values(by="Timestamp", ascending=True)
×
901

902

903
def setup_summarization_parser() -> argparse.ArgumentParser:
1✔
904
    parser = argparse.ArgumentParser(
×
905
        formatter_class=argparse.RawTextHelpFormatter,
906
        description="CLI utility for summarizing evaluation results.",
907
    )
908

909
    parser.add_argument(
×
910
        "--folder",
911
        "-f",
912
        dest="folder",
913
        type=str,
914
        default=".",
915
        help="Directory containing evaluation results json files. Default: current folder.\n",
916
    )
917

918
    parser.add_argument(
×
919
        "--subset",
920
        "-s",
921
        type=str,
922
        dest="subset",
923
        default=None,
924
        help="Subset to filter results by. Default: none.",
925
    )
926

927
    parser.add_argument(
×
928
        "--group",
929
        "-g",
930
        type=str,
931
        dest="group",
932
        default=None,
933
        help="Group to filter results to. Requires specifying a subset. Default: first group.",
934
    )
935

936
    parser.add_argument(
×
937
        "--output",
938
        "-o",
939
        type=str,
940
        choices=["markdown", "csv"],
941
        dest="output",
942
        default="markdown",
943
        help="Output format. Can be markdown or csv. Default: markdown",
944
    )
945

946
    return parser
×
947

948

949
def summarize_cli():
1✔
950
    parser = setup_summarization_parser()
×
951
    args = parser.parse_args()
×
952

953
    df = extract_scores(args.folder, args.subset, args.group)
×
954

955
    if args.output == "markdown":
×
956
        logger.info(df.to_markdown(index=False))
957
    elif args.output == "csv":
×
958
        logger.info(df.to_csv(index=False))
959
    else:
960
        logger.error(f"Unsupported output format: {args.output}")
961

962

963
if __name__ == "__main__":
1✔
964
    main()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc