Skip to content

Commit

Permalink
Merge pull request #572 from roboflow/add-timm-classification-loader
Browse files Browse the repository at this point in the history
Add timm data loader for `sv.Classifications`
  • Loading branch information
SkalskiP authored Nov 27, 2023
2 parents 2dcf2cc + 1d66840 commit 1589d8f
Showing 1 changed file with 47 additions and 4 deletions.
51 changes: 47 additions & 4 deletions supervision/classification/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def from_ultralytics(cls, ultralytics_results) -> Classifications:
Args:
ultralytics_results (ultralytics.engine.results.Results):
The output Results instance from ultralytics model
The inference result from ultralytics model.
Returns:
Classifications: A new Classifications object.
Expand All @@ -60,15 +60,58 @@ def from_ultralytics(cls, ultralytics_results) -> Classifications:
>>> image = cv2.imread(SOURCE_IMAGE_PATH)
>>> model = YOLO('yolov8n-cls.pt')
>>> model = YOLO('yolov8s-cls.pt')
>>> result = model(image)[0]
>>> classifications = sv.Classifications.from_ultralytics(result)
>>> output = model(image)[0]
>>> classifications = sv.Classifications.from_ultralytics(output)
```
"""
confidence = ultralytics_results.probs.data.cpu().numpy()
return cls(class_id=np.arange(confidence.shape[0]), confidence=confidence)

@classmethod
def from_timm(cls, timm_results) -> Classifications:
"""
Creates a Classifications instance from a
timm (https://huggingface.co/docs/hub/timm) inference result.
Args:
timm_results: The inference result from timm model.
Returns:
Classifications: A new Classifications object.
Example:
```python
>>> from PIL import Image
>>> import timm
>>> from timm.data import resolve_data_config, create_transform
>>> import supervision as sv
>>> model = timm.create_model(
... model_name='hf-hub:nateraw/resnet50-oxford-iiit-pet',
... pretrained=True
... ).eval()
>>> config = resolve_data_config({}, model=model)
>>> transform = create_transform(**config)
>>> image = Image.open(SOURCE_IMAGE_PATH).convert('RGB')
>>> x = transform(image).unsqueeze(0)
>>> output = model(x)
>>> classifications = sv.Classifications.from_timm(output)
```
"""
confidence = timm_results.cpu().detach().numpy()[0]

if len(confidence) == 0:
return cls(class_id=np.array([]), confidence=np.array([]))

class_id = np.arange(len(confidence))

return cls(class_id=class_id, confidence=confidence)

def get_top_k(self, k: int) -> Tuple[np.ndarray, np.ndarray]:
"""
Retrieve the top k class IDs and confidences,
Expand Down

0 comments on commit 1589d8f

Please sign in to comment.