langchain-chat-with-milvus/query_from_user_ai.py
2023-11-13 20:23:15 +08:00

66 lines
1.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
from langchain import text_splitter
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.embeddings import OpenAIEmbeddings
from langchain.llms.openai import OpenAI
from langchain.schema.document import Document
from pymilvus import (
connections,
utility,
FieldSchema,
CollectionSchema,
DataType,
Collection,
)
MILVUS_HOST = "127.0.0.1"
MILVUS_PORT = "19530"
question = "这个 yarn 为什么会发生错误该怎么解决reply in spoken language "
# 准备嵌入模型
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
vec = embeddings.embed_query(question)
connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
collection = Collection("todos")
search_param = {
"data": [vec],
"anns_field": "vector",
"param": {"metric_type": "L2"},
"limit": 10,
"expr": "user_id == 2",
"output_fields": ["todo_id", "title", "source", "todo_description", "language", "text", "user_id"],
}
res = collection.search(**search_param)
json_string = []
for i in range(len(res[0])):
document_content = res[0][i].get("text")
document_source = res[0][i].get("source")
doc_obj = Document(page_content=document_content, metadata={"source": document_source})
# append to json_string
json_string.append(doc_obj)
# print(json_string)
# res_data = [Document(page_content=res[0][0].get("text"), metadata={"source": "local"})]
# res_data = Document(page_content="text", metadata={"source": "local"})
# texts = text_splitter.split_text_on_tokens()
# # search data
# print(res_data.page_content)
print("正在调用 LLM...")
chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce", return_intermediate_steps=False,
verbose=False)
output = chain({"input_documents": json_string, "question": question}, return_only_outputs=True)
print("回复:" + output["output_text"])