langchain-chat-with-milvus/document_ai/server.py

173 lines
4.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from concurrent import futures
import langchain
import openai
import proto.document_query_pb2
import proto.document_query_pb2_grpc
import grpc
import proto.documents_pb2
import init
import doc_client
# from langchain.llms.openai import OpenAI
# from langchain.schema.document import Document
# from langchain.embeddings import OpenAIEmbeddings
# from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.cache import InMemoryCache
langchain.llm_cache = InMemoryCache()
CHUNK_SIZE = 500
# openai.api_base = "https://api.openai.com/v1"
# openai.api_key="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"
class AIServer(proto.document_query_pb2_grpc.DocumentQuery):
def Query(self, target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
print("新的请求:" + target.question)
vec = init.text_to_vector(target.question)
question = target.question
search_param = {
"data": [vec],
"anns_field": "vector",
"param": {"metric_type": "L2"},
"limit": 5,
"expr": "user_id == " + str(target.user_id) + " && library_id == " + str(target.library_id),
"output_fields": ["document_id", "user_id", "library_id"],
}
res = init.collection.search(**search_param)
document_text = ""
# real_document = []
sources = []
for i in range(len(res[0])):
_chunk_id = res[0][i].id
print("正在获取分块 " + str(_chunk_id) + " 的内容...")
try:
_chunk_content = doc_client.stub.GetDocumentChunk(proto.documents_pb2.GetDocumentChunkByIdRequest(
id=_chunk_id
))
print(_chunk_content.document)
# _doc_content_full = _chunk_content.content
document_text += "\n" + _chunk_content.content + "\n"
# append
sources.append({
"text": _chunk_content.content,
"document_id": _chunk_content.document.id
})
# doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"})
# real_document.append(doc_obj)
except Exception as e:
print(e)
print("正在调用 LLM...")
output = search(document_text, question)
print("完成。")
return proto.document_query_pb2.QueryResponse(
text=output,
sources=sources
)
def Chunk(self,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=20,
length_function=len,
add_start_index=True,
)
page_contents = text_splitter.create_documents([
target.text
])
texts = []
for page_content in page_contents:
texts.append(page_content.page_content)
return proto.document_query_pb2.ChunkResponse(
texts=texts
)
def serve():
_ADDR = os.getenv("BIND")
if _ADDR is None:
_ADDR = "[::]:50051"
print("Listening on", _ADDR)
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
proto.document_query_pb2_grpc.add_DocumentQueryServicer_to_server(AIServer(), server)
server.add_insecure_port(_ADDR)
server.start()
server.wait_for_termination()
def search(summaries: str, question: str) -> str:
prompt = f"""
使用以下文档回答问题使用Markdown回答你得用“你”的身份指代用户。如果你不知道答案你可以说你不知道不要编造答案。总是使用中文回复。
QUESTION: {question}
===文档开始===
{summaries}
===文档结束===
FINAL ANSWER:
"""
messages = [
{
"role": "user",
"content": prompt
}
]
print(prompt)
result = openai.ChatCompletion.create(
messages=messages, model="gpt-3.5-turbo", temperature=0
)
res = result["choices"][0]["message"].to_dict_recursive()
print(res)
return res["content"]
if __name__ == '__main__':
serve()