ArabicLAWLLM / app.py
ghostai1's picture
Update app.py
cf12b53 verified
raw
history blame
5.78 kB
import gradio as gr
from gliner import GLiNER
from vllm import LLM, SamplingParams
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import json
import torch
import requests
import threading
from queue import Queue
import logging
import pynvml
# Configure logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Initialize NVML for GPU debugging
try:
pynvml.nvmlInit()
device_count = pynvml.nvmlDeviceGetCount()
logger.info(f"NVML Initialized. GPU Count: {device_count}")
for i in range(device_count):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
name = pynvml.nvmlDeviceGetName(handle)
logger.info(f"GPU {i}: {name}")
except pynvml.NVMLError as e:
logger.error(f"NVML Initialization Failed: {str(e)}")
raise RuntimeError("Cannot initialize NVML. Check NVIDIA drivers.")
# Verify CUDA
if not torch.cuda.is_available():
logger.error("CUDA not available")
raise RuntimeError("No GPU detected. Ensure H200 GPU is available.")
logger.info(f"CUDA Version: {torch.version.cuda}")
logger.info(f"GPU Detected: {torch.cuda.get_device_name(0)}")
logger.info(f"Device Count: {torch.cuda.device_count()}")
# Load legal corpus
with open("legal_corpus.json", "r", encoding="utf-8") as f:
corpus = json.load(f)
documents = [item["text"] for item in corpus]
# Initialize sentence transformer (GPU)
embedder = SentenceTransformer("all-MiniLM-L6-v2", device="cuda")
embeddings = embedder.encode(documents, convert_to_numpy=True)
# Initialize FAISS-GPU
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
# Initialize GLiNER (GPU)
gliner_model = GLiNER.from_pretrained("NAMAA-Space/gliner_arabic-v2.1", load_tokenizer=True)
gliner_model = gliner_model.cuda()
# Initialize LLM (default to Qwen2-7B-Instruct-AWQ)
use_qwq_32b = False # Set to True if H200 detection is fixed
model_name = "Qwen/Qwen2-7B-Instruct-AWQ" if not use_qwq_32b else "Qwen/QwQ-32B"
try:
llm = LLM(
model=model_name,
quantization="awq",
max_model_len=4096,
gpu_memory_utilization=0.9,
device="cuda"
)
logger.info(f"Loaded LLM: {model_name}")
except Exception as e:
logger.error(f"Failed to initialize LLM: {str(e)}")
raise
sampling_params = SamplingParams(temperature=0.7, max_tokens=512)
def fetch_external_legal_data(query, queue):
"""Fetch external legal data via HTTP request (mock API)."""
try:
response = requests.get(
"https://api.example.com/legal",
params={"query": query},
timeout=5
)
response.raise_for_status()
queue.put(response.json().get("text", "No external data found"))
except requests.RequestException:
queue.put("Failed to fetch external data")
def run_ner(text, entity_types, queue):
"""Run NER with gliner_arabic-v2.1."""
if not text or not entity_types:
queue.put([])
return
entity_list = [e.strip() for e in entity_types.split(",")]
entities = gliner_model.predict_entities(text, entity_list, threshold=0.5)
queue.put([{"text": e["text"], "label": e["label"], "score": round(e["score"], 2)} for e in entities])
def retrieve_documents(query, k=2):
"""Retrieve top-k documents using FAISS-GPU."""
query_embedding = embedder.encode([query], convert_to_numpy=True)
distances, indices = index.search(query_embedding, k)
return [documents[idx] for idx in indices[0]]
def generate_legal_insight(text, entities, retrieved_docs, external_data):
"""Generate insight with LLM using RAG."""
entity_str = ", ".join([f"{e['text']} ({e['label']})" for e in entities])
context = "\n".join(retrieved_docs) + "\nExternal Data: " + external_data
prompt = f"""You are a legal assistant for Arabic law. Using the following context, extracted entities, and external data, provide a concise legal insight.
Context:
{context}
Entities:
{entity_str}
Input Text:
{text}
Insight:"""
outputs = llm.generate([prompt], sampling_params)
return outputs[0].outputs[0].text
def main_interface(text, entity_types):
"""Main Gradio interface with threading."""
ner_queue = Queue()
external_queue = Queue()
ner_thread = threading.Thread(target=run_ner, args=(text, entity_types, ner_queue))
external_thread = threading.Thread(target=fetch_external_legal_data, args=(text, external_queue))
ner_thread.start()
external_thread.start()
ner_thread.join()
external_thread.join()
ner_result = ner_queue.get()
external_data = external_queue.get()
retrieved_docs = retrieve_documents(text)
insight = generate_legal_insight(text, ner_result, retrieved_docs, external_data)
return ner_result, retrieved_docs, external_data, insight
# Gradio interface
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
gr.Markdown("# Arabic Legal Demo: NER & RAG with GLiNER and LLM")
with gr.Row():
text_input = gr.Textbox(label="Arabic Legal Text", lines=5, placeholder="Enter Arabic legal text...")
entity_types = gr.Textbox(
label="Entity Types (comma-separated)",
value="person,law,organization",
placeholder="e.g., person,law,organization"
)
submit_btn = gr.Button("Analyze")
ner_output = gr.JSON(label="Extracted Entities")
docs_output = gr.Textbox(label="Retrieved Legal Context")
external_output = gr.Textbox(label="External Legal Data")
insight_output = gr.Textbox(label="Legal Insight")
submit_btn.click(
fn=main_interface,
inputs=[text_input, entity_types],
outputs=[ner_output, docs_output, external_output, insight_output]
)
demo.launch()