update
This commit is contained in:
parent
a77ff095f8
commit
c8e5c8f389
@ -6,7 +6,6 @@
|
|||||||
<orderEntry type="sourceFolder" forTests="false" />
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
</component>
|
</component>
|
||||||
<component name="PackageRequirementsSettings">
|
<component name="PackageRequirementsSettings">
|
||||||
<option name="removeUnused" value="true" />
|
|
||||||
<option name="modifyBaseFiles" value="true" />
|
<option name="modifyBaseFiles" value="true" />
|
||||||
</component>
|
</component>
|
||||||
</module>
|
</module>
|
@ -2,5 +2,6 @@
|
|||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="VcsDirectoryMappings">
|
<component name="VcsDirectoryMappings">
|
||||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||||
|
<mapping directory="$PROJECT_DIR$/document_ai/proto" vcs="Git" />
|
||||||
</component>
|
</component>
|
||||||
</project>
|
</project>
|
@ -1,39 +0,0 @@
|
|||||||
import time
|
|
||||||
|
|
||||||
import documents_pb2_grpc
|
|
||||||
import documents_pb2
|
|
||||||
import init
|
|
||||||
import doc_client
|
|
||||||
import sys
|
|
||||||
import signal
|
|
||||||
|
|
||||||
|
|
||||||
def sync_documents():
|
|
||||||
while True:
|
|
||||||
documents_response = doc_client.stub.GetNoVectorDocuments(documents_pb2.GetNoVectorDocumentsRequest()).documents
|
|
||||||
|
|
||||||
# # get all documents with no vector
|
|
||||||
for document in documents_response:
|
|
||||||
docContent = document.title + "\n" + document.content
|
|
||||||
|
|
||||||
print("正在更新向量...")
|
|
||||||
text_vector = init.text_to_vector(docContent)
|
|
||||||
|
|
||||||
# update vector
|
|
||||||
update_vector_response = init.insert_document(document.id, document.user_id, text_vector)
|
|
||||||
print(update_vector_response)
|
|
||||||
|
|
||||||
# update vector_id
|
|
||||||
update_vector_id_response = doc_client.stub.UpdateDocument(documents_pb2.UpdateDocumentRequest(
|
|
||||||
id=document.id,
|
|
||||||
vector_id=update_vector_response
|
|
||||||
))
|
|
||||||
|
|
||||||
print(update_vector_id_response)
|
|
||||||
print("更新向量完成")
|
|
||||||
|
|
||||||
time.sleep(1 * 5)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
sync_documents()
|
|
@ -1,8 +1,8 @@
|
|||||||
import grpc
|
import grpc
|
||||||
import documents_pb2_grpc
|
import proto.documents_pb2_grpc
|
||||||
import documents_pb2
|
import proto.documents_pb2
|
||||||
|
|
||||||
print("正在连接...")
|
print("正在连接到 Library Server...")
|
||||||
channel = grpc.insecure_channel('localhost:8081')
|
channel = grpc.insecure_channel('localhost:8081')
|
||||||
|
|
||||||
stub = documents_pb2_grpc.DocumentSearchServiceStub(channel)
|
stub = proto.documents_pb2_grpc.DocumentSearchServiceStub(channel)
|
||||||
|
@ -1,16 +0,0 @@
|
|||||||
syntax = "proto3";
|
|
||||||
|
|
||||||
service DocumentQuery {
|
|
||||||
rpc Query(QueryRequest) returns (QueryResponse) {}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
message QueryRequest {
|
|
||||||
string question = 1;
|
|
||||||
uint64 user_id = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
message QueryResponse {
|
|
||||||
string text = 1;
|
|
||||||
}
|
|
||||||
|
|
@ -1,29 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
||||||
# source: document_query.proto
|
|
||||||
"""Generated protocol buffer code."""
|
|
||||||
from google.protobuf import descriptor as _descriptor
|
|
||||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
||||||
from google.protobuf import symbol_database as _symbol_database
|
|
||||||
from google.protobuf.internal import builder as _builder
|
|
||||||
# @@protoc_insertion_point(imports)
|
|
||||||
|
|
||||||
_sym_db = _symbol_database.Default()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x64ocument_query.proto\"1\n\x0cQueryRequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07user_id\x18\x02 \x01(\x04\"\x1d\n\rQueryResponse\x12\x0c\n\x04text\x18\x01 \x01(\t29\n\rDocumentQuery\x12(\n\x05Query\x12\r.QueryRequest\x1a\x0e.QueryResponse\"\x00\x62\x06proto3')
|
|
||||||
|
|
||||||
_globals = globals()
|
|
||||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
||||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'document_query_pb2', _globals)
|
|
||||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
|
||||||
DESCRIPTOR._options = None
|
|
||||||
_globals['_QUERYREQUEST']._serialized_start=24
|
|
||||||
_globals['_QUERYREQUEST']._serialized_end=73
|
|
||||||
_globals['_QUERYRESPONSE']._serialized_start=75
|
|
||||||
_globals['_QUERYRESPONSE']._serialized_end=104
|
|
||||||
_globals['_DOCUMENTQUERY']._serialized_start=106
|
|
||||||
_globals['_DOCUMENTQUERY']._serialized_end=163
|
|
||||||
# @@protoc_insertion_point(module_scope)
|
|
@ -1,66 +0,0 @@
|
|||||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
|
||||||
"""Client and server classes corresponding to protobuf-defined services."""
|
|
||||||
import grpc
|
|
||||||
|
|
||||||
import document_query_pb2 as document__query__pb2
|
|
||||||
|
|
||||||
|
|
||||||
class DocumentQueryStub(object):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
|
|
||||||
def __init__(self, channel):
|
|
||||||
"""Constructor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel: A grpc.Channel.
|
|
||||||
"""
|
|
||||||
self.Query = channel.unary_unary(
|
|
||||||
'/DocumentQuery/Query',
|
|
||||||
request_serializer=document__query__pb2.QueryRequest.SerializeToString,
|
|
||||||
response_deserializer=document__query__pb2.QueryResponse.FromString,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DocumentQueryServicer(object):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
|
|
||||||
def Query(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
|
|
||||||
def add_DocumentQueryServicer_to_server(servicer, server):
|
|
||||||
rpc_method_handlers = {
|
|
||||||
'Query': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.Query,
|
|
||||||
request_deserializer=document__query__pb2.QueryRequest.FromString,
|
|
||||||
response_serializer=document__query__pb2.QueryResponse.SerializeToString,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
generic_handler = grpc.method_handlers_generic_handler(
|
|
||||||
'DocumentQuery', rpc_method_handlers)
|
|
||||||
server.add_generic_rpc_handlers((generic_handler,))
|
|
||||||
|
|
||||||
|
|
||||||
# This class is part of an EXPERIMENTAL API.
|
|
||||||
class DocumentQuery(object):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def Query(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/DocumentQuery/Query',
|
|
||||||
document__query__pb2.QueryRequest.SerializeToString,
|
|
||||||
document__query__pb2.QueryResponse.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
@ -1,52 +0,0 @@
|
|||||||
syntax = "proto3";
|
|
||||||
|
|
||||||
option go_package="./utils";
|
|
||||||
package utils;
|
|
||||||
|
|
||||||
message Document {
|
|
||||||
uint64 id = 1;
|
|
||||||
string title = 2;
|
|
||||||
string description = 3;
|
|
||||||
string content = 4;
|
|
||||||
uint64 vector_id = 5;
|
|
||||||
uint64 library_id = 6;
|
|
||||||
uint64 user_id = 7;
|
|
||||||
}
|
|
||||||
|
|
||||||
message GetDocumentsRequest {
|
|
||||||
string library = 1;
|
|
||||||
string text = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
message GetDocumentsResponse {
|
|
||||||
repeated Document documents = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message GetNoVectorDocumentsRequest {
|
|
||||||
Document document = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message GetNoVectorDocumentsResponse {
|
|
||||||
repeated Document documents = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
message UpdateDocumentRequest {
|
|
||||||
uint64 id = 1;
|
|
||||||
uint64 vector_id = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
message UpdateDocumentResponse {
|
|
||||||
Document document = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message GetDocumentByIdRequest {
|
|
||||||
uint64 id = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
service DocumentSearchService {
|
|
||||||
rpc GetNoVectorDocuments(GetNoVectorDocumentsRequest) returns (GetNoVectorDocumentsResponse);
|
|
||||||
rpc UpdateDocument(UpdateDocumentRequest) returns (UpdateDocumentResponse);
|
|
||||||
rpc GetDocumentById(GetDocumentByIdRequest) returns (Document);
|
|
||||||
}
|
|
@ -1,42 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
||||||
# source: documents.proto
|
|
||||||
"""Generated protocol buffer code."""
|
|
||||||
from google.protobuf import descriptor as _descriptor
|
|
||||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
||||||
from google.protobuf import symbol_database as _symbol_database
|
|
||||||
from google.protobuf.internal import builder as _builder
|
|
||||||
# @@protoc_insertion_point(imports)
|
|
||||||
|
|
||||||
_sym_db = _symbol_database.Default()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x64ocuments.proto\x12\x05utils\"\x83\x01\n\x08\x44ocument\x12\n\n\x02id\x18\x01 \x01(\x04\x12\r\n\x05title\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x03 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x04 \x01(\t\x12\x11\n\tvector_id\x18\x05 \x01(\x04\x12\x12\n\nlibrary_id\x18\x06 \x01(\x04\x12\x0f\n\x07user_id\x18\x07 \x01(\x04\"4\n\x13GetDocumentsRequest\x12\x0f\n\x07library\x18\x01 \x01(\t\x12\x0c\n\x04text\x18\x02 \x01(\t\":\n\x14GetDocumentsResponse\x12\"\n\tdocuments\x18\x01 \x03(\x0b\x32\x0f.utils.Document\"@\n\x1bGetNoVectorDocumentsRequest\x12!\n\x08\x64ocument\x18\x01 \x01(\x0b\x32\x0f.utils.Document\"B\n\x1cGetNoVectorDocumentsResponse\x12\"\n\tdocuments\x18\x01 \x03(\x0b\x32\x0f.utils.Document\"6\n\x15UpdateDocumentRequest\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x11\n\tvector_id\x18\x02 \x01(\x04\";\n\x16UpdateDocumentResponse\x12!\n\x08\x64ocument\x18\x01 \x01(\x0b\x32\x0f.utils.Document\"$\n\x16GetDocumentByIdRequest\x12\n\n\x02id\x18\x01 \x01(\x04\x32\x8a\x02\n\x15\x44ocumentSearchService\x12_\n\x14GetNoVectorDocuments\x12\".utils.GetNoVectorDocumentsRequest\x1a#.utils.GetNoVectorDocumentsResponse\x12M\n\x0eUpdateDocument\x12\x1c.utils.UpdateDocumentRequest\x1a\x1d.utils.UpdateDocumentResponse\x12\x41\n\x0fGetDocumentById\x12\x1d.utils.GetDocumentByIdRequest\x1a\x0f.utils.DocumentB\tZ\x07./utilsb\x06proto3')
|
|
||||||
|
|
||||||
_globals = globals()
|
|
||||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
||||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'documents_pb2', _globals)
|
|
||||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
|
||||||
DESCRIPTOR._options = None
|
|
||||||
DESCRIPTOR._serialized_options = b'Z\007./utils'
|
|
||||||
_globals['_DOCUMENT']._serialized_start=27
|
|
||||||
_globals['_DOCUMENT']._serialized_end=158
|
|
||||||
_globals['_GETDOCUMENTSREQUEST']._serialized_start=160
|
|
||||||
_globals['_GETDOCUMENTSREQUEST']._serialized_end=212
|
|
||||||
_globals['_GETDOCUMENTSRESPONSE']._serialized_start=214
|
|
||||||
_globals['_GETDOCUMENTSRESPONSE']._serialized_end=272
|
|
||||||
_globals['_GETNOVECTORDOCUMENTSREQUEST']._serialized_start=274
|
|
||||||
_globals['_GETNOVECTORDOCUMENTSREQUEST']._serialized_end=338
|
|
||||||
_globals['_GETNOVECTORDOCUMENTSRESPONSE']._serialized_start=340
|
|
||||||
_globals['_GETNOVECTORDOCUMENTSRESPONSE']._serialized_end=406
|
|
||||||
_globals['_UPDATEDOCUMENTREQUEST']._serialized_start=408
|
|
||||||
_globals['_UPDATEDOCUMENTREQUEST']._serialized_end=462
|
|
||||||
_globals['_UPDATEDOCUMENTRESPONSE']._serialized_start=464
|
|
||||||
_globals['_UPDATEDOCUMENTRESPONSE']._serialized_end=523
|
|
||||||
_globals['_GETDOCUMENTBYIDREQUEST']._serialized_start=525
|
|
||||||
_globals['_GETDOCUMENTBYIDREQUEST']._serialized_end=561
|
|
||||||
_globals['_DOCUMENTSEARCHSERVICE']._serialized_start=564
|
|
||||||
_globals['_DOCUMENTSEARCHSERVICE']._serialized_end=830
|
|
||||||
# @@protoc_insertion_point(module_scope)
|
|
@ -1,132 +0,0 @@
|
|||||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
|
||||||
"""Client and server classes corresponding to protobuf-defined services."""
|
|
||||||
import grpc
|
|
||||||
|
|
||||||
import documents_pb2 as documents__pb2
|
|
||||||
|
|
||||||
|
|
||||||
class DocumentSearchServiceStub(object):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
|
|
||||||
def __init__(self, channel):
|
|
||||||
"""Constructor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel: A grpc.Channel.
|
|
||||||
"""
|
|
||||||
self.GetNoVectorDocuments = channel.unary_unary(
|
|
||||||
'/utils.DocumentSearchService/GetNoVectorDocuments',
|
|
||||||
request_serializer=documents__pb2.GetNoVectorDocumentsRequest.SerializeToString,
|
|
||||||
response_deserializer=documents__pb2.GetNoVectorDocumentsResponse.FromString,
|
|
||||||
)
|
|
||||||
self.UpdateDocument = channel.unary_unary(
|
|
||||||
'/utils.DocumentSearchService/UpdateDocument',
|
|
||||||
request_serializer=documents__pb2.UpdateDocumentRequest.SerializeToString,
|
|
||||||
response_deserializer=documents__pb2.UpdateDocumentResponse.FromString,
|
|
||||||
)
|
|
||||||
self.GetDocumentById = channel.unary_unary(
|
|
||||||
'/utils.DocumentSearchService/GetDocumentById',
|
|
||||||
request_serializer=documents__pb2.GetDocumentByIdRequest.SerializeToString,
|
|
||||||
response_deserializer=documents__pb2.Document.FromString,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DocumentSearchServiceServicer(object):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
|
|
||||||
def GetNoVectorDocuments(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def UpdateDocument(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def GetDocumentById(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
|
|
||||||
def add_DocumentSearchServiceServicer_to_server(servicer, server):
|
|
||||||
rpc_method_handlers = {
|
|
||||||
'GetNoVectorDocuments': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.GetNoVectorDocuments,
|
|
||||||
request_deserializer=documents__pb2.GetNoVectorDocumentsRequest.FromString,
|
|
||||||
response_serializer=documents__pb2.GetNoVectorDocumentsResponse.SerializeToString,
|
|
||||||
),
|
|
||||||
'UpdateDocument': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.UpdateDocument,
|
|
||||||
request_deserializer=documents__pb2.UpdateDocumentRequest.FromString,
|
|
||||||
response_serializer=documents__pb2.UpdateDocumentResponse.SerializeToString,
|
|
||||||
),
|
|
||||||
'GetDocumentById': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.GetDocumentById,
|
|
||||||
request_deserializer=documents__pb2.GetDocumentByIdRequest.FromString,
|
|
||||||
response_serializer=documents__pb2.Document.SerializeToString,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
generic_handler = grpc.method_handlers_generic_handler(
|
|
||||||
'utils.DocumentSearchService', rpc_method_handlers)
|
|
||||||
server.add_generic_rpc_handlers((generic_handler,))
|
|
||||||
|
|
||||||
|
|
||||||
# This class is part of an EXPERIMENTAL API.
|
|
||||||
class DocumentSearchService(object):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def GetNoVectorDocuments(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/utils.DocumentSearchService/GetNoVectorDocuments',
|
|
||||||
documents__pb2.GetNoVectorDocumentsRequest.SerializeToString,
|
|
||||||
documents__pb2.GetNoVectorDocumentsResponse.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def UpdateDocument(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/utils.DocumentSearchService/UpdateDocument',
|
|
||||||
documents__pb2.UpdateDocumentRequest.SerializeToString,
|
|
||||||
documents__pb2.UpdateDocumentResponse.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def GetDocumentById(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(request, target, '/utils.DocumentSearchService/GetDocumentById',
|
|
||||||
documents__pb2.GetDocumentByIdRequest.SerializeToString,
|
|
||||||
documents__pb2.Document.FromString,
|
|
||||||
options, channel_credentials,
|
|
||||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
@ -20,12 +20,19 @@ if not utility.has_collection("leaf_documents"):
|
|||||||
_document_id = FieldSchema(
|
_document_id = FieldSchema(
|
||||||
name="document_id",
|
name="document_id",
|
||||||
dtype=DataType.INT64,
|
dtype=DataType.INT64,
|
||||||
|
)
|
||||||
|
_document_chunk_id = FieldSchema(
|
||||||
|
name="document_chunk_id",
|
||||||
|
dtype=DataType.INT64,
|
||||||
is_primary=True,
|
is_primary=True,
|
||||||
)
|
)
|
||||||
|
_library_id = FieldSchema(
|
||||||
|
name="library_id",
|
||||||
|
dtype=DataType.INT64,
|
||||||
|
)
|
||||||
_user_id = FieldSchema(
|
_user_id = FieldSchema(
|
||||||
name="user_id",
|
name="user_id",
|
||||||
dtype=DataType.INT64,
|
dtype=DataType.INT64,
|
||||||
|
|
||||||
)
|
)
|
||||||
_document_vector = FieldSchema(
|
_document_vector = FieldSchema(
|
||||||
name="vector",
|
name="vector",
|
||||||
@ -33,7 +40,7 @@ if not utility.has_collection("leaf_documents"):
|
|||||||
dim=1536
|
dim=1536
|
||||||
)
|
)
|
||||||
schema = CollectionSchema(
|
schema = CollectionSchema(
|
||||||
fields=[_document_id, _user_id, _document_vector],
|
fields=[_document_id, _document_chunk_id, _library_id, _user_id, _document_vector],
|
||||||
enable_dynamic_field=True
|
enable_dynamic_field=True
|
||||||
)
|
)
|
||||||
collection_name = "leaf_documents"
|
collection_name = "leaf_documents"
|
||||||
@ -63,10 +70,12 @@ def text_to_vector(text: str):
|
|||||||
return embeddings.embed_query(text)
|
return embeddings.embed_query(text)
|
||||||
|
|
||||||
|
|
||||||
def insert_document(document_id: int, user_id: int, vector: list):
|
def insert_document(document_id: int, document_chunk_id: int, library_id: int, user_id: int, vector: list):
|
||||||
return collection.insert(
|
return collection.insert(
|
||||||
data=[
|
data=[
|
||||||
[document_id],
|
[document_id],
|
||||||
|
[document_chunk_id],
|
||||||
|
[library_id],
|
||||||
[user_id],
|
[user_id],
|
||||||
[vector]
|
[vector]
|
||||||
],
|
],
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
import chunk
|
import vector
|
||||||
import server
|
import server
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# Start the worker thread
|
# Start the worker thread
|
||||||
worker_thread = Thread(target=chunk.sync_documents, args=())
|
worker_thread = Thread(target=vector.sync_documents, args=())
|
||||||
worker_thread.start()
|
worker_thread.start()
|
||||||
|
|
||||||
# Start the server thread
|
# Start the server thread
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
import documents_pb2
|
import proto.documents_pb2
|
||||||
from langchain import text_splitter
|
from langchain import text_splitter
|
||||||
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
||||||
from langchain.embeddings import OpenAIEmbeddings
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
@ -17,20 +17,24 @@ from pymilvus import (
|
|||||||
import init
|
import init
|
||||||
import doc_client
|
import doc_client
|
||||||
|
|
||||||
|
#
|
||||||
|
# question = """
|
||||||
|
# yarn : File C:\\Users\\ivamp\\AppData\\Roaming\\npm\\yarn.ps1 cannot be loaded because running scripts is disabled on this sy
|
||||||
|
# stem. For more information, see about_Execution_Policies at https:/go.microsoft.com/fwlink/?LinkID=135170.
|
||||||
|
# At line:1 char:1
|
||||||
|
# + yarn config set registry https://registry.npm.taobao.org/
|
||||||
|
# + ~~~~
|
||||||
|
# + CategoryInfo : SecurityError: (:) [], PSSecurityException
|
||||||
|
# + FullyQualifiedErrorId : UnauthorizedAccess
|
||||||
|
#
|
||||||
|
# 是什么问题,该怎么解决
|
||||||
|
# """
|
||||||
|
|
||||||
question = """
|
question = """
|
||||||
yarn : File C:\\Users\\ivamp\\AppData\\Roaming\\npm\\yarn.ps1 cannot be loaded because running scripts is disabled on this sy
|
为什么我会在 WHMCS 下开发摸不着头脑
|
||||||
stem. For more information, see about_Execution_Policies at https:/go.microsoft.com/fwlink/?LinkID=135170.
|
|
||||||
At line:1 char:1
|
|
||||||
+ yarn config set registry https://registry.npm.taobao.org/
|
|
||||||
+ ~~~~
|
|
||||||
+ CategoryInfo : SecurityError: (:) [], PSSecurityException
|
|
||||||
+ FullyQualifiedErrorId : UnauthorizedAccess
|
|
||||||
|
|
||||||
是什么问题,该怎么解决
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
vec = init.text_to_vector(question + " (必须使用中文回复)")
|
vec = init.text_to_vector(question)
|
||||||
|
|
||||||
# vec = ""
|
# vec = ""
|
||||||
#
|
#
|
||||||
@ -47,33 +51,37 @@ search_param = {
|
|||||||
}
|
}
|
||||||
res = init.collection.search(**search_param)
|
res = init.collection.search(**search_param)
|
||||||
|
|
||||||
document_ids = []
|
document_chunk_ids = []
|
||||||
real_document = []
|
real_document = []
|
||||||
|
|
||||||
for i in range(len(res[0])):
|
for i in range(len(res[0])):
|
||||||
_doc_id = res[0][i].id
|
_chunk_id = res[0][i].id
|
||||||
print("正在获取 " + str(_doc_id) + " 的内容...")
|
print("正在获取分块 " + str(_chunk_id) + " 的内容...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_doc_content = doc_client.stub.GetDocumentById(documents_pb2.GetDocumentByIdRequest(
|
_chunk_content = doc_client.stub.GetDocumentChunk(proto.documents_pb2.GetDocumentChunkByIdRequest(
|
||||||
id=_doc_id
|
id=_chunk_id
|
||||||
))
|
))
|
||||||
_doc_content_full = _doc_content.title + "\n" + _doc_content.content
|
|
||||||
|
# print(_chunk_content)
|
||||||
|
|
||||||
|
_doc_content_full = _chunk_content.content
|
||||||
|
|
||||||
# real_document.append(_doc_content)
|
# real_document.append(_doc_content)
|
||||||
doc_obj = Document(page_content=_doc_content_full, metadata={"source": _doc_content.title})
|
# doc_obj = Document(page_content=_doc_content_full, metadata={"source": _chunk_content.title})
|
||||||
|
doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"})
|
||||||
|
|
||||||
real_document.append(doc_obj)
|
real_document.append(doc_obj)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
print(real_document)
|
print(real_document)
|
||||||
|
|
||||||
print("正在调用 LLM...")
|
print("正在调用 LLM...")
|
||||||
chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce", return_intermediate_steps=True,
|
chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce", return_intermediate_steps=True,
|
||||||
verbose=True)
|
verbose=True)
|
||||||
|
|
||||||
|
question = "必须使用中文回复:" + question
|
||||||
output = chain({"input_documents": real_document, "question": question}, return_only_outputs=False)
|
output = chain({"input_documents": real_document, "question": question}, return_only_outputs=False)
|
||||||
print("回复:" + output["output_text"])
|
print("回复:" + output["output_text"])
|
||||||
|
@ -2,83 +2,122 @@ import os
|
|||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
||||||
|
|
||||||
import document_query_pb2
|
import proto.document_query_pb2
|
||||||
import document_query_pb2_grpc
|
import proto.document_query_pb2_grpc
|
||||||
import grpc
|
import grpc
|
||||||
import documents_pb2
|
import proto.documents_pb2
|
||||||
import init
|
import init
|
||||||
import doc_client
|
import doc_client
|
||||||
from langchain.llms.openai import OpenAI
|
from langchain.llms.openai import OpenAI
|
||||||
from langchain.schema.document import Document
|
from langchain.schema.document import Document
|
||||||
from langchain.embeddings import OpenAIEmbeddings
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
from langchain.cache import InMemoryCache
|
from langchain.cache import InMemoryCache
|
||||||
|
|
||||||
langchain.llm_cache = InMemoryCache()
|
langchain.llm_cache = InMemoryCache()
|
||||||
|
|
||||||
|
CHUNK_SIZE = 500
|
||||||
|
|
||||||
class AIServer(document_query_pb2_grpc.DocumentQuery):
|
|
||||||
def Query(self, request, context):
|
|
||||||
vec = init.text_to_vector(request.question)
|
|
||||||
|
|
||||||
question = request.question + "(必须使用中文回复)"
|
class AIServer(proto.document_query_pb2_grpc.DocumentQuery):
|
||||||
|
def Query(self, target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
|
||||||
|
|
||||||
|
print("新的请求:" + target.question)
|
||||||
|
vec = init.text_to_vector(target.question)
|
||||||
|
|
||||||
|
question = "Reply in spoken language:" + target.question
|
||||||
|
|
||||||
search_param = {
|
search_param = {
|
||||||
"data": [vec],
|
"data": [vec],
|
||||||
"anns_field": "vector",
|
"anns_field": "vector",
|
||||||
"param": {"metric_type": "L2"},
|
"param": {"metric_type": "L2"},
|
||||||
"limit": 10,
|
"limit": 10,
|
||||||
"expr": "user_id == " + str(request.user_id),
|
"expr": "user_id == " + str(target.user_id),
|
||||||
"output_fields": ["document_id", "user_id"],
|
"output_fields": ["document_id", "user_id"],
|
||||||
}
|
}
|
||||||
|
|
||||||
res = init.collection.search(**search_param)
|
res = init.collection.search(**search_param)
|
||||||
|
|
||||||
document_ids = []
|
# # 最多 5 个
|
||||||
|
# if len(res[0]) > 5:
|
||||||
|
# res[0] = res[0][:5]
|
||||||
|
|
||||||
|
|
||||||
|
# document_chunk_ids = []
|
||||||
real_document = []
|
real_document = []
|
||||||
|
|
||||||
for i in range(len(res[0])):
|
for i in range(len(res[0])):
|
||||||
_doc_id = res[0][i].id
|
_chunk_id = res[0][i].id
|
||||||
print("正在获取 " + str(_doc_id) + " 的内容...")
|
print("正在获取分块 " + str(_chunk_id) + " 的内容...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_doc_content = doc_client.stub.GetDocumentById(documents_pb2.GetDocumentByIdRequest(
|
_chunk_content = doc_client.stub.GetDocumentChunk(proto.documents_pb2.GetDocumentChunkByIdRequest(
|
||||||
id=_doc_id
|
id=_chunk_id
|
||||||
))
|
))
|
||||||
_doc_content_full = _doc_content.title + "\n" + _doc_content.content
|
|
||||||
|
|
||||||
# real_document.append(_doc_content)
|
_doc_content_full = _chunk_content.content
|
||||||
doc_obj = Document(page_content=_doc_content_full, metadata={"source": _doc_content.title})
|
|
||||||
|
doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"})
|
||||||
|
|
||||||
real_document.append(doc_obj)
|
real_document.append(doc_obj)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
# print(real_document)
|
print(real_document)
|
||||||
|
|
||||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=0)
|
print("正在调用 LLM...")
|
||||||
all_splits = text_splitter.split_documents(real_document)
|
chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce",
|
||||||
|
return_intermediate_steps=True,
|
||||||
|
verbose=True)
|
||||||
|
|
||||||
print("real_document: ", all_splits)
|
output = chain({"input_documents": real_document, "question": question}, return_only_outputs=False)
|
||||||
|
|
||||||
# 文档长度
|
|
||||||
# print("文档长度: ", len(all_splits))
|
|
||||||
|
|
||||||
print("正在调用 LLM: " + question + "...")
|
|
||||||
|
|
||||||
chain = load_qa_with_sources_chain(OpenAI(temperature=0, max_tokens=4097), chain_type="map_reduce",
|
|
||||||
return_intermediate_steps=False,
|
|
||||||
verbose=False)
|
|
||||||
output = chain({"input_documents": all_splits, "question": question}, return_only_outputs=False)
|
|
||||||
print("回复:" + output["output_text"])
|
print("回复:" + output["output_text"])
|
||||||
|
|
||||||
return document_query_pb2.QueryResponse(
|
return proto.document_query_pb2.QueryResponse(
|
||||||
text=output["output_text"]
|
text=output["output_text"]
|
||||||
# text = "test"
|
)
|
||||||
|
|
||||||
|
def Chunk(self,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=CHUNK_SIZE,
|
||||||
|
chunk_overlap=20,
|
||||||
|
length_function=len,
|
||||||
|
add_start_index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
page_contents = text_splitter.create_documents([
|
||||||
|
target.text
|
||||||
|
])
|
||||||
|
|
||||||
|
texts = []
|
||||||
|
|
||||||
|
for page_content in page_contents:
|
||||||
|
texts.append(page_content.page_content)
|
||||||
|
|
||||||
|
return proto.document_query_pb2.ChunkResponse(
|
||||||
|
texts=texts
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -89,7 +128,12 @@ def serve():
|
|||||||
print("Listening on", _ADDR)
|
print("Listening on", _ADDR)
|
||||||
|
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||||
document_query_pb2_grpc.add_DocumentQueryServicer_to_server(AIServer(), server)
|
proto.document_query_pb2_grpc.add_DocumentQueryServicer_to_server(AIServer(), server)
|
||||||
|
|
||||||
server.add_insecure_port(_ADDR)
|
server.add_insecure_port(_ADDR)
|
||||||
server.start()
|
server.start()
|
||||||
server.wait_for_termination()
|
server.wait_for_termination()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
serve()
|
||||||
|
72
document_ai/vector.py
Normal file
72
document_ai/vector.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import proto.documents_pb2_grpc
|
||||||
|
import proto.documents_pb2
|
||||||
|
import init
|
||||||
|
import doc_client
|
||||||
|
import sys
|
||||||
|
import signal
|
||||||
|
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
threads = []
|
||||||
|
|
||||||
|
def sync_documents():
|
||||||
|
while True:
|
||||||
|
chunks_response = doc_client.stub.GetNoVectorDocumentChunks(proto.documents_pb2.GetNotVectorDocumentChunksRequest()).chunks
|
||||||
|
|
||||||
|
# # get all documents with no vector
|
||||||
|
for chunk in chunks_response:
|
||||||
|
#
|
||||||
|
# # 最多不超过 10 个
|
||||||
|
# if len(threads) >= 10:
|
||||||
|
# print("线程数已满,等待 5 秒...")
|
||||||
|
# time.sleep(5)
|
||||||
|
# continue
|
||||||
|
#
|
||||||
|
# # 等待
|
||||||
|
# for t in threads:
|
||||||
|
# if t.is_alive():
|
||||||
|
# t.join()
|
||||||
|
# print("线程 " + str(t) + " 已结束。")
|
||||||
|
# threads.remove(t)
|
||||||
|
#
|
||||||
|
# # 创建线程
|
||||||
|
# print("创建线程...")
|
||||||
|
# t = Thread(target=vector_and_save, args=(chunk,))
|
||||||
|
# threads.append(t)
|
||||||
|
#
|
||||||
|
vector_and_save(chunk)
|
||||||
|
|
||||||
|
print("进入下一次循环...")
|
||||||
|
time.sleep(1 * 5)
|
||||||
|
|
||||||
|
|
||||||
|
def vector_and_save(chunk):
|
||||||
|
chunk_content = chunk.content
|
||||||
|
|
||||||
|
print("正在进行文本向量化...")
|
||||||
|
text_vector = init.text_to_vector(chunk_content)
|
||||||
|
|
||||||
|
# update vector
|
||||||
|
update_vector_response = init.insert_document(
|
||||||
|
document_id=chunk.document.id,
|
||||||
|
document_chunk_id=chunk.id,
|
||||||
|
library_id=chunk.document.library_id,
|
||||||
|
user_id=chunk.document.user_id,
|
||||||
|
vector=text_vector
|
||||||
|
)
|
||||||
|
print(update_vector_response)
|
||||||
|
|
||||||
|
# update vector_id
|
||||||
|
update_vector_id_response = doc_client.stub.UpdateDocumentChunk(proto.documents_pb2.UpdateChunkedDocumentRequest(
|
||||||
|
id=chunk.id,
|
||||||
|
vector_id=update_vector_response
|
||||||
|
))
|
||||||
|
|
||||||
|
print(update_vector_id_response)
|
||||||
|
print("向量化完成。")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
sync_documents()
|
Loading…
Reference in New Issue
Block a user