84 lines
2.4 KiB
Python
84 lines
2.4 KiB
Python
from langchain.agents import Tool, load_tools
|
|
from langchain.memory import ConversationBufferMemory
|
|
from langchain.chat_models import ChatOpenAI
|
|
from langchain.schema import HumanMessage, BaseMessage
|
|
from langchain.utilities import SerpAPIWrapper
|
|
from langchain.agents import initialize_agent
|
|
from langchain.agents import AgentType
|
|
from getpass import getpass
|
|
import proto.documents_pb2
|
|
|
|
import init, doc_client
|
|
|
|
|
|
# def fake_result(str: str) -> str:
|
|
# print(str)
|
|
# return "博客名称: iVampireSP.com"
|
|
#
|
|
|
|
def search_document(question: str) -> str:
|
|
print("搜索请求:" + question)
|
|
vec = init.text_to_vector(question)
|
|
|
|
search_param = {
|
|
"data": [vec],
|
|
"anns_field": "vector",
|
|
"param": {"metric_type": "L2"},
|
|
"limit": 5,
|
|
"expr": "user_id == 2",
|
|
"output_fields": ["todo_id", "title", "source", "todo_description", "language", "text", "user_id"],
|
|
}
|
|
res = init.collection.search(**search_param)
|
|
|
|
plain_text = ""
|
|
|
|
for i in range(len(res[0])):
|
|
_chunk_id = res[0][i].id
|
|
print("正在获取分块 " + str(_chunk_id) + " 的内容...")
|
|
|
|
try:
|
|
_chunk_content = doc_client.stub.GetDocumentChunk(proto.documents_pb2.GetDocumentChunkByIdRequest(
|
|
id=_chunk_id
|
|
))
|
|
|
|
plain_text += "=== \n" + _chunk_content.content + " ===\n"
|
|
|
|
except Exception as e:
|
|
print(e)
|
|
return plain_text
|
|
|
|
|
|
tools = [
|
|
# Tool(
|
|
# name="Get Blog Name",
|
|
# func=fake_result,
|
|
# description="Get user's blog name from the Internet.",
|
|
# ),
|
|
Tool(
|
|
name="Search user's Library Document",
|
|
func=search_document,
|
|
description="优先使用 Search user's Library Document.",
|
|
)
|
|
]
|
|
|
|
llm = ChatOpenAI(temperature=0)
|
|
loaded_tools = load_tools(["llm-math"], llm=llm)
|
|
tools.extend(loaded_tools)
|
|
|
|
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
|
memory.chat_memory.messages.append(HumanMessage(content="必须使用中文回复。"))
|
|
# memory.clear()
|
|
|
|
agent_chain = initialize_agent(tools, llm,
|
|
agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,
|
|
verbose=False,
|
|
memory=memory)
|
|
|
|
while True:
|
|
question = input("请输入问题:")
|
|
question = "必须使用中文回复:" + question
|
|
result = agent_chain.run(input=question)
|
|
print(result)
|
|
|
|
|