diff --git a/.idea/langchain.iml b/.idea/langchain.iml
index bda60cf..765ca18 100644
--- a/.idea/langchain.iml
+++ b/.idea/langchain.iml
@@ -5,4 +5,8 @@
+
+
+
+
\ No newline at end of file
diff --git a/document_query/README.md b/document_query/README.md
new file mode 100644
index 0000000..eda8426
--- /dev/null
+++ b/document_query/README.md
@@ -0,0 +1,10 @@
+# 环境变量
+
+gRPC 监听地址
+```bash
+BIND=0.0.0.0:12345
+MILVUS_ADDR=127.0.0.1
+MILVUS_PORT=19530
+OPENAI_API_BASE=http://
+OPENAI_API_KEY=
+```
\ No newline at end of file
diff --git a/document_query/ai.proto b/document_query/ai.proto
new file mode 100644
index 0000000..8e6c617
--- /dev/null
+++ b/document_query/ai.proto
@@ -0,0 +1,18 @@
+syntax = "proto3";
+
+service LLMQuery {
+ rpc AddDocument (QueryDocumentRequest) returns (stream QueryDocumentReply) {}
+}
+
+
+message QueryDocumentRequest {
+ string text = 1;
+ uint64 user_id = 2;
+ string database = 3;
+ string collection = 4;
+}
+
+message QueryDocumentReply {
+ string text = 1;
+}
+
diff --git a/document_query/ai_pb2.py b/document_query/ai_pb2.py
new file mode 100644
index 0000000..214adb4
--- /dev/null
+++ b/document_query/ai_pb2.py
@@ -0,0 +1,29 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: ai.proto
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x08\x61i.proto\"[\n\x14QueryDocumentRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x0f\n\x07user_id\x18\x02 \x01(\x04\x12\x10\n\x08\x64\x61tabase\x18\x03 \x01(\t\x12\x12\n\ncollection\x18\x04 \x01(\t\"\"\n\x12QueryDocumentReply\x12\x0c\n\x04text\x18\x01 \x01(\t2I\n\x08LLMQuery\x12=\n\x0b\x41\x64\x64\x44ocument\x12\x15.QueryDocumentRequest\x1a\x13.QueryDocumentReply\"\x00\x30\x01\x62\x06proto3')
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ai_pb2', _globals)
+if _descriptor._USE_C_DESCRIPTORS == False:
+ DESCRIPTOR._options = None
+ _globals['_QUERYDOCUMENTREQUEST']._serialized_start=12
+ _globals['_QUERYDOCUMENTREQUEST']._serialized_end=103
+ _globals['_QUERYDOCUMENTREPLY']._serialized_start=105
+ _globals['_QUERYDOCUMENTREPLY']._serialized_end=139
+ _globals['_LLMQUERY']._serialized_start=141
+ _globals['_LLMQUERY']._serialized_end=214
+# @@protoc_insertion_point(module_scope)
diff --git a/document_query/ai_pb2_grpc.py b/document_query/ai_pb2_grpc.py
new file mode 100644
index 0000000..98c01ff
--- /dev/null
+++ b/document_query/ai_pb2_grpc.py
@@ -0,0 +1,66 @@
+# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
+"""Client and server classes corresponding to protobuf-defined services."""
+import grpc
+
+import ai_pb2 as ai__pb2
+
+
+class LLMQueryStub(object):
+ """Missing associated documentation comment in .proto file."""
+
+ def __init__(self, channel):
+ """Constructor.
+
+ Args:
+ channel: A grpc.Channel.
+ """
+ self.AddDocument = channel.unary_stream(
+ '/LLMQuery/AddDocument',
+ request_serializer=ai__pb2.QueryDocumentRequest.SerializeToString,
+ response_deserializer=ai__pb2.QueryDocumentReply.FromString,
+ )
+
+
+class LLMQueryServicer(object):
+ """Missing associated documentation comment in .proto file."""
+
+ def AddDocument(self, request, context):
+ """Missing associated documentation comment in .proto file."""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+
+def add_LLMQueryServicer_to_server(servicer, server):
+ rpc_method_handlers = {
+ 'AddDocument': grpc.unary_stream_rpc_method_handler(
+ servicer.AddDocument,
+ request_deserializer=ai__pb2.QueryDocumentRequest.FromString,
+ response_serializer=ai__pb2.QueryDocumentReply.SerializeToString,
+ ),
+ }
+ generic_handler = grpc.method_handlers_generic_handler(
+ 'LLMQuery', rpc_method_handlers)
+ server.add_generic_rpc_handlers((generic_handler,))
+
+
+ # This class is part of an EXPERIMENTAL API.
+class LLMQuery(object):
+ """Missing associated documentation comment in .proto file."""
+
+ @staticmethod
+ def AddDocument(request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None):
+ return grpc.experimental.unary_stream(request, target, '/LLMQuery/AddDocument',
+ ai__pb2.QueryDocumentRequest.SerializeToString,
+ ai__pb2.QueryDocumentReply.FromString,
+ options, channel_credentials,
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
diff --git a/document_query/init.py b/document_query/init.py
new file mode 100644
index 0000000..a6e57ff
--- /dev/null
+++ b/document_query/init.py
@@ -0,0 +1,73 @@
+import os
+
+from langchain.embeddings import OpenAIEmbeddings
+from pymilvus import (
+ connections,
+ utility,
+ FieldSchema,
+ CollectionSchema,
+ DataType,
+ Collection,
+)
+
+# init
+MILVUS_HOST = os.getenv("MILVUS_HOST") or "127.0.0.1"
+MILVUS_PORT = os.getenv("MILVUS_PORT") or "19530"
+
+connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
+
+if not utility.has_collection("leaf_documents"):
+ _document_id = FieldSchema(
+ name="document_id",
+ dtype=DataType.INT64,
+ is_primary=True,
+ )
+ _user_id = FieldSchema(
+ name="user_id",
+ dtype=DataType.INT64,
+
+ )
+ _document_vector = FieldSchema(
+ name="vector",
+ dtype=DataType.FLOAT_VECTOR,
+ dim=2
+ )
+ schema = CollectionSchema(
+ fields=[_document_id, _user_id, _document_vector],
+ enable_dynamic_field=True
+ )
+ collection_name = "leaf_documents"
+ print("Create collection...")
+ _collection = Collection(
+ name=collection_name,
+ schema=schema,
+ using='default',
+ shards_num=2
+ )
+ _collection.create_index(
+ field_name="vector",
+ index_params={"metric_type": "L2", "M": 8, "efConstruction": 64, "index_type": "HNSW"},
+ )
+ _collection.create_index(
+ field_name="user_id",
+ index_name="idx_user_id"
+ )
+
+_collection = Collection("leaf_documents")
+_collection.load()
+
+embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
+
+
+def text_to_vector(text: str):
+ return embeddings.embed_query(text)
+
+
+def insert_document(document_id: int, user_id: int, vector: list):
+ return _collection.insert(
+ data=[
+ [document_id],
+ [user_id],
+ [vector]
+ ],
+ )
diff --git a/document_query/server.py b/document_query/server.py
new file mode 100644
index 0000000..b9b1286
--- /dev/null
+++ b/document_query/server.py
@@ -0,0 +1,38 @@
+import os
+from concurrent import futures
+import ai_pb2
+import ai_pb2_grpc
+import grpc
+import init
+
+
+class AIServer(ai_pb2_grpc.LLMQueryServicer):
+ def QueryDocumentRequest(self, request, context):
+ _text = request.text
+ _user_id = request.user_id
+ _database = request.database
+ _collection = request.collection
+
+ vector = init.text_to_vector(_text)
+ data = init.insert_document(_user_id, vector, _database)
+
+
+ return ai_pb2.AddDocumentReply(
+ id=request.text
+ )
+
+
+def serve():
+ _ADDR = os.getenv("BIND")
+ if _ADDR is None:
+ _ADDR = "[::]:50051"
+ print("Listening on", _ADDR)
+
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+ ai_pb2_grpc.add_LLMQueryServicer_to_server(AIServer(), server)
+ server.add_insecure_port(_ADDR)
+ server.start()
+ server.wait_for_termination()
+
+
+serve()
diff --git a/document_query/sync.py b/document_query/sync.py
new file mode 100644
index 0000000..e69de29
diff --git a/requirements.txt b/requirements.txt
index e69de29..3b0fabe 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -0,0 +1,4 @@
+langchain~=0.0.325
+pymilvus~=2.3.2
+pymysql~=1.1.0
+openai~=0.28.1
\ No newline at end of file
diff --git a/rpc/init.py b/rpc/init.py
index 0e76569..d285e9a 100644
--- a/rpc/init.py
+++ b/rpc/init.py
@@ -38,23 +38,23 @@ if not utility.has_collection("leaf_documents"):
)
collection_name = "leaf_documents"
print("Create collection...")
- collection = Collection(
+ _collection = Collection(
name=collection_name,
schema=schema,
using='default',
shards_num=2
)
- collection.create_index(
+ _collection.create_index(
field_name="vector",
index_params={"metric_type": "L2", "M": 8, "efConstruction": 64, "index_type": "HNSW"},
)
- collection.create_index(
+ _collection.create_index(
field_name="user_id",
index_name="idx_user_id"
)
-collection = Collection("leaf_documents")
-collection.load()
+_collection = Collection("leaf_documents")
+_collection.load()
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
@@ -63,7 +63,11 @@ def text_to_vector(text: str):
return embeddings.embed_query(text)
-def insert_document(document_id: int, vector: list):
- collection.insert(
- data=[document_id, vector],
+def insert_document(document_id: int, user_id: int, vector: list, collection: str):
+ return _collection.insert(
+ data=[
+ [document_id],
+ [user_id],
+ [vector]
+ ],
)
diff --git a/rpc/server.py b/rpc/server.py
index 92e11dd..364dbae 100644
--- a/rpc/server.py
+++ b/rpc/server.py
@@ -4,11 +4,19 @@ from langchain.embeddings import OpenAIEmbeddings
import ai_pb2
import ai_pb2_grpc
import grpc
+import init
class AIServer(ai_pb2_grpc.LLMQueryServicer):
def AddDocument(self, request, context):
- print("AddDocument called with", request.text)
+ _text = request.text
+ _user_id = request.user_id
+ _database = request.database
+ _collection = request.collection
+
+ vector = init.text_to_vector(_text)
+ data = init.insert_document(_user_id, vector, _database, _collection)
+
return ai_pb2.AddDocumentReply(
id=request.text