Skip to content

Commit

Permalink
Unit tests for merge_metadata, fixes involving numpy arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
LinasKo committed Nov 8, 2024
1 parent e4cf743 commit b72ad46
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 3 deletions.
22 changes: 19 additions & 3 deletions supervision/detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,11 +919,27 @@ def merge_metadata(metadata_list: List[Dict[str, Any]]) -> Dict[str, Any]:
merged_metadata: Dict[str, Any] = {}
for metadata in metadata_list:
for key, value in metadata.items():
if key in merged_metadata:
if key not in merged_metadata:
merged_metadata[key] = value
continue

other_value = merged_metadata[key]
if isinstance(value, np.ndarray) and isinstance(other_value, np.ndarray):
if not np.array_equal(merged_metadata[key], value):
raise ValueError(
f"Conflicting metadata for key: '{key}': "
"{type(value)}, {type(other_value)}."
)
elif isinstance(value, np.ndarray) or isinstance(other_value, np.ndarray):
# Since [] == np.array([]).
raise ValueError(
f"Conflicting metadata for key: '{key}': "
"{type(value)}, {type(other_value)}."
)
else:
print("hm")
if merged_metadata[key] != value:
raise ValueError(f"Conflicting metadata for key: '{key}'.")
else:
merged_metadata[key] = value

return merged_metadata

Expand Down
161 changes: 161 additions & 0 deletions test/detection/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
filter_polygons_by_area,
get_data_item,
merge_data,
merge_metadata,
move_boxes,
process_roboflow_result,
scale_boxes,
Expand Down Expand Up @@ -1138,3 +1139,163 @@ def test_xywh_to_xyxy(xywh: np.ndarray, expected_result: np.ndarray) -> None:
def test_xcycwh_to_xyxy(xcycwh: np.ndarray, expected_result: np.ndarray) -> None:
result = xcycwh_to_xyxy(xcycwh)
np.testing.assert_array_equal(result, expected_result)


@pytest.mark.parametrize(
"metadata_list, expected_result, exception",
[
# Identical metadata with a single key
([{"key1": "value1"}, {"key1": "value1"}], {"key1": "value1"}, DoesNotRaise()),
# Identical metadata with multiple keys
(
[
{"key1": "value1", "key2": "value2"},
{"key1": "value1", "key2": "value2"},
],
{"key1": "value1", "key2": "value2"},
DoesNotRaise(),
),
# Conflicting values for the same key
([{"key1": "value1"}, {"key1": "value2"}], None, pytest.raises(ValueError)),
# Different sets of keys across dictionaries
([{"key1": "value1"}, {"key2": "value2"}], None, pytest.raises(ValueError)),
# Empty metadata list
([], {}, DoesNotRaise()),
# Empty metadata dictionaries
([{}, {}], {}, DoesNotRaise()),
# Different declaration order for keys
(
[
{"key1": "value1", "key2": "value2"},
{"key2": "value2", "key1": "value1"},
],
{"key1": "value1", "key2": "value2"},
DoesNotRaise(),
),
# Nested metadata dictionaries
(
[{"key1": {"sub_key": "sub_value"}}, {"key1": {"sub_key": "sub_value"}}],
{"key1": {"sub_key": "sub_value"}},
DoesNotRaise(),
),
# Large metadata dictionaries with many keys
(
[
{f"key{i}": f"value{i}" for i in range(100)},
{f"key{i}": f"value{i}" for i in range(100)},
],
{f"key{i}": f"value{i}" for i in range(100)},
DoesNotRaise(),
),
# Mixed types in list metadata values
(
[{"key1": ["value1", 2, True]}, {"key1": ["value1", 2, True]}],
{"key1": ["value1", 2, True]},
DoesNotRaise(),
),
# Identical lists across metadata dictionaries
(
[{"key1": [1, 2, 3]}, {"key1": [1, 2, 3]}],
{"key1": [1, 2, 3]},
DoesNotRaise(),
),
# Identical numpy arrays across metadata dictionaries
(
[{"key1": np.array([1, 2, 3])}, {"key1": np.array([1, 2, 3])}],
{"key1": np.array([1, 2, 3])},
DoesNotRaise(),
),
# Identical numpy arrays across metadata dictionaries, different datatype
(
[
{"key1": np.array([1, 2, 3], dtype=np.int32)},
{"key1": np.array([1, 2, 3], dtype=np.int64)},
],
{"key1": np.array([1, 2, 3])},
DoesNotRaise(),
),
# Conflicting lists for the same key
([{"key1": [1, 2, 3]}, {"key1": [4, 5, 6]}], None, pytest.raises(ValueError)),
# Conflicting numpy arrays for the same key
(
[{"key1": np.array([1, 2, 3])}, {"key1": np.array([4, 5, 6])}],
None,
pytest.raises(ValueError),
),
# Mixed data types: list and numpy array for the same key
(
[{"key1": [1, 2, 3]}, {"key1": np.array([1, 2, 3])}],
None,
pytest.raises(ValueError),
),
# Empty lists and numpy arrays for the same key
([{"key1": []}, {"key1": np.array([])}], None, pytest.raises(ValueError)),
# Identical multi-dimensional lists across metadata dictionaries
(
[{"key1": [[1, 2], [3, 4]]}, {"key1": [[1, 2], [3, 4]]}],
{"key1": [[1, 2], [3, 4]]},
DoesNotRaise(),
),
# Identical multi-dimensional numpy arrays across metadata dictionaries
(
[
{"key1": np.arange(4).reshape(2, 2)},
{"key1": np.arange(4).reshape(2, 2)},
],
{"key1": np.arange(4).reshape(2, 2)},
DoesNotRaise(),
),
# Conflicting multi-dimensional lists for the same key
(
[{"key1": [[1, 2], [3, 4]]}, {"key1": [[5, 6], [7, 8]]}],
None,
pytest.raises(ValueError),
),
# Conflicting multi-dimensional numpy arrays for the same key
(
[
{"key1": np.arange(4).reshape(2, 2)},
{"key1": np.arange(4, 8).reshape(2, 2)},
],
None,
pytest.raises(ValueError),
),
# Mixed types with multi-dimensional list and array for the same key
(
[{"key1": [[1, 2], [3, 4]]}, {"key1": np.arange(4).reshape(2, 2)}],
None,
pytest.raises(ValueError),
),
# Identical higher-dimensional (3D) numpy arrays across
# metadata dictionaries
(
[
{"key1": np.arange(8).reshape(2, 2, 2)},
{"key1": np.arange(8).reshape(2, 2, 2)},
],
{"key1": np.arange(8).reshape(2, 2, 2)},
DoesNotRaise(),
),
# Differently-shaped higher-dimensional (3D) numpy arrays
# across metadata dictionaries
(
[
{"key1": np.arange(8).reshape(2, 2, 2)},
{"key1": np.arange(8).reshape(4, 1, 2)},
],
None,
pytest.raises(ValueError),
),
],
)
def test_merge_metadata(metadata_list, expected_result, exception):
with exception:
result = merge_metadata(metadata_list)
if expected_result is None:
assert result is None, f"Expected an error, but got a result {result}"
for key, value in result.items():
assert key in expected_result
if isinstance(value, np.ndarray):
np.testing.assert_array_equal(value, expected_result[key])
else:
assert value == expected_result[key]

0 comments on commit b72ad46

Please sign in to comment.