|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This script provides a Gradio interface for interacting with a chatbot based on Retrieval-Augmented Generation. |
|
""" |
|
|
|
import argparse |
|
import base64 |
|
import copy |
|
import hashlib |
|
import json |
|
import logging |
|
import os |
|
import textwrap |
|
from argparse import ArgumentParser |
|
from collections import namedtuple |
|
from datetime import datetime |
|
from functools import partial |
|
|
|
import faiss |
|
import gradio as gr |
|
import numpy as np |
|
from bot_requests import BotClient |
|
|
|
os.environ["NO_PROXY"] = "localhost,127.0.0.1" |
|
|
|
logging.root.setLevel(logging.INFO) |
|
|
|
FILE_URL_DEFAULT = "data/coffee.txt" |
|
RELEVANT_PASSAGE_DEFAULT = textwrap.dedent( |
|
"""\ |
|
1675年时,英格兰就有3000多家咖啡馆;启蒙运动时期,咖啡馆成为民众深入讨论宗教和政治的聚集地, |
|
1670年代的英国国王查理二世就曾试图取缔咖啡馆。这一时期的英国人认为咖啡具有药用价值, |
|
甚至名医也会推荐将咖啡用于医疗。""" |
|
) |
|
|
|
QUERY_REWRITE_PROMPT = textwrap.dedent( |
|
"""\ |
|
【当前时间】 |
|
{TIMESTAMP} |
|
|
|
【对话内容】 |
|
{CONVERSATION} |
|
|
|
你的任务是根据上面user与assistant的对话内容,理解user意图,改写user的最后一轮对话,以便更高效地从知识库查找相关知识。具体的改写要求如下: |
|
1. 如果user的问题包括几个小问题,请将它们分成多个单独的问题。 |
|
2. 如果user的问题涉及到之前对话的信息,请将这些信息融入问题中,形成一个不需要上下文就可以理解的完整问题。 |
|
3. 如果user的问题是在比较或关联多个事物时,先将其拆分为单个事物的问题,例如‘A与B比起来怎么样’,拆分为:‘A怎么样’以及‘B怎么样’。 |
|
4. 如果user的问题中描述事物的限定词有多个,请将多个限定词拆分成单个限定词。 |
|
5. 如果user的问题具有**时效性(需要包含当前时间信息,才能得到正确的回复)**的时候,需要将当前时间信息添加到改写的query中;否则不加入当前时间信息。 |
|
6. 只在**确有必要**的情况下改写,不需要改写时query输出[]。输出不超过 5 个改写问题,不要为了凑满数量而输出冗余问题。 |
|
|
|
【输出格式】只输出 JSON ,不要给出多余内容 |
|
```json |
|
{{ |
|
"query": ["改写问题1", "改写问题2"...] |
|
}}``` |
|
""" |
|
) |
|
ANSWER_PROMPT = textwrap.dedent( |
|
"""\ |
|
你是阅读理解问答专家。 |
|
|
|
【文档知识】 |
|
{DOC_CONTENT} |
|
|
|
你的任务是根据对话内容,理解用户需求,参考文档知识回答用户问题,知识参考详细原则如下: |
|
- 对于同一信息点,如文档知识与模型通用知识均可支撑,应优先以文档知识为主,并对信息进行验证和综合。 |
|
- 如果文档知识不足或信息冲突,必须指出“根据资料无法确定”或“不同资料存在矛盾”,不得引入文档知识与通识之外的主观推测。 |
|
|
|
同时,回答问题需要综合考虑规则要求中的各项内容,详细要求如下: |
|
【规则要求】 |
|
* 回答问题时,应优先参考与问题紧密相关的文档知识,不要在答案中引入任何与问题无关的文档内容。 |
|
* 回答中不可以让用户知道你查询了相关文档。 |
|
* 回复答案不要出现'根据文档知识','根据当前时间'等表述。 |
|
* 论述突出重点内容,以分点条理清晰的结构化格式输出。 |
|
|
|
【当前时间】 |
|
{TIMESTAMP} |
|
|
|
【对话内容】 |
|
{CONVERSATION} |
|
|
|
直接输出回复内容即可。 |
|
""" |
|
) |
|
QUERY_DEFAULT = "1675 年时,英格兰有多少家咖啡馆?" |
|
|
|
|
|
def get_args() -> argparse.Namespace: |
|
""" |
|
Parse and return command line arguments for the ERNIE models chat demo. |
|
Configures server settings, model endpoint, and document processing parameters. |
|
|
|
Returns: |
|
argparse.Namespace: Parsed command line arguments containing all the above settings. |
|
""" |
|
parser = ArgumentParser(description="ERNIE models web chat demo.") |
|
|
|
parser.add_argument( |
|
"--server-port", type=int, default=7860, help="Demo server port." |
|
) |
|
parser.add_argument( |
|
"--server-name", type=str, default="0.0.0.0", help="Demo server name." |
|
) |
|
parser.add_argument( |
|
"--max_char", |
|
type=int, |
|
default=20000, |
|
help="Maximum character limit for messages.", |
|
) |
|
parser.add_argument( |
|
"--max_retry_num", type=int, default=3, help="Maximum retry number for request." |
|
) |
|
parser.add_argument( |
|
"--model_map", |
|
type=str, |
|
default='{"ernie-4.5-turbo-128k-preview": "https://qianfan.baidubce.com/v2"}', |
|
help="""JSON string defining model name to endpoint mappings. |
|
Required Format: |
|
{"ERNIE-4.5": "http://localhost:port/v1"} |
|
|
|
Note: |
|
- Endpoints must be valid HTTP URL |
|
- Specify ONE model endpoint in JSON format. |
|
- Prefix determines model capabilities: |
|
* ERNIE-4.5: Text-only model |
|
""", |
|
) |
|
parser.add_argument( |
|
"--embedding_service_url", |
|
type=str, |
|
default="https://qianfan.baidubce.com/v2", |
|
help="Embedding service url.", |
|
) |
|
parser.add_argument( |
|
"--qianfan_api_key", |
|
type=str, |
|
default=os.environ.get("API_KEY"), |
|
help="Qianfan API key.", |
|
) |
|
parser.add_argument( |
|
"--embedding_model", |
|
type=str, |
|
default="embedding-v1", |
|
help="Embedding model name.", |
|
) |
|
parser.add_argument( |
|
"--embedding_dim", |
|
type=int, |
|
default=384, |
|
help="Dimension of the embedding vector.", |
|
) |
|
parser.add_argument( |
|
"--chunk_size", |
|
type=int, |
|
default=512, |
|
help="Chunk size for splitting long documents.", |
|
) |
|
parser.add_argument( |
|
"--top_k", type=int, default=3, help="Top k results to retrieve." |
|
) |
|
parser.add_argument( |
|
"--faiss_index_path", |
|
type=str, |
|
default="data/faiss_index", |
|
help="Faiss index path.", |
|
) |
|
parser.add_argument( |
|
"--text_db_path", |
|
type=str, |
|
default="data/text_db.jsonl", |
|
help="Text database path.", |
|
) |
|
parser.add_argument( |
|
"--concurrency_limit", type=int, default=10, help="Default concurrency limit." |
|
) |
|
parser.add_argument( |
|
"--max_queue_size", type=int, default=50, help="Maximum queue size for request." |
|
) |
|
|
|
args = parser.parse_args() |
|
try: |
|
args.model_map = json.loads(args.model_map) |
|
|
|
|
|
if len(args.model_map) < 1: |
|
raise ValueError("model_map must contain at least one model configuration") |
|
except json.JSONDecodeError as e: |
|
raise ValueError("Invalid JSON format for --model_map") from e |
|
|
|
return args |
|
|
|
|
|
class FaissTextDatabase: |
|
""" |
|
A vector database for text retrieval using FAISS. |
|
Provides efficient similarity search and document management capabilities. |
|
""" |
|
|
|
def __init__(self, args, bot_client: BotClient): |
|
""" |
|
Initialize the FaissTextDatabase. |
|
|
|
Args: |
|
args: arguments for initialization |
|
bot_client: instance of BotClient |
|
embedding_dim: dimension of the embedding vector |
|
""" |
|
self.logger = logging.getLogger(__name__) |
|
|
|
self.bot_client = bot_client |
|
self.embedding_dim = getattr(args, "embedding_dim", 384) |
|
self.top_k = getattr(args, "top_k", 3) |
|
self.context_size = getattr(args, "context_size", 2) |
|
self.faiss_index_path = getattr(args, "faiss_index_path", "data/faiss_index") |
|
self.text_db_path = getattr(args, "text_db_path", "data/text_db.jsonl") |
|
|
|
|
|
if os.path.exists(self.faiss_index_path) and os.path.exists(self.text_db_path): |
|
self.index = faiss.read_index(self.faiss_index_path) |
|
with open(self.text_db_path, "r", encoding="utf-8") as f: |
|
self.text_db = json.load(f) |
|
else: |
|
self.index = faiss.IndexFlatIP(self.embedding_dim) |
|
self.text_db = { |
|
"file_md5s": [], |
|
"chunks": [], |
|
} |
|
|
|
def calculate_md5(self, file_path: str) -> str: |
|
""" |
|
Calculate the MD5 hash of a file |
|
|
|
Args: |
|
file_path: the path of the source file |
|
|
|
Returns: |
|
str: the MD5 hash |
|
""" |
|
with open(file_path, "rb") as f: |
|
return hashlib.md5(f.read()).hexdigest() |
|
|
|
def is_file_processed(self, file_path: str) -> bool: |
|
""" |
|
Check if the file has been processed before |
|
|
|
Args: |
|
file_path: the path of the source file |
|
|
|
Returns: |
|
bool: whether the file has been processed |
|
""" |
|
file_md5 = self.calculate_md5(file_path) |
|
return file_md5 in self.text_db["file_md5s"] |
|
|
|
def add_embeddings( |
|
self, |
|
file_path: str, |
|
segments: list[str], |
|
progress_bar: gr.Progress = None, |
|
save_file: bool = False, |
|
) -> bool: |
|
""" |
|
Stores document embeddings in FAISS database after checking for duplicates. |
|
Generates embeddings for each text segment, updates the FAISS index and metadata database, |
|
and persists changes to disk. Includes optional progress tracking for Gradio interfaces. |
|
|
|
Args: |
|
file_path: the path of the source file |
|
segments: the list of segments |
|
progress_bar: the progress bar object |
|
|
|
Returns: |
|
bool: whether the operation was successful |
|
""" |
|
file_md5 = self.calculate_md5(file_path) |
|
if file_md5 in self.text_db["file_md5s"]: |
|
self.logger.info(f"File already processed: {file_path} (MD5: {file_md5})") |
|
return False |
|
|
|
|
|
vectors = [] |
|
file_name = os.path.basename(file_path) |
|
file_txt = "".join(file_name.split(".")[:-1])[:30] |
|
for i, segment in enumerate(segments): |
|
vectors.append(self.bot_client.embed_fn(file_txt + "\n" + segment)) |
|
if progress_bar is not None: |
|
progress_bar((i + 1) / len(segments), desc=file_name + " Processing...") |
|
vectors = np.array(vectors) |
|
self.index.add(vectors.astype("float32")) |
|
|
|
start_id = len(self.text_db["chunks"]) |
|
for i, text in enumerate(segments): |
|
self.text_db["chunks"].append( |
|
{ |
|
"file_md5": file_md5, |
|
"file_name": file_name, |
|
"file_txt": file_txt, |
|
"text": text, |
|
"vector_id": start_id + i, |
|
} |
|
) |
|
|
|
self.text_db["file_md5s"].append(file_md5) |
|
if save_file: |
|
self.save() |
|
return True |
|
|
|
def search_with_context(self, query_list: list) -> str: |
|
""" |
|
Finds the most relevant text chunks for multiple queries and includes surrounding context. |
|
Uses FAISS to find the closest matching embeddings, then retrieves adjacent chunks |
|
from the same source document to provide better context understanding. |
|
|
|
Args: |
|
query_list: list of input query strings |
|
|
|
Returns: |
|
str: the concatenated output string |
|
""" |
|
|
|
all_indices = [] |
|
for query in query_list: |
|
query_vector = np.array([self.bot_client.embed_fn(query)]).astype("float32") |
|
_, indices = self.index.search(query_vector, self.top_k) |
|
all_indices.extend(indices[0].tolist()) |
|
|
|
|
|
unique_indices = sorted(set(all_indices)) |
|
self.logger.info(f"Retrieved indices: {all_indices}") |
|
self.logger.info(f"Unique indices after deduplication: {unique_indices}") |
|
|
|
|
|
expanded_indices = set() |
|
file_boundaries = {} |
|
for target_idx in unique_indices: |
|
target_chunk = self.text_db["chunks"][target_idx] |
|
target_file_md5 = target_chunk["file_md5"] |
|
|
|
if target_file_md5 not in file_boundaries: |
|
file_start = target_idx |
|
while ( |
|
file_start > 0 |
|
and self.text_db["chunks"][file_start - 1]["file_md5"] |
|
== target_file_md5 |
|
): |
|
file_start -= 1 |
|
file_end = target_idx |
|
while ( |
|
file_end < len(self.text_db["chunks"]) - 1 |
|
and self.text_db["chunks"][file_end + 1]["file_md5"] |
|
== target_file_md5 |
|
): |
|
file_end += 1 |
|
else: |
|
file_start, file_end = file_boundaries[target_file_md5] |
|
|
|
|
|
start = max(file_start, target_idx - self.context_size) |
|
end = min(file_end, target_idx + self.context_size) |
|
|
|
for pos in range(start, end + 1): |
|
expanded_indices.add(pos) |
|
|
|
|
|
sorted_indices = sorted(expanded_indices) |
|
groups = [] |
|
current_group = [sorted_indices[0]] |
|
for i in range(1, len(sorted_indices)): |
|
if ( |
|
sorted_indices[i] == sorted_indices[i - 1] + 1 |
|
and self.text_db["chunks"][sorted_indices[i]]["file_md5"] |
|
== self.text_db["chunks"][sorted_indices[i - 1]]["file_md5"] |
|
): |
|
current_group.append(sorted_indices[i]) |
|
else: |
|
groups.append(current_group) |
|
current_group = [sorted_indices[i]] |
|
groups.append(current_group) |
|
|
|
|
|
result = "" |
|
for idx, group in enumerate(groups): |
|
result += "\n段落{idx}:\n{title}\n".format( |
|
idx=idx + 1, title=self.text_db["chunks"][group[0]]["file_txt"] |
|
) |
|
for idx in group: |
|
result += self.text_db["chunks"][idx]["text"] + "\n" |
|
self.logger.info(f"Merged chunk range: {group[0]}-{group[-1]}") |
|
|
|
return result |
|
|
|
def save(self) -> None: |
|
"""Save the database to disk.""" |
|
faiss.write_index(self.index, self.faiss_index_path) |
|
|
|
with open(self.text_db_path, "w", encoding="utf-8") as f: |
|
json.dump(self.text_db, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
class GradioEvents: |
|
""" |
|
Manages event handling and UI interactions for Gradio applications. |
|
Provides methods to process user inputs, trigger callbacks, and update interface components. |
|
""" |
|
|
|
@staticmethod |
|
def get_history_conversation(task_history: list) -> tuple: |
|
""" |
|
Converts task history into conversation format for model processing. |
|
Transforms query-response pairs into structured message history and plain text. |
|
|
|
Args: |
|
task_history (list): List of tuples containing queries and responses. |
|
|
|
Returns: |
|
tuple: Tuple containing two elements: |
|
- conversation (list): List of dictionaries representing the conversation history. |
|
- conversation_str (str): String representation of the conversation history. |
|
""" |
|
conversation = [] |
|
conversation_str = "" |
|
for query_h, response_h in task_history: |
|
conversation.append({"role": "user", "content": query_h}) |
|
conversation.append({"role": "assistant", "content": response_h}) |
|
conversation_str += f"user:\n{query_h}\n assistant:\n{response_h}\n " |
|
return conversation, conversation_str |
|
|
|
@staticmethod |
|
def chat_stream( |
|
query: str, |
|
task_history: list, |
|
model: str, |
|
faiss_db: FaissTextDatabase, |
|
bot_client: BotClient, |
|
) -> dict: |
|
""" |
|
Streams chatbot responses by processing queries with context from history and FAISS database. |
|
Integrates language model generation with knowledge retrieval to produce dynamic responses. |
|
Yields response events in real-time for interactive conversation experiences. |
|
|
|
Args: |
|
query (str): The query string. |
|
task_history (list): The task history record list. |
|
model (Model): The model used to generate responses. |
|
bot_client (BotClient): The chatbot client object. |
|
faiss_db (FaissTextDatabase): The FAISS database object. |
|
|
|
Yields: |
|
dict: A dictionary containing the event type and its corresponding content. |
|
""" |
|
conversation, conversation_str = GradioEvents.get_history_conversation( |
|
task_history |
|
) |
|
conversation_str += f"user:\n{query}\n" |
|
|
|
search_info_message = QUERY_REWRITE_PROMPT.format( |
|
TIMESTAMP=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
|
CONVERSATION=conversation_str, |
|
) |
|
search_conversation = [{"role": "user", "content": search_info_message}] |
|
search_info_result = GradioEvents.get_sub_query( |
|
search_conversation, model, bot_client |
|
) |
|
if search_info_result is None: |
|
search_info_result = {"query": [query]} |
|
|
|
if search_info_result.get("query", []): |
|
relevant_passages = faiss_db.search_with_context( |
|
search_info_result["query"] |
|
) |
|
yield {"type": "relevant_passage", "content": relevant_passages} |
|
|
|
query = ANSWER_PROMPT.format( |
|
DOC_CONTENT=relevant_passages, |
|
TIMESTAMP=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
|
CONVERSATION=conversation_str, |
|
) |
|
|
|
conversation.append({"role": "user", "content": query}) |
|
try: |
|
req_data = {"messages": conversation} |
|
for chunk in bot_client.process_stream(model, req_data): |
|
if "error" in chunk: |
|
raise Exception(chunk["error"]) |
|
|
|
message = chunk.get("choices", [{}])[0].get("delta", {}) |
|
content = message.get("content", "") |
|
|
|
if content: |
|
yield {"type": "answer", "content": content} |
|
|
|
except Exception as e: |
|
raise gr.Error("Exception: " + repr(e)) |
|
|
|
@staticmethod |
|
def predict_stream( |
|
query: str, |
|
chatbot: list, |
|
task_history: list, |
|
model: str, |
|
faiss_db: FaissTextDatabase, |
|
bot_client: BotClient, |
|
) -> tuple: |
|
""" |
|
Generates streaming responses by combining model predictions with knowledge retrieval. |
|
Processes user queries using conversation history and FAISS database context, |
|
yielding updated chat messages and relevant passages in real-time. |
|
|
|
Args: |
|
query (str): The content of the user's input query. |
|
chatbot (list): The chatbot's historical message list. |
|
task_history (list): The task history record list. |
|
model (Model): The model used to generate responses. |
|
bot_client (object): The chatbot client object. |
|
faiss_db (FaissTextDatabase): The FAISS database instance. |
|
|
|
Yields: |
|
tuple: A tuple containing the updated chatbot's message list and the relevant passage. |
|
""" |
|
query = query if query else QUERY_DEFAULT |
|
|
|
logging.info(f"User: {query}") |
|
chatbot.append({"role": "user", "content": query}) |
|
|
|
|
|
yield chatbot, None |
|
|
|
new_texts = GradioEvents.chat_stream( |
|
query, |
|
task_history, |
|
model, |
|
faiss_db, |
|
bot_client, |
|
) |
|
|
|
response = "" |
|
current_relevant_passage = None |
|
for new_text in new_texts: |
|
if not isinstance(new_text, dict): |
|
continue |
|
|
|
if new_text.get("type") == "embedding": |
|
current_relevant_passage = new_text["content"] |
|
yield chatbot, current_relevant_passage |
|
continue |
|
elif new_text.get("type") == "relevant_passage": |
|
current_relevant_passage = new_text["content"] |
|
yield chatbot, current_relevant_passage |
|
continue |
|
elif new_text.get("type") == "answer": |
|
response += new_text["content"] |
|
|
|
|
|
if chatbot[-1].get("role") == "assistant": |
|
chatbot.pop(-1) |
|
|
|
if response: |
|
chatbot.append({"role": "assistant", "content": response}) |
|
yield chatbot, current_relevant_passage |
|
|
|
logging.info(f"History: {task_history}") |
|
task_history.append((query, response)) |
|
logging.info(f"ERNIE models: {response}") |
|
|
|
@staticmethod |
|
def regenerate( |
|
chatbot: list, |
|
task_history: list, |
|
model: str, |
|
faiss_db: FaissTextDatabase, |
|
bot_client: BotClient, |
|
) -> tuple: |
|
""" |
|
Regenerate the chatbot's response based on the latest user query |
|
|
|
Args: |
|
chatbot (list): Chat history list |
|
task_history (list): Task history |
|
model (str): Model name to use |
|
bot_client (BotClient): Bot request client instance |
|
faiss_db (FaissTextDatabase): Faiss database instance |
|
|
|
Yields: |
|
tuple: Updated chatbot and relevant_passage |
|
""" |
|
if not task_history: |
|
yield chatbot, None |
|
return |
|
|
|
item = task_history.pop(-1) |
|
while len(chatbot) != 0 and chatbot[-1].get("role") == "assistant": |
|
chatbot.pop(-1) |
|
chatbot.pop(-1) |
|
|
|
yield from GradioEvents.predict_stream( |
|
item[0], |
|
chatbot, |
|
task_history, |
|
model, |
|
faiss_db, |
|
bot_client, |
|
) |
|
|
|
@staticmethod |
|
def reset_user_input() -> gr.update: |
|
""" |
|
Reset user input box content. |
|
|
|
Returns: |
|
gr.update: An update object representing the cleared value |
|
""" |
|
return gr.update(value="") |
|
|
|
@staticmethod |
|
def reset_state() -> namedtuple: |
|
""" |
|
Reset chat state and clear all history. |
|
|
|
Returns: |
|
tuple: A named tuple containing the updated values for chatbot, task_history, file_btn, and relevant_passage |
|
""" |
|
GradioEvents.gc() |
|
|
|
reset_result = namedtuple( |
|
"reset_result", ["chatbot", "task_history", "file_btn", "relevant_passage"] |
|
) |
|
return reset_result( |
|
[], |
|
[], |
|
gr.update(value=None), |
|
gr.update(value=None), |
|
) |
|
|
|
@staticmethod |
|
def gc(): |
|
""" |
|
Force garbage collection to free memory. |
|
""" |
|
import gc |
|
|
|
gc.collect() |
|
|
|
@staticmethod |
|
def get_image_url(image_path: str) -> str: |
|
""" |
|
Encode image file to Base64 format and generate data URL. |
|
Reads an image file from disk, encodes it as Base64, and formats it |
|
as a data URL that can be used directly in HTML or API requests. |
|
|
|
Args: |
|
image_path (str): Path to the image file. Must be a valid file path. |
|
|
|
Returns: |
|
str: Data URL string in format "data:image/{ext};base64,{encoded_data}" |
|
""" |
|
base64_image = "" |
|
extension = image_path.split(".")[-1] |
|
with open(image_path, "rb") as image_file: |
|
base64_image = base64.b64encode(image_file.read()).decode("utf-8") |
|
url = f"data:image/{extension};base64,{base64_image}" |
|
return url |
|
|
|
@staticmethod |
|
def get_sub_query( |
|
conversation: list, model_name: str, bot_client: BotClient |
|
) -> dict: |
|
""" |
|
Enhances user queries by generating alternative phrasings using language models. |
|
Creates semantically similar variations of the original query to improve retrieval accuracy. |
|
Returns structured dictionary containing both original and rephrased queries. |
|
|
|
Args: |
|
conversation (list): The conversation history. |
|
model_name (str): The name of the model to use for rephrasing. |
|
bot_client (BotClient): The bot client instance. |
|
|
|
Returns: |
|
dict: The rephrased query. |
|
""" |
|
req_data = {"messages": conversation} |
|
try: |
|
response = bot_client.process(model_name, req_data) |
|
search_info_res = response["choices"][0]["message"]["content"] |
|
start = search_info_res.find("{") |
|
end = search_info_res.rfind("}") + 1 |
|
if start >= 0 and end > start: |
|
search_info_res = search_info_res[start:end] |
|
search_info_res = json.loads(search_info_res) |
|
if search_info_res.get("sub_query_list", []): |
|
unique_list = list(set(search_info_res["sub_query_list"])) |
|
search_info_res["sub_query_list"] = unique_list |
|
return search_info_res |
|
except Exception: |
|
logging.error("Error: Model output is not a valid JSON") |
|
return None |
|
|
|
@staticmethod |
|
def split_oversized_line(line: str, chunk_size: int) -> tuple: |
|
""" |
|
Split a line into two parts based on punctuation marks or whitespace while preserving |
|
natural language boundaries and maintaining the original content structure. |
|
|
|
Args: |
|
line (str): The line to split. |
|
chunk_size (int): The maximum length of each chunk. |
|
|
|
Returns: |
|
tuple: Two strings, the first part of the original line and the rest of the line. |
|
""" |
|
PUNCTUATIONS = { |
|
".", |
|
"。", |
|
"!", |
|
"!", |
|
"?", |
|
"?", |
|
",", |
|
",", |
|
";", |
|
";", |
|
":", |
|
":", |
|
} |
|
|
|
if len(line) <= chunk_size: |
|
return line, "" |
|
|
|
|
|
split_pos = chunk_size |
|
for i in range(chunk_size, 0, -1): |
|
if line[i] in PUNCTUATIONS: |
|
split_pos = i + 1 |
|
break |
|
|
|
|
|
if split_pos == chunk_size: |
|
split_pos = line.rfind(" ", 0, chunk_size) |
|
if split_pos == -1: |
|
split_pos = chunk_size |
|
|
|
return line[:split_pos], line[split_pos:] |
|
|
|
@staticmethod |
|
def split_text_into_chunks(file_url: str, chunk_size: int) -> list: |
|
""" |
|
Split file text into chunks of a specified size while respecting natural language boundaries |
|
and avoiding mid-word splits whenever possible. |
|
|
|
Args: |
|
file_url (str): The file URL. |
|
chunk_size (int): The maximum length of each chunk. |
|
|
|
Returns: |
|
list: A list of strings, where each element represents a chunk of the original text. |
|
""" |
|
with open(file_url, "r", encoding="utf-8") as f: |
|
text = f.read() |
|
|
|
if not text: |
|
logging.error("Error: File is empty") |
|
return [] |
|
lines = [line.strip() for line in text.split("\n") if line.strip()] |
|
chunks = [] |
|
current_chunk = [] |
|
current_length = 0 |
|
|
|
for line in lines: |
|
|
|
if current_length + len(line) > chunk_size and current_chunk: |
|
chunks.append("\n".join(current_chunk)) |
|
current_chunk = [] |
|
current_length = 0 |
|
|
|
|
|
while len(line) > chunk_size: |
|
head, line = GradioEvents.split_oversized_line(line, chunk_size) |
|
chunks.append(head) |
|
|
|
|
|
if line: |
|
current_chunk.append(line) |
|
current_length += len(line) + 1 |
|
|
|
if current_chunk: |
|
chunks.append("\n".join(current_chunk)) |
|
return chunks |
|
|
|
@staticmethod |
|
def file_upload( |
|
files_url: list, |
|
chunk_size: int, |
|
faiss_db: FaissTextDatabase, |
|
progress_bar: gr.Progress = gr.Progress(), |
|
) -> str: |
|
""" |
|
Uploads and processes multiple files by splitting them into semantically meaningful chunks, |
|
then indexes them in the FAISS database with progress tracking. |
|
|
|
Args: |
|
files_url (list): List of file URLs. |
|
chunk_size (int): Maximum chunk size. |
|
faiss_db (FaissTextDatabase): FAISS database instance. |
|
progress_bar (gr.Progress): Progress bar instance. |
|
|
|
Returns: |
|
str: Message indicating successful completion. |
|
""" |
|
if not files_url: |
|
return |
|
yield gr.update(visible=True) |
|
for file_url in files_url: |
|
if not GradioEvents.save_file_to_db( |
|
file_url, chunk_size, faiss_db, progress_bar |
|
): |
|
file_name = os.path.basename(file_url) |
|
gr.Info(f"{file_name} already processed.") |
|
|
|
yield gr.update(visible=False) |
|
|
|
@staticmethod |
|
def save_file_to_db( |
|
file_url: str, |
|
chunk_size: int, |
|
faiss_db: FaissTextDatabase, |
|
progress_bar: gr.Progress = None, |
|
save_file: bool = False, |
|
): |
|
""" |
|
Processes and indexes document content into FAISS database with semantic-aware chunking. |
|
Handles file validation, text segmentation, embedding generation and storage operations. |
|
|
|
Args: |
|
file_url (str): File URL. |
|
chunk_size (int): Chunk size. |
|
faiss_db (FaissTextDatabase): FAISS database instance. |
|
progress_bar (gr.Progress): Progress bar instance. |
|
|
|
Returns: |
|
bool: True if the file was saved successfully, otherwise False. |
|
""" |
|
if not os.path.exists(file_url): |
|
logging.error(f"File not found: {file_url}") |
|
return False |
|
|
|
file_name = os.path.basename(file_url) |
|
if not faiss_db.is_file_processed(file_url): |
|
logging.info(f"{file_url} not processed yet, processing now...") |
|
try: |
|
segments = GradioEvents.split_text_into_chunks(file_url, chunk_size) |
|
faiss_db.add_embeddings(file_url, segments, progress_bar, save_file) |
|
|
|
logging.info(f"{file_url} processed successfully.") |
|
return True |
|
except Exception as e: |
|
logging.error(f"Error processing {file_url}: {e!s}") |
|
gr.Error(f"Error processing file: {file_name}") |
|
raise |
|
else: |
|
logging.info(f"{file_url} already processed.") |
|
return False |
|
|
|
|
|
def launch_demo( |
|
args: argparse.Namespace, |
|
bot_client: BotClient, |
|
faiss_db_template: FaissTextDatabase, |
|
): |
|
""" |
|
Launch demo program |
|
|
|
Args: |
|
args (argparse.Namespace): argparse Namespace object containing parsed command line arguments |
|
bot_client (BotClient): Bot client instance |
|
faiss_db (FaissTextDatabase): FAISS database instance |
|
""" |
|
css = """ |
|
/* Hide original Chinese text */ |
|
#file-upload .wrap { |
|
font-size: 0 !important; |
|
position: relative; |
|
display: flex; |
|
flex-direction: column; |
|
align-items: center; |
|
justify-content: center; |
|
} |
|
|
|
/* Insert English prompt text below the SVG icon */ |
|
#file-upload .wrap::after { |
|
content: "Drag and drop files here or click to upload"; |
|
font-size: 18px; |
|
color: #555; |
|
margin-top: 8px; |
|
white-space: nowrap; |
|
} |
|
""" |
|
with gr.Blocks(css=css) as demo: |
|
model_name = gr.State(next(iter(args.model_map.keys()))) |
|
faiss_db = gr.State(copy.deepcopy(faiss_db_template)) |
|
|
|
logo_url = GradioEvents.get_image_url("assets/logo.png") |
|
gr.Markdown( |
|
f"""\ |
|
<p align="center"><img src="{logo_url}" \ |
|
style="height: 60px"/><p>""" |
|
) |
|
gr.Markdown( |
|
"""\ |
|
<center><font size=3>This demo is based on ERNIE models. \ |
|
(本演示基于文心大模型实现。)</center>""" |
|
) |
|
gr.Markdown( |
|
"""\ |
|
<center><font size=3> <a href="https://ernie.baidu.com/">ERNIE Bot</a> | \ |
|
<a href="https://github.com/PaddlePaddle/ERNIE">GitHub</a> | \ |
|
<a href="https://huggingface.co/baidu">Hugging Face</a> | \ |
|
<a href="https://aistudio.baidu.com/modelsoverview">BAIDU AI Studio</a> | \ |
|
<a href="https://yiyan.baidu.com/blog/publication/">Technical Report</a></center>""" |
|
) |
|
|
|
chatbot = gr.Chatbot(label="ERNIE", type="messages") |
|
|
|
with gr.Row(equal_height=True): |
|
file_btn = gr.File( |
|
label="Knowledge Base Upload (System default will be used if none provided. Accepted formats: TXT, MD)", |
|
height="150px", |
|
file_types=[".txt", ".md"], |
|
elem_id="file-upload", |
|
file_count="multiple", |
|
) |
|
relevant_passage = gr.Textbox( |
|
label="Relevant Passage", |
|
lines=5, |
|
max_lines=5, |
|
placeholder=RELEVANT_PASSAGE_DEFAULT, |
|
interactive=False, |
|
) |
|
with gr.Row(): |
|
progress_bar = gr.Textbox(label="Progress", visible=False) |
|
|
|
query = gr.Textbox(label="Query", elem_id="text_input", value=QUERY_DEFAULT) |
|
|
|
with gr.Row(): |
|
empty_btn = gr.Button("🧹 Clear History(清除历史)") |
|
submit_btn = gr.Button("🚀 Submit(发送)", elem_id="submit-button") |
|
regen_btn = gr.Button("🤔️ Regenerate(重试)") |
|
|
|
task_history = gr.State([]) |
|
|
|
predict_with_clients = partial( |
|
GradioEvents.predict_stream, bot_client=bot_client |
|
) |
|
regenerate_with_clients = partial( |
|
GradioEvents.regenerate, bot_client=bot_client |
|
) |
|
file_upload_with_clients = partial( |
|
GradioEvents.file_upload, |
|
) |
|
|
|
chunk_size = gr.State(args.chunk_size) |
|
file_btn.change( |
|
fn=file_upload_with_clients, |
|
inputs=[file_btn, chunk_size, faiss_db], |
|
outputs=[progress_bar], |
|
) |
|
query.submit( |
|
predict_with_clients, |
|
inputs=[query, chatbot, task_history, model_name, faiss_db], |
|
outputs=[chatbot, relevant_passage], |
|
show_progress=True, |
|
) |
|
query.submit(GradioEvents.reset_user_input, [], [query]) |
|
submit_btn.click( |
|
predict_with_clients, |
|
inputs=[query, chatbot, task_history, model_name, faiss_db], |
|
outputs=[chatbot, relevant_passage], |
|
show_progress=True, |
|
) |
|
submit_btn.click(GradioEvents.reset_user_input, [], [query]) |
|
empty_btn.click( |
|
GradioEvents.reset_state, |
|
outputs=[chatbot, task_history, file_btn, relevant_passage], |
|
show_progress=True, |
|
) |
|
regen_btn.click( |
|
regenerate_with_clients, |
|
inputs=[chatbot, task_history, model_name, faiss_db], |
|
outputs=[chatbot, relevant_passage], |
|
show_progress=True, |
|
) |
|
|
|
demo.queue( |
|
default_concurrency_limit=args.concurrency_limit, max_size=args.max_queue_size |
|
) |
|
demo.launch(server_port=args.server_port, server_name=args.server_name) |
|
|
|
|
|
def main(): |
|
"""Main function that runs when this script is executed.""" |
|
args = get_args() |
|
bot_client = BotClient(args) |
|
faiss_db = FaissTextDatabase(args, bot_client) |
|
|
|
|
|
GradioEvents.save_file_to_db( |
|
FILE_URL_DEFAULT, args.chunk_size, faiss_db, save_file=True |
|
) |
|
|
|
launch_demo(args, bot_client, faiss_db) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|