langchain-chat-with-milvus/query_from_user_ai.py

66 lines
1.9 KiB
Python
Raw Normal View History

2023-11-13 12:23:15 +00:00
import json
from langchain import text_splitter
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.embeddings import OpenAIEmbeddings
from langchain.llms.openai import OpenAI
from langchain.schema.document import Document
from pymilvus import (
connections,
utility,
FieldSchema,
CollectionSchema,
DataType,
Collection,
)
MILVUS_HOST = "127.0.0.1"
MILVUS_PORT = "19530"
question = "这个 yarn 为什么会发生错误该怎么解决reply in spoken language "
# 准备嵌入模型
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
vec = embeddings.embed_query(question)
connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
collection = Collection("todos")
search_param = {
"data": [vec],
"anns_field": "vector",
"param": {"metric_type": "L2"},
"limit": 10,
"expr": "user_id == 2",
"output_fields": ["todo_id", "title", "source", "todo_description", "language", "text", "user_id"],
}
res = collection.search(**search_param)
json_string = []
for i in range(len(res[0])):
document_content = res[0][i].get("text")
document_source = res[0][i].get("source")
doc_obj = Document(page_content=document_content, metadata={"source": document_source})
# append to json_string
json_string.append(doc_obj)
# print(json_string)
# res_data = [Document(page_content=res[0][0].get("text"), metadata={"source": "local"})]
# res_data = Document(page_content="text", metadata={"source": "local"})
# texts = text_splitter.split_text_on_tokens()
# # search data
# print(res_data.page_content)
print("正在调用 LLM...")
chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce", return_intermediate_steps=False,
verbose=False)
output = chain({"input_documents": json_string, "question": question}, return_only_outputs=True)
print("回复:" + output["output_text"])