langchain-chat-with-milvus/document_ai/server.py
2023-11-19 20:54:12 +08:00

140 lines
3.9 KiB
Python

import os
from concurrent import futures
import langchain
import proto.document_query_pb2
import proto.document_query_pb2_grpc
import grpc
import proto.documents_pb2
import init
import doc_client
from langchain.llms.openai import OpenAI
from langchain.schema.document import Document
from langchain.embeddings import OpenAIEmbeddings
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.cache import InMemoryCache
langchain.llm_cache = InMemoryCache()
CHUNK_SIZE = 500
class AIServer(proto.document_query_pb2_grpc.DocumentQuery):
def Query(self, target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
print("新的请求:" + target.question)
vec = init.text_to_vector(target.question)
question = "Reply in spoken language:" + target.question
search_param = {
"data": [vec],
"anns_field": "vector",
"param": {"metric_type": "L2"},
"limit": 5,
"expr": "user_id == " + str(target.user_id),
"output_fields": ["document_id", "user_id"],
}
res = init.collection.search(**search_param)
# # 最多 5 个
# if len(res[0]) > 5:
# res[0] = res[0][:5]
# document_chunk_ids = []
real_document = []
for i in range(len(res[0])):
_chunk_id = res[0][i].id
print("正在获取分块 " + str(_chunk_id) + " 的内容...")
try:
_chunk_content = doc_client.stub.GetDocumentChunk(proto.documents_pb2.GetDocumentChunkByIdRequest(
id=_chunk_id
))
_doc_content_full = _chunk_content.content
doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"})
real_document.append(doc_obj)
except Exception as e:
print(e)
print(real_document)
print("正在调用 LLM...")
chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce",
return_intermediate_steps=True,
verbose=True)
output = chain({"input_documents": real_document, "question": question}, return_only_outputs=False)
print("回复:" + output["output_text"])
return proto.document_query_pb2.QueryResponse(
text=output["output_text"]
)
def Chunk(self,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=20,
length_function=len,
add_start_index=True,
)
page_contents = text_splitter.create_documents([
target.text
])
texts = []
for page_content in page_contents:
texts.append(page_content.page_content)
return proto.document_query_pb2.ChunkResponse(
texts=texts
)
def serve():
_ADDR = os.getenv("BIND")
if _ADDR is None:
_ADDR = "[::]:50051"
print("Listening on", _ADDR)
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
proto.document_query_pb2_grpc.add_DocumentQueryServicer_to_server(AIServer(), server)
server.add_insecure_port(_ADDR)
server.start()
server.wait_for_termination()
if __name__ == '__main__':
serve()