改进
This commit is contained in:
parent
c8e5c8f389
commit
80c871e09d
83
document_ai/agent.py
Normal file
83
document_ai/agent.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
from langchain.agents import Tool, load_tools
|
||||||
|
from langchain.memory import ConversationBufferMemory
|
||||||
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
from langchain.schema import HumanMessage, BaseMessage
|
||||||
|
from langchain.utilities import SerpAPIWrapper
|
||||||
|
from langchain.agents import initialize_agent
|
||||||
|
from langchain.agents import AgentType
|
||||||
|
from getpass import getpass
|
||||||
|
import proto.documents_pb2
|
||||||
|
|
||||||
|
import init, doc_client
|
||||||
|
|
||||||
|
|
||||||
|
# def fake_result(str: str) -> str:
|
||||||
|
# print(str)
|
||||||
|
# return "博客名称: iVampireSP.com"
|
||||||
|
#
|
||||||
|
|
||||||
|
def search_document(question: str) -> str:
|
||||||
|
print("搜索请求:" + question)
|
||||||
|
vec = init.text_to_vector(question)
|
||||||
|
|
||||||
|
search_param = {
|
||||||
|
"data": [vec],
|
||||||
|
"anns_field": "vector",
|
||||||
|
"param": {"metric_type": "L2"},
|
||||||
|
"limit": 5,
|
||||||
|
"expr": "user_id == 2",
|
||||||
|
"output_fields": ["todo_id", "title", "source", "todo_description", "language", "text", "user_id"],
|
||||||
|
}
|
||||||
|
res = init.collection.search(**search_param)
|
||||||
|
|
||||||
|
plain_text = ""
|
||||||
|
|
||||||
|
for i in range(len(res[0])):
|
||||||
|
_chunk_id = res[0][i].id
|
||||||
|
print("正在获取分块 " + str(_chunk_id) + " 的内容...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
_chunk_content = doc_client.stub.GetDocumentChunk(proto.documents_pb2.GetDocumentChunkByIdRequest(
|
||||||
|
id=_chunk_id
|
||||||
|
))
|
||||||
|
|
||||||
|
plain_text += "=== \n" + _chunk_content.content + " ===\n"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return plain_text
|
||||||
|
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
# Tool(
|
||||||
|
# name="Get Blog Name",
|
||||||
|
# func=fake_result,
|
||||||
|
# description="Get user's blog name from the Internet.",
|
||||||
|
# ),
|
||||||
|
Tool(
|
||||||
|
name="Search user's Library Document",
|
||||||
|
func=search_document,
|
||||||
|
description="优先使用 Search user's Library Document.",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
llm = ChatOpenAI(temperature=0)
|
||||||
|
loaded_tools = load_tools(["llm-math"], llm=llm)
|
||||||
|
tools.extend(loaded_tools)
|
||||||
|
|
||||||
|
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
||||||
|
memory.chat_memory.messages.append(HumanMessage(content="必须使用中文回复。"))
|
||||||
|
# memory.clear()
|
||||||
|
|
||||||
|
agent_chain = initialize_agent(tools, llm,
|
||||||
|
agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,
|
||||||
|
verbose=False,
|
||||||
|
memory=memory)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
question = input("请输入问题:")
|
||||||
|
question = "必须使用中文回复:" + question
|
||||||
|
result = agent_chain.run(input=question)
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
|
121
document_ai/chat.py
Normal file
121
document_ai/chat.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
import json
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
import proto.documents_pb2
|
||||||
|
|
||||||
|
import init
|
||||||
|
import doc_client
|
||||||
|
import openai
|
||||||
|
|
||||||
|
# class Eva
|
||||||
|
|
||||||
|
|
||||||
|
# 连续对话
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": """
|
||||||
|
回答问题使用文档,并以提问的语言和Markdown回答,并告诉来源。
|
||||||
|
你得用“你”的身份指代用户。请辨别文档中的内容,有一些是不相干的。
|
||||||
|
"""
|
||||||
|
},
|
||||||
|
# {
|
||||||
|
# "role": "system",
|
||||||
|
# "content": f"""
|
||||||
|
# Context: {plain_text}
|
||||||
|
# """
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "role": "user",
|
||||||
|
# "content": f"""
|
||||||
|
# {question}
|
||||||
|
# """
|
||||||
|
# }
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def ask_question(question):
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": f"""
|
||||||
|
{question}
|
||||||
|
"""
|
||||||
|
})
|
||||||
|
|
||||||
|
question_vec = init.text_to_vector(question)
|
||||||
|
search_param = {
|
||||||
|
"data": [question_vec],
|
||||||
|
"anns_field": "vector",
|
||||||
|
"param": {"metric_type": "L2"},
|
||||||
|
"limit": 5,
|
||||||
|
"expr": "user_id == 2",
|
||||||
|
"output_fields": ["todo_id", "title", "source", "todo_description", "language", "text", "user_id"],
|
||||||
|
}
|
||||||
|
res = init.collection.search(**search_param)
|
||||||
|
|
||||||
|
plain_text = ""
|
||||||
|
|
||||||
|
for i in range(len(res[0])):
|
||||||
|
_chunk_id = res[0][i].id
|
||||||
|
|
||||||
|
if _chunk_id == "0":
|
||||||
|
continue
|
||||||
|
|
||||||
|
print("正在获取分块 " + str(_chunk_id) + " 的内容...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
_chunk_content = doc_client.stub.GetDocumentChunk(proto.documents_pb2.GetDocumentChunkByIdRequest(
|
||||||
|
id=_chunk_id
|
||||||
|
))
|
||||||
|
|
||||||
|
print(_chunk_content)
|
||||||
|
|
||||||
|
plain_text += ("=== \n" + f"文档 ID:{_chunk_content.document.id}\n"
|
||||||
|
+ f"文档内容: {_chunk_content.content}" + "===\n")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
messages.append({
|
||||||
|
"role": "system",
|
||||||
|
"content": f"""
|
||||||
|
文档: {plain_text}
|
||||||
|
"""
|
||||||
|
})
|
||||||
|
|
||||||
|
print("正在调用 LLM...")
|
||||||
|
result = openai.ChatCompletion.create(
|
||||||
|
messages=messages, model="gpt-3.5-turbo", temperature=0
|
||||||
|
)
|
||||||
|
res = result["choices"][0]["message"].to_dict_recursive()
|
||||||
|
|
||||||
|
# add to
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": res["content"]
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
# 删除多余的 system 消息
|
||||||
|
for i in range(len(messages)):
|
||||||
|
if messages[i]["role"] == "system":
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
messages.pop(i)
|
||||||
|
break
|
||||||
|
|
||||||
|
return res["content"]
|
||||||
|
|
||||||
|
|
||||||
|
while True:
|
||||||
|
|
||||||
|
print(messages)
|
||||||
|
|
||||||
|
if len(messages) > 10:
|
||||||
|
messages = messages[-10:]
|
||||||
|
print("很抱歉,我只能记住最近 10 条上下文数据。让我们重新开始吧。")
|
||||||
|
|
||||||
|
question = input("请输入问题:")
|
||||||
|
resp = ask_question(question)
|
||||||
|
print(resp)
|
@ -1,9 +1,15 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from langchain.chains import LLMChain
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
|
||||||
import proto.documents_pb2
|
import proto.documents_pb2
|
||||||
from langchain import text_splitter
|
from langchain import text_splitter
|
||||||
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
||||||
from langchain.embeddings import OpenAIEmbeddings
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
from langchain.llms.openai import OpenAI
|
# from langchain.llms.openai import OpenAI;
|
||||||
|
from langchain.chat_models.openai import ChatOpenAI
|
||||||
from langchain.schema.document import Document
|
from langchain.schema.document import Document
|
||||||
from pymilvus import (
|
from pymilvus import (
|
||||||
connections,
|
connections,
|
||||||
@ -16,6 +22,8 @@ from pymilvus import (
|
|||||||
|
|
||||||
import init
|
import init
|
||||||
import doc_client
|
import doc_client
|
||||||
|
import openai
|
||||||
|
from langchain.adapters import openai as lc_openai
|
||||||
|
|
||||||
#
|
#
|
||||||
# question = """
|
# question = """
|
||||||
@ -31,15 +39,15 @@ import doc_client
|
|||||||
# """
|
# """
|
||||||
|
|
||||||
question = """
|
question = """
|
||||||
为什么我会在 WHMCS 下开发摸不着头脑
|
错误 yarn 什么了遇到我
|
||||||
"""
|
"""
|
||||||
|
|
||||||
vec = init.text_to_vector(question)
|
|
||||||
|
|
||||||
# vec = ""
|
|
||||||
#
|
#
|
||||||
# with open("../question_vec.json", "r") as f:
|
# vec = init.text_to_vector(question)
|
||||||
# vec = json.load(f)
|
|
||||||
|
vec = ""
|
||||||
|
|
||||||
|
with open("../question_vec.json", "r") as f:
|
||||||
|
vec = json.load(f)
|
||||||
|
|
||||||
search_param = {
|
search_param = {
|
||||||
"data": [vec],
|
"data": [vec],
|
||||||
@ -51,8 +59,13 @@ search_param = {
|
|||||||
}
|
}
|
||||||
res = init.collection.search(**search_param)
|
res = init.collection.search(**search_param)
|
||||||
|
|
||||||
document_chunk_ids = []
|
# 保留 5 个
|
||||||
real_document = []
|
if len(res[0]) > 5:
|
||||||
|
res[0] = res[0][:5]
|
||||||
|
|
||||||
|
# document_chunk_ids = []
|
||||||
|
# real_document = []
|
||||||
|
plain_text = ""
|
||||||
|
|
||||||
for i in range(len(res[0])):
|
for i in range(len(res[0])):
|
||||||
_chunk_id = res[0][i].id
|
_chunk_id = res[0][i].id
|
||||||
@ -65,23 +78,104 @@ for i in range(len(res[0])):
|
|||||||
|
|
||||||
# print(_chunk_content)
|
# print(_chunk_content)
|
||||||
|
|
||||||
_doc_content_full = _chunk_content.content
|
# _doc_content_full = _chunk_content.content
|
||||||
|
# print("DOC OBJ:" + _doc_content_full)
|
||||||
|
plain_text += "=== \n" + _chunk_content.content + " ===\n"
|
||||||
|
|
||||||
# real_document.append(_doc_content)
|
# real_document.append(_doc_content)
|
||||||
# doc_obj = Document(page_content=_doc_content_full, metadata={"source": _chunk_content.title})
|
# doc_obj = Document(page_content=_doc_content_full, metadata={"source": _chunk_content.title})
|
||||||
doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"})
|
# doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"})
|
||||||
|
|
||||||
real_document.append(doc_obj)
|
# real_document.append(doc_obj)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
print(real_document)
|
# print(real_document)
|
||||||
|
|
||||||
print("正在调用 LLM...")
|
print("正在调用 LLM...")
|
||||||
chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce", return_intermediate_steps=True,
|
|
||||||
verbose=True)
|
|
||||||
|
|
||||||
question = "必须使用中文回复:" + question
|
# prompt_template = f"""Answer questions use the following context and reply in question spoken language and answer
|
||||||
output = chain({"input_documents": real_document, "question": question}, return_only_outputs=False)
|
# with Markdown format, you can fix syntax errors in the context, but do not change the meaning of the context.
|
||||||
print("回复:" + output["output_text"])
|
# you can tell user context errors(syntax or meaning) in answer.
|
||||||
|
# ---
|
||||||
|
# {plain_text}
|
||||||
|
# ---
|
||||||
|
# Question: {question}
|
||||||
|
# Answer:"""
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": """
|
||||||
|
回答问题使用以下上下文,并以提问的语言和Markdown回答,并告诉来源。
|
||||||
|
你得用“你”的身份指代用户。如果用户的问题有语法错误或者上下文的意思不对,你可以告诉用户。
|
||||||
|
请辨别上下文中的内容,有一些是不相干的。
|
||||||
|
"""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": f"""
|
||||||
|
Context: {plain_text}
|
||||||
|
"""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"""
|
||||||
|
{question}
|
||||||
|
"""
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
result = openai.ChatCompletion.create(
|
||||||
|
messages=messages, model="gpt-3.5-turbo", temperature=0
|
||||||
|
)
|
||||||
|
res = result["choices"][0]["message"].to_dict_recursive()
|
||||||
|
print(res)
|
||||||
|
|
||||||
|
|
||||||
|
# prompt_template = f"""
|
||||||
|
# ---
|
||||||
|
# {plain_text}
|
||||||
|
# ---
|
||||||
|
# Question: {question}
|
||||||
|
# Answer:"""
|
||||||
|
#
|
||||||
|
# print(prompt_template)
|
||||||
|
# # PROMPT = PromptTemplate(
|
||||||
|
# # template=prompt_template, input_variables=["real_document", "question"]
|
||||||
|
# # )
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# ChatOpenAI
|
||||||
|
|
||||||
|
# llm = OpenAI(temperature=0, model_name="gpt-3.5-turbo")
|
||||||
|
# # chain = LLMChain(llm=llm, prompt=PROMPT)
|
||||||
|
#
|
||||||
|
# output = llm(prompt_template)
|
||||||
|
|
||||||
|
# gpt = openai.Completion.create(
|
||||||
|
# engine="gpt-3.5-turbo",
|
||||||
|
# prompt=prompt_template,
|
||||||
|
# max_tokens=150,
|
||||||
|
# temperature=0,
|
||||||
|
# top_p=1,
|
||||||
|
# frequency_penalty=0,
|
||||||
|
# presence_penalty=0,
|
||||||
|
# stop=["==="]
|
||||||
|
# )
|
||||||
|
|
||||||
|
# output = gpt["choices"][0]["text"]
|
||||||
|
# print(output)
|
||||||
|
|
||||||
|
# output = chain({"real_document": real_document, "question": question}, return_only_outputs=True)
|
||||||
|
|
||||||
|
|
||||||
|
# print(output)
|
||||||
|
|
||||||
|
# chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce", return_intermediate_steps=True,
|
||||||
|
# verbose=True)
|
||||||
|
#
|
||||||
|
# question = "必须使用中文回复:" + question
|
||||||
|
# output = chain({"input_documents": real_document, "question": question}, return_only_outputs=False)
|
||||||
|
# print("回复:" + output["output_text"])
|
||||||
|
@ -42,7 +42,7 @@ class AIServer(proto.document_query_pb2_grpc.DocumentQuery):
|
|||||||
"data": [vec],
|
"data": [vec],
|
||||||
"anns_field": "vector",
|
"anns_field": "vector",
|
||||||
"param": {"metric_type": "L2"},
|
"param": {"metric_type": "L2"},
|
||||||
"limit": 10,
|
"limit": 5,
|
||||||
"expr": "user_id == " + str(target.user_id),
|
"expr": "user_id == " + str(target.user_id),
|
||||||
"output_fields": ["document_id", "user_id"],
|
"output_fields": ["document_id", "user_id"],
|
||||||
}
|
}
|
||||||
|
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
Loading…
Reference in New Issue
Block a user