122 lines
2.7 KiB
Python
122 lines
2.7 KiB
Python
|
import json
|
|||
|
from langchain.tools import BaseTool
|
|||
|
import proto.documents_pb2
|
|||
|
|
|||
|
import init
|
|||
|
import doc_client
|
|||
|
import openai
|
|||
|
|
|||
|
# class Eva
|
|||
|
|
|||
|
|
|||
|
# 连续对话
|
|||
|
messages = [
|
|||
|
{
|
|||
|
"role": "system",
|
|||
|
"content": """
|
|||
|
回答问题使用文档,并以提问的语言和Markdown回答,并告诉来源。
|
|||
|
你得用“你”的身份指代用户。请辨别文档中的内容,有一些是不相干的。
|
|||
|
"""
|
|||
|
},
|
|||
|
# {
|
|||
|
# "role": "system",
|
|||
|
# "content": f"""
|
|||
|
# Context: {plain_text}
|
|||
|
# """
|
|||
|
# },
|
|||
|
# {
|
|||
|
# "role": "user",
|
|||
|
# "content": f"""
|
|||
|
# {question}
|
|||
|
# """
|
|||
|
# }
|
|||
|
]
|
|||
|
|
|||
|
|
|||
|
def ask_question(question):
|
|||
|
messages.append({
|
|||
|
"role": "user",
|
|||
|
"content": f"""
|
|||
|
{question}
|
|||
|
"""
|
|||
|
})
|
|||
|
|
|||
|
question_vec = init.text_to_vector(question)
|
|||
|
search_param = {
|
|||
|
"data": [question_vec],
|
|||
|
"anns_field": "vector",
|
|||
|
"param": {"metric_type": "L2"},
|
|||
|
"limit": 5,
|
|||
|
"expr": "user_id == 2",
|
|||
|
"output_fields": ["todo_id", "title", "source", "todo_description", "language", "text", "user_id"],
|
|||
|
}
|
|||
|
res = init.collection.search(**search_param)
|
|||
|
|
|||
|
plain_text = ""
|
|||
|
|
|||
|
for i in range(len(res[0])):
|
|||
|
_chunk_id = res[0][i].id
|
|||
|
|
|||
|
if _chunk_id == "0":
|
|||
|
continue
|
|||
|
|
|||
|
print("正在获取分块 " + str(_chunk_id) + " 的内容...")
|
|||
|
|
|||
|
try:
|
|||
|
_chunk_content = doc_client.stub.GetDocumentChunk(proto.documents_pb2.GetDocumentChunkByIdRequest(
|
|||
|
id=_chunk_id
|
|||
|
))
|
|||
|
|
|||
|
print(_chunk_content)
|
|||
|
|
|||
|
plain_text += ("=== \n" + f"文档 ID:{_chunk_content.document.id}\n"
|
|||
|
+ f"文档内容: {_chunk_content.content}" + "===\n")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(e)
|
|||
|
|
|||
|
messages.append({
|
|||
|
"role": "system",
|
|||
|
"content": f"""
|
|||
|
文档: {plain_text}
|
|||
|
"""
|
|||
|
})
|
|||
|
|
|||
|
print("正在调用 LLM...")
|
|||
|
result = openai.ChatCompletion.create(
|
|||
|
messages=messages, model="gpt-3.5-turbo", temperature=0
|
|||
|
)
|
|||
|
res = result["choices"][0]["message"].to_dict_recursive()
|
|||
|
|
|||
|
# add to
|
|||
|
messages.append({
|
|||
|
"role": "assistant",
|
|||
|
"content": res["content"]
|
|||
|
})
|
|||
|
|
|||
|
|
|||
|
# 删除多余的 system 消息
|
|||
|
for i in range(len(messages)):
|
|||
|
if messages[i]["role"] == "system":
|
|||
|
|
|||
|
if i == 0:
|
|||
|
continue
|
|||
|
|
|||
|
messages.pop(i)
|
|||
|
break
|
|||
|
|
|||
|
return res["content"]
|
|||
|
|
|||
|
|
|||
|
while True:
|
|||
|
|
|||
|
print(messages)
|
|||
|
|
|||
|
if len(messages) > 10:
|
|||
|
messages = messages[-10:]
|
|||
|
print("很抱歉,我只能记住最近 10 条上下文数据。让我们重新开始吧。")
|
|||
|
|
|||
|
question = input("请输入问题:")
|
|||
|
resp = ask_question(question)
|
|||
|
print(resp)
|