update
This commit is contained in:
parent
c9e04df385
commit
db30aaf68c
10
rpc/README.md
Normal file
10
rpc/README.md
Normal file
@ -0,0 +1,10 @@
|
||||
# 环境变量
|
||||
|
||||
gRPC 监听地址
|
||||
```bash
|
||||
BIND=0.0.0.0:12345
|
||||
MILVUS_ADDR=127.0.0.1
|
||||
MILVUS_PORT=19530
|
||||
OPENAI_API_BASE=http://
|
||||
OPENAI_API_KEY=
|
||||
```
|
@ -15,3 +15,4 @@ message AddDocumentRequest {
|
||||
message AddDocumentReply {
|
||||
string id = 1;
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: rpc/ai.proto
|
||||
# source: ai.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
@ -13,17 +13,17 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0crpc/ai.proto\"Y\n\x12\x41\x64\x64\x44ocumentRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x0f\n\x07user_id\x18\x02 \x01(\x04\x12\x10\n\x08\x64\x61tabase\x18\x03 \x01(\t\x12\x12\n\ncollection\x18\x04 \x01(\t\"\x1e\n\x10\x41\x64\x64\x44ocumentReply\x12\n\n\x02id\x18\x01 \x01(\t2C\n\x08LLMQuery\x12\x37\n\x0b\x41\x64\x64\x44ocument\x12\x13.AddDocumentRequest\x1a\x11.AddDocumentReply\"\x00\x62\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x08\x61i.proto\"Y\n\x12\x41\x64\x64\x44ocumentRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x0f\n\x07user_id\x18\x02 \x01(\x04\x12\x10\n\x08\x64\x61tabase\x18\x03 \x01(\t\x12\x12\n\ncollection\x18\x04 \x01(\t\"\x1e\n\x10\x41\x64\x64\x44ocumentReply\x12\n\n\x02id\x18\x01 \x01(\t2C\n\x08LLMQuery\x12\x37\n\x0b\x41\x64\x64\x44ocument\x12\x13.AddDocumentRequest\x1a\x11.AddDocumentReply\"\x00\x62\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'rpc.ai_pb2', _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ai_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
DESCRIPTOR._options = None
|
||||
_globals['_ADDDOCUMENTREQUEST']._serialized_start=16
|
||||
_globals['_ADDDOCUMENTREQUEST']._serialized_end=105
|
||||
_globals['_ADDDOCUMENTREPLY']._serialized_start=107
|
||||
_globals['_ADDDOCUMENTREPLY']._serialized_end=137
|
||||
_globals['_LLMQUERY']._serialized_start=139
|
||||
_globals['_LLMQUERY']._serialized_end=206
|
||||
_globals['_ADDDOCUMENTREQUEST']._serialized_start=12
|
||||
_globals['_ADDDOCUMENTREQUEST']._serialized_end=101
|
||||
_globals['_ADDDOCUMENTREPLY']._serialized_start=103
|
||||
_globals['_ADDDOCUMENTREPLY']._serialized_end=133
|
||||
_globals['_LLMQUERY']._serialized_start=135
|
||||
_globals['_LLMQUERY']._serialized_end=202
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
23
rpc/ai_pb2.pyi
Normal file
23
rpc/ai_pb2.pyi
Normal file
@ -0,0 +1,23 @@
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import message as _message
|
||||
from typing import ClassVar as _ClassVar, Optional as _Optional
|
||||
|
||||
DESCRIPTOR: _descriptor.FileDescriptor
|
||||
|
||||
class AddDocumentRequest(_message.Message):
|
||||
__slots__ = ["text", "user_id", "database", "collection"]
|
||||
TEXT_FIELD_NUMBER: _ClassVar[int]
|
||||
USER_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
DATABASE_FIELD_NUMBER: _ClassVar[int]
|
||||
COLLECTION_FIELD_NUMBER: _ClassVar[int]
|
||||
text: str
|
||||
user_id: int
|
||||
database: str
|
||||
collection: str
|
||||
def __init__(self, text: _Optional[str] = ..., user_id: _Optional[int] = ..., database: _Optional[str] = ..., collection: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class AddDocumentReply(_message.Message):
|
||||
__slots__ = ["id"]
|
||||
ID_FIELD_NUMBER: _ClassVar[int]
|
||||
id: str
|
||||
def __init__(self, id: _Optional[str] = ...) -> None: ...
|
@ -2,7 +2,7 @@
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
|
||||
from rpc import ai_pb2 as rpc_dot_ai__pb2
|
||||
import ai_pb2 as ai__pb2
|
||||
|
||||
|
||||
class LLMQueryStub(object):
|
||||
@ -16,8 +16,8 @@ class LLMQueryStub(object):
|
||||
"""
|
||||
self.AddDocument = channel.unary_unary(
|
||||
'/LLMQuery/AddDocument',
|
||||
request_serializer=rpc_dot_ai__pb2.AddDocumentRequest.SerializeToString,
|
||||
response_deserializer=rpc_dot_ai__pb2.AddDocumentReply.FromString,
|
||||
request_serializer=ai__pb2.AddDocumentRequest.SerializeToString,
|
||||
response_deserializer=ai__pb2.AddDocumentReply.FromString,
|
||||
)
|
||||
|
||||
|
||||
@ -35,8 +35,8 @@ def add_LLMQueryServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'AddDocument': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.AddDocument,
|
||||
request_deserializer=rpc_dot_ai__pb2.AddDocumentRequest.FromString,
|
||||
response_serializer=rpc_dot_ai__pb2.AddDocumentReply.SerializeToString,
|
||||
request_deserializer=ai__pb2.AddDocumentRequest.FromString,
|
||||
response_serializer=ai__pb2.AddDocumentReply.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
@ -60,7 +60,7 @@ class LLMQuery(object):
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/LLMQuery/AddDocument',
|
||||
rpc_dot_ai__pb2.AddDocumentRequest.SerializeToString,
|
||||
rpc_dot_ai__pb2.AddDocumentReply.FromString,
|
||||
ai__pb2.AddDocumentRequest.SerializeToString,
|
||||
ai__pb2.AddDocumentReply.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
69
rpc/init.py
Normal file
69
rpc/init.py
Normal file
@ -0,0 +1,69 @@
|
||||
import os
|
||||
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from pymilvus import (
|
||||
connections,
|
||||
utility,
|
||||
FieldSchema,
|
||||
CollectionSchema,
|
||||
DataType,
|
||||
Collection,
|
||||
)
|
||||
|
||||
# init
|
||||
MILVUS_HOST = os.getenv("MILVUS_HOST") or "127.0.0.1"
|
||||
MILVUS_PORT = os.getenv("MILVUS_PORT") or "19530"
|
||||
|
||||
connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
|
||||
|
||||
if not utility.has_collection("leaf_documents"):
|
||||
_document_id = FieldSchema(
|
||||
name="document_id",
|
||||
dtype=DataType.INT64,
|
||||
is_primary=True,
|
||||
)
|
||||
_user_id = FieldSchema(
|
||||
name="user_id",
|
||||
dtype=DataType.INT64,
|
||||
|
||||
)
|
||||
_document_vector = FieldSchema(
|
||||
name="vector",
|
||||
dtype=DataType.FLOAT_VECTOR,
|
||||
dim=2
|
||||
)
|
||||
schema = CollectionSchema(
|
||||
fields=[_document_id, _user_id, _document_vector],
|
||||
enable_dynamic_field=True
|
||||
)
|
||||
collection_name = "leaf_documents"
|
||||
print("Create collection...")
|
||||
collection = Collection(
|
||||
name=collection_name,
|
||||
schema=schema,
|
||||
using='default',
|
||||
shards_num=2
|
||||
)
|
||||
collection.create_index(
|
||||
field_name="vector",
|
||||
index_params={"metric_type": "L2", "M": 8, "efConstruction": 64, "index_type": "HNSW"},
|
||||
)
|
||||
collection.create_index(
|
||||
field_name="user_id",
|
||||
index_name="idx_user_id"
|
||||
)
|
||||
|
||||
collection = Collection("leaf_documents")
|
||||
collection.load()
|
||||
|
||||
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
|
||||
|
||||
|
||||
def text_to_vector(text: str):
|
||||
return embeddings.embed_query(text)
|
||||
|
||||
|
||||
def insert_document(document_id: int, vector: list):
|
||||
collection.insert(
|
||||
data=[document_id, vector],
|
||||
)
|
31
rpc/server.py
Normal file
31
rpc/server.py
Normal file
@ -0,0 +1,31 @@
|
||||
import os
|
||||
from concurrent import futures
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
import ai_pb2
|
||||
import ai_pb2_grpc
|
||||
import grpc
|
||||
|
||||
|
||||
class AIServer(ai_pb2_grpc.LLMQueryServicer):
|
||||
def AddDocument(self, request, context):
|
||||
print("AddDocument called with", request.text)
|
||||
|
||||
return ai_pb2.AddDocumentReply(
|
||||
id=request.text
|
||||
)
|
||||
|
||||
|
||||
def serve():
|
||||
_ADDR = os.getenv("BIND")
|
||||
if _ADDR is None:
|
||||
_ADDR = "[::]:50051"
|
||||
print("Listening on", _ADDR)
|
||||
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
ai_pb2_grpc.add_LLMQueryServicer_to_server(AIServer(), server)
|
||||
server.add_insecure_port(_ADDR)
|
||||
server.start()
|
||||
server.wait_for_termination()
|
||||
|
||||
|
||||
serve()
|
Loading…
Reference in New Issue
Block a user