見出し画像

Google Colabでの日本語Mambaの事前学習

はじめに

昨年(2023年)末にMambaアーキテクチャが公開されました。
MambaはS4などと同様の状態空間モデルというもので、Transformerと比べて、

  • 高速な推論

  • シーケンス長が伸びた際のメモリ効率の良さ

  • 単純なモデル性能の良さ

で優れている様です。
日本語モデルがないので、日本語Mambaの事前学習のコードを作成しました。Google colabで動くことは確認したもののA100(40B)でも15時間近くかかるので実質最後までは実行できないです。コードの参考にしていただければ幸いです。

Paperspace gradientなどを使ってそのうち最後まで実行してみようと考えています。


1. 準備

Google colabにCheckpointを保存するためにGoogle Driveにマウントします。

# Google ColabとGoogle Driveを接続
from google.colab import drive
drive.mount('/content/drive')

# 事前学習データの保存先
drive_dir = '/content/drive/your_path'

ライブラリ等のインストール。まずは Mambaを動かすためのセットをインストール。

!pip install causal-conv1d>=1.1.0
!pip install mamba-ssm

その他

!pip install --upgrade pip setuptools wheel
!pip install accelerate datasets transformers wandb
!pip install apache-beam

2. データやモデルの準備

ライブラリimport。

from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForCausalLM
import torch
import os

from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig

from datasets import load_dataset

from dataclasses import asdict
import json

transformersのTrainerを使うため、Mambaのモジュールを継承したクラスを作成します。

# MambaConfig をラップするクラス (to_dictとto_json_stringのため)
class MambaConfigForTrainer:
    def __init__(self, **kwargs):
        self.config = MambaConfig(**kwargs)
    
    def to_dict(self):
        return asdict(self.config)

    def to_json_string(self):
        return json.dumps(self.to_dict(), indent=4)
    
    def __getattr__(self, item):
        try:
            return getattr(self.config, item)
        except AttributeError:
            raise AttributeError(f"'MambaConfigForTrainer' object has no attribute '{item}'")
class MambaTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs.pop("input_ids")
        lm_logits = model(input_ids)[0]

        labels = input_ids.to(lm_logits.device)
        shift_logits = lm_logits[:, :-1, :].contiguous()
        labels = labels[:, 1:].contiguous()

        loss_fct = torch.nn.CrossEntropyLoss()
        lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))

        return lm_loss

Tokenizerの定義。

# tokenizerの定義
# 後々、事前学習のデータにllm-jpが公開している下記スクリプトを使用しようとしているため、tokenizerもllm-jpのものを使用。実際なんでも良い。
# スクリプト: https://github.com/llm-jp/llm-jp-corpus/tree/v2.0.0
tokenizer = AutoTokenizer.from_pretrained("llm-jp/llm-jp-13b-v1.0")

Modelの定義。

# MambaConfigの定義 (130M)
mamba_config = MambaConfigForTrainer(
    d_model = 768,
    n_layer = 24,
    vocab_size = len(tokenizer),
)

# modelの定義
model = MambaLMHeadModel(
    config = mamba_config,
    device = "cuda",
)

上記は論文中の最小のモデルサイズとしているがメモリに余裕がある場合は他でも良いです。
Paramaters:

Parameters | Layers | Model dim.
130M           | 24        | 768
370M           | 48        | 1024
790M           | 48        | 1536
1.4B              | 48        | 2048
2.8B              | 64        | 2560

state-spaces / mamba

データ取得。事前学習はllm-jpが公開しているスクリプトを使用しようと考えていますが、一旦ある程度整備されていそうなrange3/wiki40b-jaを使用してみました。

# データ取得
wiki_dataset = load_dataset("range3/wiki40b-ja")
wiki_dataset

3. 学習の実行

データセットのトークン化。

from transformers import DataCollatorForLanguageModeling

def tokenize_function(examples):
    return tokenizer(examples["text"], return_special_tokens_mask=True)

# データセットをトークン化
tokenized_datasets = wiki_dataset.map(tokenize_function, batched=True, remove_columns=["wikidata_id", "version_id"])

# データコレーターの準備
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

Trainingの設定。

args = TrainingArguments(
    output_dir=drive_dir + "/checkpoints",
    report_to="wandb",
    save_strategy="epoch",
    save_total_limit=10,
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=64,
    warmup_steps=500,
    weight_decay=0.01,
    logging_steps=10,
)

trainer = MambaTrainer(
    args=args,
    model=model,
    tokenizer=tokenizer,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
)

なぜかtrainer.train()時にlibcuda.so not foundというエラーが出たため、以下を実行しておきます。

# 参照:https://zenn.dev/selllous/articles/transformers_pretrain_to_ft
# cudaバージョンを確認する 
!nvcc --version # refer:https://github.com/pytorch/pytorch/issues/107960 

# trainer.train()時のlibcuda.so not foundエラーの回避策 
!ldconfig /usr/lib64-nvidia

学習の実行。

trainer.train()

4.モデル保存

# 最後のモデルを保存
trainer.save_model(drive_dir)

以上です。

参照


この記事が気に入ったらサポートをしてみませんか?