add document ai
This commit is contained in:
parent
db30aaf68c
commit
2462ec5b7e
8
document_ai/doc_client.py
Normal file
8
document_ai/doc_client.py
Normal file
@ -0,0 +1,8 @@
|
||||
import grpc
|
||||
import documents_pb2_grpc
|
||||
import documents_pb2
|
||||
|
||||
print("正在连接...")
|
||||
channel = grpc.insecure_channel('localhost:8081')
|
||||
|
||||
stub = documents_pb2_grpc.DocumentSearchServiceStub(channel)
|
16
document_ai/document_query.proto
Normal file
16
document_ai/document_query.proto
Normal file
@ -0,0 +1,16 @@
|
||||
syntax = "proto3";
|
||||
|
||||
service DocumentQuery {
|
||||
rpc Query(QueryRequest) returns (QueryResponse) {}
|
||||
}
|
||||
|
||||
|
||||
message QueryRequest {
|
||||
string question = 1;
|
||||
uint64 user_id = 2;
|
||||
}
|
||||
|
||||
message QueryResponse {
|
||||
string text = 1;
|
||||
}
|
||||
|
29
document_ai/document_query_pb2.py
Normal file
29
document_ai/document_query_pb2.py
Normal file
@ -0,0 +1,29 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: document_query.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x64ocument_query.proto\"1\n\x0cQueryRequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07user_id\x18\x02 \x01(\x04\"\x1d\n\rQueryResponse\x12\x0c\n\x04text\x18\x01 \x01(\t29\n\rDocumentQuery\x12(\n\x05Query\x12\r.QueryRequest\x1a\x0e.QueryResponse\"\x00\x62\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'document_query_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
DESCRIPTOR._options = None
|
||||
_globals['_QUERYREQUEST']._serialized_start=24
|
||||
_globals['_QUERYREQUEST']._serialized_end=73
|
||||
_globals['_QUERYRESPONSE']._serialized_start=75
|
||||
_globals['_QUERYRESPONSE']._serialized_end=104
|
||||
_globals['_DOCUMENTQUERY']._serialized_start=106
|
||||
_globals['_DOCUMENTQUERY']._serialized_end=163
|
||||
# @@protoc_insertion_point(module_scope)
|
66
document_ai/document_query_pb2_grpc.py
Normal file
66
document_ai/document_query_pb2_grpc.py
Normal file
@ -0,0 +1,66 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
|
||||
import document_query_pb2 as document__query__pb2
|
||||
|
||||
|
||||
class DocumentQueryStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.Query = channel.unary_unary(
|
||||
'/DocumentQuery/Query',
|
||||
request_serializer=document__query__pb2.QueryRequest.SerializeToString,
|
||||
response_deserializer=document__query__pb2.QueryResponse.FromString,
|
||||
)
|
||||
|
||||
|
||||
class DocumentQueryServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def Query(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_DocumentQueryServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'Query': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Query,
|
||||
request_deserializer=document__query__pb2.QueryRequest.FromString,
|
||||
response_serializer=document__query__pb2.QueryResponse.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'DocumentQuery', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class DocumentQuery(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def Query(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/DocumentQuery/Query',
|
||||
document__query__pb2.QueryRequest.SerializeToString,
|
||||
document__query__pb2.QueryResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
52
document_ai/documents.proto
Normal file
52
document_ai/documents.proto
Normal file
@ -0,0 +1,52 @@
|
||||
syntax = "proto3";
|
||||
|
||||
option go_package="./utils";
|
||||
package utils;
|
||||
|
||||
message Document {
|
||||
uint64 id = 1;
|
||||
string title = 2;
|
||||
string description = 3;
|
||||
string content = 4;
|
||||
uint64 vector_id = 5;
|
||||
uint64 library_id = 6;
|
||||
uint64 user_id = 7;
|
||||
}
|
||||
|
||||
message GetDocumentsRequest {
|
||||
string library = 1;
|
||||
string text = 2;
|
||||
}
|
||||
|
||||
message GetDocumentsResponse {
|
||||
repeated Document documents = 1;
|
||||
}
|
||||
|
||||
message GetNoVectorDocumentsRequest {
|
||||
Document document = 1;
|
||||
}
|
||||
|
||||
message GetNoVectorDocumentsResponse {
|
||||
repeated Document documents = 1;
|
||||
}
|
||||
|
||||
|
||||
message UpdateDocumentRequest {
|
||||
uint64 id = 1;
|
||||
uint64 vector_id = 2;
|
||||
}
|
||||
|
||||
message UpdateDocumentResponse {
|
||||
Document document = 1;
|
||||
}
|
||||
|
||||
message GetDocumentByIdRequest {
|
||||
uint64 id = 1;
|
||||
}
|
||||
|
||||
|
||||
service DocumentSearchService {
|
||||
rpc GetNoVectorDocuments(GetNoVectorDocumentsRequest) returns (GetNoVectorDocumentsResponse);
|
||||
rpc UpdateDocument(UpdateDocumentRequest) returns (UpdateDocumentResponse);
|
||||
rpc GetDocumentById(GetDocumentByIdRequest) returns (Document);
|
||||
}
|
42
document_ai/documents_pb2.py
Normal file
42
document_ai/documents_pb2.py
Normal file
@ -0,0 +1,42 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: documents.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x64ocuments.proto\x12\x05utils\"\x83\x01\n\x08\x44ocument\x12\n\n\x02id\x18\x01 \x01(\x04\x12\r\n\x05title\x18\x02 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x03 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x04 \x01(\t\x12\x11\n\tvector_id\x18\x05 \x01(\x04\x12\x12\n\nlibrary_id\x18\x06 \x01(\x04\x12\x0f\n\x07user_id\x18\x07 \x01(\x04\"4\n\x13GetDocumentsRequest\x12\x0f\n\x07library\x18\x01 \x01(\t\x12\x0c\n\x04text\x18\x02 \x01(\t\":\n\x14GetDocumentsResponse\x12\"\n\tdocuments\x18\x01 \x03(\x0b\x32\x0f.utils.Document\"@\n\x1bGetNoVectorDocumentsRequest\x12!\n\x08\x64ocument\x18\x01 \x01(\x0b\x32\x0f.utils.Document\"B\n\x1cGetNoVectorDocumentsResponse\x12\"\n\tdocuments\x18\x01 \x03(\x0b\x32\x0f.utils.Document\"6\n\x15UpdateDocumentRequest\x12\n\n\x02id\x18\x01 \x01(\x04\x12\x11\n\tvector_id\x18\x02 \x01(\x04\";\n\x16UpdateDocumentResponse\x12!\n\x08\x64ocument\x18\x01 \x01(\x0b\x32\x0f.utils.Document\"$\n\x16GetDocumentByIdRequest\x12\n\n\x02id\x18\x01 \x01(\x04\x32\x8a\x02\n\x15\x44ocumentSearchService\x12_\n\x14GetNoVectorDocuments\x12\".utils.GetNoVectorDocumentsRequest\x1a#.utils.GetNoVectorDocumentsResponse\x12M\n\x0eUpdateDocument\x12\x1c.utils.UpdateDocumentRequest\x1a\x1d.utils.UpdateDocumentResponse\x12\x41\n\x0fGetDocumentById\x12\x1d.utils.GetDocumentByIdRequest\x1a\x0f.utils.DocumentB\tZ\x07./utilsb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'documents_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
DESCRIPTOR._options = None
|
||||
DESCRIPTOR._serialized_options = b'Z\007./utils'
|
||||
_globals['_DOCUMENT']._serialized_start=27
|
||||
_globals['_DOCUMENT']._serialized_end=158
|
||||
_globals['_GETDOCUMENTSREQUEST']._serialized_start=160
|
||||
_globals['_GETDOCUMENTSREQUEST']._serialized_end=212
|
||||
_globals['_GETDOCUMENTSRESPONSE']._serialized_start=214
|
||||
_globals['_GETDOCUMENTSRESPONSE']._serialized_end=272
|
||||
_globals['_GETNOVECTORDOCUMENTSREQUEST']._serialized_start=274
|
||||
_globals['_GETNOVECTORDOCUMENTSREQUEST']._serialized_end=338
|
||||
_globals['_GETNOVECTORDOCUMENTSRESPONSE']._serialized_start=340
|
||||
_globals['_GETNOVECTORDOCUMENTSRESPONSE']._serialized_end=406
|
||||
_globals['_UPDATEDOCUMENTREQUEST']._serialized_start=408
|
||||
_globals['_UPDATEDOCUMENTREQUEST']._serialized_end=462
|
||||
_globals['_UPDATEDOCUMENTRESPONSE']._serialized_start=464
|
||||
_globals['_UPDATEDOCUMENTRESPONSE']._serialized_end=523
|
||||
_globals['_GETDOCUMENTBYIDREQUEST']._serialized_start=525
|
||||
_globals['_GETDOCUMENTBYIDREQUEST']._serialized_end=561
|
||||
_globals['_DOCUMENTSEARCHSERVICE']._serialized_start=564
|
||||
_globals['_DOCUMENTSEARCHSERVICE']._serialized_end=830
|
||||
# @@protoc_insertion_point(module_scope)
|
132
document_ai/documents_pb2_grpc.py
Normal file
132
document_ai/documents_pb2_grpc.py
Normal file
@ -0,0 +1,132 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
|
||||
import documents_pb2 as documents__pb2
|
||||
|
||||
|
||||
class DocumentSearchServiceStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.GetNoVectorDocuments = channel.unary_unary(
|
||||
'/utils.DocumentSearchService/GetNoVectorDocuments',
|
||||
request_serializer=documents__pb2.GetNoVectorDocumentsRequest.SerializeToString,
|
||||
response_deserializer=documents__pb2.GetNoVectorDocumentsResponse.FromString,
|
||||
)
|
||||
self.UpdateDocument = channel.unary_unary(
|
||||
'/utils.DocumentSearchService/UpdateDocument',
|
||||
request_serializer=documents__pb2.UpdateDocumentRequest.SerializeToString,
|
||||
response_deserializer=documents__pb2.UpdateDocumentResponse.FromString,
|
||||
)
|
||||
self.GetDocumentById = channel.unary_unary(
|
||||
'/utils.DocumentSearchService/GetDocumentById',
|
||||
request_serializer=documents__pb2.GetDocumentByIdRequest.SerializeToString,
|
||||
response_deserializer=documents__pb2.Document.FromString,
|
||||
)
|
||||
|
||||
|
||||
class DocumentSearchServiceServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def GetNoVectorDocuments(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def UpdateDocument(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GetDocumentById(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_DocumentSearchServiceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'GetNoVectorDocuments': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetNoVectorDocuments,
|
||||
request_deserializer=documents__pb2.GetNoVectorDocumentsRequest.FromString,
|
||||
response_serializer=documents__pb2.GetNoVectorDocumentsResponse.SerializeToString,
|
||||
),
|
||||
'UpdateDocument': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.UpdateDocument,
|
||||
request_deserializer=documents__pb2.UpdateDocumentRequest.FromString,
|
||||
response_serializer=documents__pb2.UpdateDocumentResponse.SerializeToString,
|
||||
),
|
||||
'GetDocumentById': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetDocumentById,
|
||||
request_deserializer=documents__pb2.GetDocumentByIdRequest.FromString,
|
||||
response_serializer=documents__pb2.Document.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'utils.DocumentSearchService', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class DocumentSearchService(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def GetNoVectorDocuments(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/utils.DocumentSearchService/GetNoVectorDocuments',
|
||||
documents__pb2.GetNoVectorDocumentsRequest.SerializeToString,
|
||||
documents__pb2.GetNoVectorDocumentsResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def UpdateDocument(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/utils.DocumentSearchService/UpdateDocument',
|
||||
documents__pb2.UpdateDocumentRequest.SerializeToString,
|
||||
documents__pb2.UpdateDocumentResponse.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
|
||||
@staticmethod
|
||||
def GetDocumentById(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(request, target, '/utils.DocumentSearchService/GetDocumentById',
|
||||
documents__pb2.GetDocumentByIdRequest.SerializeToString,
|
||||
documents__pb2.Document.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
73
document_ai/init.py
Normal file
73
document_ai/init.py
Normal file
@ -0,0 +1,73 @@
|
||||
import os
|
||||
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from pymilvus import (
|
||||
connections,
|
||||
utility,
|
||||
FieldSchema,
|
||||
CollectionSchema,
|
||||
DataType,
|
||||
Collection,
|
||||
)
|
||||
|
||||
# init
|
||||
MILVUS_HOST = os.getenv("MILVUS_HOST") or "127.0.0.1"
|
||||
MILVUS_PORT = os.getenv("MILVUS_PORT") or "19530"
|
||||
|
||||
connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
|
||||
|
||||
if not utility.has_collection("leaf_documents"):
|
||||
_document_id = FieldSchema(
|
||||
name="document_id",
|
||||
dtype=DataType.INT64,
|
||||
is_primary=True,
|
||||
)
|
||||
_user_id = FieldSchema(
|
||||
name="user_id",
|
||||
dtype=DataType.INT64,
|
||||
|
||||
)
|
||||
_document_vector = FieldSchema(
|
||||
name="vector",
|
||||
dtype=DataType.FLOAT_VECTOR,
|
||||
dim=1536
|
||||
)
|
||||
schema = CollectionSchema(
|
||||
fields=[_document_id, _user_id, _document_vector],
|
||||
enable_dynamic_field=True
|
||||
)
|
||||
collection_name = "leaf_documents"
|
||||
print("Create collection...")
|
||||
collection = Collection(
|
||||
name=collection_name,
|
||||
schema=schema,
|
||||
using='default',
|
||||
shards_num=2
|
||||
)
|
||||
collection.create_index(
|
||||
field_name="vector",
|
||||
index_params={"metric_type": "L2", "M": 8, "efConstruction": 64, "index_type": "HNSW"},
|
||||
)
|
||||
collection.create_index(
|
||||
field_name="user_id",
|
||||
index_name="idx_user_id"
|
||||
)
|
||||
|
||||
collection = Collection("leaf_documents")
|
||||
collection.load()
|
||||
|
||||
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
|
||||
|
||||
|
||||
def text_to_vector(text: str):
|
||||
return embeddings.embed_query(text)
|
||||
|
||||
|
||||
def insert_document(document_id: int, user_id: int, vector: list):
|
||||
return collection.insert(
|
||||
data=[
|
||||
[document_id],
|
||||
[user_id],
|
||||
[vector]
|
||||
],
|
||||
).primary_keys[0]
|
0
document_ai/run.py
Normal file
0
document_ai/run.py
Normal file
79
document_ai/search.py
Normal file
79
document_ai/search.py
Normal file
@ -0,0 +1,79 @@
|
||||
import json
|
||||
import documents_pb2
|
||||
from langchain import text_splitter
|
||||
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.schema.document import Document
|
||||
from pymilvus import (
|
||||
connections,
|
||||
utility,
|
||||
FieldSchema,
|
||||
CollectionSchema,
|
||||
DataType,
|
||||
Collection,
|
||||
)
|
||||
|
||||
import init
|
||||
import doc_client
|
||||
|
||||
|
||||
question = """
|
||||
yarn : File C:\\Users\\ivamp\\AppData\\Roaming\\npm\\yarn.ps1 cannot be loaded because running scripts is disabled on this sy
|
||||
stem. For more information, see about_Execution_Policies at https:/go.microsoft.com/fwlink/?LinkID=135170.
|
||||
At line:1 char:1
|
||||
+ yarn config set registry https://registry.npm.taobao.org/
|
||||
+ ~~~~
|
||||
+ CategoryInfo : SecurityError: (:) [], PSSecurityException
|
||||
+ FullyQualifiedErrorId : UnauthorizedAccess
|
||||
|
||||
是什么问题,该怎么解决
|
||||
"""
|
||||
|
||||
vec = init.text_to_vector(question + " (必须使用中文回复)")
|
||||
|
||||
# vec = ""
|
||||
#
|
||||
# with open("../question_vec.json", "r") as f:
|
||||
# vec = json.load(f)
|
||||
|
||||
search_param = {
|
||||
"data": [vec],
|
||||
"anns_field": "vector",
|
||||
"param": {"metric_type": "L2"},
|
||||
"limit": 10,
|
||||
"expr": "user_id == 2",
|
||||
"output_fields": ["todo_id", "title", "source", "todo_description", "language", "text", "user_id"],
|
||||
}
|
||||
res = init.collection.search(**search_param)
|
||||
|
||||
document_ids = []
|
||||
real_document = []
|
||||
|
||||
for i in range(len(res[0])):
|
||||
_doc_id = res[0][i].id
|
||||
print("正在获取 " + str(_doc_id) + " 的内容...")
|
||||
|
||||
try:
|
||||
_doc_content = doc_client.stub.GetDocumentById(documents_pb2.GetDocumentByIdRequest(
|
||||
id=_doc_id
|
||||
))
|
||||
_doc_content_full = _doc_content.title + "\n" + _doc_content.content
|
||||
|
||||
# real_document.append(_doc_content)
|
||||
doc_obj = Document(page_content=_doc_content_full, metadata={"source": _doc_content.title})
|
||||
|
||||
real_document.append(doc_obj)
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
|
||||
print(real_document)
|
||||
|
||||
print("正在调用 LLM...")
|
||||
chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce", return_intermediate_steps=True,
|
||||
verbose=True)
|
||||
output = chain({"input_documents": real_document, "question": question}, return_only_outputs=False)
|
||||
print("回复:" + output["output_text"])
|
80
document_ai/server.py
Normal file
80
document_ai/server.py
Normal file
@ -0,0 +1,80 @@
|
||||
import os
|
||||
from concurrent import futures
|
||||
import document_query_pb2
|
||||
import document_query_pb2_grpc
|
||||
import grpc
|
||||
import documents_pb2
|
||||
import init
|
||||
import doc_client
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.schema.document import Document
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
||||
|
||||
|
||||
class AIServer(document_query_pb2_grpc.DocumentQuery):
|
||||
def Query(self, request, context):
|
||||
vec = init.text_to_vector(request.question)
|
||||
|
||||
question = request.question + "(必须使用中文回复)"
|
||||
|
||||
search_param = {
|
||||
"data": [vec],
|
||||
"anns_field": "vector",
|
||||
"param": {"metric_type": "L2"},
|
||||
"limit": 10,
|
||||
"expr": "user_id == " + str(request.user_id),
|
||||
"output_fields": ["document_id", "user_id"],
|
||||
}
|
||||
|
||||
res = init.collection.search(**search_param)
|
||||
|
||||
document_ids = []
|
||||
real_document = []
|
||||
|
||||
for i in range(len(res[0])):
|
||||
_doc_id = res[0][i].id
|
||||
print("正在获取 " + str(_doc_id) + " 的内容...")
|
||||
|
||||
try:
|
||||
_doc_content = doc_client.stub.GetDocumentById(documents_pb2.GetDocumentByIdRequest(
|
||||
id=_doc_id
|
||||
))
|
||||
_doc_content_full = _doc_content.title + "\n" + _doc_content.content
|
||||
|
||||
# real_document.append(_doc_content)
|
||||
doc_obj = Document(page_content=_doc_content_full, metadata={"source": _doc_content.title})
|
||||
|
||||
real_document.append(doc_obj)
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
print(real_document)
|
||||
|
||||
print("正在调用 LLM...")
|
||||
chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce",
|
||||
return_intermediate_steps=True,
|
||||
verbose=True)
|
||||
output = chain({"input_documents": real_document, "question": question}, return_only_outputs=False)
|
||||
print("回复:" + output["output_text"])
|
||||
|
||||
return document_query_pb2.QueryResponse(
|
||||
text=output["output_text"]
|
||||
)
|
||||
|
||||
|
||||
def serve():
|
||||
_ADDR = os.getenv("BIND")
|
||||
if _ADDR is None:
|
||||
_ADDR = "[::]:50051"
|
||||
print("Listening on", _ADDR)
|
||||
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
document_query_pb2_grpc.add_DocumentQueryServicer_to_server(AIServer(), server)
|
||||
server.add_insecure_port(_ADDR)
|
||||
server.start()
|
||||
server.wait_for_termination()
|
||||
|
||||
|
||||
serve()
|
27
document_ai/worker.py
Normal file
27
document_ai/worker.py
Normal file
@ -0,0 +1,27 @@
|
||||
import documents_pb2_grpc
|
||||
import documents_pb2
|
||||
import init
|
||||
import doc_client
|
||||
|
||||
print("获取需要更新的数据...")
|
||||
documents_response = doc_client.stub.GetNoVectorDocuments(documents_pb2.GetNoVectorDocumentsRequest()).documents
|
||||
|
||||
# # get all documents with no vector
|
||||
for document in documents_response:
|
||||
docContent = document.title + "\n" + document.content
|
||||
|
||||
print("正在更新向量...")
|
||||
text_vector = init.text_to_vector(docContent)
|
||||
|
||||
# update vector
|
||||
update_vector_response = init.insert_document(document.id, document.user_id, text_vector)
|
||||
print(update_vector_response)
|
||||
|
||||
# update vector_id
|
||||
update_vector_id_response = doc_client.stub.UpdateDocument(documents_pb2.UpdateDocumentRequest(
|
||||
id=document.id,
|
||||
vector_id=update_vector_response
|
||||
))
|
||||
|
||||
print(update_vector_id_response)
|
||||
print("更新向量完成")
|
0
requirements.txt
Normal file
0
requirements.txt
Normal file
Loading…
Reference in New Issue
Block a user