WaveFormer:实时音频分离
简单介绍 Real-Time Target Sound Extraction 论文中提出的 WaveFormer 实时音频分离,并使用 Paddle 搭建和加载转换后的预训练模型完成模型推理。
1. 引入
- 音频分离即将一段混合的音频进行分离,可用于提取特定目标的音频,大致的效果如下:
- 本次就来介绍一个基于 DCC 和 Transformer 的音频分离模型 WaveFormer
2. 参考资料
3. 算法简介
-
算法目标:
- 使用一个神经网络模型,实现实时和流式目标声音提取
-
模型介绍:
-
提出了一种基于编码器-解码器(Encoder-Decoder)架构的神经网络模型 Waveformer
-
由多个扩张因果卷积层(Dilated Causal Convolution)作为编码器,以较高的计算效率处理较大的感受野
-
由一个 Transformer 解码层作为解码器,已获取较高的计算性能
-
-
性能表现:
-
评估显示,与之前的模型相比
-
该模型的 SI-SNRi 提高了 2.2-3.3 分贝
-
同时模型大小减少 1.2-4 倍
-
运行时间减少 1.5-2 倍
-
-
模型结构图:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jUs1xgWL-1669557473886)(https://ai-studio-static-online.cdn.bcebos.com/93ee98cd2a3b4202be10d238d0a5e1409afe7f7668284b99a097409e80125b12)]
-
主要的流程:
-
使用输入卷积将混合音频转换为输入音频块(Input chunk)
-
使用 DCC Encoder 提取音频特征
-
使用 Transformer Decoder 编码目标音频标签和音频特征并转换为音频遮罩(MASK)
-
叠加输入音频块(Input chunk)和音频遮罩(MASK)得到输出音频块(Output chunk)
-
使用输出卷积将输出音频块(Output chunk)转换为目标音频
-
4. 代码实现
4.1 模型网络
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class PositionalEncoding(nn.Layer):
def __init__(self, input_size, max_len=2500):
super().__init__()
self.max_len = max_len
pe = paddle.zeros((self.max_len, input_size))
pe.stop_gradient = True
positions = paddle.arange(0, self.max_len).unsqueeze(1).cast(
paddle.float32)
denominator = paddle.exp(
paddle.arange(0, input_size, 2).cast(paddle.float32) *
-(math.log(10000.0) / input_size))
pe[:, 0::2] = paddle.sin(positions * denominator)
pe[:, 1::2] = paddle.cos(positions * denominator)
self.pe = pe.unsqueeze(0)
self.register_buffer("pe", self.pe)
def forward(self, x):
"""
Arguments
---------
x : tensor
Input feature shape (batch, time, fea)
"""
return self.pe[:, :x.shape[1]].clone().detach()
def mod_pad(x, chunk_size, pad):
# Mod pad the input to perform integer number of
# inferences
mod = 0
if (x.shape[-1] % chunk_size) != 0:
mod = chunk_size - (x.shape[-1] % chunk_size)
x = F.pad(x, (0, mod), data_format='NCL')
x = F.pad(x, pad, data_format='NCL')
return x, mod
class LayerNormTransposed(nn.LayerNorm):
def __init__(self, *args, **kwargs):
super(LayerNormTransposed, self).__init__(*args, **kwargs)
def forward(self, x):
"""
Args:
x: [B, C, T]
"""
x = x.transpose((0, 2, 1)) # [B, T, C]
x = super().forward(x)
x = x.transpose((0, 2, 1)) # [B, C, T]
return x
class DepthwiseSeparableConv(nn.Layer):
"""
Depthwise separable convolutions
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
dilation):
super(DepthwiseSeparableConv, self).__init__()
self.layers = nn.Sequential(
nn.Conv1D(in_channels,
in_channels,
kernel_size,
stride,
padding,
groups=in_channels,
dilation=dilation),
LayerNormTransposed(in_channels),
nn.ReLU(),
nn.Conv1D(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0),
LayerNormTransposed(out_channels),
nn.ReLU(),
)
def forward(self, x):
return self.layers(x)
class DilatedCausalConvEncoder(nn.Layer):
"""
A dilated causal convolution based encoder for encoding
time domain audio input into latent space.
"""
def __init__(self, channels, num_layers, kernel_size=3):
super(DilatedCausalConvEncoder, self).__init__()
self.channels = channels
self.num_layers = num_layers
self.kernel_size = kernel_size
# Compute buffer lengths for each layer
# buf_length[i] = (kernel_size - 1) * dilation[i]
self.buf_lengths = [(kernel_size - 1) * 2**i
for i in range(num_layers)]
# Compute buffer start indices for each layer
self.buf_indices = [0]
for i in range(num_layers - 1):
self.buf_indices.append(self.buf_indices[-1] + self.buf_lengths[i])
# Dilated causal conv layers aggregate previous context to obtain
# contexful encoded input.
_dcc_layers = []
for i in range(num_layers):
dcc_layer = DepthwiseSeparableConv(channels,
channels,
kernel_size=3,
stride=1,
padding=0,
dilation=2**i)
_dcc_layers.append(('dcc_%d' % i, dcc_layer))
self.dcc_layers = nn.Sequential(*_dcc_layers)
def init_ctx_buf(self, batch_size):
"""
Returns an initialized context buffer for a given batch size.
"""
return paddle.zeros(
(batch_size, self.channels,
(self.kernel_size - 1) * (2**self.num_layers - 1)))
def forward(self, x, ctx_buf):
"""
Encodes input audio `x` into latent space, and aggregates
contextual information in `ctx_buf`. Also generates new context
buffer with updated context.
Args:
x: [B, in_channels, T]
Input multi-channel audio.
ctx_buf: {[B, channels, self.buf_length[0]], ...}
A list of tensors holding context for each dilation
causal conv layer. (len(ctx_buf) == self.num_layers)
Returns:
ctx_buf: {[B, channels, self.buf_length[0]], ...}
Updated context buffer with output as the
last element.
"""
T = x.shape[-1] # Sequence length
for i in range(self.num_layers):
buf_start_idx = self.buf_indices[i]
buf_end_idx = self.buf_indices[i] + self.buf_lengths[i]
# DCC input: concatenation of current output and context
dcc_in = paddle.concat(
(ctx_buf[..., buf_start_idx:buf_end_idx], x), axis=-1)
# Push current output to the context buffer
ctx_buf[..., buf_start_idx:buf_end_idx] = \
dcc_in[..., -self.buf_lengths[i]:]
# Residual connection
x = x + self.dcc_layers[i](dcc_in)
return x, ctx_buf
class CausalTransformerDecoderLayer(nn.TransformerDecoderLayer):
"""
Adapted from:
"https://github.com/alexmt-scale/causal-transformer-decoder/blob/"
"0caf6ad71c46488f76d89845b0123d2550ef792f/"
"causal_transformer_decoder/model.py#L77"
"""
def forward(self, tgt, memory=None, chunk_size=1):
tgt_last_tok = tgt[:, -chunk_size:, :]
# self attention part
tmp_tgt = self.self_attn(tgt_last_tok, tgt, tgt, attn_mask=None)
tgt_last_tok = tgt_last_tok + self.dropout1(tmp_tgt)
tgt_last_tok = self.norm1(tgt_last_tok)
# encoder-decoder attention
if memory is not None:
tmp_tgt = self.cross_attn(tgt_last_tok,
memory,
memory,
attn_mask=None)
tgt_last_tok = tgt_last_tok + self.dropout2(tmp_tgt)
tgt_last_tok = self.norm2(tgt_last_tok)
# final feed-forward network
tmp_tgt = self.linear2(
self.dropout(self.activation(self.linear1(tgt_last_tok))))
tgt_last_tok = tgt_last_tok + self.dropout3(tmp_tgt)
tgt_last_tok = self.norm3(tgt_last_tok)
return tgt_last_tok
class CausalTransformerDecoder(nn.Layer):
"""
A casual transformer decoder which decodes input vectors using
precisely `ctx_len` past vectors in the sequence, and using no future
vectors at all.
"""
def __init__(self, model_dim, ctx_len, chunk_size, num_layers, nhead,
use_pos_enc, ff_dim):
super(CausalTransformerDecoder, self).__init__()
self.num_layers = num_layers
self.model_dim = model_dim
self.ctx_len = ctx_len
self.chunk_size = chunk_size
self.nhead = nhead
self.use_pos_enc = use_pos_enc
self.unfold = nn.Unfold(kernel_sizes=[ctx_len + chunk_size, 1],
strides=chunk_size)
self.pos_enc = PositionalEncoding(model_dim, max_len=200)
self.tf_dec_layers = nn.LayerList([
CausalTransformerDecoderLayer(d_model=model_dim,
nhead=nhead,
dim_feedforward=ff_dim)
for _ in range(num_layers)
])
def init_ctx_buf(self, batch_size):
return paddle.zeros(
(batch_size, self.num_layers + 1, self.ctx_len, self.model_dim))
def _causal_unfold(self, x):
"""
Unfolds the sequence into a batch of sequences
prepended with `ctx_len` previous values.
Args:
x: [B, ctx_len + L, C]
ctx_len: int
Returns:
[B * L, ctx_len + 1, C]
"""
B, T, C = x.shape
x = x.transpose((0, 2, 1)) # [B, C, ctx_len + L]
x = self.unfold(x.unsqueeze(-1)) # [B, C * (ctx_len + chunk_size), -1]
x = x.transpose((0, 2, 1))
x = x.reshape((B, -1, C, self.ctx_len + self.chunk_size))
x = x.reshape((-1, C, self.ctx_len + self.chunk_size))
x = x.transpose((0, 2, 1))
return x
def forward(self, tgt, mem, ctx_buf, probe=False):
"""
Args:
x: [B, model_dim, T]
ctx_buf: [B, num_layers, model_dim, ctx_len]
"""
mem, _ = mod_pad(mem, self.chunk_size, (0, 0))
tgt, mod = mod_pad(tgt, self.chunk_size, (0, 0))
# Input sequence length
B, C, T = tgt.shape
tgt = tgt.transpose((0, 2, 1))
mem = mem.transpose((0, 2, 1))
# Prepend mem with the context
mem = paddle.concat((ctx_buf[:, 0, :, :], mem), axis=1)
ctx_buf[:, 0, :, :] = mem[:, -self.ctx_len:, :]
mem_ctx = self._causal_unfold(mem)
if self.use_pos_enc:
mem_ctx = mem_ctx + self.pos_enc(mem_ctx)
# Attention chunk size: required to ensure the model
# wouldn't trigger an out-of-memory error when working
# on long sequences.
K = 1000
for i, tf_dec_layer in enumerate(self.tf_dec_layers):
# Update the tgt with context
tgt = paddle.concat((ctx_buf[:, i + 1, :, :], tgt), axis=1)
ctx_buf[:, i + 1, :, :] = tgt[:, -self.ctx_len:, :]
# Compute encoded output
tgt_ctx = self._causal_unfold(tgt)
if self.use_pos_enc and i == 0:
tgt_ctx = tgt_ctx + self.pos_enc(tgt_ctx)
tgt = paddle.zeros_like(tgt_ctx)[:, -self.chunk_size:, :]
for i in range(int(math.ceil(tgt.shape[0] / K))):
tgt[i * K:(i + 1) * K] = tf_dec_layer(
tgt_ctx[i * K:(i + 1) * K], mem_ctx[i * K:(i + 1) * K],
self.chunk_size)
tgt = tgt.reshape((B, T, C))
tgt = tgt.transpose((0, 2, 1))
if mod != 0:
tgt = tgt[..., :-mod]
return tgt, ctx_buf
class MaskNet(nn.Layer):
def __init__(self, enc_dim, num_enc_layers, dec_dim, dec_buf_len,
dec_chunk_size, num_dec_layers, use_pos_enc, skip_connection,
proj):
super(MaskNet, self).__init__()
self.skip_connection = skip_connection
self.proj = proj
# Encoder based on dilated causal convolutions.
self.encoder = DilatedCausalConvEncoder(channels=enc_dim,
num_layers=num_enc_layers)
# Project between encoder and decoder dimensions
self.proj_e2d_e = nn.Sequential(
nn.Conv1D(enc_dim,
dec_dim,
kernel_size=1,
stride=1,
padding=0,
groups=dec_dim), nn.ReLU())
self.proj_e2d_l = nn.Sequential(
nn.Conv1D(enc_dim,
dec_dim,
kernel_size=1,
stride=1,
padding=0,
groups=dec_dim), nn.ReLU())
self.proj_d2e = nn.Sequential(
nn.Conv1D(dec_dim,
enc_dim,
kernel_size=1,
stride=1,
padding=0,
groups=dec_dim), nn.ReLU())
# Transformer decoder that operates on chunks of size
# buffer size.
self.decoder = CausalTransformerDecoder(model_dim=dec_dim,
ctx_len=dec_buf_len,
chunk_size=dec_chunk_size,
num_layers=num_dec_layers,
nhead=8,
use_pos_enc=use_pos_enc,
ff_dim=2 * dec_dim)
def forward(self, x, l, enc_buf, dec_buf):
"""
Generates a mask based on encoded input `e` and the one-hot
label `label`.
Args:
x: [B, C, T]
Input audio sequence
l: [B, C]
Label embedding
ctx_buf: {[B, C, <receptive field of the layer>], ...}
List of context buffers maintained by DCC encoder
"""
# Enocder the label integrated input
e, enc_buf = self.encoder(x, enc_buf)
# Label integration
l = l.unsqueeze(2) * e
# Project to `dec_dim` dimensions
if self.proj:
e = self.proj_e2d_e(e)
m = self.proj_e2d_l(l)
# Cross-attention to predict the mask
m, dec_buf = self.decoder(m, e, dec_buf)
else:
# Cross-attention to predict the mask
m, dec_buf = self.decoder(l, e, dec_buf)
# Project mask to encoder dimensions
if self.proj:
m = self.proj_d2e(m)
# Final mask after residual connection
if self.skip_connection:
m = l + m
return m, enc_buf, dec_buf
class WaveFormer(nn.Layer):
def __init__(self,
label_len,
L=8,
enc_dim=512,
num_enc_layers=10,
dec_dim=256,
dec_buf_len=100,
num_dec_layers=2,
dec_chunk_size=72,
out_buf_len=2,
use_pos_enc=True,
skip_connection=True,
proj=True,
lookahead=True):
super(WaveFormer, self).__init__()
self.L = L
self.out_buf_len = out_buf_len
self.enc_dim = enc_dim
self.lookahead = lookahead
# Input conv to convert input audio to a latent representation
kernel_size = 3 * L if lookahead else L
self.in_conv = nn.Sequential(
nn.Conv1D(in_channels=1,
out_channels=enc_dim,
kernel_size=kernel_size,
stride=L,
padding=0,
bias_attr=False), nn.ReLU())
# Label embedding layer
self.label_embedding = nn.Sequential(nn.Linear(label_len, 512),
nn.LayerNorm(512), nn.ReLU(),
nn.Linear(512, enc_dim),
nn.LayerNorm(enc_dim), nn.ReLU())
# Mask generator
self.mask_gen = MaskNet(enc_dim=enc_dim,
num_enc_layers=num_enc_layers,
dec_dim=dec_dim,
dec_buf_len=dec_buf_len,
dec_chunk_size=dec_chunk_size,
num_dec_layers=num_dec_layers,
use_pos_enc=use_pos_enc,
skip_connection=skip_connection,
proj=proj)
# Output conv layer
self.out_conv = nn.Sequential(
nn.Conv1DTranspose(in_channels=enc_dim,
out_channels=1,
kernel_size=(out_buf_len + 1) * L,
stride=L,
padding=out_buf_len * L,
bias_attr=False), nn.Tanh())
def init_buffers(self, batch_size):
enc_buf = self.mask_gen.encoder.init_ctx_buf(batch_size)
dec_buf = self.mask_gen.decoder.init_ctx_buf(batch_size)
out_buf = paddle.zeros((batch_size, self.enc_dim, self.out_buf_len))
return enc_buf, dec_buf, out_buf
def forward(self,
x,
label,
init_enc_buf=None,
init_dec_buf=None,
init_out_buf=None,
pad=True):
"""
Extracts the audio corresponding to the `label` in the given
`mixture`. Generates `chunk_size` samples per iteration.
Args:
mixed: [B, n_mics, T]
input audio mixture
label: [B, num_labels]
one hot label
Returns:
out: [B, n_spk, T]
extracted audio with sounds corresponding to the `label`
"""
mod = 0
if pad:
pad_size = (self.L, self.L) if self.lookahead else (0, 0)
x, mod = mod_pad(x, chunk_size=self.L, pad=pad_size)
if init_enc_buf is None or init_dec_buf is None or init_out_buf is None:
assert init_enc_buf is None and \
init_dec_buf is None and \
init_out_buf is None, \
"Both buffers have to initialized, or " \
"both of them have to be None."
enc_buf, dec_buf, out_buf = self.init_buffers(x.shape[0])
else:
enc_buf, dec_buf, out_buf = \
init_enc_buf, init_dec_buf, init_out_buf
# Generate latent space representation of the input
x = self.in_conv(x)
# Generate label embedding
l = self.label_embedding(label) # [B, label_len] --> [B, channels]
# Generate mask corresponding to the label
m, enc_buf, dec_buf = self.mask_gen(x, l, enc_buf, dec_buf)
# Apply mask and decode
x = x * m
x = paddle.concat((out_buf, x), axis=-1)
out_buf = x[..., -self.out_buf_len:]
x = self.out_conv(x)
# Remove mod padding, if present.
if mod != 0:
x = x[:, :, :-mod]
return x, enc_buf, dec_buf, out_buf
4.2 模型预测器
import os
import json
from typing import List
import librosa
import paddle
import soundfile
class WaveFormerPredictor:
def __init__(self, model_dir: str) -> None:
config_file = os.path.join(model_dir, 'config.json')
ckpt_file = os.path.join(model_dir, 'model.pdparams')
with open(config_file, 'r', encoding='UTF-8') as f:
configs = json.load(f)
self.sample_rate = configs['test_data']['sr']
state_dict = paddle.load(ckpt_file)
self.waveformer = WaveFormer(**configs['model_params'])
self.waveformer.set_state_dict(state_dict)
self.waveformer.eval()
self.target_list_en = [
"Acoustic_guitar", "Applause", "Bark", "Bass_drum", "Burping_or_eructation",
"Bus", "Cello", "Chime", "Clarinet", "Computer_keyboard", "Cough", "Cowbell",
"Double_bass", "Drawer_open_or_close", "Electric_piano", "Fart", "Finger_snapping",
"Fireworks", "Flute", "Glockenspiel", "Gong", "Gunshot_or_gunfire", "Harmonica",
"Hi-hat", "Keys_jangling", "Knock", "Laughter", "Meow", "Microwave_oven", "Oboe",
"Saxophone", "Scissors", "Shatter", "Snare_drum", "Squeak", "Tambourine", "Tearing",
"Telephone", "Trumpet", "Violin_or_fiddle", "Writing"
]
self.target_list_cn = [
"木吉他", "鼓掌", "吠叫", "低音鼓", "打嗝或吞咽",
"巴士", "大提琴", "钟声", "单簧管", "电脑键盘", "咳嗽", "牛铃",
"低音提琴", "抽屉打开或关闭", "电子钢琴", "放屁", "手指打响",
"烟花", "长笛", "手风琴", "锣", "枪声或炮声", "口琴",
"踩镲", "按键声", "敲击声", "笑声", "喵星人", "微波炉", "双簧管",
"萨克斯风", "剪刀", "碎片", "小鼓", "吱吱", "手鼓", "撕裂",
"电话", "小号", "小提琴或提琴", "写作"
]
self.target_num = 41
def extraction(
self,
input_file: str,
targets: List[str],
output_file: str) -> None:
with paddle.no_grad():
mixture, _ = librosa.load(input_file, self.sample_rate)
mixture = paddle.to_tensor(
mixture, dtype=paddle.float32).unsqueeze(0).unsqueeze(0)
target_indexs = []
for target in targets:
if target in self.target_list_en:
index = self.target_list_en.index(target)
target_indexs.append(index)
elif target in self.target_list_cn:
index = self.target_list_cn.index(target)
target_indexs.append(index)
elif (isinstance(target, int)) and (0<=target<self.target_num):
target_indexs.append(target)
target_indexs = list(set(target_indexs))
if len(target_indexs) == 0:
query = paddle.ones((1, self.target_num))
else:
query = paddle.zeros((1, self.target_num))
query[:, target_indexs] = 1
output, _, _, _ = self.waveformer(mixture, query)
output = output.squeeze().numpy()
soundfile.write(output_file, output, self.sample_rate)
4.3 模型加载
class ModelList:
E256_10_D128_1 = 'pretrained_models/dcc_tf_ckpt_E256_10_D128_1'
E256_10_D256_1 = 'pretrained_models/dcc_tf_ckpt_E256_10_D256_1'
E512_10_D128_1 = 'pretrained_models/dcc_tf_ckpt_E512_10_D128_1'
E512_10_D256_1 = 'pretrained_models/dcc_tf_ckpt_E512_10_D256_1'
E256_10_D128_1_multi = 'pretrained_models/dcc_tf_ckpt_E256_10_D128_1_multi'
E256_10_D256_1_multi = 'pretrained_models/dcc_tf_ckpt_E256_10_D256_1_multi'
E512_10_D128_1_multi = 'pretrained_models/dcc_tf_ckpt_E512_10_D128_1_multi'
E512_10_D256_1_multi = 'pretrained_models/dcc_tf_ckpt_E512_10_D256_1_multi'
predictor = WaveFormerPredictor(
model_dir=ModelList.E512_10_D256_1_multi
)
4.4 模型预测
from IPython.display import Audio, display
target_list_cn = [
"木吉他", "鼓掌", "吠叫", "低音鼓", "打嗝或吞咽",
"巴士", "大提琴", "钟声", "单簧管", "电脑键盘", "咳嗽", "牛铃",
"低音提琴", "抽屉打开或关闭", "电子钢琴", "放屁", "手指打响",
"烟花", "长笛", "手风琴", "锣", "枪声或炮声", "口琴",
"踩镲", "按键声", "敲击声", "笑声", "喵星人", "微波炉", "双簧管",
"萨克斯风", "剪刀", "碎片", "小鼓", "吱吱", "手鼓", "撕裂",
"电话", "小号", "小提琴或提琴", "写作"
]
input_file='sample.wav'
targets=['电脑键盘']
output_file='output.wav'
predictor.extraction(
input_file,
targets,
output_file
)
print('输入音频: ')
display(Audio(input_file))
print('输出音频: ')
display(Audio(output_file))
输入音频:
<IPython.lib.display.Audio object>
输出音频:
<IPython.lib.display.Audio object>
5. 小结
-
简单介绍了一下 WaveFormer 模型
-
搭建 WaveFormer 模型并加载预训练模型实现音频分离的推理
此文章为搬运
原项目链接
更多推荐
所有评论(0)