langchain-chat-with-milvus/document_ai/agent.py
2023-11-19 20:54:12 +08:00

84 lines
2.4 KiB
Python

from langchain.agents import Tool, load_tools
from langchain.memory import ConversationBufferMemory
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage, BaseMessage
from langchain.utilities import SerpAPIWrapper
from langchain.agents import initialize_agent
from langchain.agents import AgentType
from getpass import getpass
import proto.documents_pb2
import init, doc_client
# def fake_result(str: str) -> str:
# print(str)
# return "博客名称: iVampireSP.com"
#
def search_document(question: str) -> str:
print("搜索请求:" + question)
vec = init.text_to_vector(question)
search_param = {
"data": [vec],
"anns_field": "vector",
"param": {"metric_type": "L2"},
"limit": 5,
"expr": "user_id == 2",
"output_fields": ["todo_id", "title", "source", "todo_description", "language", "text", "user_id"],
}
res = init.collection.search(**search_param)
plain_text = ""
for i in range(len(res[0])):
_chunk_id = res[0][i].id
print("正在获取分块 " + str(_chunk_id) + " 的内容...")
try:
_chunk_content = doc_client.stub.GetDocumentChunk(proto.documents_pb2.GetDocumentChunkByIdRequest(
id=_chunk_id
))
plain_text += "=== \n" + _chunk_content.content + " ===\n"
except Exception as e:
print(e)
return plain_text
tools = [
# Tool(
# name="Get Blog Name",
# func=fake_result,
# description="Get user's blog name from the Internet.",
# ),
Tool(
name="Search user's Library Document",
func=search_document,
description="优先使用 Search user's Library Document.",
)
]
llm = ChatOpenAI(temperature=0)
loaded_tools = load_tools(["llm-math"], llm=llm)
tools.extend(loaded_tools)
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
memory.chat_memory.messages.append(HumanMessage(content="必须使用中文回复。"))
# memory.clear()
agent_chain = initialize_agent(tools, llm,
agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,
verbose=False,
memory=memory)
while True:
question = input("请输入问题:")
question = "必须使用中文回复:" + question
result = agent_chain.run(input=question)
print(result)