From 80c871e09dfe80f86761e895e697ece0f3e77dd3 Mon Sep 17 00:00:00 2001 From: iVamp Date: Sun, 19 Nov 2023 20:54:12 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=B9=E8=BF=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- document_ai/agent.py | 83 ++++++++++++++++++++++++++ document_ai/chat.py | 121 ++++++++++++++++++++++++++++++++++++++ document_ai/search.py | 132 ++++++++++++++++++++++++++++++++++++------ document_ai/server.py | 2 +- requirements.txt | Bin 2118 -> 2118 bytes 5 files changed, 318 insertions(+), 20 deletions(-) create mode 100644 document_ai/agent.py create mode 100644 document_ai/chat.py diff --git a/document_ai/agent.py b/document_ai/agent.py new file mode 100644 index 0000000..37083aa --- /dev/null +++ b/document_ai/agent.py @@ -0,0 +1,83 @@ +from langchain.agents import Tool, load_tools +from langchain.memory import ConversationBufferMemory +from langchain.chat_models import ChatOpenAI +from langchain.schema import HumanMessage, BaseMessage +from langchain.utilities import SerpAPIWrapper +from langchain.agents import initialize_agent +from langchain.agents import AgentType +from getpass import getpass +import proto.documents_pb2 + +import init, doc_client + + +# def fake_result(str: str) -> str: +# print(str) +# return "博客名称: iVampireSP.com" +# + +def search_document(question: str) -> str: + print("搜索请求:" + question) + vec = init.text_to_vector(question) + + search_param = { + "data": [vec], + "anns_field": "vector", + "param": {"metric_type": "L2"}, + "limit": 5, + "expr": "user_id == 2", + "output_fields": ["todo_id", "title", "source", "todo_description", "language", "text", "user_id"], + } + res = init.collection.search(**search_param) + + plain_text = "" + + for i in range(len(res[0])): + _chunk_id = res[0][i].id + print("正在获取分块 " + str(_chunk_id) + " 的内容...") + + try: + _chunk_content = doc_client.stub.GetDocumentChunk(proto.documents_pb2.GetDocumentChunkByIdRequest( + id=_chunk_id + )) + + plain_text += "=== \n" + _chunk_content.content + " ===\n" + + except Exception as e: + print(e) + return plain_text + + +tools = [ + # Tool( + # name="Get Blog Name", + # func=fake_result, + # description="Get user's blog name from the Internet.", + # ), + Tool( + name="Search user's Library Document", + func=search_document, + description="优先使用 Search user's Library Document.", + ) +] + +llm = ChatOpenAI(temperature=0) +loaded_tools = load_tools(["llm-math"], llm=llm) +tools.extend(loaded_tools) + +memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) +memory.chat_memory.messages.append(HumanMessage(content="必须使用中文回复。")) +# memory.clear() + +agent_chain = initialize_agent(tools, llm, + agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION, + verbose=False, + memory=memory) + +while True: + question = input("请输入问题:") + question = "必须使用中文回复:" + question + result = agent_chain.run(input=question) + print(result) + + diff --git a/document_ai/chat.py b/document_ai/chat.py new file mode 100644 index 0000000..333b28e --- /dev/null +++ b/document_ai/chat.py @@ -0,0 +1,121 @@ +import json +from langchain.tools import BaseTool +import proto.documents_pb2 + +import init +import doc_client +import openai + +# class Eva + + +# 连续对话 +messages = [ + { + "role": "system", + "content": """ +回答问题使用文档,并以提问的语言和Markdown回答,并告诉来源。 +你得用“你”的身份指代用户。请辨别文档中的内容,有一些是不相干的。 +""" + }, + # { + # "role": "system", + # "content": f""" + # Context: {plain_text} + # """ + # }, + # { + # "role": "user", + # "content": f""" + # {question} + # """ + # } +] + + +def ask_question(question): + messages.append({ + "role": "user", + "content": f""" + {question} + """ + }) + + question_vec = init.text_to_vector(question) + search_param = { + "data": [question_vec], + "anns_field": "vector", + "param": {"metric_type": "L2"}, + "limit": 5, + "expr": "user_id == 2", + "output_fields": ["todo_id", "title", "source", "todo_description", "language", "text", "user_id"], + } + res = init.collection.search(**search_param) + + plain_text = "" + + for i in range(len(res[0])): + _chunk_id = res[0][i].id + + if _chunk_id == "0": + continue + + print("正在获取分块 " + str(_chunk_id) + " 的内容...") + + try: + _chunk_content = doc_client.stub.GetDocumentChunk(proto.documents_pb2.GetDocumentChunkByIdRequest( + id=_chunk_id + )) + + print(_chunk_content) + + plain_text += ("=== \n" + f"文档 ID:{_chunk_content.document.id}\n" + + f"文档内容: {_chunk_content.content}" + "===\n") + + except Exception as e: + print(e) + + messages.append({ + "role": "system", + "content": f""" +文档: {plain_text} +""" + }) + + print("正在调用 LLM...") + result = openai.ChatCompletion.create( + messages=messages, model="gpt-3.5-turbo", temperature=0 + ) + res = result["choices"][0]["message"].to_dict_recursive() + + # add to + messages.append({ + "role": "assistant", + "content": res["content"] + }) + + + # 删除多余的 system 消息 + for i in range(len(messages)): + if messages[i]["role"] == "system": + + if i == 0: + continue + + messages.pop(i) + break + + return res["content"] + + +while True: + + print(messages) + + if len(messages) > 10: + messages = messages[-10:] + print("很抱歉,我只能记住最近 10 条上下文数据。让我们重新开始吧。") + + question = input("请输入问题:") + resp = ask_question(question) + print(resp) diff --git a/document_ai/search.py b/document_ai/search.py index c455254..ddb6887 100644 --- a/document_ai/search.py +++ b/document_ai/search.py @@ -1,9 +1,15 @@ import json + +import openai +from langchain.chains import LLMChain +from langchain.prompts import PromptTemplate + import proto.documents_pb2 from langchain import text_splitter from langchain.chains.qa_with_sources import load_qa_with_sources_chain from langchain.embeddings import OpenAIEmbeddings -from langchain.llms.openai import OpenAI +# from langchain.llms.openai import OpenAI; +from langchain.chat_models.openai import ChatOpenAI from langchain.schema.document import Document from pymilvus import ( connections, @@ -16,6 +22,8 @@ from pymilvus import ( import init import doc_client +import openai +from langchain.adapters import openai as lc_openai # # question = """ @@ -31,15 +39,15 @@ import doc_client # """ question = """ -为什么我会在 WHMCS 下开发摸不着头脑 +错误 yarn 什么了遇到我 """ - -vec = init.text_to_vector(question) - -# vec = "" # -# with open("../question_vec.json", "r") as f: -# vec = json.load(f) +# vec = init.text_to_vector(question) + +vec = "" + +with open("../question_vec.json", "r") as f: + vec = json.load(f) search_param = { "data": [vec], @@ -51,8 +59,13 @@ search_param = { } res = init.collection.search(**search_param) -document_chunk_ids = [] -real_document = [] +# 保留 5 个 +if len(res[0]) > 5: + res[0] = res[0][:5] + +# document_chunk_ids = [] +# real_document = [] +plain_text = "" for i in range(len(res[0])): _chunk_id = res[0][i].id @@ -65,23 +78,104 @@ for i in range(len(res[0])): # print(_chunk_content) - _doc_content_full = _chunk_content.content + # _doc_content_full = _chunk_content.content + # print("DOC OBJ:" + _doc_content_full) + plain_text += "=== \n" + _chunk_content.content + " ===\n" # real_document.append(_doc_content) # doc_obj = Document(page_content=_doc_content_full, metadata={"source": _chunk_content.title}) - doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"}) + # doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"}) - real_document.append(doc_obj) + # real_document.append(doc_obj) except Exception as e: print(e) -print(real_document) +# print(real_document) print("正在调用 LLM...") -chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce", return_intermediate_steps=True, - verbose=True) -question = "必须使用中文回复:" + question -output = chain({"input_documents": real_document, "question": question}, return_only_outputs=False) -print("回复:" + output["output_text"]) +# prompt_template = f"""Answer questions use the following context and reply in question spoken language and answer +# with Markdown format, you can fix syntax errors in the context, but do not change the meaning of the context. +# you can tell user context errors(syntax or meaning) in answer. +# --- +# {plain_text} +# --- +# Question: {question} +# Answer:""" + +messages = [ + { + "role": "system", + "content": """ +回答问题使用以下上下文,并以提问的语言和Markdown回答,并告诉来源。 +你得用“你”的身份指代用户。如果用户的问题有语法错误或者上下文的意思不对,你可以告诉用户。 +请辨别上下文中的内容,有一些是不相干的。 +""" + }, + { + "role": "system", + "content": f""" +Context: {plain_text} +""" + }, +{ + "role": "user", + "content": f""" +{question} +""" + } +] + +result = openai.ChatCompletion.create( + messages=messages, model="gpt-3.5-turbo", temperature=0 +) +res = result["choices"][0]["message"].to_dict_recursive() +print(res) + + +# prompt_template = f""" +# --- +# {plain_text} +# --- +# Question: {question} +# Answer:""" +# +# print(prompt_template) +# # PROMPT = PromptTemplate( +# # template=prompt_template, input_variables=["real_document", "question"] +# # ) +# +# +# ChatOpenAI + +# llm = OpenAI(temperature=0, model_name="gpt-3.5-turbo") +# # chain = LLMChain(llm=llm, prompt=PROMPT) +# +# output = llm(prompt_template) + +# gpt = openai.Completion.create( +# engine="gpt-3.5-turbo", +# prompt=prompt_template, +# max_tokens=150, +# temperature=0, +# top_p=1, +# frequency_penalty=0, +# presence_penalty=0, +# stop=["==="] +# ) + +# output = gpt["choices"][0]["text"] +# print(output) + +# output = chain({"real_document": real_document, "question": question}, return_only_outputs=True) + + +# print(output) + +# chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce", return_intermediate_steps=True, +# verbose=True) +# +# question = "必须使用中文回复:" + question +# output = chain({"input_documents": real_document, "question": question}, return_only_outputs=False) +# print("回复:" + output["output_text"]) diff --git a/document_ai/server.py b/document_ai/server.py index a725e7b..e302db5 100644 --- a/document_ai/server.py +++ b/document_ai/server.py @@ -42,7 +42,7 @@ class AIServer(proto.document_query_pb2_grpc.DocumentQuery): "data": [vec], "anns_field": "vector", "param": {"metric_type": "L2"}, - "limit": 10, + "limit": 5, "expr": "user_id == " + str(target.user_id), "output_fields": ["document_id", "user_id"], } diff --git a/requirements.txt b/requirements.txt index f45607199a24d30c04920bc7919f8566a903e4aa..fb048cc09a67c3c2fb81e8264e2e719b8f48e5ff 100644 GIT binary patch delta 20 ccmX>ma7ma7