見出し画像

rinna-3.6Bをオリジナル小説でLoRAファインチューニングしてみた【RTX3060 (VRAM 12GB)】

動作確認のために、お試しでやってみました。

概要

背景

AITuberを含めた創作活動への活用のためにrinna-3.6Bでのファインチューニングを勉強したかったのですが、せっかくなら持ってるRTX3060を使ってローカルでやりたいと思っていました。

偉大なる先駆者の方々によって方法が開拓されていたので、ありがたく参考にさせていただいた次第です。

本記事でやったこと

・ローカルのRTX3060 (VRAM 12GB)でrinna-3.6BのLoRAファインチューニングを実行。
・オリジナルの小説をデータセットとして利用。文の続きを書けるか試す。

参考記事

以下の記事を参考にさせていただきました。ほとんどそのままです。環境構築など、詳細な手法はこれらの記事を参照してください。

実装

環境構築

上記の参考記事そのままです。

データセットの準備

1)筆者のオリジナルの小説である『ミューズ・クロニクル ―第十七学芸課は眠らない―』(約16万字)を使用しました。(古い下手くそな文章ですが、長編はこれくらいしか書いたことがないので……)

2)本文のみを抽出してテキストファイルに保存。

3)下記コードにてjsonに保存。今回は、5行分を入力として、続きの5行を出力するようにデータを整形しました。

import pandas as pd
import json

with open('./data/N9764DK_text.txt', encoding='utf8') as f:
    raw_text_lines = f.readlines()

# 空行を削除
text_lines = []
for line in raw_text_lines:
    if line != "\n":
        text_lines.append(line)
        
result = []
instruction_text = ""
output_text = ""

instruction_num = 5 # 入力する小説の行数
output_num = 5 # 出力する小説の行数
multiline_num = instruction_num + output_num

instruction_count = 0
instruction_flg = True

for i in range(len(text_lines)):
    if i % multiline_num == 0:
        formatted = {
            "input": instruction_text,
            "completion": output_text
        }
        result.append(formatted)

        instruction_text = ""
        output_text = ""
        instruction_flg = True
        
        instruction_text += text_lines[i]
    else:
        if instruction_flg:
            instruction_count += 1
            instruction_text += text_lines[i]
            if instruction_count == instruction_num-1:
                instruction_flg = False
                instruction_count = 0
        else:
            output_text += text_lines[i]

with open('./data/formatted_muse_chronicle.json', 'w', encoding='utf-8') as f:
    json.dump(result, f, indent=4)

ファインチューニング

下記コードを実行。上記の参考記事をほとんどそのまま使わせていただいております。

モデルは、japanese-gpt-neox-3.6bを使用。

メモリ使用量は最大7.9GB程度。使用率は100%が体感7割くらいで、時々数十~数%に落ちる感じ。30エポックで1時間半程度かかりました。

import os
import torch
import torch.nn as nn
import bitsandbytes as bnb
import transformers
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM

model_name = "rinna/japanese-gpt-neox-3.6b"
dataset = "./data/formatted_muse_chronicle.json"
peft_name = "lora-rinna-3.6b"
output_dir = "lora-rinna-3.6b-results"


tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

CUTOFF_LEN = 256

def tokenize(prompt, tokenizer):
    result = tokenizer(
        prompt,
        truncation=True,
        max_length=CUTOFF_LEN,
        padding=False,
    )
    return {
        "input_ids": result["input_ids"],
        "attention_mask": result["attention_mask"],
    }


# データセット
import json

with open(dataset, "r", encoding='utf-8') as f:
    data = json.load(f)

def generate_prompt(data_point):
    result = f"""### 指示:
{data_point["input"]}

### 回答:
{data_point["completion"]}
"""
    # 改行→<NL>
    result = result.replace('\n', '<NL>')
    return result

train_dataset = []
val_dataset = []

for i in range(len(data)):
    if i % 5 == 0:
        x = tokenize(generate_prompt(data[i]), tokenizer)
        val_dataset.append(x)
    else:
        x = tokenize(generate_prompt(data[i]), tokenizer)
        train_dataset.append(x)

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_8bit=True,
    device_map="auto",
)

from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType

lora_config = LoraConfig(
    r= 8, 
    lora_alpha=16,
    target_modules=["query_key_value"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

model = prepare_model_for_int8_training(model)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

eval_steps = 200
save_steps = 200
logging_steps = 20

trainer = transformers.Trainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    args=transformers.TrainingArguments(
        num_train_epochs=30,
        learning_rate=3e-4,
        logging_steps=logging_steps,
        evaluation_strategy="steps",
        save_strategy="steps",
        eval_steps=eval_steps,
        save_steps=save_steps,
        output_dir=output_dir,
        save_total_limit=3,
        push_to_hub=False,
        auto_find_batch_size=True
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

model.config.use_cache = False
trainer.train() 
model.config.use_cache = True

trainer.model.save_pretrained(peft_name)

推論の実行

下記コードを実行。こちらも上記の参考記事ほぼそのままです。複数文を出力するために、最後のEOSトークンまでデコードしています。

import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "rinna/japanese-gpt-neox-3.6b"
peft_name = "lora-rinna-3.6b"
output_dir = "lora-rinna-3.6b-results"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_8bit=True,
    device_map="auto",
)

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

model = PeftModel.from_pretrained(
    model,
    peft_name,
    # device_map="auto"
)

model.eval()

def generate_prompt(data_point):
    if data_point["input"]:
        result = f"""### 指示:
{data_point["instruction"]}

### 入力:
{data_point["input"]}

### 回答:
"""
    else:
        result = f"""### 指示:
{data_point["instruction"]}

### 回答:
"""

    # 改行→<NL>
    result = result.replace('\n', '<NL>')
    return result

def generate(instruction, input=None, maxTokens=256) -> str:
    prompt = generate_prompt({'instruction': instruction, 'input': input})
    input_ids = tokenizer(prompt,
                          return_tensors="pt",
                          truncation=True,
                          add_special_tokens=False).input_ids.cuda()
    outputs = model.generate(
        input_ids=input_ids,
        max_new_tokens=maxTokens,
        do_sample=True,
        temperature=0.7,
        top_p=0.75,
        top_k=40,
        no_repeat_ngram_size=2,
    )
    outputs = outputs[0].tolist()

    # 最後のEOSトークンまでデコード
    if tokenizer.eos_token_id in outputs:
        eos_list = [i for i, x in enumerate(outputs) if x == tokenizer.eos_token_id]
        decoded = tokenizer.decode(outputs[:eos_list[len(eos_list) - 1]])

        sentinel = "### 回答:"
        sentinelLoc = decoded.find(sentinel)
        if sentinelLoc >= 0:
            result = decoded[sentinelLoc + len(sentinel):]
            return result.replace("<NL>", "\n")  # <NL>→改行
        else:
            return 'Warning: Expected prompt template to be emitted.  Ignoring output.'
    else:
        return 'Warning: no <eos> detected ignoring output'

query = """「なら、スカーフが何色だったか分かるかい?」
「……それは、そのー、つまり、優雅で繊細な色合いでしたね」
「私は何色だったかと聞いているんだよ」
「まぁ、それは言うだけ野暮ってもんじゃないですか」
 そこにハルが口を挟んだ。"""
print(query)

for i in range(3):
    print("**********")
    print(generate(query))
    print("**********")

結果

データセットとした小説内に含まれる文を入力にしました。ちなみに、この文は以下のように続きます。

入力部分:
「なら、スカーフが何色だったか分かるかい?」
「……それは、そのー、つまり、優雅で繊細な色合いでしたね」
「私は何色だったかと聞いているんだよ」
「まぁ、それは言うだけ野暮ってもんじゃないですか」
 そこにハルが口を挟んだ。

続き:
「あなたも見えてないんですね?」
「やはりそうか。君達にも見えてなかったんだね」
「先に言ってくださいよ」
 ジムは、どっと襲ってきた疲れで今にも崩れ落ちそうだった。
「何があったんですか?」

この場面では、子供がいるかのように対応する店主に対して、子供が見えなかった主人公ハルと同僚のジムは困惑しています。子供のスカーフの色を質問する店主に、ジムは空気を読んで子供が見えていたかのように振る舞いますが、ハルは店主にも子供が見えていないのだと推察するという流れになっています。

今回は、データセットの整形時に5行を入力、続きの5行を出力としているため、全体の約半分は入力用データになっていません。さらにそこから訓練データとテストデータに分けているので、今回の入力部分がそっくりそのまま学習された確率は低そうです。

とはいえ「スカーフ」という単語はこのエピソードにしか出てこないので、学習されていればうまく引っかかるのではないかと期待しました。また、見えているかのような会話から逆に見えていないことを質問するという流れが再現できるかも気になりました。

結果は以下の通りです。3回繰り返して生成を試しています。

**********
「あなた、何かやましいことでも?」〈NP〉 その言葉に、ペトラは顔を青くして俯き、そして顔を上げ、自らの潔白を証明するように頷いた。〈 NL 〉「いや、そんなことはないけどさ。どうして『色』を聞いたのかしら?」“聞きたくなかった”</NW>(違う、そうじゃない。俺の思い違いだ)< NP> パァン。銃弾が皮膚をえぐるように、朝の空気が鼓膜を揺さぶった。
**********
**********
「あなた、何かやましいことでもしたんですか?」〈NP〉<NA>(なんだと?)< NL 一見、ハルの悪戯だったが、しかしジムは答えるのを躊躇った。もしそれが事実ならハルは学芸員ではない。つまりハル自身が犯人ということになる。しかしどちらも泥臭い匂いのするのはハル自身のはずだった。ジムはその可能性を考えてみたが、同時に妙な焦りも覚えていた。こんな時に出てこないような汚れ仕事は、ジムには勤まらない。きっと本棚の奥にしまってあるはずのメモを見落としたのだ。ここは強気に尋ねるべき場面だろう。”Now”と
**********
**********
「あなた、何かやましいことでも?」〈Nl〉「ち、違うとは言ってませんよ!」< NL > 必死にハルは否定したが、それが嘘だというのはジムには分かっていた。か細い腕がしっかりとハルの背中を掴んでいる。もうジムの言葉も耳に入らないようだった。〈 Nl 〉:#例 : <NM> 声がしたのは、後ろからだったと思う。通路をこちらに振り返り、名を呼ばれた。振り向くまで気付かなかったが、どうやらそこはハルがいた
**********

『「あなた、何かやましいことでも……』から続く文章が生成されましたが、学習データを確認したところ、そのような文章は含まれていませんでした。そのあとの文章も、意味的つながりが弱くなっている印象です。

一方、良い点としては、上手く言い逃れようとしているジムに対してハルが何かを言おうとしているという文脈にも捉えられるので、「何かやましいことでも?」と尋ねる言葉が来る可能性は高いかもしれません。

文体も(他の人には分かりづらいかもしれませんが)筆者らしさが表現されているように感じます。特に私の文章の特徴の一つとして、Twitter小説を執筆していた影響から短い文章で簡潔に表現する傾向があります。

特に下記の表現は簡潔でありつつ比喩的な想像力を感じる文章で、評価できると思います。

『銃弾が皮膚をえぐるように、朝の空気が鼓膜を揺さぶった。』
『きっと本棚の奥にしまってあるはずのメモを見落としたのだ。』

ちなみに「パァン」は作中で銃の発射音として一度使われていました。「ペトラ」は登場人物の一人です。

おわりに

ひとまず動作確認ができてよかったです。

3.6Bなのでまだまだパラメータ数が小さいですが、いずれより大きなパラメータ数のモデルを広く普及しているGPUで使えるようになると思うので、今後も様々な活用方法を探っていきたいと思います。

この記事が気に入ったらサポートをしてみませんか?