Skip to content

Commit

Permalink
add ability to use t5 relative positional bias, addressing #15
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 26, 2023
1 parent f2b1d43 commit 7ef65a9
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 10 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'soundstorm-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.17',
version = '0.0.18',
license='MIT',
description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch',
author = 'Phil Wang',
Expand Down
3 changes: 2 additions & 1 deletion soundstorm_pytorch/attend.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ def forward(self, q, k, v, mask = None, attn_bias = None):
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'

if self.flash:
return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
assert not exists(attn_bias)
return self.flash_attn(q, k, v, mask = mask)

# similarity

Expand Down
83 changes: 75 additions & 8 deletions soundstorm_pytorch/soundstorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,64 @@ def rotate_half(x):
def apply_rotary_pos_emb(pos, t):
return (t * pos.cos()) + (rotate_half(t) * pos.sin())

# t5 relative positional bias

class T5RelativePositionBias(nn.Module):
def __init__(
self,
scale = 1.,
num_buckets = 32,
max_distance = 128,
heads = 8
):
super().__init__()
self.scale = scale
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)

@staticmethod
def _relative_position_bucket(
relative_position,
num_buckets = 32,
max_distance = 128
):
ret = 0
n = -relative_position

num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)

max_exact = num_buckets // 2
is_small = n < max_exact

val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()

val_if_large = torch.min(
val_if_large,
torch.full_like(val_if_large, num_buckets - 1)
)

ret += torch.where(is_small, n, val_if_large)
return ret

@property
def device(self):
return next(self.parameters()).device

def forward(self, n):
pos = torch.arange(n, device = self.device).long()
rel_pos = rearrange(pos, 'j -> 1 j') - rearrange(pos, 'i -> i 1')

rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
values = self.relative_attention_bias(rp_bucket)

bias = rearrange(values, 'i j h -> h i j')
return bias * self.scale

# conformer

class Swish(nn.Module):
Expand Down Expand Up @@ -213,7 +271,8 @@ def forward(
x,
context = None,
mask = None,
rotary_emb = None
rotary_emb = None,
attn_bias = None
):
n, device, h, has_context = x.shape[-2], x.device, self.heads, exists(context)
context = default(context, x)
Expand All @@ -225,7 +284,7 @@ def forward(
q = apply_rotary_pos_emb(rotary_emb, q)
k = apply_rotary_pos_emb(rotary_emb, k)

out = self.attend(q, k, v, mask = mask)
out = self.attend(q, k, v, mask = mask, attn_bias = attn_bias)

out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
Expand Down Expand Up @@ -313,10 +372,11 @@ def forward(
self,
x,
mask = None,
rotary_emb = None
rotary_emb = None,
attn_bias = None
):
x = self.ff1(x) + x
x = self.attn(x, mask = mask, rotary_emb = rotary_emb) + x
x = self.attn(x, mask = mask, rotary_emb = rotary_emb, attn_bias = attn_bias) + x
x = self.conv(x) + x
x = self.ff2(x) + x
x = self.post_norm(x)
Expand All @@ -339,13 +399,18 @@ def __init__(
ff_dropout = 0.,
conv_dropout = 0.,
conv_causal = False,
attn_flash = True
attn_flash = True,
t5_rel_pos_bias = False
):
super().__init__()

assert not (t5_rel_pos_bias and attn_flash), 'flash attention is not compatible with learned bias'

self.dim = dim
self.layers = nn.ModuleList([])

self.rotary_emb = RotaryEmbedding(dim_head)
self.rotary_emb = RotaryEmbedding(dim_head) if not t5_rel_pos_bias else None
self.rel_pos_bias = T5RelativePositionBias(dim_head ** 0.5, heads = heads) if t5_rel_pos_bias else None

for _ in range(depth):
self.layers.append(ConformerBlock(
Expand All @@ -361,11 +426,13 @@ def __init__(
))

def forward(self, x):
seq_len = x.shape[-2]

rotary_emb = self.rotary_emb(x.shape[-2])
rotary_emb = self.rotary_emb(seq_len) if exists(self.rotary_emb) else None
attn_bias = self.rel_pos_bias(seq_len) if exists(self.rel_pos_bias) else None

for block in self.layers:
x = block(x, rotary_emb = rotary_emb)
x = block(x, rotary_emb = rotary_emb, attn_bias = attn_bias)

return x

Expand Down

0 comments on commit 7ef65a9

Please sign in to comment.