-
Notifications
You must be signed in to change notification settings - Fork 99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Export to ONNX #38
Comments
Why ONNX?To start with, ONNX is optimized for inference, and works fluently cross-platforms. It can even run inference on the browser, optionally accelerated with web-GPU API. |
Can you elaborate on the problems you've got ? |
Torch version: 2.2.0.dev20231208+cpu torch.onnx._internal.diagnostics.infra.context.RuntimeErrorWithDiagnostic: Unsupported FX nodes: {'call_function': ['aten._assert_async.msg']}. Click to expand full run output
Best |
Is there any update on this? |
I found a partial solution for this, I used the conventional ONNX export on the code that @mush42 shared and got rid of the assertion errors. However the export complains about istft operation. I then modified the vocos head to output
you can find an example here and I also shared the 24khz_mel model to the HF hub https://huggingface.co/wetdog/vocos-mel-24khz-onnx |
@wetdog i followed the same as you described but my model is producing white noise instead of speech. Is it possible to post the onnx conversion code that you used? |
@sankar-mukherjee for sure, is mostly the same that shared mush42 but with the head modification. Another thing is that you have to load the weights into the model, this step was also missing in the earlier script. Code# coding: utf-8
import argparse
import logging
import os
import random
from pathlib import Path
import numpy as np
import torch
import yaml
from torch import nn
from vocos.pretrained import Vocos
DEFAULT_OPSET_VERSION = 15
_LOGGER = logging.getLogger("export_onnx")
class VocosGen(nn.Module):
def __init__(self, vocos):
super().__init__()
self.vocos = vocos
def forward(self, mels):
x = self.vocos.backbone(mels)
spec = self.vocos.head(x)
return spec
def export_generator(config_path, checkpoint_path, output_dir, opset_version):
with open(config_path, "r") as f:
config = yaml.safe_load(f)
class_module, class_name = config["model"]["class_path"].rsplit(".", 1)
module = __import__(class_module, fromlist=[class_name])
vocos_cls = getattr(module, class_name)
components = Vocos.from_hparams(config_path)
params = config["model"]["init_args"]
vocos = vocos_cls(
feature_extractor=components.feature_extractor,
backbone=components.backbone,
head=components.head,
sample_rate=params["sample_rate"],
initial_learning_rate=params["initial_learning_rate"],
num_warmup_steps=params["num_warmup_steps"],
mel_loss_coeff=params["mel_loss_coeff"],
mrd_loss_coeff=params["mrd_loss_coeff"],
)
if checkpoint_path.endswith(".bin"):
state_dict = torch.load(checkpoint_path, map_location="cpu")
vocos.load_state_dict(state_dict, strict=False)
elif checkpoint_path.endswith(".ckpt"):
raw_model = torch.load(checkpoint_path, map_location="cpu")
vocos.load_state_dict(raw_model['state_dict'], strict=False)
model = VocosGen(vocos)
model.eval()
Path(output_dir).mkdir(parents=True, exist_ok=True)
onnx_filename = f"mel_spec_24khz.onnx"
onnx_path = os.path.join(output_dir, onnx_filename)
dummy_input = torch.rand(1, vocos.backbone.input_channels, 64)
dynamic_axes = {
"mels": {0: "batch_size", 2: "time"},
}
#Conventional ONNX export
torch.onnx.export(
model=model,
args=dummy_input,
f=onnx_path,
input_names=["mels"],
output_names=["mag","x","y"],
dynamic_axes=dynamic_axes,
opset_version=opset_version,
export_params=True,
do_constant_folding=True,
)
# Using the new dynamo export
#export_output = torch.onnx.dynamo_export(model, dummy_input)
#export_output.save(onnx_path)
return onnx_path
def main():
logging.basicConfig(level=logging.DEBUG)
parser = argparse.ArgumentParser(
prog="export_onnx",
description="Export a vocos checkpoint to onnx",
)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--checkpoint", type=str, required=True)
parser.add_argument("--output-dir", type=str, required=True)
parser.add_argument("--seed", type=int, default=1234, help="random seed")
parser.add_argument("--opset", type=int, default=DEFAULT_OPSET_VERSION)
args = parser.parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
_LOGGER.info("Exporting model to ONNX")
_LOGGER.info(f"Config path: `{args.config}`")
_LOGGER.info(f"Using checkpoint: `{args.checkpoint}`")
onnx_path = export_generator(
config_path=args.config,
checkpoint_path=args.checkpoint,
output_dir=args.output_dir,
opset_version=args.opset
)
_LOGGER.info(f"Exported ONNX model to: `{onnx_path}`")
if __name__ == '__main__':
main() |
Thank you @wetdog it works. |
Hi all What library are you using for ISTFT? |
@sankar-mukherjee I'm glad that it worked for you. @mush42 I'm using the ISTFT from torch but I guess that librosa would also work. BTW here's our version of 22khz vocos https://huggingface.co/BSC-LT/vocos-mel-22khz |
@wetdog but does ONNX export of torch ISTFT already work ? |
@wetdog I am either do not understand how to export onnx since the model has ISTFT |
Here's an ISTFT implementation that can be exported to ONNX, because I want to stay with the It works with Vocos out of the box. |
@lumpidu and @Liujingxiu23 the workaorund that I posted works doing the export before the ISTFT operation. However, with the code that @mush42 shared it seems to be complete. I'll try to merge to the models into one. Also @mush42 you should post your solution to this issue in onnx/onnx#4777 |
|
Hi,
Any plans to enable ONNX export for Vocos?
I developed a script to do it, but it has some issues with some
pytorch
operators that Vocos uses.Click to expand code
The text was updated successfully, but these errors were encountered: