This commit is contained in:
iVamp 2023-11-15 20:20:21 +08:00
parent 2462ec5b7e
commit c3429e78f7
11 changed files with 263 additions and 9 deletions

View File

@ -5,4 +5,8 @@
<orderEntry type="jdk" jdkName="chat" jdkType="Python SDK" /> <orderEntry type="jdk" jdkName="chat" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" /> <orderEntry type="sourceFolder" forTests="false" />
</component> </component>
<component name="PackageRequirementsSettings">
<option name="removeUnused" value="true" />
<option name="modifyBaseFiles" value="true" />
</component>
</module> </module>

10
document_query/README.md Normal file
View File

@ -0,0 +1,10 @@
# 环境变量
gRPC 监听地址
```bash
BIND=0.0.0.0:12345
MILVUS_ADDR=127.0.0.1
MILVUS_PORT=19530
OPENAI_API_BASE=http://
OPENAI_API_KEY=
```

18
document_query/ai.proto Normal file
View File

@ -0,0 +1,18 @@
syntax = "proto3";
service LLMQuery {
rpc AddDocument (QueryDocumentRequest) returns (stream QueryDocumentReply) {}
}
message QueryDocumentRequest {
string text = 1;
uint64 user_id = 2;
string database = 3;
string collection = 4;
}
message QueryDocumentReply {
string text = 1;
}

29
document_query/ai_pb2.py Normal file
View File

@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: ai.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x08\x61i.proto\"[\n\x14QueryDocumentRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x0f\n\x07user_id\x18\x02 \x01(\x04\x12\x10\n\x08\x64\x61tabase\x18\x03 \x01(\t\x12\x12\n\ncollection\x18\x04 \x01(\t\"\"\n\x12QueryDocumentReply\x12\x0c\n\x04text\x18\x01 \x01(\t2I\n\x08LLMQuery\x12=\n\x0b\x41\x64\x64\x44ocument\x12\x15.QueryDocumentRequest\x1a\x13.QueryDocumentReply\"\x00\x30\x01\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ai_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_globals['_QUERYDOCUMENTREQUEST']._serialized_start=12
_globals['_QUERYDOCUMENTREQUEST']._serialized_end=103
_globals['_QUERYDOCUMENTREPLY']._serialized_start=105
_globals['_QUERYDOCUMENTREPLY']._serialized_end=139
_globals['_LLMQUERY']._serialized_start=141
_globals['_LLMQUERY']._serialized_end=214
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,66 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import ai_pb2 as ai__pb2
class LLMQueryStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.AddDocument = channel.unary_stream(
'/LLMQuery/AddDocument',
request_serializer=ai__pb2.QueryDocumentRequest.SerializeToString,
response_deserializer=ai__pb2.QueryDocumentReply.FromString,
)
class LLMQueryServicer(object):
"""Missing associated documentation comment in .proto file."""
def AddDocument(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_LLMQueryServicer_to_server(servicer, server):
rpc_method_handlers = {
'AddDocument': grpc.unary_stream_rpc_method_handler(
servicer.AddDocument,
request_deserializer=ai__pb2.QueryDocumentRequest.FromString,
response_serializer=ai__pb2.QueryDocumentReply.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'LLMQuery', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class LLMQuery(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def AddDocument(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_stream(request, target, '/LLMQuery/AddDocument',
ai__pb2.QueryDocumentRequest.SerializeToString,
ai__pb2.QueryDocumentReply.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

73
document_query/init.py Normal file
View File

@ -0,0 +1,73 @@
import os
from langchain.embeddings import OpenAIEmbeddings
from pymilvus import (
connections,
utility,
FieldSchema,
CollectionSchema,
DataType,
Collection,
)
# init
MILVUS_HOST = os.getenv("MILVUS_HOST") or "127.0.0.1"
MILVUS_PORT = os.getenv("MILVUS_PORT") or "19530"
connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
if not utility.has_collection("leaf_documents"):
_document_id = FieldSchema(
name="document_id",
dtype=DataType.INT64,
is_primary=True,
)
_user_id = FieldSchema(
name="user_id",
dtype=DataType.INT64,
)
_document_vector = FieldSchema(
name="vector",
dtype=DataType.FLOAT_VECTOR,
dim=2
)
schema = CollectionSchema(
fields=[_document_id, _user_id, _document_vector],
enable_dynamic_field=True
)
collection_name = "leaf_documents"
print("Create collection...")
_collection = Collection(
name=collection_name,
schema=schema,
using='default',
shards_num=2
)
_collection.create_index(
field_name="vector",
index_params={"metric_type": "L2", "M": 8, "efConstruction": 64, "index_type": "HNSW"},
)
_collection.create_index(
field_name="user_id",
index_name="idx_user_id"
)
_collection = Collection("leaf_documents")
_collection.load()
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
def text_to_vector(text: str):
return embeddings.embed_query(text)
def insert_document(document_id: int, user_id: int, vector: list):
return _collection.insert(
data=[
[document_id],
[user_id],
[vector]
],
)

38
document_query/server.py Normal file
View File

@ -0,0 +1,38 @@
import os
from concurrent import futures
import ai_pb2
import ai_pb2_grpc
import grpc
import init
class AIServer(ai_pb2_grpc.LLMQueryServicer):
def QueryDocumentRequest(self, request, context):
_text = request.text
_user_id = request.user_id
_database = request.database
_collection = request.collection
vector = init.text_to_vector(_text)
data = init.insert_document(_user_id, vector, _database)
return ai_pb2.AddDocumentReply(
id=request.text
)
def serve():
_ADDR = os.getenv("BIND")
if _ADDR is None:
_ADDR = "[::]:50051"
print("Listening on", _ADDR)
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
ai_pb2_grpc.add_LLMQueryServicer_to_server(AIServer(), server)
server.add_insecure_port(_ADDR)
server.start()
server.wait_for_termination()
serve()

0
document_query/sync.py Normal file
View File

View File

@ -0,0 +1,4 @@
langchain~=0.0.325
pymilvus~=2.3.2
pymysql~=1.1.0
openai~=0.28.1

View File

@ -38,23 +38,23 @@ if not utility.has_collection("leaf_documents"):
) )
collection_name = "leaf_documents" collection_name = "leaf_documents"
print("Create collection...") print("Create collection...")
collection = Collection( _collection = Collection(
name=collection_name, name=collection_name,
schema=schema, schema=schema,
using='default', using='default',
shards_num=2 shards_num=2
) )
collection.create_index( _collection.create_index(
field_name="vector", field_name="vector",
index_params={"metric_type": "L2", "M": 8, "efConstruction": 64, "index_type": "HNSW"}, index_params={"metric_type": "L2", "M": 8, "efConstruction": 64, "index_type": "HNSW"},
) )
collection.create_index( _collection.create_index(
field_name="user_id", field_name="user_id",
index_name="idx_user_id" index_name="idx_user_id"
) )
collection = Collection("leaf_documents") _collection = Collection("leaf_documents")
collection.load() _collection.load()
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002") embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
@ -63,7 +63,11 @@ def text_to_vector(text: str):
return embeddings.embed_query(text) return embeddings.embed_query(text)
def insert_document(document_id: int, vector: list): def insert_document(document_id: int, user_id: int, vector: list, collection: str):
collection.insert( return _collection.insert(
data=[document_id, vector], data=[
[document_id],
[user_id],
[vector]
],
) )

View File

@ -4,11 +4,19 @@ from langchain.embeddings import OpenAIEmbeddings
import ai_pb2 import ai_pb2
import ai_pb2_grpc import ai_pb2_grpc
import grpc import grpc
import init
class AIServer(ai_pb2_grpc.LLMQueryServicer): class AIServer(ai_pb2_grpc.LLMQueryServicer):
def AddDocument(self, request, context): def AddDocument(self, request, context):
print("AddDocument called with", request.text) _text = request.text
_user_id = request.user_id
_database = request.database
_collection = request.collection
vector = init.text_to_vector(_text)
data = init.insert_document(_user_id, vector, _database, _collection)
return ai_pb2.AddDocumentReply( return ai_pb2.AddDocumentReply(
id=request.text id=request.text