langchain-chat-with-milvus/document_ai/chat.py
2023-11-19 20:54:12 +08:00

122 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
from langchain.tools import BaseTool
import proto.documents_pb2
import init
import doc_client
import openai
# class Eva
# 连续对话
messages = [
{
"role": "system",
"content": """
回答问题使用文档并以提问的语言和Markdown回答并告诉来源。
你得用“你”的身份指代用户。请辨别文档中的内容,有一些是不相干的。
"""
},
# {
# "role": "system",
# "content": f"""
# Context: {plain_text}
# """
# },
# {
# "role": "user",
# "content": f"""
# {question}
# """
# }
]
def ask_question(question):
messages.append({
"role": "user",
"content": f"""
{question}
"""
})
question_vec = init.text_to_vector(question)
search_param = {
"data": [question_vec],
"anns_field": "vector",
"param": {"metric_type": "L2"},
"limit": 5,
"expr": "user_id == 2",
"output_fields": ["todo_id", "title", "source", "todo_description", "language", "text", "user_id"],
}
res = init.collection.search(**search_param)
plain_text = ""
for i in range(len(res[0])):
_chunk_id = res[0][i].id
if _chunk_id == "0":
continue
print("正在获取分块 " + str(_chunk_id) + " 的内容...")
try:
_chunk_content = doc_client.stub.GetDocumentChunk(proto.documents_pb2.GetDocumentChunkByIdRequest(
id=_chunk_id
))
print(_chunk_content)
plain_text += ("=== \n" + f"文档 ID:{_chunk_content.document.id}\n"
+ f"文档内容: {_chunk_content.content}" + "===\n")
except Exception as e:
print(e)
messages.append({
"role": "system",
"content": f"""
文档: {plain_text}
"""
})
print("正在调用 LLM...")
result = openai.ChatCompletion.create(
messages=messages, model="gpt-3.5-turbo", temperature=0
)
res = result["choices"][0]["message"].to_dict_recursive()
# add to
messages.append({
"role": "assistant",
"content": res["content"]
})
# 删除多余的 system 消息
for i in range(len(messages)):
if messages[i]["role"] == "system":
if i == 0:
continue
messages.pop(i)
break
return res["content"]
while True:
print(messages)
if len(messages) > 10:
messages = messages[-10:]
print("很抱歉,我只能记住最近 10 条上下文数据。让我们重新开始吧。")
question = input("请输入问题:")
resp = ask_question(question)
print(resp)