langchain-chat-with-milvus/document_ai/server.py

175 lines
4.7 KiB
Python
Raw Normal View History

2023-11-15 08:20:30 +00:00
import os
from concurrent import futures
2023-11-15 14:24:03 +00:00
import langchain
2023-11-22 02:13:58 +00:00
import openai
2023-11-15 14:24:03 +00:00
2023-11-18 15:08:22 +00:00
import proto.document_query_pb2
import proto.document_query_pb2_grpc
2023-11-15 08:20:30 +00:00
import grpc
2023-11-18 15:08:22 +00:00
import proto.documents_pb2
2023-11-15 08:20:30 +00:00
import init
import doc_client
2023-11-22 02:13:58 +00:00
# from langchain.llms.openai import OpenAI
# from langchain.schema.document import Document
# from langchain.embeddings import OpenAIEmbeddings
# from langchain.chains.qa_with_sources import load_qa_with_sources_chain
2023-11-18 15:08:22 +00:00
from langchain.text_splitter import RecursiveCharacterTextSplitter
2023-11-15 14:24:03 +00:00
from langchain.cache import InMemoryCache
langchain.llm_cache = InMemoryCache()
2023-11-18 15:08:22 +00:00
CHUNK_SIZE = 500
2023-11-22 02:13:58 +00:00
# openai.api_base = "https://api.openai.com/v1"
# openai.api_key="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"
2023-11-18 15:08:22 +00:00
class AIServer(proto.document_query_pb2_grpc.DocumentQuery):
def Query(self, target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
2023-11-15 08:20:30 +00:00
2023-11-18 15:08:22 +00:00
print("新的请求:" + target.question)
vec = init.text_to_vector(target.question)
2023-11-22 02:13:58 +00:00
question = target.question
2023-11-15 08:20:30 +00:00
search_param = {
"data": [vec],
"anns_field": "vector",
"param": {"metric_type": "L2"},
2023-11-19 12:54:12 +00:00
"limit": 5,
2023-11-22 02:13:58 +00:00
"expr": "user_id == " + str(target.user_id) + " && library_id == " + str(target.library_id),
"output_fields": ["document_id", "user_id", "library_id"],
2023-11-15 08:20:30 +00:00
}
res = init.collection.search(**search_param)
2023-11-22 02:13:58 +00:00
document_text = ""
# real_document = []
sources = []
2023-11-15 08:20:30 +00:00
for i in range(len(res[0])):
2023-11-18 15:08:22 +00:00
_chunk_id = res[0][i].id
print("正在获取分块 " + str(_chunk_id) + " 的内容...")
2023-11-15 08:20:30 +00:00
try:
2023-11-18 15:08:22 +00:00
_chunk_content = doc_client.stub.GetDocumentChunk(proto.documents_pb2.GetDocumentChunkByIdRequest(
id=_chunk_id
2023-11-15 08:20:30 +00:00
))
2023-11-22 03:15:38 +00:00
# print(_chunk_content)
2023-11-22 02:13:58 +00:00
# _doc_content_full = _chunk_content.content
document_text += "\n" + _chunk_content.content + "\n"
2023-11-18 15:08:22 +00:00
2023-11-22 02:13:58 +00:00
# append
sources.append({
"text": _chunk_content.content,
2023-11-22 03:15:38 +00:00
"document_id": _chunk_content.document.id,
"title": _chunk_content.document.title
2023-11-22 02:13:58 +00:00
})
2023-11-15 08:20:30 +00:00
2023-11-22 02:13:58 +00:00
# doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"})
# real_document.append(doc_obj)
2023-11-15 08:20:30 +00:00
except Exception as e:
print(e)
2023-11-18 15:08:22 +00:00
print("正在调用 LLM...")
2023-11-15 14:24:03 +00:00
2023-11-22 02:13:58 +00:00
output = search(document_text, question)
2023-11-15 14:24:03 +00:00
2023-11-22 03:15:38 +00:00
print(sources)
2023-11-22 02:13:58 +00:00
print("完成。")
2023-11-18 15:08:22 +00:00
return proto.document_query_pb2.QueryResponse(
2023-11-22 02:13:58 +00:00
text=output,
sources=sources
2023-11-18 15:08:22 +00:00
)
2023-11-15 14:24:03 +00:00
2023-11-18 15:08:22 +00:00
def Chunk(self,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=20,
length_function=len,
add_start_index=True,
)
2023-11-15 08:20:30 +00:00
2023-11-18 15:08:22 +00:00
page_contents = text_splitter.create_documents([
target.text
])
2023-11-15 14:24:03 +00:00
2023-11-18 15:08:22 +00:00
texts = []
2023-11-15 08:20:30 +00:00
2023-11-18 15:08:22 +00:00
for page_content in page_contents:
texts.append(page_content.page_content)
return proto.document_query_pb2.ChunkResponse(
texts=texts
2023-11-15 08:20:30 +00:00
)
def serve():
_ADDR = os.getenv("BIND")
if _ADDR is None:
_ADDR = "[::]:50051"
print("Listening on", _ADDR)
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
2023-11-18 15:08:22 +00:00
proto.document_query_pb2_grpc.add_DocumentQueryServicer_to_server(AIServer(), server)
2023-11-15 08:20:30 +00:00
server.add_insecure_port(_ADDR)
server.start()
server.wait_for_termination()
2023-11-18 15:08:22 +00:00
2023-11-22 02:13:58 +00:00
def search(summaries: str, question: str) -> str:
prompt = f"""
使用以下文档回答问题使用Markdown回答你得用的身份指代用户如果你不知道答案你可以说你不知道不要编造答案总是使用中文回复
QUESTION: {question}
===文档开始===
{summaries}
===文档结束===
FINAL ANSWER:
"""
messages = [
{
"role": "user",
"content": prompt
}
]
print(prompt)
result = openai.ChatCompletion.create(
messages=messages, model="gpt-3.5-turbo", temperature=0
)
res = result["choices"][0]["message"].to_dict_recursive()
print(res)
return res["content"]
2023-11-18 15:08:22 +00:00
if __name__ == '__main__':
serve()