This commit is contained in:
iVampireSP.com 2023-11-14 19:41:35 +08:00
parent c9e04df385
commit db30aaf68c
No known key found for this signature in database
GPG Key ID: 2F7B001CA27A8132
8 changed files with 151 additions and 17 deletions

10
rpc/README.md Normal file
View File

@ -0,0 +1,10 @@
# 环境变量
gRPC 监听地址
```bash
BIND=0.0.0.0:12345
MILVUS_ADDR=127.0.0.1
MILVUS_PORT=19530
OPENAI_API_BASE=http://
OPENAI_API_KEY=
```

View File

@ -15,3 +15,4 @@ message AddDocumentRequest {
message AddDocumentReply {
string id = 1;
}

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: rpc/ai.proto
# source: ai.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
@ -13,17 +13,17 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0crpc/ai.proto\"Y\n\x12\x41\x64\x64\x44ocumentRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x0f\n\x07user_id\x18\x02 \x01(\x04\x12\x10\n\x08\x64\x61tabase\x18\x03 \x01(\t\x12\x12\n\ncollection\x18\x04 \x01(\t\"\x1e\n\x10\x41\x64\x64\x44ocumentReply\x12\n\n\x02id\x18\x01 \x01(\t2C\n\x08LLMQuery\x12\x37\n\x0b\x41\x64\x64\x44ocument\x12\x13.AddDocumentRequest\x1a\x11.AddDocumentReply\"\x00\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x08\x61i.proto\"Y\n\x12\x41\x64\x64\x44ocumentRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x0f\n\x07user_id\x18\x02 \x01(\x04\x12\x10\n\x08\x64\x61tabase\x18\x03 \x01(\t\x12\x12\n\ncollection\x18\x04 \x01(\t\"\x1e\n\x10\x41\x64\x64\x44ocumentReply\x12\n\n\x02id\x18\x01 \x01(\t2C\n\x08LLMQuery\x12\x37\n\x0b\x41\x64\x64\x44ocument\x12\x13.AddDocumentRequest\x1a\x11.AddDocumentReply\"\x00\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'rpc.ai_pb2', _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ai_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_globals['_ADDDOCUMENTREQUEST']._serialized_start=16
_globals['_ADDDOCUMENTREQUEST']._serialized_end=105
_globals['_ADDDOCUMENTREPLY']._serialized_start=107
_globals['_ADDDOCUMENTREPLY']._serialized_end=137
_globals['_LLMQUERY']._serialized_start=139
_globals['_LLMQUERY']._serialized_end=206
_globals['_ADDDOCUMENTREQUEST']._serialized_start=12
_globals['_ADDDOCUMENTREQUEST']._serialized_end=101
_globals['_ADDDOCUMENTREPLY']._serialized_start=103
_globals['_ADDDOCUMENTREPLY']._serialized_end=133
_globals['_LLMQUERY']._serialized_start=135
_globals['_LLMQUERY']._serialized_end=202
# @@protoc_insertion_point(module_scope)

23
rpc/ai_pb2.pyi Normal file
View File

@ -0,0 +1,23 @@
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Optional as _Optional
DESCRIPTOR: _descriptor.FileDescriptor
class AddDocumentRequest(_message.Message):
__slots__ = ["text", "user_id", "database", "collection"]
TEXT_FIELD_NUMBER: _ClassVar[int]
USER_ID_FIELD_NUMBER: _ClassVar[int]
DATABASE_FIELD_NUMBER: _ClassVar[int]
COLLECTION_FIELD_NUMBER: _ClassVar[int]
text: str
user_id: int
database: str
collection: str
def __init__(self, text: _Optional[str] = ..., user_id: _Optional[int] = ..., database: _Optional[str] = ..., collection: _Optional[str] = ...) -> None: ...
class AddDocumentReply(_message.Message):
__slots__ = ["id"]
ID_FIELD_NUMBER: _ClassVar[int]
id: str
def __init__(self, id: _Optional[str] = ...) -> None: ...

View File

@ -2,7 +2,7 @@
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
from rpc import ai_pb2 as rpc_dot_ai__pb2
import ai_pb2 as ai__pb2
class LLMQueryStub(object):
@ -16,8 +16,8 @@ class LLMQueryStub(object):
"""
self.AddDocument = channel.unary_unary(
'/LLMQuery/AddDocument',
request_serializer=rpc_dot_ai__pb2.AddDocumentRequest.SerializeToString,
response_deserializer=rpc_dot_ai__pb2.AddDocumentReply.FromString,
request_serializer=ai__pb2.AddDocumentRequest.SerializeToString,
response_deserializer=ai__pb2.AddDocumentReply.FromString,
)
@ -35,8 +35,8 @@ def add_LLMQueryServicer_to_server(servicer, server):
rpc_method_handlers = {
'AddDocument': grpc.unary_unary_rpc_method_handler(
servicer.AddDocument,
request_deserializer=rpc_dot_ai__pb2.AddDocumentRequest.FromString,
response_serializer=rpc_dot_ai__pb2.AddDocumentReply.SerializeToString,
request_deserializer=ai__pb2.AddDocumentRequest.FromString,
response_serializer=ai__pb2.AddDocumentReply.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
@ -60,7 +60,7 @@ class LLMQuery(object):
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/LLMQuery/AddDocument',
rpc_dot_ai__pb2.AddDocumentRequest.SerializeToString,
rpc_dot_ai__pb2.AddDocumentReply.FromString,
ai__pb2.AddDocumentRequest.SerializeToString,
ai__pb2.AddDocumentReply.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

View File

69
rpc/init.py Normal file
View File

@ -0,0 +1,69 @@
import os
from langchain.embeddings import OpenAIEmbeddings
from pymilvus import (
connections,
utility,
FieldSchema,
CollectionSchema,
DataType,
Collection,
)
# init
MILVUS_HOST = os.getenv("MILVUS_HOST") or "127.0.0.1"
MILVUS_PORT = os.getenv("MILVUS_PORT") or "19530"
connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
if not utility.has_collection("leaf_documents"):
_document_id = FieldSchema(
name="document_id",
dtype=DataType.INT64,
is_primary=True,
)
_user_id = FieldSchema(
name="user_id",
dtype=DataType.INT64,
)
_document_vector = FieldSchema(
name="vector",
dtype=DataType.FLOAT_VECTOR,
dim=2
)
schema = CollectionSchema(
fields=[_document_id, _user_id, _document_vector],
enable_dynamic_field=True
)
collection_name = "leaf_documents"
print("Create collection...")
collection = Collection(
name=collection_name,
schema=schema,
using='default',
shards_num=2
)
collection.create_index(
field_name="vector",
index_params={"metric_type": "L2", "M": 8, "efConstruction": 64, "index_type": "HNSW"},
)
collection.create_index(
field_name="user_id",
index_name="idx_user_id"
)
collection = Collection("leaf_documents")
collection.load()
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
def text_to_vector(text: str):
return embeddings.embed_query(text)
def insert_document(document_id: int, vector: list):
collection.insert(
data=[document_id, vector],
)

31
rpc/server.py Normal file
View File

@ -0,0 +1,31 @@
import os
from concurrent import futures
from langchain.embeddings import OpenAIEmbeddings
import ai_pb2
import ai_pb2_grpc
import grpc
class AIServer(ai_pb2_grpc.LLMQueryServicer):
def AddDocument(self, request, context):
print("AddDocument called with", request.text)
return ai_pb2.AddDocumentReply(
id=request.text
)
def serve():
_ADDR = os.getenv("BIND")
if _ADDR is None:
_ADDR = "[::]:50051"
print("Listening on", _ADDR)
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
ai_pb2_grpc.add_LLMQueryServicer_to_server(AIServer(), server)
server.add_insecure_port(_ADDR)
server.start()
server.wait_for_termination()
serve()