Skip to content

Commit

Permalink
handle the conditioning token
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 17, 2023
1 parent 1dcbc48 commit 7e85f47
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'soundstorm-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.1',
version = '0.0.2',
license='MIT',
description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch',
author = 'Phil Wang',
Expand Down
12 changes: 10 additions & 2 deletions soundstorm_pytorch/soundstorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch import Tensor, nn, einsum
import torch.nn.functional as F

from einops import rearrange, reduce
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange

from beartype import beartype
Expand Down Expand Up @@ -50,9 +50,17 @@ def __init__(

def forward(
self,
x
x,
cond = None
):
x = reduce(x, 'b (n h) d -> b n d', h = self.num_tokens_reduce)

if exists(cond):
if cond.ndim == 2:
cond = rearrange(cond, 'b d -> b 1 d')

x = x + cond

logits = self.conformer(x)
out = self.heads(logits)
return out
Expand Down

0 comments on commit 7e85f47

Please sign in to comment.