Skip to content

Commit

Permalink
Remoe Detections.empty metadata field
Browse files Browse the repository at this point in the history
  • Loading branch information
LinasKo committed Nov 4, 2024
1 parent 847f132 commit 5cce8a1
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
16 changes: 6 additions & 10 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ def from_ncnn(cls, ncnn_results) -> Detections:
)

@classmethod
def empty(cls, metadata: Optional[Dict[str, Any]] = None) -> Detections:
def empty(cls) -> Detections:
"""
Create an empty Detections object with no bounding boxes,
confidences, or class IDs.
Expand All @@ -980,14 +980,10 @@ def empty(cls, metadata: Optional[Dict[str, Any]] = None) -> Detections:
empty_detections = Detections.empty()
```
"""
if metadata is not None and not isinstance(metadata, dict):
raise TypeError("Metadata must be a dictionary.")

return cls(
xyxy=np.empty((0, 4), dtype=np.float32),
confidence=np.array([], dtype=np.float32),
class_id=np.array([], dtype=int),
metadata=metadata if metadata is not None else {},
)

def is_empty(self) -> bool:
Expand All @@ -996,6 +992,7 @@ def is_empty(self) -> bool:
"""
empty_detections = Detections.empty()
empty_detections.data = self.data
empty_detections.metadata = self.metadata
return self == empty_detections

@classmethod
Expand Down Expand Up @@ -1052,16 +1049,12 @@ def merge(cls, detections_list: List[Detections]) -> Detections:
array([0.1, 0.2, 0.3])
```
"""
metadata_list = [detections.metadata for detections in detections_list]

detections_list = [
detections for detections in detections_list if not detections.is_empty()
]

metadata = merge_metadata(metadata_list)

if len(detections_list) == 0:
return Detections.empty(metadata=metadata)
return Detections.empty()

for detections in detections_list:
validate_detections_fields(
Expand Down Expand Up @@ -1093,6 +1086,9 @@ def stack_or_none(name: str):

data = merge_data([d.data for d in detections_list])

metadata_list = [detections.metadata for detections in detections_list]
metadata = merge_metadata(metadata_list)

return cls(
xyxy=xyxy,
mask=mask,
Expand Down
14 changes: 11 additions & 3 deletions supervision/detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,8 +820,10 @@ def is_metadata_equal(metadata_a: Dict[str, Any], metadata_b: Dict[str, Any]) ->
"""
return set(metadata_a.keys()) == set(metadata_b.keys()) and all(
np.array_equal(metadata_a[key], metadata_b[key])
if isinstance(metadata_a[key], np.ndarray)
and isinstance(metadata_b[key], np.ndarray)
if (
isinstance(metadata_a[key], np.ndarray)
and isinstance(metadata_b[key], np.ndarray)
)
else metadata_a[key] == metadata_b[key]
for key in metadata_a
)
Expand All @@ -833,6 +835,9 @@ def merge_data(
"""
Merges the data payloads of a list of Detections instances.
Warning: Assumes that empty detections were filtered-out before passing data to
this function.
Args:
data_list: The data payloads of the Detections instances. Each data payload
is a dictionary with the same keys, and the values are either lists or
Expand Down Expand Up @@ -892,6 +897,9 @@ def merge_metadata(metadata_list: List[Dict[str, Any]]) -> Dict[str, Any]:
This function combines the metadata dictionaries. If a key appears in more than one
dictionary, the values must be identical for the merge to succeed.
Warning: Assumes that empty detections were filtered-out before passing metadata to
this function.
Args:
metadata_list (List[Dict[str, Any]]): A list of metadata dictionaries to merge.
Expand All @@ -909,7 +917,7 @@ def merge_metadata(metadata_list: List[Dict[str, Any]]) -> Dict[str, Any]:
if not all(keys_set == all_keys_sets[0] for keys_set in all_keys_sets):
raise ValueError("All metadata dictionaries must have the same keys to merge.")

merged_metadata = {}
merged_metadata: Dict[str, Any] = {}
for metadata in metadata_list:
for key, value in metadata.items():
if key in merged_metadata:
Expand Down

0 comments on commit 5cce8a1

Please sign in to comment.