langchain-chat-with-milvus/milvus.py
2023-11-13 20:23:15 +08:00

140 lines
3.8 KiB
Python

from os import environ
MILVUS_HOST = "127.0.0.1"
MILVUS_PORT = "19530"
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Milvus
from langchain.document_loaders import WebBaseLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.llms import OpenAI
import random
from pymilvus import (
connections,
utility,
FieldSchema,
CollectionSchema,
DataType,
Collection,
)
# create connect
connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
# if not has book collection, create
if not utility.has_collection("book"):
book_id = FieldSchema(
name="book_id",
dtype=DataType.INT64,
is_primary=True,
)
book_name = FieldSchema(
name="book_name",
dtype=DataType.VARCHAR,
max_length=200,
# The default value will be used if this field is left empty during data inserts or upserts.
# The data type of `default_value` must be the same as that specified in `dtype`.
default_value="Unknown"
)
word_count = FieldSchema(
name="word_count",
dtype=DataType.INT64,
# The default value will be used if this field is left empty during data inserts or upserts.
# The data type of `default_value` must be the same as that specified in `dtype`.
default_value=9999
)
book_intro = FieldSchema(
name="book_intro",
dtype=DataType.FLOAT_VECTOR,
dim=2
)
schema = CollectionSchema(
fields=[book_id, book_name, word_count, book_intro],
description="Test book search",
enable_dynamic_field=True
)
collection_name = "book"
print("Create collection...")
collection = Collection(
name=collection_name,
schema=schema,
using='default',
shards_num=2
)
data = [
[i for i in range(2000)],
[str(i) for i in range(2000)],
[i for i in range(10000, 12000)],
[[random.random() for _ in range(2)] for _ in range(2000)],
]
collection = Collection("book") # Get an existing collection.
# # if not load, load
# if not collection.is_loaded:
# collection.load()
mr = collection.insert(data)
# exit
exit(0)
print("读取文档")
loader = WebBaseLoader([
"https://ivampiresp.com/2022/10/25/nginx-dynamic-reverse-proxy-expose-intranet-http-service",
])
print("加载文档")
docs = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1024, chunk_overlap=0)
docs = text_splitter.split_documents(docs)
print("转换为向量")
# 转换为向量
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
# # Query Milvus
# vector_db = Milvus(
# embedding_function=embeddings,
# connection_args={"host": MILVUS_HOST, "port": MILVUS_PORT},
# )
#
# # 根据 url 搜索来去重
# docs = vector_db.similarity_search(query=docs, k=1)
#
print("存储向量")
vector_db = Milvus.from_documents(docs, embedding=embeddings, connection_args={
"host": MILVUS_HOST, "port": MILVUS_PORT
})
print("存储完成")
# vector_db = Milvus.from_documents(docs, embedding=embeddings, connection_args={
# "uri": "https://in03-d25b13fd0ed7426.api.gcp-us-west1.zillizcloud.com",
# "token": "595921e6226168e620de54ab4867392186259e784e3161b2347fbb41757423b4423edf9a6e9e14fc325bf4ff0d20d7f814b8cce9"
# })
#
# print("执行查询")
# query = ""
#
# print("相似度搜索")
# docs = vector_db.similarity_search(query)
#
# print("内容")
# content = docs[0].page_content
# print(content)
print("提出问题")
chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce", return_intermediate_steps=True)
query = "首页是什么"
output = chain({"input_documents": docs, "question": query}, return_only_outputs=True)
print(output)