langchain-chat-with-milvus/document_ai/search.py

184 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import openai
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
import proto.documents_pb2
from langchain import text_splitter
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.embeddings import OpenAIEmbeddings
# from langchain.llms.openai import OpenAI;
from langchain.chat_models.openai import ChatOpenAI
from langchain.schema.document import Document
from pymilvus import (
connections,
utility,
FieldSchema,
CollectionSchema,
DataType,
Collection,
)
import init
import doc_client
import openai
from langchain.adapters import openai as lc_openai
#
# question = """
# yarn : File C:\\Users\\ivamp\\AppData\\Roaming\\npm\\yarn.ps1 cannot be loaded because running scripts is disabled on this sy
# stem. For more information, see about_Execution_Policies at https:/go.microsoft.com/fwlink/?LinkID=135170.
# At line:1 char:1
# + yarn config set registry https://registry.npm.taobao.org/
# + ~~~~
# + CategoryInfo : SecurityError: (:) [], PSSecurityException
# + FullyQualifiedErrorId : UnauthorizedAccess
#
# 是什么问题,该怎么解决
# """
question = """
错误 yarn 什么了遇到我
"""
#
# vec = init.text_to_vector(question)
vec = ""
with open("../question_vec.json", "r") as f:
vec = json.load(f)
search_param = {
"data": [vec],
"anns_field": "vector",
"param": {"metric_type": "L2"},
"limit": 10,
"expr": "user_id == 2",
"output_fields": ["todo_id", "title", "source", "todo_description", "language", "text", "user_id"],
}
res = init.collection.search(**search_param)
# 保留 5 个
if len(res[0]) > 5:
res[0] = res[0][:5]
# document_chunk_ids = []
# real_document = []
plain_text = ""
document_chunks = []
for i in range(len(res[0])):
_chunk_id = res[0][i].id
print("正在获取分块 " + str(_chunk_id) + " 的内容...")
try:
_chunk_content = doc_client.stub.GetDocumentChunk(proto.documents_pb2.GetDocumentChunkByIdRequest(
id=_chunk_id
))
# print(_chunk_content)
# _doc_content_full = _chunk_content.content
# print("DOC OBJ:" + _doc_content_full)
document_chunks.append(_chunk_content.content)
plain_text += "=== \n" + _chunk_content.content + " ===\n"
# real_document.append(_doc_content)
# doc_obj = Document(page_content=_doc_content_full, metadata={"source": _chunk_content.title})
# doc_obj = Document(page_content=_doc_content_full, metadata={"source": "chunked content"})
# real_document.append(doc_obj)
except Exception as e:
print(e)
# print(real_document)
print("正在调用 LLM...")
# prompt_template = f"""Answer questions use the following context and reply in question spoken language and answer
# with Markdown format, you can fix syntax errors in the context, but do not change the meaning of the context.
# you can tell user context errors(syntax or meaning) in answer.
# ---
# {plain_text}
# ---
# Question: {question}
# Answer:"""
messages = [
{
"role": "system",
"content": """
回答问题使用以下上下文并以提问的语言和Markdown回答并告诉来源。
你得用“你”的身份指代用户。如果用户的问题有语法错误或者上下文的意思不对,你可以告诉用户。
请辨别上下文中的内容,有一些是不相干的。
"""
},
{
"role": "system",
"content": f"""
{plain_text}
"""
},
{
"role": "user",
"content": f"""
{question}
"""
}
]
result = openai.ChatCompletion.create(
messages=messages, model="gpt-3.5-turbo", temperature=0
)
res = result["choices"][0]["message"].to_dict_recursive()
print(res)
# prompt_template = f"""
# ---
# {plain_text}
# ---
# Question: {question}
# Answer:"""
#
# print(prompt_template)
# # PROMPT = PromptTemplate(
# # template=prompt_template, input_variables=["real_document", "question"]
# # )
#
#
# ChatOpenAI
# llm = OpenAI(temperature=0, model_name="gpt-3.5-turbo")
# # chain = LLMChain(llm=llm, prompt=PROMPT)
#
# output = llm(prompt_template)
# gpt = openai.Completion.create(
# engine="gpt-3.5-turbo",
# prompt=prompt_template,
# max_tokens=150,
# temperature=0,
# top_p=1,
# frequency_penalty=0,
# presence_penalty=0,
# stop=["==="]
# )
# output = gpt["choices"][0]["text"]
# print(output)
# output = chain({"real_document": real_document, "question": question}, return_only_outputs=True)
# print(output)
# chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="map_reduce", return_intermediate_steps=True,
# verbose=True)
#
# question = "必须使用中文回复:" + question
# output = chain({"input_documents": real_document, "question": question}, return_only_outputs=False)
# print("回复:" + output["output_text"])