diff --git a/models/tts/maskgct/maskgct_s2a.py b/models/tts/maskgct/maskgct_s2a.py index ad9fd099..408aedae 100644 --- a/models/tts/maskgct/maskgct_s2a.py +++ b/models/tts/maskgct/maskgct_s2a.py @@ -11,27 +11,6 @@ from models.tts.maskgct.llama_nar import DiffLlama -def top_k(logits, thres=0.9): - k = math.ceil((1 - thres) * logits.shape[-1]) - val, ind = logits.topk(k, dim=-1) - probs = torch.full_like(logits, float("-inf")) - probs.scatter_(2, ind, val) - return probs - - -def log(t, eps=1e-10): - return torch.log(t + eps) - - -def gumbel_noise(t): - noise = torch.zeros_like(t).uniform_(0, 1) - return -log(-log(noise)) - - -def gumbel_sample(t, temperature=1.0, dim=-1): - return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim) - - def top_k(logits, thres=0.9): k = math.ceil((1 - thres) * logits.shape[-1]) val, ind = logits.topk(k, dim=-1)