相关论文

base: 基础普适研究

PS: 想法和自己整理的simple LM相似~

扩展阅读:2024.12 Byte Latent Transformer: Patches Scale Better Than Tokens | paper code


audio speech: 场景研究

扩展阅读

Neural Vocoder

VQ vector-quantize 矢量量化

VQ 让你脑洞大开~

常与 VAE, GAN 框架;Codec(tokenizer) 结合。以及Vector Compression storage and ANN search for Vector embeddings。

https://github.com/lucidrains/vector-quantize-pytorch


VQ-VAEs


VQ-GANs


Audio Codec with VQ


Visual Tokenizer with VQ

  • 3D-VQ from MAGVIT ( VQ-VAE 框架): 2023. MAGVIT: Masked Generative Video Transformer |paper code
  • Lookup-Free Quantization (LFQ) from MAGVIT2 ( VQ-VAE 框架): 2024.3 Language Model Beats Diffusion - Tokenizer is Key to Visual Generation (LM(discrete latent) vs DDM(continuous latent) 很好的一篇论文) | magvit2-pytorch code
    • 一种新的视频分词器,在三个方面优于之前性能最佳的视频分词器:视觉生成、视频压缩和动作识别。
    • 一种新颖的免查找量化方法,可以通过学习大量词汇来提高语言模型的视觉生成质量。
    • 论文中的实验表明,当提供相同的训练数据、等效的模型大小和相似的训练预算时,语言模型可以优于 ImageNet 上的扩散模型。
    • 根据用户研究,视频压缩器在相似的比特率下比 HEVC(High Efficiency Video Coding H.265) 和下一代video codec VVC (Versatile Video Coding H.266)具有更好的质量(人工评估)。专为视频生成而设计的视觉分词器的第一次成功尝试,以达到与标准编解码器相当的结果。
  • BSQ (Binary Spherical Quantization) from 2024.6 Image and Video Tokenization with Binary Spherical Quantization (BSQ vs LFQ BSQ 在量化误差、训练效率、熵计算复杂度以及实验结果等方面均优于 LFQ,尤其是在高维度和高分辨率数据上表现更为突出)|paper code
  • 2024.9 Open-MAGVIT2: An Open-Source Project Toward Democratizing Auto-regressive Visual Generation (可以练练丹,看看效果)| paper code Open-MAGVIT2(SEED-Voken 后面又引入 IBQ 后变成了 Visual Tokenizer) from magvit2-pytorch
  • IBQ (Index Backpropagation Quantization) from 2024.12 Taming Scalable Visual Tokenizer for Autoregressive Image Generation (解决了现有 VQ 方法在扩展性上的瓶颈问题,实现了高利用率的大规模视觉标记化,并在图像重建和生成任务中取得了优异的性能。通过全局更新策略,IBQ在整个训练过程中保持了较高的码本利用率(约96%);IBQ首次成功训练了大规模码本(2¹⁸,即262,144个码本大小)和高维度(256维)的视觉分词器;ImageNet 上的实验表明,IBQ 能够实现高利用率、大规模的视觉分词器,并在重建(还原修复)(1.00 rFID)和生成(创新)(2.05 gFID)方面提高性能)| paper code (scaling codebook size, code dimension and model size)

Vector Compression(storage) and Search(ANN) with VQ for Vector embeddings

FishSpeech

主要贡献

  • 利用LLMs和双AR 结构来取代传统的G2P (文字到音素的转换(Grapheme-to-Phoneme))转换,提供强大且可扩展的多语言语音合成;
  • FFGAN vocoder 集成了多种矢量量化技术(Grouped Finite Scalar Vector Quantization, GFSQ),以优化压缩比和码本利用率来实现高保真语音合成;
  • 优化推理,在消费级 NVIDIA RTX 4060 移动平台上实现了约 1:5 的实时因子,在高性能 NVIDIA RTX 4090 配置上实现了 1:15 的实时因子。延迟为 150 毫秒,远低于其他使用 DiT 和 Flow 结构的 TTS 系统。

Fish-Speech的双自回归(Dual-AR)模型结构

Dual-AR 模型结构 提高了序列生成过程中码本处理的稳定性和计算效率,特别是在使用分组有限标量矢量量化 (GFSQ) 时,由两个sequential AR transformer(GLM)组成: Slow TransformerFast Transformer (分别对应UniAudio的Global GPT 和 local GPT) ;Transformer 采用 Llama模型结构。

image

Slow Transformer

Slow Transformer 在更高的抽象级别上运行,处理输入文本嵌入(Text Embeddings 和 Codebook Embeddings)以编码全局语言结构和语义内容。该模块负责生成中间隐藏状态并高精度预测语义token

假设输入的toen序列为$x = [x₁, x₂, …, xₜ]$,Slow Transformer 生成隐藏状态 $h ∈ ℝ^{T×D}$ 和token对数 $z$,通过以下变换实现: $$ h = SlowTransformer(x) $$

$$ z = W_{tok} · Norm(h) $$

其中,$Norm(·)$表示层归一化,$W_{tok}$ 是标记预测层的可学习参数。

Fast Transformer

Fast Transformer通过Codebook Embedding处理来细化Slow Transformer的输出,捕捉自然语音所需的详细声学特征(acoustic features)处理残差信息并优化Codebook的使用

Fast Transformer的输入是隐藏状态 $h$ 和codebook embedding $c$ 的拼接序列,具体如下: $$ h̃ = [h; c],(h^{fast}) $$

$$ h^{fast} = FastTransformer(ĥ, (h^{fast})) $$

$$ y = W_{cbk} · Norm(h^{fast}) $$

其中,$[h; c]$ 表示 $h$ 和 $c$ 的拼接操作,$W_{cbk}$ 是codebook预测层的可学习参数,$y$ 是最终的码本对数。

Dual AR结构的优势

  • 增强的序列生成稳定性:通过分层处理全局和局部信息,显著提高了GFSQ在序列生成任务中的稳定性。
  • 优化的码本处理:Fast Transformer实现了一种高效的码本嵌入处理机制,在不增加显著计算开销的情况下提升了性能,特别是对于规模为7B或更大的模型。
  • 高保真语音合成质量:Slow和Fast Transformer之间的协同作用使得系统能够以高保真度合成语音,并能够处理复杂的语言现象。
  • 多语言处理能力:通过大语言模型(LLM)生成语言特征,消除了传统文字到音素转换的依赖,从而简化了合成流程并增强了多语言能力。通过混合文本数据,理解能力将进一步提升。

Firefly-GAN

Firefly-GAN(FF-GAN)是 2024.1 EVA-GAN: Enhanced Various Audio Generation via Scalable Generative Adversarial Networks 的一个显著改进版本。它用更高效的设计替换了HiFi-GAN中的传统卷积组件,引入了并行块(ParallelBlock)以替代多感受野(Multi-Receptive Field, MRF)模块。通过引入分组有限标量矢量量化(Grouped Finite Scalar Vector Quantization GFSQ),FF-GAN在序列生成稳定性以及语言变化处理方面表现出色,尤其适用于多语言合成的AI应用

image

Firefly Generator

FF-GAN采用了增强的卷积结构,包括深度可分离卷积(depth-wise separable convolution) 2017.4 Mobilenets: Efficient convolutional neural networks for mobile vision applications扩张卷积(dilated convolution) 2015.11 Multi-Scale Context Aggregation by Dilated Convolutions,以替代传统的Conv1d层。这种架构改进增强了模型捕捉和合成复杂音频特征的能力

对于ConvNext 论文中未提到,这里补充上, 很重要的, 在f5-tts 等结合diffusion的模型结构经常用到:

  • ⭐️ 2022. A ConvNet for the 2020s | paper code (对ResNet-200进行改进,按照ViT(encoder)变体 Swin Transformer设计,替换MSA(multiheaded self-attention)-> 7x7 conv2d; MLP(linear->1x1 kernel conv2d,active function ReLU->GeLU), BN->LN, 纯 ConvNet 模型结构, 性能和Swin transformer相当, 但是模型结构轻巧,推理更快; 但是ConvNeXt 可能更适合某些任务,比如: 图像分类、对象检测、实例和语义分割任务; 而 Transformers 对于其他任务可能更灵活,泛化能力强,当用于需要离散、稀疏或结构化输出的任务时,Transformer 可能会更加灵活。所以架构选择应该满足手头任务的需求,同时力求简单。)
  • ⭐️ 2023. ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders (ConvNeXt V2) | paper code

PS: 代码中的实现来自: https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py

并行块(ParallelBlock 来自 2024.1 EVA-GAN: Enhanced Various Audio Generation via Scalable Generative Adversarial Networks替代传统的多感受野(MRF)模块,进行改进:

  • 优化了“typo-codebook(错误码本)”输入的处理效率。
  • 实现了可配置的卷积核大小和扩张率,并采用 stack-and-average 机制处理来自三个残差块(ResBlocks)的输出,而不是直接进行加法操作。
  • 提供了增强的感受野(Multi-Receptive Field, MRF)覆盖、优越的特征提取能力和更好的可配置性,从而提升了音频合成的质量。

Grouped Finite Scalar Vector Quantization (GFSQ)

为了适应“typo-codebook”任务,引入了分组有限标量矢量量化( Grouped Finite Scalar Vector Quantization GFSQ)。

以下是GFSQ的详细开发过程:

假设输入张量为 $z ∈ ℝ^{B×C×L}$,整个过程包括以下步骤:

下采样(Downsampling)

使用下采样函数 $f_{down}$ 对输入张量 $z$ 进行下采样,得到下采样后的张量 $z_d ∈ ℝ^{B×C_d×L_d}$: $$ z_d = f_{down}(z) $$

GFSQ过程
  • 特征分组(Feature Grouping) 输入特征矩阵 $Z$ 被划分为$G$ 组: $$ Z = [Z^{(1)}, Z^{(2)}, …, Z^{(G)}] $$

  • 标量量化(Scalar Quantization) 对于每个标量 $z_{b,c,l}^{(g)}$: $$ ẑ_{b,c,l}^{(g)} = Q(z_{b,c,l}^{(g)}) $$

  • 索引生成(Index Generation) 每个标量映射到索引 $k_{b,c,l}^{(g)}$

  • 解码(Decoding) $$ ẑ_{b,c,l}^{(g)} = Codebook^{(g)}[k_{b,c,l}^{(g)}] $$

重构量化下采样张量(Reconstruct the Quantized Downsampled Tensor)

将所有组的量化向量沿通道维度拼接,得到量化下采样张量 $z_{qd} ∈ ℝ^{(B×C_d×L_d)}$: $$ z_{qd}(b, :, l) = [ẑ_{q_d}^{(1)}(b, :, l); ẑ_{q_d}^{(2)}(b, :, l); …; ẑ_{q_d}^{(G)}(b, :, l)] $$

上采样(Upsampling)

使用上采样函数 $f_{up}$ 将量化下采样张量恢复到其原始大小,得到最终的量化张量 $z_q ∈ ℝ^{B×C×L}$: $$ z_q = f_{up}(z_{q_d}) $$ 目标是使 $z_q$ 尽可能接近原始输入 $z$: $$ z_q ≈ z $$

💡

具体实现见配置对应的类和参数:

https://github.com/fishaudio/fish-speech/blob/main/fish_speech/configs/firefly_gan_vq.yaml

在配置文件中定义了dim和input_dim, 感觉不好理解,而且会产生歧义,dim定义成output_channels (特征数目)更好理解, 通过卷积层下采样将单个特征的尺寸(长度或者维度)缩小; 反之通过反卷积层将单个特征的尺寸(长度或者维度)放大;同理 input_dim定义成input_channels

结论

GFSQ技术实现了接近100%的码本利用率,并在内部消融实验中,相比于其他量化技术(如RFSQ、RVQ和GRFSQ)获得了更好的客观和主观评分。FF-GAN显著增强了“typo-codebook”操作的稳定性,并确保在多情感和多语言任务中保留全面的中间变量信息。

训练与推理

image

Training 训练

Fish-Speech采用了一个三阶段的训练方法:

  • 首先使用大规模的标准数据进行预训练(PT)
  • 然后通过小批量的高质量数据进行有监督微调(SFT)
  • 最后使用手动标记的正负样本对 进行DPO(Discriminator Pessimistic Optimization)训练

训练基础设施分为两个部分(见上图):

  • 声码器(Firefly-GAN vocoder)训练则使用8块RTX 4090 GPU持续一周。(Training Stage1)
  • 自回归(Dual AR)训练使用8块H100 GPU(80G显存),再持续一周。(Training Stage2)

需要注意的是,这些训练时间不包括DPO阶段。

PS:论文中没有对tokenizer这块介绍,使用 tiktoken BPE,见词表(扩展了codebook_size 1024个semantic 对应的 voice(waveform -> mel spectrogram) code indices (来自Firefly Encoder模块生成)): https://huggingface.co/fishaudio/fish-speech-1.5/raw/main/tokenizer.tiktoken | https://huggingface.co/fishaudio/fish-speech-1.5/blob/main/special_tokens.json -> embedding

Inference 推理

训练后推理过程见上图Inference。通过采用fish-tech技术(包括KV缓存⭐️2023. Efficiently Scaling Transformer Inference、PyTorch编译等加速方法),系统在消费级NVIDIA RTX 4060移动平台上实现了大约1:5的实时因子(real-time factors RTF),在高性能NVIDIA RTX 4090配置上实现了1:15的实时因子。这些架构优化显著降低了推理延迟,实现了150毫秒的首包延迟(first-packet latency)

此外,该系统可以流式处理信息,便于与现代AI工具结合,并在不同场景中使用。

Dataset 数据集

训练数据包括来自公共数据源和 数据收集过程的大量语音样本。数据集包含约72万(720k)小时的多语言语音数据,其中英语和普通话是主要组成部分,各占30万小时。还包含了其他语系的语音数据,每种语系各占2万小时,包括德语(日耳曼语系)、法语和意大利语(罗曼语系)、日语和韩语(东亚语系)以及阿拉伯语(闪米特语系)

仔细平衡了不同语言的数据,以帮助模型同时学习多种语言。这种方法有助于模型在生成混合语言内容时表现良好。数据集的庞大规模和多样性显著提升了模型处理多语言的自然性。

PS: scaling 数据集 到 72万小时;而 后续 CosyVoice2公布的训练数据集时 20万训练speech tokenizer, 16.68万小时训练cosyvoice2 LM

Experimental Evaluation 实验评估

为了评估论文中的模型相对于基线模型的效果,进行了说话人克隆任务的实验。

为了实验验证,将分析限制在单语言语音克隆场景,不包括跨语言合成的任务。评估语料库由10个不同的说话者(包括不同的语言)身份组成,其中30个是综合的每个说话者的话语,产生300个样本的综合评估集。应该指出的是,跨语言综合是不包括在本次评估中

评估方法包括客观和主观指标:

这一评估框架旨在评估模型在保持说话人身份的同时,维持高保真语音合成的能力。

Word Error Rate Analysis 词错误率分析

模型名称 WER (%)
真实录音 9.22
Fish-Speech 6.89
Reecho1 11.92
F5-TTS 13.98
CosyVoice 22.20

表1:语音克隆任务的词错误率(WER)结果

表1的分析显示,Fish-Speech 模型在语音克隆任务中实现了6.89%的词错误率,这不仅显著低于基线模型,甚至优于真实录音(9.22%)。这一性能为Fish-Speech模型在语音克隆场景中的能力提供了有力证据。该模型与其他竞争模型(WER范围从11.92%到22.20%)之间的差距突显了Fish-Speech模型方法在合成稳定性和内容保真度方面的改进。

Speaker Similarity Analysis 说话人相似度分析

模型名称 Resemblyzer SpeechBrain
真实录音 0.921 0.770
CosyVoice 0.936 0.813
Fish-Speech 0.914 0.762
F5-TTS 0.905 0.787
Reecho 0.887 0.636

表2:不同模型的说话人相似度评分,包括真实录音

表2展示了“typo-codebook”策略对说话人相似度指标的影响。Fish-Speech模型在Resemblyzer和SpeechBrain上的相似度评分分别为0.914和0.762,与真实录音的表现(0.921和0.770)非常接近。在Resemblyzer评估中,与真实录音的差距仅为0.76%,在SpeechBrain评估中为1.04%。这表明Fish-Speech模型在捕捉自然语音特征方面具有卓越的能力。结果强烈表明,“typo-codebook”架构能够更全面地捕捉声学状态,从而提高合成语音的音色保真度。Fish-Speech模型方法显著优于基线模型,例如F5-TTS(0.905和0.787)和Reecho(0.887和0.636)。在两个评估框架中的一致表现证明了Fish-Speech模型方法在保留说话人特征方面的有效性,这对于高质量的文本到语音合成和代理任务至关重要。

Perceptual Quality Assessment感知质量评估

模型名称 MOS
真实录音 5.00
Fish-Speech 4.05
CosyVoice 3.80
F5-TTS 2.90
Reecho 3.76

表3:克隆语音质量的五级平均意见得分(MOS)评分

为了评估合成音频的感知质量,进行了全面的平均意见得分(MOS)听测实验,参与者为没有音频处理经验的普通听众。评估遵循双盲、随机化方法,以确保评估的公正性。结果显示,Fish-Speech在主观评分上显著高于其他基线模型(p < 0.05),在语音自然度和说话人相似度方面表现出色。这一人类感知指标的评估强烈表明,Fish-Speech能够更好地捕捉和再现人类语音的自然特征,尤其是在语音克隆任务的背景下

结论

Fish-Speech论文研究在文本到语音(TTS)领域取得了显著进展,通过引入一种新型的多语言和多情感稳定化解决方案,为未来的人工智能应用提供了更自然、更高质量的语音合成技术。核心创新在于开发了一种结合了双自回归(Dual-AR)生成结构的“typo-codebook”声码器。这种结构组合在合成过程中表现出稳定性,同时保留了生成语音中的声学特征。

此外,Fish-Speech论文中采用了非文字到音素(non-G2P)结构,有效解决了传统基于音素系统固有的局限性,并为跨语言和情感多样化的TTS应用提供了坚实的基础,特别是在人工智能代理交互的背景下。


附录

Fréchet Inception Distance (FID)

FID 是 Fréchet Inception Distance 的缩写,是一种用于评估生成模型(如 GAN 和扩散模型)生成图像质量的指标。它通过比较生成图像和真实图像在特征空间中的分布差异来量化两者的相似度。具体来说,FID 的计算过程如下:

  1. 特征提取:使用预训练的 Inception v3 网络提取真实图像和生成图像的特征向量。
  2. 统计分布:计算这些特征向量的均值和协方差矩阵。
  3. 距离计算:使用 Fréchet 距离公式计算两个分布之间的距离。

FID 的公式为: $$ FID=∥μ_r−μ_g∥^2+Tr(Σ_r+Σ_g−2(Σ_rΣ_g)^{1/2}) $$ 其中,μr 和 μg 分别是真实图像和生成图像的特征均值,$Σr$ 和 $Σg$ 是它们的协方差矩阵。

FID 的物理含义

  • FID 越小,表示生成图像与真实图像越相似。
  • 它同时衡量了生成图像的质量和多样性。

FID 是目前广泛使用的生成模型评估指标之一,因为它能够更准确地反映生成图像与真实图像之间的分布差异。


residual information(残差信息)

在深度学习和语音合成的上下文中,residual information(残差信息) 是一个非常重要的概念,尤其是在处理复杂的序列生成任务时。它通常用于描述模型中未被完全捕捉或处理的信息,这些信息可能包含重要的细节或未被充分利用的特征。以下是对残差信息的详细解释及其在语音合成中的应用。

1. 残差信息的定义

在深度学习中,残差信息通常指的是输入数据与模型当前输出之间的差异。这种差异可能包含以下内容:

  • 未被模型捕捉的细节:模型可能未能完全理解输入数据中的某些复杂特征。
  • 噪声或误差:输入数据中可能存在的噪声或模型预测中的误差。
  • 未充分利用的特征:输入数据中某些特征可能未被模型充分利用,残差信息可以提供额外的线索来优化这些特征的处理。

2. 残差信息在语音合成中的应用

在语音合成(TTS)系统中,残差信息可以用于优化生成的语音质量和自然度。以下是残差信息在语音合成中的具体应用:

2.1 在双自回归架构中的作用

在Fish-Speech框架中,双自回归架构(Dual-AR)通过Slow Transformer 和 Fast Transformer 协同工作,处理全局和局部信息。残差信息在这个过程中扮演了重要角色:

  1. Slow Transformer的输出:Slow Transformer生成的隐藏状态包含了全局语义信息,但可能未能完全捕捉到所有细节。这些未被完全处理的信息就是残差信息。
  2. Fast Transformer的细化:Fast Transformer接收Slow Transformer的输出,并进一步处理残差信息。通过处理这些残差信息,Fast Transformer能够补充细节,优化生成的语音质量。例如:
    • 细节特征的补充:Fast Transformer可以处理残差信息中的声学细节,如音调、语调、情感表达等。
    • 自然度的提升:通过处理残差信息,Fast Transformer能够生成更自然、更流畅的语音。

2.2 在矢量量化中的作用

在GFSQ(Grouped Finite Scalar Vector Quantization)中,残差信息也起到重要作用。具体来说:

  • 量化过程中的残差:在量化过程中,输入特征被映射到码本中的最近邻向量。这个映射过程可能会丢失一些细节信息,这些未被完全映射的信息就是残差信息。
  • 优化码本利用:通过处理这些残差信息,模型可以更好地调整码本的使用,从而提高合成语音的自然度和保真度。

2.3 在神经声码器中的作用

在神经声码器(如FF-GAN)中,残差信息用于优化生成的音频信号。例如:

  • 频谱细节的补充:神经声码器通过处理残差信息,补充频谱中的细节,从而生成更自然的音频。
  • 高频重建:残差信息可以帮助神经声码器更好地重建高频成分,这对于语音的自然度和清晰度至关重要。

3. 残差信息的处理方法

在实际应用中,残差信息可以通过以下几种方式处理:

  1. 残差连接(Residual Connections):在深度学习模型中,残差连接允许模型直接传递输入信息到后续层,从而保留未被处理的细节。例如,在Transformer架构中,残差连接被广泛使用。
  2. 注意力机制(Attention Mechanisms):通过注意力机制,模型可以动态地关注输入数据中的重要部分,从而更好地处理残差信息。
  3. 多尺度特征提取:通过多尺度特征提取,模型可以同时处理全局和局部信息,从而更好地利用残差信息。

4. 总结

残差信息在语音合成中扮演着重要角色,它包含了未被模型完全处理的细节和特征。通过合理处理残差信息,模型可以显著提升生成语音的质量和自然度。在Fish-Speech框架中,残差信息通过双自回归架构和GFSQ技术被充分利用,从而实现了高效、高质量的语音合成。


注: https://huggingface.co/fishaudio/fish-speech-1.5/blob/main/config.json fish-speech-1.5 中的LM配置,Slow Transformer和 Fast Transformer 唯一的配置区别是layer数不同,前者n_layer是24,后者n_fast_layer是4.

DualARTransformer(
  (embeddings): Embedding(102048, 1024)
  (codebook_embeddings): Embedding(8192, 1024)
  (layers): ModuleList(
    (0-23): 24 x TransformerBlock(
      (attention): Attention(
        (wqkv): Linear(in_features=1024, out_features=1280, bias=False)
        (wo): Linear(in_features=1024, out_features=1024, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=1024, out_features=4096, bias=False)
        (w3): Linear(in_features=1024, out_features=4096, bias=False)
        (w2): Linear(in_features=4096, out_features=1024, bias=False)
      )
      (ffn_norm): RMSNorm()
      (attention_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=1024, out_features=102048, bias=False)
  (fast_project_in): Identity()
  (fast_embeddings): Embedding(1024, 1024)
  (fast_layers): ModuleList(
    (0-3): 4 x TransformerBlock(
      (attention): Attention(
        (wqkv): Linear(in_features=1024, out_features=1280, bias=False)
        (wo): Linear(in_features=1024, out_features=1024, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=1024, out_features=4096, bias=False)
        (w3): Linear(in_features=1024, out_features=4096, bias=False)
        (w2): Linear(in_features=4096, out_features=1024, bias=False)
      )
      (ffn_norm): RMSNorm()
      (attention_norm): RMSNorm()
    )
  )
  (fast_norm): RMSNorm()
  (fast_output): Linear(in_features=1024, out_features=1024, bias=False)
)
637.92128 M parameters


https://github.com/fishaudio/fish-speech/blob/main/fish_speech/configs/firefly_gan_vq.yaml

FireflyArchitecture(
  (backbone): ConvNeXtEncoder(
    (downsample_layers): ModuleList(
      (0): Sequential(
        (0): FishConvNet(
          (conv): Conv1d(160, 128, kernel_size=(7,), stride=(1,))
        )
        (1): LayerNorm()
      )
      (1): Sequential(
        (0): LayerNorm()
        (1): Conv1d(128, 256, kernel_size=(1,), stride=(1,))
      )
      (2): Sequential(
        (0): LayerNorm()
        (1): Conv1d(256, 384, kernel_size=(1,), stride=(1,))
      )
      (3): Sequential(
        (0): LayerNorm()
        (1): Conv1d(384, 512, kernel_size=(1,), stride=(1,))
      )
    )
    (stages): ModuleList(
      (0): Sequential(
        (0): ConvNeXtBlock(
          (dwconv): FishConvNet(
            (conv): Conv1d(128, 128, kernel_size=(7,), stride=(1,), groups=128)
          )
          (norm): LayerNorm()
          (pwconv1): Linear(in_features=128, out_features=512, bias=True)
          (act): GELU(approximate='none')
          (pwconv2): Linear(in_features=512, out_features=128, bias=True)
          (drop_path): Identity()
        )
        (1): ConvNeXtBlock(
          (dwconv): FishConvNet(
            (conv): Conv1d(128, 128, kernel_size=(7,), stride=(1,), groups=128)
          )
          (norm): LayerNorm()
          (pwconv1): Linear(in_features=128, out_features=512, bias=True)
          (act): GELU(approximate='none')
          (pwconv2): Linear(in_features=512, out_features=128, bias=True)
          (drop_path): DropPath(drop_prob=0.012)
        )
        (2): ConvNeXtBlock(
          (dwconv): FishConvNet(
            (conv): Conv1d(128, 128, kernel_size=(7,), stride=(1,), groups=128)
          )
          (norm): LayerNorm()
          (pwconv1): Linear(in_features=128, out_features=512, bias=True)
          (act): GELU(approximate='none')
          (pwconv2): Linear(in_features=512, out_features=128, bias=True)
          (drop_path): DropPath(drop_prob=0.024)
        )
      )
      (1): Sequential(
        (0): ConvNeXtBlock(
          (dwconv): FishConvNet(
            (conv): Conv1d(256, 256, kernel_size=(7,), stride=(1,), groups=256)
          )
          (norm): LayerNorm()
          (pwconv1): Linear(in_features=256, out_features=1024, bias=True)
          (act): GELU(approximate='none')
          (pwconv2): Linear(in_features=1024, out_features=256, bias=True)
          (drop_path): DropPath(drop_prob=0.035)
        )
        (1): ConvNeXtBlock(
          (dwconv): FishConvNet(
            (conv): Conv1d(256, 256, kernel_size=(7,), stride=(1,), groups=256)
          )
          (norm): LayerNorm()
          (pwconv1): Linear(in_features=256, out_features=1024, bias=True)
          (act): GELU(approximate='none')
          (pwconv2): Linear(in_features=1024, out_features=256, bias=True)
          (drop_path): DropPath(drop_prob=0.047)
        )
        (2): ConvNeXtBlock(
          (dwconv): FishConvNet(
            (conv): Conv1d(256, 256, kernel_size=(7,), stride=(1,), groups=256)
          )
          (norm): LayerNorm()
          (pwconv1): Linear(in_features=256, out_features=1024, bias=True)
          (act): GELU(approximate='none')
          (pwconv2): Linear(in_features=1024, out_features=256, bias=True)
          (drop_path): DropPath(drop_prob=0.059)
        )
      )
      (2): Sequential(
        (0): ConvNeXtBlock(
          (dwconv): FishConvNet(
            (conv): Conv1d(384, 384, kernel_size=(7,), stride=(1,), groups=384)
          )
          (norm): LayerNorm()
          (pwconv1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate='none')
          (pwconv2): Linear(in_features=1536, out_features=384, bias=True)
          (drop_path): DropPath(drop_prob=0.071)
        )
        (1): ConvNeXtBlock(
          (dwconv): FishConvNet(
            (conv): Conv1d(384, 384, kernel_size=(7,), stride=(1,), groups=384)
          )
          (norm): LayerNorm()
          (pwconv1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate='none')
          (pwconv2): Linear(in_features=1536, out_features=384, bias=True)
          (drop_path): DropPath(drop_prob=0.082)
        )
        (2): ConvNeXtBlock(
          (dwconv): FishConvNet(
            (conv): Conv1d(384, 384, kernel_size=(7,), stride=(1,), groups=384)
          )
          (norm): LayerNorm()
          (pwconv1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate='none')
          (pwconv2): Linear(in_features=1536, out_features=384, bias=True)
          (drop_path): DropPath(drop_prob=0.094)
        )
        (3): ConvNeXtBlock(
          (dwconv): FishConvNet(
            (conv): Conv1d(384, 384, kernel_size=(7,), stride=(1,), groups=384)
          )
          (norm): LayerNorm()
          (pwconv1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate='none')
          (pwconv2): Linear(in_features=1536, out_features=384, bias=True)
          (drop_path): DropPath(drop_prob=0.106)
        )
        (4): ConvNeXtBlock(
          (dwconv): FishConvNet(
            (conv): Conv1d(384, 384, kernel_size=(7,), stride=(1,), groups=384)
          )
          (norm): LayerNorm()
          (pwconv1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate='none')
          (pwconv2): Linear(in_features=1536, out_features=384, bias=True)
          (drop_path): DropPath(drop_prob=0.118)
        )
        (5): ConvNeXtBlock(
          (dwconv): FishConvNet(
            (conv): Conv1d(384, 384, kernel_size=(7,), stride=(1,), groups=384)
          )
          (norm): LayerNorm()
          (pwconv1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate='none')
          (pwconv2): Linear(in_features=1536, out_features=384, bias=True)
          (drop_path): DropPath(drop_prob=0.129)
        )
        (6): ConvNeXtBlock(
          (dwconv): FishConvNet(
            (conv): Conv1d(384, 384, kernel_size=(7,), stride=(1,), groups=384)
          )
          (norm): LayerNorm()
          (pwconv1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate='none')
          (pwconv2): Linear(in_features=1536, out_features=384, bias=True)
          (drop_path): DropPath(drop_prob=0.141)
        )
        (7): ConvNeXtBlock(
          (dwconv): FishConvNet(
            (conv): Conv1d(384, 384, kernel_size=(7,), stride=(1,), groups=384)
          )
          (norm): LayerNorm()
          (pwconv1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate='none')
          (pwconv2): Linear(in_features=1536, out_features=384, bias=True)
          (drop_path): DropPath(drop_prob=0.153)
        )
        (8): ConvNeXtBlock(
          (dwconv): FishConvNet(
            (conv): Conv1d(384, 384, kernel_size=(7,), stride=(1,), groups=384)
          )
          (norm): LayerNorm()
          (pwconv1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate='none')
          (pwconv2): Linear(in_features=1536, out_features=384, bias=True)
          (drop_path): DropPath(drop_prob=0.165)
        )
      )
      (3): Sequential(
        (0): ConvNeXtBlock(
          (dwconv): FishConvNet(
            (conv): Conv1d(512, 512, kernel_size=(7,), stride=(1,), groups=512)
          )
          (norm): LayerNorm()
          (pwconv1): Linear(in_features=512, out_features=2048, bias=True)
          (act): GELU(approximate='none')
          (pwconv2): Linear(in_features=2048, out_features=512, bias=True)
          (drop_path): DropPath(drop_prob=0.176)
        )
        (1): ConvNeXtBlock(
          (dwconv): FishConvNet(
            (conv): Conv1d(512, 512, kernel_size=(7,), stride=(1,), groups=512)
          )
          (norm): LayerNorm()
          (pwconv1): Linear(in_features=512, out_features=2048, bias=True)
          (act): GELU(approximate='none')
          (pwconv2): Linear(in_features=2048, out_features=512, bias=True)
          (drop_path): DropPath(drop_prob=0.188)
        )
        (2): ConvNeXtBlock(
          (dwconv): FishConvNet(
            (conv): Conv1d(512, 512, kernel_size=(7,), stride=(1,), groups=512)
          )
          (norm): LayerNorm()
          (pwconv1): Linear(in_features=512, out_features=2048, bias=True)
          (act): GELU(approximate='none')
          (pwconv2): Linear(in_features=2048, out_features=512, bias=True)
          (drop_path): DropPath(drop_prob=0.200)
        )
      )
    )
    (norm): LayerNorm()
  )
  (head): HiFiGANGenerator(
    (conv_pre): FishConvNet(
      (conv): ParametrizedConv1d(
        512, 512, kernel_size=(13,), stride=(1,)
        (parametrizations): ModuleDict(
          (weight): ParametrizationList(
            (0): _WeightNorm()
          )
        )
      )
    )
    (noise_convs): ModuleList()
    (ups): ModuleList(
      (0): FishTransConvNet(
        (conv): ParametrizedConvTranspose1d(
          512, 256, kernel_size=(16,), stride=(8,)
          (parametrizations): ModuleDict(
            (weight): ParametrizationList(
              (0): _WeightNorm()
            )
          )
        )
      )
      (1): FishTransConvNet(
        (conv): ParametrizedConvTranspose1d(
          256, 128, kernel_size=(16,), stride=(8,)
          (parametrizations): ModuleDict(
            (weight): ParametrizationList(
              (0): _WeightNorm()
            )
          )
        )
      )
      (2): FishTransConvNet(
        (conv): ParametrizedConvTranspose1d(
          128, 64, kernel_size=(4,), stride=(2,)
          (parametrizations): ModuleDict(
            (weight): ParametrizationList(
              (0): _WeightNorm()
            )
          )
        )
      )
      (3): FishTransConvNet(
        (conv): ParametrizedConvTranspose1d(
          64, 32, kernel_size=(4,), stride=(2,)
          (parametrizations): ModuleDict(
            (weight): ParametrizationList(
              (0): _WeightNorm()
            )
          )
        )
      )
      (4): FishTransConvNet(
        (conv): ParametrizedConvTranspose1d(
          32, 16, kernel_size=(4,), stride=(2,)
          (parametrizations): ModuleDict(
            (weight): ParametrizationList(
              (0): _WeightNorm()
            )
          )
        )
      )
    )
    (resblocks): ModuleList(
      (0): ParallelBlock(
        (blocks): ModuleList(
          (0): ResBlock1(
            (convs1): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  256, 256, kernel_size=(3,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  256, 256, kernel_size=(3,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  256, 256, kernel_size=(3,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
            (convs2): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  256, 256, kernel_size=(3,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  256, 256, kernel_size=(3,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  256, 256, kernel_size=(3,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
          )
          (1): ResBlock1(
            (convs1): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  256, 256, kernel_size=(7,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  256, 256, kernel_size=(7,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  256, 256, kernel_size=(7,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
            (convs2): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  256, 256, kernel_size=(7,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  256, 256, kernel_size=(7,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  256, 256, kernel_size=(7,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
          )
          (2): ResBlock1(
            (convs1): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  256, 256, kernel_size=(11,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  256, 256, kernel_size=(11,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  256, 256, kernel_size=(11,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
            (convs2): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  256, 256, kernel_size=(11,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  256, 256, kernel_size=(11,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  256, 256, kernel_size=(11,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
          )
        )
      )
      (1): ParallelBlock(
        (blocks): ModuleList(
          (0): ResBlock1(
            (convs1): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  128, 128, kernel_size=(3,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  128, 128, kernel_size=(3,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  128, 128, kernel_size=(3,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
            (convs2): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  128, 128, kernel_size=(3,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  128, 128, kernel_size=(3,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  128, 128, kernel_size=(3,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
          )
          (1): ResBlock1(
            (convs1): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  128, 128, kernel_size=(7,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  128, 128, kernel_size=(7,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  128, 128, kernel_size=(7,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
            (convs2): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  128, 128, kernel_size=(7,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  128, 128, kernel_size=(7,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  128, 128, kernel_size=(7,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
          )
          (2): ResBlock1(
            (convs1): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  128, 128, kernel_size=(11,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  128, 128, kernel_size=(11,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  128, 128, kernel_size=(11,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
            (convs2): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  128, 128, kernel_size=(11,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  128, 128, kernel_size=(11,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  128, 128, kernel_size=(11,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
          )
        )
      )
      (2): ParallelBlock(
        (blocks): ModuleList(
          (0): ResBlock1(
            (convs1): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  64, 64, kernel_size=(3,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  64, 64, kernel_size=(3,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  64, 64, kernel_size=(3,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
            (convs2): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  64, 64, kernel_size=(3,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  64, 64, kernel_size=(3,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  64, 64, kernel_size=(3,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
          )
          (1): ResBlock1(
            (convs1): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  64, 64, kernel_size=(7,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  64, 64, kernel_size=(7,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  64, 64, kernel_size=(7,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
            (convs2): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  64, 64, kernel_size=(7,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  64, 64, kernel_size=(7,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  64, 64, kernel_size=(7,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
          )
          (2): ResBlock1(
            (convs1): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  64, 64, kernel_size=(11,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  64, 64, kernel_size=(11,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  64, 64, kernel_size=(11,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
            (convs2): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  64, 64, kernel_size=(11,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  64, 64, kernel_size=(11,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  64, 64, kernel_size=(11,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
          )
        )
      )
      (3): ParallelBlock(
        (blocks): ModuleList(
          (0): ResBlock1(
            (convs1): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  32, 32, kernel_size=(3,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  32, 32, kernel_size=(3,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  32, 32, kernel_size=(3,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
            (convs2): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  32, 32, kernel_size=(3,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  32, 32, kernel_size=(3,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  32, 32, kernel_size=(3,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
          )
          (1): ResBlock1(
            (convs1): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  32, 32, kernel_size=(7,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  32, 32, kernel_size=(7,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  32, 32, kernel_size=(7,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
            (convs2): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  32, 32, kernel_size=(7,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  32, 32, kernel_size=(7,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  32, 32, kernel_size=(7,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
          )
          (2): ResBlock1(
            (convs1): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  32, 32, kernel_size=(11,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  32, 32, kernel_size=(11,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  32, 32, kernel_size=(11,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
            (convs2): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  32, 32, kernel_size=(11,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  32, 32, kernel_size=(11,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  32, 32, kernel_size=(11,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
          )
        )
      )
      (4): ParallelBlock(
        (blocks): ModuleList(
          (0): ResBlock1(
            (convs1): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  16, 16, kernel_size=(3,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  16, 16, kernel_size=(3,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  16, 16, kernel_size=(3,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
            (convs2): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  16, 16, kernel_size=(3,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  16, 16, kernel_size=(3,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  16, 16, kernel_size=(3,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
          )
          (1): ResBlock1(
            (convs1): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  16, 16, kernel_size=(7,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  16, 16, kernel_size=(7,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  16, 16, kernel_size=(7,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
            (convs2): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  16, 16, kernel_size=(7,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  16, 16, kernel_size=(7,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  16, 16, kernel_size=(7,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
          )
          (2): ResBlock1(
            (convs1): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  16, 16, kernel_size=(11,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  16, 16, kernel_size=(11,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  16, 16, kernel_size=(11,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
            (convs2): ModuleList(
              (0): FishConvNet(
                (conv): ParametrizedConv1d(
                  16, 16, kernel_size=(11,), stride=(1,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (1): FishConvNet(
                (conv): ParametrizedConv1d(
                  16, 16, kernel_size=(11,), stride=(1,), dilation=(3,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
              (2): FishConvNet(
                (conv): ParametrizedConv1d(
                  16, 16, kernel_size=(11,), stride=(1,), dilation=(5,)
                  (parametrizations): ModuleDict(
                    (weight): ParametrizationList(
                      (0): _WeightNorm()
                    )
                  )
                )
              )
            )
          )
        )
      )
    )
    (activation_post): SiLU(inplace=True)
    (conv_post): FishConvNet(
      (conv): ParametrizedConv1d(
        16, 1, kernel_size=(13,), stride=(1,)
        (parametrizations): ModuleDict(
          (weight): ParametrizationList(
            (0): _WeightNorm()
          )
        )
      )
    )
  )
  (quantizer): DownsampleFiniteScalarQuantize(
    (residual_fsq): GroupedResidualFSQ(
      (rvqs): ModuleList(
        (0-7): 8 x ResidualFSQ(
          (project_in): Linear(in_features=64, out_features=4, bias=True)
          (project_out): Linear(in_features=4, out_features=64, bias=True)
          (layers): ModuleList(
            (0): FSQ(
              (project_in): Identity()
              (project_out): Identity()
            )
          )
        )
      )
    )
    (downsample): Sequential(
      (0): Sequential(
        (0): FishConvNet(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,))
        )
        (1): ConvNeXtBlock(
          (dwconv): FishConvNet(
            (conv): Conv1d(512, 512, kernel_size=(7,), stride=(1,), groups=512)
          )
          (norm): LayerNorm()
          (pwconv1): Linear(in_features=512, out_features=2048, bias=True)
          (act): GELU(approximate='none')
          (pwconv2): Linear(in_features=2048, out_features=512, bias=True)
          (drop_path): Identity()
        )
      )
      (1): Sequential(
        (0): FishConvNet(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,))
        )
        (1): ConvNeXtBlock(
          (dwconv): FishConvNet(
            (conv): Conv1d(512, 512, kernel_size=(7,), stride=(1,), groups=512)
          )
          (norm): LayerNorm()
          (pwconv1): Linear(in_features=512, out_features=2048, bias=True)
          (act): GELU(approximate='none')
          (pwconv2): Linear(in_features=2048, out_features=512, bias=True)
          (drop_path): Identity()
        )
      )
    )
    (upsample): Sequential(
      (0): Sequential(
        (0): FishTransConvNet(
          (conv): ConvTranspose1d(512, 512, kernel_size=(2,), stride=(2,))
        )
        (1): ConvNeXtBlock(
          (dwconv): FishConvNet(
            (conv): Conv1d(512, 512, kernel_size=(7,), stride=(1,), groups=512)
          )
          (norm): LayerNorm()
          (pwconv1): Linear(in_features=512, out_features=2048, bias=True)
          (act): GELU(approximate='none')
          (pwconv2): Linear(in_features=2048, out_features=512, bias=True)
          (drop_path): Identity()
        )
      )
      (1): Sequential(
        (0): FishTransConvNet(
          (conv): ConvTranspose1d(512, 512, kernel_size=(2,), stride=(2,))
        )
        (1): ConvNeXtBlock(
          (dwconv): FishConvNet(
            (conv): Conv1d(512, 512, kernel_size=(7,), stride=(1,), groups=512)
          )
          (norm): LayerNorm()
          (pwconv1): Linear(in_features=512, out_features=2048, bias=True)
          (act): GELU(approximate='none')
          (pwconv2): Linear(in_features=2048, out_features=512, bias=True)
          (drop_path): Identity()
        )
      )
    )
  )
  (spec_transform): LogMelSpectrogram(
    (spectrogram): LinearSpectrogram()
  )
)
47.065218 M parameters