66 lines
1.9 KiB
Python
66 lines
1.9 KiB
Python
import json
|
||
|
||
from langchain import text_splitter
|
||
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
||
from langchain.embeddings import OpenAIEmbeddings
|
||
from langchain.llms.openai import OpenAI
|
||
from langchain.schema.document import Document
|
||
from pymilvus import (
|
||
connections,
|
||
utility,
|
||
FieldSchema,
|
||
CollectionSchema,
|
||
DataType,
|
||
Collection,
|
||
)
|
||
|
||
MILVUS_HOST = "127.0.0.1"
|
||
MILVUS_PORT = "19530"
|
||
|
||
question = "这个 yarn 为什么会发生错误,该怎么解决?reply in spoken language "
|
||
|
||
# 准备嵌入模型
|
||
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
|
||
vec = embeddings.embed_query(question)
|
||
|
||
connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
|
||
|
||
collection = Collection("todos")
|
||
search_param = {
|
||
"data": [vec],
|
||
"anns_field": "vector",
|
||
"param": {"metric_type": "L2"},
|
||
"limit": 10,
|
||
"expr": "user_id == 2",
|
||
"output_fields": ["todo_id", "title", "source", "todo_description", "language", "text", "user_id"],
|
||
}
|
||
res = collection.search(**search_param)
|
||
|
||
json_string = []
|
||
|
||
for i in range(len(res[0])):
|
||
document_content = res[0][i].get("text")
|
||
document_source = res[0][i].get("source")
|
||
|
||
doc_obj = Document(page_content=document_content, metadata={"source": document_source})
|
||
|
||
# append to json_string
|
||
|
||
json_string.append(doc_obj)
|
||
|
||
# print(json_string)
|
||
|
||
|
||
# res_data = [Document(page_content=res[0][0].get("text"), metadata={"source": "local"})]
|
||
# res_data = Document(page_content="text", metadata={"source": "local"})
|
||
# texts = text_splitter.split_text_on_tokens()
|
||
|
||
# # search data
|
||
# print(res_data.page_content)
|
||
|
||
print("正在调用 LLM...")
|
||
chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce", return_intermediate_steps=False,
|
||
verbose=False)
|
||
output = chain({"input_documents": json_string, "question": question}, return_only_outputs=True)
|
||
print("回复:" + output["output_text"])
|