FlaskでChatGPTのレスポンスをストリーミング表示

PythonでChatGPTのツールを作っていて、回答をストリーミング表示(ヌルヌルと文字が出てくるやつ)させたくて、その調査ログです。

Flaskでストリーミング表示

そもそもFlaskをよく知らないので、GPTに関係なくストリーミング表示させる方法をGPTに聞いてみました。

GPTに投げたお願い
Flaskでストリーミング表示を学んでいます。以下のようなサンプルを書いてください。

* socketで実現する。roomも使う。
* HTMLのボタンを押したら、ボタンを無効にしてPythonに処理が行く。
* Pythonでは、5秒ウェイとする。1秒ごとに経過時間をクライアントに伝え、Webページでそれを表示する。
* 5秒経過したら、Pythonから「終了」というメッセージがHTMLに飛ぶ
* 通信を終了する。ボタンも有効に戻す。
* そのメッセージをHTMLに表示する

結果、PythonとHTMLでそれぞれ以下のような返答が来ました。一発でエラーなく回って感動♪

PIP

pip install flask flask-socketio eventlet

app.py

from flask import Flask, render_template
from flask_socketio import SocketIO, emit, join_room
import time

app = Flask(__name__)
socketio = SocketIO(app)

@app.route('/')
def index():
    return render_template('index.html')

@socketio.on('start_process', namespace='/test')
def start_process(message):
    join_room('room')
    for i in range(1, 6):
        time.sleep(1)
        emit('time_update', {'message': f'{i} seconds passed.'}, room='room')
    emit('process_completed', {'message': '終了'}, room='room')

if __name__ == '__main__':
    socketio.run(app, debug=True)

templates/index.html

<!DOCTYPE html>
<html lang="ja">
<head>
    <meta charset="UTF-8">
    <title>Streaming Example</title>
    <script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.0.0/socket.io.min.js"></script>
    <script type="text/javascript">
        var socket = io.connect('http://' + document.domain + ':' + location.port + '/test');
        socket.on('connect', function() {
            socket.emit('joined', {});
        });
        socket.on('time_update', function(msg) {
            document.getElementById('status').innerHTML = msg.message;
        });
        socket.on('process_completed', function(msg) {
            document.getElementById('status').innerHTML = msg.message;
            document.getElementById('startButton').disabled = false;
        });
        function startProcess() {
            socket.emit('start_process', {});
            document.getElementById('startButton').disabled = true;
        }
    </script>
</head>
<body>
    <button id="startButton" onclick="startProcess()">Start</button>
    <div id="status"></div>
</body>
</html>

実行結果

Langchainで生成AIを動かす

API_KEYを「.env」というファイル名で保存して、以下のプログラムを実行。

.env ファイルの中身

test_chatgpt.py

import os
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage

os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")

llm = ChatOpenAI(
    model_name  = "gpt-3.5-turbo",
    temperature = 0,
    streaming   = False, 
)
response = llm(messages=[HumanMessage(content="こんにちわ")])
print(response)

これを回してターミナルにメッセージが出たら成功です。次はヌルヌル表示にチャレンジします。

GPTの返答をストリーミング表示

まずはライブラリのインストール。ここで注意点!langchainのバージョンが「0.0.142(以降?)」じゃないと、CallbackManagerなる関数が使えない模様。PIPするときに注意してください。

pip install langchain==0.0.142

まずはコールバック関数を定義します。「https://ict-worker.com/ai/langchain-stream.html」の記事が非常にわかりやすかったので、ほぼその中身を流用させていただいております。
ChatOpenAIでStreamingをTrueにすると、ChatGPTが返答トークンを発行するたびに「on_llm_new_token」が呼ばれ、自作コールバック関数(後述)が動くという仕様です。

from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import CallbackManager, BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult

class mycbhandler(BaseCallbackHandler):

    streaming_handler = None
    def __init__(self, jisaku_callbackfunction):
        #自作のコールバック関数を登録
        self.streaming_handler = jisaku_callbackfunction
        
    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        self.streaming_handler(token)

    def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
        pass

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        pass

    def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
        pass

    def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> None:
        class_name = serialized["name"]
        print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")

    def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
        print("\n\033[1m> Finished chain.\033[0m")

    def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
        pass

    def on_tool_start(self,serialized: Dict[str, Any], input_str: str, **kwargs: Any, ) -> None:
        pass

    def on_agent_action(self, action: AgentAction, color: Optional[str] = None, **kwargs: Any) -> Any:
        print(action)

    def on_tool_end(self, output: str, color: Optional[str] = None, observation_prefix: Optional[str] = None, llm_prefix: Optional[str] = None, **kwargs: Any) -> None:
        print(output)

    def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
        pass

    def on_text(self, text: str, color: Optional[str] = None, end: str = "", **kwargs: Optional[str]) -> None:
        print(text)

    def on_agent_finish(self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any) -> None:
        print(finish.log)

次に、自作コールバック関数を作り(単にプリントするだけ)、ChatOpenAIにそのコールバック関数を渡してあげます。

import os
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage

os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")

# 自作コールバック関数(単にプリントするだけ)
def handle_token(token):
    print('\033[36m' + token + '\033[0m')

llm = ChatOpenAI(
    streaming        = True, 
    callback_manager = CallbackManager([mycbhandler(handle_token)]), 
    verbose          = True, 
    temperature      = 0
)
response = llm(messages=[HumanMessage(content="こんにちわ")])
print(response)

これら2つを足し合わせて実行してみてください。例えば、以下のような感じです。(単に上の2つを足し合わせただけです)

from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import CallbackManager, BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult

#--------------------------------------------------------#
class mycbhandler(BaseCallbackHandler):

    streaming_handler = None
    def __init__(self, jisaku_callbackfunction):
        #自作のコールバック関数を登録
        self.streaming_handler = jisaku_callbackfunction
        
    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        self.streaming_handler(token)

    def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
        pass

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        pass

    def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
        pass

    def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> None:
        class_name = serialized["name"]
        print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")

    def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
        print("\n\033[1m> Finished chain.\033[0m")

    def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
        pass

    def on_tool_start(self,serialized: Dict[str, Any], input_str: str, **kwargs: Any, ) -> None:
        pass

    def on_agent_action(self, action: AgentAction, color: Optional[str] = None, **kwargs: Any) -> Any:
        print(action)

    def on_tool_end(self, output: str, color: Optional[str] = None, observation_prefix: Optional[str] = None, llm_prefix: Optional[str] = None, **kwargs: Any) -> None:
        print(output)

    def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
        pass

    def on_text(self, text: str, color: Optional[str] = None, end: str = "", **kwargs: Optional[str]) -> None:
        print(text)

    def on_agent_finish(self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any) -> None:
        print(finish.log)
#--------------------------------------------------------#

import os
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage

os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")

def handle_token(token):
    print('\033[36m' + token + '\033[0m')

if __name__ == '__main__' :
    llm = ChatOpenAI(
        streaming        = True, 
        callback_manager = CallbackManager([mycbhandler(handle_token)]), 
        verbose          = True, 
        temperature      = 0
    )
    response = llm(messages=[HumanMessage(content="こんにちわ")])
    print(response)

ターミナルにヌルヌルと文字が表示されれば成功です。

Webページでヌルヌル実装

いよいよ実装です。
最初の「Flaskでストリーミング表示」にある`start_process`に、以下のような部分があります。

    for i in range(1, 6):
        time.sleep(1)
        emit('time_update', {'message': f'{i} seconds passed.'}, room='room')

ここのemit部分を、自作コールバック関数に使いまわします。

自作コールバック関数

ai_message = ''
def handle_token(token):
    global ai_message
    ai_message = ai_message + token
    emit('time_update', {'message': ai_message}, room='room')

また、上述のfor文はコメントアウトして、以下の2行を付け足します。(llmは外(global)で定義済みという前提)

    global llm
    response = llm(messages=[HumanMessage(content="こんにちわ")])

これでストリーミング表示されるはずです!
ソースコード全文は以下の通りです。(index.htmlは全く変えず)

from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import CallbackManager, BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult

#--------------------------------------------------------#
class mycbhandler(BaseCallbackHandler):

    streaming_handler = None
    def __init__(self, jisaku_callbackfunction):
        #自作のコールバック関数を登録
        self.streaming_handler = jisaku_callbackfunction
        
    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        self.streaming_handler(token)

    def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
        pass

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        pass

    def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
        pass

    def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> None:
        class_name = serialized["name"]
        print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")

    def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
        print("\n\033[1m> Finished chain.\033[0m")

    def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
        pass

    def on_tool_start(self,serialized: Dict[str, Any], input_str: str, **kwargs: Any, ) -> None:
        pass

    def on_agent_action(self, action: AgentAction, color: Optional[str] = None, **kwargs: Any) -> Any:
        print(action)

    def on_tool_end(self, output: str, color: Optional[str] = None, observation_prefix: Optional[str] = None, llm_prefix: Optional[str] = None, **kwargs: Any) -> None:
        print(output)

    def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
        pass

    def on_text(self, text: str, color: Optional[str] = None, end: str = "", **kwargs: Optional[str]) -> None:
        print(text)

    def on_agent_finish(self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any) -> None:
        print(finish.log)
#--------------------------------------------------------#

import os
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage

os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")

ai_message = ''
def handle_token(token):
    global ai_message
    ai_message = ai_message + token
    emit('time_update', {'message': ai_message}, room='room')

llm = ChatOpenAI(
        streaming=True, 
        callback_manager=CallbackManager([mycbhandler(handle_token)]), 
        verbose=True, 
        temperature=0
)
#--------------------------------------------------------#

from flask import Flask, render_template
from flask_socketio import SocketIO, emit, join_room
import time


app = Flask(__name__)
socketio = SocketIO(app)

@app.route('/')
def index():
    return render_template('index.html')

@socketio.on('start_process', namespace='/test')
def start_process(message):
    join_room('room')

    # for i in range(1, 6):
    #     time.sleep(1)
    #     emit('time_update', {'message': f'{i} seconds passed.'}, room='room')
    #以下の2行を付け足し
    global llm
    response = llm(messages=[HumanMessage(content="こんにちわ")])

    emit('process_completed', {'message': response.content}, room='room')

if __name__ == '__main__':
    socketio.run(app, debug=True)


おまけ:自作openaiクラス

これらのことを毎回書くのが面倒なので、自作クラスを作りました。上にも書きましたが、langchalnは0.0.142(以降?)ですのでご注意を。

pip install langchain==0.0.142
from typing import Any, Dict, List, Optional, Union
import  os

from langchain.prompts.chat import (
    ChatPromptTemplate          ,
    SystemMessagePromptTemplate ,
    MessagesPlaceholder         ,
    HumanMessagePromptTemplate  ,
)
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationChain

from langchain.callbacks.base import CallbackManager, BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage

#--- コールバッククラス -----------------------------------#
class mycbhandler(BaseCallbackHandler):

    streaming_handler = None
    def __init__(self, jisaku_callbackfunction):
        #自作のコールバック関数を登録
        self.streaming_handler = jisaku_callbackfunction
        
    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        self.streaming_handler(token)

    def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
        pass

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        pass

    def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
        pass

    def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> None:
        class_name = serialized["name"]
        print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")

    def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
        print("\n\033[1m> Finished chain.\033[0m")

    def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
        pass

    def on_tool_start(self,serialized: Dict[str, Any], input_str: str, **kwargs: Any, ) -> None:
        pass

    def on_agent_action(self, action: AgentAction, color: Optional[str] = None, **kwargs: Any) -> Any:
        print(action)

    def on_tool_end(self, output: str, color: Optional[str] = None, observation_prefix: Optional[str] = None, llm_prefix: Optional[str] = None, **kwargs: Any) -> None:
        print(output)

    def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
        pass

    def on_text(self, text: str, color: Optional[str] = None, end: str = "", **kwargs: Optional[str]) -> None:
        print(text)

    def on_agent_finish(self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any) -> None:
        print(finish.log)
#--------------------------------------------------------#

#--- メインクラス -----------------------------------#
class myopenai :
    
    template        = None
    mycallbackfunc  = None

    def __init__(self) :
        pass

    def set_prompt(self, txt:str) :
        self.prompt = ChatPromptTemplate.from_messages([
            SystemMessagePromptTemplate.from_template   (txt)                       ,
            MessagesPlaceholder                         (variable_name="history")   ,
            HumanMessagePromptTemplate.from_template    ("{input}")                 ,
        ])

    def set_mycallbackfunction(self, mycallbackfunc:Any) :
        self.mycallbackfunc = mycallbackfunc

    #会話の読み込みを行う関数を定義
    def load_conversation(self, model:str, streaming:bool=True):
        llm                  = ChatOpenAI(
            model_name       = model,
            temperature      = 0,
            streaming        = streaming, 
            callback_manager = CallbackManager([mycbhandler(self.mycallbackfunc)]), 
            verbose          = True, 
        )

        memory          = ConversationBufferMemory(return_messages=True)
        # print(f'---{self.prompt}---')
        conversation    = ConversationChain(
            memory      = memory,
            prompt      = self.prompt,
            llm         = llm
        )

        return conversation

#--------------------------------------------------------#

if __name__ == '__main__' :
    os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
    mo = myopenai()

    #まず普通に出す(ストリーミングなし)
    mo.set_prompt('あなたは精神科医です。私の悩みを聞いて、適切にアドバイスをしてください。')
    conv = mo.load_conversation(model='gpt-3.5-turbo', streaming=False)
    ans = conv.predict(input='こんにちわ')
    print(ans)

    #ストリーミングあり
    def handle_token(token):
        print('\033[36m' + token + '\033[0m')

    mo.set_mycallbackfunction(handle_token)
    mo.set_prompt('あなたは精神科医です。私の悩みを聞いて、適切にアドバイスをしてください。')
    conv = mo.load_conversation(model='gpt-3.5-turbo', streaming=True)
    ans = conv.predict(input='お腹が痛い')
    print(ans)
    ans = conv.predict(input='別のアドバイスありますか?')
    print(ans)

この記事が気に入ったらサポートをしてみませんか?