2023-11-15 08:20:30 +00:00
|
|
|
import os
|
|
|
|
from concurrent import futures
|
2023-11-15 14:24:03 +00:00
|
|
|
|
|
|
|
import langchain
|
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
|
|
2023-11-15 08:20:30 +00:00
|
|
|
import document_query_pb2
|
|
|
|
import document_query_pb2_grpc
|
|
|
|
import grpc
|
|
|
|
import documents_pb2
|
|
|
|
import init
|
|
|
|
import doc_client
|
|
|
|
from langchain.llms.openai import OpenAI
|
|
|
|
from langchain.schema.document import Document
|
|
|
|
from langchain.embeddings import OpenAIEmbeddings
|
|
|
|
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
|
|
|
|
2023-11-15 14:24:03 +00:00
|
|
|
from langchain.cache import InMemoryCache
|
|
|
|
|
|
|
|
langchain.llm_cache = InMemoryCache()
|
|
|
|
|
2023-11-15 08:20:30 +00:00
|
|
|
|
|
|
|
class AIServer(document_query_pb2_grpc.DocumentQuery):
|
|
|
|
def Query(self, request, context):
|
|
|
|
vec = init.text_to_vector(request.question)
|
|
|
|
|
|
|
|
question = request.question + "(必须使用中文回复)"
|
|
|
|
|
|
|
|
search_param = {
|
|
|
|
"data": [vec],
|
|
|
|
"anns_field": "vector",
|
|
|
|
"param": {"metric_type": "L2"},
|
|
|
|
"limit": 10,
|
|
|
|
"expr": "user_id == " + str(request.user_id),
|
|
|
|
"output_fields": ["document_id", "user_id"],
|
|
|
|
}
|
|
|
|
|
|
|
|
res = init.collection.search(**search_param)
|
|
|
|
|
|
|
|
document_ids = []
|
|
|
|
real_document = []
|
|
|
|
|
|
|
|
for i in range(len(res[0])):
|
|
|
|
_doc_id = res[0][i].id
|
|
|
|
print("正在获取 " + str(_doc_id) + " 的内容...")
|
|
|
|
|
|
|
|
try:
|
|
|
|
_doc_content = doc_client.stub.GetDocumentById(documents_pb2.GetDocumentByIdRequest(
|
|
|
|
id=_doc_id
|
|
|
|
))
|
|
|
|
_doc_content_full = _doc_content.title + "\n" + _doc_content.content
|
|
|
|
|
|
|
|
# real_document.append(_doc_content)
|
|
|
|
doc_obj = Document(page_content=_doc_content_full, metadata={"source": _doc_content.title})
|
|
|
|
|
|
|
|
real_document.append(doc_obj)
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
print(e)
|
|
|
|
|
2023-11-15 14:24:03 +00:00
|
|
|
# print(real_document)
|
|
|
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=0)
|
|
|
|
all_splits = text_splitter.split_documents(real_document)
|
|
|
|
|
|
|
|
print("real_document: ", all_splits)
|
|
|
|
|
|
|
|
# 文档长度
|
|
|
|
# print("文档长度: ", len(all_splits))
|
2023-11-15 08:20:30 +00:00
|
|
|
|
2023-11-15 13:12:06 +00:00
|
|
|
print("正在调用 LLM: " + question + "...")
|
2023-11-15 14:24:03 +00:00
|
|
|
|
|
|
|
chain = load_qa_with_sources_chain(OpenAI(temperature=0, max_tokens=4097), chain_type="map_reduce",
|
2023-11-15 13:12:06 +00:00
|
|
|
return_intermediate_steps=False,
|
|
|
|
verbose=False)
|
2023-11-15 14:24:03 +00:00
|
|
|
output = chain({"input_documents": all_splits, "question": question}, return_only_outputs=False)
|
2023-11-15 08:20:30 +00:00
|
|
|
print("回复:" + output["output_text"])
|
|
|
|
|
|
|
|
return document_query_pb2.QueryResponse(
|
|
|
|
text=output["output_text"]
|
2023-11-15 14:24:03 +00:00
|
|
|
# text = "test"
|
2023-11-15 08:20:30 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def serve():
|
|
|
|
_ADDR = os.getenv("BIND")
|
|
|
|
if _ADDR is None:
|
|
|
|
_ADDR = "[::]:50051"
|
|
|
|
print("Listening on", _ADDR)
|
|
|
|
|
|
|
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
|
|
|
document_query_pb2_grpc.add_DocumentQueryServicer_to_server(AIServer(), server)
|
|
|
|
server.add_insecure_port(_ADDR)
|
|
|
|
server.start()
|
|
|
|
server.wait_for_termination()
|