月曜日, 6月 30, 2025
月曜日, 6月 30, 2025
- Advertisment -
ホームニューステックニュース初めてのTTSモデルをずんだもんのデータで作ってみた with LLaSA

初めてのTTSモデルをずんだもんのデータで作ってみた with LLaSA


この記事は、音声系素人がLLaSAというTTSモデルをずんだもんのデータでファインチューニングしてみたという記事です。

元々私はLLM関連をメインに開発等を行っている人間なのですが、何となく音声系に興味を持ちました。(深い理由はない)

理論の勉強の前に何か試しに作ってみたいなと思って色々と調べていたところ、LLaSAというモデルを見つけました。

https://arxiv.org/abs/2502.04128

https://huggingface.co/collections/HKUSTAudio/llasa-679b87dbd06ac556cc0e0f44

このモデルの面白いなと思った点は、モデルのアーキテクチャが完全にLLaMA 3.1(8B) / 3.2(1B、3B)と同一だということです。LLMに対して何らかのアーキテクチャ変更を加えて音声合成するのではなく、語彙サイズ以外完全に同一です。学習も単なるNext Token Predictionで、特殊なlossの導入等も不要です。

そのため、LLMを普段触っている身として学習などが非常にやりやすく、今回試すモデルとしてこれを選定しました。

筆者はTTS初心者なので用語や説明に間違いがある可能性があります。もし何かあればご指摘ください。

LLaSAはLLaSA: Scaling Train-Time and Test-Time Compute for LLaMA-based Speech Synthesisという論文で提案されたTTSのフレームワーク・モデルです。

このモデルの特徴として、完成品のモデルのアーキテクチャが語彙サイズ以外完全にLLaMA 3.1 / 3.2と同一です。そのため、元モデルの重みを最大限有効活用することが出来ますし、学習・推論時にそのままLLMのエコシステムに載せることが出来ます。

具体的には、以下のような流れでモデルの準備と学習を行います。

1. 音声データをLLMが扱える離散空間のトークンに変換するためのニューラル音声コーデックを学習

LLM的に言うと、Tokenizerの音声版のようなものでしょうか。音声→トークンの変換とトークン→音声の変換(音声合成時の操作)を両方行うことのできるコーデックを学習します。
これによって出来上がったコーデックがX-Codec2です。1秒間の音声を50フレームに切り出し、それを65536種類のコードにマッピングします。1秒の音声が50トークンに相当する形で変換されます。

2. TTSの学習データを上記コーデックにより変換

TTSの学習に使う音声データをX-Codec2によりコードに変換します。これを次に説明する処理でLLMが扱えるトークンと対応付けて学習します。
このステップで音声系の範囲から完全にLLM側の範囲に落とし込まれます。そのため、次のステップ以降は完全にLLM側の処理になります。

3. 音声トークンを扱えるようにLLMを語彙拡張

X-Codec2が音声を処理して出す結果は単なる整数列なので、これをLLMが扱えるトークンとして語彙拡張します。
具体的には、X-Codec2が出す0から65535のコードを以下のような形式のSpecial Tokensとして取り扱います。それぞれコードの0, 1, … , 65535に対応付けられます。

また、同時に以下のSpecial Tokensも追加しています。テキスト部分と音声部分の切り分けなどに用いられているようです。

この辺りは実際のtokenizer_config.jsonを見ていただけると分かりやすいと思います。

https://huggingface.co/HKUSTAudio/Llasa-1B-Multilingual/blob/main/tokenizer_config.json

4. データの前処理

これまでの処理で、音声データがLLMの扱えるトークン列に全て変換されました。
このタイミングで、テキストトークンと音声トークンを結合して1つの学習データを作成します。LLMで言うとInstruction Tuning用の形式に加工するイメージです。

以下のようにテキスト入力→それに対応する音声トークンのような形式にすることで、TTSを実現します。学習時にはLLMのSFTのようにテキスト入力部分にはloss maskがかけられ、音声トークン部分のみ学習されます。

system

Cutting Knowledge Date: December 2023
Today Date: 27 Jun 2025

user

Convert the text to speech:{input_text}assistant

...

個人的にやや疑問が残っている点は以下2つです。

  1. 元モデルのchat templateを流用しているため等がそのまま使われていますが、あまり流用する意味はない気がします。
  2. 学習コードや推論コードを見るとsystem roleを設定していないため、元モデルにあるデフォルトのsystem promptが使われています。上記の例でのCutting Knowledge Date ...あたりがそれに該当する部分です。これも同様に不要な気がします。

少し調べてみましたが結局よく分からずでした。元モデルのテキストチャット能力を残したマルチモーダルモデルのような拡張を考慮した結果などだったりするんでしょうか。

5. TTSモデルの学習

上記のデータを元に、シンプルなNext Token Predictionタスクとして学習します。lossも通常のCross Entropy Lossです。
TTSモデルの学習ではありますが普通のLLMの学習と同等なので、LLMのエコシステムに載って簡単に学習することが出来ます。HF Trainerで学習することも出来ますし、MegatronやNeMo等のフレームワークでの学習も可能なはずです。Liger-Kernelのような学習効率化手法も利用可能です。

6. 実際の推論

実際の推論は非常にシンプルで、以下のようにテキストトークン部分 + までをモデルに入力します。

system

Cutting Knowledge Date: December 2023
Today Date: 27 Jun 2025

user

Convert the text to speech:{input_text}assistant


すると、Next Token Predictionの文脈で入力テキストに対応する音声トークンに相当する追加されたSpecial Tokensが出力されます。この音声トークンを対応するコードに変換し、それをX-Codec2でdecodeして音声を形成することで、テキスト入力に対しての音声を得ることが出来ます。

X-Codec2によるdecode部分以外は通常のLlamaの推論と完全に同等です。そのため、vLLMのような高速推論フレームワークに載せて利用することも簡単に出来ます。

7. Test-Time Scalingなどの検証

TTSモデルとしてはこれで完成ですが、元論文ではTest-Time Scalingなどの発展技術等に関しても検証されています。私自身理解出来ていない部分も多く、この記事の内容ともあまり関係が無いので、ここでは説明を省きます。気になる方は元論文をご確認ください。

ここでは、音声データを元に実際にLLaSAを学習し、カスタマイズされたTTSモデルを作ってみます。

先述したように、学習はLLMのエコシステムで行うことが出来ます。今回はLLM学習フレームワークであるaxolotlで学習を行ってみます。

実行した環境は以下の通りです。

  • GPU: NVIDIA A100 80GB x 8
  • CUDA version: 12.4
  • axolotl version: axolotlai/axolotl-cloud:main-20250616-py3.11-cu124-2.6.0のDocker Imageを利用

学習データの準備

学習データとして、音声とその文字起こしに相当するテキストのペアが必要です。

今回、学習データとしてSSS合同会社様より公開されているずんだもんのマルチモーダルデータベース ©SSSのボイスデータを利用させていただきます。利用の際には規約をよくご確認ください。
このデータからROHAN4600マルチモーダルデータベースとITAコーパスマルチモーダルデータベースの朗読324、合計4924件をお借りします。規約上データは再頒布出来ないので、加工方法だけ説明します。

まず、上記リンクよりずんだもんのROHAN4600ボイスデータおよびITAコーパスマルチモーダルデータベース 朗読324のボイスデータをダウンロード・解凍します。合計4924件の.wavファイルが得られると思います。

次に、対応する文字起こしを用意します。文字起こしデータは上記データベースではなく、元コーパスのリポジトリに存在するので、そこから用意します。

これらを元に、audiotextの2つのカラムを持つHugging Faceデータセットを用意しておいてください。文字起こし中にある括弧書きの読みは不要だと思います。また、HFにデータをアップロードする場合は必ずprivateにしてください。
なお、このデータにはやや長めの無音区間が含まれるため本来前処理をした方が良いと思われますが、今回はお試しという事でそのまま利用します。

このデータを元に、X-Codec2による音声トークンへの変換・テキストトークンとの結合などを行い、実際の学習データセットを作成します。今回は学習にaxolotlを使うので、axolotlの期待するpre-tokenized datasetの形式に合わせて加工します。

データの前処理には以下のスクリプトを用いました。主にUnslothのNotebookDeep-unlearningのリポジトリを参考にo3に書かせたものです。

変更点として、参考元のスクリプトで行っているpadding処理を削除しています。
これは、axolotl側でのMultipack Samplerを用いたSample Packingによる効率的な学習を実現するためです。元の例では2048トークンまでpaddingしていたため1サンプルで1レコードしか学習できませんでしたが、Packingにより1サンプルに複数レコードを詰め込むことができ、学習の高速化が見込めます。

前処理スクリプト


"""
preprocess_llasa.py
------------------------
・XCodec2 で音声→トークン
・chat template を経由して Llasa 推論時と同じ input_ids を生成
・labels は  以降のみ学習
・可変長 JSONL を Axolotl 用に保存
"""

import argparse, json, os, sys, torch, torchaudio, numpy as np
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer
from xcodec2.modeling_xcodec2 import XCodec2Model
from tqdm import tqdm


p = argparse.ArgumentParser()
p.add_argument("--data", required=True, help="HF dataset name / local path / JSONL …")
p.add_argument("--split", default="train")
p.add_argument("--output", required=True)
p.add_argument("--base_model", default="HKUSTAudio/Llasa-1B-Multilingual")
p.add_argument("--codec", default="HKUSTAudio/xcodec2")
p.add_argument("--sample_rate", type=int, default=16000)
p.add_argument("--max_len", type=int, default=2048)
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
args = p.parse_args()


tok = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
special = [
    "", "",
    "",  ""
]
if not all(t in tok.get_vocab() for t in special):
    tok.add_tokens([t for t in special if t not in tok.get_vocab()])


if "" not in tok.get_vocab():
    tok.add_tokens([f"{i}|>" for i in range(65536)])

codec = XCodec2Model.from_pretrained(args.codec).to(args.device).eval()

BOS, EOS = tok.bos_token_id, tok.eos_token_id
SP_S     = tok.convert_tokens_to_ids("")


ds: Dataset = load_dataset(args.data, split=args.split)


os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
written, skipped = 0, 0

def replace_token(seq, target_id, new_ids):
    """seq(list[int]) の target_id を new_ids に置換(1 回だけ)"""
    try:
        idx = seq.index(target_id)
        return seq[:idx] + new_ids + seq[idx+1:]
    except ValueError:
        return seq

with open(args.output, "w", encoding="utf-8") as fout, torch.inference_mode():
    for ex in tqdm(ds, desc="build jsonl"):
        try:
            
            if not ex.get("text"):
                raise ValueError("no text")
            text_tagged = f"{ex['text']}"
            text_ids = tok.encode(text_tagged, add_special_tokens=False)

            
            if not (ex.get("audio") and "array" in ex["audio"]):
                raise ValueError("no audio")
            wav = torch.tensor(ex["audio"]["array"]).float()
            sr  = ex["audio"].get("sampling_rate", args.sample_rate)
            if sr != args.sample_rate:
                wav = torchaudio.functional.resample(wav, sr, args.sample_rate)
            wav = wav.squeeze()
            if wav.ndim == 0:
                raise ValueError("empty audio")
            wav = wav.unsqueeze(0).to(args.device)  
            codes_np = codec.encode_code(wav)[0][0].cpu().numpy()
            speech_ids = [SP_S] + tok.convert_tokens_to_ids(
                [f"{c}|>" for c in codes_np]
            ) + [tok.convert_tokens_to_ids("")]

            
            chat = [
                {"role": "user",
                 "content": "Convert the text to speech:"},
                {"role": "assistant",
                 "content": ""}
            ]
            
            from inspect import signature
            kw = {}
            if "add_generation_prompt" in signature(tok.apply_chat_template).parameters:
                kw["add_generation_prompt"] = False
            templ_ids = tok.apply_chat_template(chat, tokenize=True, **kw)

            
            seq = replace_token(templ_ids, tok.convert_tokens_to_ids(""), text_ids)
            seq = replace_token(seq, tok.convert_tokens_to_ids(""), speech_ids)

            
            if seq[0] != BOS:
                seq = [BOS] + seq
            if seq[-1] != EOS:
                seq.append(EOS)

            if len(seq) > args.max_len:
                skipped += 1
                continue

            
            labels = seq.copy()
            sp_idx = labels.index(SP_S)
            labels[:sp_idx] = [-100] * sp_idx
            attn = [1] * len(seq)

            
            fout.write(json.dumps({
                "input_ids": seq,
                "attention_mask": attn,
                "labels": labels
            }, ensure_ascii=False) + "\n")
            written += 1

        except Exception as e:
            skipped += 1
            print(e)
            continue

print(f"✅ written : {written}")
print(f"⚠️ skipped : {skipped}")
print(f"➡️ saved   : {args.output}")

xcodec2==0.1.5transformers==4.48torch==2.5.0の環境でこれを実行し、前処理済みのjsonlファイルを書き出すところまで行いました。

学習の実行

先述した手順で前処理したデータを元に、実際にaxolotlによる学習を行います。ライブラリの概要や基本的な使い方に関しては、私が以前書いた以下記事を参考にしてください。

https://zenn.dev/aratako_lm/articles/b58ac364f9c9cd

以下のようなconfigを実際の学習に利用しました。preprocessed.jsonlは前処理済みのjsonlデータです。

学習のconfig
base_model: HKUSTAudio/Llasa-1B-Multilingual
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer

hub_model_id: Aratako/Llasa-1B-Zundamon
hub_strategy: "end"
push_dataset_to_hub:
hf_use_auth_token: true


plugins:
  - axolotl.integrations.liger.LigerPlugin

liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_fused_linear_cross_entropy: true
liger_layer_norm: true

load_in_8bit: false
load_in_4bit: false
strict: false

chat_template: tokenizer_default

datasets:
  - path: /workspace/data/preprocessed.jsonl
    ds_type: json
    type:

dataset_processes: 128

shuffle_merged_datasets: true
dataset_prepared_path: /workspace/data/llasa-data
val_set_size: 0.05
output_dir: /workspace/models/Llasa-1B-Zundamon

sequence_len: 2048
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true

adapter:
lora_model_dir:
lora_r:
lora_alpha:
lora_dropout:
lora_target_linear:
lora_fan_in_fan_out:

wandb_project: Llasa-1B-Multilingual
wandb_entity: aratako-lm
wandb_watch:
wandb_name: Zundamon-LR1e-5-BS16
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 30
optimizer: adamw_torch
lr_scheduler: cosine
cosine_min_lr_ratio: 0.1
learning_rate: 1e-5
max_grad_norm: 1.0

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
bfloat16: true

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false
early_stopping_patience:
auto_resume_from_checkpoints: true
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

save_strategy: epoch
save_total_limit: 2
eval_strategy: epoch
eval_batch_size: 4

save_only_model: true

warmup_steps: 50
debug:
deepspeed: /workspace/axolotl/deepspeed_configs/zero1.json
weight_decay: 0.01
fsdp:
fsdp_config:
metric_for_best_model: eval_loss
load_best_model_at_end: true

このconfigを元に以下のようなコマンドでaxolotl側の前処理と学習を実行します。

axolotl preprocess config.yaml --deepspeed /workspace/axolotl/deepspeed_configs/zero1.json --debug
axolotl train config.yaml --deepspeed /workspace/axolotl/deepspeed_configs/zero1.json


学習の様子

無事学習が終わるとoutput_dirに学習済みモデルが保存されます。

学習後のモデルで実際に推論してみます。

推論コードの例は元モデルのリポジトリに存在するので、それを参考にして以下のようなスクリプトを書いて実行してみます。
model_pathは自身のモデルに差し替えてください。

推論コード例(transformers)
import soundfile as sf
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from xcodec2.modeling_xcodec2 import XCodec2Model


model_path = "Aratako/Llasa-1B-Zundamon"





tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16)
model.eval().cuda()

codec_path = "HKUSTAudio/xcodec2"

codec_model = XCodec2Model.from_pretrained(codec_path)
codec_model.eval().cuda()

input_text = "ボク、ずんだもんなのだ!よろしくなのだ!"


def ids_to_speech_tokens(speech_ids):

    speech_tokens_str = []
    for speech_id in speech_ids:
        speech_tokens_str.append(f"{speech_id}|>")
    return speech_tokens_str


def extract_speech_ids(speech_tokens_str):

    speech_ids = []
    for token_str in speech_tokens_str:
        if token_str.startswith(") and token_str.endswith("|>"):
            num_str = token_str[4:-2]

            num = int(num_str)
            speech_ids.append(num)
        else:
            print(f"Unexpected token: {token_str}")
    return speech_ids


with torch.no_grad():

    formatted_text = (
        f"{input_text}"
    )

    chat = [
        {"role": "user", "content": "Convert the text to speech:" + formatted_text},
        {"role": "assistant", "content": ""},
    ]

    input_ids = tokenizer.apply_chat_template(
        chat, tokenize=True, return_tensors="pt", continue_final_message=True
    )
    input_ids = input_ids.to("cuda")
    speech_end_id = tokenizer.convert_tokens_to_ids("")

    
    outputs = model.generate(
        input_ids,
        max_length=2048,
        eos_token_id=speech_end_id,
        do_sample=True,
        top_p=1,
        temperature=0.8,
    )

    generated_ids = outputs[0][input_ids.shape[1] : -1]

    speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    speech_tokens = extract_speech_ids(speech_tokens)

    speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)

    gen_wav = codec_model.decode_code(speech_tokens)


sf.write("gen.wav", gen_wav[0, 0, :].cpu().numpy(), 16000)

実際の合成結果が以下の通りです。テキストはボク、ずんだもんなのだ!よろしくなのだ!です。

学習前の1Bモデル(HKUSTAudio/Llasa-1B-Multilingual)

https://soundcloud.com/aratako/hkustaudiollasa-1b-multilingual?si=4456785ca67b4d089289047143241602&utm_source=clipboard&utm_medium=text&utm_campaign=social_sharing

学習後の1Bモデル

https://soundcloud.com/aratako/aratakollasa-1b-zundamon?si=9210b75158e4454ab944e35c11744955&utm_source=clipboard&utm_medium=text&utm_campaign=social_sharing

学習により、合成音声がずんだもんのようになったことが確認できました!やや発音等がおかしいのは元モデルの日本語データの学習不足が原因だと思われます。

参考までに、3B・8Bでの学習前後のモデルでの合成結果も貼っておきます。

学習前の3Bモデル(HKUSTAudio/Llasa-3B)

https://soundcloud.com/aratako/hkustaudiollasa-3b?si=380a95ba47064d6aab4da3b2ab8b570c&utm_source=clipboard&utm_medium=text&utm_campaign=social_sharing

学習後の3Bモデル

https://soundcloud.com/aratako/aratakollasa-3b-zundamon?si=b11555d795884e85a2ac7c8cc4bc487f&utm_source=clipboard&utm_medium=text&utm_campaign=social_sharing

学習前の8Bモデル(HKUSTAudio/Llasa-8B)

https://soundcloud.com/aratako/hkustaudiollasa-8b?si=f7ddaad1e67f4a178ac22f68acfbf21e&utm_source=clipboard&utm_medium=text&utm_campaign=social_sharing

学習後の8Bモデル

https://soundcloud.com/aratako/aratakollasa-8b-zundamon?si=6e82536a4fea49c18fe9b4be5d79d984&utm_source=clipboard&utm_medium=text&utm_campaign=social_sharing

また、これも参考までにですが、vLLMを使った簡単な推論スクリプトも貼っておきます。

推論コード例(vLLM)
import soundfile as sf
import torch
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from xcodec2.modeling_xcodec2 import XCodec2Model






model_path = "Aratako/Llasa-8B-Zundamon"

codec_path = "HKUSTAudio/xcodec2"


def ids_to_speech_tokens(speech_ids):

    speech_tokens_str = []
    for speech_id in speech_ids:
        speech_tokens_str.append(f"{speech_id}|>")
    return speech_tokens_str


def extract_speech_ids(speech_tokens_str):

    speech_ids = []
    for token_str in speech_tokens_str:
        if token_str.startswith(") and token_str.endswith("|>"):
            num_str = token_str[4:-2]

            num = int(num_str)
            speech_ids.append(num)
        else:
            print(f"Unexpected token: {token_str}")
    return speech_ids


def main():
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = LLM(model=model_path, max_model_len=2048, gpu_memory_utilization=0.8)
    codec_model = XCodec2Model.from_pretrained(codec_path)
    codec_model.eval().cuda()
    input_text = "ボク、ずんだもんなのだ!よろしくなのだ!"

    formatted_text = (
        f"{input_text}"
    )

    chat = [
        {"role": "user", "content": "Convert the text to speech:" + formatted_text},
        {"role": "assistant", "content": ""},
    ]

    prompt = tokenizer.apply_chat_template(
        chat, tokenize=False, continue_final_message=True
    )
    print(prompt)
    speech_end_token = ""

    sampling_params = SamplingParams(
        temperature=0.8, top_p=1.0, max_tokens=1024, stop=[speech_end_token]
    )

    
    outputs = model.generate([prompt], sampling_params)

    generated_ids = outputs[0].outputs[0].token_ids

    speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    speech_tokens = extract_speech_ids(speech_tokens)

    speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)

    gen_wav = codec_model.decode_code(speech_tokens)

    sf.write("gen.wav", gen_wav[0, 0, :].cpu().numpy(), 16000)


if __name__ == "__main__":
    main()

この記事では、LLaSAというTTSモデルをずんだもんのデータでファインチューニングしてみました。結果として、ある程度それっぽい音声を合成可能になる事が確認できました。

まだまだ音声系はド素人なので、他のアーキテクチャ含め色々試して勉強していきたいと思います。

仕組み上どんなLLMをベースにしても同じ流れでTTSモデルが作成可能なので、ベースモデルの違いによる差が気になりました。Llamaは日本語が弱い印象なのでもっと日本語に強いモデルから始めた方が良い結果が得られたりするんでしょうか。

また、LLM-BasedなのでASRを上手く併用するとGRPOが出来るらしいです。面白そうなのでいつかやってみたい。

https://www.reddit.com/r/LocalLLaMA/comments/1l7pmua/grpo_can_boost_llmbased_tts_performance/



Source link

Views: 0

RELATED ARTICLES

返事を書く

あなたのコメントを入力してください。
ここにあなたの名前を入力してください

- Advertisment -