• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 14690895430

27 Apr 2025 10:03AM UTC coverage: 80.035% (+0.001%) from 80.034%
14690895430

push

github

web-flow
Fix relative imports in evaluate cli (#1758)

1619 of 2008 branches covered (80.63%)

Branch coverage included in aggregate %.

10147 of 12693 relevant lines covered (79.94%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

70.38
src/unitxt/evaluate_cli.py
1
# evaluate_cli.py
2
import argparse
1✔
3
import importlib.metadata
1✔
4
import json
1✔
5
import logging
1✔
6
import os
1✔
7
import platform
1✔
8
import subprocess
1✔
9
import sys
1✔
10
from datetime import datetime
1✔
11
from functools import partial
1✔
12
from typing import Any, Dict, List, Optional, Tuple, Union
1✔
13

14
from datasets import Dataset as HFDataset
1✔
15

16
from .api import evaluate, load_dataset
1✔
17
from .artifact import UnitxtArtifactNotFoundError
1✔
18
from .benchmark import Benchmark
1✔
19

20
# Use HFAutoModelInferenceEngine for local models
21
from .inference import (
1✔
22
    CrossProviderInferenceEngine,
23
    HFAutoModelInferenceEngine,
24
    InferenceEngine,
25
)
26
from .logging_utils import get_logger
1✔
27
from .metric_utils import EvaluationResults
1✔
28
from .parsing_utils import parse_key_equals_value_string_to_dict
1✔
29
from .settings_utils import settings
1✔
30
from .standard import DatasetRecipe
1✔
31

32
# Define logger early so it can be used in initial error handling
33
# Basic config for initial messages, will be reconfigured in main()
34
logger = get_logger()
1✔
35

36

37
def try_parse_json(value: str) -> Union[str, dict, None]:
1✔
38
    """Attempts to parse a string as JSON or key=value pairs.
39

40
    Returns the original string if parsing fails
41
    and the string doesn't look like JSON/kv pairs.
42
    Raises ArgumentTypeError if it looks like JSON but is invalid.
43
    """
44
    if value is None:
1✔
45
        return None
1✔
46
    try:
1✔
47
        # Handle simple key-value pairs like "key=value,key2=value2"
48
        if "=" in value and "{" not in value:
1✔
49
            parsed_dict = parse_key_equals_value_string_to_dict(value)
1✔
50
            if parsed_dict:
1✔
51
                return parsed_dict
1✔
52

53
        # Attempt standard JSON parsing
54
        return json.loads(value)
1✔
55

56
    except json.JSONDecodeError as e:
1✔
57
        if value.strip().startswith("{") or value.strip().startswith("["):
1✔
58
            raise argparse.ArgumentTypeError(
1✔
59
                f"Invalid JSON: '{value}'. Hint: Use double quotes for JSON strings and check syntax."
60
            ) from e
61
        return value  # Return as string if not JSON-like
1✔
62
    except Exception as e:
×
63
        logger.error(f"Error parsing argument '{value}': {e}")
×
64
        raise argparse.ArgumentTypeError(f"Could not parse argument: '{value}'") from e
×
65

66

67
def setup_parser() -> argparse.ArgumentParser:
1✔
68
    """Sets up the argument parser."""
69
    parser = argparse.ArgumentParser(
1✔
70
        formatter_class=argparse.RawTextHelpFormatter,
71
        description="CLI utility for running evaluations with unitxt.",
72
    )
73

74
    # --- Task/Dataset Arguments ---
75
    parser.add_argument(
1✔
76
        "--tasks",  # Changed to plural to better reflect it holds a list
77
        "-t",
78
        dest="tasks",  # Explicitly set the attribute name to 'tasks'
79
        type=partial(str.split, sep="+"),  # Use the custom function for type conversion
80
        required=True,
81
        help="Plus-separated (+) list of Unitxt task/dataset identifier strings.\n"
82
        "Each task format: 'card=<card_ref>,template=<template_ref>,...'\n"
83
        "Example: 'card=cards.mmlu,t=t.mmlu.all+card=cards.hellaswag,t=t.hellaswag.no'",
84
    )
85

86
    parser.add_argument(
1✔
87
        "--split",
88
        type=str,
89
        default="test",
90
        help="Dataset split to use (e.g., 'train', 'validation', 'test'). Default: 'test'.",
91
    )
92
    parser.add_argument(
1✔
93
        "--num_fewshots",
94
        type=int,
95
        default=None,
96
        help="number of fewshots to use",
97
    )
98
    parser.add_argument(
1✔
99
        "--limit",
100
        "-L",
101
        type=int,
102
        default=None,
103
        metavar="N",
104
        help="Limit the number of examples per task/dataset.",
105
    )
106

107
    parser.add_argument(
1✔
108
        "--batch_size",
109
        "-b",
110
        type=int,
111
        default=1,
112
        help="Batch size for use in inference when selected model is hf. Default 1",
113
    )
114

115
    # --- Model Arguments (Explicit Types) ---
116
    parser.add_argument(
1✔
117
        "--model",
118
        "-m",
119
        type=str,
120
        default="hf",
121
        choices=["hf", "cross_provider"],
122
        help="Specifies the model type/engine.\n"
123
        "- 'hf': Local Hugging Face model via HFAutoModel (default). Requires 'pretrained=...' in --model_args.\n"
124
        "- 'cross_provider': Remote model via CrossProviderInferenceEngine. Requires 'model_name=...' in --model_args.",
125
    )
126
    parser.add_argument(
1✔
127
        "--model_args",
128
        "-a",
129
        type=try_parse_json,
130
        default={},
131
        help="Comma separated string or JSON formatted arguments for the model/inference engine.\n"
132
        "Examples:\n"
133
        "- For --model hf (default): 'pretrained=meta-llama/Llama-3.1-8B-Instruct,torch_dtype=bfloat16,device=cuda'\n"
134
        "  (Note: 'pretrained' key is REQUIRED. Other args like 'torch_dtype', 'device', generation params are passed too)\n"
135
        "- For --model generic_remote: 'model_name=llama-3-3-70b-instruct,max_tokens=256,temperature=0.7'\n"
136
        "  (Note: 'model_name' key is REQUIRED)\n"
137
        '- JSON format: \'{"pretrained": "my_model", "torch_dtype": "float32"}\' or \'{"model_name": "openai/gpt-4o"}\'',
138
    )
139

140
    parser.add_argument(
1✔
141
        "--gen_kwargs",
142
        type=try_parse_json,
143
        default=None,
144
        help=(
145
            "Comma delimited string for model generation on greedy_until tasks,"
146
            """ e.g. temperature=0,top_p=0.1."""
147
        ),
148
    )
149

150
    parser.add_argument(
1✔
151
        "--chat_template_kwargs",
152
        type=try_parse_json,
153
        default=None,
154
        help=(
155
            "Comma delimited string for tokenizer kwargs"
156
            "e.g. thinking=True (https://github.com/huggingface/transformers/blob/9a1c1fe7edaefdb25ab37116a979832df298d6ea/src/transformers/tokenization_utils_base.py#L1542)"
157
        ),
158
    )
159

160
    # --- Output and Logging Arguments ---
161
    parser.add_argument(
1✔
162
        "--output_path",
163
        "-o",
164
        type=str,
165
        default=".",
166
        help="Directory to save evaluation results and logs. Default: current directory.",
167
    )
168
    parser.add_argument(
1✔
169
        "--output_file_prefix",
170
        type=str,
171
        default="evaluation_results",
172
        help="Prefix for the output JSON file names. Default: 'evaluation_results'.",
173
    )
174
    parser.add_argument(
1✔
175
        "--log_samples",
176
        "-s",
177
        action="store_true",
178
        default=False,
179
        help="If True, save individual predictions and scores to a separate JSON file.",
180
    )
181
    parser.add_argument(
1✔
182
        "--verbosity",
183
        "-v",
184
        type=str.upper,
185
        default="INFO",
186
        choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
187
        help="Controls logging verbosity level. Default: INFO.",
188
    )
189

190
    parser.add_argument(
1✔
191
        "--apply_chat_template",
192
        action="store_true",
193
        default=False,
194
    )
195

196
    # --- Unitxt Settings ---
197
    parser.add_argument(
1✔
198
        "--trust_remote_code",
199
        action="store_true",
200
        default=False,
201
        help="Allow execution of unverified code from the HuggingFace Hub (used by datasets/unitxt).",
202
    )
203
    parser.add_argument(
1✔
204
        "--disable_hf_cache",
205
        action="store_true",
206
        default=False,
207
        help="Disable HuggingFace datasets caching.",
208
    )
209
    parser.add_argument(
1✔
210
        "--cache_dir",
211
        type=str,
212
        default=None,
213
        help="Directory for HuggingFace datasets cache (overrides default).",
214
    )
215

216
    return parser
1✔
217

218

219
def setup_logging(verbosity: str) -> None:
1✔
220
    """Configures logging based on verbosity level."""
221
    logging.basicConfig(
×
222
        level=verbosity,
223
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
224
        force=True,  # Ensures reconfiguration works if basicConfig was called before
225
    )
226
    # Re-get the logger instance after basicConfig is set
227
    global logger
228
    logger = get_logger()
×
229
    logger.setLevel(verbosity)
×
230

231

232
def prepare_output_paths(output_path: str, prefix: str) -> Tuple[str, str]:
1✔
233
    """Creates output directory and defines file paths.
234

235
    Args:
236
        output_path (str): The directory where output files will be saved.
237
        prefix (str): The prefix for the output file names.
238

239
    Returns:
240
        Tuple[str, str]: A tuple containing the path for the results summary file
241
                         and the path for the detailed samples file.
242
    """
243
    os.makedirs(output_path, exist_ok=True)
1✔
244
    results_file_path = os.path.join(output_path, f"{prefix}.json")
1✔
245
    samples_file_path = os.path.join(output_path, f"{prefix}_samples.json")
1✔
246
    return results_file_path, samples_file_path
1✔
247

248

249
def configure_unitxt_settings(args: argparse.Namespace):
1✔
250
    """Configures unitxt settings and returns a context manager.
251

252
    Args:
253
        args (argparse.Namespace): Parsed command-line arguments.
254

255
    Returns:
256
        ContextManager: A context manager for applying unitxt settings.
257
    """
258
    unitxt_settings_dict = {
1✔
259
        "disable_hf_datasets_cache": args.disable_hf_cache,
260
        "allow_unverified_code": args.trust_remote_code,
261
    }
262
    if args.cache_dir:
1✔
263
        unitxt_settings_dict["hf_cache_dir"] = args.cache_dir
1✔
264
        # Also set environment variable as some HF parts might read it directly
265
        os.environ["HF_DATASETS_CACHE"] = args.cache_dir
1✔
266
        os.environ["HF_HOME"] = args.cache_dir
1✔
267
        logger.info(f"Set HF_DATASETS_CACHE to: {args.cache_dir}")
1✔
268

269
    if args.disable_hf_cache:
1✔
270
        os.environ["UNITXT_DISABLE_HF_DATASETS_CACHE"] = "True"
1✔
271

272
    logger.info(f"Applying unitxt settings: {unitxt_settings_dict}")
1✔
273
    return settings.context(**unitxt_settings_dict)
1✔
274

275

276
def cli_load_dataset(args: argparse.Namespace) -> HFDataset:
1✔
277
    """Loads the dataset based on command line arguments.
278

279
    Args:
280
        args (argparse.Namespace): Parsed command-line arguments.
281

282
    Returns:
283
        HFDataset: The loaded dataset.
284

285
    Raises:
286
        UnitxtArtifactNotFoundError: If the specified card or template artifact is not found.
287
        FileNotFoundError: If a specified file (e.g., in a local card path) is not found.
288
        AttributeError: If there's an issue accessing attributes during loading.
289
        ValueError: If there's a value-related error during loading (e.g., parsing).
290
    """
291
    logger.info(
1✔
292
        f"Loading task/dataset using identifier: '{args.tasks}' with split '{args.split}'"
293
    )
294

295
    benchmark_subsets = {}
1✔
296
    for task_str in args.tasks:
1✔
297
        dataset_args = task_str_to_dataset_args(task_str, args)
1✔
298

299
        benchmark_subsets[task_str] = DatasetRecipe(**dataset_args)
1✔
300

301
    benchmark = Benchmark(subsets=benchmark_subsets)
1✔
302

303
    test_dataset = load_dataset(benchmark, split=args.split)
1✔
304
    logger.info(
1✔
305
        f"Dataset loaded successfully. Number of instances: {len(test_dataset)}"
306
    )
307
    return test_dataset
1✔
308

309

310
def task_str_to_dataset_args(task_str, args):
1✔
311
    dataset_args = parse_key_equals_value_string_to_dict(task_str)
1✔
312

313
    if args.limit is not None:
1✔
314
        assert f"max_{args.split}_instances" not in dataset_args, (
1✔
315
            "limit was inputted both as an arg and as a task parameter"
316
        )
317
        # Check if limit or loader_limit is already present
318
        # dataset_args[f"max_{args.split}_instances"] = args.limit
319
        dataset_args[f"max_{args.split}_instances"] = args.limit
1✔
320
        # Use loader_limit for unitxt compatibility
321
        logger.info(
1✔
322
            f"Applying limit from --limit argument: max_{args.split}_instances={args.limit}"
323
        )
324

325
    if args.num_fewshots:
1✔
326
        assert "num_demos" not in dataset_args, (
1✔
327
            "num_demos was inputted both as an arg and as a task parameter"
328
        )
329
        dataset_args["num_demos"] = args.num_fewshots
1✔
330
        dataset_args.update(
1✔
331
            {
332
                "demos_taken_from": "train",
333
                "demos_pool_size": -1,
334
                "demos_removed_from_data": True,
335
            }
336
        )  # Use loader_limit for unitxt compatibility
337
        logger.info(
1✔
338
            f"Applying limit from --limit argument: num_demos={args.num_fewshots}"
339
        )
340

341
    if args.apply_chat_template:
1✔
342
        assert "format" not in dataset_args, (
1✔
343
            "format was inputted as a task parameter, but chat_api was requested"
344
        )
345
        dataset_args["format"] = "formats.chat_api"
1✔
346
        logger.info(
1✔
347
            "Applying chat template from --apply_chat_template argument: format=formats.chat_api"
348
        )
349

350
    return dataset_args
1✔
351

352

353
def prepare_kwargs(kwargs: dict) -> Dict[str, Any]:
1✔
354
    """Prepares the model arguments dictionary.
355

356
    Args:
357
        kwargs (dict): Parsed command-line arguments.
358

359
    Returns:
360
        Dict[str, Any]: The processed model arguments dictionary.
361
    """
362
    # Ensure model_args is a dictionary, handling potential string return from try_parse_json
363
    kwargs_dict = kwargs if isinstance(kwargs, dict) else {}
1✔
364
    if not isinstance(kwargs, dict) and kwargs is not None:
1✔
365
        logger.warning(
1✔
366
            f"Could not parse kwargs '{kwargs}' as JSON or key-value pairs. Treating as empty."
367
        )
368

369
    logger.info(f"Using kwargs: {kwargs_dict}")
1✔
370
    return kwargs_dict
1✔
371

372

373
def initialize_inference_engine(
1✔
374
    args: argparse.Namespace,
375
    model_args_dict: Dict[str, Any],
376
    chat_kwargs_dict: Dict[str, Any],
377
) -> InferenceEngine:
378
    """Initializes the appropriate inference engine based on arguments.
379

380
    Args:
381
        args (argparse.Namespace): Parsed command-line arguments.
382
        model_args_dict (Dict[str, Any]): Processed model arguments.
383
        chat_kwargs_dict (Dict[str, Any]): Processed chat arguments.
384

385
    Returns:
386
        InferenceEngine: The initialized inference engine instance.
387

388
    Raises:
389
        SystemExit: If required dependencies are missing for the selected model type.
390
        ValueError: If required keys are missing in model_args for the selected model type.
391
    """
392
    inference_model = None
1✔
393
    # --- Local Hugging Face Model (using HFAutoModelInferenceEngine) ---
394
    if args.model.lower() == "hf":
1✔
395
        if "pretrained" not in model_args_dict:
1✔
396
            logger.error(
1✔
397
                "Missing 'pretrained=<model_id_or_path>' in --model_args for '--model hf'."
398
            )
399
            raise ValueError(
1✔
400
                "Argument 'pretrained' is required in --model_args when --model is 'hf'"
401
            )
402

403
        local_model_name = model_args_dict.pop("pretrained")
1✔
404
        logger.info(
1✔
405
            f"Initializing HFAutoModelInferenceEngine for model: {local_model_name}"
406
        )
407

408
        model_args_dict.update({"batch_size": args.batch_size})
1✔
409
        logger.info(f"HFAutoModelInferenceEngine args: {model_args_dict}")
1✔
410

411
        inference_model = HFAutoModelInferenceEngine(
1✔
412
            model_name=local_model_name,
413
            **model_args_dict,
414
            chat_kwargs_dict=chat_kwargs_dict,
415
        )
416

417
    # --- Remote Model (CrossProviderInferenceEngine) ---
418
    elif args.model.lower() == "cross_provider":
1✔
419
        if "model_name" not in model_args_dict:
1✔
420
            logger.error(
1✔
421
                "Missing 'model_name=<provider/model_id>' in --model_args for '--model cross_provider'."
422
            )
423
            raise ValueError(
1✔
424
                "Argument 'model_name' is required in --model_args when --model is 'cross_provider'"
425
            )
426

427
        remote_model_name = model_args_dict.pop("model_name")
1✔
428
        logger.info(
1✔
429
            f"Initializing CrossProviderInferenceEngine for model: {remote_model_name}"
430
        )
431

432
        if (
1✔
433
            "max_tokens" not in model_args_dict
434
            and "max_new_tokens" not in model_args_dict
435
        ):
436
            logger.warning(
×
437
                f"'max_tokens' or 'max_new_tokens' not found in --model_args, {remote_model_name} might require it."
438
            )
439

440
        logger.info(f"CrossProviderInferenceEngine args: {model_args_dict}")
1✔
441

442
        # Note: CrossProviderInferenceEngine expects 'model' parameter, not 'model_name'
443
        inference_model = CrossProviderInferenceEngine(
1✔
444
            model=remote_model_name,
445
            **model_args_dict,
446
        )
447
    else:
448
        # This case should not be reached due to argparse choices
449
        logger.error(
×
450
            f"Invalid --model type specified: {args.model}. Use 'hf' or 'cross_provider'."
451
        )
452
        sys.exit(1)  # Exit here as it's an invalid configuration
×
453

454
    return inference_model
1✔
455

456

457
def run_inference(engine: InferenceEngine, dataset: HFDataset) -> List[Any]:
1✔
458
    """Runs inference using the initialized engine.
459

460
    Args:
461
        engine (InferenceEngine): The inference engine instance.
462
        dataset (HFDataset): The dataset to run inference on.
463

464
    Returns:
465
        List[Any]: A list of predictions.
466

467
    Raises:
468
        Exception: If an error occurs during inference.
469
    """
470
    logger.info("Starting inference...")
1✔
471
    try:
1✔
472
        predictions = engine.infer(dataset)
1✔
473
        logger.info("Inference completed.")
1✔
474
        if not predictions:
1✔
475
            logger.warning("Inference returned no predictions.")
×
476
            return []  # Return empty list if no predictions
×
477
        if len(predictions) != len(dataset):
1✔
478
            logger.error(
1✔
479
                f"Inference returned an unexpected number of predictions ({len(predictions)}). Expected {len(dataset)}."
480
            )
481
            # Don't exit, but log error. Evaluation might still work partially or fail later.
482
        return predictions
1✔
483
    except Exception:
1✔
484
        logger.exception("An error occurred during inference")  # Use logger.exception
1✔
485
        raise  # Re-raise after logging
1✔
486

487

488
def run_evaluation(predictions: List[Any], dataset: HFDataset) -> EvaluationResults:
1✔
489
    """Runs evaluation on the predictions.
490

491
    Args:
492
        predictions (List[Any]): The list of predictions from the model.
493
        dataset (HFDataset): The dataset containing references and other data.
494

495
    Returns:
496
        EvaluationResults: The evaluated dataset (list of instances with scores).
497

498
    Raises:
499
        RuntimeError: If evaluation returns no results or an unexpected type.
500
        Exception: If any other error occurs during evaluation.
501
    """
502
    logger.info("Starting evaluation...")
1✔
503
    if not predictions:
1✔
504
        logger.warning("Skipping evaluation as there are no predictions.")
1✔
505
        return []  # Return empty list if no predictions to evaluate
1✔
506

507
    try:
1✔
508
        evaluation_results = evaluate(predictions=predictions, data=dataset)
1✔
509
        logger.info("Evaluation completed.")
1✔
510
        if not evaluation_results:
1✔
511
            logger.error("Evaluation returned no results (empty list/None).")
1✔
512
            # Raise an error as this indicates a problem in the evaluation process
513
            raise RuntimeError("Evaluation returned no results.")
1✔
514
        if not isinstance(evaluation_results, EvaluationResults):
1✔
515
            logger.error(
1✔
516
                f"Evaluation returned unexpected type: {type(evaluation_results)}. Expected list."
517
            )
518
            raise RuntimeError(
1✔
519
                f"Evaluation returned unexpected type: {type(evaluation_results)}"
520
            )
521

522
        return evaluation_results
1✔
523
    except Exception:
1✔
524
        logger.exception("An error occurred during evaluation")  # Use logger.exception
1✔
525
        raise  # Re-raise after logging
1✔
526

527

528
def _get_unitxt_commit_hash() -> Optional[str]:
1✔
529
    """Tries to get the git commit hash of the installed unitxt package."""
530
    try:
×
531
        # Find the directory of the unitxt package
532
        # Use inspect to be more robust finding the package path
533

534
        current_script_path = os.path.abspath(__file__)
×
535
        package_dir = os.path.dirname(current_script_path)
×
536

537
        # Check if it's a git repository and get the commit hash
538
        # Use absolute path for git command
539
        git_command = ["git", "-C", os.path.abspath(package_dir), "rev-parse", "HEAD"]
×
540
        logger.debug(f"Running git command: {' '.join(git_command)}")
×
541
        result = subprocess.run(
×
542
            git_command,
543
            capture_output=True,
544
            text=True,
545
            check=False,  # Don't raise error if git command fails
546
            encoding="utf-8",
547
            errors="ignore",  # Ignore potential decoding errors
548
        )
549
        if result.returncode == 0:
×
550
            commit_hash = result.stdout.strip()
×
551
            logger.info(f"Found unitxt git commit hash: {commit_hash}")
×
552
            # Verify it looks like a hash (e.g., 40 hex chars)
553
            if len(commit_hash) == 40 and all(
×
554
                c in "0123456789abcdef" for c in commit_hash
555
            ):
556
                return commit_hash
×
557
            logger.warning(
×
558
                f"Git command output '{commit_hash}' doesn't look like a valid commit hash."
559
            )
560
            return None
×
561
        stderr_msg = result.stderr.strip() if result.stderr else "No stderr"
×
562
        logger.warning(
×
563
            f"Could not get unitxt git commit hash (git command failed with code {result.returncode}): {stderr_msg}"
564
        )
565
        return None
×
566
    except ImportError:
×
567
        logger.warning("unitxt package not found, cannot determine commit hash.")
×
568
        return None
×
569
    except FileNotFoundError:
×
570
        logger.warning(
×
571
            "'git' command not found in PATH. Cannot determine unitxt commit hash."
572
        )
573
        return None
×
574
    except Exception as e:
×
575
        logger.warning(
×
576
            f"Error getting unitxt commit hash: {e}", exc_info=True
577
        )  # Log traceback
578
        return None
×
579

580

581
def _get_installed_packages() -> Dict[str, str]:
1✔
582
    """Gets a dictionary of installed packages and their versions."""
583
    packages = {}
×
584
    try:
×
585
        for dist in importlib.metadata.distributions():
×
586
            # Handle potential missing metadata gracefully
587
            name = dist.metadata.get("Name")
×
588
            version = dist.metadata.get("Version")
×
589
            if name and version:
×
590
                packages[name] = version
×
591
            elif name:
×
592
                packages[name] = "N/A"  # Record package even if version is missing
×
593
                logger.debug(f"Could not find version for package: {name}")
×
594

595
        logger.info(f"Collected versions for {len(packages)} installed packages.")
×
596
    except Exception as e:
×
597
        logger.warning(f"Could not retrieve installed package list: {e}", exc_info=True)
×
598
    return packages
×
599

600

601
def _get_unitxt_version() -> str:
1✔
602
    """Gets the installed unitxt version using importlib.metadata."""
603
    try:
×
604
        version = importlib.metadata.version("unitxt")
×
605
        logger.info(f"Found unitxt version using importlib.metadata: {version}")
×
606
        return version
×
607
    except importlib.metadata.PackageNotFoundError:
×
608
        logger.warning(
×
609
            "Could not find 'unitxt' package version using importlib.metadata. Is it installed correctly?"
610
        )
611
        return "N/A"
×
612
    except Exception as e:
×
613
        logger.warning(
×
614
            f"Error getting unitxt version using importlib.metadata: {e}", exc_info=True
615
        )
616
        return "N/A"
×
617

618

619
def prepend_timestamp_to_path(original_path, timestamp):
1✔
620
    """Takes a path string and a timestamp string, prepends the timestamp to the filename part of the path, and returns the new path string."""
621
    directory, filename = os.path.split(original_path)
1✔
622
    # Use an f-string to create the new filename with the timestamp prepended
623
    new_filename = f"{timestamp}_{filename}"
1✔
624
    # Join the directory and the new filename back together
625
    return os.path.join(directory, new_filename)
1✔
626

627

628
def _save_results_to_disk(
1✔
629
    args: argparse.Namespace,
630
    global_scores: Dict[str, Any],
631
    all_samples_data: Dict[str, List[Dict[str, Any]]],
632
    results_path: str,
633
    samples_path: str,
634
) -> None:
635
    """Saves the configuration, environment info, global scores, and samples to JSON files.
636

637
    Args:
638
        args (argparse.Namespace): Parsed command-line arguments.
639
        global_scores (Dict[str, Any]): Dictionary of global scores.
640
        all_samples_data (Dict[str, List[Dict[str, Any]]]): List of processed sample data.
641
        results_path (str): Path to save the summary results JSON file.
642
        samples_path (str): Path to save the detailed samples JSON file.
643
    """
644
    # --- Gather Configuration ---
645
    config_to_save = {}
1✔
646
    for k, v in vars(args).items():
1✔
647
        # Ensure complex objects are represented as strings
648
        if isinstance(v, (str, int, float, bool, list, dict, type(None))):
1✔
649
            config_to_save[k] = v
1✔
650
        else:
651
            try:
×
652
                # Try standard repr first
653
                config_to_save[k] = repr(v)
×
654
            except Exception:
×
655
                # Fallback if repr fails
656
                config_to_save[k] = (
×
657
                    f"<Object of type {type(v).__name__} could not be represented>"
658
                )
659

660
    # --- Gather Environment Info ---
661
    unitxt_commit = _get_unitxt_commit_hash()
1✔
662
    # Get version using the dedicated function
663
    unitxt_pkg_version = _get_unitxt_version()
1✔
664

665
    environment_info = {
1✔
666
        "timestamp_utc": datetime.utcnow().isoformat() + "Z",
667
        "command_line_invocation": sys.argv,
668
        "parsed_arguments": config_to_save,  # Include parsed args here as well
669
        "unitxt_version": unitxt_pkg_version,  # Use version from importlib.metadata
670
        "unitxt_commit_hash": unitxt_commit if unitxt_commit else "N/A",
671
        "python_version": platform.python_version(),
672
        "system": platform.system(),
673
        "system_version": platform.version(),
674
        "installed_packages": _get_installed_packages(),
675
    }
676

677
    # --- Prepare Final Results Structure ---
678
    results_summary = {
1✔
679
        "environment_info": environment_info,
680
        "results": global_scores,
681
    }
682

683
    # prepend to the results_path name the time in a wat like this: 2025-04-04T11:37:32
684

685
    timestamp = datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
1✔
686

687
    results_path = prepend_timestamp_to_path(results_path, timestamp)
1✔
688
    samples_path = prepend_timestamp_to_path(samples_path, timestamp)
1✔
689

690
    # --- Save Summary ---
691
    logger.info(f"Saving global results summary to: {results_path}")
1✔
692
    try:
1✔
693
        with open(results_path, "w", encoding="utf-8") as f:
1✔
694
            json.dump(results_summary, f, indent=4, ensure_ascii=False)
1✔
695
    except OSError as e:
×
696
        logger.error(f"Failed to write results summary file {results_path}: {e}")
×
697
    except TypeError as e:
×
698
        logger.error(
×
699
            f"Failed to serialize results summary to JSON: {e}. Check data types."
700
        )
701
        # Log the problematic structure if possible (might be large)
702
        # logger.debug(f"Problematic results_summary structure: {results_summary}")
703

704
    # --- Save Samples (if requested) ---
705
    if args.log_samples:
1✔
706
        logger.info(f"Saving detailed samples to: {samples_path}")
×
707
        # Structure samples file with environment info as well for self-containment
708
        samples_output = {
×
709
            "environment_info": environment_info,  # Repeat env info here
710
            "samples": all_samples_data,
711
        }
712
        try:
×
713
            with open(samples_path, "w", encoding="utf-8") as f:
×
714
                json.dump(samples_output, f, indent=4, ensure_ascii=False)
×
715
        except OSError as e:
×
716
            logger.error(f"Failed to write samples file {samples_path}: {e}")
×
717
        except TypeError as e:
×
718
            logger.error(f"Failed to serialize samples to JSON: {e}. Check data types.")
×
719

720

721
def process_and_save_results(
1✔
722
    args: argparse.Namespace,
723
    evaluation_results: EvaluationResults,
724
    results_path: str,
725
    samples_path: str,
726
) -> None:
727
    """Processes, prints, and saves the evaluation results.
728

729
    Args:
730
        args (argparse.Namespace): Parsed command-line arguments.
731
        evaluation_results (EvaluationResults): The list of evaluated instances.
732
        results_path (str): Path to save the summary results JSON file.
733
        samples_path (str): Path to save the detailed samples JSON file.
734

735
    Raises:
736
        Exception: If an error occurs during result processing or saving (re-raised).
737
    """
738
    try:
1✔
739
        # global_scores, all_samples_data = _extract_scores_and_samples(evaluated_dataset)
740

741
        subsets_scores = evaluation_results.subsets_scores
1✔
742
        instances_results = evaluation_results.instance_scores
1✔
743

744
        subset_instances = {}
1✔
745
        for instance in instances_results:
1✔
746
            if instance["subset"][0] not in subset_instances:
1✔
747
                subset_instances[instance["subset"][0]] = []
1✔
748
            del instance["postprocessors"]
1✔
749
            subset_instances[instance["subset"][0]].append(instance)
1✔
750

751
        logger.info(f"\n{subsets_scores.summary}")
1✔
752

753
        # --- Save Results ---
754
        # Pass all necessary data to the saving function
755
        _save_results_to_disk(
1✔
756
            args, subsets_scores, subset_instances, results_path, samples_path
757
        )
758

759
    except Exception:
×
760
        logger.exception(
×
761
            "An error occurred during result processing or saving"
762
        )  # Use logger.exception
763
        raise  # Re-raise after logging
×
764

765

766
def main():
1✔
767
    """Main function to parse arguments and run evaluation."""
768
    parser = setup_parser()
1✔
769
    args = parser.parse_args()
1✔
770

771
    # Setup logging ASAP
772
    setup_logging(args.verbosity)
1✔
773

774
    logger.info("Starting Unitxt Evaluation CLI")
1✔
775
    # Log raw and parsed args at DEBUG level
776
    logger.debug(f"Raw command line arguments: {sys.argv}")
1✔
777
    logger.debug(f"Parsed arguments: {vars(args)}")  # Log the vars(args) dict
1✔
778
    logger.debug(
1✔
779
        f"Parsed model_args type: {type(args.model_args)}, value: {args.model_args}"
780
    )
781

782
    try:
1✔
783
        results_path, samples_path = prepare_output_paths(
1✔
784
            args.output_path, args.output_file_prefix
785
        )
786

787
        # Apply unitxt settings within a context manager
788
        with configure_unitxt_settings(args):
1✔
789
            test_dataset = cli_load_dataset(args)
1✔
790
            model_args_dict = prepare_kwargs(args.model_args)
1✔
791
            gen_kwargs_dict = prepare_kwargs(args.gen_kwargs)
1✔
792
            chat_kwargs_dict = prepare_kwargs(args.chat_template_kwargs)
1✔
793

794
            model_args_dict.update(gen_kwargs_dict)
1✔
795
            inference_model = initialize_inference_engine(
1✔
796
                args, model_args_dict, chat_kwargs_dict
797
            )
798
            predictions = run_inference(inference_model, test_dataset)
1✔
799
            evaluation_results = run_evaluation(predictions, test_dataset)
1✔
800
            process_and_save_results(
1✔
801
                args, evaluation_results, results_path, samples_path
802
            )
803

804
    # --- More Specific Error Handling ---
805
    except (UnitxtArtifactNotFoundError, FileNotFoundError) as e:
×
806
        logger.exception(f"Error loading artifact or file: {e}")
×
807
        sys.exit(1)
×
808
    except (AttributeError, ValueError) as e:
×
809
        # Catch issues like missing keys in args, parsing errors, etc.
810
        logger.exception(f"Configuration or value error: {e}")
×
811
        sys.exit(1)
×
812
    except ImportError as e:
×
813
        # Catch missing optional dependencies
814
        logger.exception(f"Missing dependency: {e}")
×
815
        sys.exit(1)
×
816
    except RuntimeError as e:
×
817
        # Catch errors explicitly raised during execution (e.g., evaluation failure)
818
        logger.exception(f"Runtime error during processing: {e}")
×
819
        sys.exit(1)
×
820
    except Exception as e:
×
821
        # Catch any other unexpected errors
822
        logger.exception(f"An unexpected error occurred: {e}")
×
823
        sys.exit(1)
×
824

825
    logger.info("Unitxt Evaluation CLI finished successfully.")
1✔
826

827

828
if __name__ == "__main__":
1✔
829
    main()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc