langchain-chat-with-milvus/document_ai/chat.py

122 lines
2.7 KiB
Python
Raw Normal View History

2023-11-19 12:54:12 +00:00
import json
from langchain.tools import BaseTool
import proto.documents_pb2
import init
import doc_client
import openai
# class Eva
# 连续对话
messages = [
{
"role": "system",
"content": """
回答问题使用文档并以提问的语言和Markdown回答并告诉来源
你得用的身份指代用户请辨别文档中的内容有一些是不相干的
"""
},
# {
# "role": "system",
# "content": f"""
# Context: {plain_text}
# """
# },
# {
# "role": "user",
# "content": f"""
# {question}
# """
# }
]
def ask_question(question):
messages.append({
"role": "user",
"content": f"""
{question}
"""
})
question_vec = init.text_to_vector(question)
search_param = {
"data": [question_vec],
"anns_field": "vector",
"param": {"metric_type": "L2"},
"limit": 5,
"expr": "user_id == 2",
"output_fields": ["todo_id", "title", "source", "todo_description", "language", "text", "user_id"],
}
res = init.collection.search(**search_param)
plain_text = ""
for i in range(len(res[0])):
_chunk_id = res[0][i].id
if _chunk_id == "0":
continue
print("正在获取分块 " + str(_chunk_id) + " 的内容...")
try:
_chunk_content = doc_client.stub.GetDocumentChunk(proto.documents_pb2.GetDocumentChunkByIdRequest(
id=_chunk_id
))
print(_chunk_content)
plain_text += ("=== \n" + f"文档 ID:{_chunk_content.document.id}\n"
+ f"文档内容: {_chunk_content.content}" + "===\n")
except Exception as e:
print(e)
messages.append({
"role": "system",
"content": f"""
文档: {plain_text}
"""
})
print("正在调用 LLM...")
result = openai.ChatCompletion.create(
messages=messages, model="gpt-3.5-turbo", temperature=0
)
res = result["choices"][0]["message"].to_dict_recursive()
# add to
messages.append({
"role": "assistant",
"content": res["content"]
})
# 删除多余的 system 消息
for i in range(len(messages)):
if messages[i]["role"] == "system":
if i == 0:
continue
messages.pop(i)
break
return res["content"]
while True:
print(messages)
if len(messages) > 10:
messages = messages[-10:]
print("很抱歉,我只能记住最近 10 条上下文数据。让我们重新开始吧。")
question = input("请输入问题:")
resp = ask_question(question)
print(resp)