from langchain.agents import Tool, load_tools from langchain.memory import ConversationBufferMemory from langchain.chat_models import ChatOpenAI from langchain.schema import HumanMessage, BaseMessage from langchain.utilities import SerpAPIWrapper from langchain.agents import initialize_agent from langchain.agents import AgentType from getpass import getpass import proto.documents_pb2 import init, doc_client # def fake_result(str: str) -> str: # print(str) # return "博客名称: iVampireSP.com" # def search_document(question: str) -> str: print("搜索请求:" + question) vec = init.text_to_vector(question) search_param = { "data": [vec], "anns_field": "vector", "param": {"metric_type": "L2"}, "limit": 5, "expr": "user_id == 2", "output_fields": ["todo_id", "title", "source", "todo_description", "language", "text", "user_id"], } res = init.collection.search(**search_param) plain_text = "" for i in range(len(res[0])): _chunk_id = res[0][i].id print("正在获取分块 " + str(_chunk_id) + " 的内容...") try: _chunk_content = doc_client.stub.GetDocumentChunk(proto.documents_pb2.GetDocumentChunkByIdRequest( id=_chunk_id )) plain_text += "=== \n" + _chunk_content.content + " ===\n" except Exception as e: print(e) return plain_text tools = [ # Tool( # name="Get Blog Name", # func=fake_result, # description="Get user's blog name from the Internet.", # ), Tool( name="Search user's Library Document", func=search_document, description="优先使用 Search user's Library Document.", ) ] llm = ChatOpenAI(temperature=0) loaded_tools = load_tools(["llm-math"], llm=llm) tools.extend(loaded_tools) memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) memory.chat_memory.messages.append(HumanMessage(content="必须使用中文回复。")) # memory.clear() agent_chain = initialize_agent(tools, llm, agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION, verbose=False, memory=memory) while True: question = input("请输入问题:") question = "必须使用中文回复:" + question result = agent_chain.run(input=question) print(result)