• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lunarlab-gatech / robotdataprocess / 20172884498

12 Dec 2025 04:18PM UTC coverage: 74.672% (+1.2%) from 73.457%
20172884498

push

github

web-flow
(v0.1.2) Path Metrics & Transformations, Python >=3.8 Support

269 of 364 new or added lines in 8 files covered. (73.9%)

6 existing lines in 4 files now uncovered.

1082 of 1449 relevant lines covered (74.67%)

0.75 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

71.36
/src/robotdataprocess/data_types/PathData.py
1
from __future__ import annotations
1✔
2

3
from ..conversion_utils import col_to_dec_arr, dec_arr_to_float_arr
1✔
4
import copy
1✔
5
from .Data import Data, CoordinateFrame
1✔
6
from decimal import Decimal
1✔
7
from evo.core import sync, metrics
1✔
8
from evo.core.trajectory import PoseTrajectory3D
1✔
9
import matplotlib.pyplot as plt
1✔
10
import numpy as np
1✔
11
from numpy.typing import NDArray
1✔
12
from pathlib import Path
1✔
13
from ..rosbag.Ros2BagWrapper import Ros2BagWrapper
1✔
14
from rosbags.rosbag2 import Reader as Reader2
1✔
15
from rosbags.typesys.store import Typestore
1✔
16
from scipy.spatial.transform import Rotation as R
1✔
17
from typeguard import typechecked
1✔
18
from typing import Union, Tuple, List
1✔
19
import tqdm
1✔
20

21
class PathData(Data):
1✔
22

23
    positions: NDArray[Decimal] # meters (x, y, z)
1✔
24
    orientations: NDArray[Decimal] # quaternions (x, y, z, w)
1✔
25
    frame: CoordinateFrame
1✔
26

27
    @typechecked
1✔
28
    def __init__(self, frame_id: str, timestamps: Union[np.ndarray, list], 
1✔
29
                 positions: Union[np.ndarray, list], orientations: Union[np.ndarray, list], 
30
                 frame: CoordinateFrame):
31
        super().__init__(frame_id, timestamps)
1✔
32
        self.positions = col_to_dec_arr(positions)
1✔
33
        self.orientations = col_to_dec_arr(orientations)
1✔
34
        self.frame = frame
1✔
35

36
    # =========================================================================
37
    # ============================ Class Methods ============================== 
38
    # =========================================================================  
39

40
    @classmethod
1✔
41
    @typechecked
1✔
42
    def from_ros2_bag(cls, bag_path: Union[Path, str], odom_topic: str, frame: CoordinateFrame) -> PathData:
1✔
43
        """
44
        Creates a class structure from a ROS2 bag file with a Path topic.
45

46
        Args:
47
            bag_path (Union[Path, str]): Path to the ROS2 bag file.
48
            odom_topic (str): Topic of the Path messages.
49
        Returns:
50
            PathData: Instance of this class.
51
        """
52

53
        # Get topic message count and typestore
54
        bag_wrapper = Ros2BagWrapper(bag_path, None)
1✔
55
        typestore: Typestore = bag_wrapper.get_typestore()
1✔
56
        num_msgs: int = bag_wrapper.get_topic_count(odom_topic)
1✔
57
        
58
        # Make empty arrays
59
        timestamps_np = np.zeros(0, dtype=Decimal)
1✔
60
        positions_np = np.zeros((0, 3), dtype=Decimal)
1✔
61
        orientations_np = np.zeros((0, 4), dtype=Decimal)
1✔
62

63
        # Setup tqdm bar & counter
64
        pbar = tqdm.tqdm(total=num_msgs, desc="Extracting Path...", unit=" msgs")
1✔
65

66
        # Extract the odometry information
67
        frame_id = None
1✔
68
        with Reader2(str(bag_path)) as reader:
1✔
69

70
            # Extract frame_id from first message
71
            connections = [x for x in reader.connections if x.topic == odom_topic]
1✔
72
            for conn, timestamp, rawdata in reader.messages(connections=connections):  
1✔
73
                msg = typestore.deserialize_cdr(rawdata, conn.msgtype)
1✔
74
                frame_id = msg.header.frame_id
1✔
75
                break
1✔
76

77
            # Extract message data
78
            connections = [x for x in reader.connections if x.topic == odom_topic]
1✔
79
            for conn, timestamp, rawdata in reader.messages(connections=connections):
1✔
80
                msg = typestore.deserialize_cdr(rawdata, conn.msgtype)
1✔
81
                
82
                # NOTE: Currently, this method doesn't track when each Path message 
83
                # is recieved, and throws away duplicate poses contained in multiple
84
                # Path messages.
85

86
                # Iterate through each pose in the message
87
                for pose in msg.poses:
1✔
88
                    
89
                    # See if we already have this pose (via timestamp)
90
                    ts = bag_wrapper.extract_timestamp(pose)
1✔
91
                    if ts in timestamps_np:
1✔
92
                        continue
1✔
93

94
                    # If not, extract data
95
                    timestamps_np = np.concatenate((timestamps_np, [ts]), axis= 0)
1✔
96
                    pos = pose.pose.position
1✔
97
                    positions_np = np.concatenate((positions_np, [[Decimal(pos.x), Decimal(pos.y), Decimal(pos.z)]]), axis=0)
1✔
98
                    ori = pose.pose.orientation
1✔
99
                    orientations_np = np.concatenate((orientations_np, [[Decimal(ori.x), Decimal(ori.y), Decimal(ori.z), Decimal(ori.w)]]), axis=0)
1✔
100

101
                    # Increment the count
102
                    pbar.update(1)
1✔
103

104
        # Create an OdometryData class
105
        return cls(frame_id, timestamps_np, positions_np, orientations_np, frame)
1✔
106
    
107
    @classmethod
1✔
108
    def from_evo(cls, pose_trajectory_3d: PoseTrajectory3D, frame_id: str, frame: CoordinateFrame) -> PathData:
1✔
109
        """ Creates a PathData object from an evo PoseTrajectory3D object. """
110

111
        # Convert orientations from wxyz to xyzw
112
        orientations_xyzw = pose_trajectory_3d.orientations_quat_wxyz[:, [1, 2, 3, 0]]
1✔
113

114
        return cls(frame_id=frame_id, 
1✔
115
                   timestamps=pose_trajectory_3d.timestamps, 
116
                   positions=pose_trajectory_3d.positions_xyz, 
117
                   orientations=orientations_xyzw,
118
                   frame=frame)
119
    
120
    # =========================================================================
121
    # ============================ Visualization ============================== 
122
    # =========================================================================  
123

124
    @typechecked
1✔
125
    def visualize(self, otherList: list[PathData], titles: list[str], axes_length: Union[float, list[float]] = 10.0, axes_interval: Union[int, list[int]] = 1000):
1✔
126
        """
127
        Visualizes this PathData (and all others included in otherList)
128
        on a single plot.
129

130
        Args:
131
            otherList (list[PathData]): All other PathData objects whose path should also be visualized on this plot.
132
            titles (list[str]): Titles for each PathData object, starting with self.
133

134
        """
135

NEW
136
        print("Warning! This code has not been unit tested yet!")
×
137

NEW
138
        def draw_axes(data: PathData, axes_length: int, axes_interval: int):
×
139
            """Helper function that visualizes orientation along the trajectory path with axes."""
140

NEW
141
            for i in range(0, data.len(), axes_interval):
×
142
                # Extract data
NEW
143
                pos = data.positions[i].astype(np.float64)
×
NEW
144
                quat = data.orientations[i].astype(np.float64)
×
NEW
145
                rot = R.from_quat(quat)
×
146

147
                # Define unit vectors for X, Y, Z in local frame
NEW
148
                x_axis = rot.apply([1, 0, 0])
×
NEW
149
                y_axis = rot.apply([0, 1, 0])
×
NEW
150
                z_axis = rot.apply([0, 0, 1])
×
151

152
                # Plot axes
NEW
153
                ax.quiver(*pos, *x_axis, length=axes_length, color='r', normalize=True, linewidth=0.8)
×
NEW
154
                ax.quiver(*pos, *y_axis, length=axes_length, color='g', normalize=True, linewidth=0.8)
×
NEW
155
                ax.quiver(*pos, *z_axis, length=axes_length, color='b', normalize=True, linewidth=0.8)
×
156

157
        # Ensure that the lists are of the proper sizes
NEW
158
        if (len(otherList) + 1) != len(titles):
×
NEW
159
            raise ValueError("Length of titles should be one more than length of otherlist!")
×
160

161
        # Build a 3D plot
NEW
162
        fig = plt.figure()
×
NEW
163
        ax = fig.add_subplot(111, projection='3d')
×
164

NEW
165
        ax.plot(self.positions[:,0].astype(np.float64), 
×
166
                self.positions[:,1].astype(np.float64), 
167
                self.positions[:,2].astype(np.float64), label=titles[0])
NEW
168
        for i, other in enumerate(otherList):
×
NEW
169
            ax.plot(other.positions[:,0].astype(np.float64), 
×
170
                    other.positions[:,1].astype(np.float64), 
171
                    other.positions[:,2].astype(np.float64), 
172
                    label=titles[1+i])
173
            
174
        # Handle axes_length and axes_interval if they are lists
NEW
175
        if isinstance(axes_length, list):
×
NEW
176
            if len(axes_length) != (len(otherList) + 1):
×
NEW
177
                raise ValueError("If axes_length is a list, it must be the same length as otherList + 1!")
×
NEW
178
        else: axes_length: list[float] = [axes_length] * (len(otherList) + 1)
×
179

NEW
180
        if isinstance(axes_interval, list):
×
NEW
181
            if len(axes_interval) != (len(otherList) + 1):
×
NEW
182
                raise ValueError("If axes_interval is a list, it must be the same length as otherList + 1!")
×
NEW
183
        else: axes_interval: list[int] = [axes_interval] * (len(otherList) + 1)
×
184

185
        # Draw orientation axes (X = red, Y = green, Z = blue)
NEW
186
        draw_axes(self, axes_length=axes_length[0], axes_interval=axes_interval[0])
×
NEW
187
        for i, other in enumerate(otherList):
×
NEW
188
            draw_axes(other, axes_length=axes_length[i+1], axes_interval=axes_interval[i+1])
×
189

190
        # Set labels
NEW
191
        ax.set_title("Trajectory Comparison with Full Orientation")
×
NEW
192
        ax.set_xlabel("X (m)")
×
NEW
193
        ax.set_ylabel("Y (m)")
×
NEW
194
        ax.set_zlabel("Z (m)")
×
NEW
195
        ax.legend()
×
196

197
        # Concatenate all x, y and z values together
NEW
198
        all_x = self.positions[:,0]
×
NEW
199
        all_y = self.positions[:,1]
×
NEW
200
        all_z = self.positions[:,2]
×
NEW
201
        for other in otherList:
×
NEW
202
            all_x = np.concatenate((all_x, other.positions[:,0]))
×
NEW
203
            all_y = np.concatenate((all_y, other.positions[:,1]))
×
NEW
204
            all_z = np.concatenate((all_z, other.positions[:,2]))
×
NEW
205
        all_x = all_x.astype(np.float64)
×
NEW
206
        all_y = all_y.astype(np.float64)
×
NEW
207
        all_z = all_z.astype(np.float64)
×
208

209
        # Set an equal scale for all axes
NEW
210
        x_center = (all_x.max() + all_x.min()) / 2
×
NEW
211
        y_center = (all_y.max() + all_y.min()) / 2
×
NEW
212
        z_center = (all_z.max() + all_z.min()) / 2
×
NEW
213
        max_range = max(all_x.max() - all_x.min(), all_y.max() - all_y.min(), all_z.max() - all_z.min()) / 2
×
NEW
214
        ax.set_xlim(x_center - max_range, x_center + max_range)
×
NEW
215
        ax.set_ylim(y_center - max_range, y_center + max_range)
×
NEW
216
        ax.set_zlim(z_center - max_range, z_center + max_range)
×
217

218
        # Show the plot
NEW
219
        plt.tight_layout()
×
NEW
220
        plt.show()
×
221
    
222
    # =========================================================================
223
    # ============================ Export Methods ============================= 
224
    # =========================================================================  
225

226
    def to_OdometryData(self, new_frame_id: str, new_child_frame_id: str):
1✔
227
        """ 
228
        Returns an OdometryData object for this class. 
229

230
        Parameters:
231
            new_frame_id: The new frame ID to assign to the OdometryData object. 
232
            new_child_frame_id: The new child frame ID to assign to the OdometryData object.
233
        """
234

235
        from .OdometryData import OdometryData
1✔
236
        return OdometryData(frame_id=new_frame_id,
1✔
237
                            child_frame_id=new_child_frame_id,
238
                            timestamps=self.timestamps,
239
                            positions=self.positions,
240
                            orientations=self.orientations,
241
                            frame=self.frame)
242

243
    def to_evo(self) -> PoseTrajectory3D:
1✔
244
        """ Returns an evo PoseTrajectory3D object for this class. """
245

246
        orientations_wxyz = dec_arr_to_float_arr(self.orientations[:, [3, 0, 1, 2]])
1✔
247
        return PoseTrajectory3D(positions_xyz=dec_arr_to_float_arr(self.positions), 
1✔
248
                                orientations_quat_wxyz=orientations_wxyz,
249
                                timestamps=dec_arr_to_float_arr(self.timestamps))
250
    
251
    # =========================================================================
252
    # ======================= Multi PathData Methods ========================== 
253
    # ========================================================================= 
254

255
    @staticmethod
1✔
256
    def make_start_and_end_times_match(est: list[PathData], gt: list[PathData]) -> tuple[list[PathData], list[PathData]]:
1✔
257
        """ 
258
        For pairs of lists of PathData objects, extract each pair by index and 
259
        ensure that the first and last timestamps match by extending the data
260
        as necessary at the start and end with duplicate values. Used for evaluation
261
        purposes.
262
        
263
        Mimics the behavior found in ROMAN's (https://github.com/lunarlab-gatech/roman) evaluation scripts.
264

265
        Parameters:
266
            est: List of PathData objects that represent estimated paths.
267
            gt: List of PathData objects that represent ground truth paths.
268
        """
269

270
        # Copy the PathData objects so we don't modify the originals
271
        est = copy.deepcopy(est)
1✔
272
        gt = copy.deepcopy(gt)
1✔
273

274
        # Check that the lists are the same length
275
        if len(est) == 0 or len(gt) == 0 or len(est) != len(gt):
1✔
NEW
276
            raise ValueError("est and gt lists must be non-empty and of the same length!")
×
277

278
        # For each pair of PathData objects
279
        for est_i, gt_i in zip(est, gt):
1✔
280

281
            # Adjust start times
282
            if est_i.timestamps[0] < gt_i.timestamps[0]:
1✔
283
                gt_i.timestamps = np.concatenate(([est_i.timestamps[0]], gt_i.timestamps))
1✔
284
                gt_i.positions = np.concatenate(([gt_i.positions[0]], gt_i.positions))
1✔
285
                gt_i.orientations = np.concatenate(([gt_i.orientations[0]], gt_i.orientations))
1✔
286
            elif est_i.timestamps[0] > gt_i.timestamps[0]:
1✔
287
                est_i.timestamps = np.concatenate(([gt_i.timestamps[0]], est_i.timestamps))
1✔
288
                est_i.positions = np.concatenate(([est_i.positions[0]], est_i.positions))
1✔
289
                est_i.orientations = np.concatenate(([est_i.orientations[0]], est_i.orientations))
1✔
290

291
            # Adjust end times
292
            if est_i.timestamps[-1] < gt_i.timestamps[-1]:
1✔
293
                est_i.timestamps = np.concatenate((est_i.timestamps, [gt_i.timestamps[-1]]))
1✔
294
                est_i.positions = np.concatenate((est_i.positions, [est_i.positions[-1]]))
1✔
295
                est_i.orientations = np.concatenate((est_i.orientations, [est_i.orientations[-1]]))
1✔
296
            elif est_i.timestamps[-1] > gt_i.timestamps[-1]:
1✔
297
                gt_i.timestamps = np.concatenate((gt_i.timestamps, [est_i.timestamps[-1]]))
1✔
298
                gt_i.positions = np.concatenate((gt_i.positions, [gt_i.positions[-1]]))
1✔
299
                gt_i.orientations = np.concatenate((gt_i.orientations, [gt_i.orientations[-1]]))
1✔
300
        
301
        # Return the modified lists
302
        return est, gt
1✔
303

304
    @staticmethod
1✔
305
    def concatenate_PathData(path_data_objs: list[PathData]) -> PathData:
1✔
306
        """ 
307
        Combines multiple PathData objects into a single PathData object. In doing so,
308
        will shift the timestamps of each subsequent PathData so that their data starts
309
        one second after the previous PathData ends. Also assumes the frame_id and frame
310
        of the first PathData object for final PathData object.
311
        
312
        Mimics the behavior found in ROMAN's (https://github.com/lunarlab-gatech/roman) evaluation scripts.
313
        """
314

315
        print("Warning! This code has not been unit tested yet!")
1✔
316

317
        if len(path_data_objs) == 0:
1✔
318
            raise ValueError("path_data_objs list is empty!")
1✔
319
        if len(path_data_objs) == 1:
1✔
320
            raise ValueError("path_data_objs list has only one element; no need to concatenate!")
1✔
321

322
        # NOTE: Assumes the frame_id and frame of the first object
323
        frame_id = path_data_objs[0].frame_id
1✔
324
        frame = path_data_objs[0].frame
1✔
325

326
        # Create all empty arrays to hold concatenated data
327
        all_timestamps = np.zeros((0,), dtype=Decimal)
1✔
328
        all_positions = np.zeros((0, 3), dtype=Decimal)
1✔
329
        all_orientations = np.zeros((0, 4), dtype=Decimal)
1✔
330

331
        # For each PathData object
332
        for i, path_data in enumerate(path_data_objs):
1✔
333

334
            # If not first PathData, shift timestamps
335
            if i == 0:
1✔
336
                shifted_timestamps = path_data.timestamps
1✔
337
            else:
338
                shifted_timestamps = path_data.timestamps - path_data.timestamps[0] + all_timestamps[-1] + 1
1✔
339

340
            # Concatentate data
341
            all_timestamps = np.concatenate((all_timestamps, shifted_timestamps), axis=0)
1✔
342
            all_positions = np.concatenate((all_positions, path_data.positions), axis=0)
1✔
343
            all_orientations = np.concatenate((all_orientations, path_data.orientations), axis=0)
1✔
344

345
        return PathData(frame_id, all_timestamps, all_positions, all_orientations, frame)
1✔
346

347
    @staticmethod
1✔
348
    def calculate_trajectory_errors(gt_path: PathData, est_path: PathData, max_diff: float, visualize: bool = False, 
1✔
349
                                    axes_length: Union[float, list[float]] = 10.0, axes_interval: Union[int, list[int]] = 1000) -> dict:
350
        """
351
        Utilizing the evo library, calculates a variety of trajectory error metrics
352
        and returns them in a dictionary.
353

354
        Parameters:
355
            max_diff: maximum absolute time difference allowed between associated timestamps
356
            visualize: If true, will show a 3D plot of the aligned trajectories.
357
            axes_length: Same as in visualize() method.
358
            axes_interval: Same as in visualize() method.
359
        """
360

361
        gt_traj: PoseTrajectory3D = gt_path.to_evo()
1✔
362
        est_traj: PoseTrajectory3D = est_path.to_evo()
1✔
363

364
        gt_traj, est_traj = sync.associate_trajectories(gt_traj, est_traj, max_diff)
1✔
365

366
        est_traj_align: PoseTrajectory3D = copy.deepcopy(est_traj)
1✔
367
        est_traj_align.align(gt_traj, correct_scale=False, correct_only_scale=False) 
1✔
368

369
        path_pair: tuple[PoseTrajectory3D, PoseTrajectory3D] = (gt_traj, est_traj_align)
1✔
370

371
        # Calculate various error metrics using evo, including APE and RPE
372
        all_pose_relations: list[metrics.PoseRelation] = [metrics.PoseRelation.full_transformation, # dimensionless
1✔
373
                                                          metrics.PoseRelation.translation_part, # meters
374
                                                          metrics.PoseRelation.rotation_part, # dimensionless
375
                                                          metrics.PoseRelation.rotation_angle_deg, # degrees
376
                                                          metrics.PoseRelation.rotation_angle_rad, # radians
377
                                                          metrics.PoseRelation.point_distance, # meters
378
                                                          metrics.PoseRelation.point_distance_error_ratio] # percent
379
        all_statistic_types: list[metrics.StatisticsType] = [metrics.StatisticsType.rmse,
1✔
380
                                                             metrics.StatisticsType.mean,
381
                                                             metrics.StatisticsType.median,
382
                                                             metrics.StatisticsType.std,
383
                                                             metrics.StatisticsType.min,
384
                                                             metrics.StatisticsType.max,
385
                                                             metrics.StatisticsType.sse]
386
        all_metrics: list[metrics.PE] = [metrics.APE, metrics.RPE]
1✔
387
        dict_all_results: dict = {}
1✔
388
        for metric in all_metrics:
1✔
389
            dict_metric: dict = {}
1✔
390

391
            for pose_relation in all_pose_relations:
1✔
392
                dict_relation: dict = {}
1✔
393

394
                # Skip uncompatible relation with metric
395
                if metric is metrics.APE and pose_relation == metrics.PoseRelation.point_distance_error_ratio:
1✔
396
                    continue
1✔
397

398
                path_pair_copied = copy.deepcopy(path_pair)
1✔
399
                metric_with_relation: metrics.PE = metric(pose_relation)
1✔
400
                metric_with_relation.process_data(path_pair_copied)
1✔
401

402
                for stat in all_statistic_types:
1✔
403
                    final_stat: float = metric_with_relation.get_statistic(stat)
1✔
404
                    dict_relation[stat.name] = final_stat
1✔
405

406
                dict_metric[pose_relation.name] = dict_relation
1✔
407
            
408
            dict_all_results[metric.__name__] = dict_metric
1✔
409

410
        # Visualize the aligned trajectories if desired
411
        if visualize:
1✔
412
            
413
            # Convert est_traj_align back to PathData
NEW
414
            est_traj_align_pathdata: PathData = PathData.from_evo(est_traj_align, gt_path.frame_id, gt_path.frame)
×
415

416
            # Visualize the aligned trajectories with specified axes length and interval
NEW
417
            gt_path.visualize([est_traj_align_pathdata], ['Ground Truth', 'Estimated (Aligned)'], 
×
418
                              axes_interval=axes_interval, axes_length=axes_length)
419

420
        return dict_all_results
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc