Spaces:
Configuration error
Configuration error
import gradio as gr | |
import os | |
import docx | |
import fitz # PyMuPDF | |
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, Trainer, TrainingArguments, pipeline | |
from datasets import Dataset | |
import re | |
import logging | |
from datetime import datetime | |
import warnings | |
# Suppress FutureWarning from huggingface_hub | |
warnings.filterwarnings("ignore", category=FutureWarning, module="huggingface_hub.file_download") | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize tokenizer and model with error handling | |
model_name = "aubmindlab/bert-base-arabertv2" | |
try: | |
logger.info(f"{datetime.now()}: Loading tokenizer for {model_name}") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
logger.info(f"{datetime.now()}: Loading model for {model_name}") | |
model = AutoModelForQuestionAnswering.from_pretrained(model_name) | |
except Exception as e: | |
logger.error(f"{datetime.now()}: Failed to load model/tokenizer: {e}") | |
raise | |
# Directory to save fine-tuned model | |
MODEL_SAVE_PATH = "./fine_tuned_model" | |
# Custom Arabic text preprocessing function | |
def preprocess_arabic_text(text): | |
logger.info(f"{datetime.now()}: Preprocessing text (length: {len(text)} characters)") | |
# Remove Arabic diacritics | |
diacritics = re.compile(r'[\u0617-\u061A\u064B-\u0652]') | |
text = diacritics.sub('', text) | |
# Normalize Arabic characters | |
text = re.sub(r'[أإآ]', 'ا', text) | |
text = re.sub(r'ى', 'ي', text) | |
text = re.sub(r'ة', 'ه', text) | |
# Remove extra spaces and non-essential characters | |
text = re.sub(r'\s+', ' ', text) | |
text = re.sub(r'[^\w\s]', '', text) | |
logger.info(f"{datetime.now()}: Text preprocessed, new length: {len(text)} characters") | |
return text.strip() | |
# Function to extract text from .docx | |
def extract_text_docx(file_path): | |
logger.info(f"{datetime.now()}: Extracting text from .docx file: {file_path}") | |
try: | |
doc = docx.Document(file_path) | |
text = "\n".join([para.text for para in doc.paragraphs if para.text.strip()]) | |
logger.info(f"{datetime.now()}: Successfully extracted {len(text)} characters from .docx") | |
return text | |
except Exception as e: | |
logger.error(f"{datetime.now()}: Error extracting text from .docx: {e}") | |
return "" | |
# Function to extract text from .pdf | |
def extract_text_pdf(file_path): | |
logger.info(f"{datetime.now()}: Extracting text from .pdf file: {file_path}") | |
try: | |
doc = fitz.open(file_path) | |
text = "" | |
for page in doc: | |
text += page.get_text() | |
logger.info(f"{datetime.now()}: Successfully extracted {len(text)} characters from .pdf") | |
return text | |
except Exception as e: | |
logger.error(f"{datetime.now()}: Error extracting text from .pdf: {e}") | |
return "" | |
# Function to chunk text for dataset | |
def chunk_text(text, max_length=512): | |
logger.info(f"{datetime.now()}: Chunking text into segments") | |
words = text.split() | |
chunks = [] | |
current_chunk = [] | |
current_length = 0 | |
for word in words: | |
current_chunk.append(word) | |
current_length += len(word) + 1 | |
if current_length >= max_length: | |
chunks.append(" ".join(current_chunk)) | |
current_chunk = [] | |
current_length = 0 | |
if current_chunk: | |
chunks.append(" ".join(current_chunk)) | |
logger.info(f"{datetime.now()}: Created {len(chunks)} text chunks") | |
return chunks | |
# Function to prepare dataset | |
def prepare_dataset(text): | |
logger.info(f"{datetime.now()}: Preparing dataset") | |
chunks = chunk_text(text) | |
data = {"text": chunks} | |
dataset = Dataset.from_dict(data) | |
logger.info(f"{datetime.now()}: Dataset prepared with {len(dataset)} examples") | |
return dataset | |
# Function to tokenize dataset | |
def tokenize_dataset(dataset): | |
logger.info(f"{datetime.now()}: Tokenizing dataset") | |
def tokenize_function(examples): | |
return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512) | |
tokenized_dataset = dataset.map(tokenize_function, batched=True) | |
logger.info(f"{datetime.now()}: Dataset tokenized") | |
return tokenized_dataset | |
# Function to fine-tune model | |
def fine_tune_model(dataset): | |
logger.info(f"{datetime.now()}: Starting model fine-tuning") | |
training_args = TrainingArguments( | |
output_dir="./results", | |
num_train_epochs=1, | |
per_device_train_batch_size=4, | |
save_steps=10_000, | |
save_total_limit=2, | |
logging_dir='./logs', | |
logging_steps=200, | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=dataset, | |
) | |
trainer.train() | |
model.save_pretrained(MODEL_SAVE_PATH) | |
tokenizer.save_pretrained(MODEL_SAVE_PATH) | |
logger.info(f"{datetime.now()}: Model fine-tuned and saved to {MODEL_SAVE_PATH}") | |
# Function to handle file upload and training | |
def upload_and_train(files, progress=gr.Progress()): | |
uploaded_files = [] | |
all_text = "" | |
training_log = [] | |
def log_and_update(step, desc, progress_value): | |
msg = f"{datetime.now()}: {desc}" | |
logger.info(msg) | |
training_log.append(msg) | |
progress(progress_value, desc=desc) | |
return "\n".join(training_log) | |
log_and_update("Starting upload", "Loading books...", 0.1) | |
for file in files: | |
file_name = os.path.basename(file.name) | |
uploaded_files.append(file_name) | |
if file_name.endswith(".docx"): | |
text = extract_text_docx(file.name) | |
elif file_name.endswith(".pdf"): | |
text = extract_text_pdf(file.name) | |
else: | |
continue | |
all_text += text + "\n" | |
if not all_text.strip(): | |
msg = f"{datetime.now()}: No valid text extracted from uploaded files." | |
logger.error(msg) | |
training_log.append(msg) | |
return "\n".join(training_log), uploaded_files | |
log_and_update("Text extraction complete", "Extracting ideas...", 0.4) | |
cleaned_text = preprocess_arabic_text(all_text) | |
log_and_update("Preprocessing complete", "Preparing dataset...", 0.6) | |
dataset = prepare_dataset(cleaned_text) | |
tokenized_dataset = tokenize_dataset(dataset) | |
log_and_update("Dataset preparation complete", "Training in progress...", 0.8) | |
fine_tune_model(tokenized_dataset) | |
log_and_update("Training complete", "Training completed!", 1.0) | |
# Example QA | |
qa_pipeline = pipeline("question-answering", model=MODEL_SAVE_PATH, tokenizer=MODEL_SAVE_PATH) | |
example_question = "ما هو قانون الإيمان وفقًا للكتاب؟" | |
example_answer = qa_pipeline(question=example_question, context=cleaned_text[:512])["answer"] | |
final_message = ( | |
f"Training process finished: Enter your question\n\n" | |
f"**مثال لسؤال**: {example_question}\n" | |
f"**الإجابة**: {example_answer}\n\n" | |
f"**سجل التدريب**:\n" + "\n".join(training_log) | |
) | |
return final_message, uploaded_files | |
# Function to answer questions | |
def answer_question(question, context): | |
if not os.path.exists(MODEL_SAVE_PATH): | |
return "النظام لم يتم تدريبه بعد. الرجاء رفع الكتب وتدريب النظام أولاً." | |
qa_pipeline = pipeline("question-answering", model=MODEL_SAVE_PATH, tokenizer=MODEL_SAVE_PATH) | |
answer = qa_pipeline(question=question, context=context[:512])["answer"] | |
return answer | |
# Gradio Interface with Tabs | |
with gr.Blocks(title="Arabic Book Analysis AI") as demo: | |
gr.Markdown("# نظام ذكاء اصطناعي لتحليل الكتب باللغة العربية") | |
with gr.Tabs(): | |
with gr.TabItem("التدريب والسؤال"): | |
with gr.Row(): | |
with gr.Column(): | |
file_upload = gr.File(file_types=[".docx", ".pdf"], label="رفع الكتب", file_count="multiple") | |
upload_button = gr.Button("رفع وتدريب") | |
uploaded_files = gr.Textbox(label="الكتب المرفوعة") | |
with gr.Column(): | |
training_status = gr.Textbox(label="حالة التدريب", lines=10) | |
with gr.Row(): | |
question_input = gr.Textbox(label="أدخل سؤالك بالعربية", placeholder="مثال: ما هو قانون الإيمان؟") | |
answer_output = gr.Textbox(label="الإجابة") | |
ask_button = gr.Button("طرح السؤال") | |
# Event handlers | |
upload_button.click( | |
fn=upload_and_train, | |
inputs=[file_upload], | |
outputs=[training_status, uploaded_files] | |
) | |
ask_button.click( | |
fn=answer_question, | |
inputs=[question_input, gr.State(value="")], | |
outputs=[answer_output] | |
) | |
with gr.TabItem("طرح الأسئلة فقط"): | |
gr.Markdown("أدخل سؤالك بالعربية وسيتم الإجابة بناءً على محتوى الكتب المدربة.") | |
question_input_qa = gr.Textbox(label="أدخل سؤالك", placeholder="مثال: ما هو قانون الإيمان؟") | |
answer_output_qa = gr.Textbox(label="الإجابة") | |
ask_button_qa = gr.Button("طرح السؤال") | |
ask_button_qa.click( | |
fn=answer_question, | |
inputs=[question_input_qa, gr.State(value="")], | |
outputs=[answer_output_qa] | |
) | |
if __name__ == "__main__": | |
demo.launch() |