diff --git a/rpc/README.md b/rpc/README.md new file mode 100644 index 0000000..eda8426 --- /dev/null +++ b/rpc/README.md @@ -0,0 +1,10 @@ +# 环境变量 + +gRPC 监听地址 +```bash +BIND=0.0.0.0:12345 +MILVUS_ADDR=127.0.0.1 +MILVUS_PORT=19530 +OPENAI_API_BASE=http:// +OPENAI_API_KEY= +``` \ No newline at end of file diff --git a/rpc/ai.proto b/rpc/ai.proto index eab85ac..b8664aa 100644 --- a/rpc/ai.proto +++ b/rpc/ai.proto @@ -14,4 +14,5 @@ message AddDocumentRequest { message AddDocumentReply { string id = 1; -} \ No newline at end of file +} + diff --git a/rpc/ai_pb2.py b/rpc/ai_pb2.py index 5a89d60..f4d65f5 100644 --- a/rpc/ai_pb2.py +++ b/rpc/ai_pb2.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! -# source: rpc/ai.proto +# source: ai.proto """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -13,17 +13,17 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0crpc/ai.proto\"Y\n\x12\x41\x64\x64\x44ocumentRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x0f\n\x07user_id\x18\x02 \x01(\x04\x12\x10\n\x08\x64\x61tabase\x18\x03 \x01(\t\x12\x12\n\ncollection\x18\x04 \x01(\t\"\x1e\n\x10\x41\x64\x64\x44ocumentReply\x12\n\n\x02id\x18\x01 \x01(\t2C\n\x08LLMQuery\x12\x37\n\x0b\x41\x64\x64\x44ocument\x12\x13.AddDocumentRequest\x1a\x11.AddDocumentReply\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x08\x61i.proto\"Y\n\x12\x41\x64\x64\x44ocumentRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x0f\n\x07user_id\x18\x02 \x01(\x04\x12\x10\n\x08\x64\x61tabase\x18\x03 \x01(\t\x12\x12\n\ncollection\x18\x04 \x01(\t\"\x1e\n\x10\x41\x64\x64\x44ocumentReply\x12\n\n\x02id\x18\x01 \x01(\t2C\n\x08LLMQuery\x12\x37\n\x0b\x41\x64\x64\x44ocument\x12\x13.AddDocumentRequest\x1a\x11.AddDocumentReply\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'rpc.ai_pb2', _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ai_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_ADDDOCUMENTREQUEST']._serialized_start=16 - _globals['_ADDDOCUMENTREQUEST']._serialized_end=105 - _globals['_ADDDOCUMENTREPLY']._serialized_start=107 - _globals['_ADDDOCUMENTREPLY']._serialized_end=137 - _globals['_LLMQUERY']._serialized_start=139 - _globals['_LLMQUERY']._serialized_end=206 + _globals['_ADDDOCUMENTREQUEST']._serialized_start=12 + _globals['_ADDDOCUMENTREQUEST']._serialized_end=101 + _globals['_ADDDOCUMENTREPLY']._serialized_start=103 + _globals['_ADDDOCUMENTREPLY']._serialized_end=133 + _globals['_LLMQUERY']._serialized_start=135 + _globals['_LLMQUERY']._serialized_end=202 # @@protoc_insertion_point(module_scope) diff --git a/rpc/ai_pb2.pyi b/rpc/ai_pb2.pyi new file mode 100644 index 0000000..71b4314 --- /dev/null +++ b/rpc/ai_pb2.pyi @@ -0,0 +1,23 @@ +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Optional as _Optional + +DESCRIPTOR: _descriptor.FileDescriptor + +class AddDocumentRequest(_message.Message): + __slots__ = ["text", "user_id", "database", "collection"] + TEXT_FIELD_NUMBER: _ClassVar[int] + USER_ID_FIELD_NUMBER: _ClassVar[int] + DATABASE_FIELD_NUMBER: _ClassVar[int] + COLLECTION_FIELD_NUMBER: _ClassVar[int] + text: str + user_id: int + database: str + collection: str + def __init__(self, text: _Optional[str] = ..., user_id: _Optional[int] = ..., database: _Optional[str] = ..., collection: _Optional[str] = ...) -> None: ... + +class AddDocumentReply(_message.Message): + __slots__ = ["id"] + ID_FIELD_NUMBER: _ClassVar[int] + id: str + def __init__(self, id: _Optional[str] = ...) -> None: ... diff --git a/rpc/ai_pb2_grpc.py b/rpc/ai_pb2_grpc.py index d0c8a10..b509391 100644 --- a/rpc/ai_pb2_grpc.py +++ b/rpc/ai_pb2_grpc.py @@ -2,7 +2,7 @@ """Client and server classes corresponding to protobuf-defined services.""" import grpc -from rpc import ai_pb2 as rpc_dot_ai__pb2 +import ai_pb2 as ai__pb2 class LLMQueryStub(object): @@ -16,8 +16,8 @@ class LLMQueryStub(object): """ self.AddDocument = channel.unary_unary( '/LLMQuery/AddDocument', - request_serializer=rpc_dot_ai__pb2.AddDocumentRequest.SerializeToString, - response_deserializer=rpc_dot_ai__pb2.AddDocumentReply.FromString, + request_serializer=ai__pb2.AddDocumentRequest.SerializeToString, + response_deserializer=ai__pb2.AddDocumentReply.FromString, ) @@ -35,8 +35,8 @@ def add_LLMQueryServicer_to_server(servicer, server): rpc_method_handlers = { 'AddDocument': grpc.unary_unary_rpc_method_handler( servicer.AddDocument, - request_deserializer=rpc_dot_ai__pb2.AddDocumentRequest.FromString, - response_serializer=rpc_dot_ai__pb2.AddDocumentReply.SerializeToString, + request_deserializer=ai__pb2.AddDocumentRequest.FromString, + response_serializer=ai__pb2.AddDocumentReply.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -60,7 +60,7 @@ class LLMQuery(object): timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/LLMQuery/AddDocument', - rpc_dot_ai__pb2.AddDocumentRequest.SerializeToString, - rpc_dot_ai__pb2.AddDocumentReply.FromString, + ai__pb2.AddDocumentRequest.SerializeToString, + ai__pb2.AddDocumentReply.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/rpc/grpc.py b/rpc/grpc.py deleted file mode 100644 index e69de29..0000000 diff --git a/rpc/init.py b/rpc/init.py new file mode 100644 index 0000000..0e76569 --- /dev/null +++ b/rpc/init.py @@ -0,0 +1,69 @@ +import os + +from langchain.embeddings import OpenAIEmbeddings +from pymilvus import ( + connections, + utility, + FieldSchema, + CollectionSchema, + DataType, + Collection, +) + +# init +MILVUS_HOST = os.getenv("MILVUS_HOST") or "127.0.0.1" +MILVUS_PORT = os.getenv("MILVUS_PORT") or "19530" + +connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT) + +if not utility.has_collection("leaf_documents"): + _document_id = FieldSchema( + name="document_id", + dtype=DataType.INT64, + is_primary=True, + ) + _user_id = FieldSchema( + name="user_id", + dtype=DataType.INT64, + + ) + _document_vector = FieldSchema( + name="vector", + dtype=DataType.FLOAT_VECTOR, + dim=2 + ) + schema = CollectionSchema( + fields=[_document_id, _user_id, _document_vector], + enable_dynamic_field=True + ) + collection_name = "leaf_documents" + print("Create collection...") + collection = Collection( + name=collection_name, + schema=schema, + using='default', + shards_num=2 + ) + collection.create_index( + field_name="vector", + index_params={"metric_type": "L2", "M": 8, "efConstruction": 64, "index_type": "HNSW"}, + ) + collection.create_index( + field_name="user_id", + index_name="idx_user_id" + ) + +collection = Collection("leaf_documents") +collection.load() + +embeddings = OpenAIEmbeddings(model="text-embedding-ada-002") + + +def text_to_vector(text: str): + return embeddings.embed_query(text) + + +def insert_document(document_id: int, vector: list): + collection.insert( + data=[document_id, vector], + ) diff --git a/rpc/server.py b/rpc/server.py new file mode 100644 index 0000000..92e11dd --- /dev/null +++ b/rpc/server.py @@ -0,0 +1,31 @@ +import os +from concurrent import futures +from langchain.embeddings import OpenAIEmbeddings +import ai_pb2 +import ai_pb2_grpc +import grpc + + +class AIServer(ai_pb2_grpc.LLMQueryServicer): + def AddDocument(self, request, context): + print("AddDocument called with", request.text) + + return ai_pb2.AddDocumentReply( + id=request.text + ) + + +def serve(): + _ADDR = os.getenv("BIND") + if _ADDR is None: + _ADDR = "[::]:50051" + print("Listening on", _ADDR) + + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + ai_pb2_grpc.add_LLMQueryServicer_to_server(AIServer(), server) + server.add_insecure_port(_ADDR) + server.start() + server.wait_for_termination() + + +serve()