見出し画像

Rinnaのマルチターン対応LLM Youriをベンチマークする

Rinnaからマルチターン対応のYouri-7b-chatが出ていたのでJapanese-MT-benchでベンチマークできるようにプログラムを書いてみた。

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import random
import string
import time

def generate_random_string(length):
    letters = string.ascii_letters
    result_str = ''.join(random.choice(letters) for i in range(length))
    return result_str

llm = "rinna/youri-7b-chat"
tokenizer = AutoTokenizer.from_pretrained(llm)
model = AutoModelForCausalLM.from_pretrained(llm, device_map="auto", torch_dtype=torch.float16)
model.to("cuda:0")
template="ユーザー: %s\nシステム: "
sep="\n"
max_length=2048

def think(question, context):
    #print(question)
    input_ids = tokenizer.encode(context + template%question , 
                                 add_special_tokens=False,return_tensors="pt").to("cuda:0")
    output_ids = model.generate(input_ids, max_length=max_length, num_beams=5, no_repeat_ngram_size=2, early_stopping=True)
    return tokenizer.decode(output_ids[0], skip_special_tokens=True)

def generate_answer(file):
    model_id= llm[llm.index("/") + 1:]

    outfile = llm[llm.index("/") + 1:] + ".jsonl"
    print(outfile)

    with open(outfile,"w") as of:

        with open(file) as f:
            for i in f:
                d=json.loads(i)
                context = "設定: あなたは様々な質問に丁寧に答えるアシスタントです\n"
                turns=[]
                d["model_id"]=model_id
                d["answer_id"]=generate_random_string(22)
                
                prompt=""

                for utterance in d["turns"]:
                    response=think(utterance, context)
                    answer = response[len(context+template%utterance):]
                    print(answer)
                    context =response+sep    
                    #turns.append(utterance)
                    turns.append(answer)
                    #print(context)    
                d["generate_params"]={"prompt":context,
                                      "tstamp":time.time(),
                                      "do_sample": True, "max_new_token": 512, 
                                      "temperature": 0.5, "top_p": 0.9, "repetition_penalty": 1.1
                                      }
                d["choices"]=[{"index:":0,"turns":turns}]
                print(json.dumps(d, ensure_ascii=False))
                of.write(json.dumps(d, ensure_ascii=False)+"\n")
generate_answer("data/japanese_mt_bench/question.jsonl")

これで、実際の出力はこんな感じになる。

設定: あなたは様々な質問に丁寧に答えるアシスタントです
ユーザー: 最新の日本の太陽光発電技術に関する記事のための、科学的に正確かつ興味を引く見出しを4つ提案してください。
システム: 以下は、私が提供できるいくつかのアイデアのリストにすぎません:
1. ソーラーパネルはどのくらい効率的なのでしょうか?
2. CIGSとCdTeのどちらがより多くの電力を生み出すのに優れているのか
3. HITのような新しい材料についてもっと知りたい
4. PV業界における中国の影響力
ユーザー: 2つ目の見出しに以下の修正を加えてみてください:1. トーンをカジュアルに 2. 「スーパーパワー」の商品広告を埋め込む 3. 10個以下の単語を使う。
システム: もちろん、これらの変更を反映させることは可能であると思いますが、あまりにも詳細になってしまうと、読者を魅了したり、混乱を招かないという目的を果たせなくなる懸念がありまあす。",

悪くないように見えるのだがJapanese_MT_benchの結果はイマイチ。
はぐらかし系が多い印象かな。

サイズの割には各種質問の深堀に弱い模様。学習したデータの問題だろうか。

全体のベンチマークはここにあります