From 8b6d71801905d2b5f402e5540e05485ef7dc201d Mon Sep 17 00:00:00 2001 From: "iVampireSP.com" Date: Wed, 22 Nov 2023 10:13:58 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=B9=E8=BF=9B=20GPT=20=E5=A4=84=E7=90=86?= =?UTF-8?q?=E5=92=8C=E5=9B=9E=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- document_ai/run.py | 2 +- document_ai/search.py | 4 +- document_ai/server.py | 85 +++++++++++++------ document_ai/server.py.bak | 166 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 229 insertions(+), 28 deletions(-) create mode 100644 document_ai/server.py.bak diff --git a/document_ai/run.py b/document_ai/run.py index 076b160..343535b 100644 --- a/document_ai/run.py +++ b/document_ai/run.py @@ -1,4 +1,4 @@ -from threading import Thread +from threading import Thread, Event import vector import server diff --git a/document_ai/search.py b/document_ai/search.py index ddb6887..b26a774 100644 --- a/document_ai/search.py +++ b/document_ai/search.py @@ -66,6 +66,7 @@ if len(res[0]) > 5: # document_chunk_ids = [] # real_document = [] plain_text = "" +document_chunks = [] for i in range(len(res[0])): _chunk_id = res[0][i].id @@ -80,6 +81,7 @@ for i in range(len(res[0])): # _doc_content_full = _chunk_content.content # print("DOC OBJ:" + _doc_content_full) + document_chunks.append(_chunk_content.content) plain_text += "=== \n" + _chunk_content.content + " ===\n" # real_document.append(_doc_content) @@ -116,7 +118,7 @@ messages = [ { "role": "system", "content": f""" -Context: {plain_text} +{plain_text} """ }, { diff --git a/document_ai/server.py b/document_ai/server.py index e302db5..b8b306e 100644 --- a/document_ai/server.py +++ b/document_ai/server.py @@ -2,6 +2,7 @@ import os from concurrent import futures import langchain +import openai import proto.document_query_pb2 import proto.document_query_pb2_grpc @@ -9,10 +10,10 @@ import grpc import proto.documents_pb2 import init import doc_client -from langchain.llms.openai import OpenAI -from langchain.schema.document import Document -from langchain.embeddings import OpenAIEmbeddings -from langchain.chains.qa_with_sources import load_qa_with_sources_chain +# from langchain.llms.openai import OpenAI +# from langchain.schema.document import Document +# from langchain.embeddings import OpenAIEmbeddings +# from langchain.chains.qa_with_sources import load_qa_with_sources_chain from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.cache import InMemoryCache @@ -20,6 +21,8 @@ langchain.llm_cache = InMemoryCache() CHUNK_SIZE = 500 +# openai.api_base = "https://api.openai.com/v1" +# openai.api_key="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" class AIServer(proto.document_query_pb2_grpc.DocumentQuery): def Query(self, target, @@ -32,30 +35,25 @@ class AIServer(proto.document_query_pb2_grpc.DocumentQuery): timeout=None, metadata=None): - print("新的请求:" + target.question) vec = init.text_to_vector(target.question) - question = "Reply in spoken language:" + target.question + question = target.question search_param = { "data": [vec], "anns_field": "vector", "param": {"metric_type": "L2"}, "limit": 5, - "expr": "user_id == " + str(target.user_id), - "output_fields": ["document_id", "user_id"], + "expr": "user_id == " + str(target.user_id) + " && library_id == " + str(target.library_id), + "output_fields": ["document_id", "user_id", "library_id"], } res = init.collection.search(**search_param) - # # 最多 5 个 - # if len(res[0]) > 5: - # res[0] = res[0][:5] - - - # document_chunk_ids = [] - real_document = [] + document_text = "" + # real_document = [] + sources = [] for i in range(len(res[0])): _chunk_id = res[0][i].id @@ -66,27 +64,32 @@ class AIServer(proto.document_query_pb2_grpc.DocumentQuery): id=_chunk_id )) - _doc_content_full = _chunk_content.content + print(_chunk_content.document) - doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"}) + # _doc_content_full = _chunk_content.content + document_text += "\n" + _chunk_content.content + "\n" - real_document.append(doc_obj) + # append + sources.append({ + "text": _chunk_content.content, + "document_id": _chunk_content.document.id + }) + + # doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"}) + + # real_document.append(doc_obj) except Exception as e: print(e) - print(real_document) - print("正在调用 LLM...") - chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce", - return_intermediate_steps=True, - verbose=True) - output = chain({"input_documents": real_document, "question": question}, return_only_outputs=False) - print("回复:" + output["output_text"]) + output = search(document_text, question) + print("完成。") return proto.document_query_pb2.QueryResponse( - text=output["output_text"] + text=output, + sources=sources ) def Chunk(self, @@ -135,5 +138,35 @@ def serve(): server.wait_for_termination() +def search(summaries: str, question: str) -> str: + prompt = f""" +使用以下文档回答问题,使用Markdown回答你得用“你”的身份指代用户。如果你不知道答案,你可以说你不知道,不要编造答案。总是使用中文回复。 + +QUESTION: {question} + +===文档开始=== +{summaries} +===文档结束=== + +FINAL ANSWER: +""" + + messages = [ + { + "role": "user", + "content": prompt + } + ] + + print(prompt) + + result = openai.ChatCompletion.create( + messages=messages, model="gpt-3.5-turbo", temperature=0 + ) + res = result["choices"][0]["message"].to_dict_recursive() + print(res) + return res["content"] + + if __name__ == '__main__': serve() diff --git a/document_ai/server.py.bak b/document_ai/server.py.bak new file mode 100644 index 0000000..8e5f514 --- /dev/null +++ b/document_ai/server.py.bak @@ -0,0 +1,166 @@ +import os +from concurrent import futures + +import langchain +import openai + +import proto.document_query_pb2 +import proto.document_query_pb2_grpc +import grpc +import proto.documents_pb2 +import init +import doc_client +from langchain.llms.openai import OpenAI +from langchain.schema.document import Document +from langchain.embeddings import OpenAIEmbeddings +from langchain.chains.qa_with_sources import load_qa_with_sources_chain +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain.cache import InMemoryCache + +langchain.llm_cache = InMemoryCache() + +CHUNK_SIZE = 500 + + +class AIServer(proto.document_query_pb2_grpc.DocumentQuery): + def Query(self, target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + + + print("新的请求:" + target.question) + vec = init.text_to_vector(target.question) + + question = "Reply in spoken language:" + target.question + + search_param = { + "data": [vec], + "anns_field": "vector", + "param": {"metric_type": "L2"}, + "limit": 5, + "expr": "user_id == " + str(target.user_id), + "output_fields": ["document_id", "user_id"], + } + + res = init.collection.search(**search_param) + + # # 最多 5 个 + # if len(res[0]) > 5: + # res[0] = res[0][:5] + + + # document_chunk_ids = [] + real_document = [] + + for i in range(len(res[0])): + _chunk_id = res[0][i].id + print("正在获取分块 " + str(_chunk_id) + " 的内容...") + + try: + _chunk_content = doc_client.stub.GetDocumentChunk(proto.documents_pb2.GetDocumentChunkByIdRequest( + id=_chunk_id + )) + + _doc_content_full = _chunk_content.content + + doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"}) + + real_document.append(doc_obj) + + except Exception as e: + print(e) + + print(real_document) + + print("正在调用 LLM...") + chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce", + return_intermediate_steps=True, + verbose=True) + + output = chain({"input_documents": real_document, "question": question}, return_only_outputs=False) + print("回复:" + output["output_text"]) + + return proto.document_query_pb2.QueryResponse( + text=output["output_text"] + ) + + def Chunk(self, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=CHUNK_SIZE, + chunk_overlap=20, + length_function=len, + add_start_index=True, + ) + + page_contents = text_splitter.create_documents([ + target.text + ]) + + texts = [] + + for page_content in page_contents: + texts.append(page_content.page_content) + + return proto.document_query_pb2.ChunkResponse( + texts=texts + ) + + +def serve(): + _ADDR = os.getenv("BIND") + if _ADDR is None: + _ADDR = "[::]:50051" + print("Listening on", _ADDR) + + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + proto.document_query_pb2_grpc.add_DocumentQueryServicer_to_server(AIServer(), server) + + server.add_insecure_port(_ADDR) + server.start() + server.wait_for_termination() + + +def search(summaries: str, question: str) -> str: + prompt = f""" +Given the following extracted parts of a long document and a question, create a final answer with references ("SOURCES"). +If you don't know the answer, just say that you don't know. Don't try to make up an answer. ALWAYS response with spoken language. + +QUESTION: {question} +========= +{summaries} +========= +FINAL ANSWER: +""" + + messages = [ + { + "role": "user", + "content": prompt + } + ] + + result = openai.ChatCompletion.create( + messages=messages, model="gpt-3.5-turbo", temperature=0 + ) + res = result["choices"][0]["message"].to_dict_recursive() + print(res) + + +if __name__ == '__main__': + serve()