import os from concurrent import futures import langchain from langchain.text_splitter import RecursiveCharacterTextSplitter import document_query_pb2 import document_query_pb2_grpc import grpc import documents_pb2 import init import doc_client from langchain.llms.openai import OpenAI from langchain.schema.document import Document from langchain.embeddings import OpenAIEmbeddings from langchain.chains.qa_with_sources import load_qa_with_sources_chain from langchain.cache import InMemoryCache langchain.llm_cache = InMemoryCache() class AIServer(document_query_pb2_grpc.DocumentQuery): def Query(self, request, context): vec = init.text_to_vector(request.question) question = request.question + "(必须使用中文回复)" search_param = { "data": [vec], "anns_field": "vector", "param": {"metric_type": "L2"}, "limit": 10, "expr": "user_id == " + str(request.user_id), "output_fields": ["document_id", "user_id"], } res = init.collection.search(**search_param) document_ids = [] real_document = [] for i in range(len(res[0])): _doc_id = res[0][i].id print("正在获取 " + str(_doc_id) + " 的内容...") try: _doc_content = doc_client.stub.GetDocumentById(documents_pb2.GetDocumentByIdRequest( id=_doc_id )) _doc_content_full = _doc_content.title + "\n" + _doc_content.content # real_document.append(_doc_content) doc_obj = Document(page_content=_doc_content_full, metadata={"source": _doc_content.title}) real_document.append(doc_obj) except Exception as e: print(e) # print(real_document) text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=0) all_splits = text_splitter.split_documents(real_document) print("real_document: ", all_splits) # 文档长度 # print("文档长度: ", len(all_splits)) print("正在调用 LLM: " + question + "...") chain = load_qa_with_sources_chain(OpenAI(temperature=0, max_tokens=4097), chain_type="map_reduce", return_intermediate_steps=False, verbose=False) output = chain({"input_documents": all_splits, "question": question}, return_only_outputs=False) print("回复:" + output["output_text"]) return document_query_pb2.QueryResponse( text=output["output_text"] # text = "test" ) def serve(): _ADDR = os.getenv("BIND") if _ADDR is None: _ADDR = "[::]:50051" print("Listening on", _ADDR) server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) document_query_pb2_grpc.add_DocumentQueryServicer_to_server(AIServer(), server) server.add_insecure_port(_ADDR) server.start() server.wait_for_termination()