74 lines
1.7 KiB
Python
74 lines
1.7 KiB
Python
import os
|
|
|
|
from langchain.embeddings import OpenAIEmbeddings
|
|
from pymilvus import (
|
|
connections,
|
|
utility,
|
|
FieldSchema,
|
|
CollectionSchema,
|
|
DataType,
|
|
Collection,
|
|
)
|
|
|
|
# init
|
|
MILVUS_HOST = os.getenv("MILVUS_HOST") or "127.0.0.1"
|
|
MILVUS_PORT = os.getenv("MILVUS_PORT") or "19530"
|
|
|
|
connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
|
|
|
|
if not utility.has_collection("leaf_documents"):
|
|
_document_id = FieldSchema(
|
|
name="document_id",
|
|
dtype=DataType.INT64,
|
|
is_primary=True,
|
|
)
|
|
_user_id = FieldSchema(
|
|
name="user_id",
|
|
dtype=DataType.INT64,
|
|
|
|
)
|
|
_document_vector = FieldSchema(
|
|
name="vector",
|
|
dtype=DataType.FLOAT_VECTOR,
|
|
dim=2
|
|
)
|
|
schema = CollectionSchema(
|
|
fields=[_document_id, _user_id, _document_vector],
|
|
enable_dynamic_field=True
|
|
)
|
|
collection_name = "leaf_documents"
|
|
print("Create collection...")
|
|
_collection = Collection(
|
|
name=collection_name,
|
|
schema=schema,
|
|
using='default',
|
|
shards_num=2
|
|
)
|
|
_collection.create_index(
|
|
field_name="vector",
|
|
index_params={"metric_type": "L2", "M": 8, "efConstruction": 64, "index_type": "HNSW"},
|
|
)
|
|
_collection.create_index(
|
|
field_name="user_id",
|
|
index_name="idx_user_id"
|
|
)
|
|
|
|
_collection = Collection("leaf_documents")
|
|
_collection.load()
|
|
|
|
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
|
|
|
|
|
|
def text_to_vector(text: str):
|
|
return embeddings.embed_query(text)
|
|
|
|
|
|
def insert_document(document_id: int, user_id: int, vector: list, collection: str):
|
|
return _collection.insert(
|
|
data=[
|
|
[document_id],
|
|
[user_id],
|
|
[vector]
|
|
],
|
|
)
|