-
Notifications
You must be signed in to change notification settings - Fork 112
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
38 changed files
with
2,196 additions
and
89 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
*.pyc | ||
__pycache__/ | ||
*/__pycache__/ | ||
alias_free_cuda/build/ | ||
exp/ | ||
tmp/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Copyright (c) 2024 NVIDIA CORPORATION. | ||
# Licensed under the MIT license. | ||
|
||
import torch | ||
import torch.nn as nn | ||
from alias_free_torch.resample import UpSample1d, DownSample1d | ||
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda | ||
from alias_free_cuda import load | ||
load.load() | ||
|
||
class FusedAntiAliasActivation(torch.autograd.Function): | ||
""" | ||
Assumes filter size 12, replication padding on upsampling, and logscale alpha/beta parameters as inputs | ||
""" | ||
@staticmethod | ||
def forward(ctx, inputs, ftr, alpha, beta): | ||
import anti_alias_activation_cuda | ||
activation_results = anti_alias_activation_cuda.forward(inputs, ftr, alpha, beta) | ||
return activation_results | ||
|
||
@staticmethod | ||
def backward(ctx, output_grads): | ||
# TODO: implement bwd pass | ||
raise NotImplementedError | ||
return output_grads, None, None | ||
|
||
class Activation1d(nn.Module): | ||
def __init__(self, | ||
activation, | ||
up_ratio: int = 2, | ||
down_ratio: int = 2, | ||
up_kernel_size: int = 12, | ||
down_kernel_size: int = 12, | ||
fused: bool = True | ||
): | ||
super().__init__() | ||
self.up_ratio = up_ratio | ||
self.down_ratio = down_ratio | ||
self.act = activation | ||
self.upsample = UpSample1d(up_ratio, up_kernel_size) | ||
self.downsample = DownSample1d(down_ratio, down_kernel_size) | ||
|
||
self.fused = fused # whether to use fused CUDA kernel or not | ||
|
||
|
||
def forward(self, x): | ||
if not self.fused: | ||
x = self.upsample(x) | ||
x = self.act(x) | ||
x = self.downsample(x) | ||
return x | ||
else: | ||
if self.act.__class__.__name__ == "Snake": | ||
beta = self.act.alpha.data # snake uses same params for alpha and beta | ||
else: | ||
beta = self.act.beta.data # snakebeta uses different params for alpha and beta | ||
alpha = self.act.alpha.data | ||
if not self.act.alpha_logscale: # exp baked into cuda kernel, cancel it out with a log | ||
alpha = torch.log(alpha) | ||
beta = torch.log(beta) | ||
x = FusedAntiAliasActivation.apply(x, self.upsample.filter, alpha, beta) | ||
x = self.downsample(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
/* coding=utf-8 | ||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include <cuda_fp16.h> | ||
#include <torch/extension.h> | ||
#include <vector> | ||
|
||
namespace anti_alias_activation { | ||
|
||
torch::Tensor fwd_cuda(torch::Tensor const& input, | ||
torch::Tensor const& filter, | ||
torch::Tensor const& alpha, | ||
torch::Tensor const& beta | ||
); | ||
|
||
torch::Tensor fwd(torch::Tensor const& input, | ||
torch::Tensor const& filter, | ||
torch::Tensor const& alpha, | ||
torch::Tensor const& beta | ||
) { | ||
AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); | ||
//AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || | ||
// (input.scalar_type() == at::ScalarType::BFloat16), | ||
// "Only fp16 and bf16 are supported"); | ||
|
||
return fwd_cuda(input, filter, alpha, beta); | ||
} | ||
|
||
} // end namespace anti_alias_activation | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def("forward", | ||
&anti_alias_activation::fwd, | ||
"Anti Alias Activation -- Forward."); | ||
} |
Oops, something went wrong.