This commit is contained in:
iVamp 2023-11-18 23:08:22 +08:00
parent a77ff095f8
commit c8e5c8f389
16 changed files with 199 additions and 442 deletions

View File

@ -6,7 +6,6 @@
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PackageRequirementsSettings">
<option name="removeUnused" value="true" />
<option name="modifyBaseFiles" value="true" />
</component>
</module>

View File

@ -2,5 +2,6 @@
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
<mapping directory="$PROJECT_DIR$/document_ai/proto" vcs="Git" />
</component>
</project>

View File

@ -1,39 +0,0 @@
import time
import documents_pb2_grpc
import documents_pb2
import init
import doc_client
import sys
import signal
def sync_documents():
while True:
documents_response = doc_client.stub.GetNoVectorDocuments(documents_pb2.GetNoVectorDocumentsRequest()).documents
# # get all documents with no vector
for document in documents_response:
docContent = document.title + "\n" + document.content
print("正在更新向量...")
text_vector = init.text_to_vector(docContent)
# update vector
update_vector_response = init.insert_document(document.id, document.user_id, text_vector)
print(update_vector_response)
# update vector_id
update_vector_id_response = doc_client.stub.UpdateDocument(documents_pb2.UpdateDocumentRequest(
id=document.id,
vector_id=update_vector_response
))
print(update_vector_id_response)
print("更新向量完成")
time.sleep(1 * 5)
if __name__ == '__main__':
sync_documents()

View File

@ -1,8 +1,8 @@
import grpc
import documents_pb2_grpc
import documents_pb2
import proto.documents_pb2_grpc
import proto.documents_pb2
print("正在连接...")
print("正在连接到 Library Server...")
channel = grpc.insecure_channel('localhost:8081')
stub = documents_pb2_grpc.DocumentSearchServiceStub(channel)
stub = proto.documents_pb2_grpc.DocumentSearchServiceStub(channel)

View File

@ -1,16 +0,0 @@
syntax = "proto3";
service DocumentQuery {
rpc Query(QueryRequest) returns (QueryResponse) {}
}
message QueryRequest {
string question = 1;
uint64 user_id = 2;
}
message QueryResponse {
string text = 1;
}

View File

@ -1,29 +0,0 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: document_query.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x64ocument_query.proto\"1\n\x0cQueryRequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07user_id\x18\x02 \x01(\x04\"\x1d\n\rQueryResponse\x12\x0c\n\x04text\x18\x01 \x01(\t29\n\rDocumentQuery\x12(\n\x05Query\x12\r.QueryRequest\x1a\x0e.QueryResponse\"\x00\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'document_query_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_globals['_QUERYREQUEST']._serialized_start=24
_globals['_QUERYREQUEST']._serialized_end=73
_globals['_QUERYRESPONSE']._serialized_start=75
_globals['_QUERYRESPONSE']._serialized_end=104
_globals['_DOCUMENTQUERY']._serialized_start=106
_globals['_DOCUMENTQUERY']._serialized_end=163
# @@protoc_insertion_point(module_scope)

View File

@ -1,66 +0,0 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import document_query_pb2 as document__query__pb2
class DocumentQueryStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.Query = channel.unary_unary(
'/DocumentQuery/Query',
request_serializer=document__query__pb2.QueryRequest.SerializeToString,
response_deserializer=document__query__pb2.QueryResponse.FromString,
)
class DocumentQueryServicer(object):
"""Missing associated documentation comment in .proto file."""
def Query(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_DocumentQueryServicer_to_server(servicer, server):
rpc_method_handlers = {
'Query': grpc.unary_unary_rpc_method_handler(
servicer.Query,
request_deserializer=document__query__pb2.QueryRequest.FromString,
response_serializer=document__query__pb2.QueryResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'DocumentQuery', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class DocumentQuery(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def Query(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/DocumentQuery/Query',
document__query__pb2.QueryRequest.SerializeToString,
document__query__pb2.QueryResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

View File

@ -1,52 +0,0 @@
syntax = "proto3";
option go_package="./utils";
package utils;
message Document {
uint64 id = 1;
string title = 2;
string description = 3;
string content = 4;
uint64 vector_id = 5;
uint64 library_id = 6;
uint64 user_id = 7;
}
message GetDocumentsRequest {
string library = 1;
string text = 2;
}
message GetDocumentsResponse {
repeated Document documents = 1;
}
message GetNoVectorDocumentsRequest {
Document document = 1;
}
message GetNoVectorDocumentsResponse {
repeated Document documents = 1;
}
message UpdateDocumentRequest {
uint64 id = 1;
uint64 vector_id = 2;
}
message UpdateDocumentResponse {
Document document = 1;
}
message GetDocumentByIdRequest {
uint64 id = 1;
}
service DocumentSearchService {
rpc GetNoVectorDocuments(GetNoVectorDocumentsRequest) returns (GetNoVectorDocumentsResponse);
rpc UpdateDocument(UpdateDocumentRequest) returns (UpdateDocumentResponse);
rpc GetDocumentById(GetDocumentByIdRequest) returns (Document);
}

View File

@ -1,42 +0,0 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: documents.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x64ocuments.proto\x12\x05utils\"\x83\x01\n\x08\x44ocument\x12\n\n\x02id\x18\x01 \x01(\x04\x12\r\n\x05title\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x03 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x04 \x01(\t\x12\x11\n\tvector_id\x18\x05 \x01(\x04\x12\x12\n\nlibrary_id\x18\x06 \x01(\x04\x12\x0f\n\x07user_id\x18\x07 \x01(\x04\"4\n\x13GetDocumentsRequest\x12\x0f\n\x07library\x18\x01 \x01(\t\x12\x0c\n\x04text\x18\x02 \x01(\t\":\n\x14GetDocumentsResponse\x12\"\n\tdocuments\x18\x01 \x03(\x0b\x32\x0f.utils.Document\"@\n\x1bGetNoVectorDocumentsRequest\x12!\n\x08\x64ocument\x18\x01 \x01(\x0b\x32\x0f.utils.Document\"B\n\x1cGetNoVectorDocumentsResponse\x12\"\n\tdocuments\x18\x01 \x03(\x0b\x32\x0f.utils.Document\"6\n\x15UpdateDocumentRequest\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x11\n\tvector_id\x18\x02 \x01(\x04\";\n\x16UpdateDocumentResponse\x12!\n\x08\x64ocument\x18\x01 \x01(\x0b\x32\x0f.utils.Document\"$\n\x16GetDocumentByIdRequest\x12\n\n\x02id\x18\x01 \x01(\x04\x32\x8a\x02\n\x15\x44ocumentSearchService\x12_\n\x14GetNoVectorDocuments\x12\".utils.GetNoVectorDocumentsRequest\x1a#.utils.GetNoVectorDocumentsResponse\x12M\n\x0eUpdateDocument\x12\x1c.utils.UpdateDocumentRequest\x1a\x1d.utils.UpdateDocumentResponse\x12\x41\n\x0fGetDocumentById\x12\x1d.utils.GetDocumentByIdRequest\x1a\x0f.utils.DocumentB\tZ\x07./utilsb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'documents_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b'Z\007./utils'
_globals['_DOCUMENT']._serialized_start=27
_globals['_DOCUMENT']._serialized_end=158
_globals['_GETDOCUMENTSREQUEST']._serialized_start=160
_globals['_GETDOCUMENTSREQUEST']._serialized_end=212
_globals['_GETDOCUMENTSRESPONSE']._serialized_start=214
_globals['_GETDOCUMENTSRESPONSE']._serialized_end=272
_globals['_GETNOVECTORDOCUMENTSREQUEST']._serialized_start=274
_globals['_GETNOVECTORDOCUMENTSREQUEST']._serialized_end=338
_globals['_GETNOVECTORDOCUMENTSRESPONSE']._serialized_start=340
_globals['_GETNOVECTORDOCUMENTSRESPONSE']._serialized_end=406
_globals['_UPDATEDOCUMENTREQUEST']._serialized_start=408
_globals['_UPDATEDOCUMENTREQUEST']._serialized_end=462
_globals['_UPDATEDOCUMENTRESPONSE']._serialized_start=464
_globals['_UPDATEDOCUMENTRESPONSE']._serialized_end=523
_globals['_GETDOCUMENTBYIDREQUEST']._serialized_start=525
_globals['_GETDOCUMENTBYIDREQUEST']._serialized_end=561
_globals['_DOCUMENTSEARCHSERVICE']._serialized_start=564
_globals['_DOCUMENTSEARCHSERVICE']._serialized_end=830
# @@protoc_insertion_point(module_scope)

View File

@ -1,132 +0,0 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import documents_pb2 as documents__pb2
class DocumentSearchServiceStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.GetNoVectorDocuments = channel.unary_unary(
'/utils.DocumentSearchService/GetNoVectorDocuments',
request_serializer=documents__pb2.GetNoVectorDocumentsRequest.SerializeToString,
response_deserializer=documents__pb2.GetNoVectorDocumentsResponse.FromString,
)
self.UpdateDocument = channel.unary_unary(
'/utils.DocumentSearchService/UpdateDocument',
request_serializer=documents__pb2.UpdateDocumentRequest.SerializeToString,
response_deserializer=documents__pb2.UpdateDocumentResponse.FromString,
)
self.GetDocumentById = channel.unary_unary(
'/utils.DocumentSearchService/GetDocumentById',
request_serializer=documents__pb2.GetDocumentByIdRequest.SerializeToString,
response_deserializer=documents__pb2.Document.FromString,
)
class DocumentSearchServiceServicer(object):
"""Missing associated documentation comment in .proto file."""
def GetNoVectorDocuments(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def UpdateDocument(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GetDocumentById(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_DocumentSearchServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'GetNoVectorDocuments': grpc.unary_unary_rpc_method_handler(
servicer.GetNoVectorDocuments,
request_deserializer=documents__pb2.GetNoVectorDocumentsRequest.FromString,
response_serializer=documents__pb2.GetNoVectorDocumentsResponse.SerializeToString,
),
'UpdateDocument': grpc.unary_unary_rpc_method_handler(
servicer.UpdateDocument,
request_deserializer=documents__pb2.UpdateDocumentRequest.FromString,
response_serializer=documents__pb2.UpdateDocumentResponse.SerializeToString,
),
'GetDocumentById': grpc.unary_unary_rpc_method_handler(
servicer.GetDocumentById,
request_deserializer=documents__pb2.GetDocumentByIdRequest.FromString,
response_serializer=documents__pb2.Document.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'utils.DocumentSearchService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class DocumentSearchService(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def GetNoVectorDocuments(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/utils.DocumentSearchService/GetNoVectorDocuments',
documents__pb2.GetNoVectorDocumentsRequest.SerializeToString,
documents__pb2.GetNoVectorDocumentsResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def UpdateDocument(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/utils.DocumentSearchService/UpdateDocument',
documents__pb2.UpdateDocumentRequest.SerializeToString,
documents__pb2.UpdateDocumentResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def GetDocumentById(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/utils.DocumentSearchService/GetDocumentById',
documents__pb2.GetDocumentByIdRequest.SerializeToString,
documents__pb2.Document.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

View File

@ -20,12 +20,19 @@ if not utility.has_collection("leaf_documents"):
_document_id = FieldSchema(
name="document_id",
dtype=DataType.INT64,
)
_document_chunk_id = FieldSchema(
name="document_chunk_id",
dtype=DataType.INT64,
is_primary=True,
)
_library_id = FieldSchema(
name="library_id",
dtype=DataType.INT64,
)
_user_id = FieldSchema(
name="user_id",
dtype=DataType.INT64,
)
_document_vector = FieldSchema(
name="vector",
@ -33,7 +40,7 @@ if not utility.has_collection("leaf_documents"):
dim=1536
)
schema = CollectionSchema(
fields=[_document_id, _user_id, _document_vector],
fields=[_document_id, _document_chunk_id, _library_id, _user_id, _document_vector],
enable_dynamic_field=True
)
collection_name = "leaf_documents"
@ -63,10 +70,12 @@ def text_to_vector(text: str):
return embeddings.embed_query(text)
def insert_document(document_id: int, user_id: int, vector: list):
def insert_document(document_id: int, document_chunk_id: int, library_id: int, user_id: int, vector: list):
return collection.insert(
data=[
[document_id],
[document_chunk_id],
[library_id],
[user_id],
[vector]
],

View File

@ -1,11 +1,11 @@
from threading import Thread
import chunk
import vector
import server
if __name__ == '__main__':
# Start the worker thread
worker_thread = Thread(target=chunk.sync_documents, args=())
worker_thread = Thread(target=vector.sync_documents, args=())
worker_thread.start()
# Start the server thread

View File

@ -1,5 +1,5 @@
import json
import documents_pb2
import proto.documents_pb2
from langchain import text_splitter
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.embeddings import OpenAIEmbeddings
@ -17,20 +17,24 @@ from pymilvus import (
import init
import doc_client
#
# question = """
# yarn : File C:\\Users\\ivamp\\AppData\\Roaming\\npm\\yarn.ps1 cannot be loaded because running scripts is disabled on this sy
# stem. For more information, see about_Execution_Policies at https:/go.microsoft.com/fwlink/?LinkID=135170.
# At line:1 char:1
# + yarn config set registry https://registry.npm.taobao.org/
# + ~~~~
# + CategoryInfo : SecurityError: (:) [], PSSecurityException
# + FullyQualifiedErrorId : UnauthorizedAccess
#
# 是什么问题,该怎么解决
# """
question = """
yarn : File C:\\Users\\ivamp\\AppData\\Roaming\\npm\\yarn.ps1 cannot be loaded because running scripts is disabled on this sy
stem. For more information, see about_Execution_Policies at https:/go.microsoft.com/fwlink/?LinkID=135170.
At line:1 char:1
+ yarn config set registry https://registry.npm.taobao.org/
+ ~~~~
+ CategoryInfo : SecurityError: (:) [], PSSecurityException
+ FullyQualifiedErrorId : UnauthorizedAccess
是什么问题该怎么解决
为什么我会在 WHMCS 下开发摸不着头脑
"""
vec = init.text_to_vector(question + " (必须使用中文回复)")
vec = init.text_to_vector(question)
# vec = ""
#
@ -47,33 +51,37 @@ search_param = {
}
res = init.collection.search(**search_param)
document_ids = []
document_chunk_ids = []
real_document = []
for i in range(len(res[0])):
_doc_id = res[0][i].id
print("正在获取 " + str(_doc_id) + " 的内容...")
_chunk_id = res[0][i].id
print("正在获取分块 " + str(_chunk_id) + " 的内容...")
try:
_doc_content = doc_client.stub.GetDocumentById(documents_pb2.GetDocumentByIdRequest(
id=_doc_id
_chunk_content = doc_client.stub.GetDocumentChunk(proto.documents_pb2.GetDocumentChunkByIdRequest(
id=_chunk_id
))
_doc_content_full = _doc_content.title + "\n" + _doc_content.content
# print(_chunk_content)
_doc_content_full = _chunk_content.content
# real_document.append(_doc_content)
doc_obj = Document(page_content=_doc_content_full, metadata={"source": _doc_content.title})
# doc_obj = Document(page_content=_doc_content_full, metadata={"source": _chunk_content.title})
doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"})
real_document.append(doc_obj)
except Exception as e:
print(e)
print(real_document)
print("正在调用 LLM...")
chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce", return_intermediate_steps=True,
verbose=True)
question = "必须使用中文回复:" + question
output = chain({"input_documents": real_document, "question": question}, return_only_outputs=False)
print("回复:" + output["output_text"])

View File

@ -2,83 +2,122 @@ import os
from concurrent import futures
import langchain
from langchain.text_splitter import RecursiveCharacterTextSplitter
import document_query_pb2
import document_query_pb2_grpc
import proto.document_query_pb2
import proto.document_query_pb2_grpc
import grpc
import documents_pb2
import proto.documents_pb2
import init
import doc_client
from langchain.llms.openai import OpenAI
from langchain.schema.document import Document
from langchain.embeddings import OpenAIEmbeddings
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.cache import InMemoryCache
langchain.llm_cache = InMemoryCache()
CHUNK_SIZE = 500
class AIServer(document_query_pb2_grpc.DocumentQuery):
def Query(self, request, context):
vec = init.text_to_vector(request.question)
question = request.question + "(必须使用中文回复)"
class AIServer(proto.document_query_pb2_grpc.DocumentQuery):
def Query(self, target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
print("新的请求:" + target.question)
vec = init.text_to_vector(target.question)
question = "Reply in spoken language:" + target.question
search_param = {
"data": [vec],
"anns_field": "vector",
"param": {"metric_type": "L2"},
"limit": 10,
"expr": "user_id == " + str(request.user_id),
"expr": "user_id == " + str(target.user_id),
"output_fields": ["document_id", "user_id"],
}
res = init.collection.search(**search_param)
document_ids = []
# # 最多 5 个
# if len(res[0]) > 5:
# res[0] = res[0][:5]
# document_chunk_ids = []
real_document = []
for i in range(len(res[0])):
_doc_id = res[0][i].id
print("正在获取 " + str(_doc_id) + " 的内容...")
_chunk_id = res[0][i].id
print("正在获取分块 " + str(_chunk_id) + " 的内容...")
try:
_doc_content = doc_client.stub.GetDocumentById(documents_pb2.GetDocumentByIdRequest(
id=_doc_id
_chunk_content = doc_client.stub.GetDocumentChunk(proto.documents_pb2.GetDocumentChunkByIdRequest(
id=_chunk_id
))
_doc_content_full = _doc_content.title + "\n" + _doc_content.content
# real_document.append(_doc_content)
doc_obj = Document(page_content=_doc_content_full, metadata={"source": _doc_content.title})
_doc_content_full = _chunk_content.content
doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"})
real_document.append(doc_obj)
except Exception as e:
print(e)
# print(real_document)
print(real_document)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=0)
all_splits = text_splitter.split_documents(real_document)
print("正在调用 LLM...")
chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce",
return_intermediate_steps=True,
verbose=True)
print("real_document: ", all_splits)
# 文档长度
# print("文档长度: ", len(all_splits))
print("正在调用 LLM: " + question + "...")
chain = load_qa_with_sources_chain(OpenAI(temperature=0, max_tokens=4097), chain_type="map_reduce",
return_intermediate_steps=False,
verbose=False)
output = chain({"input_documents": all_splits, "question": question}, return_only_outputs=False)
output = chain({"input_documents": real_document, "question": question}, return_only_outputs=False)
print("回复:" + output["output_text"])
return document_query_pb2.QueryResponse(
return proto.document_query_pb2.QueryResponse(
text=output["output_text"]
# text = "test"
)
def Chunk(self,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=20,
length_function=len,
add_start_index=True,
)
page_contents = text_splitter.create_documents([
target.text
])
texts = []
for page_content in page_contents:
texts.append(page_content.page_content)
return proto.document_query_pb2.ChunkResponse(
texts=texts
)
@ -89,7 +128,12 @@ def serve():
print("Listening on", _ADDR)
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
document_query_pb2_grpc.add_DocumentQueryServicer_to_server(AIServer(), server)
proto.document_query_pb2_grpc.add_DocumentQueryServicer_to_server(AIServer(), server)
server.add_insecure_port(_ADDR)
server.start()
server.wait_for_termination()
if __name__ == '__main__':
serve()

72
document_ai/vector.py Normal file
View File

@ -0,0 +1,72 @@
import time
import proto.documents_pb2_grpc
import proto.documents_pb2
import init
import doc_client
import sys
import signal
from threading import Thread
threads = []
def sync_documents():
while True:
chunks_response = doc_client.stub.GetNoVectorDocumentChunks(proto.documents_pb2.GetNotVectorDocumentChunksRequest()).chunks
# # get all documents with no vector
for chunk in chunks_response:
#
# # 最多不超过 10 个
# if len(threads) >= 10:
# print("线程数已满,等待 5 秒...")
# time.sleep(5)
# continue
#
# # 等待
# for t in threads:
# if t.is_alive():
# t.join()
# print("线程 " + str(t) + " 已结束。")
# threads.remove(t)
#
# # 创建线程
# print("创建线程...")
# t = Thread(target=vector_and_save, args=(chunk,))
# threads.append(t)
#
vector_and_save(chunk)
print("进入下一次循环...")
time.sleep(1 * 5)
def vector_and_save(chunk):
chunk_content = chunk.content
print("正在进行文本向量化...")
text_vector = init.text_to_vector(chunk_content)
# update vector
update_vector_response = init.insert_document(
document_id=chunk.document.id,
document_chunk_id=chunk.id,
library_id=chunk.document.library_id,
user_id=chunk.document.user_id,
vector=text_vector
)
print(update_vector_response)
# update vector_id
update_vector_id_response = doc_client.stub.UpdateDocumentChunk(proto.documents_pb2.UpdateChunkedDocumentRequest(
id=chunk.id,
vector_id=update_vector_response
))
print(update_vector_id_response)
print("向量化完成。")
if __name__ == '__main__':
sync_documents()

View File

@ -1 +1 @@
python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. .\rpc\ai.proto
python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. .\proto\*.proto