add
This commit is contained in:
parent
2462ec5b7e
commit
c3429e78f7
@ -5,4 +5,8 @@
|
||||
<orderEntry type="jdk" jdkName="chat" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PackageRequirementsSettings">
|
||||
<option name="removeUnused" value="true" />
|
||||
<option name="modifyBaseFiles" value="true" />
|
||||
</component>
|
||||
</module>
|
10
document_query/README.md
Normal file
10
document_query/README.md
Normal file
@ -0,0 +1,10 @@
|
||||
# 环境变量
|
||||
|
||||
gRPC 监听地址
|
||||
```bash
|
||||
BIND=0.0.0.0:12345
|
||||
MILVUS_ADDR=127.0.0.1
|
||||
MILVUS_PORT=19530
|
||||
OPENAI_API_BASE=http://
|
||||
OPENAI_API_KEY=
|
||||
```
|
18
document_query/ai.proto
Normal file
18
document_query/ai.proto
Normal file
@ -0,0 +1,18 @@
|
||||
syntax = "proto3";
|
||||
|
||||
service LLMQuery {
|
||||
rpc AddDocument (QueryDocumentRequest) returns (stream QueryDocumentReply) {}
|
||||
}
|
||||
|
||||
|
||||
message QueryDocumentRequest {
|
||||
string text = 1;
|
||||
uint64 user_id = 2;
|
||||
string database = 3;
|
||||
string collection = 4;
|
||||
}
|
||||
|
||||
message QueryDocumentReply {
|
||||
string text = 1;
|
||||
}
|
||||
|
29
document_query/ai_pb2.py
Normal file
29
document_query/ai_pb2.py
Normal file
@ -0,0 +1,29 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: ai.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x08\x61i.proto\"[\n\x14QueryDocumentRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x0f\n\x07user_id\x18\x02 \x01(\x04\x12\x10\n\x08\x64\x61tabase\x18\x03 \x01(\t\x12\x12\n\ncollection\x18\x04 \x01(\t\"\"\n\x12QueryDocumentReply\x12\x0c\n\x04text\x18\x01 \x01(\t2I\n\x08LLMQuery\x12=\n\x0b\x41\x64\x64\x44ocument\x12\x15.QueryDocumentRequest\x1a\x13.QueryDocumentReply\"\x00\x30\x01\x62\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ai_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
DESCRIPTOR._options = None
|
||||
_globals['_QUERYDOCUMENTREQUEST']._serialized_start=12
|
||||
_globals['_QUERYDOCUMENTREQUEST']._serialized_end=103
|
||||
_globals['_QUERYDOCUMENTREPLY']._serialized_start=105
|
||||
_globals['_QUERYDOCUMENTREPLY']._serialized_end=139
|
||||
_globals['_LLMQUERY']._serialized_start=141
|
||||
_globals['_LLMQUERY']._serialized_end=214
|
||||
# @@protoc_insertion_point(module_scope)
|
66
document_query/ai_pb2_grpc.py
Normal file
66
document_query/ai_pb2_grpc.py
Normal file
@ -0,0 +1,66 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
|
||||
import ai_pb2 as ai__pb2
|
||||
|
||||
|
||||
class LLMQueryStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.AddDocument = channel.unary_stream(
|
||||
'/LLMQuery/AddDocument',
|
||||
request_serializer=ai__pb2.QueryDocumentRequest.SerializeToString,
|
||||
response_deserializer=ai__pb2.QueryDocumentReply.FromString,
|
||||
)
|
||||
|
||||
|
||||
class LLMQueryServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def AddDocument(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_LLMQueryServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'AddDocument': grpc.unary_stream_rpc_method_handler(
|
||||
servicer.AddDocument,
|
||||
request_deserializer=ai__pb2.QueryDocumentRequest.FromString,
|
||||
response_serializer=ai__pb2.QueryDocumentReply.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'LLMQuery', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class LLMQuery(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def AddDocument(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_stream(request, target, '/LLMQuery/AddDocument',
|
||||
ai__pb2.QueryDocumentRequest.SerializeToString,
|
||||
ai__pb2.QueryDocumentReply.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
73
document_query/init.py
Normal file
73
document_query/init.py
Normal file
@ -0,0 +1,73 @@
|
||||
import os
|
||||
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from pymilvus import (
|
||||
connections,
|
||||
utility,
|
||||
FieldSchema,
|
||||
CollectionSchema,
|
||||
DataType,
|
||||
Collection,
|
||||
)
|
||||
|
||||
# init
|
||||
MILVUS_HOST = os.getenv("MILVUS_HOST") or "127.0.0.1"
|
||||
MILVUS_PORT = os.getenv("MILVUS_PORT") or "19530"
|
||||
|
||||
connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
|
||||
|
||||
if not utility.has_collection("leaf_documents"):
|
||||
_document_id = FieldSchema(
|
||||
name="document_id",
|
||||
dtype=DataType.INT64,
|
||||
is_primary=True,
|
||||
)
|
||||
_user_id = FieldSchema(
|
||||
name="user_id",
|
||||
dtype=DataType.INT64,
|
||||
|
||||
)
|
||||
_document_vector = FieldSchema(
|
||||
name="vector",
|
||||
dtype=DataType.FLOAT_VECTOR,
|
||||
dim=2
|
||||
)
|
||||
schema = CollectionSchema(
|
||||
fields=[_document_id, _user_id, _document_vector],
|
||||
enable_dynamic_field=True
|
||||
)
|
||||
collection_name = "leaf_documents"
|
||||
print("Create collection...")
|
||||
_collection = Collection(
|
||||
name=collection_name,
|
||||
schema=schema,
|
||||
using='default',
|
||||
shards_num=2
|
||||
)
|
||||
_collection.create_index(
|
||||
field_name="vector",
|
||||
index_params={"metric_type": "L2", "M": 8, "efConstruction": 64, "index_type": "HNSW"},
|
||||
)
|
||||
_collection.create_index(
|
||||
field_name="user_id",
|
||||
index_name="idx_user_id"
|
||||
)
|
||||
|
||||
_collection = Collection("leaf_documents")
|
||||
_collection.load()
|
||||
|
||||
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
|
||||
|
||||
|
||||
def text_to_vector(text: str):
|
||||
return embeddings.embed_query(text)
|
||||
|
||||
|
||||
def insert_document(document_id: int, user_id: int, vector: list):
|
||||
return _collection.insert(
|
||||
data=[
|
||||
[document_id],
|
||||
[user_id],
|
||||
[vector]
|
||||
],
|
||||
)
|
38
document_query/server.py
Normal file
38
document_query/server.py
Normal file
@ -0,0 +1,38 @@
|
||||
import os
|
||||
from concurrent import futures
|
||||
import ai_pb2
|
||||
import ai_pb2_grpc
|
||||
import grpc
|
||||
import init
|
||||
|
||||
|
||||
class AIServer(ai_pb2_grpc.LLMQueryServicer):
|
||||
def QueryDocumentRequest(self, request, context):
|
||||
_text = request.text
|
||||
_user_id = request.user_id
|
||||
_database = request.database
|
||||
_collection = request.collection
|
||||
|
||||
vector = init.text_to_vector(_text)
|
||||
data = init.insert_document(_user_id, vector, _database)
|
||||
|
||||
|
||||
return ai_pb2.AddDocumentReply(
|
||||
id=request.text
|
||||
)
|
||||
|
||||
|
||||
def serve():
|
||||
_ADDR = os.getenv("BIND")
|
||||
if _ADDR is None:
|
||||
_ADDR = "[::]:50051"
|
||||
print("Listening on", _ADDR)
|
||||
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
ai_pb2_grpc.add_LLMQueryServicer_to_server(AIServer(), server)
|
||||
server.add_insecure_port(_ADDR)
|
||||
server.start()
|
||||
server.wait_for_termination()
|
||||
|
||||
|
||||
serve()
|
0
document_query/sync.py
Normal file
0
document_query/sync.py
Normal file
@ -0,0 +1,4 @@
|
||||
langchain~=0.0.325
|
||||
pymilvus~=2.3.2
|
||||
pymysql~=1.1.0
|
||||
openai~=0.28.1
|
20
rpc/init.py
20
rpc/init.py
@ -38,23 +38,23 @@ if not utility.has_collection("leaf_documents"):
|
||||
)
|
||||
collection_name = "leaf_documents"
|
||||
print("Create collection...")
|
||||
collection = Collection(
|
||||
_collection = Collection(
|
||||
name=collection_name,
|
||||
schema=schema,
|
||||
using='default',
|
||||
shards_num=2
|
||||
)
|
||||
collection.create_index(
|
||||
_collection.create_index(
|
||||
field_name="vector",
|
||||
index_params={"metric_type": "L2", "M": 8, "efConstruction": 64, "index_type": "HNSW"},
|
||||
)
|
||||
collection.create_index(
|
||||
_collection.create_index(
|
||||
field_name="user_id",
|
||||
index_name="idx_user_id"
|
||||
)
|
||||
|
||||
collection = Collection("leaf_documents")
|
||||
collection.load()
|
||||
_collection = Collection("leaf_documents")
|
||||
_collection.load()
|
||||
|
||||
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
|
||||
|
||||
@ -63,7 +63,11 @@ def text_to_vector(text: str):
|
||||
return embeddings.embed_query(text)
|
||||
|
||||
|
||||
def insert_document(document_id: int, vector: list):
|
||||
collection.insert(
|
||||
data=[document_id, vector],
|
||||
def insert_document(document_id: int, user_id: int, vector: list, collection: str):
|
||||
return _collection.insert(
|
||||
data=[
|
||||
[document_id],
|
||||
[user_id],
|
||||
[vector]
|
||||
],
|
||||
)
|
||||
|
@ -4,11 +4,19 @@ from langchain.embeddings import OpenAIEmbeddings
|
||||
import ai_pb2
|
||||
import ai_pb2_grpc
|
||||
import grpc
|
||||
import init
|
||||
|
||||
|
||||
class AIServer(ai_pb2_grpc.LLMQueryServicer):
|
||||
def AddDocument(self, request, context):
|
||||
print("AddDocument called with", request.text)
|
||||
_text = request.text
|
||||
_user_id = request.user_id
|
||||
_database = request.database
|
||||
_collection = request.collection
|
||||
|
||||
vector = init.text_to_vector(_text)
|
||||
data = init.insert_document(_user_id, vector, _database, _collection)
|
||||
|
||||
|
||||
return ai_pb2.AddDocumentReply(
|
||||
id=request.text
|
||||
|
Loading…
Reference in New Issue
Block a user