見出し画像

RAG DAY4 Hyde

Hydeとは

HyDE(Hypothetical Document Embeddings:仮の文書の埋め込み)は、入力されたクエリに対して仮の文書を生成し、その文書を埋め込み、検索に使用する手法でQuery Translationの一つです。

典型的な文書検索では、ユーザーが入力したクエリと文書の類似度を計算することが多いですが、クエリと文書が必ずしも類似しているとは限りません。

そこで、Hydeでは、生成モデルを使って仮の回答の文書(Hypothetical Document)を生成し、その文書と検索エンジンに格納された文書の類似度を計算してしまおうという考え方をします。

Hydeの例

例を見ていきましょう。
例えば次のようなクエリ(質問)があるとします

(ユーザーの質問)
社長の名前を教えて

そうするとHydeでは、retrievalに検索をかける前に仮の回答を作成します。

(GPTが作成した仮の回答)
社長の名前は大槻 大地です。

当然この回答はGPTが適当に作ったもので正しくはありません。

この回答を文章検索の入力として類似度検索を行うのが、Hydeとなります。

ここからがコードの説明です。

ソースコードはgithubに公開しています。

下準備

from langchain_community.document_loaders import DirectoryLoader
from langchain.text_splitter import MarkdownHeaderTextSplitter,RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain

#Document load
loader = DirectoryLoader("../datasets/company_documents_dataset_1/", glob="**/*.txt",recursive=True)
raw_docs = loader.load()

# Document split
headers_to_split_on = [
    ("#", "Header 1"),
    ("##", "Header 2"),
    ("###", "Header 3"),
]
markdown_splitter = MarkdownHeaderTextSplitter(
    headers_to_split_on=headers_to_split_on, 
    return_each_line=False,
    strip_headers = False 
)
docs = []
for raw_doc in raw_docs:
    source = raw_doc.metadata["source"]
    spilited_docs = markdown_splitter.split_text(raw_doc.page_content)
    for doc in spilited_docs:
        doc.metadata["source"] = source#metadataにsourceを加える
    docs = docs + spilited_docs
markdown_splited_docs = docs
text_splitter = RecursiveCharacterTextSplitter(chunk_size = 300,chunk_overlap=50)
docs = text_splitter.split_documents(docs)

# Embd
vectorstore = Chroma.from_documents(persist_directory="./vecstore/index", documents=docs, embedding=OpenAIEmbeddings())

 #llm llm = ChatOpenAI(model_name="gpt-3.5-turbo",temperature=0)

# retriever
retriever = vectorstore.as_retriever(search_kwargs={"k": 4})

Hyde Prompt

from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI
from langchain.prompts.prompt import PromptTemplate

hyde_template = """Please write a japanese sentence to answer the question
Question: {question}
Passage:"""
prompt_hyde = PromptTemplate.from_template(hyde_template)

hyde_chain = (
    prompt_hyde | ChatOpenAI(temperature=0) | StrOutputParser() 
)

# Run
question = "社長の名前は?"
hyde_chain.invoke({"question":question})
'社長の名前は山田さんです。'

このpromptを使ったchainで質問に対する仮の回答を作成することができます。

Retriever

retrieval_chain = hyde_chain | retriever 
retireved_docs = retrieval_chain.invoke({"question":question})
retireved_docs
[Document(page_content='### 社長  \n漆黒 花太郎(しっこく かたろう)', metadata={'Header 1': '株式会社架空ブラック 会社情報', 'Header 2': '会社概要', 'Header 3': '社長', 'source': '../datasets/company_documents_dataset_1/マニュアル/会社情報.txt'}),
 Document(page_content='### 社長  \n漆黒 花太郎(しっこく かたろう)', metadata={'Header 1': '株式会社架空ブラック 会社情報', 'Header 2': '会社概要', 'Header 3': '社長', 'source': '../datasets/company_documents_dataset_1/マニュアル/会社情報.txt'}),
 Document(page_content='### 社長  \n漆黒 花太郎(しっこく かたろう)', metadata={'Header 1': '株式会社架空ブラック 会社情報', 'Header 2': '会社概要', 'Header 3': '社長', 'source': '../datasets/company_documents_dataset_1/マニュアル/会社情報.txt'}),
 Document(page_content='## 社長のプロフィール  \n### 名前  \n漆黒 花太郎(しっこく かたろう)', metadata={'Header 1': '株式会社架空ブラック 会社情報', 'Header 2': '社長のプロフィール', 'Header 3': '名前', 'source': '../datasets/company_documents_dataset_1/マニュアル/会社情報.txt'})]

retrieval_chainでは、先ほどのhyde_chainにて出力した回答を使って、ベクトル検索を行います。

関連するドキュメントが取得できているのがわかります。

RAG

from langchain.prompts import ChatPromptTemplate
from langchain.callbacks.tracers import ConsoleCallbackHandler
rag_template = """Answer the following question based on this context:

{context}

Question: {question}
"""

rag_prompt = ChatPromptTemplate.from_template(rag_template)

rag_chain = (
    rag_prompt
    | llm
    | StrOutputParser()
)
handler = ConsoleCallbackHandler()

rag_chain.invoke({"context":retireved_docs,"question":question},{"callbacks":[handler]})

contextにretreaval_chainで作成したdocumentのリストであるretireved_docsを指定して、chainを実行します。

実行した結果がこちらです。

[chain/start] [1:chain:RunnableSequence] Entering Chain run with input:
[inputs]
[chain/start] [1:chain:RunnableSequence > 2:prompt:ChatPromptTemplate] Entering Prompt run with input:
[inputs]
[chain/end] [1:chain:RunnableSequence > 2:prompt:ChatPromptTemplate] [1ms] Exiting Prompt run with output:
{
  "lc": 1,
  "type": "constructor",
  "id": [
    "langchain",
    "prompts",
    "chat",
    "ChatPromptValue"
  ],
  "kwargs": {
    "messages": [
      {
        "lc": 1,
        "type": "constructor",
        "id": [
          "langchain",
          "schema",
          "messages",
          "HumanMessage"
        ],
        "kwargs": {
          "content": "Answer the following question based on this context:\n\n[Document(page_content='### 社長  \\n漆黒 花太郎(しっこく かたろう)', metadata={'Header 1': '株式会社架空ブラック 会社情報', 'Header 2': '会社概要', 'Header 3': '社長', 'source': '../datasets/company_documents_dataset_1/マニュアル/会社情報.txt'}), Document(page_content='### 社長  \\n漆黒 花太郎(しっこく かたろう)', metadata={'Header 1': '株式会社架空ブラック 会社情報', 'Header 2': '会社概要', 'Header 3': '社長', 'source': '../datasets/company_documents_dataset_1/マニュアル/会社情報.txt'}), Document(page_content='### 社長  \\n漆黒 花太郎(しっこく かたろう)', metadata={'Header 1': '株式会社架空ブラック 会社情報', 'Header 2': '会社概要', 'Header 3': '社長', 'source': '../datasets/company_documents_dataset_1/マニュアル/会社情報.txt'}), Document(page_content='## 社長のプロフィール  \\n### 名前  \\n漆黒 花太郎(しっこく かたろう)', metadata={'Header 1': '株式会社架空ブラック 会社情報', 'Header 2': '社長のプロフィール', 'Header 3': '名前', 'source': '../datasets/company_documents_dataset_1/マニュアル/会社情報.txt'})]\n\nQuestion: 社長の名前は?\n",
          "additional_kwargs": {}
        }
      }
    ]
  }
}
[llm/start] [1:chain:RunnableSequence > 3:llm:ChatOpenAI] Entering LLM run with input:
{
  "prompts": [
    "Human: Answer the following question based on this context:\n\n[Document(page_content='### 社長  \\n漆黒 花太郎(しっこく かたろう)', metadata={'Header 1': '株式会社架空ブラック 会社情報', 'Header 2': '会社概要', 'Header 3': '社長', 'source': '../datasets/company_documents_dataset_1/マニュアル/会社情報.txt'}), Document(page_content='### 社長  \\n漆黒 花太郎(しっこく かたろう)', metadata={'Header 1': '株式会社架空ブラック 会社情報', 'Header 2': '会社概要', 'Header 3': '社長', 'source': '../datasets/company_documents_dataset_1/マニュアル/会社情報.txt'}), Document(page_content='### 社長  \\n漆黒 花太郎(しっこく かたろう)', metadata={'Header 1': '株式会社架空ブラック 会社情報', 'Header 2': '会社概要', 'Header 3': '社長', 'source': '../datasets/company_documents_dataset_1/マニュアル/会社情報.txt'}), Document(page_content='## 社長のプロフィール  \\n### 名前  \\n漆黒 花太郎(しっこく かたろう)', metadata={'Header 1': '株式会社架空ブラック 会社情報', 'Header 2': '社長のプロフィール', 'Header 3': '名前', 'source': '../datasets/company_documents_dataset_1/マニュアル/会社情報.txt'})]\n\nQuestion: 社長の名前は?"
  ]
}
[llm/end] [1:chain:RunnableSequence > 3:llm:ChatOpenAI] [1.09s] Exiting LLM run with output:
{
  "generations": [
    [
      {
        "text": "漆黒 花太郎(しっこく かたろう)",
        "generation_info": {
          "finish_reason": "stop",
          "logprobs": null
        },
        "type": "ChatGeneration",
        "message": {
          "lc": 1,
          "type": "constructor",
          "id": [
            "langchain",
            "schema",
            "messages",
            "AIMessage"
          ],
          "kwargs": {
            "content": "漆黒 花太郎(しっこく かたろう)",
            "additional_kwargs": {}
          }
        }
      }
    ]
  ],
  "llm_output": {
    "token_usage": {
      "completion_tokens": 22,
      "prompt_tokens": 463,
      "total_tokens": 485
    },
    "model_name": "gpt-3.5-turbo",
    "system_fingerprint": "fp_69829325d0"
  },
  "run": null
}
[chain/start] [1:chain:RunnableSequence > 4:parser:StrOutputParser] Entering Parser run with input:
[inputs]
[chain/end] [1:chain:RunnableSequence > 4:parser:StrOutputParser] [0ms] Exiting Parser run with output:
{
  "output": "漆黒 花太郎(しっこく かたろう)"
}
[chain/end] [1:chain:RunnableSequence] [1.09s] Exiting Chain run with output:
{
  "output": "漆黒 花太郎(しっこく かたろう)"
}
'漆黒 花太郎(しっこく かたろう)'

しっかりと名前が出力できてるのがわかります。

chainの可視化

rag_chain.get_graph().print_ascii()
 +-------------+       
     | PromptInput |       
     +-------------+       
            *              
            *              
            *              
  +--------------------+   
  | ChatPromptTemplate |   
  +--------------------+   
            *              
            *              
            *              
      +------------+       
      | ChatOpenAI |       
      +------------+       
            *              
            *              
            *              
   +-----------------+     
   | StrOutputParser |     
   +-----------------+     
            *              
            *              
            *              
+-----------------------+  
| StrOutputParserOutput |  
+-----------------------+  


この記事では、RAGのQuery Translationの手法の一つであるHydeを紹介しました。

次はRAG FUSIONを紹介していきます。

この記事が気に入ったらサポートをしてみませんか?