assistant / legal.py
fizzarif7's picture
Upload 8 files
ce4fe6e verified
raw
history blame
10.3 kB
from flask import Flask, request, jsonify, send_from_directory
import speech_recognition as sr
import threading
import datetime
import pyttsx3
from langdetect import detect
from huggingface_hub import login
from sentence_transformers import SentenceTransformer
from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM
import faiss
import numpy as np
import pandas as pd
import json
import webbrowser
from pydub import AudioSegment
import os
from werkzeug.utils import secure_filename
import tempfile
app = Flask(__name__, static_folder='.') # Serve static files from the current directory
# Load Hugging Face API key from environment variable
hf_token = os.environ.get("API_KEY")
if not hf_token:
# Attempt to load from .env file if not set in environment
from dotenv import load_dotenv
load_dotenv()
hf_token = os.environ.get("API_KEY")
if not hf_token:
raise ValueError("Hugging Face API key not found. Please set 'API_KEY' as an environment variable or in a .env file.")
login(token=hf_token)
# QA Models
qa_model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
qa_tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
qa_pipeline = pipeline("question-answering", model=qa_model, tokenizer=qa_tokenizer)
# Summarization Model
summarizer_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
summarizer_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
summarizer_pipeline = pipeline("summarization", model=summarizer_model, tokenizer=summarizer_tokenizer)
embed_model = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L6-v2")
# Load both datasets
df_parquet = pd.read_parquet("ibtehaj dataset.parquet")
corpus_parquet = df_parquet["text"].dropna().tolist()
# Load the JSON dataset
with open("pdf_data.json", "r", encoding="utf-8") as f:
json_data = json.load(f)
# Extract text from JSON
corpus_json = []
for entry in json_data:
if isinstance(entry, dict) and "text" in entry:
text = entry["text"].strip()
if text:
corpus_json.append(text)
# Combine both corpora
corpus = corpus_parquet + corpus_json
# Compute embeddings
embeddings = embed_model.encode(corpus, show_progress_bar=True, batch_size=16)
# Build FAISS index
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(np.array(embeddings))
def rag_answer(question: str, k: int = 3) -> str:
q_emb = embed_model.encode([question])
D, I = index.search(q_emb, k)
context = "\n\n".join(corpus[i] for i in I[0] if 0 <= i < len(corpus))
if not context.strip():
return "Context is empty. Try rephrasing the question."
try:
result = qa_pipeline(question=question, context=context)
raw_answer = result.get("answer", "No answer found.")
# Summarize if answer is too long (>40 words or 300 characters)
if len(raw_answer.split()) > 40 or len(raw_answer) > 300:
summary = summarizer_pipeline(raw_answer, max_length=50, min_length=15, do_sample=False)
summarized_answer = summary[0]['summary_text']
else:
summarized_answer = raw_answer
return f"Answer: {summarized_answer}\n\n[Context Used]:\n{context[:500]}..."
except Exception as e:
return f"Error: {e}"
# Global for TTS engine (to allow stopping)
tts_engine = None
def init_tts_engine():
global tts_engine
if tts_engine is None:
tts_engine = pyttsx3.init()
tts_engine.setProperty('rate', 150)
tts_engine.setProperty('volume', 1.0)
voices = tts_engine.getProperty('voices')
for v in voices:
if "zira" in v.name.lower() or "female" in v.name.lower():
tts_engine.setProperty('voice', v.id)
break
init_tts_engine()
# Global variables for managing state (simplify for web context)
conversation_history = []
last_question_text = ""
last_answer_text = ""
@app.route('/')
def serve_index():
return send_from_directory('.', 'index.html')
@app.route('/<path:path>')
def serve_static_files(path):
return send_from_directory('.', path)
@app.route('/answer', methods=['POST'])
def generate_answer_endpoint():
global last_question_text, last_answer_text, conversation_history
data = request.get_json()
question = data.get('question', '').strip()
if not question:
return jsonify({"answer": "Please provide a question."}), 400
last_question_text = question
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
conversation_history.append({"role": "user", "time": timestamp, "text": question})
ans = rag_answer(question)
last_answer_text = ans
conversation_history.append({"role": "bot", "time": timestamp, "text": ans})
return jsonify({"answer": ans})
@app.route('/read-aloud', methods=['POST'])
def read_aloud_endpoint():
data = request.get_json()
text_to_read = data.get('text', '').strip()
if not text_to_read:
return jsonify({"status": "No text provided to read."}), 400
try:
# Create a temporary file for the speech audio
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
temp_audio_path = fp.name
tts_engine.save_to_file(text_to_read, temp_audio_path)
tts_engine.runAndWait()
# You would typically serve this file or stream it.
# For simplicity, let's just confirm it was generated.
# In a real app, you might use Flask's send_file for audio playback.
# For now, let's just return success.
# This approach is suitable if the browser requests the audio file directly after this.
# For direct playback, you might stream it or serve it immediately.
# For web, it's more common to have the frontend's SpeechSynthesis API handle this.
# The frontend `readAloud` function already does this.
# So, this endpoint might not be strictly necessary unless for server-side TTS.
return jsonify({"status": "TTS audio generated (server-side)."})
except Exception as e:
return jsonify({"status": f"Error during TTS: {str(e)}"}), 500
finally:
if os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
@app.route('/upload-mp3', methods=['POST'])
def upload_mp3_endpoint():
global last_question_text, last_answer_text, conversation_history
if 'file' not in request.files:
return jsonify({"message": "No file part"}), 400
file = request.files['file']
if file.filename == '':
return jsonify({"message": "No selected file"}), 400
if file:
filename = secure_filename(file.filename)
# Create a temporary directory to save the uploaded file and its WAV conversion
with tempfile.TemporaryDirectory() as tmpdir:
mp3_path = os.path.join(tmpdir, filename)
file.save(mp3_path)
wav_path = os.path.join(tmpdir, filename.replace(".mp3", ".wav"))
try:
sound = AudioSegment.from_mp3(mp3_path)
sound.export(wav_path, format="wav")
except Exception as e:
return jsonify({"message": f"Error converting MP3 to WAV: {e}"}), 500
try:
recognizer = sr.Recognizer()
with sr.AudioFile(wav_path) as src:
audio = recognizer.record(src)
text = recognizer.recognize_google(audio)
except sr.UnknownValueError:
return jsonify({"message": "Speech not understood."}), 400
except sr.RequestError as e:
return jsonify({"message": f"Speech recognition service error: {e}"}), 500
# Store transcription temporarily (can be handled differently)
transcript_path = os.path.join(tmpdir, "transcription.txt")
with open(transcript_path, "w", encoding="utf-8") as f:
f.write(text)
# Option to summarize or generate answer from transcription
# For this web integration, we'll return the transcription and let frontend decide
return jsonify({
"message": "MP3 transcribed successfully.",
"transcription": text
})
@app.route('/summarize', methods=['POST'])
def summarize_endpoint():
data = request.get_json()
text_to_summarize = data.get('text', '').strip()
if not text_to_summarize:
return jsonify({"summary": "No text provided for summarization."}), 400
def chunk_text(text, max_chunk_size=4000):
sentences = text.split(". ")
chunks = []
current_chunk = ""
for sentence in sentences:
# Add sentence length + 2 for ". "
if len(current_chunk) + len(sentence) + 2 < max_chunk_size:
current_chunk += sentence + ". "
else:
chunks.append(current_chunk.strip())
current_chunk = sentence + ". "
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
try:
chunks = chunk_text(text_to_summarize)
summaries = [
summarizer_pipeline(chunk, max_length=150, min_length=50, do_sample=False)[0]["summary_text"]
for chunk in chunks
]
final_input = " ".join(summaries)
final_summary = summarizer_pipeline(final_input, max_length=150, min_length=50, do_sample=False)[0]["summary_text"]
return jsonify({"summary": final_summary})
except Exception as e:
return jsonify({"summary": f"Error during summarization: {e}"}), 500
@app.route('/history', methods=['GET'])
def get_history():
return jsonify({"history": conversation_history})
if __name__ == '__main__':
# Make sure your datasets are in the same directory as app.py
# ibtehaj dataset.parquet
# pdf_data.json
# man.jpg (for the image)
app.run(debug=True) # debug=True allows for automatic reloading on code changes