1. 引入

  • 音频分离即将一段混合的音频进行分离,可用于提取特定目标的音频,大致的效果如下:

在这里插入图片描述

  • 本次就来介绍一个基于 DCC 和 Transformer 的音频分离模型 WaveFormer

2. 参考资料

3. 算法简介

  • 算法目标:

    • 使用一个神经网络模型,实现实时和流式目标声音提取
  • 模型介绍:

    • 提出了一种基于编码器-解码器(Encoder-Decoder)架构的神经网络模型 Waveformer

    • 由多个扩张因果卷积层(Dilated Causal Convolution)作为编码器,以较高的计算效率处理较大的感受野

    • 由一个 Transformer 解码层作为解码器,已获取较高的计算性能

  • 性能表现:

    • 评估显示,与之前的模型相比

    • 该模型的 SI-SNRi 提高了 2.2-3.3 分贝

    • 同时模型大小减少 1.2-4 倍

    • 运行时间减少 1.5-2 倍

  • 模型结构图:

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jUs1xgWL-1669557473886)(https://ai-studio-static-online.cdn.bcebos.com/93ee98cd2a3b4202be10d238d0a5e1409afe7f7668284b99a097409e80125b12)]

  • 主要的流程:

    • 使用输入卷积将混合音频转换为输入音频块(Input chunk)

    • 使用 DCC Encoder 提取音频特征

    • 使用 Transformer Decoder 编码目标音频标签和音频特征并转换为音频遮罩(MASK)

    • 叠加输入音频块(Input chunk)和音频遮罩(MASK)得到输出音频块(Output chunk)

    • 使用输出卷积将输出音频块(Output chunk)转换为目标音频

4. 代码实现

4.1 模型网络

import math

import paddle
import paddle.nn as nn
import paddle.nn.functional as F


class PositionalEncoding(nn.Layer):

    def __init__(self, input_size, max_len=2500):
        super().__init__()
        self.max_len = max_len
        pe = paddle.zeros((self.max_len, input_size))
        pe.stop_gradient = True
        positions = paddle.arange(0, self.max_len).unsqueeze(1).cast(
            paddle.float32)
        denominator = paddle.exp(
            paddle.arange(0, input_size, 2).cast(paddle.float32) *
            -(math.log(10000.0) / input_size))

        pe[:, 0::2] = paddle.sin(positions * denominator)
        pe[:, 1::2] = paddle.cos(positions * denominator)
        self.pe = pe.unsqueeze(0)
        self.register_buffer("pe", self.pe)

    def forward(self, x):
        """
        Arguments
        ---------
        x : tensor
            Input feature shape (batch, time, fea)
        """
        return self.pe[:, :x.shape[1]].clone().detach()


def mod_pad(x, chunk_size, pad):
    # Mod pad the input to perform integer number of
    # inferences
    mod = 0
    if (x.shape[-1] % chunk_size) != 0:
        mod = chunk_size - (x.shape[-1] % chunk_size)

    x = F.pad(x, (0, mod), data_format='NCL')
    x = F.pad(x, pad, data_format='NCL')

    return x, mod


class LayerNormTransposed(nn.LayerNorm):

    def __init__(self, *args, **kwargs):
        super(LayerNormTransposed, self).__init__(*args, **kwargs)

    def forward(self, x):
        """
        Args:
            x: [B, C, T]
        """
        x = x.transpose((0, 2, 1))  # [B, T, C]
        x = super().forward(x)
        x = x.transpose((0, 2, 1))  # [B, C, T]
        return x


class DepthwiseSeparableConv(nn.Layer):
    """
    Depthwise separable convolutions
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
                 dilation):
        super(DepthwiseSeparableConv, self).__init__()

        self.layers = nn.Sequential(
            nn.Conv1D(in_channels,
                      in_channels,
                      kernel_size,
                      stride,
                      padding,
                      groups=in_channels,
                      dilation=dilation),
            LayerNormTransposed(in_channels),
            nn.ReLU(),
            nn.Conv1D(in_channels,
                      out_channels,
                      kernel_size=1,
                      stride=1,
                      padding=0),
            LayerNormTransposed(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.layers(x)


class DilatedCausalConvEncoder(nn.Layer):
    """
    A dilated causal convolution based encoder for encoding
    time domain audio input into latent space.
    """

    def __init__(self, channels, num_layers, kernel_size=3):
        super(DilatedCausalConvEncoder, self).__init__()
        self.channels = channels
        self.num_layers = num_layers
        self.kernel_size = kernel_size

        # Compute buffer lengths for each layer
        # buf_length[i] = (kernel_size - 1) * dilation[i]
        self.buf_lengths = [(kernel_size - 1) * 2**i
                            for i in range(num_layers)]

        # Compute buffer start indices for each layer
        self.buf_indices = [0]
        for i in range(num_layers - 1):
            self.buf_indices.append(self.buf_indices[-1] + self.buf_lengths[i])

        # Dilated causal conv layers aggregate previous context to obtain
        # contexful encoded input.
        _dcc_layers = []
        for i in range(num_layers):
            dcc_layer = DepthwiseSeparableConv(channels,
                                               channels,
                                               kernel_size=3,
                                               stride=1,
                                               padding=0,
                                               dilation=2**i)
            _dcc_layers.append(('dcc_%d' % i, dcc_layer))
        self.dcc_layers = nn.Sequential(*_dcc_layers)

    def init_ctx_buf(self, batch_size):
        """
        Returns an initialized context buffer for a given batch size.
        """
        return paddle.zeros(
            (batch_size, self.channels,
             (self.kernel_size - 1) * (2**self.num_layers - 1)))

    def forward(self, x, ctx_buf):
        """
        Encodes input audio `x` into latent space, and aggregates
        contextual information in `ctx_buf`. Also generates new context
        buffer with updated context.
        Args:
            x: [B, in_channels, T]
                Input multi-channel audio.
            ctx_buf: {[B, channels, self.buf_length[0]], ...}
                A list of tensors holding context for each dilation
                causal conv layer. (len(ctx_buf) == self.num_layers)
        Returns:
            ctx_buf: {[B, channels, self.buf_length[0]], ...}
                Updated context buffer with output as the
                last element.
        """
        T = x.shape[-1]  # Sequence length

        for i in range(self.num_layers):
            buf_start_idx = self.buf_indices[i]
            buf_end_idx = self.buf_indices[i] + self.buf_lengths[i]

            # DCC input: concatenation of current output and context
            dcc_in = paddle.concat(
                (ctx_buf[..., buf_start_idx:buf_end_idx], x), axis=-1)

            # Push current output to the context buffer
            ctx_buf[..., buf_start_idx:buf_end_idx] = \
                dcc_in[..., -self.buf_lengths[i]:]

            # Residual connection
            x = x + self.dcc_layers[i](dcc_in)

        return x, ctx_buf


class CausalTransformerDecoderLayer(nn.TransformerDecoderLayer):
    """
    Adapted from:
    "https://github.com/alexmt-scale/causal-transformer-decoder/blob/"
    "0caf6ad71c46488f76d89845b0123d2550ef792f/"
    "causal_transformer_decoder/model.py#L77"
    """

    def forward(self, tgt, memory=None, chunk_size=1):
        tgt_last_tok = tgt[:, -chunk_size:, :]

        # self attention part
        tmp_tgt = self.self_attn(tgt_last_tok, tgt, tgt, attn_mask=None)
        tgt_last_tok = tgt_last_tok + self.dropout1(tmp_tgt)
        tgt_last_tok = self.norm1(tgt_last_tok)

        # encoder-decoder attention
        if memory is not None:
            tmp_tgt = self.cross_attn(tgt_last_tok,
                                      memory,
                                      memory,
                                      attn_mask=None)
            tgt_last_tok = tgt_last_tok + self.dropout2(tmp_tgt)
            tgt_last_tok = self.norm2(tgt_last_tok)

        # final feed-forward network
        tmp_tgt = self.linear2(
            self.dropout(self.activation(self.linear1(tgt_last_tok))))
        tgt_last_tok = tgt_last_tok + self.dropout3(tmp_tgt)
        tgt_last_tok = self.norm3(tgt_last_tok)
        return tgt_last_tok


class CausalTransformerDecoder(nn.Layer):
    """
    A casual transformer decoder which decodes input vectors using
    precisely `ctx_len` past vectors in the sequence, and using no future
    vectors at all.
    """

    def __init__(self, model_dim, ctx_len, chunk_size, num_layers, nhead,
                 use_pos_enc, ff_dim):
        super(CausalTransformerDecoder, self).__init__()
        self.num_layers = num_layers
        self.model_dim = model_dim
        self.ctx_len = ctx_len
        self.chunk_size = chunk_size
        self.nhead = nhead
        self.use_pos_enc = use_pos_enc
        self.unfold = nn.Unfold(kernel_sizes=[ctx_len + chunk_size, 1],
                                strides=chunk_size)
        self.pos_enc = PositionalEncoding(model_dim, max_len=200)
        self.tf_dec_layers = nn.LayerList([
            CausalTransformerDecoderLayer(d_model=model_dim,
                                          nhead=nhead,
                                          dim_feedforward=ff_dim)
            for _ in range(num_layers)
        ])

    def init_ctx_buf(self, batch_size):
        return paddle.zeros(
            (batch_size, self.num_layers + 1, self.ctx_len, self.model_dim))

    def _causal_unfold(self, x):
        """
        Unfolds the sequence into a batch of sequences
        prepended with `ctx_len` previous values.

        Args:
            x: [B, ctx_len + L, C]
            ctx_len: int
        Returns:
            [B * L, ctx_len + 1, C]
        """
        B, T, C = x.shape
        x = x.transpose((0, 2, 1))  # [B, C, ctx_len + L]
        x = self.unfold(x.unsqueeze(-1))  # [B, C * (ctx_len + chunk_size), -1]
        x = x.transpose((0, 2, 1))
        x = x.reshape((B, -1, C, self.ctx_len + self.chunk_size))
        x = x.reshape((-1, C, self.ctx_len + self.chunk_size))
        x = x.transpose((0, 2, 1))
        return x

    def forward(self, tgt, mem, ctx_buf, probe=False):
        """
        Args:
            x: [B, model_dim, T]
            ctx_buf: [B, num_layers, model_dim, ctx_len]
        """
        mem, _ = mod_pad(mem, self.chunk_size, (0, 0))
        tgt, mod = mod_pad(tgt, self.chunk_size, (0, 0))

        # Input sequence length
        B, C, T = tgt.shape

        tgt = tgt.transpose((0, 2, 1))
        mem = mem.transpose((0, 2, 1))

        # Prepend mem with the context
        mem = paddle.concat((ctx_buf[:, 0, :, :], mem), axis=1)
        ctx_buf[:, 0, :, :] = mem[:, -self.ctx_len:, :]
        mem_ctx = self._causal_unfold(mem)
        if self.use_pos_enc:
            mem_ctx = mem_ctx + self.pos_enc(mem_ctx)

        # Attention chunk size: required to ensure the model
        # wouldn't trigger an out-of-memory error when working
        # on long sequences.
        K = 1000

        for i, tf_dec_layer in enumerate(self.tf_dec_layers):
            # Update the tgt with context
            tgt = paddle.concat((ctx_buf[:, i + 1, :, :], tgt), axis=1)
            ctx_buf[:, i + 1, :, :] = tgt[:, -self.ctx_len:, :]

            # Compute encoded output
            tgt_ctx = self._causal_unfold(tgt)
            if self.use_pos_enc and i == 0:
                tgt_ctx = tgt_ctx + self.pos_enc(tgt_ctx)
            tgt = paddle.zeros_like(tgt_ctx)[:, -self.chunk_size:, :]
            for i in range(int(math.ceil(tgt.shape[0] / K))):
                tgt[i * K:(i + 1) * K] = tf_dec_layer(
                    tgt_ctx[i * K:(i + 1) * K], mem_ctx[i * K:(i + 1) * K],
                    self.chunk_size)
            tgt = tgt.reshape((B, T, C))

        tgt = tgt.transpose((0, 2, 1))
        if mod != 0:
            tgt = tgt[..., :-mod]

        return tgt, ctx_buf


class MaskNet(nn.Layer):

    def __init__(self, enc_dim, num_enc_layers, dec_dim, dec_buf_len,
                 dec_chunk_size, num_dec_layers, use_pos_enc, skip_connection,
                 proj):
        super(MaskNet, self).__init__()
        self.skip_connection = skip_connection
        self.proj = proj

        # Encoder based on dilated causal convolutions.
        self.encoder = DilatedCausalConvEncoder(channels=enc_dim,
                                                num_layers=num_enc_layers)

        # Project between encoder and decoder dimensions
        self.proj_e2d_e = nn.Sequential(
            nn.Conv1D(enc_dim,
                      dec_dim,
                      kernel_size=1,
                      stride=1,
                      padding=0,
                      groups=dec_dim), nn.ReLU())
        self.proj_e2d_l = nn.Sequential(
            nn.Conv1D(enc_dim,
                      dec_dim,
                      kernel_size=1,
                      stride=1,
                      padding=0,
                      groups=dec_dim), nn.ReLU())
        self.proj_d2e = nn.Sequential(
            nn.Conv1D(dec_dim,
                      enc_dim,
                      kernel_size=1,
                      stride=1,
                      padding=0,
                      groups=dec_dim), nn.ReLU())

        # Transformer decoder that operates on chunks of size
        # buffer size.
        self.decoder = CausalTransformerDecoder(model_dim=dec_dim,
                                                ctx_len=dec_buf_len,
                                                chunk_size=dec_chunk_size,
                                                num_layers=num_dec_layers,
                                                nhead=8,
                                                use_pos_enc=use_pos_enc,
                                                ff_dim=2 * dec_dim)

    def forward(self, x, l, enc_buf, dec_buf):
        """
        Generates a mask based on encoded input `e` and the one-hot
        label `label`.

        Args:
            x: [B, C, T]
                Input audio sequence
            l: [B, C]
                Label embedding
            ctx_buf: {[B, C, <receptive field of the layer>], ...}
                List of context buffers maintained by DCC encoder
        """
        # Enocder the label integrated input
        e, enc_buf = self.encoder(x, enc_buf)

        # Label integration
        l = l.unsqueeze(2) * e

        # Project to `dec_dim` dimensions
        if self.proj:
            e = self.proj_e2d_e(e)
            m = self.proj_e2d_l(l)
            # Cross-attention to predict the mask
            m, dec_buf = self.decoder(m, e, dec_buf)
        else:
            # Cross-attention to predict the mask
            m, dec_buf = self.decoder(l, e, dec_buf)

        # Project mask to encoder dimensions
        if self.proj:
            m = self.proj_d2e(m)

        # Final mask after residual connection
        if self.skip_connection:
            m = l + m

        return m, enc_buf, dec_buf


class WaveFormer(nn.Layer):

    def __init__(self,
                 label_len,
                 L=8,
                 enc_dim=512,
                 num_enc_layers=10,
                 dec_dim=256,
                 dec_buf_len=100,
                 num_dec_layers=2,
                 dec_chunk_size=72,
                 out_buf_len=2,
                 use_pos_enc=True,
                 skip_connection=True,
                 proj=True,
                 lookahead=True):
        super(WaveFormer, self).__init__()
        self.L = L
        self.out_buf_len = out_buf_len
        self.enc_dim = enc_dim
        self.lookahead = lookahead

        # Input conv to convert input audio to a latent representation
        kernel_size = 3 * L if lookahead else L
        self.in_conv = nn.Sequential(
            nn.Conv1D(in_channels=1,
                      out_channels=enc_dim,
                      kernel_size=kernel_size,
                      stride=L,
                      padding=0,
                      bias_attr=False), nn.ReLU())

        # Label embedding layer
        self.label_embedding = nn.Sequential(nn.Linear(label_len, 512),
                                             nn.LayerNorm(512), nn.ReLU(),
                                             nn.Linear(512, enc_dim),
                                             nn.LayerNorm(enc_dim), nn.ReLU())

        # Mask generator
        self.mask_gen = MaskNet(enc_dim=enc_dim,
                                num_enc_layers=num_enc_layers,
                                dec_dim=dec_dim,
                                dec_buf_len=dec_buf_len,
                                dec_chunk_size=dec_chunk_size,
                                num_dec_layers=num_dec_layers,
                                use_pos_enc=use_pos_enc,
                                skip_connection=skip_connection,
                                proj=proj)

        # Output conv layer
        self.out_conv = nn.Sequential(
            nn.Conv1DTranspose(in_channels=enc_dim,
                               out_channels=1,
                               kernel_size=(out_buf_len + 1) * L,
                               stride=L,
                               padding=out_buf_len * L,
                               bias_attr=False), nn.Tanh())

    def init_buffers(self, batch_size):
        enc_buf = self.mask_gen.encoder.init_ctx_buf(batch_size)
        dec_buf = self.mask_gen.decoder.init_ctx_buf(batch_size)
        out_buf = paddle.zeros((batch_size, self.enc_dim, self.out_buf_len))
        return enc_buf, dec_buf, out_buf

    def forward(self,
                x,
                label,
                init_enc_buf=None,
                init_dec_buf=None,
                init_out_buf=None,
                pad=True):
        """
        Extracts the audio corresponding to the `label` in the given
        `mixture`. Generates `chunk_size` samples per iteration.

        Args:
            mixed: [B, n_mics, T]
                input audio mixture
            label: [B, num_labels]
                one hot label
        Returns:
            out: [B, n_spk, T]
                extracted audio with sounds corresponding to the `label`
        """
        mod = 0
        if pad:
            pad_size = (self.L, self.L) if self.lookahead else (0, 0)
            x, mod = mod_pad(x, chunk_size=self.L, pad=pad_size)

        if init_enc_buf is None or init_dec_buf is None or init_out_buf is None:
            assert init_enc_buf is None and \
                   init_dec_buf is None and \
                   init_out_buf is None, \
                "Both buffers have to initialized, or " \
                "both of them have to be None."
            enc_buf, dec_buf, out_buf = self.init_buffers(x.shape[0])
        else:
            enc_buf, dec_buf, out_buf = \
                init_enc_buf, init_dec_buf, init_out_buf

        # Generate latent space representation of the input
        x = self.in_conv(x)

        # Generate label embedding
        l = self.label_embedding(label)  # [B, label_len] --> [B, channels]

        # Generate mask corresponding to the label
        m, enc_buf, dec_buf = self.mask_gen(x, l, enc_buf, dec_buf)

        # Apply mask and decode
        x = x * m
        x = paddle.concat((out_buf, x), axis=-1)
        out_buf = x[..., -self.out_buf_len:]
        x = self.out_conv(x)

        # Remove mod padding, if present.
        if mod != 0:
            x = x[:, :, :-mod]

        return x, enc_buf, dec_buf, out_buf

4.2 模型预测器

import os
import json
from typing import List

import librosa
import paddle
import soundfile

class WaveFormerPredictor:

    def __init__(self, model_dir: str) -> None:
        config_file = os.path.join(model_dir, 'config.json')
        ckpt_file = os.path.join(model_dir, 'model.pdparams')
        with open(config_file, 'r', encoding='UTF-8') as f:
            configs = json.load(f)
        self.sample_rate = configs['test_data']['sr']
        state_dict = paddle.load(ckpt_file)

        self.waveformer = WaveFormer(**configs['model_params'])
        self.waveformer.set_state_dict(state_dict)
        self.waveformer.eval()

        self.target_list_en = [
            "Acoustic_guitar", "Applause", "Bark", "Bass_drum", "Burping_or_eructation",
            "Bus", "Cello", "Chime", "Clarinet", "Computer_keyboard", "Cough", "Cowbell",
            "Double_bass", "Drawer_open_or_close", "Electric_piano", "Fart", "Finger_snapping",
            "Fireworks", "Flute", "Glockenspiel", "Gong", "Gunshot_or_gunfire", "Harmonica",
            "Hi-hat", "Keys_jangling", "Knock", "Laughter", "Meow", "Microwave_oven", "Oboe",
            "Saxophone", "Scissors", "Shatter", "Snare_drum", "Squeak", "Tambourine", "Tearing",
            "Telephone", "Trumpet", "Violin_or_fiddle", "Writing"
        ]

        self.target_list_cn = [
            "木吉他", "鼓掌", "吠叫", "低音鼓", "打嗝或吞咽",
            "巴士", "大提琴", "钟声", "单簧管", "电脑键盘", "咳嗽", "牛铃",
            "低音提琴", "抽屉打开或关闭", "电子钢琴", "放屁", "手指打响",
            "烟花", "长笛", "手风琴", "锣", "枪声或炮声", "口琴",
            "踩镲", "按键声", "敲击声", "笑声", "喵星人", "微波炉", "双簧管",
            "萨克斯风", "剪刀", "碎片", "小鼓", "吱吱", "手鼓", "撕裂",
            "电话", "小号", "小提琴或提琴", "写作"
        ]
        self.target_num = 41

    def extraction(
            self,
            input_file: str,
            targets: List[str],
            output_file: str) -> None:
        with paddle.no_grad():
            mixture, _ = librosa.load(input_file, self.sample_rate)
            mixture = paddle.to_tensor(
                mixture, dtype=paddle.float32).unsqueeze(0).unsqueeze(0)

            target_indexs = []
            for target in targets:
                if target in self.target_list_en:
                    index = self.target_list_en.index(target)
                    target_indexs.append(index)
                elif target in self.target_list_cn:
                    index = self.target_list_cn.index(target)
                    target_indexs.append(index)
                elif (isinstance(target, int)) and (0<=target<self.target_num):
                    target_indexs.append(target)
            target_indexs = list(set(target_indexs))

            if len(target_indexs) == 0:
                query = paddle.ones((1, self.target_num))
            else:
                query = paddle.zeros((1, self.target_num))
                query[:, target_indexs] = 1

            output, _, _, _ = self.waveformer(mixture, query)
            output = output.squeeze().numpy()

            soundfile.write(output_file, output, self.sample_rate)

4.3 模型加载

class ModelList:
    E256_10_D128_1 = 'pretrained_models/dcc_tf_ckpt_E256_10_D128_1'
    E256_10_D256_1 = 'pretrained_models/dcc_tf_ckpt_E256_10_D256_1'
    E512_10_D128_1 = 'pretrained_models/dcc_tf_ckpt_E512_10_D128_1'
    E512_10_D256_1 = 'pretrained_models/dcc_tf_ckpt_E512_10_D256_1'
    E256_10_D128_1_multi = 'pretrained_models/dcc_tf_ckpt_E256_10_D128_1_multi'
    E256_10_D256_1_multi = 'pretrained_models/dcc_tf_ckpt_E256_10_D256_1_multi'
    E512_10_D128_1_multi = 'pretrained_models/dcc_tf_ckpt_E512_10_D128_1_multi'
    E512_10_D256_1_multi = 'pretrained_models/dcc_tf_ckpt_E512_10_D256_1_multi'

predictor = WaveFormerPredictor(
    model_dir=ModelList.E512_10_D256_1_multi
)

4.4 模型预测

from IPython.display import Audio, display

target_list_cn = [
    "木吉他", "鼓掌", "吠叫", "低音鼓", "打嗝或吞咽",
    "巴士", "大提琴", "钟声", "单簧管", "电脑键盘", "咳嗽", "牛铃",
    "低音提琴", "抽屉打开或关闭", "电子钢琴", "放屁", "手指打响",
    "烟花", "长笛", "手风琴", "锣", "枪声或炮声", "口琴",
    "踩镲", "按键声", "敲击声", "笑声", "喵星人", "微波炉", "双簧管",
    "萨克斯风", "剪刀", "碎片", "小鼓", "吱吱", "手鼓", "撕裂",
    "电话", "小号", "小提琴或提琴", "写作"
]

input_file='sample.wav'
targets=['电脑键盘']
output_file='output.wav'

predictor.extraction(
    input_file,
    targets,
    output_file
)

print('输入音频: ')
display(Audio(input_file))
print('输出音频: ')
display(Audio(output_file))

输入音频:
<IPython.lib.display.Audio object>
输出音频:
<IPython.lib.display.Audio object>

5. 小结

  • 简单介绍了一下 WaveFormer 模型

  • 搭建 WaveFormer 模型并加载预训练模型实现音频分离的推理
    此文章为搬运
    原项目链接

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐