diff --git a/supervision/detection/utils.py b/supervision/detection/utils.py index fc4458fa2..a7ad3a389 100644 --- a/supervision/detection/utils.py +++ b/supervision/detection/utils.py @@ -919,11 +919,27 @@ def merge_metadata(metadata_list: List[Dict[str, Any]]) -> Dict[str, Any]: merged_metadata: Dict[str, Any] = {} for metadata in metadata_list: for key, value in metadata.items(): - if key in merged_metadata: + if key not in merged_metadata: + merged_metadata[key] = value + continue + + other_value = merged_metadata[key] + if isinstance(value, np.ndarray) and isinstance(other_value, np.ndarray): + if not np.array_equal(merged_metadata[key], value): + raise ValueError( + f"Conflicting metadata for key: '{key}': " + "{type(value)}, {type(other_value)}." + ) + elif isinstance(value, np.ndarray) or isinstance(other_value, np.ndarray): + # Since [] == np.array([]). + raise ValueError( + f"Conflicting metadata for key: '{key}': " + "{type(value)}, {type(other_value)}." + ) + else: + print("hm") if merged_metadata[key] != value: raise ValueError(f"Conflicting metadata for key: '{key}'.") - else: - merged_metadata[key] = value return merged_metadata diff --git a/test/detection/test_utils.py b/test/detection/test_utils.py index 77c4cea54..87e50f6a4 100644 --- a/test/detection/test_utils.py +++ b/test/detection/test_utils.py @@ -14,6 +14,7 @@ filter_polygons_by_area, get_data_item, merge_data, + merge_metadata, move_boxes, process_roboflow_result, scale_boxes, @@ -1138,3 +1139,163 @@ def test_xywh_to_xyxy(xywh: np.ndarray, expected_result: np.ndarray) -> None: def test_xcycwh_to_xyxy(xcycwh: np.ndarray, expected_result: np.ndarray) -> None: result = xcycwh_to_xyxy(xcycwh) np.testing.assert_array_equal(result, expected_result) + + +@pytest.mark.parametrize( + "metadata_list, expected_result, exception", + [ + # Identical metadata with a single key + ([{"key1": "value1"}, {"key1": "value1"}], {"key1": "value1"}, DoesNotRaise()), + # Identical metadata with multiple keys + ( + [ + {"key1": "value1", "key2": "value2"}, + {"key1": "value1", "key2": "value2"}, + ], + {"key1": "value1", "key2": "value2"}, + DoesNotRaise(), + ), + # Conflicting values for the same key + ([{"key1": "value1"}, {"key1": "value2"}], None, pytest.raises(ValueError)), + # Different sets of keys across dictionaries + ([{"key1": "value1"}, {"key2": "value2"}], None, pytest.raises(ValueError)), + # Empty metadata list + ([], {}, DoesNotRaise()), + # Empty metadata dictionaries + ([{}, {}], {}, DoesNotRaise()), + # Different declaration order for keys + ( + [ + {"key1": "value1", "key2": "value2"}, + {"key2": "value2", "key1": "value1"}, + ], + {"key1": "value1", "key2": "value2"}, + DoesNotRaise(), + ), + # Nested metadata dictionaries + ( + [{"key1": {"sub_key": "sub_value"}}, {"key1": {"sub_key": "sub_value"}}], + {"key1": {"sub_key": "sub_value"}}, + DoesNotRaise(), + ), + # Large metadata dictionaries with many keys + ( + [ + {f"key{i}": f"value{i}" for i in range(100)}, + {f"key{i}": f"value{i}" for i in range(100)}, + ], + {f"key{i}": f"value{i}" for i in range(100)}, + DoesNotRaise(), + ), + # Mixed types in list metadata values + ( + [{"key1": ["value1", 2, True]}, {"key1": ["value1", 2, True]}], + {"key1": ["value1", 2, True]}, + DoesNotRaise(), + ), + # Identical lists across metadata dictionaries + ( + [{"key1": [1, 2, 3]}, {"key1": [1, 2, 3]}], + {"key1": [1, 2, 3]}, + DoesNotRaise(), + ), + # Identical numpy arrays across metadata dictionaries + ( + [{"key1": np.array([1, 2, 3])}, {"key1": np.array([1, 2, 3])}], + {"key1": np.array([1, 2, 3])}, + DoesNotRaise(), + ), + # Identical numpy arrays across metadata dictionaries, different datatype + ( + [ + {"key1": np.array([1, 2, 3], dtype=np.int32)}, + {"key1": np.array([1, 2, 3], dtype=np.int64)}, + ], + {"key1": np.array([1, 2, 3])}, + DoesNotRaise(), + ), + # Conflicting lists for the same key + ([{"key1": [1, 2, 3]}, {"key1": [4, 5, 6]}], None, pytest.raises(ValueError)), + # Conflicting numpy arrays for the same key + ( + [{"key1": np.array([1, 2, 3])}, {"key1": np.array([4, 5, 6])}], + None, + pytest.raises(ValueError), + ), + # Mixed data types: list and numpy array for the same key + ( + [{"key1": [1, 2, 3]}, {"key1": np.array([1, 2, 3])}], + None, + pytest.raises(ValueError), + ), + # Empty lists and numpy arrays for the same key + ([{"key1": []}, {"key1": np.array([])}], None, pytest.raises(ValueError)), + # Identical multi-dimensional lists across metadata dictionaries + ( + [{"key1": [[1, 2], [3, 4]]}, {"key1": [[1, 2], [3, 4]]}], + {"key1": [[1, 2], [3, 4]]}, + DoesNotRaise(), + ), + # Identical multi-dimensional numpy arrays across metadata dictionaries + ( + [ + {"key1": np.arange(4).reshape(2, 2)}, + {"key1": np.arange(4).reshape(2, 2)}, + ], + {"key1": np.arange(4).reshape(2, 2)}, + DoesNotRaise(), + ), + # Conflicting multi-dimensional lists for the same key + ( + [{"key1": [[1, 2], [3, 4]]}, {"key1": [[5, 6], [7, 8]]}], + None, + pytest.raises(ValueError), + ), + # Conflicting multi-dimensional numpy arrays for the same key + ( + [ + {"key1": np.arange(4).reshape(2, 2)}, + {"key1": np.arange(4, 8).reshape(2, 2)}, + ], + None, + pytest.raises(ValueError), + ), + # Mixed types with multi-dimensional list and array for the same key + ( + [{"key1": [[1, 2], [3, 4]]}, {"key1": np.arange(4).reshape(2, 2)}], + None, + pytest.raises(ValueError), + ), + # Identical higher-dimensional (3D) numpy arrays across + # metadata dictionaries + ( + [ + {"key1": np.arange(8).reshape(2, 2, 2)}, + {"key1": np.arange(8).reshape(2, 2, 2)}, + ], + {"key1": np.arange(8).reshape(2, 2, 2)}, + DoesNotRaise(), + ), + # Differently-shaped higher-dimensional (3D) numpy arrays + # across metadata dictionaries + ( + [ + {"key1": np.arange(8).reshape(2, 2, 2)}, + {"key1": np.arange(8).reshape(4, 1, 2)}, + ], + None, + pytest.raises(ValueError), + ), + ], +) +def test_merge_metadata(metadata_list, expected_result, exception): + with exception: + result = merge_metadata(metadata_list) + if expected_result is None: + assert result is None, f"Expected an error, but got a result {result}" + for key, value in result.items(): + assert key in expected_result + if isinstance(value, np.ndarray): + np.testing.assert_array_equal(value, expected_result[key]) + else: + assert value == expected_result[key]