改进 GPT 处理和回复
This commit is contained in:
parent
3144c49ee1
commit
8b6d718019
@ -1,4 +1,4 @@
|
|||||||
from threading import Thread
|
from threading import Thread, Event
|
||||||
|
|
||||||
import vector
|
import vector
|
||||||
import server
|
import server
|
||||||
|
@ -66,6 +66,7 @@ if len(res[0]) > 5:
|
|||||||
# document_chunk_ids = []
|
# document_chunk_ids = []
|
||||||
# real_document = []
|
# real_document = []
|
||||||
plain_text = ""
|
plain_text = ""
|
||||||
|
document_chunks = []
|
||||||
|
|
||||||
for i in range(len(res[0])):
|
for i in range(len(res[0])):
|
||||||
_chunk_id = res[0][i].id
|
_chunk_id = res[0][i].id
|
||||||
@ -80,6 +81,7 @@ for i in range(len(res[0])):
|
|||||||
|
|
||||||
# _doc_content_full = _chunk_content.content
|
# _doc_content_full = _chunk_content.content
|
||||||
# print("DOC OBJ:" + _doc_content_full)
|
# print("DOC OBJ:" + _doc_content_full)
|
||||||
|
document_chunks.append(_chunk_content.content)
|
||||||
plain_text += "=== \n" + _chunk_content.content + " ===\n"
|
plain_text += "=== \n" + _chunk_content.content + " ===\n"
|
||||||
|
|
||||||
# real_document.append(_doc_content)
|
# real_document.append(_doc_content)
|
||||||
@ -116,7 +118,7 @@ messages = [
|
|||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": f"""
|
"content": f"""
|
||||||
Context: {plain_text}
|
{plain_text}
|
||||||
"""
|
"""
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -2,6 +2,7 @@ import os
|
|||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
|
import openai
|
||||||
|
|
||||||
import proto.document_query_pb2
|
import proto.document_query_pb2
|
||||||
import proto.document_query_pb2_grpc
|
import proto.document_query_pb2_grpc
|
||||||
@ -9,10 +10,10 @@ import grpc
|
|||||||
import proto.documents_pb2
|
import proto.documents_pb2
|
||||||
import init
|
import init
|
||||||
import doc_client
|
import doc_client
|
||||||
from langchain.llms.openai import OpenAI
|
# from langchain.llms.openai import OpenAI
|
||||||
from langchain.schema.document import Document
|
# from langchain.schema.document import Document
|
||||||
from langchain.embeddings import OpenAIEmbeddings
|
# from langchain.embeddings import OpenAIEmbeddings
|
||||||
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
# from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
from langchain.cache import InMemoryCache
|
from langchain.cache import InMemoryCache
|
||||||
|
|
||||||
@ -20,6 +21,8 @@ langchain.llm_cache = InMemoryCache()
|
|||||||
|
|
||||||
CHUNK_SIZE = 500
|
CHUNK_SIZE = 500
|
||||||
|
|
||||||
|
# openai.api_base = "https://api.openai.com/v1"
|
||||||
|
# openai.api_key="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"
|
||||||
|
|
||||||
class AIServer(proto.document_query_pb2_grpc.DocumentQuery):
|
class AIServer(proto.document_query_pb2_grpc.DocumentQuery):
|
||||||
def Query(self, target,
|
def Query(self, target,
|
||||||
@ -32,30 +35,25 @@ class AIServer(proto.document_query_pb2_grpc.DocumentQuery):
|
|||||||
timeout=None,
|
timeout=None,
|
||||||
metadata=None):
|
metadata=None):
|
||||||
|
|
||||||
|
|
||||||
print("新的请求:" + target.question)
|
print("新的请求:" + target.question)
|
||||||
vec = init.text_to_vector(target.question)
|
vec = init.text_to_vector(target.question)
|
||||||
|
|
||||||
question = "Reply in spoken language:" + target.question
|
question = target.question
|
||||||
|
|
||||||
search_param = {
|
search_param = {
|
||||||
"data": [vec],
|
"data": [vec],
|
||||||
"anns_field": "vector",
|
"anns_field": "vector",
|
||||||
"param": {"metric_type": "L2"},
|
"param": {"metric_type": "L2"},
|
||||||
"limit": 5,
|
"limit": 5,
|
||||||
"expr": "user_id == " + str(target.user_id),
|
"expr": "user_id == " + str(target.user_id) + " && library_id == " + str(target.library_id),
|
||||||
"output_fields": ["document_id", "user_id"],
|
"output_fields": ["document_id", "user_id", "library_id"],
|
||||||
}
|
}
|
||||||
|
|
||||||
res = init.collection.search(**search_param)
|
res = init.collection.search(**search_param)
|
||||||
|
|
||||||
# # 最多 5 个
|
document_text = ""
|
||||||
# if len(res[0]) > 5:
|
# real_document = []
|
||||||
# res[0] = res[0][:5]
|
sources = []
|
||||||
|
|
||||||
|
|
||||||
# document_chunk_ids = []
|
|
||||||
real_document = []
|
|
||||||
|
|
||||||
for i in range(len(res[0])):
|
for i in range(len(res[0])):
|
||||||
_chunk_id = res[0][i].id
|
_chunk_id = res[0][i].id
|
||||||
@ -66,27 +64,32 @@ class AIServer(proto.document_query_pb2_grpc.DocumentQuery):
|
|||||||
id=_chunk_id
|
id=_chunk_id
|
||||||
))
|
))
|
||||||
|
|
||||||
_doc_content_full = _chunk_content.content
|
print(_chunk_content.document)
|
||||||
|
|
||||||
doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"})
|
# _doc_content_full = _chunk_content.content
|
||||||
|
document_text += "\n" + _chunk_content.content + "\n"
|
||||||
|
|
||||||
real_document.append(doc_obj)
|
# append
|
||||||
|
sources.append({
|
||||||
|
"text": _chunk_content.content,
|
||||||
|
"document_id": _chunk_content.document.id
|
||||||
|
})
|
||||||
|
|
||||||
|
# doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"})
|
||||||
|
|
||||||
|
# real_document.append(doc_obj)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
print(real_document)
|
|
||||||
|
|
||||||
print("正在调用 LLM...")
|
print("正在调用 LLM...")
|
||||||
chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce",
|
|
||||||
return_intermediate_steps=True,
|
|
||||||
verbose=True)
|
|
||||||
|
|
||||||
output = chain({"input_documents": real_document, "question": question}, return_only_outputs=False)
|
output = search(document_text, question)
|
||||||
print("回复:" + output["output_text"])
|
|
||||||
|
|
||||||
|
print("完成。")
|
||||||
return proto.document_query_pb2.QueryResponse(
|
return proto.document_query_pb2.QueryResponse(
|
||||||
text=output["output_text"]
|
text=output,
|
||||||
|
sources=sources
|
||||||
)
|
)
|
||||||
|
|
||||||
def Chunk(self,
|
def Chunk(self,
|
||||||
@ -135,5 +138,35 @@ def serve():
|
|||||||
server.wait_for_termination()
|
server.wait_for_termination()
|
||||||
|
|
||||||
|
|
||||||
|
def search(summaries: str, question: str) -> str:
|
||||||
|
prompt = f"""
|
||||||
|
使用以下文档回答问题,使用Markdown回答你得用“你”的身份指代用户。如果你不知道答案,你可以说你不知道,不要编造答案。总是使用中文回复。
|
||||||
|
|
||||||
|
QUESTION: {question}
|
||||||
|
|
||||||
|
===文档开始===
|
||||||
|
{summaries}
|
||||||
|
===文档结束===
|
||||||
|
|
||||||
|
FINAL ANSWER:
|
||||||
|
"""
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
print(prompt)
|
||||||
|
|
||||||
|
result = openai.ChatCompletion.create(
|
||||||
|
messages=messages, model="gpt-3.5-turbo", temperature=0
|
||||||
|
)
|
||||||
|
res = result["choices"][0]["message"].to_dict_recursive()
|
||||||
|
print(res)
|
||||||
|
return res["content"]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
serve()
|
serve()
|
||||||
|
166
document_ai/server.py.bak
Normal file
166
document_ai/server.py.bak
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
import os
|
||||||
|
from concurrent import futures
|
||||||
|
|
||||||
|
import langchain
|
||||||
|
import openai
|
||||||
|
|
||||||
|
import proto.document_query_pb2
|
||||||
|
import proto.document_query_pb2_grpc
|
||||||
|
import grpc
|
||||||
|
import proto.documents_pb2
|
||||||
|
import init
|
||||||
|
import doc_client
|
||||||
|
from langchain.llms.openai import OpenAI
|
||||||
|
from langchain.schema.document import Document
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
from langchain.cache import InMemoryCache
|
||||||
|
|
||||||
|
langchain.llm_cache = InMemoryCache()
|
||||||
|
|
||||||
|
CHUNK_SIZE = 500
|
||||||
|
|
||||||
|
|
||||||
|
class AIServer(proto.document_query_pb2_grpc.DocumentQuery):
|
||||||
|
def Query(self, target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
|
||||||
|
|
||||||
|
print("新的请求:" + target.question)
|
||||||
|
vec = init.text_to_vector(target.question)
|
||||||
|
|
||||||
|
question = "Reply in spoken language:" + target.question
|
||||||
|
|
||||||
|
search_param = {
|
||||||
|
"data": [vec],
|
||||||
|
"anns_field": "vector",
|
||||||
|
"param": {"metric_type": "L2"},
|
||||||
|
"limit": 5,
|
||||||
|
"expr": "user_id == " + str(target.user_id),
|
||||||
|
"output_fields": ["document_id", "user_id"],
|
||||||
|
}
|
||||||
|
|
||||||
|
res = init.collection.search(**search_param)
|
||||||
|
|
||||||
|
# # 最多 5 个
|
||||||
|
# if len(res[0]) > 5:
|
||||||
|
# res[0] = res[0][:5]
|
||||||
|
|
||||||
|
|
||||||
|
# document_chunk_ids = []
|
||||||
|
real_document = []
|
||||||
|
|
||||||
|
for i in range(len(res[0])):
|
||||||
|
_chunk_id = res[0][i].id
|
||||||
|
print("正在获取分块 " + str(_chunk_id) + " 的内容...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
_chunk_content = doc_client.stub.GetDocumentChunk(proto.documents_pb2.GetDocumentChunkByIdRequest(
|
||||||
|
id=_chunk_id
|
||||||
|
))
|
||||||
|
|
||||||
|
_doc_content_full = _chunk_content.content
|
||||||
|
|
||||||
|
doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"})
|
||||||
|
|
||||||
|
real_document.append(doc_obj)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
print(real_document)
|
||||||
|
|
||||||
|
print("正在调用 LLM...")
|
||||||
|
chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce",
|
||||||
|
return_intermediate_steps=True,
|
||||||
|
verbose=True)
|
||||||
|
|
||||||
|
output = chain({"input_documents": real_document, "question": question}, return_only_outputs=False)
|
||||||
|
print("回复:" + output["output_text"])
|
||||||
|
|
||||||
|
return proto.document_query_pb2.QueryResponse(
|
||||||
|
text=output["output_text"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def Chunk(self,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=CHUNK_SIZE,
|
||||||
|
chunk_overlap=20,
|
||||||
|
length_function=len,
|
||||||
|
add_start_index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
page_contents = text_splitter.create_documents([
|
||||||
|
target.text
|
||||||
|
])
|
||||||
|
|
||||||
|
texts = []
|
||||||
|
|
||||||
|
for page_content in page_contents:
|
||||||
|
texts.append(page_content.page_content)
|
||||||
|
|
||||||
|
return proto.document_query_pb2.ChunkResponse(
|
||||||
|
texts=texts
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def serve():
|
||||||
|
_ADDR = os.getenv("BIND")
|
||||||
|
if _ADDR is None:
|
||||||
|
_ADDR = "[::]:50051"
|
||||||
|
print("Listening on", _ADDR)
|
||||||
|
|
||||||
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||||
|
proto.document_query_pb2_grpc.add_DocumentQueryServicer_to_server(AIServer(), server)
|
||||||
|
|
||||||
|
server.add_insecure_port(_ADDR)
|
||||||
|
server.start()
|
||||||
|
server.wait_for_termination()
|
||||||
|
|
||||||
|
|
||||||
|
def search(summaries: str, question: str) -> str:
|
||||||
|
prompt = f"""
|
||||||
|
Given the following extracted parts of a long document and a question, create a final answer with references ("SOURCES").
|
||||||
|
If you don't know the answer, just say that you don't know. Don't try to make up an answer. ALWAYS response with spoken language.
|
||||||
|
|
||||||
|
QUESTION: {question}
|
||||||
|
=========
|
||||||
|
{summaries}
|
||||||
|
=========
|
||||||
|
FINAL ANSWER:
|
||||||
|
"""
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
result = openai.ChatCompletion.create(
|
||||||
|
messages=messages, model="gpt-3.5-turbo", temperature=0
|
||||||
|
)
|
||||||
|
res = result["choices"][0]["message"].to_dict_recursive()
|
||||||
|
print(res)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
serve()
|
Loading…
Reference in New Issue
Block a user