173 lines
4.6 KiB
Python
173 lines
4.6 KiB
Python
import os
|
||
from concurrent import futures
|
||
|
||
import langchain
|
||
import openai
|
||
|
||
import proto.document_query_pb2
|
||
import proto.document_query_pb2_grpc
|
||
import grpc
|
||
import proto.documents_pb2
|
||
import init
|
||
import doc_client
|
||
# from langchain.llms.openai import OpenAI
|
||
# from langchain.schema.document import Document
|
||
# from langchain.embeddings import OpenAIEmbeddings
|
||
# from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||
from langchain.cache import InMemoryCache
|
||
|
||
langchain.llm_cache = InMemoryCache()
|
||
|
||
CHUNK_SIZE = 500
|
||
|
||
# openai.api_base = "https://api.openai.com/v1"
|
||
# openai.api_key="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"
|
||
|
||
class AIServer(proto.document_query_pb2_grpc.DocumentQuery):
|
||
def Query(self, target,
|
||
options=(),
|
||
channel_credentials=None,
|
||
call_credentials=None,
|
||
insecure=False,
|
||
compression=None,
|
||
wait_for_ready=None,
|
||
timeout=None,
|
||
metadata=None):
|
||
|
||
print("新的请求:" + target.question)
|
||
vec = init.text_to_vector(target.question)
|
||
|
||
question = target.question
|
||
|
||
search_param = {
|
||
"data": [vec],
|
||
"anns_field": "vector",
|
||
"param": {"metric_type": "L2"},
|
||
"limit": 5,
|
||
"expr": "user_id == " + str(target.user_id) + " && library_id == " + str(target.library_id),
|
||
"output_fields": ["document_id", "user_id", "library_id"],
|
||
}
|
||
|
||
res = init.collection.search(**search_param)
|
||
|
||
document_text = ""
|
||
# real_document = []
|
||
sources = []
|
||
|
||
for i in range(len(res[0])):
|
||
_chunk_id = res[0][i].id
|
||
print("正在获取分块 " + str(_chunk_id) + " 的内容...")
|
||
|
||
try:
|
||
_chunk_content = doc_client.stub.GetDocumentChunk(proto.documents_pb2.GetDocumentChunkByIdRequest(
|
||
id=_chunk_id
|
||
))
|
||
|
||
print(_chunk_content.document)
|
||
|
||
# _doc_content_full = _chunk_content.content
|
||
document_text += "\n" + _chunk_content.content + "\n"
|
||
|
||
# append
|
||
sources.append({
|
||
"text": _chunk_content.content,
|
||
"document_id": _chunk_content.document.id
|
||
})
|
||
|
||
# doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"})
|
||
|
||
# real_document.append(doc_obj)
|
||
|
||
except Exception as e:
|
||
print(e)
|
||
|
||
print("正在调用 LLM...")
|
||
|
||
output = search(document_text, question)
|
||
|
||
print("完成。")
|
||
return proto.document_query_pb2.QueryResponse(
|
||
text=output,
|
||
sources=sources
|
||
)
|
||
|
||
def Chunk(self,
|
||
target,
|
||
options=(),
|
||
channel_credentials=None,
|
||
call_credentials=None,
|
||
insecure=False,
|
||
compression=None,
|
||
wait_for_ready=None,
|
||
timeout=None,
|
||
metadata=None):
|
||
|
||
text_splitter = RecursiveCharacterTextSplitter(
|
||
chunk_size=CHUNK_SIZE,
|
||
chunk_overlap=20,
|
||
length_function=len,
|
||
add_start_index=True,
|
||
)
|
||
|
||
page_contents = text_splitter.create_documents([
|
||
target.text
|
||
])
|
||
|
||
texts = []
|
||
|
||
for page_content in page_contents:
|
||
texts.append(page_content.page_content)
|
||
|
||
return proto.document_query_pb2.ChunkResponse(
|
||
texts=texts
|
||
)
|
||
|
||
|
||
def serve():
|
||
_ADDR = os.getenv("BIND")
|
||
if _ADDR is None:
|
||
_ADDR = "[::]:50051"
|
||
print("Listening on", _ADDR)
|
||
|
||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||
proto.document_query_pb2_grpc.add_DocumentQueryServicer_to_server(AIServer(), server)
|
||
|
||
server.add_insecure_port(_ADDR)
|
||
server.start()
|
||
server.wait_for_termination()
|
||
|
||
|
||
def search(summaries: str, question: str) -> str:
|
||
prompt = f"""
|
||
使用以下文档回答问题,使用Markdown回答你得用“你”的身份指代用户。如果你不知道答案,你可以说你不知道,不要编造答案。总是使用中文回复。
|
||
|
||
QUESTION: {question}
|
||
|
||
===文档开始===
|
||
{summaries}
|
||
===文档结束===
|
||
|
||
FINAL ANSWER:
|
||
"""
|
||
|
||
messages = [
|
||
{
|
||
"role": "user",
|
||
"content": prompt
|
||
}
|
||
]
|
||
|
||
print(prompt)
|
||
|
||
result = openai.ChatCompletion.create(
|
||
messages=messages, model="gpt-3.5-turbo", temperature=0
|
||
)
|
||
res = result["choices"][0]["message"].to_dict_recursive()
|
||
print(res)
|
||
return res["content"]
|
||
|
||
|
||
if __name__ == '__main__':
|
||
serve()
|