• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpenCOMPES / sed / 12876831595

20 Jan 2025 10:55PM UTC coverage: 92.174% (+0.4%) from 91.801%
12876831595

Pull #437

github

web-flow
Merge pull request #555 from OpenCOMPES/config_renaming

use user platformdir also for user config
Pull Request #437: Upgrade to V1

2235 of 2372 new or added lines in 53 files covered. (94.22%)

4 existing lines in 1 file now uncovered.

7703 of 8357 relevant lines covered (92.17%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.63
/src/sed/loader/utils.py
1
"""Utilities for loaders
2
"""
3
from __future__ import annotations
1✔
4

5
from collections.abc import Sequence
1✔
6
from glob import glob
1✔
7
from pathlib import Path
1✔
8
from typing import cast
1✔
9

10
import dask.dataframe
1✔
11
import numpy as np
1✔
12
import pandas as pd
1✔
13
import pyarrow.parquet as pq
1✔
14
from h5py import File
1✔
15
from h5py import Group
1✔
16
from natsort import natsorted
1✔
17

18

19
def gather_files(
1✔
20
    folder: str,
21
    extension: str,
22
    f_start: int = None,
23
    f_end: int = None,
24
    f_step: int = 1,
25
    file_sorting: bool = True,
26
) -> list[str]:
27
    """Collects and sorts files with specified extension from a given folder.
28

29
    Args:
30
        folder (str): The folder to search
31
        extension (str):  File extension used for glob.glob().
32
        f_start (int, optional): Start file id used to construct a file selector.
33
            Defaults to None.
34
        f_end (int, optional): End file id used to construct a file selector.
35
            Defaults to None.
36
        f_step (int, optional): Step of file id incrementation, used to construct
37
            a file selector. Defaults to 1.
38
        file_sorting (bool, optional): Option to sort the files by their names.
39
            Defaults to True.
40

41
    Returns:
42
        list[str]: List of collected file names.
43
    """
44
    try:
1✔
45
        files = glob(folder + "/*." + extension)
1✔
46

47
        if file_sorting:
1✔
48
            files = cast(list[str], natsorted(files))
1✔
49

50
        if f_start is not None and f_end is not None:
1✔
51
            files = files[slice(f_start, f_end, f_step)]
×
52

53
    except FileNotFoundError:
×
54
        print("No legitimate folder address is specified for file retrieval!")
×
55
        raise
×
56

57
    return files
1✔
58

59

60
def parse_h5_keys(h5_file: File, prefix: str = "") -> list[str]:
1✔
61
    """Helper method which parses the channels present in the h5 file
62
    Args:
63
        h5_file (h5py.File): The H5 file object.
64
        prefix (str, optional): The prefix for the channel names.
65
        Defaults to an empty string.
66

67
    Returns:
68
        list[str]: A list of channel names in the H5 file.
69

70
    Raises:
71
        Exception: If an error occurs while parsing the keys.
72
    """
73

74
    # Initialize an empty list to store the channels
75
    file_channel_list = []
1✔
76

77
    # Iterate over the keys in the H5 file
78
    for key in h5_file.keys():
1✔
79
        try:
1✔
80
            # Check if the object corresponding to the key is a group
81
            if isinstance(h5_file[key], Group):
1✔
82
                # If it's a group, recursively call the function on the group object
83
                # and append the returned channels to the file_channel_list
84
                file_channel_list.extend(
1✔
85
                    parse_h5_keys(h5_file[key], prefix=prefix + "/" + key),
86
                )
87
            else:
88
                # If it's not a group (i.e., it's a dataset), append the key
89
                # to the file_channel_list
90
                file_channel_list.append(prefix + "/" + key)
1✔
91
        except KeyError as exception:
×
92
            # If an exception occurs, raise a new exception with an error message
93
            raise KeyError(
×
94
                f"Error parsing key: {prefix}/{key}",
95
            ) from exception
96

97
    # Return the list of channels
98
    return file_channel_list
1✔
99

100

101
def split_channel_bitwise(
1✔
102
    df: dask.dataframe.DataFrame,
103
    input_column: str,
104
    output_columns: Sequence[str],
105
    bit_mask: int,
106
    overwrite: bool = False,
107
    types: Sequence[type] = None,
108
) -> dask.dataframe.DataFrame:
109
    """Splits a channel into two channels bitwise.
110

111
    This function splits a channel into two channels by separating the first n bits from
112
    the remaining bits. The first n bits are stored in the first output column, the
113
    remaining bits are stored in the second output column.
114

115
    Args:
116
        df (dask.dataframe.DataFrame): Dataframe to use.
117
        input_column (str): Name of the column to split.
118
        output_columns (Sequence[str]): Names of the columns to create.
119
        bit_mask (int): Bit mask to use for splitting.
120
        overwrite (bool, optional): Whether to overwrite existing columns.
121
            Defaults to False.
122
        types (Sequence[type], optional): Types of the new columns.
123

124
    Returns:
125
        dask.dataframe.DataFrame: Dataframe with the new columns.
126
    """
127
    if len(output_columns) != 2:
1✔
128
        raise ValueError("Exactly two output columns must be given.")
1✔
129
    if input_column not in df.columns:
1✔
130
        raise KeyError(f"Column {input_column} not in dataframe.")
1✔
131
    if output_columns[0] in df.columns and not overwrite:
1✔
132
        raise KeyError(f"Column {output_columns[0]} already in dataframe.")
1✔
133
    if output_columns[1] in df.columns and not overwrite:
1✔
134
        raise KeyError(f"Column {output_columns[1]} already in dataframe.")
×
135
    if bit_mask < 0 or not isinstance(bit_mask, int):
1✔
136
        raise ValueError("bit_mask must be a positive. integer")
1✔
137
    if types is None:
1✔
138
        types = [np.int8 if bit_mask < 8 else np.int16, np.int32]
1✔
139
    elif len(types) != 2:
1✔
140
        raise ValueError("Exactly two types must be given.")
1✔
141
    elif not all(isinstance(t, type) for t in types):
1✔
142
        raise ValueError("types must be a sequence of types.")
1✔
143
    df[output_columns[0]] = (df[input_column] % 2**bit_mask).astype(types[0])
1✔
144
    df[output_columns[1]] = (df[input_column] // 2**bit_mask).astype(types[1])
1✔
145
    return df
1✔
146

147

148
def split_dld_time_from_sector_id(
1✔
149
    df: pd.DataFrame | dask.dataframe.DataFrame,
150
    tof_column: str = None,
151
    sector_id_column: str = None,
152
    sector_id_reserved_bits: int = None,
153
    config: dict = None,
154
) -> tuple[pd.DataFrame | dask.dataframe.DataFrame, dict]:
155
    """Converts the 8s time in steps to time in steps and sectorID.
156

157
    The 8s detector encodes the dldSectorID in the 3 least significant bits of the
158
    dldTimeSteps channel.
159

160
    Args:
161
        df (pd.DataFrame | dask.dataframe.DataFrame): Dataframe to use.
162
        tof_column (str, optional): Name of the column containing the
163
            time-of-flight steps. Defaults to config["dataframe"]["columns"]["tof"].
164
        sector_id_column (str, optional): Name of the column containing the
165
            sectorID. Defaults to config["dataframe"]["columns"]["sector_id"].
166
        sector_id_reserved_bits (int, optional): Number of bits reserved for the
167
        config (dict, optional): Dataframe configuration dictionary. Defaults to None.
168

169
    Returns:
170
        pd.DataFrame | dask.dataframe.DataFrame: Dataframe with the new columns.
171
    """
172
    if tof_column is None:
1✔
173
        if config is None:
1✔
174
            raise ValueError("Either tof_column or config must be given.")
×
175
        tof_column = config["columns"]["tof"]
1✔
176
    if sector_id_column is None:
1✔
177
        if config is None:
1✔
178
            raise ValueError("Either sector_id_column or config must be given.")
×
179
        sector_id_column = config["columns"]["sector_id"]
1✔
180
    if sector_id_reserved_bits is None:
1✔
181
        if config is None:
1✔
182
            raise ValueError("Either sector_id_reserved_bits or config must be given.")
×
183
        sector_id_reserved_bits = config.get("sector_id_reserved_bits", None)
1✔
184
        if sector_id_reserved_bits is None:
1✔
185
            raise ValueError('No value for "sector_id_reserved_bits" found in config.')
×
186

187
    if sector_id_column in df.columns:
1✔
NEW
188
        metadata = {"applied": False, "reason": f"Column {sector_id_column} already in dataframe"}
×
189
    else:
190
        # Split the time-of-flight column into sector ID and time-of-flight steps
191
        df = split_channel_bitwise(
1✔
192
            df=df,
193
            input_column=tof_column,
194
            output_columns=[sector_id_column, tof_column],
195
            bit_mask=sector_id_reserved_bits,
196
            overwrite=True,
197
            types=[np.int8, np.int32],
198
        )
199
        metadata = {
1✔
200
            "applied": True,
201
            "tof_column": tof_column,
202
            "sector_id_column": sector_id_column,
203
            "sector_id_reserved_bits": sector_id_reserved_bits,
204
        }
205

206
    return df, {"split_dld_time_from_sector_id": metadata}
1✔
207

208

209
def get_stats(meta: pq.FileMetaData) -> dict:
1✔
210
    """
211
    Extracts the minimum and maximum of all columns from the metadata of a Parquet file.
212

213
    Args:
214
        meta (pq.FileMetaData): The metadata of the Parquet file.
215

216
    Returns:
217
        Tuple[int, int]: The minimum and maximum timestamps.
218
    """
219
    min_max = {}
1✔
220
    for idx, name in enumerate(meta.schema.names):
1✔
221
        col = []
1✔
222
        for i in range(meta.num_row_groups):
1✔
223
            stats = meta.row_group(i).column(idx).statistics
1✔
224
            if stats is not None:
1✔
225
                if stats.min is not None:
1✔
226
                    col.append(stats.min)
1✔
227
                if stats.max is not None:
1✔
228
                    col.append(stats.max)
1✔
229
        if col:
1✔
230
            min_max[name] = {"min": min(col), "max": max(col)}
1✔
231
    return min_max
1✔
232

233

234
def get_parquet_metadata(file_paths: list[Path]) -> dict[str, dict]:
1✔
235
    """
236
    Extracts and organizes metadata from a list of Parquet files.
237

238
    For each file, the function reads the metadata, adds the filename,
239
    and extracts the minimum and maximum timestamps.
240
    "row_groups" entry is removed from FileMetaData.
241

242
    Args:
243
        file_paths (list[Path]): A list of paths to the Parquet files.
244

245
    Returns:
246
        dict[str, dict]: A dictionary file index as key and the values as metadata of each file.
247
    """
248
    organized_metadata = {}
1✔
249
    for i, file_path in enumerate(file_paths):
1✔
250
        # Read the metadata for the file
251
        file_meta: pq.FileMetaData = pq.read_metadata(file_path)
1✔
252
        # Convert the metadata to a dictionary
253
        metadata_dict = file_meta.to_dict()
1✔
254
        # Add the filename to the metadata dictionary
255
        metadata_dict["filename"] = str(file_path.name)
1✔
256

257
        # Get column min and max values
258
        metadata_dict["columns"] = get_stats(file_meta)
1✔
259

260
        # Remove "row_groups" as they contain a lot of info that is not needed
261
        metadata_dict.pop("row_groups", None)
1✔
262

263
        # Add the metadata dictionary to the organized_metadata dictionary
264
        organized_metadata[str(i)] = metadata_dict
1✔
265

266
    return organized_metadata
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc