• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 15879351923

25 Jun 2025 02:36PM UTC coverage: 79.708% (-0.2%) from 79.887%
15879351923

Pull #1842

github

web-flow
Merge 350b03701 into 7dd0bd678
Pull Request #1842: Results summarization utility for the CLI

1698 of 2109 branches covered (80.51%)

Branch coverage included in aggregate %.

10565 of 13276 relevant lines covered (79.58%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

56.28
src/unitxt/evaluate_cli.py
1
# evaluate_cli.py
2
import argparse
1✔
3
import importlib.metadata
1✔
4
import json
1✔
5
import logging
1✔
6
import os
1✔
7
import platform
1✔
8
import subprocess
1✔
9
import sys
1✔
10
from datetime import datetime
1✔
11
from functools import partial
1✔
12
from typing import Any, Dict, List, Optional, Tuple, Union
1✔
13

14
from datasets import Dataset as HFDataset
1✔
15

16
from .api import _source_to_dataset, evaluate, load_recipe
1✔
17
from .artifact import UnitxtArtifactNotFoundError
1✔
18
from .benchmark import Benchmark
1✔
19

20
# Use HFAutoModelInferenceEngine for local models
21
from .inference import (
1✔
22
    CrossProviderInferenceEngine,
23
    HFAutoModelInferenceEngine,
24
    InferenceEngine,
25
)
26
from .logging_utils import get_logger
1✔
27
from .metric_utils import EvaluationResults
1✔
28
from .parsing_utils import parse_key_equals_value_string_to_dict
1✔
29
from .settings_utils import settings
1✔
30

31
# Define logger early so it can be used in initial error handling
32
# Basic config for initial messages, will be reconfigured in main()
33
logger = get_logger()
1✔
34

35

36
def try_parse_json(value: str) -> Union[str, dict, None]:
1✔
37
    """Attempts to parse a string as JSON or key=value pairs.
38

39
    Returns the original string if parsing fails
40
    and the string doesn't look like JSON/kv pairs.
41
    Raises ArgumentTypeError if it looks like JSON but is invalid.
42
    """
43
    if value is None:
1✔
44
        return None
1✔
45
    try:
1✔
46
        # Handle simple key-value pairs like "key=value,key2=value2"
47
        if "=" in value and "{" not in value:
1✔
48
            parsed_dict = parse_key_equals_value_string_to_dict(value)
1✔
49
            if parsed_dict:
1✔
50
                return parsed_dict
1✔
51

52
        # Attempt standard JSON parsing
53
        return json.loads(value)
1✔
54

55
    except json.JSONDecodeError as e:
1✔
56
        if value.strip().startswith("{") or value.strip().startswith("["):
1✔
57
            raise argparse.ArgumentTypeError(
1✔
58
                f"Invalid JSON: '{value}'. Hint: Use double quotes for JSON strings and check syntax."
59
            ) from e
60
        return value  # Return as string if not JSON-like
1✔
61
    except Exception as e:
×
62
        logger.error(f"Error parsing argument '{value}': {e}")
×
63
        raise argparse.ArgumentTypeError(f"Could not parse argument: '{value}'") from e
×
64

65

66
def setup_parser() -> argparse.ArgumentParser:
1✔
67
    """Sets up the argument parser."""
68
    parser = argparse.ArgumentParser(
1✔
69
        formatter_class=argparse.RawTextHelpFormatter,
70
        description="CLI utility for running evaluations with unitxt.",
71
    )
72

73
    # --- Task/Dataset Arguments ---
74
    parser.add_argument(
1✔
75
        "--tasks",  # Changed to plural to better reflect it holds a list
76
        "-t",
77
        dest="tasks",  # Explicitly set the attribute name to 'tasks'
78
        type=partial(str.split, sep="+"),  # Use the custom function for type conversion
79
        required=True,
80
        help="Plus-separated (+) list of Unitxt task/dataset identifier strings.\n"
81
        "Each task format: 'card=<card_ref>,template=<template_ref>,...'\n"
82
        "Example: 'card=cards.mmlu,t=t.mmlu.all+card=cards.hellaswag,t=t.hellaswag.no'",
83
    )
84

85
    parser.add_argument(
1✔
86
        "--split",
87
        type=str,
88
        default="test",
89
        help="Dataset split to use (e.g., 'train', 'validation', 'test'). Default: 'test'.",
90
    )
91
    parser.add_argument(
1✔
92
        "--num_fewshots",
93
        type=int,
94
        default=None,
95
        help="number of fewshots to use",
96
    )
97
    parser.add_argument(
1✔
98
        "--limit",
99
        "-L",
100
        type=int,
101
        default=None,
102
        metavar="N",
103
        help="Limit the number of examples per task/dataset.",
104
    )
105

106
    parser.add_argument(
1✔
107
        "--batch_size",
108
        "-b",
109
        type=int,
110
        default=1,
111
        help="Batch size for use in inference when selected model is hf. Default 1",
112
    )
113

114
    # --- Model Arguments (Explicit Types) ---
115
    parser.add_argument(
1✔
116
        "--model",
117
        "-m",
118
        type=str,
119
        default="hf",
120
        choices=["hf", "cross_provider"],
121
        help="Specifies the model type/engine.\n"
122
        "- 'hf': Local Hugging Face model via HFAutoModel (default). Requires 'pretrained=...' in --model_args.\n"
123
        "- 'cross_provider': Remote model via CrossProviderInferenceEngine. Requires 'model_name=...' in --model_args.",
124
    )
125
    parser.add_argument(
1✔
126
        "--model_args",
127
        "-a",
128
        type=try_parse_json,
129
        default={},
130
        help="Comma separated string or JSON formatted arguments for the model/inference engine.\n"
131
        "Examples:\n"
132
        "- For --model hf (default): 'pretrained=meta-llama/Llama-3.1-8B-Instruct,torch_dtype=bfloat16,device=cuda'\n"
133
        "  (Note: 'pretrained' key is REQUIRED. Other args like 'torch_dtype', 'device', generation params are passed too)\n"
134
        "- For --model generic_remote: 'model_name=llama-3-3-70b-instruct,max_tokens=256,temperature=0.7'\n"
135
        "  (Note: 'model_name' key is REQUIRED)\n"
136
        '- JSON format: \'{"pretrained": "my_model", "torch_dtype": "float32"}\' or \'{"model_name": "openai/gpt-4o"}\'',
137
    )
138

139
    parser.add_argument(
1✔
140
        "--gen_kwargs",
141
        type=try_parse_json,
142
        default=None,
143
        help=(
144
            "Comma delimited string for model generation on greedy_until tasks,"
145
            """ e.g. temperature=0,top_p=0.1."""
146
        ),
147
    )
148

149
    parser.add_argument(
1✔
150
        "--chat_template_kwargs",
151
        type=try_parse_json,
152
        default=None,
153
        help=(
154
            "Comma delimited string for tokenizer kwargs"
155
            "e.g. thinking=True (https://github.com/huggingface/transformers/blob/9a1c1fe7edaefdb25ab37116a979832df298d6ea/src/transformers/tokenization_utils_base.py#L1542)"
156
        ),
157
    )
158

159
    # --- Output and Logging Arguments ---
160
    parser.add_argument(
1✔
161
        "--output_path",
162
        "-o",
163
        type=str,
164
        default=".",
165
        help="Directory to save evaluation results and logs. Default: current directory.",
166
    )
167
    parser.add_argument(
1✔
168
        "--output_file_prefix",
169
        type=str,
170
        default="evaluation_results",
171
        help="Prefix for the output JSON file names. Default: 'evaluation_results'.",
172
    )
173
    parser.add_argument(
1✔
174
        "--log_samples",
175
        "-s",
176
        action="store_true",
177
        default=False,
178
        help="If True, save individual predictions and scores to a separate JSON file.",
179
    )
180
    parser.add_argument(
1✔
181
        "--verbosity",
182
        "-v",
183
        type=str.upper,
184
        default="INFO",
185
        choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
186
        help="Controls logging verbosity level. Default: INFO.",
187
    )
188

189
    parser.add_argument(
1✔
190
        "--apply_chat_template",
191
        action="store_true",
192
        default=False,
193
    )
194

195
    # --- Unitxt Settings ---
196
    parser.add_argument(
1✔
197
        "--trust_remote_code",
198
        action="store_true",
199
        default=False,
200
        help="Allow execution of unverified code from the HuggingFace Hub (used by datasets/unitxt).",
201
    )
202
    parser.add_argument(
1✔
203
        "--disable_hf_cache",
204
        action="store_true",
205
        default=False,
206
        help="Disable HuggingFace datasets caching.",
207
    )
208
    parser.add_argument(
1✔
209
        "--cache_dir",
210
        type=str,
211
        default=None,
212
        help="Directory for HuggingFace datasets cache (overrides default).",
213
    )
214

215
    return parser
1✔
216

217

218
def setup_logging(verbosity: str) -> None:
1✔
219
    """Configures logging based on verbosity level."""
220
    logging.basicConfig(
×
221
        level=verbosity,
222
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
223
        force=True,  # Ensures reconfiguration works if basicConfig was called before
224
    )
225
    # Re-get the logger instance after basicConfig is set
226
    global logger
227
    logger = get_logger()
×
228
    logger.setLevel(verbosity)
×
229

230

231
def prepare_output_paths(output_path: str, prefix: str) -> Tuple[str, str]:
1✔
232
    """Creates output directory and defines file paths.
233

234
    Args:
235
        output_path (str): The directory where output files will be saved.
236
        prefix (str): The prefix for the output file names.
237

238
    Returns:
239
        Tuple[str, str]: A tuple containing the path for the results summary file
240
                         and the path for the detailed samples file.
241
    """
242
    os.makedirs(output_path, exist_ok=True)
1✔
243
    results_file_path = os.path.join(output_path, f"{prefix}.json")
1✔
244
    samples_file_path = os.path.join(output_path, f"{prefix}_samples.json")
1✔
245
    return results_file_path, samples_file_path
1✔
246

247

248
def configure_unitxt_settings(args: argparse.Namespace):
1✔
249
    """Configures unitxt settings and returns a context manager.
250

251
    Args:
252
        args (argparse.Namespace): Parsed command-line arguments.
253

254
    Returns:
255
        ContextManager: A context manager for applying unitxt settings.
256
    """
257
    unitxt_settings_dict = {
1✔
258
        "disable_hf_datasets_cache": args.disable_hf_cache,
259
        "allow_unverified_code": args.trust_remote_code,
260
    }
261
    if args.cache_dir:
1✔
262
        unitxt_settings_dict["hf_cache_dir"] = args.cache_dir
1✔
263
        # Also set environment variable as some HF parts might read it directly
264
        os.environ["HF_DATASETS_CACHE"] = args.cache_dir
1✔
265
        os.environ["HF_HOME"] = args.cache_dir
1✔
266
        logger.info(f"Set HF_DATASETS_CACHE to: {args.cache_dir}")
1✔
267

268
    if args.disable_hf_cache:
1✔
269
        os.environ["UNITXT_DISABLE_HF_DATASETS_CACHE"] = "True"
1✔
270

271
    logger.info(f"Applying unitxt settings: {unitxt_settings_dict}")
1✔
272
    return settings.context(**unitxt_settings_dict)
1✔
273

274

275
def cli_load_dataset(args: argparse.Namespace) -> HFDataset:
1✔
276
    """Loads the dataset based on command line arguments.
277

278
    Args:
279
        args (argparse.Namespace): Parsed command-line arguments.
280

281
    Returns:
282
        HFDataset: The loaded dataset.
283

284
    Raises:
285
        UnitxtArtifactNotFoundError: If the specified card or template artifact is not found.
286
        FileNotFoundError: If a specified file (e.g., in a local card path) is not found.
287
        AttributeError: If there's an issue accessing attributes during loading.
288
        ValueError: If there's a value-related error during loading (e.g., parsing).
289
    """
290
    logger.info(
×
291
        f"Loading task/dataset using identifier: '{args.tasks}' with split '{args.split}'"
292
    )
293

294
    benchmark_subsets = {}
×
295
    for task_str in args.tasks:
×
296
        overwrite_args = extract_overwrite_args(args)
×
297
        benchmark_subsets[task_str] = load_recipe(
×
298
            dataset_query=task_str, **overwrite_args
299
        )
300

301
    # this hack circumvents an issue with multi-level benchmarks (such Bluebench's translation subset) that fail when wrapped with an additional Benchmark() object.
302
    if len(benchmark_subsets) == 1:
×
303
        source = next(iter(benchmark_subsets.values()))
×
304
    else:
305
        source = Benchmark(subsets=benchmark_subsets)
×
306

307
    test_dataset = _source_to_dataset(source, split=args.split)
×
308
    logger.info(
×
309
        f"Dataset loaded successfully. Number of instances: {len(test_dataset)}"
310
    )
311
    return test_dataset
×
312

313

314
def extract_overwrite_args(args):
1✔
315
    dataset_args = {}
×
316

317
    if args.limit is not None:
×
318
        assert (
×
319
            f"max_{args.split}_instances" not in dataset_args
320
        ), "limit was inputted both as an arg and as a task parameter"
321
        # Check if limit or loader_limit is already present
322
        # dataset_args[f"max_{args.split}_instances"] = args.limit
323
        dataset_args[f"max_{args.split}_instances"] = args.limit
×
324
        # Use loader_limit for unitxt compatibility
325
        logger.info(
×
326
            f"Applying limit from --limit argument: max_{args.split}_instances={args.limit}"
327
        )
328

329
    if args.num_fewshots:
×
330
        assert (
×
331
            "num_demos" not in dataset_args
332
        ), "num_demos was inputted both as an arg and as a task parameter"
333
        dataset_args["num_demos"] = args.num_fewshots
×
334
        dataset_args.update(
×
335
            {
336
                "demos_taken_from": "train",
337
                "demos_pool_size": -1,
338
                "demos_removed_from_data": True,
339
            }
340
        )  # Use loader_limit for unitxt compatibility
341
        logger.info(
×
342
            f"Applying limit from --limit argument: num_demos={args.num_fewshots}"
343
        )
344

345
    if args.apply_chat_template:
×
346
        assert (
×
347
            "format" not in dataset_args
348
        ), "format was inputted as a task parameter, but chat_api was requested"
349
        dataset_args["format"] = "formats.chat_api"
×
350
        logger.info(
×
351
            "Applying chat template from --apply_chat_template argument: format=formats.chat_api"
352
        )
353

354
    return dataset_args
×
355

356

357
def prepare_kwargs(kwargs: dict) -> Dict[str, Any]:
1✔
358
    """Prepares the model arguments dictionary.
359

360
    Args:
361
        kwargs (dict): Parsed command-line arguments.
362

363
    Returns:
364
        Dict[str, Any]: The processed model arguments dictionary.
365
    """
366
    # Ensure model_args is a dictionary, handling potential string return from try_parse_json
367
    kwargs_dict = kwargs if isinstance(kwargs, dict) else {}
1✔
368
    if not isinstance(kwargs, dict) and kwargs is not None:
1✔
369
        logger.warning(
1✔
370
            f"Could not parse kwargs '{kwargs}' as JSON or key-value pairs. Treating as empty."
371
        )
372

373
    logger.info(f"Using kwargs: {kwargs_dict}")
1✔
374
    return kwargs_dict
1✔
375

376

377
def initialize_inference_engine(
1✔
378
    args: argparse.Namespace,
379
    model_args_dict: Dict[str, Any],
380
    chat_kwargs_dict: Dict[str, Any],
381
) -> InferenceEngine:
382
    """Initializes the appropriate inference engine based on arguments.
383

384
    Args:
385
        args (argparse.Namespace): Parsed command-line arguments.
386
        model_args_dict (Dict[str, Any]): Processed model arguments.
387
        chat_kwargs_dict (Dict[str, Any]): Processed chat arguments.
388

389
    Returns:
390
        InferenceEngine: The initialized inference engine instance.
391

392
    Raises:
393
        SystemExit: If required dependencies are missing for the selected model type.
394
        ValueError: If required keys are missing in model_args for the selected model type.
395
    """
396
    inference_model = None
1✔
397
    # --- Local Hugging Face Model (using HFAutoModelInferenceEngine) ---
398
    if args.model.lower() == "hf":
1✔
399
        if "pretrained" not in model_args_dict:
1✔
400
            logger.error(
1✔
401
                "Missing 'pretrained=<model_id_or_path>' in --model_args for '--model hf'."
402
            )
403
            raise ValueError(
1✔
404
                "Argument 'pretrained' is required in --model_args when --model is 'hf'"
405
            )
406

407
        local_model_name = model_args_dict.pop("pretrained")
1✔
408
        logger.info(
1✔
409
            f"Initializing HFAutoModelInferenceEngine for model: {local_model_name}"
410
        )
411

412
        model_args_dict.update({"batch_size": args.batch_size})
1✔
413
        logger.info(f"HFAutoModelInferenceEngine args: {model_args_dict}")
1✔
414

415
        inference_model = HFAutoModelInferenceEngine(
1✔
416
            model_name=local_model_name,
417
            **model_args_dict,
418
            chat_kwargs_dict=chat_kwargs_dict,
419
        )
420

421
        # Keep the actual model name for the results
422
        args.model = inference_model.model_name
1✔
423
    # --- Remote Model (CrossProviderInferenceEngine) ---
424
    elif args.model.lower() == "cross_provider":
1✔
425
        if "model_name" not in model_args_dict:
1✔
426
            logger.error(
1✔
427
                "Missing 'model_name=<provider/model_id>' in --model_args for '--model cross_provider'."
428
            )
429
            raise ValueError(
1✔
430
                "Argument 'model_name' is required in --model_args when --model is 'cross_provider'"
431
            )
432

433
        remote_model_name = model_args_dict.pop("model_name")
1✔
434
        logger.info(
1✔
435
            f"Initializing CrossProviderInferenceEngine for model: {remote_model_name}"
436
        )
437

438
        if (
1✔
439
            "max_tokens" not in model_args_dict
440
            and "max_new_tokens" not in model_args_dict
441
        ):
442
            logger.warning(
×
443
                f"'max_tokens' or 'max_new_tokens' not found in --model_args, {remote_model_name} might require it."
444
            )
445

446
        logger.info(f"CrossProviderInferenceEngine args: {model_args_dict}")
1✔
447

448
        # Note: CrossProviderInferenceEngine expects 'model' parameter, not 'model_name'
449
        inference_model = CrossProviderInferenceEngine(
1✔
450
            model=remote_model_name,
451
            **model_args_dict,
452
        )
453

454
        # Keep the actual model name for the results
455
        args.model = inference_model.engine.model
1✔
456
    else:
457
        # This case should not be reached due to argparse choices
458
        logger.error(
×
459
            f"Invalid --model type specified: {args.model}. Use 'hf' or 'cross_provider'."
460
        )
461
        sys.exit(1)  # Exit here as it's an invalid configuration
×
462

463
    return inference_model
1✔
464

465

466
def run_inference(engine: InferenceEngine, dataset: HFDataset) -> List[Any]:
1✔
467
    """Runs inference using the initialized engine.
468

469
    Args:
470
        engine (InferenceEngine): The inference engine instance.
471
        dataset (HFDataset): The dataset to run inference on.
472

473
    Returns:
474
        List[Any]: A list of predictions.
475

476
    Raises:
477
        Exception: If an error occurs during inference.
478
    """
479
    logger.info("Starting inference...")
1✔
480
    try:
1✔
481
        predictions = engine.infer(dataset)
1✔
482
        logger.info("Inference completed.")
1✔
483
        if not predictions:
1✔
484
            logger.warning("Inference returned no predictions.")
×
485
            return []  # Return empty list if no predictions
×
486
        if len(predictions) != len(dataset):
1✔
487
            logger.error(
1✔
488
                f"Inference returned an unexpected number of predictions ({len(predictions)}). Expected {len(dataset)}."
489
            )
490
            # Don't exit, but log error. Evaluation might still work partially or fail later.
491
        return predictions
1✔
492
    except Exception:
1✔
493
        logger.exception("An error occurred during inference")  # Use logger.exception
1✔
494
        raise  # Re-raise after logging
1✔
495

496

497
def run_evaluation(predictions: List[Any], dataset: HFDataset) -> EvaluationResults:
1✔
498
    """Runs evaluation on the predictions.
499

500
    Args:
501
        predictions (List[Any]): The list of predictions from the model.
502
        dataset (HFDataset): The dataset containing references and other data.
503

504
    Returns:
505
        EvaluationResults: The evaluated dataset (list of instances with scores).
506

507
    Raises:
508
        RuntimeError: If evaluation returns no results or an unexpected type.
509
        Exception: If any other error occurs during evaluation.
510
    """
511
    logger.info("Starting evaluation...")
1✔
512
    if not predictions:
1✔
513
        logger.warning("Skipping evaluation as there are no predictions.")
1✔
514
        return []  # Return empty list if no predictions to evaluate
1✔
515

516
    try:
1✔
517
        evaluation_results = evaluate(predictions=predictions, data=dataset)
1✔
518
        logger.info("Evaluation completed.")
1✔
519
        if not evaluation_results:
1✔
520
            logger.error("Evaluation returned no results (empty list/None).")
1✔
521
            # Raise an error as this indicates a problem in the evaluation process
522
            raise RuntimeError("Evaluation returned no results.")
1✔
523
        if not isinstance(evaluation_results, EvaluationResults):
1✔
524
            logger.error(
1✔
525
                f"Evaluation returned unexpected type: {type(evaluation_results)}. Expected list."
526
            )
527
            raise RuntimeError(
1✔
528
                f"Evaluation returned unexpected type: {type(evaluation_results)}"
529
            )
530

531
        return evaluation_results
1✔
532
    except Exception:
1✔
533
        logger.exception("An error occurred during evaluation")  # Use logger.exception
1✔
534
        raise  # Re-raise after logging
1✔
535

536

537
def _get_unitxt_commit_hash() -> Optional[str]:
1✔
538
    """Tries to get the git commit hash of the installed unitxt package."""
539
    try:
×
540
        # Find the directory of the unitxt package
541
        # Use inspect to be more robust finding the package path
542

543
        current_script_path = os.path.abspath(__file__)
×
544
        package_dir = os.path.dirname(current_script_path)
×
545

546
        # Check if it's a git repository and get the commit hash
547
        # Use absolute path for git command
548
        git_command = ["git", "-C", os.path.abspath(package_dir), "rev-parse", "HEAD"]
×
549
        logger.debug(f"Running git command: {' '.join(git_command)}")
×
550
        result = subprocess.run(
×
551
            git_command,
552
            capture_output=True,
553
            text=True,
554
            check=False,  # Don't raise error if git command fails
555
            encoding="utf-8",
556
            errors="ignore",  # Ignore potential decoding errors
557
        )
558
        if result.returncode == 0:
×
559
            commit_hash = result.stdout.strip()
×
560
            logger.info(f"Found unitxt git commit hash: {commit_hash}")
×
561
            # Verify it looks like a hash (e.g., 40 hex chars)
562
            if len(commit_hash) == 40 and all(
×
563
                c in "0123456789abcdef" for c in commit_hash
564
            ):
565
                return commit_hash
×
566
            logger.warning(
×
567
                f"Git command output '{commit_hash}' doesn't look like a valid commit hash."
568
            )
569
            return None
×
570
        stderr_msg = result.stderr.strip() if result.stderr else "No stderr"
×
571
        logger.warning(
×
572
            f"Could not get unitxt git commit hash (git command failed with code {result.returncode}): {stderr_msg}"
573
        )
574
        return None
×
575
    except ImportError:
×
576
        logger.warning("unitxt package not found, cannot determine commit hash.")
×
577
        return None
×
578
    except FileNotFoundError:
×
579
        logger.warning(
×
580
            "'git' command not found in PATH. Cannot determine unitxt commit hash."
581
        )
582
        return None
×
583
    except Exception as e:
×
584
        logger.warning(
×
585
            f"Error getting unitxt commit hash: {e}", exc_info=True
586
        )  # Log traceback
587
        return None
×
588

589

590
def _get_installed_packages() -> Dict[str, str]:
1✔
591
    """Gets a dictionary of installed packages and their versions."""
592
    packages = {}
×
593
    try:
×
594
        for dist in importlib.metadata.distributions():
×
595
            # Handle potential missing metadata gracefully
596
            name = dist.metadata.get("Name")
×
597
            version = dist.metadata.get("Version")
×
598
            if name and version:
×
599
                packages[name] = version
×
600
            elif name:
×
601
                packages[name] = "N/A"  # Record package even if version is missing
×
602
                logger.debug(f"Could not find version for package: {name}")
×
603

604
        logger.info(f"Collected versions for {len(packages)} installed packages.")
×
605
    except Exception as e:
×
606
        logger.warning(f"Could not retrieve installed package list: {e}", exc_info=True)
×
607
    return packages
×
608

609

610
def _get_unitxt_version() -> str:
1✔
611
    """Gets the installed unitxt version using importlib.metadata."""
612
    try:
×
613
        version = importlib.metadata.version("unitxt")
×
614
        logger.info(f"Found unitxt version using importlib.metadata: {version}")
×
615
        return version
×
616
    except importlib.metadata.PackageNotFoundError:
×
617
        logger.warning(
×
618
            "Could not find 'unitxt' package version using importlib.metadata. Is it installed correctly?"
619
        )
620
        return "N/A"
×
621
    except Exception as e:
×
622
        logger.warning(
×
623
            f"Error getting unitxt version using importlib.metadata: {e}", exc_info=True
624
        )
625
        return "N/A"
×
626

627

628
def prepend_timestamp_to_path(original_path, timestamp):
1✔
629
    """Takes a path string and a timestamp string, prepends the timestamp to the filename part of the path, and returns the new path string."""
630
    directory, filename = os.path.split(original_path)
1✔
631
    # Use an f-string to create the new filename with the timestamp prepended
632
    new_filename = f"{timestamp}_{filename}"
1✔
633
    # Join the directory and the new filename back together
634
    return os.path.join(directory, new_filename)
1✔
635

636

637
def _save_results_to_disk(
1✔
638
    args: argparse.Namespace,
639
    global_scores: Dict[str, Any],
640
    all_samples_data: Dict[str, List[Dict[str, Any]]],
641
    results_path: str,
642
    samples_path: str,
643
) -> None:
644
    """Saves the configuration, environment info, global scores, and samples to JSON files.
645

646
    Args:
647
        args (argparse.Namespace): Parsed command-line arguments.
648
        global_scores (Dict[str, Any]): Dictionary of global scores.
649
        all_samples_data (Dict[str, List[Dict[str, Any]]]): List of processed sample data.
650
        results_path (str): Path to save the summary results JSON file.
651
        samples_path (str): Path to save the detailed samples JSON file.
652
    """
653
    # --- Gather Configuration ---
654
    config_to_save = {}
1✔
655
    for k, v in vars(args).items():
1✔
656
        # Ensure complex objects are represented as strings
657
        if isinstance(v, (str, int, float, bool, list, dict, type(None))):
1✔
658
            config_to_save[k] = v
1✔
659
        else:
660
            try:
×
661
                # Try standard repr first
662
                config_to_save[k] = repr(v)
×
663
            except Exception:
×
664
                # Fallback if repr fails
665
                config_to_save[
×
666
                    k
667
                ] = f"<Object of type {type(v).__name__} could not be represented>"
668

669
    # --- Gather Environment Info ---
670
    unitxt_commit = _get_unitxt_commit_hash()
1✔
671
    # Get version using the dedicated function
672
    unitxt_pkg_version = _get_unitxt_version()
1✔
673

674
    environment_info = {
1✔
675
        "timestamp_utc": datetime.utcnow().isoformat() + "Z",
676
        "command_line_invocation": sys.argv,
677
        "parsed_arguments": config_to_save,  # Include parsed args here as well
678
        "unitxt_version": unitxt_pkg_version,  # Use version from importlib.metadata
679
        "unitxt_commit_hash": unitxt_commit if unitxt_commit else "N/A",
680
        "python_version": platform.python_version(),
681
        "system": platform.system(),
682
        "system_version": platform.version(),
683
        "installed_packages": _get_installed_packages(),
684
    }
685

686
    # --- Prepare Final Results Structure ---
687
    results_summary = {
1✔
688
        "environment_info": environment_info,
689
        "results": global_scores,
690
    }
691

692
    # prepend to the results_path name the time in a wat like this: 2025-04-04T11:37:32
693

694
    timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
1✔
695

696
    results_path = prepend_timestamp_to_path(results_path, timestamp)
1✔
697
    samples_path = prepend_timestamp_to_path(samples_path, timestamp)
1✔
698

699
    # --- Save Summary ---
700
    logger.info(f"Saving global results summary to: {results_path}")
1✔
701
    try:
1✔
702
        with open(results_path, "w", encoding="utf-8") as f:
1✔
703
            json.dump(results_summary, f, indent=4, ensure_ascii=False)
1✔
704
    except OSError as e:
×
705
        logger.error(f"Failed to write results summary file {results_path}: {e}")
×
706
    except TypeError as e:
×
707
        logger.error(
×
708
            f"Failed to serialize results summary to JSON: {e}. Check data types."
709
        )
710
        # Log the problematic structure if possible (might be large)
711
        # logger.debug(f"Problematic results_summary structure: {results_summary}")
712

713
    # --- Save Samples (if requested) ---
714
    if args.log_samples:
1✔
715
        logger.info(f"Saving detailed samples to: {samples_path}")
×
716
        # Structure samples file with environment info as well for self-containment
717
        samples_output = {
×
718
            "environment_info": environment_info,  # Repeat env info here
719
            "samples": all_samples_data,
720
        }
721
        try:
×
722
            with open(samples_path, "w", encoding="utf-8") as f:
×
723
                json.dump(samples_output, f, indent=4, ensure_ascii=False)
×
724
        except OSError as e:
×
725
            logger.error(f"Failed to write samples file {samples_path}: {e}")
×
726
        except TypeError as e:
×
727
            logger.error(f"Failed to serialize samples to JSON: {e}. Check data types.")
×
728

729

730
def process_and_save_results(
1✔
731
    args: argparse.Namespace,
732
    evaluation_results: EvaluationResults,
733
    results_path: str,
734
    samples_path: str,
735
) -> None:
736
    """Processes, prints, and saves the evaluation results.
737

738
    Args:
739
        args (argparse.Namespace): Parsed command-line arguments.
740
        evaluation_results (EvaluationResults): The list of evaluated instances.
741
        results_path (str): Path to save the summary results JSON file.
742
        samples_path (str): Path to save the detailed samples JSON file.
743

744
    Raises:
745
        Exception: If an error occurs during result processing or saving (re-raised).
746
    """
747
    try:
1✔
748
        # global_scores, all_samples_data = _extract_scores_and_samples(evaluated_dataset)
749

750
        subsets_scores = evaluation_results.subsets_scores
1✔
751
        instances_results = evaluation_results.instance_scores
1✔
752

753
        subset_instances = {}
1✔
754
        for instance in instances_results:
1✔
755
            if instance["subset"][0] not in subset_instances:
1✔
756
                subset_instances[instance["subset"][0]] = []
1✔
757
            del instance["postprocessors"]
1✔
758
            subset_instances[instance["subset"][0]].append(instance)
1✔
759

760
        logger.info(f"\n{subsets_scores.summary}")
1✔
761

762
        # --- Save Results ---
763
        # Pass all necessary data to the saving function
764
        _save_results_to_disk(
1✔
765
            args, subsets_scores, subset_instances, results_path, samples_path
766
        )
767

768
    except Exception:
×
769
        logger.exception(
×
770
            "An error occurred during result processing or saving"
771
        )  # Use logger.exception
772
        raise  # Re-raise after logging
×
773

774

775
def main():
1✔
776
    """Main function to parse arguments and run evaluation."""
777
    parser = setup_parser()
1✔
778
    args = parser.parse_args()
1✔
779

780
    # Setup logging ASAP
781
    setup_logging(args.verbosity)
1✔
782

783
    logger.info("Starting Unitxt Evaluation CLI")
1✔
784
    # Log raw and parsed args at DEBUG level
785
    logger.debug(f"Raw command line arguments: {sys.argv}")
1✔
786
    logger.debug(f"Parsed arguments: {vars(args)}")  # Log the vars(args) dict
1✔
787
    logger.debug(
1✔
788
        f"Parsed model_args type: {type(args.model_args)}, value: {args.model_args}"
789
    )
790

791
    try:
1✔
792
        results_path, samples_path = prepare_output_paths(
1✔
793
            args.output_path, args.output_file_prefix
794
        )
795

796
        # Apply unitxt settings within a context manager
797
        with configure_unitxt_settings(args):
1✔
798
            test_dataset = cli_load_dataset(args)
1✔
799
            model_args_dict = prepare_kwargs(args.model_args)
1✔
800
            gen_kwargs_dict = prepare_kwargs(args.gen_kwargs)
1✔
801
            chat_kwargs_dict = prepare_kwargs(args.chat_template_kwargs)
1✔
802

803
            model_args_dict.update(gen_kwargs_dict)
1✔
804
            inference_model = initialize_inference_engine(
1✔
805
                args, model_args_dict, chat_kwargs_dict
806
            )
807
            predictions = run_inference(inference_model, test_dataset)
1✔
808
            evaluation_results = run_evaluation(predictions, test_dataset)
1✔
809
            process_and_save_results(
1✔
810
                args, evaluation_results, results_path, samples_path
811
            )
812

813
    # --- More Specific Error Handling ---
814
    except (UnitxtArtifactNotFoundError, FileNotFoundError) as e:
×
815
        logger.exception(f"Error loading artifact or file: {e}")
×
816
        sys.exit(1)
×
817
    except (AttributeError, ValueError) as e:
×
818
        # Catch issues like missing keys in args, parsing errors, etc.
819
        logger.exception(f"Configuration or value error: {e}")
×
820
        sys.exit(1)
×
821
    except ImportError as e:
×
822
        # Catch missing optional dependencies
823
        logger.exception(f"Missing dependency: {e}")
×
824
        sys.exit(1)
×
825
    except RuntimeError as e:
×
826
        # Catch errors explicitly raised during execution (e.g., evaluation failure)
827
        logger.exception(f"Runtime error during processing: {e}")
×
828
        sys.exit(1)
×
829
    except Exception as e:
×
830
        # Catch any other unexpected errors
831
        logger.exception(f"An unexpected error occurred: {e}")
×
832
        sys.exit(1)
×
833

834
    logger.info("Unitxt Evaluation CLI finished successfully.")
1✔
835

836

837
def extract_scores(directory):
1✔
838
    import pandas as pd
×
839

840
    data = []
×
841

842
    for filename in sorted(os.listdir(directory)):
×
843
        if filename.endswith("evaluation_results.json"):
×
844
            file_path = os.path.join(directory, filename)
×
845
            try:
×
846
                with open(file_path, encoding="utf-8") as f:
×
847
                    content = json.load(f)
×
848

849
                    env_info = content.get("environment_info", {})
×
850
                    timestamp = env_info.get("timestamp_utc", "N/A")
×
851
                    model = env_info.get("parsed_arguments", {}).get("model", "N/A")
×
852
                    results = content.get("results", {})
×
853

854
                    row = {}
×
855
                    row["Model"] = model
×
856
                    row["Timestamp"] = timestamp
×
857
                    row["Average"] = results.get("score", "N/A")
×
858

859
                    for key in results.keys():
×
860
                        if isinstance(results[key], dict):
×
861
                            score = results[key].get("score", "N/A")
×
862
                            row[key] = score
×
863

864
                    data.append(row)
×
865
            except Exception as e:
×
866
                logger.error(f"Error parsing results file {filename}: {e}.")
×
867

868
    return pd.DataFrame(data).sort_values(by="Timestamp", ascending=True)
×
869

870

871
def summarize_cli():
1✔
872
    if len(sys.argv) != 2:
×
873
        logger.error("Usage: python summarize_cli_results.py <results-directory>")
×
874
        sys.exit(1)
×
875
    directory = sys.argv[1]
×
876
    df = extract_scores(directory)
×
877

878
    logger.info(df.to_markdown(index=False))
×
879

880

881
if __name__ == "__main__":
1✔
882
    main()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc