diff --git a/document_ai/doc_client.py b/document_ai/doc_client.py new file mode 100644 index 0000000..a82ab20 --- /dev/null +++ b/document_ai/doc_client.py @@ -0,0 +1,8 @@ +import grpc +import documents_pb2_grpc +import documents_pb2 + +print("正在连接...") +channel = grpc.insecure_channel('localhost:8081') + +stub = documents_pb2_grpc.DocumentSearchServiceStub(channel) diff --git a/document_ai/document_query.proto b/document_ai/document_query.proto new file mode 100644 index 0000000..4cc520e --- /dev/null +++ b/document_ai/document_query.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +service DocumentQuery { + rpc Query(QueryRequest) returns (QueryResponse) {} +} + + +message QueryRequest { + string question = 1; + uint64 user_id = 2; +} + +message QueryResponse { + string text = 1; +} + diff --git a/document_ai/document_query_pb2.py b/document_ai/document_query_pb2.py new file mode 100644 index 0000000..11b1ff3 --- /dev/null +++ b/document_ai/document_query_pb2.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: document_query.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x64ocument_query.proto\"1\n\x0cQueryRequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07user_id\x18\x02 \x01(\x04\"\x1d\n\rQueryResponse\x12\x0c\n\x04text\x18\x01 \x01(\t29\n\rDocumentQuery\x12(\n\x05Query\x12\r.QueryRequest\x1a\x0e.QueryResponse\"\x00\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'document_query_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_QUERYREQUEST']._serialized_start=24 + _globals['_QUERYREQUEST']._serialized_end=73 + _globals['_QUERYRESPONSE']._serialized_start=75 + _globals['_QUERYRESPONSE']._serialized_end=104 + _globals['_DOCUMENTQUERY']._serialized_start=106 + _globals['_DOCUMENTQUERY']._serialized_end=163 +# @@protoc_insertion_point(module_scope) diff --git a/document_ai/document_query_pb2_grpc.py b/document_ai/document_query_pb2_grpc.py new file mode 100644 index 0000000..74994a9 --- /dev/null +++ b/document_ai/document_query_pb2_grpc.py @@ -0,0 +1,66 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import document_query_pb2 as document__query__pb2 + + +class DocumentQueryStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Query = channel.unary_unary( + '/DocumentQuery/Query', + request_serializer=document__query__pb2.QueryRequest.SerializeToString, + response_deserializer=document__query__pb2.QueryResponse.FromString, + ) + + +class DocumentQueryServicer(object): + """Missing associated documentation comment in .proto file.""" + + def Query(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_DocumentQueryServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Query': grpc.unary_unary_rpc_method_handler( + servicer.Query, + request_deserializer=document__query__pb2.QueryRequest.FromString, + response_serializer=document__query__pb2.QueryResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'DocumentQuery', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class DocumentQuery(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def Query(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/DocumentQuery/Query', + document__query__pb2.QueryRequest.SerializeToString, + document__query__pb2.QueryResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/document_ai/documents.proto b/document_ai/documents.proto new file mode 100644 index 0000000..1cfc135 --- /dev/null +++ b/document_ai/documents.proto @@ -0,0 +1,52 @@ +syntax = "proto3"; + +option go_package="./utils"; +package utils; + +message Document { + uint64 id = 1; + string title = 2; + string description = 3; + string content = 4; + uint64 vector_id = 5; + uint64 library_id = 6; + uint64 user_id = 7; +} + +message GetDocumentsRequest { + string library = 1; + string text = 2; +} + +message GetDocumentsResponse { + repeated Document documents = 1; +} + +message GetNoVectorDocumentsRequest { + Document document = 1; +} + +message GetNoVectorDocumentsResponse { + repeated Document documents = 1; +} + + +message UpdateDocumentRequest { + uint64 id = 1; + uint64 vector_id = 2; +} + +message UpdateDocumentResponse { + Document document = 1; +} + +message GetDocumentByIdRequest { + uint64 id = 1; +} + + +service DocumentSearchService { + rpc GetNoVectorDocuments(GetNoVectorDocumentsRequest) returns (GetNoVectorDocumentsResponse); + rpc UpdateDocument(UpdateDocumentRequest) returns (UpdateDocumentResponse); + rpc GetDocumentById(GetDocumentByIdRequest) returns (Document); +} diff --git a/document_ai/documents_pb2.py b/document_ai/documents_pb2.py new file mode 100644 index 0000000..a38abd4 --- /dev/null +++ b/document_ai/documents_pb2.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: documents.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x64ocuments.proto\x12\x05utils\"\x83\x01\n\x08\x44ocument\x12\n\n\x02id\x18\x01 \x01(\x04\x12\r\n\x05title\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x03 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x04 \x01(\t\x12\x11\n\tvector_id\x18\x05 \x01(\x04\x12\x12\n\nlibrary_id\x18\x06 \x01(\x04\x12\x0f\n\x07user_id\x18\x07 \x01(\x04\"4\n\x13GetDocumentsRequest\x12\x0f\n\x07library\x18\x01 \x01(\t\x12\x0c\n\x04text\x18\x02 \x01(\t\":\n\x14GetDocumentsResponse\x12\"\n\tdocuments\x18\x01 \x03(\x0b\x32\x0f.utils.Document\"@\n\x1bGetNoVectorDocumentsRequest\x12!\n\x08\x64ocument\x18\x01 \x01(\x0b\x32\x0f.utils.Document\"B\n\x1cGetNoVectorDocumentsResponse\x12\"\n\tdocuments\x18\x01 \x03(\x0b\x32\x0f.utils.Document\"6\n\x15UpdateDocumentRequest\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x11\n\tvector_id\x18\x02 \x01(\x04\";\n\x16UpdateDocumentResponse\x12!\n\x08\x64ocument\x18\x01 \x01(\x0b\x32\x0f.utils.Document\"$\n\x16GetDocumentByIdRequest\x12\n\n\x02id\x18\x01 \x01(\x04\x32\x8a\x02\n\x15\x44ocumentSearchService\x12_\n\x14GetNoVectorDocuments\x12\".utils.GetNoVectorDocumentsRequest\x1a#.utils.GetNoVectorDocumentsResponse\x12M\n\x0eUpdateDocument\x12\x1c.utils.UpdateDocumentRequest\x1a\x1d.utils.UpdateDocumentResponse\x12\x41\n\x0fGetDocumentById\x12\x1d.utils.GetDocumentByIdRequest\x1a\x0f.utils.DocumentB\tZ\x07./utilsb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'documents_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'Z\007./utils' + _globals['_DOCUMENT']._serialized_start=27 + _globals['_DOCUMENT']._serialized_end=158 + _globals['_GETDOCUMENTSREQUEST']._serialized_start=160 + _globals['_GETDOCUMENTSREQUEST']._serialized_end=212 + _globals['_GETDOCUMENTSRESPONSE']._serialized_start=214 + _globals['_GETDOCUMENTSRESPONSE']._serialized_end=272 + _globals['_GETNOVECTORDOCUMENTSREQUEST']._serialized_start=274 + _globals['_GETNOVECTORDOCUMENTSREQUEST']._serialized_end=338 + _globals['_GETNOVECTORDOCUMENTSRESPONSE']._serialized_start=340 + _globals['_GETNOVECTORDOCUMENTSRESPONSE']._serialized_end=406 + _globals['_UPDATEDOCUMENTREQUEST']._serialized_start=408 + _globals['_UPDATEDOCUMENTREQUEST']._serialized_end=462 + _globals['_UPDATEDOCUMENTRESPONSE']._serialized_start=464 + _globals['_UPDATEDOCUMENTRESPONSE']._serialized_end=523 + _globals['_GETDOCUMENTBYIDREQUEST']._serialized_start=525 + _globals['_GETDOCUMENTBYIDREQUEST']._serialized_end=561 + _globals['_DOCUMENTSEARCHSERVICE']._serialized_start=564 + _globals['_DOCUMENTSEARCHSERVICE']._serialized_end=830 +# @@protoc_insertion_point(module_scope) diff --git a/document_ai/documents_pb2_grpc.py b/document_ai/documents_pb2_grpc.py new file mode 100644 index 0000000..1465f89 --- /dev/null +++ b/document_ai/documents_pb2_grpc.py @@ -0,0 +1,132 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import documents_pb2 as documents__pb2 + + +class DocumentSearchServiceStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.GetNoVectorDocuments = channel.unary_unary( + '/utils.DocumentSearchService/GetNoVectorDocuments', + request_serializer=documents__pb2.GetNoVectorDocumentsRequest.SerializeToString, + response_deserializer=documents__pb2.GetNoVectorDocumentsResponse.FromString, + ) + self.UpdateDocument = channel.unary_unary( + '/utils.DocumentSearchService/UpdateDocument', + request_serializer=documents__pb2.UpdateDocumentRequest.SerializeToString, + response_deserializer=documents__pb2.UpdateDocumentResponse.FromString, + ) + self.GetDocumentById = channel.unary_unary( + '/utils.DocumentSearchService/GetDocumentById', + request_serializer=documents__pb2.GetDocumentByIdRequest.SerializeToString, + response_deserializer=documents__pb2.Document.FromString, + ) + + +class DocumentSearchServiceServicer(object): + """Missing associated documentation comment in .proto file.""" + + def GetNoVectorDocuments(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def UpdateDocument(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetDocumentById(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_DocumentSearchServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'GetNoVectorDocuments': grpc.unary_unary_rpc_method_handler( + servicer.GetNoVectorDocuments, + request_deserializer=documents__pb2.GetNoVectorDocumentsRequest.FromString, + response_serializer=documents__pb2.GetNoVectorDocumentsResponse.SerializeToString, + ), + 'UpdateDocument': grpc.unary_unary_rpc_method_handler( + servicer.UpdateDocument, + request_deserializer=documents__pb2.UpdateDocumentRequest.FromString, + response_serializer=documents__pb2.UpdateDocumentResponse.SerializeToString, + ), + 'GetDocumentById': grpc.unary_unary_rpc_method_handler( + servicer.GetDocumentById, + request_deserializer=documents__pb2.GetDocumentByIdRequest.FromString, + response_serializer=documents__pb2.Document.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'utils.DocumentSearchService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class DocumentSearchService(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def GetNoVectorDocuments(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/utils.DocumentSearchService/GetNoVectorDocuments', + documents__pb2.GetNoVectorDocumentsRequest.SerializeToString, + documents__pb2.GetNoVectorDocumentsResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def UpdateDocument(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/utils.DocumentSearchService/UpdateDocument', + documents__pb2.UpdateDocumentRequest.SerializeToString, + documents__pb2.UpdateDocumentResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetDocumentById(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/utils.DocumentSearchService/GetDocumentById', + documents__pb2.GetDocumentByIdRequest.SerializeToString, + documents__pb2.Document.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/document_ai/init.py b/document_ai/init.py new file mode 100644 index 0000000..0b8d87e --- /dev/null +++ b/document_ai/init.py @@ -0,0 +1,73 @@ +import os + +from langchain.embeddings import OpenAIEmbeddings +from pymilvus import ( + connections, + utility, + FieldSchema, + CollectionSchema, + DataType, + Collection, +) + +# init +MILVUS_HOST = os.getenv("MILVUS_HOST") or "127.0.0.1" +MILVUS_PORT = os.getenv("MILVUS_PORT") or "19530" + +connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT) + +if not utility.has_collection("leaf_documents"): + _document_id = FieldSchema( + name="document_id", + dtype=DataType.INT64, + is_primary=True, + ) + _user_id = FieldSchema( + name="user_id", + dtype=DataType.INT64, + + ) + _document_vector = FieldSchema( + name="vector", + dtype=DataType.FLOAT_VECTOR, + dim=1536 + ) + schema = CollectionSchema( + fields=[_document_id, _user_id, _document_vector], + enable_dynamic_field=True + ) + collection_name = "leaf_documents" + print("Create collection...") + collection = Collection( + name=collection_name, + schema=schema, + using='default', + shards_num=2 + ) + collection.create_index( + field_name="vector", + index_params={"metric_type": "L2", "M": 8, "efConstruction": 64, "index_type": "HNSW"}, + ) + collection.create_index( + field_name="user_id", + index_name="idx_user_id" + ) + +collection = Collection("leaf_documents") +collection.load() + +embeddings = OpenAIEmbeddings(model="text-embedding-ada-002") + + +def text_to_vector(text: str): + return embeddings.embed_query(text) + + +def insert_document(document_id: int, user_id: int, vector: list): + return collection.insert( + data=[ + [document_id], + [user_id], + [vector] + ], + ).primary_keys[0] diff --git a/document_ai/run.py b/document_ai/run.py new file mode 100644 index 0000000..e69de29 diff --git a/document_ai/search.py b/document_ai/search.py new file mode 100644 index 0000000..3c2b527 --- /dev/null +++ b/document_ai/search.py @@ -0,0 +1,79 @@ +import json +import documents_pb2 +from langchain import text_splitter +from langchain.chains.qa_with_sources import load_qa_with_sources_chain +from langchain.embeddings import OpenAIEmbeddings +from langchain.llms.openai import OpenAI +from langchain.schema.document import Document +from pymilvus import ( + connections, + utility, + FieldSchema, + CollectionSchema, + DataType, + Collection, +) + +import init +import doc_client + + +question = """ +yarn : File C:\\Users\\ivamp\\AppData\\Roaming\\npm\\yarn.ps1 cannot be loaded because running scripts is disabled on this sy +stem. For more information, see about_Execution_Policies at https:/go.microsoft.com/fwlink/?LinkID=135170. +At line:1 char:1 ++ yarn config set registry https://registry.npm.taobao.org/ ++ ~~~~ + + CategoryInfo : SecurityError: (:) [], PSSecurityException + + FullyQualifiedErrorId : UnauthorizedAccess + +是什么问题,该怎么解决 +""" + +vec = init.text_to_vector(question + " (必须使用中文回复)") + +# vec = "" +# +# with open("../question_vec.json", "r") as f: +# vec = json.load(f) + +search_param = { + "data": [vec], + "anns_field": "vector", + "param": {"metric_type": "L2"}, + "limit": 10, + "expr": "user_id == 2", + "output_fields": ["todo_id", "title", "source", "todo_description", "language", "text", "user_id"], +} +res = init.collection.search(**search_param) + +document_ids = [] +real_document = [] + +for i in range(len(res[0])): + _doc_id = res[0][i].id + print("正在获取 " + str(_doc_id) + " 的内容...") + + try: + _doc_content = doc_client.stub.GetDocumentById(documents_pb2.GetDocumentByIdRequest( + id=_doc_id + )) + _doc_content_full = _doc_content.title + "\n" + _doc_content.content + + # real_document.append(_doc_content) + doc_obj = Document(page_content=_doc_content_full, metadata={"source": _doc_content.title}) + + real_document.append(doc_obj) + + except Exception as e: + print(e) + + + +print(real_document) + +print("正在调用 LLM...") +chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce", return_intermediate_steps=True, + verbose=True) +output = chain({"input_documents": real_document, "question": question}, return_only_outputs=False) +print("回复:" + output["output_text"]) diff --git a/document_ai/server.py b/document_ai/server.py new file mode 100644 index 0000000..64b46ae --- /dev/null +++ b/document_ai/server.py @@ -0,0 +1,80 @@ +import os +from concurrent import futures +import document_query_pb2 +import document_query_pb2_grpc +import grpc +import documents_pb2 +import init +import doc_client +from langchain.llms.openai import OpenAI +from langchain.schema.document import Document +from langchain.embeddings import OpenAIEmbeddings +from langchain.chains.qa_with_sources import load_qa_with_sources_chain + + +class AIServer(document_query_pb2_grpc.DocumentQuery): + def Query(self, request, context): + vec = init.text_to_vector(request.question) + + question = request.question + "(必须使用中文回复)" + + search_param = { + "data": [vec], + "anns_field": "vector", + "param": {"metric_type": "L2"}, + "limit": 10, + "expr": "user_id == " + str(request.user_id), + "output_fields": ["document_id", "user_id"], + } + + res = init.collection.search(**search_param) + + document_ids = [] + real_document = [] + + for i in range(len(res[0])): + _doc_id = res[0][i].id + print("正在获取 " + str(_doc_id) + " 的内容...") + + try: + _doc_content = doc_client.stub.GetDocumentById(documents_pb2.GetDocumentByIdRequest( + id=_doc_id + )) + _doc_content_full = _doc_content.title + "\n" + _doc_content.content + + # real_document.append(_doc_content) + doc_obj = Document(page_content=_doc_content_full, metadata={"source": _doc_content.title}) + + real_document.append(doc_obj) + + except Exception as e: + print(e) + + print(real_document) + + print("正在调用 LLM...") + chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce", + return_intermediate_steps=True, + verbose=True) + output = chain({"input_documents": real_document, "question": question}, return_only_outputs=False) + print("回复:" + output["output_text"]) + + return document_query_pb2.QueryResponse( + text=output["output_text"] + ) + + +def serve(): + _ADDR = os.getenv("BIND") + if _ADDR is None: + _ADDR = "[::]:50051" + print("Listening on", _ADDR) + + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + document_query_pb2_grpc.add_DocumentQueryServicer_to_server(AIServer(), server) + server.add_insecure_port(_ADDR) + server.start() + server.wait_for_termination() + + +serve() diff --git a/document_ai/worker.py b/document_ai/worker.py new file mode 100644 index 0000000..70b1885 --- /dev/null +++ b/document_ai/worker.py @@ -0,0 +1,27 @@ +import documents_pb2_grpc +import documents_pb2 +import init +import doc_client + +print("获取需要更新的数据...") +documents_response = doc_client.stub.GetNoVectorDocuments(documents_pb2.GetNoVectorDocumentsRequest()).documents + +# # get all documents with no vector +for document in documents_response: + docContent = document.title + "\n" + document.content + + print("正在更新向量...") + text_vector = init.text_to_vector(docContent) + + # update vector + update_vector_response = init.insert_document(document.id, document.user_id, text_vector) + print(update_vector_response) + + # update vector_id + update_vector_id_response = doc_client.stub.UpdateDocument(documents_pb2.UpdateDocumentRequest( + id=document.id, + vector_id=update_vector_response + )) + + print(update_vector_id_response) + print("更新向量完成") diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e69de29