This commit is contained in:
iVamp 2023-11-19 20:54:12 +08:00
parent c8e5c8f389
commit 80c871e09d
5 changed files with 318 additions and 20 deletions

83
document_ai/agent.py Normal file
View File

@ -0,0 +1,83 @@
from langchain.agents import Tool, load_tools
from langchain.memory import ConversationBufferMemory
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage, BaseMessage
from langchain.utilities import SerpAPIWrapper
from langchain.agents import initialize_agent
from langchain.agents import AgentType
from getpass import getpass
import proto.documents_pb2
import init, doc_client
# def fake_result(str: str) -> str:
# print(str)
# return "博客名称: iVampireSP.com"
#
def search_document(question: str) -> str:
print("搜索请求:" + question)
vec = init.text_to_vector(question)
search_param = {
"data": [vec],
"anns_field": "vector",
"param": {"metric_type": "L2"},
"limit": 5,
"expr": "user_id == 2",
"output_fields": ["todo_id", "title", "source", "todo_description", "language", "text", "user_id"],
}
res = init.collection.search(**search_param)
plain_text = ""
for i in range(len(res[0])):
_chunk_id = res[0][i].id
print("正在获取分块 " + str(_chunk_id) + " 的内容...")
try:
_chunk_content = doc_client.stub.GetDocumentChunk(proto.documents_pb2.GetDocumentChunkByIdRequest(
id=_chunk_id
))
plain_text += "=== \n" + _chunk_content.content + " ===\n"
except Exception as e:
print(e)
return plain_text
tools = [
# Tool(
# name="Get Blog Name",
# func=fake_result,
# description="Get user's blog name from the Internet.",
# ),
Tool(
name="Search user's Library Document",
func=search_document,
description="优先使用 Search user's Library Document.",
)
]
llm = ChatOpenAI(temperature=0)
loaded_tools = load_tools(["llm-math"], llm=llm)
tools.extend(loaded_tools)
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
memory.chat_memory.messages.append(HumanMessage(content="必须使用中文回复。"))
# memory.clear()
agent_chain = initialize_agent(tools, llm,
agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,
verbose=False,
memory=memory)
while True:
question = input("请输入问题:")
question = "必须使用中文回复:" + question
result = agent_chain.run(input=question)
print(result)

121
document_ai/chat.py Normal file
View File

@ -0,0 +1,121 @@
import json
from langchain.tools import BaseTool
import proto.documents_pb2
import init
import doc_client
import openai
# class Eva
# 连续对话
messages = [
{
"role": "system",
"content": """
回答问题使用文档并以提问的语言和Markdown回答并告诉来源
你得用的身份指代用户请辨别文档中的内容有一些是不相干的
"""
},
# {
# "role": "system",
# "content": f"""
# Context: {plain_text}
# """
# },
# {
# "role": "user",
# "content": f"""
# {question}
# """
# }
]
def ask_question(question):
messages.append({
"role": "user",
"content": f"""
{question}
"""
})
question_vec = init.text_to_vector(question)
search_param = {
"data": [question_vec],
"anns_field": "vector",
"param": {"metric_type": "L2"},
"limit": 5,
"expr": "user_id == 2",
"output_fields": ["todo_id", "title", "source", "todo_description", "language", "text", "user_id"],
}
res = init.collection.search(**search_param)
plain_text = ""
for i in range(len(res[0])):
_chunk_id = res[0][i].id
if _chunk_id == "0":
continue
print("正在获取分块 " + str(_chunk_id) + " 的内容...")
try:
_chunk_content = doc_client.stub.GetDocumentChunk(proto.documents_pb2.GetDocumentChunkByIdRequest(
id=_chunk_id
))
print(_chunk_content)
plain_text += ("=== \n" + f"文档 ID:{_chunk_content.document.id}\n"
+ f"文档内容: {_chunk_content.content}" + "===\n")
except Exception as e:
print(e)
messages.append({
"role": "system",
"content": f"""
文档: {plain_text}
"""
})
print("正在调用 LLM...")
result = openai.ChatCompletion.create(
messages=messages, model="gpt-3.5-turbo", temperature=0
)
res = result["choices"][0]["message"].to_dict_recursive()
# add to
messages.append({
"role": "assistant",
"content": res["content"]
})
# 删除多余的 system 消息
for i in range(len(messages)):
if messages[i]["role"] == "system":
if i == 0:
continue
messages.pop(i)
break
return res["content"]
while True:
print(messages)
if len(messages) > 10:
messages = messages[-10:]
print("很抱歉,我只能记住最近 10 条上下文数据。让我们重新开始吧。")
question = input("请输入问题:")
resp = ask_question(question)
print(resp)

View File

@ -1,9 +1,15 @@
import json import json
import openai
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
import proto.documents_pb2 import proto.documents_pb2
from langchain import text_splitter from langchain import text_splitter
from langchain.chains.qa_with_sources import load_qa_with_sources_chain from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.embeddings import OpenAIEmbeddings from langchain.embeddings import OpenAIEmbeddings
from langchain.llms.openai import OpenAI # from langchain.llms.openai import OpenAI;
from langchain.chat_models.openai import ChatOpenAI
from langchain.schema.document import Document from langchain.schema.document import Document
from pymilvus import ( from pymilvus import (
connections, connections,
@ -16,6 +22,8 @@ from pymilvus import (
import init import init
import doc_client import doc_client
import openai
from langchain.adapters import openai as lc_openai
# #
# question = """ # question = """
@ -31,15 +39,15 @@ import doc_client
# """ # """
question = """ question = """
为什么我会在 WHMCS 下开发摸不着头脑 错误 yarn 什么了遇到我
""" """
vec = init.text_to_vector(question)
# vec = ""
# #
# with open("../question_vec.json", "r") as f: # vec = init.text_to_vector(question)
# vec = json.load(f)
vec = ""
with open("../question_vec.json", "r") as f:
vec = json.load(f)
search_param = { search_param = {
"data": [vec], "data": [vec],
@ -51,8 +59,13 @@ search_param = {
} }
res = init.collection.search(**search_param) res = init.collection.search(**search_param)
document_chunk_ids = [] # 保留 5 个
real_document = [] if len(res[0]) > 5:
res[0] = res[0][:5]
# document_chunk_ids = []
# real_document = []
plain_text = ""
for i in range(len(res[0])): for i in range(len(res[0])):
_chunk_id = res[0][i].id _chunk_id = res[0][i].id
@ -65,23 +78,104 @@ for i in range(len(res[0])):
# print(_chunk_content) # print(_chunk_content)
_doc_content_full = _chunk_content.content # _doc_content_full = _chunk_content.content
# print("DOC OBJ:" + _doc_content_full)
plain_text += "=== \n" + _chunk_content.content + " ===\n"
# real_document.append(_doc_content) # real_document.append(_doc_content)
# doc_obj = Document(page_content=_doc_content_full, metadata={"source": _chunk_content.title}) # doc_obj = Document(page_content=_doc_content_full, metadata={"source": _chunk_content.title})
doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"}) # doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"})
real_document.append(doc_obj) # real_document.append(doc_obj)
except Exception as e: except Exception as e:
print(e) print(e)
print(real_document) # print(real_document)
print("正在调用 LLM...") print("正在调用 LLM...")
chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce", return_intermediate_steps=True,
verbose=True)
question = "必须使用中文回复:" + question # prompt_template = f"""Answer questions use the following context and reply in question spoken language and answer
output = chain({"input_documents": real_document, "question": question}, return_only_outputs=False) # with Markdown format, you can fix syntax errors in the context, but do not change the meaning of the context.
print("回复:" + output["output_text"]) # you can tell user context errors(syntax or meaning) in answer.
# ---
# {plain_text}
# ---
# Question: {question}
# Answer:"""
messages = [
{
"role": "system",
"content": """
回答问题使用以下上下文并以提问的语言和Markdown回答并告诉来源
你得用的身份指代用户如果用户的问题有语法错误或者上下文的意思不对你可以告诉用户
请辨别上下文中的内容有一些是不相干的
"""
},
{
"role": "system",
"content": f"""
Context: {plain_text}
"""
},
{
"role": "user",
"content": f"""
{question}
"""
}
]
result = openai.ChatCompletion.create(
messages=messages, model="gpt-3.5-turbo", temperature=0
)
res = result["choices"][0]["message"].to_dict_recursive()
print(res)
# prompt_template = f"""
# ---
# {plain_text}
# ---
# Question: {question}
# Answer:"""
#
# print(prompt_template)
# # PROMPT = PromptTemplate(
# # template=prompt_template, input_variables=["real_document", "question"]
# # )
#
#
# ChatOpenAI
# llm = OpenAI(temperature=0, model_name="gpt-3.5-turbo")
# # chain = LLMChain(llm=llm, prompt=PROMPT)
#
# output = llm(prompt_template)
# gpt = openai.Completion.create(
# engine="gpt-3.5-turbo",
# prompt=prompt_template,
# max_tokens=150,
# temperature=0,
# top_p=1,
# frequency_penalty=0,
# presence_penalty=0,
# stop=["==="]
# )
# output = gpt["choices"][0]["text"]
# print(output)
# output = chain({"real_document": real_document, "question": question}, return_only_outputs=True)
# print(output)
# chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce", return_intermediate_steps=True,
# verbose=True)
#
# question = "必须使用中文回复:" + question
# output = chain({"input_documents": real_document, "question": question}, return_only_outputs=False)
# print("回复:" + output["output_text"])

View File

@ -42,7 +42,7 @@ class AIServer(proto.document_query_pb2_grpc.DocumentQuery):
"data": [vec], "data": [vec],
"anns_field": "vector", "anns_field": "vector",
"param": {"metric_type": "L2"}, "param": {"metric_type": "L2"},
"limit": 10, "limit": 5,
"expr": "user_id == " + str(target.user_id), "expr": "user_id == " + str(target.user_id),
"output_fields": ["document_id", "user_id"], "output_fields": ["document_id", "user_id"],
} }

Binary file not shown.