Spaces:
Build error
Build error
File size: 5,781 Bytes
ae433b0 cf12b53 ae433b0 cf12b53 ae433b0 cf12b53 ae433b0 cf12b53 ae433b0 cf12b53 ae433b0 cf12b53 ae433b0 cf12b53 ae433b0 cf12b53 ae433b0 cf12b53 ae433b0 cf12b53 ae433b0 cf12b53 ae433b0 cf12b53 ae433b0 cf12b53 ae433b0 cf12b53 ae433b0 cf12b53 ae433b0 cf12b53 ae433b0 cf12b53 ae433b0 cf12b53 ae433b0 cf12b53 ae433b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import gradio as gr
from gliner import GLiNER
from vllm import LLM, SamplingParams
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import json
import torch
import requests
import threading
from queue import Queue
import logging
import pynvml
# Configure logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Initialize NVML for GPU debugging
try:
pynvml.nvmlInit()
device_count = pynvml.nvmlDeviceGetCount()
logger.info(f"NVML Initialized. GPU Count: {device_count}")
for i in range(device_count):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
name = pynvml.nvmlDeviceGetName(handle)
logger.info(f"GPU {i}: {name}")
except pynvml.NVMLError as e:
logger.error(f"NVML Initialization Failed: {str(e)}")
raise RuntimeError("Cannot initialize NVML. Check NVIDIA drivers.")
# Verify CUDA
if not torch.cuda.is_available():
logger.error("CUDA not available")
raise RuntimeError("No GPU detected. Ensure H200 GPU is available.")
logger.info(f"CUDA Version: {torch.version.cuda}")
logger.info(f"GPU Detected: {torch.cuda.get_device_name(0)}")
logger.info(f"Device Count: {torch.cuda.device_count()}")
# Load legal corpus
with open("legal_corpus.json", "r", encoding="utf-8") as f:
corpus = json.load(f)
documents = [item["text"] for item in corpus]
# Initialize sentence transformer (GPU)
embedder = SentenceTransformer("all-MiniLM-L6-v2", device="cuda")
embeddings = embedder.encode(documents, convert_to_numpy=True)
# Initialize FAISS-GPU
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
# Initialize GLiNER (GPU)
gliner_model = GLiNER.from_pretrained("NAMAA-Space/gliner_arabic-v2.1", load_tokenizer=True)
gliner_model = gliner_model.cuda()
# Initialize LLM (default to Qwen2-7B-Instruct-AWQ)
use_qwq_32b = False # Set to True if H200 detection is fixed
model_name = "Qwen/Qwen2-7B-Instruct-AWQ" if not use_qwq_32b else "Qwen/QwQ-32B"
try:
llm = LLM(
model=model_name,
quantization="awq",
max_model_len=4096,
gpu_memory_utilization=0.9,
device="cuda"
)
logger.info(f"Loaded LLM: {model_name}")
except Exception as e:
logger.error(f"Failed to initialize LLM: {str(e)}")
raise
sampling_params = SamplingParams(temperature=0.7, max_tokens=512)
def fetch_external_legal_data(query, queue):
"""Fetch external legal data via HTTP request (mock API)."""
try:
response = requests.get(
"https://api.example.com/legal",
params={"query": query},
timeout=5
)
response.raise_for_status()
queue.put(response.json().get("text", "No external data found"))
except requests.RequestException:
queue.put("Failed to fetch external data")
def run_ner(text, entity_types, queue):
"""Run NER with gliner_arabic-v2.1."""
if not text or not entity_types:
queue.put([])
return
entity_list = [e.strip() for e in entity_types.split(",")]
entities = gliner_model.predict_entities(text, entity_list, threshold=0.5)
queue.put([{"text": e["text"], "label": e["label"], "score": round(e["score"], 2)} for e in entities])
def retrieve_documents(query, k=2):
"""Retrieve top-k documents using FAISS-GPU."""
query_embedding = embedder.encode([query], convert_to_numpy=True)
distances, indices = index.search(query_embedding, k)
return [documents[idx] for idx in indices[0]]
def generate_legal_insight(text, entities, retrieved_docs, external_data):
"""Generate insight with LLM using RAG."""
entity_str = ", ".join([f"{e['text']} ({e['label']})" for e in entities])
context = "\n".join(retrieved_docs) + "\nExternal Data: " + external_data
prompt = f"""You are a legal assistant for Arabic law. Using the following context, extracted entities, and external data, provide a concise legal insight.
Context:
{context}
Entities:
{entity_str}
Input Text:
{text}
Insight:"""
outputs = llm.generate([prompt], sampling_params)
return outputs[0].outputs[0].text
def main_interface(text, entity_types):
"""Main Gradio interface with threading."""
ner_queue = Queue()
external_queue = Queue()
ner_thread = threading.Thread(target=run_ner, args=(text, entity_types, ner_queue))
external_thread = threading.Thread(target=fetch_external_legal_data, args=(text, external_queue))
ner_thread.start()
external_thread.start()
ner_thread.join()
external_thread.join()
ner_result = ner_queue.get()
external_data = external_queue.get()
retrieved_docs = retrieve_documents(text)
insight = generate_legal_insight(text, ner_result, retrieved_docs, external_data)
return ner_result, retrieved_docs, external_data, insight
# Gradio interface
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
gr.Markdown("# Arabic Legal Demo: NER & RAG with GLiNER and LLM")
with gr.Row():
text_input = gr.Textbox(label="Arabic Legal Text", lines=5, placeholder="Enter Arabic legal text...")
entity_types = gr.Textbox(
label="Entity Types (comma-separated)",
value="person,law,organization",
placeholder="e.g., person,law,organization"
)
submit_btn = gr.Button("Analyze")
ner_output = gr.JSON(label="Extracted Entities")
docs_output = gr.Textbox(label="Retrieved Legal Context")
external_output = gr.Textbox(label="External Legal Data")
insight_output = gr.Textbox(label="Legal Insight")
submit_btn.click(
fn=main_interface,
inputs=[text_input, entity_types],
outputs=[ner_output, docs_output, external_output, insight_output]
)
demo.launch() |