Spaces:
Build error
Build error
import gradio as gr | |
from gliner import GLiNER | |
from vllm import LLM, SamplingParams | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import numpy as np | |
import json | |
import torch | |
import requests | |
import threading | |
from queue import Queue | |
import logging | |
import pynvml | |
# Configure logging | |
logging.basicConfig(level=logging.DEBUG) | |
logger = logging.getLogger(__name__) | |
# Initialize NVML for GPU debugging | |
try: | |
pynvml.nvmlInit() | |
device_count = pynvml.nvmlDeviceGetCount() | |
logger.info(f"NVML Initialized. GPU Count: {device_count}") | |
for i in range(device_count): | |
handle = pynvml.nvmlDeviceGetHandleByIndex(i) | |
name = pynvml.nvmlDeviceGetName(handle) | |
logger.info(f"GPU {i}: {name}") | |
except pynvml.NVMLError as e: | |
logger.error(f"NVML Initialization Failed: {str(e)}") | |
raise RuntimeError("Cannot initialize NVML. Check NVIDIA drivers.") | |
# Verify CUDA | |
if not torch.cuda.is_available(): | |
logger.error("CUDA not available") | |
raise RuntimeError("No GPU detected. Ensure H200 GPU is available.") | |
logger.info(f"CUDA Version: {torch.version.cuda}") | |
logger.info(f"GPU Detected: {torch.cuda.get_device_name(0)}") | |
logger.info(f"Device Count: {torch.cuda.device_count()}") | |
# Load legal corpus | |
with open("legal_corpus.json", "r", encoding="utf-8") as f: | |
corpus = json.load(f) | |
documents = [item["text"] for item in corpus] | |
# Initialize sentence transformer (GPU) | |
embedder = SentenceTransformer("all-MiniLM-L6-v2", device="cuda") | |
embeddings = embedder.encode(documents, convert_to_numpy=True) | |
# Initialize FAISS-GPU | |
dimension = embeddings.shape[1] | |
index = faiss.IndexFlatL2(dimension) | |
index.add(embeddings) | |
# Initialize GLiNER (GPU) | |
gliner_model = GLiNER.from_pretrained("NAMAA-Space/gliner_arabic-v2.1", load_tokenizer=True) | |
gliner_model = gliner_model.cuda() | |
# Initialize LLM (default to Qwen2-7B-Instruct-AWQ) | |
use_qwq_32b = False # Set to True if H200 detection is fixed | |
model_name = "Qwen/Qwen2-7B-Instruct-AWQ" if not use_qwq_32b else "Qwen/QwQ-32B" | |
try: | |
llm = LLM( | |
model=model_name, | |
quantization="awq", | |
max_model_len=4096, | |
gpu_memory_utilization=0.9, | |
device="cuda" | |
) | |
logger.info(f"Loaded LLM: {model_name}") | |
except Exception as e: | |
logger.error(f"Failed to initialize LLM: {str(e)}") | |
raise | |
sampling_params = SamplingParams(temperature=0.7, max_tokens=512) | |
def fetch_external_legal_data(query, queue): | |
"""Fetch external legal data via HTTP request (mock API).""" | |
try: | |
response = requests.get( | |
"https://api.example.com/legal", | |
params={"query": query}, | |
timeout=5 | |
) | |
response.raise_for_status() | |
queue.put(response.json().get("text", "No external data found")) | |
except requests.RequestException: | |
queue.put("Failed to fetch external data") | |
def run_ner(text, entity_types, queue): | |
"""Run NER with gliner_arabic-v2.1.""" | |
if not text or not entity_types: | |
queue.put([]) | |
return | |
entity_list = [e.strip() for e in entity_types.split(",")] | |
entities = gliner_model.predict_entities(text, entity_list, threshold=0.5) | |
queue.put([{"text": e["text"], "label": e["label"], "score": round(e["score"], 2)} for e in entities]) | |
def retrieve_documents(query, k=2): | |
"""Retrieve top-k documents using FAISS-GPU.""" | |
query_embedding = embedder.encode([query], convert_to_numpy=True) | |
distances, indices = index.search(query_embedding, k) | |
return [documents[idx] for idx in indices[0]] | |
def generate_legal_insight(text, entities, retrieved_docs, external_data): | |
"""Generate insight with LLM using RAG.""" | |
entity_str = ", ".join([f"{e['text']} ({e['label']})" for e in entities]) | |
context = "\n".join(retrieved_docs) + "\nExternal Data: " + external_data | |
prompt = f"""You are a legal assistant for Arabic law. Using the following context, extracted entities, and external data, provide a concise legal insight. | |
Context: | |
{context} | |
Entities: | |
{entity_str} | |
Input Text: | |
{text} | |
Insight:""" | |
outputs = llm.generate([prompt], sampling_params) | |
return outputs[0].outputs[0].text | |
def main_interface(text, entity_types): | |
"""Main Gradio interface with threading.""" | |
ner_queue = Queue() | |
external_queue = Queue() | |
ner_thread = threading.Thread(target=run_ner, args=(text, entity_types, ner_queue)) | |
external_thread = threading.Thread(target=fetch_external_legal_data, args=(text, external_queue)) | |
ner_thread.start() | |
external_thread.start() | |
ner_thread.join() | |
external_thread.join() | |
ner_result = ner_queue.get() | |
external_data = external_queue.get() | |
retrieved_docs = retrieve_documents(text) | |
insight = generate_legal_insight(text, ner_result, retrieved_docs, external_data) | |
return ner_result, retrieved_docs, external_data, insight | |
# Gradio interface | |
with gr.Blocks(theme=gr.themes.Monochrome()) as demo: | |
gr.Markdown("# Arabic Legal Demo: NER & RAG with GLiNER and LLM") | |
with gr.Row(): | |
text_input = gr.Textbox(label="Arabic Legal Text", lines=5, placeholder="Enter Arabic legal text...") | |
entity_types = gr.Textbox( | |
label="Entity Types (comma-separated)", | |
value="person,law,organization", | |
placeholder="e.g., person,law,organization" | |
) | |
submit_btn = gr.Button("Analyze") | |
ner_output = gr.JSON(label="Extracted Entities") | |
docs_output = gr.Textbox(label="Retrieved Legal Context") | |
external_output = gr.Textbox(label="External Legal Data") | |
insight_output = gr.Textbox(label="Legal Insight") | |
submit_btn.click( | |
fn=main_interface, | |
inputs=[text_input, entity_types], | |
outputs=[ner_output, docs_output, external_output, insight_output] | |
) | |
demo.launch() |