Skip to content

Commit

Permalink
now able to train directly on raw audio, conditioned on semantic toke…
Browse files Browse the repository at this point in the history
…n ids automatically, if wav2vec module is passed in (or text to semantic module contains wav2vec)
  • Loading branch information
lucidrains committed Jun 26, 2023
1 parent d2f701f commit f2b1d43
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'soundstorm-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.16',
version = '0.0.17',
license='MIT',
description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch',
author = 'Phil Wang',
Expand Down
22 changes: 22 additions & 0 deletions soundstorm_pytorch/soundstorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from spear_tts_pytorch import TextToSemantic

from audiolm_pytorch import SoundStream
from audiolm_pytorch import HubertWithKmeans, FairseqVQWav2Vec

from tqdm import tqdm

Expand Down Expand Up @@ -512,6 +513,7 @@ def __init__(
*,
soundstream: Optional[SoundStream] = None,
spear_tts_text_to_semantic: Optional[TextToSemantic] = None,
wav2vec: Optional[Union[HubertWithKmeans, FairseqVQWav2Vec]] = None,
steps = 18,
self_cond = False,
self_cond_train_prob = 0.75,
Expand Down Expand Up @@ -558,10 +560,20 @@ def __init__(
self.text_to_semantic = spear_tts_text_to_semantic

if exists(spear_tts_text_to_semantic) and exists(spear_tts_text_to_semantic.wav2vec):
assert not exists(wav2vec), 'wav2vec model already supplied from the TextToSemantic instance from SpearTTS'
assert not (exists(wav2vec_downsample_factor) or exists(wav2vec_target_sample_hz)), 'wav2vec downsample factor and sampling freq being auto-set from the text-to-semantic module passed in, as it contains the wav2vec instance'

self.wav2vec = spear_tts_text_to_semantic.wav2vec
self.wav2vec_target_sample_hz = maybe_wav2vec.target_sample_hz
self.wav2vec_downsample_factor = maybe_wav2vec.downsample_factor

elif exists(wav2vec):
assert not (exists(wav2vec_downsample_factor) or exists(wav2vec_target_sample_hz)), 'wav2vec downsample factor and sampling freq being auto-set from the text-to-semantic module passed in, as it contains the wav2vec instance'

self.wav2vec = wav2vec
self.wav2vec_target_sample_hz = wav2vec.target_sample_hz
self.wav2vec_downsample_factor = wav2vec.downsample_factor

else:
self.wav2vec = None
self.wav2vec_target_sample_hz = wav2vec_target_sample_hz
Expand Down Expand Up @@ -815,6 +827,16 @@ def forward(

is_raw_audio = x.dtype == torch.float

# if semantic token ids not supplied and conditioning is indicated
# see if wav2vec and raw audio is available

if self.should_condition and not exists(cond_semantic_token_ids) and is_raw_audio:
with torch.no_grad():
self.wav2vec.eval()
cond_semantic_token_ids = self.wav2vec(x, flatten = False)

# derive residual vector quantized ids if raw audio passed in

if is_raw_audio:
assert exists(self.soundstream)
with torch.no_grad():
Expand Down

0 comments on commit f2b1d43

Please sign in to comment.