Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -6,52 +6,106 @@ import faiss
|
|
6 |
import numpy as np
|
7 |
import json
|
8 |
import torch
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
with open("legal_corpus.json", "r", encoding="utf-8") as f:
|
12 |
corpus = json.load(f)
|
13 |
documents = [item["text"] for item in corpus]
|
14 |
|
15 |
-
# Initialize sentence transformer
|
16 |
-
embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
17 |
embeddings = embedder.encode(documents, convert_to_numpy=True)
|
18 |
|
19 |
-
# Initialize FAISS
|
20 |
dimension = embeddings.shape[1]
|
21 |
index = faiss.IndexFlatL2(dimension)
|
22 |
index.add(embeddings)
|
23 |
|
24 |
-
# Initialize GLiNER
|
25 |
gliner_model = GLiNER.from_pretrained("NAMAA-Space/gliner_arabic-v2.1", load_tokenizer=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
-
# Initialize QwQ-32B
|
28 |
-
llm = LLM(
|
29 |
-
model="Qwen/QwQ-32B",
|
30 |
-
quantization="awq",
|
31 |
-
max_model_len=4096,
|
32 |
-
gpu_memory_utilization=0.9
|
33 |
-
)
|
34 |
sampling_params = SamplingParams(temperature=0.7, max_tokens=512)
|
35 |
|
36 |
-
def
|
37 |
-
"""
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
def run_ner(text, entity_types):
|
43 |
"""Run NER with gliner_arabic-v2.1."""
|
44 |
if not text or not entity_types:
|
45 |
-
|
|
|
46 |
entity_list = [e.strip() for e in entity_types.split(",")]
|
47 |
entities = gliner_model.predict_entities(text, entity_list, threshold=0.5)
|
48 |
-
|
49 |
|
50 |
-
def
|
51 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
entity_str = ", ".join([f"{e['text']} ({e['label']})" for e in entities])
|
53 |
-
context = "\n".join(retrieved_docs)
|
54 |
-
prompt = f"""You are a legal assistant for Arabic law. Using the following context
|
55 |
|
56 |
Context:
|
57 |
{context}
|
@@ -67,21 +121,31 @@ Insight:"""
|
|
67 |
return outputs[0].outputs[0].text
|
68 |
|
69 |
def main_interface(text, entity_types):
|
70 |
-
"""Main Gradio interface."""
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
75 |
retrieved_docs = retrieve_documents(text)
|
76 |
|
77 |
-
|
78 |
-
insight = generate_legal_insight(text, ner_result, retrieved_docs)
|
79 |
|
80 |
-
return ner_result, retrieved_docs, insight
|
81 |
|
82 |
# Gradio interface
|
83 |
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
84 |
-
gr.Markdown("# Arabic Legal Demo: NER & RAG with GLiNER and
|
85 |
with gr.Row():
|
86 |
text_input = gr.Textbox(label="Arabic Legal Text", lines=5, placeholder="Enter Arabic legal text...")
|
87 |
entity_types = gr.Textbox(
|
@@ -92,12 +156,13 @@ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
|
92 |
submit_btn = gr.Button("Analyze")
|
93 |
ner_output = gr.JSON(label="Extracted Entities")
|
94 |
docs_output = gr.Textbox(label="Retrieved Legal Context")
|
|
|
95 |
insight_output = gr.Textbox(label="Legal Insight")
|
96 |
|
97 |
submit_btn.click(
|
98 |
fn=main_interface,
|
99 |
inputs=[text_input, entity_types],
|
100 |
-
outputs=[ner_output, docs_output, insight_output]
|
101 |
)
|
102 |
|
103 |
demo.launch()
|
|
|
6 |
import numpy as np
|
7 |
import json
|
8 |
import torch
|
9 |
+
import requests
|
10 |
+
import threading
|
11 |
+
from queue import Queue
|
12 |
+
import logging
|
13 |
+
import pynvml
|
14 |
|
15 |
+
# Configure logging
|
16 |
+
logging.basicConfig(level=logging.DEBUG)
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
# Initialize NVML for GPU debugging
|
20 |
+
try:
|
21 |
+
pynvml.nvmlInit()
|
22 |
+
device_count = pynvml.nvmlDeviceGetCount()
|
23 |
+
logger.info(f"NVML Initialized. GPU Count: {device_count}")
|
24 |
+
for i in range(device_count):
|
25 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
|
26 |
+
name = pynvml.nvmlDeviceGetName(handle)
|
27 |
+
logger.info(f"GPU {i}: {name}")
|
28 |
+
except pynvml.NVMLError as e:
|
29 |
+
logger.error(f"NVML Initialization Failed: {str(e)}")
|
30 |
+
raise RuntimeError("Cannot initialize NVML. Check NVIDIA drivers.")
|
31 |
+
|
32 |
+
# Verify CUDA
|
33 |
+
if not torch.cuda.is_available():
|
34 |
+
logger.error("CUDA not available")
|
35 |
+
raise RuntimeError("No GPU detected. Ensure H200 GPU is available.")
|
36 |
+
logger.info(f"CUDA Version: {torch.version.cuda}")
|
37 |
+
logger.info(f"GPU Detected: {torch.cuda.get_device_name(0)}")
|
38 |
+
logger.info(f"Device Count: {torch.cuda.device_count()}")
|
39 |
+
|
40 |
+
# Load legal corpus
|
41 |
with open("legal_corpus.json", "r", encoding="utf-8") as f:
|
42 |
corpus = json.load(f)
|
43 |
documents = [item["text"] for item in corpus]
|
44 |
|
45 |
+
# Initialize sentence transformer (GPU)
|
46 |
+
embedder = SentenceTransformer("all-MiniLM-L6-v2", device="cuda")
|
47 |
embeddings = embedder.encode(documents, convert_to_numpy=True)
|
48 |
|
49 |
+
# Initialize FAISS-GPU
|
50 |
dimension = embeddings.shape[1]
|
51 |
index = faiss.IndexFlatL2(dimension)
|
52 |
index.add(embeddings)
|
53 |
|
54 |
+
# Initialize GLiNER (GPU)
|
55 |
gliner_model = GLiNER.from_pretrained("NAMAA-Space/gliner_arabic-v2.1", load_tokenizer=True)
|
56 |
+
gliner_model = gliner_model.cuda()
|
57 |
+
|
58 |
+
# Initialize LLM (default to Qwen2-7B-Instruct-AWQ)
|
59 |
+
use_qwq_32b = False # Set to True if H200 detection is fixed
|
60 |
+
model_name = "Qwen/Qwen2-7B-Instruct-AWQ" if not use_qwq_32b else "Qwen/QwQ-32B"
|
61 |
+
try:
|
62 |
+
llm = LLM(
|
63 |
+
model=model_name,
|
64 |
+
quantization="awq",
|
65 |
+
max_model_len=4096,
|
66 |
+
gpu_memory_utilization=0.9,
|
67 |
+
device="cuda"
|
68 |
+
)
|
69 |
+
logger.info(f"Loaded LLM: {model_name}")
|
70 |
+
except Exception as e:
|
71 |
+
logger.error(f"Failed to initialize LLM: {str(e)}")
|
72 |
+
raise
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
sampling_params = SamplingParams(temperature=0.7, max_tokens=512)
|
75 |
|
76 |
+
def fetch_external_legal_data(query, queue):
|
77 |
+
"""Fetch external legal data via HTTP request (mock API)."""
|
78 |
+
try:
|
79 |
+
response = requests.get(
|
80 |
+
"https://api.example.com/legal",
|
81 |
+
params={"query": query},
|
82 |
+
timeout=5
|
83 |
+
)
|
84 |
+
response.raise_for_status()
|
85 |
+
queue.put(response.json().get("text", "No external data found"))
|
86 |
+
except requests.RequestException:
|
87 |
+
queue.put("Failed to fetch external data")
|
88 |
|
89 |
+
def run_ner(text, entity_types, queue):
|
90 |
"""Run NER with gliner_arabic-v2.1."""
|
91 |
if not text or not entity_types:
|
92 |
+
queue.put([])
|
93 |
+
return
|
94 |
entity_list = [e.strip() for e in entity_types.split(",")]
|
95 |
entities = gliner_model.predict_entities(text, entity_list, threshold=0.5)
|
96 |
+
queue.put([{"text": e["text"], "label": e["label"], "score": round(e["score"], 2)} for e in entities])
|
97 |
|
98 |
+
def retrieve_documents(query, k=2):
|
99 |
+
"""Retrieve top-k documents using FAISS-GPU."""
|
100 |
+
query_embedding = embedder.encode([query], convert_to_numpy=True)
|
101 |
+
distances, indices = index.search(query_embedding, k)
|
102 |
+
return [documents[idx] for idx in indices[0]]
|
103 |
+
|
104 |
+
def generate_legal_insight(text, entities, retrieved_docs, external_data):
|
105 |
+
"""Generate insight with LLM using RAG."""
|
106 |
entity_str = ", ".join([f"{e['text']} ({e['label']})" for e in entities])
|
107 |
+
context = "\n".join(retrieved_docs) + "\nExternal Data: " + external_data
|
108 |
+
prompt = f"""You are a legal assistant for Arabic law. Using the following context, extracted entities, and external data, provide a concise legal insight.
|
109 |
|
110 |
Context:
|
111 |
{context}
|
|
|
121 |
return outputs[0].outputs[0].text
|
122 |
|
123 |
def main_interface(text, entity_types):
|
124 |
+
"""Main Gradio interface with threading."""
|
125 |
+
ner_queue = Queue()
|
126 |
+
external_queue = Queue()
|
127 |
+
|
128 |
+
ner_thread = threading.Thread(target=run_ner, args=(text, entity_types, ner_queue))
|
129 |
+
external_thread = threading.Thread(target=fetch_external_legal_data, args=(text, external_queue))
|
130 |
+
|
131 |
+
ner_thread.start()
|
132 |
+
external_thread.start()
|
133 |
|
134 |
+
ner_thread.join()
|
135 |
+
external_thread.join()
|
136 |
+
|
137 |
+
ner_result = ner_queue.get()
|
138 |
+
external_data = external_queue.get()
|
139 |
+
|
140 |
retrieved_docs = retrieve_documents(text)
|
141 |
|
142 |
+
insight = generate_legal_insight(text, ner_result, retrieved_docs, external_data)
|
|
|
143 |
|
144 |
+
return ner_result, retrieved_docs, external_data, insight
|
145 |
|
146 |
# Gradio interface
|
147 |
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
148 |
+
gr.Markdown("# Arabic Legal Demo: NER & RAG with GLiNER and LLM")
|
149 |
with gr.Row():
|
150 |
text_input = gr.Textbox(label="Arabic Legal Text", lines=5, placeholder="Enter Arabic legal text...")
|
151 |
entity_types = gr.Textbox(
|
|
|
156 |
submit_btn = gr.Button("Analyze")
|
157 |
ner_output = gr.JSON(label="Extracted Entities")
|
158 |
docs_output = gr.Textbox(label="Retrieved Legal Context")
|
159 |
+
external_output = gr.Textbox(label="External Legal Data")
|
160 |
insight_output = gr.Textbox(label="Legal Insight")
|
161 |
|
162 |
submit_btn.click(
|
163 |
fn=main_interface,
|
164 |
inputs=[text_input, entity_types],
|
165 |
+
outputs=[ner_output, docs_output, external_output, insight_output]
|
166 |
)
|
167 |
|
168 |
demo.launch()
|