見出し画像

gpt-index(0.2.5)をOpenAI APIなし&日本語で動かす

gpt-indexは長いコンテキストに対してQAを行えるフレームワークです。
デフォルトではOpenAIのAPIを利用するので無邪気に長いコンテキストに質問を投げているとすぐ数$の請求になって焦りますね。

今回はローカルでオープンな日本語モデルを使って動かす方法をご紹介します。
あくまで試みであり、正答率もいまひとつで実用性があるものではありませんが、学習データセットを作るコード、モデル学習コード、gpt-indexを実行するコードはこのリポジトリに置いています。
https://github.com/oshizo/gpt_index_japanese_trial

1/18のツイートで投稿したツリーをもう少し詳しく説明する内容です。

QAモデル

OpenAIのAPIを使う場合、コンテキストに対するQA応答は命令チューニングをした大規模モデル(Text-Davinci-003)で実行します。
これはQAに特化したモデルではありませんが、以下のようなプロンプトを使うだけでうまく応答ができる強さを持っています。

# gpt-index default_prompts.py
DEFAULT_TEXT_QA_PROMPT_TMPL = (
    "Context information is below. \n"
    "---------------------\n"
    "{context_str}"
    "\n---------------------\n"
    "Given the context information and not prior knowledge, "
    "answer the question: {query_str}\n"
)

日本語訳すると、このようなプロンプトです。

DEFAULT_PROMPT = """
文脈情報は以下です。
---
{context_str}
---
事前知識ではなく、文脈情報を参考に質問に答えてください。:{query_str}
"""

オープンかつ日本語対応しているモデルでは、promptのみでQAに対応することは難しそうなので、今回はファインチューニングを行うことにします。

QAモデルのファインチューニングの概要

gpt-indexでは、モデルの許容するmax_lengthより長いコンテキストにQA応答するための仕組みがあります。
一例としては、長いコンテキストを複数に分割しておき、埋め込みモデルでQueryに近そうなものをいくつか探して順に問い合わせていく方法があります。

https://gpt-index.readthedocs.io/en/latest/guides/index_guide.html

これを実現するために、前のコンテキストでの回答を踏まえ、次のコンテキストで必要に応じ解答をリファインしていく機能が必要です。
具体的には、以下のようなプロンプトに対する推論ができるモデルを用意します。

# gpt-index default_prompts.py DEFAULT_REFINE_PROMPT_TMPL の日本語版
REFINE_PROMPT = """
質問は以下です。:{query_str}
すでに答えの候補があります。:{existing_answer}
必要な場合のみ、以下の文脈情報を使ってこの答えを改良することができます。
---
{context_msg}
---
この文脈情報により、元の答えを改良して質問に答えてください。
文脈情報が有用でない場合は元の答えをそのまま返してください。
"""

通常の抽出型QAの推論に加えて、このプロンプトに沿って推論できるモデルを学習するデータセットを作っていきます。

データセットの準備

日本語の文脈付きQAデータセットは、知る限り以下の3つがあります。
今回はJSQuADとJaQuADを使いました。

  • JGLUEのJSQuAD … 62,859件

  • JaQuAD … 31,748件

  • TyDi QAの日本語サブセット … 16,288件(抽出回答があるのはうち32%)

リファインの学習のために、2つのQAをランダムに取り出してREFINE_PROMPTに埋め込むことで、実際のクエリ時に起こる様々なパターンをカバーした学習データを作ります。
以下は、誤った解答「商業」を「手塚治虫」にリファインする学習データの構成例です。

このpromptに対し「手塚治虫」を出力するように学習

具体的には、以下のパターンを合計70k件ほど作成しました。

  1. 答えを得ていない(最初のcontext)場合に、適切なcontextで答えを出力する

  2. 答えを得ていない(最初のcontext)場合に、適切でないcontextで答えを出力しない(「分かりません」と出力)

  3. 正しい答えを得ている場合に、適切でないcontextで答えを変更しない

  4. 正しい答えを得ている場合に、適切なcontextで答えを変更しない

  5. 正しい答えを得ていない場合に、適切でないcontextで答えを変更しない

  6. 正しい答えを得ていない場合に、適切なcontextで答えを変更する

学習実施

学習は rinna/japanese-gpt-1b をベースモデルとして、通常のpromptの続きを生成する形での学習をしました。
(今回は抽出型QAのみを行ったので、生成モデルを使う必要はなくBERTなどでもよかったと思います。また、タスクとしてはシンプルなのでmediumサイズで十分だったかもしれません。)

環境は Google Colab Pro の A100 を使って3時間学習させました(500円ぐらい)。

モデルはoshizo/qa-refine-japanese-gpt-1bに置いてはいますが、QAモデルとしての精度評価はしていないので良いQAモデルかは分かりません。
末尾で例を挙げますが、リファインの性能は良くないです(正しい答えを、無関係なコンテキストで誤った答えに修正してしまうことが多い)。

埋め込みモデル

今回は、以前学習させておいたこのモデルを使用しました。日本語のSentence-Transformersモデルが必要です。

gpt-indexの設定

gpt-indexはかなりのカスタマイズ性があり、ほぼ用意されたIFでローカルモデルを使うことができます。
設定の渡し方はv0.2.5時点ではちょっと分かりづらい所があるので少し解説します。

LLMPredictorを作る

QAモデルを使用するLLMPredictorは、promptを引数に答えを返す関数を定義することでカスタマイズできます。
まずQAモデルで答えを生成する関数を作っておきます。

from transformers import T5Tokenizer, AutoModelForCausalLM

tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt-1b")
qa_model = AutoModelForCausalLM.from_pretrained("./models/qa-refine-japanese-gpt-1b")

def generate(prompt):

    token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
    n = len(token_ids[0])

    with torch.no_grad():
        output_ids = qa_model.generate(
            token_ids.to(qa_model.device),
            pad_token_id=tokenizer.pad_token_id,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    # 出力から回答部分のみを切り出す
    output = tokenizer.decode(output_ids.tolist()[0][n:])
    return output.replace("</s>", "")

langchain.llms.base.LLMを継承して、_callに先ほどの関数を設定します。
このクラスのインスタンスを渡してLLMPredictorを初期化すればOKです。

from gpt_index import LLMPredictor
from langchain.llms.base import LLM
from typing import Optional, List, Mapping, Any
class CustomLLM(LLM):
        
    @property
    def _llm_type(self) -> str:
        return "custom"
    
    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        return generate(prompt)
    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        return {"name":"custom"}
llm = CustomLLM()
llm_predictor = LLMPredictor(llm)

埋め込みモデル

sentence-transformersを使うためのクラスHuggingFaceEmbeddingsの引数でmodel_idを指定します。
(作った埋め込みモデルのmax_lengthを128にしてしまっており…ちょっと足りなそうなので上書きしています)

from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from gpt_index import LangchainEmbedding
from sentence_transformers import SentenceTransformer

# いったんall-mpnet-base-v2で初期化
embed_model = LangchainEmbedding(HuggingFaceEmbeddings(model_name="oshizo/sbert-jsnli-luke-japanese-base-lite"))
# max_seq_length上書き
embed_model._langchain_embedding.client.max_seq_length = 256

インデックス作成

まずデータを取ってきます(ぼざろのWikipediaページ全体)。
WikipediaのURLをコピペして?action=cirrusdumpを付ければ他のページでも実行できます。

import urllib
import json
with urllib.request.urlopen("https://ja.wikipedia.org/wiki/%E3%81%BC%E3%81%A3%E3%81%A1%E3%83%BB%E3%81%96%E3%83%BB%E3%82%8D%E3%81%A3%E3%81%8F!?action=cirrusdump") as f:
    data = f.read()
text = json.loads(data)[0]["_source"]["text"]

200文字ずつのチャンクに、100文字ずつ開始位置をスライドさせながら分けてDocumentを作成しました。
長くするとembedding計算が不正確になって適切なcontextが選ばれづらくなり、逆に短くすると1つのcontextに同時に含まれる情報が減って質問に答えられなくなるトレードオフがあると思います。

from gpt_index import GPTSimpleVectorIndex
from gpt_index.readers.schema.base import Document

documents = []
for i in range(0, len(text), 200):
    documents.append(Document(text[i:i+200]))
    if i != 0:
        documents.append(Document(text[i-100:i+100]))

約400チャンクに分割されました。
最後に、QAモデル、埋め込みモデル、Documentのリストを与えてインデックスを作ります。

from gpt_index import GPTSimpleVectorIndex

index = GPTSimpleVectorIndex(documents, llm_predictor=llm_predictor, embed_model=embed_model)

クエリ

index.queryには、カスタマイズのための引数を渡すことができます。
v0.2.5の時点ではAPIリファレンスには「**kwargs – Additional kwargs to pass to the index constructor.」と書かれていますが、実際になにが渡せるかはコードを眺める必要があります。

  • text_qa_template

  • refine_template

この二つの引数で、質問応答に使うpromptのテンプレートを渡すことができます。
実装上は、BaseGPTIndex.query, QueryRunner.query, BaseGPTIndexQuery, ResponseBuilder.refine_response_singleと受け渡されてLLMPredictorの呼び出しに使われています。

  • similarity_top_k

この引数で、embed_modelで探してきたコンテキストの類似度topK個を使うように設定できます。
多くすると、embed_modelが最適なコンテキストをtopkの中に見つけられればよくなります。問い合わせ対象のテキストのチャンクが多い場合や、embed_modelの精度に自信がない場合は有効そうです。
(OpenAI APIを使う場合はtext-davinci-003にk回問い合わせを投げるので、増やしすぎに注意)
実装上は、↑の処理中のBaseGPTVectorStoreIndexQueryのコンストラクタで受け取って_get_nodes_for_responseで使われています。


プロンプトテンプレート文を使って、Promptインスタンスを作っておきます。

from gpt_index.prompts.prompts import RefinePrompt, QuestionAnswerPrompt
refine_prompt = RefinePrompt(REFINE_PROMPT)
default_prompt = QuestionAnswerPrompt(DEFAULT_PROMPT)

これらを渡して、index.queryを呼び出します。


response = index.query(
    "虹夏ちゃんのお姉さんの職業はなに?",  
    mode="embedding", 
    verbose=True, 
    embed_model=embed_model,
    text_qa_template=default_prompt,
    refine_template=refine_prompt,
    similarity_top_k=3
)
print(response)

> Top 3 nodes:
> [Node 40aa0d84-e8e6-4102-8bab-ce864fc34bf2] て作曲の経験もあり、ひとりにアドバイスを与えた。 ひとりの動画投稿サイトでの収益を貯金していて、ギターを壊してしまったひとりにそのお金を与えている。 目はいつも何か(髪の毛や、ふたりの手など)に...
> [Node ea6fdcea-cd80-4696-af1f-d9ca8334aa00] 束バンドの後藤ひとり」であることは公にはされていない。 伊地知 虹夏(いじち にじか) 声 - 鈴代紗弓 誕生日:529日 / 血液型:A型 ドラム担当。下北沢高校2年→3年。明るく世話焼きで...
> [Node 7a1d891a-4450-4269-b23c-c7077ee9feed] に遠ざけられ、結局ひとりが作曲したものをリョウが編曲することになった。 虹夏と喜多によると(普段の奇行から忘れがちだが)美少女。文化祭のクラスの出し物でメイド服を着させられたときは、皆に似合って...
> Searching in chunk: て作曲の経験もあり、ひとりにアドバイスを与えた。 ひとりの動画投稿サイトでの収益を貯金していて...
> Searching in chunk: 束バンドの後藤ひとり」であることは公にはされていない。 伊地知 虹夏(いじち にじか) 声 -...
> Searching in chunk: に遠ざけられ、結局ひとりが作曲したものをリョウが編曲することになった。 虹夏と喜多によると(普...
> Initial response: 後藤 美智代
> Refine context: 束バンドの後藤ひとり」であることは公にはされていない。 伊地知 虹夏(いじち にじか) 声 -...
> Refined response: ライブハウスの店長
> Refine context: に遠ざけられ、結局ひとりが作曲したものをリョウが編曲することになった。 虹夏と喜多によると(普...
> Refined response: ライブハウスの店長
> [query] Total LLM token usage: 1436 tokens
> [query] Total embedding token usage: 30 tokens
'ライブハウスの店長'

similarity_top_kを3にすることで、1つ目のコンテキストでは誤った解答をしているところを、2つ目のコンテキストでリファインに成功していることが分かります。

次の例のように、「この作品のあらすじを教えて?」など抽出型QAで答えられない質問は今回のモデルではうまく応答できません。
OpenAI APIを使うとある程度うまくこたえてくれると思うので、面白さに差があると感じます。抽出型以外を含む幅広いQAにうまく応答できるモデルを作るのは…どうすればいいんでしょうか。

response = index.query(
    "この作品のあらすじを教えて?", 
    mode="embedding", 
    verbose=True, 
    embed_model=embed_model,
    text_qa_template=default_prompt,
    refine_template=refine_prompt,
    similarity_top_k=3
)
response.response

> Top 1 nodes:
> [Node bf228da7-5efb-4322-a50f-e220f38b680b] 込む際に、ひとりの会話が視聴者に聞かせるための内容か音として鳴っているだけのものか、精査しながら場面構成を整えていく手法がとられている。また、原作に沿った内容でありながらもギャグの勢いをアニメだ...
> Searching in chunk: 込む際に、ひとりの会話が視聴者に聞かせるための内容か音として鳴っているだけのものか、精査しなが...
> Initial response: ギャグの勢いをアニメだからこそできる表現
> [query] Total LLM token usage: 411 tokens
> [query] Total embedding token usage: 21 tokens
'ギャグの勢いをアニメだからこそできる表現'


次の例ではリファイン処理に失敗して、Initial responseでは正しくこたえられているのに次のコンテキストで誤った答えにしてしまっています。

response = index.query(
    "後藤ひとりがギターに熱中するようになった理由は?", 
    mode="embedding", 
    verbose=True, 
    embed_model=embed_model,
    text_qa_template=default_prompt,
    refine_template=refine_prompt,
    similarity_top_k=3
)
response.response
> Top 3 nodes:
> [Node 43ee6b50-eb6b-444f-ad8c-621809d48f6a] の時に、暗い学生時代から一転してスターとなったバンドマンのインタビューを目にしたことで、父親から借りたギターに没頭する。毎日6時間以上の練習を約3年間欠かさず行ってきたため、個人としての演奏技術...
> [Node 5f1d20a4-5dfa-4335-83cd-e4c5063be3e4] 「心が乙女なおっさん」。好きな音楽ジャンルはパンクロック。 ぽいずん♡やみ 誕生日:214日 / 血液型:B型 音楽情報サイトに寄稿するフリーライター。本名は佐藤 愛子(さとう あいこ)。23...
> [Node 39f627ae-846e-418b-9c9b-348be5e1c7ac] ィング音源を使用しているほか、ひとりが行った奏法で演奏している。この楽曲は演奏中にひとりの1弦が切れてしまうシーンで使用されているため、この展開を踏まえて編曲を行い、ギターソロ後は1弦を使わなく...
> Searching in chunk: の時に、暗い学生時代から一転してスターとなったバンドマンのインタビューを目にしたことで、父親か...
> Searching in chunk: 「心が乙女なおっさん」。好きな音楽ジャンルはパンクロック。 ぽいずん♡やみ 誕生日:214...
> Searching in chunk: ィング音源を使用しているほか、ひとりが行った奏法で演奏している。この楽曲は演奏中にひとりの1...
> Initial response: 暗い学生時代から一転してスターとなったバンドマンのインタビューを目にしたこと
> Refine context: 「心が乙女なおっさん」。好きな音楽ジャンルはパンクロック。 ぽいずん♡やみ 誕生日:214...
> Refined response: 癖
> Refine context: ィング音源を使用しているほか、ひとりが行った奏法で演奏している。この楽曲は演奏中にひとりの1...
> Refined response: 癖
> [query] Total LLM token usage: 1463 tokens
> [query] Total embedding token usage: 35 tokens
'癖'

今回学習したモデルのリファイン性能は体感としてかなり悪く、一度正解しても容易に答えを変更してしまいます。
kを増やすと応答履歴のどこかに正解があってもほとんどの場合最終的に不正解になってしまいます…。

これには二つの理由があると思います。

  • リファイン時には新しいコンテキストしか与えられず、元の答えの時に参照したコンテキストを与えられないので、2つのコンテキストを比較して判断することができない

  • 学習データとしてはランダムなコンテキストをピックアップして学習しているが、推論時には埋め込みモデルにより類似したコンテキストが与えられるので、学習データで答えを変更するパターンと似た状況になりやすい

後者は学習データを作るときに似たコンテキストを選ぶようにすると改善できそうですが、前者はgpt-indexのしくみに依存しているので難しいかもしれません。
text-davinci-003ならコンテキストを比較しなくてもリファイン精度が高いと思うので、このしくみで十分なのかもしれませんね。


この記事が気に入ったらサポートをしてみませんか?