改进 GPT 处理和回复
This commit is contained in:
parent
3144c49ee1
commit
8b6d718019
@ -1,4 +1,4 @@
|
||||
from threading import Thread
|
||||
from threading import Thread, Event
|
||||
|
||||
import vector
|
||||
import server
|
||||
|
@ -66,6 +66,7 @@ if len(res[0]) > 5:
|
||||
# document_chunk_ids = []
|
||||
# real_document = []
|
||||
plain_text = ""
|
||||
document_chunks = []
|
||||
|
||||
for i in range(len(res[0])):
|
||||
_chunk_id = res[0][i].id
|
||||
@ -80,6 +81,7 @@ for i in range(len(res[0])):
|
||||
|
||||
# _doc_content_full = _chunk_content.content
|
||||
# print("DOC OBJ:" + _doc_content_full)
|
||||
document_chunks.append(_chunk_content.content)
|
||||
plain_text += "=== \n" + _chunk_content.content + " ===\n"
|
||||
|
||||
# real_document.append(_doc_content)
|
||||
@ -116,7 +118,7 @@ messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"""
|
||||
Context: {plain_text}
|
||||
{plain_text}
|
||||
"""
|
||||
},
|
||||
{
|
||||
|
@ -2,6 +2,7 @@ import os
|
||||
from concurrent import futures
|
||||
|
||||
import langchain
|
||||
import openai
|
||||
|
||||
import proto.document_query_pb2
|
||||
import proto.document_query_pb2_grpc
|
||||
@ -9,10 +10,10 @@ import grpc
|
||||
import proto.documents_pb2
|
||||
import init
|
||||
import doc_client
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.schema.document import Document
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
||||
# from langchain.llms.openai import OpenAI
|
||||
# from langchain.schema.document import Document
|
||||
# from langchain.embeddings import OpenAIEmbeddings
|
||||
# from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.cache import InMemoryCache
|
||||
|
||||
@ -20,6 +21,8 @@ langchain.llm_cache = InMemoryCache()
|
||||
|
||||
CHUNK_SIZE = 500
|
||||
|
||||
# openai.api_base = "https://api.openai.com/v1"
|
||||
# openai.api_key="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"
|
||||
|
||||
class AIServer(proto.document_query_pb2_grpc.DocumentQuery):
|
||||
def Query(self, target,
|
||||
@ -32,30 +35,25 @@ class AIServer(proto.document_query_pb2_grpc.DocumentQuery):
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
|
||||
|
||||
print("新的请求:" + target.question)
|
||||
vec = init.text_to_vector(target.question)
|
||||
|
||||
question = "Reply in spoken language:" + target.question
|
||||
question = target.question
|
||||
|
||||
search_param = {
|
||||
"data": [vec],
|
||||
"anns_field": "vector",
|
||||
"param": {"metric_type": "L2"},
|
||||
"limit": 5,
|
||||
"expr": "user_id == " + str(target.user_id),
|
||||
"output_fields": ["document_id", "user_id"],
|
||||
"expr": "user_id == " + str(target.user_id) + " && library_id == " + str(target.library_id),
|
||||
"output_fields": ["document_id", "user_id", "library_id"],
|
||||
}
|
||||
|
||||
res = init.collection.search(**search_param)
|
||||
|
||||
# # 最多 5 个
|
||||
# if len(res[0]) > 5:
|
||||
# res[0] = res[0][:5]
|
||||
|
||||
|
||||
# document_chunk_ids = []
|
||||
real_document = []
|
||||
document_text = ""
|
||||
# real_document = []
|
||||
sources = []
|
||||
|
||||
for i in range(len(res[0])):
|
||||
_chunk_id = res[0][i].id
|
||||
@ -66,27 +64,32 @@ class AIServer(proto.document_query_pb2_grpc.DocumentQuery):
|
||||
id=_chunk_id
|
||||
))
|
||||
|
||||
_doc_content_full = _chunk_content.content
|
||||
print(_chunk_content.document)
|
||||
|
||||
doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"})
|
||||
# _doc_content_full = _chunk_content.content
|
||||
document_text += "\n" + _chunk_content.content + "\n"
|
||||
|
||||
real_document.append(doc_obj)
|
||||
# append
|
||||
sources.append({
|
||||
"text": _chunk_content.content,
|
||||
"document_id": _chunk_content.document.id
|
||||
})
|
||||
|
||||
# doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"})
|
||||
|
||||
# real_document.append(doc_obj)
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
print(real_document)
|
||||
|
||||
print("正在调用 LLM...")
|
||||
chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce",
|
||||
return_intermediate_steps=True,
|
||||
verbose=True)
|
||||
|
||||
output = chain({"input_documents": real_document, "question": question}, return_only_outputs=False)
|
||||
print("回复:" + output["output_text"])
|
||||
output = search(document_text, question)
|
||||
|
||||
print("完成。")
|
||||
return proto.document_query_pb2.QueryResponse(
|
||||
text=output["output_text"]
|
||||
text=output,
|
||||
sources=sources
|
||||
)
|
||||
|
||||
def Chunk(self,
|
||||
@ -135,5 +138,35 @@ def serve():
|
||||
server.wait_for_termination()
|
||||
|
||||
|
||||
def search(summaries: str, question: str) -> str:
|
||||
prompt = f"""
|
||||
使用以下文档回答问题,使用Markdown回答你得用“你”的身份指代用户。如果你不知道答案,你可以说你不知道,不要编造答案。总是使用中文回复。
|
||||
|
||||
QUESTION: {question}
|
||||
|
||||
===文档开始===
|
||||
{summaries}
|
||||
===文档结束===
|
||||
|
||||
FINAL ANSWER:
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}
|
||||
]
|
||||
|
||||
print(prompt)
|
||||
|
||||
result = openai.ChatCompletion.create(
|
||||
messages=messages, model="gpt-3.5-turbo", temperature=0
|
||||
)
|
||||
res = result["choices"][0]["message"].to_dict_recursive()
|
||||
print(res)
|
||||
return res["content"]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
serve()
|
||||
|
166
document_ai/server.py.bak
Normal file
166
document_ai/server.py.bak
Normal file
@ -0,0 +1,166 @@
|
||||
import os
|
||||
from concurrent import futures
|
||||
|
||||
import langchain
|
||||
import openai
|
||||
|
||||
import proto.document_query_pb2
|
||||
import proto.document_query_pb2_grpc
|
||||
import grpc
|
||||
import proto.documents_pb2
|
||||
import init
|
||||
import doc_client
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.schema.document import Document
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.cache import InMemoryCache
|
||||
|
||||
langchain.llm_cache = InMemoryCache()
|
||||
|
||||
CHUNK_SIZE = 500
|
||||
|
||||
|
||||
class AIServer(proto.document_query_pb2_grpc.DocumentQuery):
|
||||
def Query(self, target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
|
||||
|
||||
print("新的请求:" + target.question)
|
||||
vec = init.text_to_vector(target.question)
|
||||
|
||||
question = "Reply in spoken language:" + target.question
|
||||
|
||||
search_param = {
|
||||
"data": [vec],
|
||||
"anns_field": "vector",
|
||||
"param": {"metric_type": "L2"},
|
||||
"limit": 5,
|
||||
"expr": "user_id == " + str(target.user_id),
|
||||
"output_fields": ["document_id", "user_id"],
|
||||
}
|
||||
|
||||
res = init.collection.search(**search_param)
|
||||
|
||||
# # 最多 5 个
|
||||
# if len(res[0]) > 5:
|
||||
# res[0] = res[0][:5]
|
||||
|
||||
|
||||
# document_chunk_ids = []
|
||||
real_document = []
|
||||
|
||||
for i in range(len(res[0])):
|
||||
_chunk_id = res[0][i].id
|
||||
print("正在获取分块 " + str(_chunk_id) + " 的内容...")
|
||||
|
||||
try:
|
||||
_chunk_content = doc_client.stub.GetDocumentChunk(proto.documents_pb2.GetDocumentChunkByIdRequest(
|
||||
id=_chunk_id
|
||||
))
|
||||
|
||||
_doc_content_full = _chunk_content.content
|
||||
|
||||
doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"})
|
||||
|
||||
real_document.append(doc_obj)
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
print(real_document)
|
||||
|
||||
print("正在调用 LLM...")
|
||||
chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce",
|
||||
return_intermediate_steps=True,
|
||||
verbose=True)
|
||||
|
||||
output = chain({"input_documents": real_document, "question": question}, return_only_outputs=False)
|
||||
print("回复:" + output["output_text"])
|
||||
|
||||
return proto.document_query_pb2.QueryResponse(
|
||||
text=output["output_text"]
|
||||
)
|
||||
|
||||
def Chunk(self,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=CHUNK_SIZE,
|
||||
chunk_overlap=20,
|
||||
length_function=len,
|
||||
add_start_index=True,
|
||||
)
|
||||
|
||||
page_contents = text_splitter.create_documents([
|
||||
target.text
|
||||
])
|
||||
|
||||
texts = []
|
||||
|
||||
for page_content in page_contents:
|
||||
texts.append(page_content.page_content)
|
||||
|
||||
return proto.document_query_pb2.ChunkResponse(
|
||||
texts=texts
|
||||
)
|
||||
|
||||
|
||||
def serve():
|
||||
_ADDR = os.getenv("BIND")
|
||||
if _ADDR is None:
|
||||
_ADDR = "[::]:50051"
|
||||
print("Listening on", _ADDR)
|
||||
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
proto.document_query_pb2_grpc.add_DocumentQueryServicer_to_server(AIServer(), server)
|
||||
|
||||
server.add_insecure_port(_ADDR)
|
||||
server.start()
|
||||
server.wait_for_termination()
|
||||
|
||||
|
||||
def search(summaries: str, question: str) -> str:
|
||||
prompt = f"""
|
||||
Given the following extracted parts of a long document and a question, create a final answer with references ("SOURCES").
|
||||
If you don't know the answer, just say that you don't know. Don't try to make up an answer. ALWAYS response with spoken language.
|
||||
|
||||
QUESTION: {question}
|
||||
=========
|
||||
{summaries}
|
||||
=========
|
||||
FINAL ANSWER:
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}
|
||||
]
|
||||
|
||||
result = openai.ChatCompletion.create(
|
||||
messages=messages, model="gpt-3.5-turbo", temperature=0
|
||||
)
|
||||
res = result["choices"][0]["message"].to_dict_recursive()
|
||||
print(res)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
serve()
|
Loading…
Reference in New Issue
Block a user