改进 GPT 处理和回复

This commit is contained in:
iVampireSP.com 2023-11-22 10:13:58 +08:00
parent 3144c49ee1
commit 8b6d718019
No known key found for this signature in database
GPG Key ID: 2F7B001CA27A8132
4 changed files with 229 additions and 28 deletions

View File

@ -1,4 +1,4 @@
from threading import Thread from threading import Thread, Event
import vector import vector
import server import server

View File

@ -66,6 +66,7 @@ if len(res[0]) > 5:
# document_chunk_ids = [] # document_chunk_ids = []
# real_document = [] # real_document = []
plain_text = "" plain_text = ""
document_chunks = []
for i in range(len(res[0])): for i in range(len(res[0])):
_chunk_id = res[0][i].id _chunk_id = res[0][i].id
@ -80,6 +81,7 @@ for i in range(len(res[0])):
# _doc_content_full = _chunk_content.content # _doc_content_full = _chunk_content.content
# print("DOC OBJ:" + _doc_content_full) # print("DOC OBJ:" + _doc_content_full)
document_chunks.append(_chunk_content.content)
plain_text += "=== \n" + _chunk_content.content + " ===\n" plain_text += "=== \n" + _chunk_content.content + " ===\n"
# real_document.append(_doc_content) # real_document.append(_doc_content)
@ -116,7 +118,7 @@ messages = [
{ {
"role": "system", "role": "system",
"content": f""" "content": f"""
Context: {plain_text} {plain_text}
""" """
}, },
{ {

View File

@ -2,6 +2,7 @@ import os
from concurrent import futures from concurrent import futures
import langchain import langchain
import openai
import proto.document_query_pb2 import proto.document_query_pb2
import proto.document_query_pb2_grpc import proto.document_query_pb2_grpc
@ -9,10 +10,10 @@ import grpc
import proto.documents_pb2 import proto.documents_pb2
import init import init
import doc_client import doc_client
from langchain.llms.openai import OpenAI # from langchain.llms.openai import OpenAI
from langchain.schema.document import Document # from langchain.schema.document import Document
from langchain.embeddings import OpenAIEmbeddings # from langchain.embeddings import OpenAIEmbeddings
from langchain.chains.qa_with_sources import load_qa_with_sources_chain # from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.cache import InMemoryCache from langchain.cache import InMemoryCache
@ -20,6 +21,8 @@ langchain.llm_cache = InMemoryCache()
CHUNK_SIZE = 500 CHUNK_SIZE = 500
# openai.api_base = "https://api.openai.com/v1"
# openai.api_key="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"
class AIServer(proto.document_query_pb2_grpc.DocumentQuery): class AIServer(proto.document_query_pb2_grpc.DocumentQuery):
def Query(self, target, def Query(self, target,
@ -32,30 +35,25 @@ class AIServer(proto.document_query_pb2_grpc.DocumentQuery):
timeout=None, timeout=None,
metadata=None): metadata=None):
print("新的请求:" + target.question) print("新的请求:" + target.question)
vec = init.text_to_vector(target.question) vec = init.text_to_vector(target.question)
question = "Reply in spoken language:" + target.question question = target.question
search_param = { search_param = {
"data": [vec], "data": [vec],
"anns_field": "vector", "anns_field": "vector",
"param": {"metric_type": "L2"}, "param": {"metric_type": "L2"},
"limit": 5, "limit": 5,
"expr": "user_id == " + str(target.user_id), "expr": "user_id == " + str(target.user_id) + " && library_id == " + str(target.library_id),
"output_fields": ["document_id", "user_id"], "output_fields": ["document_id", "user_id", "library_id"],
} }
res = init.collection.search(**search_param) res = init.collection.search(**search_param)
# # 最多 5 个 document_text = ""
# if len(res[0]) > 5: # real_document = []
# res[0] = res[0][:5] sources = []
# document_chunk_ids = []
real_document = []
for i in range(len(res[0])): for i in range(len(res[0])):
_chunk_id = res[0][i].id _chunk_id = res[0][i].id
@ -66,27 +64,32 @@ class AIServer(proto.document_query_pb2_grpc.DocumentQuery):
id=_chunk_id id=_chunk_id
)) ))
_doc_content_full = _chunk_content.content print(_chunk_content.document)
doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"}) # _doc_content_full = _chunk_content.content
document_text += "\n" + _chunk_content.content + "\n"
real_document.append(doc_obj) # append
sources.append({
"text": _chunk_content.content,
"document_id": _chunk_content.document.id
})
# doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"})
# real_document.append(doc_obj)
except Exception as e: except Exception as e:
print(e) print(e)
print(real_document)
print("正在调用 LLM...") print("正在调用 LLM...")
chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce",
return_intermediate_steps=True,
verbose=True)
output = chain({"input_documents": real_document, "question": question}, return_only_outputs=False) output = search(document_text, question)
print("回复:" + output["output_text"])
print("完成。")
return proto.document_query_pb2.QueryResponse( return proto.document_query_pb2.QueryResponse(
text=output["output_text"] text=output,
sources=sources
) )
def Chunk(self, def Chunk(self,
@ -135,5 +138,35 @@ def serve():
server.wait_for_termination() server.wait_for_termination()
def search(summaries: str, question: str) -> str:
prompt = f"""
使用以下文档回答问题使用Markdown回答你得用的身份指代用户如果你不知道答案你可以说你不知道不要编造答案总是使用中文回复
QUESTION: {question}
===文档开始===
{summaries}
===文档结束===
FINAL ANSWER:
"""
messages = [
{
"role": "user",
"content": prompt
}
]
print(prompt)
result = openai.ChatCompletion.create(
messages=messages, model="gpt-3.5-turbo", temperature=0
)
res = result["choices"][0]["message"].to_dict_recursive()
print(res)
return res["content"]
if __name__ == '__main__': if __name__ == '__main__':
serve() serve()

166
document_ai/server.py.bak Normal file
View File

@ -0,0 +1,166 @@
import os
from concurrent import futures
import langchain
import openai
import proto.document_query_pb2
import proto.document_query_pb2_grpc
import grpc
import proto.documents_pb2
import init
import doc_client
from langchain.llms.openai import OpenAI
from langchain.schema.document import Document
from langchain.embeddings import OpenAIEmbeddings
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.cache import InMemoryCache
langchain.llm_cache = InMemoryCache()
CHUNK_SIZE = 500
class AIServer(proto.document_query_pb2_grpc.DocumentQuery):
def Query(self, target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
print("新的请求:" + target.question)
vec = init.text_to_vector(target.question)
question = "Reply in spoken language:" + target.question
search_param = {
"data": [vec],
"anns_field": "vector",
"param": {"metric_type": "L2"},
"limit": 5,
"expr": "user_id == " + str(target.user_id),
"output_fields": ["document_id", "user_id"],
}
res = init.collection.search(**search_param)
# # 最多 5 个
# if len(res[0]) > 5:
# res[0] = res[0][:5]
# document_chunk_ids = []
real_document = []
for i in range(len(res[0])):
_chunk_id = res[0][i].id
print("正在获取分块 " + str(_chunk_id) + " 的内容...")
try:
_chunk_content = doc_client.stub.GetDocumentChunk(proto.documents_pb2.GetDocumentChunkByIdRequest(
id=_chunk_id
))
_doc_content_full = _chunk_content.content
doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"})
real_document.append(doc_obj)
except Exception as e:
print(e)
print(real_document)
print("正在调用 LLM...")
chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce",
return_intermediate_steps=True,
verbose=True)
output = chain({"input_documents": real_document, "question": question}, return_only_outputs=False)
print("回复:" + output["output_text"])
return proto.document_query_pb2.QueryResponse(
text=output["output_text"]
)
def Chunk(self,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=20,
length_function=len,
add_start_index=True,
)
page_contents = text_splitter.create_documents([
target.text
])
texts = []
for page_content in page_contents:
texts.append(page_content.page_content)
return proto.document_query_pb2.ChunkResponse(
texts=texts
)
def serve():
_ADDR = os.getenv("BIND")
if _ADDR is None:
_ADDR = "[::]:50051"
print("Listening on", _ADDR)
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
proto.document_query_pb2_grpc.add_DocumentQueryServicer_to_server(AIServer(), server)
server.add_insecure_port(_ADDR)
server.start()
server.wait_for_termination()
def search(summaries: str, question: str) -> str:
prompt = f"""
Given the following extracted parts of a long document and a question, create a final answer with references ("SOURCES").
If you don't know the answer, just say that you don't know. Don't try to make up an answer. ALWAYS response with spoken language.
QUESTION: {question}
=========
{summaries}
=========
FINAL ANSWER:
"""
messages = [
{
"role": "user",
"content": prompt
}
]
result = openai.ChatCompletion.create(
messages=messages, model="gpt-3.5-turbo", temperature=0
)
res = result["choices"][0]["message"].to_dict_recursive()
print(res)
if __name__ == '__main__':
serve()