166 lines
4.1 KiB
Python
166 lines
4.1 KiB
Python
import random
|
|
|
|
import pymysql
|
|
from langchain.docstore.document import Document
|
|
from os import environ
|
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
|
from langchain.vectorstores import Milvus
|
|
from langchain.document_loaders import WebBaseLoader
|
|
from langchain.text_splitter import CharacterTextSplitter
|
|
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
|
from langchain.llms import OpenAI
|
|
|
|
MILVUS_HOST = "127.0.0.1"
|
|
MILVUS_PORT = "19530"
|
|
|
|
from pymilvus import (
|
|
connections,
|
|
utility,
|
|
FieldSchema,
|
|
CollectionSchema,
|
|
DataType,
|
|
Collection,
|
|
)
|
|
|
|
# create connect
|
|
connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
|
|
|
|
# if not has book collection, create
|
|
if not utility.has_collection("todos"):
|
|
pk = FieldSchema(
|
|
name="pk",
|
|
dtype=DataType.INT64,
|
|
is_primary=True,
|
|
auto_id=True,
|
|
)
|
|
|
|
todo_id = FieldSchema(
|
|
name="todo_id",
|
|
dtype=DataType.INT64
|
|
)
|
|
todo_title = FieldSchema(
|
|
name="title",
|
|
dtype=DataType.VARCHAR,
|
|
max_length=65535,
|
|
default_value="Unknown"
|
|
)
|
|
source = FieldSchema(
|
|
name="source",
|
|
dtype=DataType.VARCHAR,
|
|
max_length=65535,
|
|
default_value="Unknown"
|
|
)
|
|
todo_description = FieldSchema(
|
|
name="todo_description",
|
|
dtype=DataType.VARCHAR,
|
|
max_length=65535,
|
|
default_value="Unknown"
|
|
)
|
|
todo_language = FieldSchema(
|
|
name="language",
|
|
dtype=DataType.VARCHAR,
|
|
max_length=65535,
|
|
default_value="zh_CN"
|
|
)
|
|
todo_text = FieldSchema(
|
|
name="text",
|
|
dtype=DataType.VARCHAR,
|
|
max_length=65535,
|
|
default_value="zh_CN"
|
|
)
|
|
user_id = FieldSchema(
|
|
name="user_id",
|
|
dtype=DataType.INT64,
|
|
)
|
|
todo_intro = FieldSchema(
|
|
name="vector",
|
|
dtype=DataType.FLOAT_VECTOR,
|
|
dim=1536,
|
|
)
|
|
schema = CollectionSchema(
|
|
fields=[pk, todo_id, source, todo_title, todo_description, todo_text, todo_language, user_id, todo_intro],
|
|
description="Test book search",
|
|
enable_dynamic_field=True
|
|
)
|
|
collection_name = "todos"
|
|
print("Create collection...")
|
|
collection = Collection(
|
|
name=collection_name,
|
|
schema=schema,
|
|
using='default',
|
|
)
|
|
# index
|
|
print("Create index: todo_intro...")
|
|
collection.create_index(
|
|
field_name="vector",
|
|
index_params={"metric_type": "L2", "M": 8, "efConstruction": 64, "index_type": "HNSW"},
|
|
)
|
|
collection.create_index(
|
|
field_name="user_id",
|
|
|
|
index_name="index"
|
|
)
|
|
|
|
# load
|
|
print("Loading data...")
|
|
collection.load()
|
|
|
|
# 打开数据库连接
|
|
db = pymysql.connect(host='localhost',
|
|
port=64639,
|
|
user='root',
|
|
password='6HbuKyjHO5',
|
|
database='go-todo')
|
|
|
|
# 使用 cursor() 方法创建一个游标对象 cursor
|
|
cursor = db.cursor()
|
|
|
|
# get all vector_id = null
|
|
sql = "SELECT * FROM `todos` WHERE `vector_id` IS NULL"
|
|
|
|
# 使用 execute() 方法执行 SQL 查询
|
|
cursor.execute(sql)
|
|
|
|
# 获取所有
|
|
results = cursor.fetchall()
|
|
db.close()
|
|
|
|
for row in results:
|
|
todo__id = row[0]
|
|
todo__title = row[1]
|
|
todo__description = row[2]
|
|
todo__user_id = row[3]
|
|
|
|
todoData = "Id: " + str(todo__id) + ";Title: " + todo__title + "\n" + ";Content: " + todo__description + "\n"
|
|
|
|
doc = Document(page_content=todoData)
|
|
|
|
# ins_data[0].append(todo__id)
|
|
# ins_data[1].append(todo__title)
|
|
# ins_data[2].append(todo__description)
|
|
# ins_data[3].append(todo__user_id)
|
|
|
|
print("转换为向量")
|
|
# 转换为向量
|
|
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
|
|
vec = embeddings.embed_query(
|
|
todo__title + "\n" + todo__description
|
|
)
|
|
|
|
collection = Collection("todos")
|
|
|
|
mr = collection.insert([
|
|
[todo__id],
|
|
["todo.awa.im"],
|
|
[todo__title],
|
|
[todo__title + todo__description],
|
|
[todo__title + todo__description],
|
|
["zh_CN"],
|
|
[todo__user_id],
|
|
[vec],
|
|
])
|
|
|
|
print(mr)
|
|
|
|
print(doc)
|