Skip to content

Commit

Permalink
Merge pull request #238 from roboflow/feature/yolonas-upload
Browse files Browse the repository at this point in the history
Feature/yolonas upload
  • Loading branch information
probicheaux authored Mar 8, 2024
2 parents 33ae3b0 + 4d2d512 commit 3e85436
Showing 1 changed file with 76 additions and 3 deletions.
79 changes: 76 additions & 3 deletions roboflow/core/version.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import json
import os
import shutil
import sys
import time
import zipfile
Expand Down Expand Up @@ -425,11 +426,15 @@ def deploy(self, model_type: str, model_path: str) -> None:
model_path (str): File path to model weights to be uploaded
"""

supported_models = ["yolov5", "yolov7-seg", "yolov8", "yolov9"]
supported_models = ["yolov5", "yolov7-seg", "yolov8", "yolov9", "yolonas"]

if not any(supported_model in model_type for supported_model in supported_models):
raise (ValueError(f"Model type {model_type} not supported. Supported models are" f" {supported_models}"))

if "yolonas" in model_type:
self.deploy_yolonas(model_type, model_path)
return

if "yolov8" in model_type:
try:
import torch
Expand Down Expand Up @@ -516,15 +521,15 @@ def deploy(self, model_type: str, model_path: str) -> None:

torch.save(model["model"].state_dict(), os.path.join(model_path, "state_dict.pt"))

lista_files = [
list_files = [
"results.csv",
"results.png",
"model_artifacts.json",
"state_dict.pt",
]

with zipfile.ZipFile(os.path.join(model_path, "roboflow_deploy.zip"), "w") as zipMe:
for file in lista_files:
for file in list_files:
if os.path.exists(os.path.join(model_path, file)):
zipMe.write(
os.path.join(model_path, file),
Expand All @@ -535,6 +540,74 @@ def deploy(self, model_type: str, model_path: str) -> None:
if file in ["model_artifacts.json", "state_dict.pt"]:
raise (ValueError(f"File {file} not found. Please make sure to provide a" " valid model path."))

self.upload_zip(model_type, model_path)

def deploy_yolonas(self, model_type: str, model_path: str) -> None:
try:
import torch
except ImportError:
raise (
"The torch python package is required to deploy yolonas models."
" Please install it with `pip install torch`"
)

model = torch.load(os.path.join(model_path, "weights/best.pt"), map_location="cpu")
class_names = model["processing_params"]["class_names"]

opt_path = os.path.join(model_path, "opt.yaml")
if not os.path.exists(opt_path):
raise RuntimeError(
f"You must create an opt.yaml file at {os.path.join(model_path, '')} of the format:\n"
f"imgsz: <resolution of model>\n"
f"batch_size: <batch size of inference model>\n"
f"architecture: <one of [yolo_nas_s, yolo_nas_m, yolo_nas_l]."
f"s, m, l refer to small, medium, large architecture sizes, respectively>\n"
)
with open(os.path.join(model_path, "opt.yaml"), "r") as stream:
opts = yaml.safe_load(stream)
required_keys = ["imgsz", "batch_size", "architecture"]
for key in required_keys:
if key not in opts:
raise RuntimeError(f"{opt_path} lacks required key {key}. Required keys: {required_keys}")

model_artifacts = {
"names": class_names,
"nc": len(class_names),
"args": {
"imgsz": opts["imgsz"] if "imgsz" in opts else opts["img_size"],
"batch": opts["batch_size"],
"architecture": opts["architecture"],
},
"model_type": model_type,
}

with open(os.path.join(model_path, "model_artifacts.json"), "w") as fp:
json.dump(model_artifacts, fp)

shutil.copy(os.path.join(model_path, "weights/best.pt"), os.path.join(model_path, "state_dict.pt"))

list_files = [
"results.json",
"results.png",
"model_artifacts.json",
"state_dict.pt",
]

with zipfile.ZipFile(os.path.join(model_path, "roboflow_deploy.zip"), "w") as zipMe:
for file in list_files:
if os.path.exists(os.path.join(model_path, file)):
zipMe.write(
os.path.join(model_path, file),
arcname=file,
compress_type=zipfile.ZIP_DEFLATED,
)
else:
if file in ["model_artifacts.json", "best.pt"]:
raise (ValueError(f"File {file} not found. Please make sure to provide a" " valid model path."))

self.upload_zip(model_type, model_path)

def upload_zip(self, model_type: str, model_path: str):
res = requests.get(
f"{API_URL}/{self.workspace}/{self.project}/{self.version}"
f"/uploadModel?api_key={self.__api_key}&modelType={model_type}&nocache=true"
Expand Down

0 comments on commit 3e85436

Please sign in to comment.