diff --git a/README.md b/README.md index fe569183e..2aa3cef0f 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ [![python-version](https://img.shields.io/pypi/pyversions/supervision)](https://badge.fury.io/py/supervision) [![colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/roboflow/supervision/blob/main/demo.ipynb) [![gradio](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/Roboflow/Annotators) -[![discord](https://img.shields.io/discord/1159501506232451173)](https://discord.gg/GbfgXGJ8Bk) +[![discord](https://img.shields.io/discord/1159501506232451173?logo=discord&label=discord&labelColor=fff&color=5865f2&link=https%3A%2F%2Fdiscord.gg%2FGbfgXGJ8Bk)](https://discord.gg/GbfgXGJ8Bk) [![built-with-material-for-mkdocs](https://img.shields.io/badge/Material_for_MkDocs-526CFE?logo=MaterialForMkDocs&logoColor=white)](https://squidfunk.github.io/mkdocs-material/)
@@ -135,88 +135,88 @@ for path, image, annotation in ds: - load - ```python - dataset = sv.DetectionDataset.from_yolo( - images_directory_path=..., - annotations_directory_path=..., - data_yaml_path=... - ) - - dataset = sv.DetectionDataset.from_pascal_voc( - images_directory_path=..., - annotations_directory_path=... - ) - - dataset = sv.DetectionDataset.from_coco( - images_directory_path=..., - annotations_path=... - ) - ``` + ```python + dataset = sv.DetectionDataset.from_yolo( + images_directory_path=..., + annotations_directory_path=..., + data_yaml_path=... + ) + + dataset = sv.DetectionDataset.from_pascal_voc( + images_directory_path=..., + annotations_directory_path=... + ) + + dataset = sv.DetectionDataset.from_coco( + images_directory_path=..., + annotations_path=... + ) + ``` - split - ```python - train_dataset, test_dataset = dataset.split(split_ratio=0.7) - test_dataset, valid_dataset = test_dataset.split(split_ratio=0.5) + ```python + train_dataset, test_dataset = dataset.split(split_ratio=0.7) + test_dataset, valid_dataset = test_dataset.split(split_ratio=0.5) - len(train_dataset), len(test_dataset), len(valid_dataset) - # (700, 150, 150) - ``` + len(train_dataset), len(test_dataset), len(valid_dataset) + # (700, 150, 150) + ``` - merge - ```python - ds_1 = sv.DetectionDataset(...) - len(ds_1) - # 100 - ds_1.classes - # ['dog', 'person'] - - ds_2 = sv.DetectionDataset(...) - len(ds_2) - # 200 - ds_2.classes - # ['cat'] - - ds_merged = sv.DetectionDataset.merge([ds_1, ds_2]) - len(ds_merged) - # 300 - ds_merged.classes - # ['cat', 'dog', 'person'] - ``` + ```python + ds_1 = sv.DetectionDataset(...) + len(ds_1) + # 100 + ds_1.classes + # ['dog', 'person'] + + ds_2 = sv.DetectionDataset(...) + len(ds_2) + # 200 + ds_2.classes + # ['cat'] + + ds_merged = sv.DetectionDataset.merge([ds_1, ds_2]) + len(ds_merged) + # 300 + ds_merged.classes + # ['cat', 'dog', 'person'] + ``` - save - ```python - dataset.as_yolo( - images_directory_path=..., - annotations_directory_path=..., - data_yaml_path=... - ) - - dataset.as_pascal_voc( - images_directory_path=..., - annotations_directory_path=... - ) - - dataset.as_coco( - images_directory_path=..., - annotations_path=... - ) - ``` + ```python + dataset.as_yolo( + images_directory_path=..., + annotations_directory_path=..., + data_yaml_path=... + ) + + dataset.as_pascal_voc( + images_directory_path=..., + annotations_directory_path=... + ) + + dataset.as_coco( + images_directory_path=..., + annotations_path=... + ) + ``` - convert - ```python - sv.DetectionDataset.from_yolo( - images_directory_path=..., - annotations_directory_path=..., - data_yaml_path=... - ).as_pascal_voc( - images_directory_path=..., - annotations_directory_path=... - ) - ``` + ```python + sv.DetectionDataset.from_yolo( + images_directory_path=..., + annotations_directory_path=..., + data_yaml_path=... + ).as_pascal_voc( + images_directory_path=..., + annotations_directory_path=... + ) + ``` diff --git a/docs/how_to/track_objects.md b/docs/how_to/track_objects.md index 784cb7bf1..1b321e7fe 100644 --- a/docs/how_to/track_objects.md +++ b/docs/how_to/track_objects.md @@ -1,5 +1,6 @@ --- comments: true +status: new --- # Track Objects @@ -13,6 +14,8 @@ take you through the steps to perform inference using the YOLOv8 model via eithe you'll discover how to track these objects efficiently and annotate your video content for a deeper analysis. +## Object Detection & Segmentation + To make it easier for you to follow our tutorial download the video we will use as an example. You can do this using [`supervision[assets]`](/latest/assets/) extension. @@ -21,14 +24,13 @@ example. You can do this using from supervision.assets import download_assets, VideoAssets download_assets(VideoAssets.PEOPLE_WALKING) -download_assets(VideoAssets.SKIING) ``` -## Run Inference +### Run Inference First, you'll need to obtain predictions from your object detection or segmentation model. In this tutorial, we are using the YOLOv8 model as an example. However, @@ -41,6 +43,10 @@ by obtaining model predictions and then annotating the frame based on these pred This `callback` function will be essential in the subsequent steps of the tutorial, as it will be modified to include tracking, labeling, and trace annotations. +!!! tip + + Both object detection and segmentation models are supported. Try it with `yolov8n.pt` or `yolov8n-640-seg`! + === "Ultralytics" ```{ .py } @@ -89,7 +95,7 @@ it will be modified to include tracking, labeling, and trace annotations. -## Tracking +### Tracking After running inference and obtaining predictions, the next step is to track the detected objects throughout the video. Utilizing Supervision’s @@ -145,7 +151,7 @@ enabling the continuous following of the object's motion path across different f ) ``` -## Annotate Video with Tracking IDs +### Annotate Video with Tracking IDs Annotating the video with tracking IDs helps in distinguishing and following each object distinctly. With the @@ -227,7 +233,7 @@ offering a clear visual representation of each object's class and unique identif -## Annotate Video with Traces +### Annotate Video with Traces Adding traces to the video involves overlaying the historical paths of the detected objects. This feature, powered by the @@ -315,9 +321,96 @@ movement patterns and interactions between objects in the video. -## Tracking Key Points +## Keypoints + +Models aren't limited to object detection and segmentation. Keypoint detection allows for detailed analysis of body joints and connections, especially valuable for applications like human pose estimation. This section introduces keypoint tracking. We'll walk through the steps of annotating keypoints, converting them into bounding box detections compatible with `ByteTrack`, and applying detection smoothing for enhanced stability. + +To make it easier for you to follow our tutorial, let's download the video we will use as an +example. You can do this using [`supervision[assets]`](/latest/assets/) extension. + +```python +from supervision.assets import download_assets, VideoAssets + +download_assets(VideoAssets.SKIING) +``` + + + +### Keypoint Detection + +First, you'll need to obtain predictions from your keypoint detection model. In this tutorial, we are using the YOLOv8 model as an example. However, +Supervision is versatile and compatible with various models. Check this [link](/latest/keypoint/core/) for guidance on how to plug in other models. + +We will define a `callback` function, which will process each frame of the video by obtaining model predictions and then annotating the frame based on these predictions. + +Let's immediately visualize the results with our [`EdgeAnnotator`](/latest/keypoint/annotators/#supervision.keypoint.annotators.EdgeAnnotator) and [`VertexAnnotator`](https://supervision.roboflow.com/latest/keypoint/annotators/#supervision.keypoint.annotators.VertexAnnotator). + +=== "Ultralytics" + + ```{ .py hl_lines="5 10-11" } + import numpy as np + import supervision as sv + from ultralytics import YOLO -Keypoint tracking is currently supported via the conversion of `KeyPoints` to `Detections`. This is achieved with the [`keypoints_to_detections`](/latest/utils/datatypes/#supervision.utils.datatypes.keypoints_to_detections) function. We'll use a different video as well as [`DetectionsSmoother`](/latest/detection/tools/smoother/) to stabilize the boxes. + model = YOLO("yolov8m-pose.pt") + edge_annotator = sv.EdgeAnnotator() + vertex_annotator = sv.VertexAnnotator() + + def callback(frame: np.ndarray, _: int) -> np.ndarray: + results = model(frame)[0] + key_points = sv.KeyPoints.from_ultralytics(results) + + annotated_frame = edge_annotator.annotate( + frame.copy(), key_points=key_points) + return vertex_annotator.annotate( + annotated_frame, key_points=key_points) + + sv.process_video( + source_path="skiing.mp4", + target_path="result.mp4", + callback=callback + ) + ``` + +=== "Inference" + + ```{ .py hl_lines="5-6 11-12" } + import numpy as np + import supervision as sv + from inference.models.utils import get_roboflow_model + + model = get_roboflow_model( + model_id="yolov8m-pose-640", api_key=) + edge_annotator = sv.EdgeAnnotator() + vertex_annotator = sv.VertexAnnotator() + + def callback(frame: np.ndarray, _: int) -> np.ndarray: + results = model.infer(frame)[0] + key_points = sv.KeyPoints.from_inference(results) + + annotated_frame = edge_annotator.annotate( + frame.copy(), key_points=key_points) + return vertex_annotator.annotate( + annotated_frame, key_points=key_points) + + sv.process_video( + source_path="skiing.mp4", + target_path="result.mp4", + callback=callback + ) + ``` + + + +### Convert to Detections + +Keypoint tracking is currently supported via the conversion of `KeyPoints` to `Detections`. This is achieved with the [`KeyPoints.as_detections()`](/latest/keypoint/core/#supervision.keypoint.core.KeyPoints.as_detections) function. + +Let's convert to detections and visualize the results with our [`BoxAnnotator`](/latest/detection/annotators/#supervision.annotators.core.BoxAnnotator). !!! tip @@ -325,35 +418,187 @@ Keypoint tracking is currently supported via the conversion of `KeyPoints` to `D === "Ultralytics" - ```{ .py hl_lines="5 7 14-15 17 33" } + ```{ .py hl_lines="8 13 19-20" } import numpy as np import supervision as sv from ultralytics import YOLO model = YOLO("yolov8m-pose.pt") + edge_annotator = sv.EdgeAnnotator() + vertex_annotator = sv.VertexAnnotator() + box_annotator = sv.BoxAnnotator() + + def callback(frame: np.ndarray, _: int) -> np.ndarray: + results = model(frame)[0] + key_points = sv.KeyPoints.from_ultralytics(results) + detections = key_points.as_detections() + + annotated_frame = edge_annotator.annotate( + frame.copy(), key_points=key_points) + annotated_frame = vertex_annotator.annotate( + annotated_frame, key_points=key_points) + return box_annotator.annotate( + annotated_frame, detections=detections) + + sv.process_video( + source_path="skiing.mp4", + target_path="result.mp4", + callback=callback + ) + ``` + +=== "Inference" + + ```{ .py hl_lines="9 14 20-21" } + import numpy as np + import supervision as sv + from inference.models.utils import get_roboflow_model + + model = get_roboflow_model( + model_id="yolov8m-pose-640", api_key=) + edge_annotator = sv.EdgeAnnotator() + vertex_annotator = sv.VertexAnnotator() + box_annotator = sv.BoxAnnotator() + + def callback(frame: np.ndarray, _: int) -> np.ndarray: + results = model.infer(frame)[0] + key_points = sv.KeyPoints.from_inference(results) + detections = key_points.as_detections() + + annotated_frame = edge_annotator.annotate( + frame.copy(), key_points=key_points) + annotated_frame = vertex_annotator.annotate( + annotated_frame, key_points=key_points) + return box_annotator.annotate( + annotated_frame, detections=detections) + + sv.process_video( + source_path="skiing.mp4", + target_path="result.mp4", + callback=callback + ) + ``` + + + +### Keypoint Tracking + +Now that we have a `Detections` object, we can track it throughout the video. Utilizing Supervision’s [`sv.ByteTrack`](/latest/trackers/#supervision.tracker.byte_tracker.core.ByteTrack) functionality, each detected object is assigned a unique tracker ID, enabling the continuous following of the object's motion path across different frames. We shall visualize the result with `TraceAnnotator`. + +=== "Ultralytics" + + ```{ .py hl_lines="10-11 17 25-26" } + import numpy as np + import supervision as sv + from ultralytics import YOLO + + model = YOLO("yolov8m-pose.pt") + edge_annotator = sv.EdgeAnnotator() + vertex_annotator = sv.VertexAnnotator() + box_annotator = sv.BoxAnnotator() + + tracker = sv.ByteTrack() + trace_annotator = sv.TraceAnnotator() + + def callback(frame: np.ndarray, _: int) -> np.ndarray: + results = model(frame)[0] + key_points = sv.KeyPoints.from_ultralytics(results) + detections = key_points.as_detections() + detections = tracker.update_with_detections(detections) + + annotated_frame = edge_annotator.annotate( + frame.copy(), key_points=key_points) + annotated_frame = vertex_annotator.annotate( + annotated_frame, key_points=key_points) + annotated_frame = box_annotator.annotate( + annotated_frame, detections=detections) + return trace_annotator.annotate( + annotated_frame, detections=detections) + + sv.process_video( + source_path="skiing.mp4", + target_path="result.mp4", + callback=callback + ) + ``` + +=== "Inference" + + ```{ .py hl_lines="11-12 18 26-27" } + import numpy as np + import supervision as sv + from inference.models.utils import get_roboflow_model + + model = get_roboflow_model( + model_id="yolov8m-pose-640", api_key=) + edge_annotator = sv.EdgeAnnotator() + vertex_annotator = sv.VertexAnnotator() + box_annotator = sv.BoxAnnotator() + + tracker = sv.ByteTrack() + trace_annotator = sv.TraceAnnotator() + + def callback(frame: np.ndarray, _: int) -> np.ndarray: + results = model.infer(frame)[0] + key_points = sv.KeyPoints.from_inference(results) + detections = key_points.as_detections() + detections = tracker.update_with_detections(detections) + + annotated_frame = edge_annotator.annotate( + frame.copy(), key_points=key_points) + annotated_frame = vertex_annotator.annotate( + annotated_frame, key_points=key_points) + annotated_frame = box_annotator.annotate( + annotated_frame, detections=detections) + return trace_annotator.annotate( + annotated_frame, detections=detections) + + sv.process_video( + source_path="skiing.mp4", + target_path="result.mp4", + callback=callback + ) + ``` + + + +### Bonus: Smoothing + +We could stop here as we have successfully tracked the object detected by the keypoint model. However, we can further enhance the stability of the boxes by applying [`DetectionsSmoother`](/latest/detection/tools/smoother/). This tool helps in stabilizing the boxes by smoothing the bounding box coordinates across frames. It is very simple to use: + +=== "Ultralytics" + + ```{ .py hl_lines="11 19" } + import numpy as np + import supervision as sv + from ultralytics import YOLO + + model = YOLO("yolov8m-pose.pt") + edge_annotator = sv.EdgeAnnotator() + vertex_annotator = sv.VertexAnnotator() + box_annotator = sv.BoxAnnotator() + tracker = sv.ByteTrack() smoother = sv.DetectionsSmoother() - box_annotator = sv.BoundingBoxAnnotator() - label_annotator = sv.LabelAnnotator() trace_annotator = sv.TraceAnnotator() def callback(frame: np.ndarray, _: int) -> np.ndarray: results = model(frame)[0] - keypoints = sv.KeyPoints.from_ultralytics(results) - detections = sv.keypoints_to_detections(keypoints) + key_points = sv.KeyPoints.from_ultralytics(results) + detections = key_points.as_detections() detections = tracker.update_with_detections(detections) detections = smoother.update_with_detections(detections) - labels = [ - f"#{tracker_id} {results.names[class_id]}" - for class_id, tracker_id - in zip(detections.class_id, detections.tracker_id) - ] - + annotated_frame = edge_annotator.annotate( + frame.copy(), key_points=key_points) + annotated_frame = vertex_annotator.annotate( + annotated_frame, key_points=key_points) annotated_frame = box_annotator.annotate( - frame.copy(), detections=detections) - annotated_frame = label_annotator.annotate( - annotated_frame, detections=detections, labels=labels) + annotated_frame, detections=detections) return trace_annotator.annotate( annotated_frame, detections=detections) @@ -366,36 +611,34 @@ Keypoint tracking is currently supported via the conversion of `KeyPoints` to `D === "Inference" - ```{ .py hl_lines="5-6 8 15-16 18 34" } + ```{ .py hl_lines="12 20" } import numpy as np import supervision as sv from inference.models.utils import get_roboflow_model model = get_roboflow_model( model_id="yolov8m-pose-640", api_key=) + edge_annotator = sv.EdgeAnnotator() + vertex_annotator = sv.VertexAnnotator() + box_annotator = sv.BoxAnnotator() + tracker = sv.ByteTrack() smoother = sv.DetectionsSmoother() - box_annotator = sv.BoundingBoxAnnotator() - label_annotator = sv.LabelAnnotator() trace_annotator = sv.TraceAnnotator() def callback(frame: np.ndarray, _: int) -> np.ndarray: results = model.infer(frame)[0] - keypoints = sv.KeyPoints.from_inference(results) - detections = sv.keypoints_to_detections(keypoints) + key_points = sv.KeyPoints.from_inference(results) + detections = key_points.as_detections() detections = tracker.update_with_detections(detections) detections = smoother.update_with_detections(detections) - labels = [ - f"#{tracker_id} {results.names[class_id]}" - for class_id, tracker_id - in zip(detections.class_id, detections.tracker_id) - ] - + annotated_frame = edge_annotator.annotate( + frame.copy(), key_points=key_points) + annotated_frame = vertex_annotator.annotate( + annotated_frame, key_points=key_points) annotated_frame = box_annotator.annotate( - frame.copy(), detections=detections) - annotated_frame = label_annotator.annotate( - annotated_frame, detections=detections, labels=labels) + annotated_frame, detections=detections) return trace_annotator.annotate( annotated_frame, detections=detections) @@ -407,9 +650,7 @@ Keypoint tracking is currently supported via the conversion of `KeyPoints` to `D ``` -This structured walkthrough should give a detailed pathway to annotate videos -effectively using Supervision’s various functionalities, including object tracking and -trace annotations. +This structured walkthrough should give a detailed pathway to annotate videos effectively using Supervision’s various functionalities, including object tracking and trace annotations. diff --git a/docs/keypoint/core.md b/docs/keypoint/core.md index 6f42c254d..7354babab 100644 --- a/docs/keypoint/core.md +++ b/docs/keypoint/core.md @@ -1,5 +1,6 @@ --- comments: true +status: new --- # Keypoint Detection diff --git a/docs/utils/datatypes.md b/docs/utils/datatypes.md deleted file mode 100644 index 5e0560bd3..000000000 --- a/docs/utils/datatypes.md +++ /dev/null @@ -1,12 +0,0 @@ ---- -comments: true -status: new ---- - -# Data Types Utils - - - -:::supervision.utils.datatypes.keypoints_to_detections diff --git a/mkdocs.yml b/mkdocs.yml index 1ed9fafa3..6d013a730 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -80,7 +80,6 @@ nav: - File: utils/file.md - Draw: utils/draw.md - Geometry: utils/geometry.md - - Datatypes: utils/datatypes.md - Assets: assets.md - Cookbooks: cookbooks.md - Cheatsheet: https://roboflow.github.io/cheatsheet-supervision/ diff --git a/supervision/__init__.py b/supervision/__init__.py index 04813ef22..746b2f67d 100644 --- a/supervision/__init__.py +++ b/supervision/__init__.py @@ -100,7 +100,6 @@ from supervision.metrics.detection import ConfusionMatrix, MeanAveragePrecision from supervision.tracker.byte_tracker.core import ByteTrack from supervision.utils.conversion import cv2_to_pillow, pillow_to_cv2 -from supervision.utils.datatypes import keypoints_to_detections from supervision.utils.file import list_files_with_extensions from supervision.utils.image import ( ImageSink, diff --git a/supervision/detection/line_zone.py b/supervision/detection/line_zone.py index 422bc9c5c..b63660bc4 100644 --- a/supervision/detection/line_zone.py +++ b/supervision/detection/line_zone.py @@ -42,6 +42,10 @@ class LineZone: to inside. out_count (int): The number of objects that have crossed the line from inside to outside. + in_count_per_class (Dict[int, int]): Number of objects of each class that have + crossed the line from outside to inside. + out_count_per_class (Dict[int, int]): Number of objects of each class that have + crossed the line from inside to outside. Example: ```python @@ -75,7 +79,7 @@ def __init__( Position.BOTTOM_LEFT, Position.BOTTOM_RIGHT, ), - crossing_acceptance_threshold: int = 1, + minimum_crossing_threshold: int = 1, ): """ Args: @@ -86,7 +90,7 @@ def __init__( to consider when deciding on whether the detection has passed the line counter or not. By default, this contains the four corners of the detection's bounding box - crossing_acceptance_threshold (int): Detection needs to be seen + minimum_crossing_threshold (int): Detection needs to be seen on the other side of the line for this many frames to be considered as having crossed the line. This is useful when dealing with unstable bounding boxes or when detections @@ -94,7 +98,7 @@ def __init__( """ self.vector = Vector(start=start, end=end) self.limits = self._calculate_region_of_interest_limits(vector=self.vector) - self.crossing_history_length = max(2, crossing_acceptance_threshold + 1) + self.crossing_history_length = max(2, minimum_crossing_threshold + 1) self.crossing_state_history: Dict[int, Deque[bool]] = defaultdict( lambda: deque(maxlen=self.crossing_history_length) ) @@ -107,34 +111,18 @@ def __init__( @property def in_count(self) -> int: - """ - Number of objects that have crossed the line from - outside to inside. - """ return sum(self._in_count_per_class.values()) @property def out_count(self) -> int: - """ - Number of objects that have crossed the line from - inside to outside. - """ return sum(self._out_count_per_class.values()) @property def in_count_per_class(self) -> Dict[int, int]: - """ - Number of objects of each class that have crossed - the line from outside to inside. - """ return dict(self._in_count_per_class) @property def out_count_per_class(self) -> Dict[int, int]: - """ - Number of objects of each class that have crossed the line - from inside to outside. - """ return dict(self._out_count_per_class) def trigger(self, detections: Detections) -> Tuple[np.ndarray, np.ndarray]: diff --git a/supervision/detection/utils.py b/supervision/detection/utils.py index 3491126c3..f6bcd33bc 100644 --- a/supervision/detection/utils.py +++ b/supervision/detection/utils.py @@ -919,11 +919,27 @@ def merge_metadata(metadata_list: List[Dict[str, Any]]) -> Dict[str, Any]: merged_metadata: Dict[str, Any] = {} for metadata in metadata_list: for key, value in metadata.items(): - if key in merged_metadata: + if key not in merged_metadata: + merged_metadata[key] = value + continue + + other_value = merged_metadata[key] + if isinstance(value, np.ndarray) and isinstance(other_value, np.ndarray): + if not np.array_equal(merged_metadata[key], value): + raise ValueError( + f"Conflicting metadata for key: '{key}': " + "{type(value)}, {type(other_value)}." + ) + elif isinstance(value, np.ndarray) or isinstance(other_value, np.ndarray): + # Since [] == np.array([]). + raise ValueError( + f"Conflicting metadata for key: '{key}': " + "{type(value)}, {type(other_value)}." + ) + else: + print("hm") if merged_metadata[key] != value: raise ValueError(f"Conflicting metadata for key: '{key}'.") - else: - merged_metadata[key] = value return merged_metadata diff --git a/supervision/keypoint/core.py b/supervision/keypoint/core.py index 252fb63f3..4b8e9d55b 100644 --- a/supervision/keypoint/core.py +++ b/supervision/keypoint/core.py @@ -2,12 +2,13 @@ from contextlib import suppress from dataclasses import dataclass, field -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union import numpy as np import numpy.typing as npt from supervision.config import CLASS_NAME_DATA_FIELD +from supervision.detection.core import Detections from supervision.detection.utils import get_data_item, is_data_equal from supervision.validators import validate_keypoints_fields @@ -620,3 +621,67 @@ def is_empty(self) -> bool: empty_keypoints = KeyPoints.empty() empty_keypoints.data = self.data return self == empty_keypoints + + def as_detections( + self, selected_keypoint_indices: Optional[Iterable[int]] = None + ) -> Detections: + """ + Convert a KeyPoints object to a Detections object. This + approximates the bounding box of the detected object by + taking the bounding box that fits all keypoints. + + Arguments: + selected_keypoint_indices (Optional[Iterable[int]]): The + indices of the keypoints to include in the bounding box + calculation. This helps focus on a subset of keypoints, + e.g. when some are occluded. Captures all keypoints by default. + + Returns: + detections (Detections): The converted detections object. + + Example: + ```python + keypoints = sv.KeyPoints.from_inference(...) + detections = keypoints.as_detections() + ``` + """ + if self.is_empty(): + return Detections.empty() + + detections_list = [] + for i, xy in enumerate(self.xy): + if selected_keypoint_indices: + xy = xy[selected_keypoint_indices] + + # [0, 0] used by some frameworks to indicate missing keypoints + xy = xy[~np.all(xy == 0, axis=1)] + if len(xy) == 0: + xyxy = np.array([[0, 0, 0, 0]], dtype=np.float32) + else: + x_min = xy[:, 0].min() + x_max = xy[:, 0].max() + y_min = xy[:, 1].min() + y_max = xy[:, 1].max() + xyxy = np.array([[x_min, y_min, x_max, y_max]], dtype=np.float32) + + if self.confidence is None: + confidence = None + else: + confidence = self.confidence[i] + if selected_keypoint_indices: + confidence = confidence[selected_keypoint_indices] + confidence = np.array([confidence.mean()], dtype=np.float32) + + detections_list.append( + Detections( + xyxy=xyxy, + confidence=confidence, + ) + ) + + detections = Detections.merge(detections_list) + detections.class_id = self.class_id + detections.data = self.data + detections = detections[detections.area > 0] + + return detections diff --git a/supervision/metrics/f1_score.py b/supervision/metrics/f1_score.py index cc8c87a2c..98cb5f265 100644 --- a/supervision/metrics/f1_score.py +++ b/supervision/metrics/f1_score.py @@ -47,10 +47,30 @@ class F1Score(Metric): f1_metric = F1Score() f1_result = f1_metric.update(predictions, targets).compute() - print(f1_result) print(f1_result.f1_50) - print(f1_result.small_objects.f1_50) + # 0.7618 + + print(f1_result) + # F1ScoreResult: + # Metric target: MetricTarget.BOXES + # Averaging method: AveragingMethod.WEIGHTED + # F1 @ 50: 0.7618 + # F1 @ 75: 0.7487 + # F1 @ thresh: [0.76175 0.76068 0.76068] + # IoU thresh: [0.5 0.55 0.6 ...] + # F1 per class: + # 0: [0.70968 0.70968 0.70968 ...] + # ... + # Small objects: ... + # Medium objects: ... + # Large objects: ... + + f1_result.plot() ``` + + ![example_plot](\ + https://media.roboflow.com/supervision-docs/metrics/f1_plot_example.png\ + ){ align=center width="800" } """ def __init__( @@ -456,11 +476,11 @@ class F1ScoreResult: matched_classes (np.ndarray): the class IDs of all matched classes. Corresponds to the rows of `f1_per_class`. small_objects (Optional[F1ScoreResult]): the F1 metric results - for small objects. + for small objects (area < 32²). medium_objects (Optional[F1ScoreResult]): the F1 metric results - for medium objects. + for medium objects (32² ≤ area < 96²). large_objects (Optional[F1ScoreResult]): the F1 metric results - for large objects. + for large objects (area ≥ 96²). """ metric_target: MetricTarget @@ -490,6 +510,19 @@ def __str__(self) -> str: Example: ```python print(f1_result) + # F1ScoreResult: + # Metric target: MetricTarget.BOXES + # Averaging method: AveragingMethod.WEIGHTED + # F1 @ 50: 0.7618 + # F1 @ 75: 0.7487 + # F1 @ thresh: [0.76175 0.76068 0.76068] + # IoU thresh: [0.5 0.55 0.6 ...] + # F1 per class: + # 0: [0.70968 0.70968 0.70968 ...] + # ... + # Small objects: ... + # Medium objects: ... + # Large objects: ... ``` """ out_str = ( @@ -553,6 +586,10 @@ def to_pandas(self) -> "pd.DataFrame": def plot(self): """ Plot the F1 results. + + ![example_plot](\ + https://media.roboflow.com/supervision-docs/metrics/f1_plot_example.png\ + ){ align=center width="800" } """ labels = ["F1@50", "F1@75"] diff --git a/supervision/metrics/mean_average_precision.py b/supervision/metrics/mean_average_precision.py index ba37837b3..9e7a30d0e 100644 --- a/supervision/metrics/mean_average_precision.py +++ b/supervision/metrics/mean_average_precision.py @@ -42,10 +42,31 @@ class MeanAveragePrecision(Metric): map_metric = MeanAveragePrecision() map_result = map_metric.update(predictions, targets).compute() - print(map_result) print(map_result.map50_95) + # 0.4674 + + print(map_result) + # MeanAveragePrecisionResult: + # Metric target: MetricTarget.BOXES + # Class agnostic: False + # mAP @ 50:95: 0.4674 + # mAP @ 50: 0.5048 + # mAP @ 75: 0.4796 + # mAP scores: [0.50485 0.50377 0.50377 ...] + # IoU thresh: [0.5 0.55 0.6 ...] + # AP per class: + # 0: [0.67699 0.67699 0.67699 ...] + # ... + # Small objects: ... + # Medium objects: ... + # Large objects: ... + map_result.plot() ``` + + ![example_plot](\ + https://media.roboflow.com/supervision-docs/metrics/mAP_plot_example.png\ + ){ align=center width="800" } """ def __init__( @@ -419,11 +440,11 @@ class and IoU threshold. Shape: `(num_target_classes, num_iou_thresholds)` matched_classes (np.ndarray): the class IDs of all matched classes. Corresponds to the rows of `ap_per_class`. small_objects (Optional[MeanAveragePrecisionResult]): the mAP results - for small objects. + for small objects (area < 32²). medium_objects (Optional[MeanAveragePrecisionResult]): the mAP results - for medium objects. + for medium objects (32² ≤ area < 96²). large_objects (Optional[MeanAveragePrecisionResult]): the mAP results - for large objects. + for large objects (area ≥ 96²). """ metric_target: MetricTarget @@ -456,6 +477,20 @@ def __str__(self) -> str: Example: ```python print(map_result) + # MeanAveragePrecisionResult: + # Metric target: MetricTarget.BOXES + # Class agnostic: False + # mAP @ 50:95: 0.4674 + # mAP @ 50: 0.5048 + # mAP @ 75: 0.4796 + # mAP scores: [0.50485 0.50377 0.50377 ...] + # IoU thresh: [0.5 0.55 0.6 ...] + # AP per class: + # 0: [0.67699 0.67699 0.67699 ...] + # ... + # Small objects: ... + # Medium objects: ... + # Large objects: ... ``` """ @@ -527,6 +562,10 @@ def to_pandas(self) -> "pd.DataFrame": def plot(self): """ Plot the mAP results. + + ![example_plot](\ + https://media.roboflow.com/supervision-docs/metrics/mAP_plot_example.png\ + ){ align=center width="800" } """ labels = ["mAP@50:95", "mAP@50", "mAP@75"] diff --git a/supervision/metrics/precision.py b/supervision/metrics/precision.py index fa6cf2b1a..a5d4011e8 100644 --- a/supervision/metrics/precision.py +++ b/supervision/metrics/precision.py @@ -50,10 +50,30 @@ class Precision(Metric): precision_metric = Precision() precision_result = precision_metric.update(predictions, targets).compute() - print(precision_result) print(precision_result.precision_at_50) + # 0.8099 + + print(precision_result) + # PrecisionResult: + # Metric target: MetricTarget.BOXES + # Averaging method: AveragingMethod.WEIGHTED + # P @ 50: 0.8099 + # P @ 75: 0.7969 + # P @ thresh: [0.80992 0.80905 0.80905 ...] + # IoU thresh: [0.5 0.55 0.6 ...] + # Precision per class: + # 0: [0.64706 0.64706 0.64706 ...] + # ... + # Small objects: ... + # Medium objects: ... + # Large objects: ... + print(precision_result.small_objects.precision_at_50) ``` + + ![example_plot](\ + https://media.roboflow.com/supervision-docs/metrics/precision_plot_example.png\ + ){ align=center width="800" } """ def __init__( @@ -459,11 +479,11 @@ class PrecisionResult: matched_classes (np.ndarray): the class IDs of all matched classes. Corresponds to the rows of `precision_per_class`. small_objects (Optional[PrecisionResult]): the Precision metric results - for small objects. + for small objects (area < 32²). medium_objects (Optional[PrecisionResult]): the Precision metric results - for medium objects. + for medium objects (32² ≤ area < 96²). large_objects (Optional[PrecisionResult]): the Precision metric results - for large objects. + for large objects (area ≥ 96²). """ metric_target: MetricTarget @@ -493,6 +513,19 @@ def __str__(self) -> str: Example: ```python print(precision_result) + # PrecisionResult: + # Metric target: MetricTarget.BOXES + # Averaging method: AveragingMethod.WEIGHTED + # P @ 50: 0.8099 + # P @ 75: 0.7969 + # P @ thresh: [0.80992 0.80905 0.80905 ...] + # IoU thresh: [0.5 0.55 0.6 ...] + # Precision per class: + # 0: [0.64706 0.64706 0.64706 ...] + # ... + # Small objects: ... + # Medium objects: ... + # Large objects: ... ``` """ out_str = ( @@ -558,6 +591,10 @@ def to_pandas(self) -> "pd.DataFrame": def plot(self): """ Plot the precision results. + + ![example_plot](\ + https://media.roboflow.com/supervision-docs/metrics/precision_plot_example.png\ + ){ align=center width="800" } """ labels = ["Precision@50", "Precision@75"] diff --git a/supervision/metrics/recall.py b/supervision/metrics/recall.py index 21bba1a6e..b3586ff7d 100644 --- a/supervision/metrics/recall.py +++ b/supervision/metrics/recall.py @@ -50,10 +50,31 @@ class Recall(Metric): recall_metric = Recall() recall_result = recall_metric.update(predictions, targets).compute() - print(recall_result) print(recall_result.recall_at_50) - print(recall_result.small_objects.recall_at_50) + # 0.7615 + + print(recall_result) + # RecallResult: + # Metric target: MetricTarget.BOXES + # Averaging method: AveragingMethod.WEIGHTED + # R @ 50: 0.7615 + # R @ 75: 0.7462 + # R @ thresh: [0.76151 0.76011 0.76011 0.75732 ...] + # IoU thresh: [0.5 0.55 0.6 ...] + # Recall per class: + # 0: [0.78571 0.78571 0.78571 ...] + # ... + # Small objects: ... + # Medium objects: ... + # Large objects: ... + + recall_result.plot() + ``` + + ![example_plot](\ + https://media.roboflow.com/supervision-docs/metrics/recall_plot_example.png\ + ){ align=center width="800" } """ def __init__( @@ -457,11 +478,11 @@ class RecallResult: matched_classes (np.ndarray): the class IDs of all matched classes. Corresponds to the rows of `recall_per_class`. small_objects (Optional[RecallResult]): the Recall metric results - for small objects. + for small objects (area < 32²). medium_objects (Optional[RecallResult]): the Recall metric results - for medium objects. + for medium objects (32² ≤ area < 96²). large_objects (Optional[RecallResult]): the Recall metric results - for large objects. + for large objects (area ≥ 96²). """ metric_target: MetricTarget @@ -491,6 +512,19 @@ def __str__(self) -> str: Example: ```python print(recall_result) + # RecallResult: + # Metric target: MetricTarget.BOXES + # Averaging method: AveragingMethod.WEIGHTED + # R @ 50: 0.7615 + # R @ 75: 0.7462 + # R @ thresh: [0.76151 0.76011 0.76011 0.75732 ...] + # IoU thresh: [0.5 0.55 0.6 ...] + # Recall per class: + # 0: [0.78571 0.78571 0.78571 ...] + # ... + # Small objects: ... + # Medium objects: ... + # Large objects: ... ``` """ out_str = ( @@ -556,6 +590,10 @@ def to_pandas(self) -> "pd.DataFrame": def plot(self): """ Plot the recall results. + + ![example_plot](\ + https://media.roboflow.com/supervision-docs/metrics/recall_plot_example.png\ + ){ align=center width="800" } """ labels = ["Recall@50", "Recall@75"] diff --git a/supervision/utils/datatypes.py b/supervision/utils/datatypes.py deleted file mode 100644 index 6f2e8879c..000000000 --- a/supervision/utils/datatypes.py +++ /dev/null @@ -1,72 +0,0 @@ -from typing import Iterable, Optional - -import numpy as np - -from supervision.detection.core import Detections -from supervision.keypoint.core import KeyPoints - - -def keypoints_to_detections( - keypoints: KeyPoints, selected_keypoint_indices: Optional[Iterable[int]] = None -) -> Detections: - """ - Convert a KeyPoints object to a Detections object. This - approximates the bounding box of the detected object by - taking the bounding box that fits all keypoints. - - Arguments: - keypoints (KeyPoints): The keypoints to convert to detections. - selected_keypoint_indices (Optional[Iterable[int]]): The - indices of the keypoints to include in the bounding box - calculation. This helps focus on a subset of keypoints, - e.g. when some are occluded. Captures all keypoints by default. - - Returns: - detections (Detections): The converted detections object. - - Example: - ```python - keypoints = sv.KeyPoints.from_inference(...) - detections = keypoints_to_detections(keypoints) - ``` - """ - if keypoints.is_empty(): - return Detections.empty() - - detections_list = [] - for i, xy in enumerate(keypoints.xy): - if selected_keypoint_indices: - xy = xy[selected_keypoint_indices] - - # [0, 0] used by some frameworks to indicate missing keypoints - xy = xy[~np.all(xy == 0, axis=1)] - if len(xy) == 0: - xyxy = np.array([[0, 0, 0, 0]], dtype=np.float32) - else: - x_min = xy[:, 0].min() - x_max = xy[:, 0].max() - y_min = xy[:, 1].min() - y_max = xy[:, 1].max() - xyxy = np.array([[x_min, y_min, x_max, y_max]], dtype=np.float32) - - if keypoints.confidence is None: - confidence = None - else: - confidence = keypoints.confidence[i] - if selected_keypoint_indices: - confidence = confidence[selected_keypoint_indices] - confidence = np.array([confidence.mean()], dtype=np.float32) - - detections_list.append( - Detections( - xyxy=xyxy, - confidence=confidence, - ) - ) - - detections = Detections.merge(detections_list) - detections.class_id = keypoints.class_id - detections.data = keypoints.data - detections = detections[detections.area > 0] - - return detections diff --git a/test/detection/test_core.py b/test/detection/test_core.py index b857250e0..61796bef2 100644 --- a/test/detection/test_core.py +++ b/test/detection/test_core.py @@ -106,6 +106,26 @@ "never_seen_key": [9], }, ) +TEST_DET_WITH_METADATA = Detections( + xyxy=np.array([[10, 10, 20, 20]]), + class_id=np.array([1]), + metadata={"source": "camera1"}, +) + +TEST_DET_WITH_METADATA_2 = Detections( + xyxy=np.array([[30, 30, 40, 40]]), + class_id=np.array([2]), + metadata={"source": "camera1"}, +) +TEST_DET_NO_METADATA = Detections( + xyxy=np.array([[10, 10, 20, 20]]), + class_id=np.array([1]), +) +TEST_DET_DIFFERENT_METADATA = Detections( + xyxy=np.array([[50, 50, 60, 60]]), + class_id=np.array([3]), + metadata={"source": "camera2"}, +) @pytest.mark.parametrize( @@ -258,6 +278,11 @@ def test_getitem( TEST_DET_1, DoesNotRaise(), ), # Single detection and empty-array fields + ( + [TEST_DET_ZERO_LENGTH, TEST_DET_ZERO_LENGTH], + TEST_DET_ZERO_LENGTH, + DoesNotRaise(), + ), # Zero-length fields across all Detections ( [ TEST_DET_1, @@ -287,12 +312,190 @@ def test_getitem( Detections.empty(), ], mock_detections( - xyxy=[[10, 10, 20, 20]], + xyxy=np.array([[10, 10, 20, 20]]), class_id=[1], mask=[np.zeros((4, 4), dtype=bool)], ), DoesNotRaise(), ), # Segmentation + Empty + # Metadata + ( + [ + Detections( + xyxy=np.array([[10, 10, 20, 20]]), + class_id=np.array([1]), + metadata={"source": "camera1"}, + ), + Detections.empty(), + ], + Detections( + xyxy=np.array([[10, 10, 20, 20]]), + class_id=np.array([1]), + metadata={"source": "camera1"}, + ), + DoesNotRaise(), + ), # Metadata merge with empty detections + ( + [ + Detections( + xyxy=np.array([[10, 10, 20, 20]]), + class_id=np.array([1]), + metadata={"source": "camera1"}, + ), + Detections(xyxy=np.array([[30, 30, 40, 40]]), class_id=np.array([2])), + ], + None, + pytest.raises(ValueError), + ), # Empty and non-empty metadata + ( + [ + Detections( + xyxy=np.array([[10, 10, 20, 20]]), + class_id=np.array([1]), + metadata={"source": "camera1"}, + ) + ], + Detections( + xyxy=np.array([[10, 10, 20, 20]]), + class_id=np.array([1]), + metadata={"source": "camera1"}, + ), + DoesNotRaise(), + ), # Single detection with metadata + ( + [ + Detections( + xyxy=np.array([[10, 10, 20, 20]]), + class_id=np.array([1]), + metadata={"source": "camera1"}, + ), + Detections( + xyxy=np.array([[30, 30, 40, 40]]), + class_id=np.array([2]), + metadata={"source": "camera1"}, + ), + ], + Detections( + xyxy=np.array([[10, 10, 20, 20], [30, 30, 40, 40]]), + class_id=np.array([1, 2]), + metadata={"source": "camera1"}, + ), + DoesNotRaise(), + ), # Multiple metadata entries with identical values + ( + [ + Detections( + xyxy=np.array([[10, 10, 20, 20]]), + class_id=np.array([1]), + metadata={"source": "camera1"}, + ), + Detections( + xyxy=np.array([[50, 50, 60, 60]]), + class_id=np.array([3]), + metadata={"source": "camera2"}, + ), + ], + None, + pytest.raises(ValueError), + ), # Different metadata values + ( + [ + Detections( + xyxy=np.array([[10, 10, 20, 20]]), + metadata={"source": "camera1", "resolution": "1080p"}, + ), + Detections( + xyxy=np.array([[30, 30, 40, 40]]), + metadata={"source": "camera1", "resolution": "1080p"}, + ), + ], + Detections( + xyxy=np.array([[10, 10, 20, 20], [30, 30, 40, 40]]), + metadata={"source": "camera1", "resolution": "1080p"}, + ), + DoesNotRaise(), + ), # Large metadata with multiple identical entries + ( + [ + Detections( + xyxy=np.array([[10, 10, 20, 20]]), metadata={"source": "camera1"} + ), + Detections( + xyxy=np.array([[30, 30, 40, 40]]), metadata={"source": ["camera1"]} + ), + ], + None, + pytest.raises(ValueError), + ), # Inconsistent types in metadata values + ( + [ + Detections( + xyxy=np.array([[10, 10, 20, 20]]), metadata={"source": "camera1"} + ), + Detections( + xyxy=np.array([[30, 30, 40, 40]]), metadata={"location": "indoor"} + ), + ], + None, + pytest.raises(ValueError), + ), # Metadata key mismatch + ( + [ + Detections( + xyxy=np.array([[10, 10, 20, 20]]), + metadata={ + "source": "camera1", + "settings": {"resolution": "1080p", "fps": 30}, + }, + ), + Detections( + xyxy=np.array([[30, 30, 40, 40]]), + metadata={ + "source": "camera1", + "settings": {"resolution": "1080p", "fps": 30}, + }, + ), + ], + Detections( + xyxy=np.array([[10, 10, 20, 20], [30, 30, 40, 40]]), + metadata={ + "source": "camera1", + "settings": {"resolution": "1080p", "fps": 30}, + }, + ), + DoesNotRaise(), + ), # multi-field metadata + ( + [ + Detections( + xyxy=np.array([[10, 10, 20, 20]]), + metadata={"calibration_matrix": np.array([[1, 0], [0, 1]])}, + ), + Detections( + xyxy=np.array([[30, 30, 40, 40]]), + metadata={"calibration_matrix": np.array([[1, 0], [0, 1]])}, + ), + ], + Detections( + xyxy=np.array([[10, 10, 20, 20], [30, 30, 40, 40]]), + metadata={"calibration_matrix": np.array([[1, 0], [0, 1]])}, + ), + DoesNotRaise(), + ), # Identical 2D numpy arrays in metadata + ( + [ + Detections( + xyxy=np.array([[10, 10, 20, 20]]), + metadata={"calibration_matrix": np.array([[1, 0], [0, 1]])}, + ), + Detections( + xyxy=np.array([[30, 30, 40, 40]]), + metadata={"calibration_matrix": np.array([[2, 0], [0, 2]])}, + ), + ], + None, + pytest.raises(ValueError), + ), # Mismatching 2D numpy arrays in metadata ], ) def test_merge( @@ -302,7 +505,7 @@ def test_merge( ) -> None: with exception: result = Detections.merge(detections_list=detections_list) - assert result == expected_result + assert result == expected_result, f"Expected: {expected_result}, Got: {result}" @pytest.mark.parametrize( diff --git a/test/detection/test_line_counter.py b/test/detection/test_line_counter.py index d0ec5fbd2..a140add55 100644 --- a/test/detection/test_line_counter.py +++ b/test/detection/test_line_counter.py @@ -493,7 +493,7 @@ def test_line_zone_multiple_detections( @pytest.mark.parametrize( - "vector, xyxy_sequence, triggering_anchors, crossing_acceptance_threshold, " + "vector, xyxy_sequence, triggering_anchors, minimum_crossing_threshold, " "expected_crossed_in, expected_crossed_out", [ ( # Detection lingers around line, all crosses counted @@ -578,7 +578,7 @@ def test_line_zone_one_detection_long_horizon( vector: Vector, xyxy_sequence: List[List[float]], triggering_anchors: List[Position], - crossing_acceptance_threshold: int, + minimum_crossing_threshold: int, expected_crossed_in: List[bool], expected_crossed_out: List[bool], ) -> None: @@ -586,7 +586,7 @@ def test_line_zone_one_detection_long_horizon( start=vector.start, end=vector.end, triggering_anchors=triggering_anchors, - crossing_acceptance_threshold=crossing_acceptance_threshold, + minimum_crossing_threshold=minimum_crossing_threshold, ) crossed_in_list = [] @@ -609,7 +609,7 @@ def test_line_zone_one_detection_long_horizon( @pytest.mark.parametrize( - "vector, xyxy_sequence, anchors, crossing_acceptance_threshold, " + "vector, xyxy_sequence, anchors, minimum_crossing_threshold, " "expected_crossed_in, expected_crossed_out, expected_count_in, " "expected_count_out, exception", [ @@ -743,7 +743,7 @@ def test_line_zone_long_horizon_disappearing_detections( vector: Vector, xyxy_sequence: List[List[Optional[List[float]]]], anchors: List[Position], - crossing_acceptance_threshold: int, + minimum_crossing_threshold: int, expected_crossed_in: List[List[bool]], expected_crossed_out: List[List[bool]], expected_count_in: List[int], @@ -755,7 +755,7 @@ def test_line_zone_long_horizon_disappearing_detections( start=vector.start, end=vector.end, triggering_anchors=anchors, - crossing_acceptance_threshold=crossing_acceptance_threshold, + minimum_crossing_threshold=minimum_crossing_threshold, ) crossed_in_list = [] crossed_out_list = [] diff --git a/test/detection/test_utils.py b/test/detection/test_utils.py index 77c4cea54..87e50f6a4 100644 --- a/test/detection/test_utils.py +++ b/test/detection/test_utils.py @@ -14,6 +14,7 @@ filter_polygons_by_area, get_data_item, merge_data, + merge_metadata, move_boxes, process_roboflow_result, scale_boxes, @@ -1138,3 +1139,163 @@ def test_xywh_to_xyxy(xywh: np.ndarray, expected_result: np.ndarray) -> None: def test_xcycwh_to_xyxy(xcycwh: np.ndarray, expected_result: np.ndarray) -> None: result = xcycwh_to_xyxy(xcycwh) np.testing.assert_array_equal(result, expected_result) + + +@pytest.mark.parametrize( + "metadata_list, expected_result, exception", + [ + # Identical metadata with a single key + ([{"key1": "value1"}, {"key1": "value1"}], {"key1": "value1"}, DoesNotRaise()), + # Identical metadata with multiple keys + ( + [ + {"key1": "value1", "key2": "value2"}, + {"key1": "value1", "key2": "value2"}, + ], + {"key1": "value1", "key2": "value2"}, + DoesNotRaise(), + ), + # Conflicting values for the same key + ([{"key1": "value1"}, {"key1": "value2"}], None, pytest.raises(ValueError)), + # Different sets of keys across dictionaries + ([{"key1": "value1"}, {"key2": "value2"}], None, pytest.raises(ValueError)), + # Empty metadata list + ([], {}, DoesNotRaise()), + # Empty metadata dictionaries + ([{}, {}], {}, DoesNotRaise()), + # Different declaration order for keys + ( + [ + {"key1": "value1", "key2": "value2"}, + {"key2": "value2", "key1": "value1"}, + ], + {"key1": "value1", "key2": "value2"}, + DoesNotRaise(), + ), + # Nested metadata dictionaries + ( + [{"key1": {"sub_key": "sub_value"}}, {"key1": {"sub_key": "sub_value"}}], + {"key1": {"sub_key": "sub_value"}}, + DoesNotRaise(), + ), + # Large metadata dictionaries with many keys + ( + [ + {f"key{i}": f"value{i}" for i in range(100)}, + {f"key{i}": f"value{i}" for i in range(100)}, + ], + {f"key{i}": f"value{i}" for i in range(100)}, + DoesNotRaise(), + ), + # Mixed types in list metadata values + ( + [{"key1": ["value1", 2, True]}, {"key1": ["value1", 2, True]}], + {"key1": ["value1", 2, True]}, + DoesNotRaise(), + ), + # Identical lists across metadata dictionaries + ( + [{"key1": [1, 2, 3]}, {"key1": [1, 2, 3]}], + {"key1": [1, 2, 3]}, + DoesNotRaise(), + ), + # Identical numpy arrays across metadata dictionaries + ( + [{"key1": np.array([1, 2, 3])}, {"key1": np.array([1, 2, 3])}], + {"key1": np.array([1, 2, 3])}, + DoesNotRaise(), + ), + # Identical numpy arrays across metadata dictionaries, different datatype + ( + [ + {"key1": np.array([1, 2, 3], dtype=np.int32)}, + {"key1": np.array([1, 2, 3], dtype=np.int64)}, + ], + {"key1": np.array([1, 2, 3])}, + DoesNotRaise(), + ), + # Conflicting lists for the same key + ([{"key1": [1, 2, 3]}, {"key1": [4, 5, 6]}], None, pytest.raises(ValueError)), + # Conflicting numpy arrays for the same key + ( + [{"key1": np.array([1, 2, 3])}, {"key1": np.array([4, 5, 6])}], + None, + pytest.raises(ValueError), + ), + # Mixed data types: list and numpy array for the same key + ( + [{"key1": [1, 2, 3]}, {"key1": np.array([1, 2, 3])}], + None, + pytest.raises(ValueError), + ), + # Empty lists and numpy arrays for the same key + ([{"key1": []}, {"key1": np.array([])}], None, pytest.raises(ValueError)), + # Identical multi-dimensional lists across metadata dictionaries + ( + [{"key1": [[1, 2], [3, 4]]}, {"key1": [[1, 2], [3, 4]]}], + {"key1": [[1, 2], [3, 4]]}, + DoesNotRaise(), + ), + # Identical multi-dimensional numpy arrays across metadata dictionaries + ( + [ + {"key1": np.arange(4).reshape(2, 2)}, + {"key1": np.arange(4).reshape(2, 2)}, + ], + {"key1": np.arange(4).reshape(2, 2)}, + DoesNotRaise(), + ), + # Conflicting multi-dimensional lists for the same key + ( + [{"key1": [[1, 2], [3, 4]]}, {"key1": [[5, 6], [7, 8]]}], + None, + pytest.raises(ValueError), + ), + # Conflicting multi-dimensional numpy arrays for the same key + ( + [ + {"key1": np.arange(4).reshape(2, 2)}, + {"key1": np.arange(4, 8).reshape(2, 2)}, + ], + None, + pytest.raises(ValueError), + ), + # Mixed types with multi-dimensional list and array for the same key + ( + [{"key1": [[1, 2], [3, 4]]}, {"key1": np.arange(4).reshape(2, 2)}], + None, + pytest.raises(ValueError), + ), + # Identical higher-dimensional (3D) numpy arrays across + # metadata dictionaries + ( + [ + {"key1": np.arange(8).reshape(2, 2, 2)}, + {"key1": np.arange(8).reshape(2, 2, 2)}, + ], + {"key1": np.arange(8).reshape(2, 2, 2)}, + DoesNotRaise(), + ), + # Differently-shaped higher-dimensional (3D) numpy arrays + # across metadata dictionaries + ( + [ + {"key1": np.arange(8).reshape(2, 2, 2)}, + {"key1": np.arange(8).reshape(4, 1, 2)}, + ], + None, + pytest.raises(ValueError), + ), + ], +) +def test_merge_metadata(metadata_list, expected_result, exception): + with exception: + result = merge_metadata(metadata_list) + if expected_result is None: + assert result is None, f"Expected an error, but got a result {result}" + for key, value in result.items(): + assert key in expected_result + if isinstance(value, np.ndarray): + np.testing.assert_array_equal(value, expected_result[key]) + else: + assert value == expected_result[key]