update
This commit is contained in:
parent
c9e04df385
commit
db30aaf68c
10
rpc/README.md
Normal file
10
rpc/README.md
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
# 环境变量
|
||||||
|
|
||||||
|
gRPC 监听地址
|
||||||
|
```bash
|
||||||
|
BIND=0.0.0.0:12345
|
||||||
|
MILVUS_ADDR=127.0.0.1
|
||||||
|
MILVUS_PORT=19530
|
||||||
|
OPENAI_API_BASE=http://
|
||||||
|
OPENAI_API_KEY=
|
||||||
|
```
|
@ -15,3 +15,4 @@ message AddDocumentRequest {
|
|||||||
message AddDocumentReply {
|
message AddDocumentReply {
|
||||||
string id = 1;
|
string id = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||||
# source: rpc/ai.proto
|
# source: ai.proto
|
||||||
"""Generated protocol buffer code."""
|
"""Generated protocol buffer code."""
|
||||||
from google.protobuf import descriptor as _descriptor
|
from google.protobuf import descriptor as _descriptor
|
||||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||||
@ -13,17 +13,17 @@ _sym_db = _symbol_database.Default()
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0crpc/ai.proto\"Y\n\x12\x41\x64\x64\x44ocumentRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x0f\n\x07user_id\x18\x02 \x01(\x04\x12\x10\n\x08\x64\x61tabase\x18\x03 \x01(\t\x12\x12\n\ncollection\x18\x04 \x01(\t\"\x1e\n\x10\x41\x64\x64\x44ocumentReply\x12\n\n\x02id\x18\x01 \x01(\t2C\n\x08LLMQuery\x12\x37\n\x0b\x41\x64\x64\x44ocument\x12\x13.AddDocumentRequest\x1a\x11.AddDocumentReply\"\x00\x62\x06proto3')
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x08\x61i.proto\"Y\n\x12\x41\x64\x64\x44ocumentRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x0f\n\x07user_id\x18\x02 \x01(\x04\x12\x10\n\x08\x64\x61tabase\x18\x03 \x01(\t\x12\x12\n\ncollection\x18\x04 \x01(\t\"\x1e\n\x10\x41\x64\x64\x44ocumentReply\x12\n\n\x02id\x18\x01 \x01(\t2C\n\x08LLMQuery\x12\x37\n\x0b\x41\x64\x64\x44ocument\x12\x13.AddDocumentRequest\x1a\x11.AddDocumentReply\"\x00\x62\x06proto3')
|
||||||
|
|
||||||
_globals = globals()
|
_globals = globals()
|
||||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'rpc.ai_pb2', _globals)
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ai_pb2', _globals)
|
||||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||||
DESCRIPTOR._options = None
|
DESCRIPTOR._options = None
|
||||||
_globals['_ADDDOCUMENTREQUEST']._serialized_start=16
|
_globals['_ADDDOCUMENTREQUEST']._serialized_start=12
|
||||||
_globals['_ADDDOCUMENTREQUEST']._serialized_end=105
|
_globals['_ADDDOCUMENTREQUEST']._serialized_end=101
|
||||||
_globals['_ADDDOCUMENTREPLY']._serialized_start=107
|
_globals['_ADDDOCUMENTREPLY']._serialized_start=103
|
||||||
_globals['_ADDDOCUMENTREPLY']._serialized_end=137
|
_globals['_ADDDOCUMENTREPLY']._serialized_end=133
|
||||||
_globals['_LLMQUERY']._serialized_start=139
|
_globals['_LLMQUERY']._serialized_start=135
|
||||||
_globals['_LLMQUERY']._serialized_end=206
|
_globals['_LLMQUERY']._serialized_end=202
|
||||||
# @@protoc_insertion_point(module_scope)
|
# @@protoc_insertion_point(module_scope)
|
||||||
|
23
rpc/ai_pb2.pyi
Normal file
23
rpc/ai_pb2.pyi
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from google.protobuf import descriptor as _descriptor
|
||||||
|
from google.protobuf import message as _message
|
||||||
|
from typing import ClassVar as _ClassVar, Optional as _Optional
|
||||||
|
|
||||||
|
DESCRIPTOR: _descriptor.FileDescriptor
|
||||||
|
|
||||||
|
class AddDocumentRequest(_message.Message):
|
||||||
|
__slots__ = ["text", "user_id", "database", "collection"]
|
||||||
|
TEXT_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
USER_ID_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
DATABASE_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
COLLECTION_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
text: str
|
||||||
|
user_id: int
|
||||||
|
database: str
|
||||||
|
collection: str
|
||||||
|
def __init__(self, text: _Optional[str] = ..., user_id: _Optional[int] = ..., database: _Optional[str] = ..., collection: _Optional[str] = ...) -> None: ...
|
||||||
|
|
||||||
|
class AddDocumentReply(_message.Message):
|
||||||
|
__slots__ = ["id"]
|
||||||
|
ID_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
id: str
|
||||||
|
def __init__(self, id: _Optional[str] = ...) -> None: ...
|
@ -2,7 +2,7 @@
|
|||||||
"""Client and server classes corresponding to protobuf-defined services."""
|
"""Client and server classes corresponding to protobuf-defined services."""
|
||||||
import grpc
|
import grpc
|
||||||
|
|
||||||
from rpc import ai_pb2 as rpc_dot_ai__pb2
|
import ai_pb2 as ai__pb2
|
||||||
|
|
||||||
|
|
||||||
class LLMQueryStub(object):
|
class LLMQueryStub(object):
|
||||||
@ -16,8 +16,8 @@ class LLMQueryStub(object):
|
|||||||
"""
|
"""
|
||||||
self.AddDocument = channel.unary_unary(
|
self.AddDocument = channel.unary_unary(
|
||||||
'/LLMQuery/AddDocument',
|
'/LLMQuery/AddDocument',
|
||||||
request_serializer=rpc_dot_ai__pb2.AddDocumentRequest.SerializeToString,
|
request_serializer=ai__pb2.AddDocumentRequest.SerializeToString,
|
||||||
response_deserializer=rpc_dot_ai__pb2.AddDocumentReply.FromString,
|
response_deserializer=ai__pb2.AddDocumentReply.FromString,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -35,8 +35,8 @@ def add_LLMQueryServicer_to_server(servicer, server):
|
|||||||
rpc_method_handlers = {
|
rpc_method_handlers = {
|
||||||
'AddDocument': grpc.unary_unary_rpc_method_handler(
|
'AddDocument': grpc.unary_unary_rpc_method_handler(
|
||||||
servicer.AddDocument,
|
servicer.AddDocument,
|
||||||
request_deserializer=rpc_dot_ai__pb2.AddDocumentRequest.FromString,
|
request_deserializer=ai__pb2.AddDocumentRequest.FromString,
|
||||||
response_serializer=rpc_dot_ai__pb2.AddDocumentReply.SerializeToString,
|
response_serializer=ai__pb2.AddDocumentReply.SerializeToString,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
generic_handler = grpc.method_handlers_generic_handler(
|
generic_handler = grpc.method_handlers_generic_handler(
|
||||||
@ -60,7 +60,7 @@ class LLMQuery(object):
|
|||||||
timeout=None,
|
timeout=None,
|
||||||
metadata=None):
|
metadata=None):
|
||||||
return grpc.experimental.unary_unary(request, target, '/LLMQuery/AddDocument',
|
return grpc.experimental.unary_unary(request, target, '/LLMQuery/AddDocument',
|
||||||
rpc_dot_ai__pb2.AddDocumentRequest.SerializeToString,
|
ai__pb2.AddDocumentRequest.SerializeToString,
|
||||||
rpc_dot_ai__pb2.AddDocumentReply.FromString,
|
ai__pb2.AddDocumentReply.FromString,
|
||||||
options, channel_credentials,
|
options, channel_credentials,
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||||
|
69
rpc/init.py
Normal file
69
rpc/init.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
from pymilvus import (
|
||||||
|
connections,
|
||||||
|
utility,
|
||||||
|
FieldSchema,
|
||||||
|
CollectionSchema,
|
||||||
|
DataType,
|
||||||
|
Collection,
|
||||||
|
)
|
||||||
|
|
||||||
|
# init
|
||||||
|
MILVUS_HOST = os.getenv("MILVUS_HOST") or "127.0.0.1"
|
||||||
|
MILVUS_PORT = os.getenv("MILVUS_PORT") or "19530"
|
||||||
|
|
||||||
|
connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
|
||||||
|
|
||||||
|
if not utility.has_collection("leaf_documents"):
|
||||||
|
_document_id = FieldSchema(
|
||||||
|
name="document_id",
|
||||||
|
dtype=DataType.INT64,
|
||||||
|
is_primary=True,
|
||||||
|
)
|
||||||
|
_user_id = FieldSchema(
|
||||||
|
name="user_id",
|
||||||
|
dtype=DataType.INT64,
|
||||||
|
|
||||||
|
)
|
||||||
|
_document_vector = FieldSchema(
|
||||||
|
name="vector",
|
||||||
|
dtype=DataType.FLOAT_VECTOR,
|
||||||
|
dim=2
|
||||||
|
)
|
||||||
|
schema = CollectionSchema(
|
||||||
|
fields=[_document_id, _user_id, _document_vector],
|
||||||
|
enable_dynamic_field=True
|
||||||
|
)
|
||||||
|
collection_name = "leaf_documents"
|
||||||
|
print("Create collection...")
|
||||||
|
collection = Collection(
|
||||||
|
name=collection_name,
|
||||||
|
schema=schema,
|
||||||
|
using='default',
|
||||||
|
shards_num=2
|
||||||
|
)
|
||||||
|
collection.create_index(
|
||||||
|
field_name="vector",
|
||||||
|
index_params={"metric_type": "L2", "M": 8, "efConstruction": 64, "index_type": "HNSW"},
|
||||||
|
)
|
||||||
|
collection.create_index(
|
||||||
|
field_name="user_id",
|
||||||
|
index_name="idx_user_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
collection = Collection("leaf_documents")
|
||||||
|
collection.load()
|
||||||
|
|
||||||
|
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
|
||||||
|
|
||||||
|
|
||||||
|
def text_to_vector(text: str):
|
||||||
|
return embeddings.embed_query(text)
|
||||||
|
|
||||||
|
|
||||||
|
def insert_document(document_id: int, vector: list):
|
||||||
|
collection.insert(
|
||||||
|
data=[document_id, vector],
|
||||||
|
)
|
31
rpc/server.py
Normal file
31
rpc/server.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import os
|
||||||
|
from concurrent import futures
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
import ai_pb2
|
||||||
|
import ai_pb2_grpc
|
||||||
|
import grpc
|
||||||
|
|
||||||
|
|
||||||
|
class AIServer(ai_pb2_grpc.LLMQueryServicer):
|
||||||
|
def AddDocument(self, request, context):
|
||||||
|
print("AddDocument called with", request.text)
|
||||||
|
|
||||||
|
return ai_pb2.AddDocumentReply(
|
||||||
|
id=request.text
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def serve():
|
||||||
|
_ADDR = os.getenv("BIND")
|
||||||
|
if _ADDR is None:
|
||||||
|
_ADDR = "[::]:50051"
|
||||||
|
print("Listening on", _ADDR)
|
||||||
|
|
||||||
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||||
|
ai_pb2_grpc.add_LLMQueryServicer_to_server(AIServer(), server)
|
||||||
|
server.add_insecure_port(_ADDR)
|
||||||
|
server.start()
|
||||||
|
server.wait_for_termination()
|
||||||
|
|
||||||
|
|
||||||
|
serve()
|
Loading…
Reference in New Issue
Block a user