From c3429e78f74213c3fcf114588491c6a93f9b5e36 Mon Sep 17 00:00:00 2001 From: iVamp Date: Wed, 15 Nov 2023 20:20:21 +0800 Subject: [PATCH] add --- .idea/langchain.iml | 4 ++ document_query/README.md | 10 +++++ document_query/ai.proto | 18 +++++++++ document_query/ai_pb2.py | 29 ++++++++++++++ document_query/ai_pb2_grpc.py | 66 +++++++++++++++++++++++++++++++ document_query/init.py | 73 +++++++++++++++++++++++++++++++++++ document_query/server.py | 38 ++++++++++++++++++ document_query/sync.py | 0 requirements.txt | 4 ++ rpc/init.py | 20 ++++++---- rpc/server.py | 10 ++++- 11 files changed, 263 insertions(+), 9 deletions(-) create mode 100644 document_query/README.md create mode 100644 document_query/ai.proto create mode 100644 document_query/ai_pb2.py create mode 100644 document_query/ai_pb2_grpc.py create mode 100644 document_query/init.py create mode 100644 document_query/server.py create mode 100644 document_query/sync.py diff --git a/.idea/langchain.iml b/.idea/langchain.iml index bda60cf..765ca18 100644 --- a/.idea/langchain.iml +++ b/.idea/langchain.iml @@ -5,4 +5,8 @@ + + \ No newline at end of file diff --git a/document_query/README.md b/document_query/README.md new file mode 100644 index 0000000..eda8426 --- /dev/null +++ b/document_query/README.md @@ -0,0 +1,10 @@ +# 环境变量 + +gRPC 监听地址 +```bash +BIND=0.0.0.0:12345 +MILVUS_ADDR=127.0.0.1 +MILVUS_PORT=19530 +OPENAI_API_BASE=http:// +OPENAI_API_KEY= +``` \ No newline at end of file diff --git a/document_query/ai.proto b/document_query/ai.proto new file mode 100644 index 0000000..8e6c617 --- /dev/null +++ b/document_query/ai.proto @@ -0,0 +1,18 @@ +syntax = "proto3"; + +service LLMQuery { + rpc AddDocument (QueryDocumentRequest) returns (stream QueryDocumentReply) {} +} + + +message QueryDocumentRequest { + string text = 1; + uint64 user_id = 2; + string database = 3; + string collection = 4; +} + +message QueryDocumentReply { + string text = 1; +} + diff --git a/document_query/ai_pb2.py b/document_query/ai_pb2.py new file mode 100644 index 0000000..214adb4 --- /dev/null +++ b/document_query/ai_pb2.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: ai.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x08\x61i.proto\"[\n\x14QueryDocumentRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x0f\n\x07user_id\x18\x02 \x01(\x04\x12\x10\n\x08\x64\x61tabase\x18\x03 \x01(\t\x12\x12\n\ncollection\x18\x04 \x01(\t\"\"\n\x12QueryDocumentReply\x12\x0c\n\x04text\x18\x01 \x01(\t2I\n\x08LLMQuery\x12=\n\x0b\x41\x64\x64\x44ocument\x12\x15.QueryDocumentRequest\x1a\x13.QueryDocumentReply\"\x00\x30\x01\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ai_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_QUERYDOCUMENTREQUEST']._serialized_start=12 + _globals['_QUERYDOCUMENTREQUEST']._serialized_end=103 + _globals['_QUERYDOCUMENTREPLY']._serialized_start=105 + _globals['_QUERYDOCUMENTREPLY']._serialized_end=139 + _globals['_LLMQUERY']._serialized_start=141 + _globals['_LLMQUERY']._serialized_end=214 +# @@protoc_insertion_point(module_scope) diff --git a/document_query/ai_pb2_grpc.py b/document_query/ai_pb2_grpc.py new file mode 100644 index 0000000..98c01ff --- /dev/null +++ b/document_query/ai_pb2_grpc.py @@ -0,0 +1,66 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import ai_pb2 as ai__pb2 + + +class LLMQueryStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.AddDocument = channel.unary_stream( + '/LLMQuery/AddDocument', + request_serializer=ai__pb2.QueryDocumentRequest.SerializeToString, + response_deserializer=ai__pb2.QueryDocumentReply.FromString, + ) + + +class LLMQueryServicer(object): + """Missing associated documentation comment in .proto file.""" + + def AddDocument(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_LLMQueryServicer_to_server(servicer, server): + rpc_method_handlers = { + 'AddDocument': grpc.unary_stream_rpc_method_handler( + servicer.AddDocument, + request_deserializer=ai__pb2.QueryDocumentRequest.FromString, + response_serializer=ai__pb2.QueryDocumentReply.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'LLMQuery', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class LLMQuery(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def AddDocument(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/LLMQuery/AddDocument', + ai__pb2.QueryDocumentRequest.SerializeToString, + ai__pb2.QueryDocumentReply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/document_query/init.py b/document_query/init.py new file mode 100644 index 0000000..a6e57ff --- /dev/null +++ b/document_query/init.py @@ -0,0 +1,73 @@ +import os + +from langchain.embeddings import OpenAIEmbeddings +from pymilvus import ( + connections, + utility, + FieldSchema, + CollectionSchema, + DataType, + Collection, +) + +# init +MILVUS_HOST = os.getenv("MILVUS_HOST") or "127.0.0.1" +MILVUS_PORT = os.getenv("MILVUS_PORT") or "19530" + +connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT) + +if not utility.has_collection("leaf_documents"): + _document_id = FieldSchema( + name="document_id", + dtype=DataType.INT64, + is_primary=True, + ) + _user_id = FieldSchema( + name="user_id", + dtype=DataType.INT64, + + ) + _document_vector = FieldSchema( + name="vector", + dtype=DataType.FLOAT_VECTOR, + dim=2 + ) + schema = CollectionSchema( + fields=[_document_id, _user_id, _document_vector], + enable_dynamic_field=True + ) + collection_name = "leaf_documents" + print("Create collection...") + _collection = Collection( + name=collection_name, + schema=schema, + using='default', + shards_num=2 + ) + _collection.create_index( + field_name="vector", + index_params={"metric_type": "L2", "M": 8, "efConstruction": 64, "index_type": "HNSW"}, + ) + _collection.create_index( + field_name="user_id", + index_name="idx_user_id" + ) + +_collection = Collection("leaf_documents") +_collection.load() + +embeddings = OpenAIEmbeddings(model="text-embedding-ada-002") + + +def text_to_vector(text: str): + return embeddings.embed_query(text) + + +def insert_document(document_id: int, user_id: int, vector: list): + return _collection.insert( + data=[ + [document_id], + [user_id], + [vector] + ], + ) diff --git a/document_query/server.py b/document_query/server.py new file mode 100644 index 0000000..b9b1286 --- /dev/null +++ b/document_query/server.py @@ -0,0 +1,38 @@ +import os +from concurrent import futures +import ai_pb2 +import ai_pb2_grpc +import grpc +import init + + +class AIServer(ai_pb2_grpc.LLMQueryServicer): + def QueryDocumentRequest(self, request, context): + _text = request.text + _user_id = request.user_id + _database = request.database + _collection = request.collection + + vector = init.text_to_vector(_text) + data = init.insert_document(_user_id, vector, _database) + + + return ai_pb2.AddDocumentReply( + id=request.text + ) + + +def serve(): + _ADDR = os.getenv("BIND") + if _ADDR is None: + _ADDR = "[::]:50051" + print("Listening on", _ADDR) + + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + ai_pb2_grpc.add_LLMQueryServicer_to_server(AIServer(), server) + server.add_insecure_port(_ADDR) + server.start() + server.wait_for_termination() + + +serve() diff --git a/document_query/sync.py b/document_query/sync.py new file mode 100644 index 0000000..e69de29 diff --git a/requirements.txt b/requirements.txt index e69de29..3b0fabe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -0,0 +1,4 @@ +langchain~=0.0.325 +pymilvus~=2.3.2 +pymysql~=1.1.0 +openai~=0.28.1 \ No newline at end of file diff --git a/rpc/init.py b/rpc/init.py index 0e76569..d285e9a 100644 --- a/rpc/init.py +++ b/rpc/init.py @@ -38,23 +38,23 @@ if not utility.has_collection("leaf_documents"): ) collection_name = "leaf_documents" print("Create collection...") - collection = Collection( + _collection = Collection( name=collection_name, schema=schema, using='default', shards_num=2 ) - collection.create_index( + _collection.create_index( field_name="vector", index_params={"metric_type": "L2", "M": 8, "efConstruction": 64, "index_type": "HNSW"}, ) - collection.create_index( + _collection.create_index( field_name="user_id", index_name="idx_user_id" ) -collection = Collection("leaf_documents") -collection.load() +_collection = Collection("leaf_documents") +_collection.load() embeddings = OpenAIEmbeddings(model="text-embedding-ada-002") @@ -63,7 +63,11 @@ def text_to_vector(text: str): return embeddings.embed_query(text) -def insert_document(document_id: int, vector: list): - collection.insert( - data=[document_id, vector], +def insert_document(document_id: int, user_id: int, vector: list, collection: str): + return _collection.insert( + data=[ + [document_id], + [user_id], + [vector] + ], ) diff --git a/rpc/server.py b/rpc/server.py index 92e11dd..364dbae 100644 --- a/rpc/server.py +++ b/rpc/server.py @@ -4,11 +4,19 @@ from langchain.embeddings import OpenAIEmbeddings import ai_pb2 import ai_pb2_grpc import grpc +import init class AIServer(ai_pb2_grpc.LLMQueryServicer): def AddDocument(self, request, context): - print("AddDocument called with", request.text) + _text = request.text + _user_id = request.user_id + _database = request.database + _collection = request.collection + + vector = init.text_to_vector(_text) + data = init.insert_document(_user_id, vector, _database, _collection) + return ai_pb2.AddDocumentReply( id=request.text