ghostai1 commited on
Commit
cf12b53
·
verified ·
1 Parent(s): 2c56c93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -33
app.py CHANGED
@@ -6,52 +6,106 @@ import faiss
6
  import numpy as np
7
  import json
8
  import torch
 
 
 
 
 
9
 
10
- # Load mock legal corpus
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  with open("legal_corpus.json", "r", encoding="utf-8") as f:
12
  corpus = json.load(f)
13
  documents = [item["text"] for item in corpus]
14
 
15
- # Initialize sentence transformer for embeddings
16
- embedder = SentenceTransformer("all-MiniLM-L6-v2") # Lightweight embedder
17
  embeddings = embedder.encode(documents, convert_to_numpy=True)
18
 
19
- # Initialize FAISS index
20
  dimension = embeddings.shape[1]
21
  index = faiss.IndexFlatL2(dimension)
22
  index.add(embeddings)
23
 
24
- # Initialize GLiNER model
25
  gliner_model = GLiNER.from_pretrained("NAMAA-Space/gliner_arabic-v2.1", load_tokenizer=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- # Initialize QwQ-32B
28
- llm = LLM(
29
- model="Qwen/QwQ-32B",
30
- quantization="awq",
31
- max_model_len=4096,
32
- gpu_memory_utilization=0.9
33
- )
34
  sampling_params = SamplingParams(temperature=0.7, max_tokens=512)
35
 
36
- def retrieve_documents(query, k=2):
37
- """Retrieve top-k relevant documents using FAISS."""
38
- query_embedding = embedder.encode([query], convert_to_numpy=True)
39
- distances, indices = index.search(query_embedding, k)
40
- return [documents[idx] for idx in indices[0]]
 
 
 
 
 
 
 
41
 
42
- def run_ner(text, entity_types):
43
  """Run NER with gliner_arabic-v2.1."""
44
  if not text or not entity_types:
45
- return []
 
46
  entity_list = [e.strip() for e in entity_types.split(",")]
47
  entities = gliner_model.predict_entities(text, entity_list, threshold=0.5)
48
- return [{"text": e["text"], "label": e["label"], "score": round(e["score"], 2)} for e in entities]
49
 
50
- def generate_legal_insight(text, entities, retrieved_docs):
51
- """Generate insight with QwQ-32B using RAG."""
 
 
 
 
 
 
52
  entity_str = ", ".join([f"{e['text']} ({e['label']})" for e in entities])
53
- context = "\n".join(retrieved_docs)
54
- prompt = f"""You are a legal assistant for Arabic law. Using the following context and extracted entities, provide a concise legal insight (e.g., summary or explanation). Ensure the response is grounded in the context and entities.
55
 
56
  Context:
57
  {context}
@@ -67,21 +121,31 @@ Insight:"""
67
  return outputs[0].outputs[0].text
68
 
69
  def main_interface(text, entity_types):
70
- """Main Gradio interface."""
71
- # Run NER
72
- ner_result = run_ner(text, entity_types)
 
 
 
 
 
 
73
 
74
- # Retrieve relevant documents
 
 
 
 
 
75
  retrieved_docs = retrieve_documents(text)
76
 
77
- # Generate legal insight
78
- insight = generate_legal_insight(text, ner_result, retrieved_docs)
79
 
80
- return ner_result, retrieved_docs, insight
81
 
82
  # Gradio interface
83
  with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
84
- gr.Markdown("# Arabic Legal Demo: NER & RAG with GLiNER and QwQ-32B")
85
  with gr.Row():
86
  text_input = gr.Textbox(label="Arabic Legal Text", lines=5, placeholder="Enter Arabic legal text...")
87
  entity_types = gr.Textbox(
@@ -92,12 +156,13 @@ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
92
  submit_btn = gr.Button("Analyze")
93
  ner_output = gr.JSON(label="Extracted Entities")
94
  docs_output = gr.Textbox(label="Retrieved Legal Context")
 
95
  insight_output = gr.Textbox(label="Legal Insight")
96
 
97
  submit_btn.click(
98
  fn=main_interface,
99
  inputs=[text_input, entity_types],
100
- outputs=[ner_output, docs_output, insight_output]
101
  )
102
 
103
  demo.launch()
 
6
  import numpy as np
7
  import json
8
  import torch
9
+ import requests
10
+ import threading
11
+ from queue import Queue
12
+ import logging
13
+ import pynvml
14
 
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.DEBUG)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Initialize NVML for GPU debugging
20
+ try:
21
+ pynvml.nvmlInit()
22
+ device_count = pynvml.nvmlDeviceGetCount()
23
+ logger.info(f"NVML Initialized. GPU Count: {device_count}")
24
+ for i in range(device_count):
25
+ handle = pynvml.nvmlDeviceGetHandleByIndex(i)
26
+ name = pynvml.nvmlDeviceGetName(handle)
27
+ logger.info(f"GPU {i}: {name}")
28
+ except pynvml.NVMLError as e:
29
+ logger.error(f"NVML Initialization Failed: {str(e)}")
30
+ raise RuntimeError("Cannot initialize NVML. Check NVIDIA drivers.")
31
+
32
+ # Verify CUDA
33
+ if not torch.cuda.is_available():
34
+ logger.error("CUDA not available")
35
+ raise RuntimeError("No GPU detected. Ensure H200 GPU is available.")
36
+ logger.info(f"CUDA Version: {torch.version.cuda}")
37
+ logger.info(f"GPU Detected: {torch.cuda.get_device_name(0)}")
38
+ logger.info(f"Device Count: {torch.cuda.device_count()}")
39
+
40
+ # Load legal corpus
41
  with open("legal_corpus.json", "r", encoding="utf-8") as f:
42
  corpus = json.load(f)
43
  documents = [item["text"] for item in corpus]
44
 
45
+ # Initialize sentence transformer (GPU)
46
+ embedder = SentenceTransformer("all-MiniLM-L6-v2", device="cuda")
47
  embeddings = embedder.encode(documents, convert_to_numpy=True)
48
 
49
+ # Initialize FAISS-GPU
50
  dimension = embeddings.shape[1]
51
  index = faiss.IndexFlatL2(dimension)
52
  index.add(embeddings)
53
 
54
+ # Initialize GLiNER (GPU)
55
  gliner_model = GLiNER.from_pretrained("NAMAA-Space/gliner_arabic-v2.1", load_tokenizer=True)
56
+ gliner_model = gliner_model.cuda()
57
+
58
+ # Initialize LLM (default to Qwen2-7B-Instruct-AWQ)
59
+ use_qwq_32b = False # Set to True if H200 detection is fixed
60
+ model_name = "Qwen/Qwen2-7B-Instruct-AWQ" if not use_qwq_32b else "Qwen/QwQ-32B"
61
+ try:
62
+ llm = LLM(
63
+ model=model_name,
64
+ quantization="awq",
65
+ max_model_len=4096,
66
+ gpu_memory_utilization=0.9,
67
+ device="cuda"
68
+ )
69
+ logger.info(f"Loaded LLM: {model_name}")
70
+ except Exception as e:
71
+ logger.error(f"Failed to initialize LLM: {str(e)}")
72
+ raise
73
 
 
 
 
 
 
 
 
74
  sampling_params = SamplingParams(temperature=0.7, max_tokens=512)
75
 
76
+ def fetch_external_legal_data(query, queue):
77
+ """Fetch external legal data via HTTP request (mock API)."""
78
+ try:
79
+ response = requests.get(
80
+ "https://api.example.com/legal",
81
+ params={"query": query},
82
+ timeout=5
83
+ )
84
+ response.raise_for_status()
85
+ queue.put(response.json().get("text", "No external data found"))
86
+ except requests.RequestException:
87
+ queue.put("Failed to fetch external data")
88
 
89
+ def run_ner(text, entity_types, queue):
90
  """Run NER with gliner_arabic-v2.1."""
91
  if not text or not entity_types:
92
+ queue.put([])
93
+ return
94
  entity_list = [e.strip() for e in entity_types.split(",")]
95
  entities = gliner_model.predict_entities(text, entity_list, threshold=0.5)
96
+ queue.put([{"text": e["text"], "label": e["label"], "score": round(e["score"], 2)} for e in entities])
97
 
98
+ def retrieve_documents(query, k=2):
99
+ """Retrieve top-k documents using FAISS-GPU."""
100
+ query_embedding = embedder.encode([query], convert_to_numpy=True)
101
+ distances, indices = index.search(query_embedding, k)
102
+ return [documents[idx] for idx in indices[0]]
103
+
104
+ def generate_legal_insight(text, entities, retrieved_docs, external_data):
105
+ """Generate insight with LLM using RAG."""
106
  entity_str = ", ".join([f"{e['text']} ({e['label']})" for e in entities])
107
+ context = "\n".join(retrieved_docs) + "\nExternal Data: " + external_data
108
+ prompt = f"""You are a legal assistant for Arabic law. Using the following context, extracted entities, and external data, provide a concise legal insight.
109
 
110
  Context:
111
  {context}
 
121
  return outputs[0].outputs[0].text
122
 
123
  def main_interface(text, entity_types):
124
+ """Main Gradio interface with threading."""
125
+ ner_queue = Queue()
126
+ external_queue = Queue()
127
+
128
+ ner_thread = threading.Thread(target=run_ner, args=(text, entity_types, ner_queue))
129
+ external_thread = threading.Thread(target=fetch_external_legal_data, args=(text, external_queue))
130
+
131
+ ner_thread.start()
132
+ external_thread.start()
133
 
134
+ ner_thread.join()
135
+ external_thread.join()
136
+
137
+ ner_result = ner_queue.get()
138
+ external_data = external_queue.get()
139
+
140
  retrieved_docs = retrieve_documents(text)
141
 
142
+ insight = generate_legal_insight(text, ner_result, retrieved_docs, external_data)
 
143
 
144
+ return ner_result, retrieved_docs, external_data, insight
145
 
146
  # Gradio interface
147
  with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
148
+ gr.Markdown("# Arabic Legal Demo: NER & RAG with GLiNER and LLM")
149
  with gr.Row():
150
  text_input = gr.Textbox(label="Arabic Legal Text", lines=5, placeholder="Enter Arabic legal text...")
151
  entity_types = gr.Textbox(
 
156
  submit_btn = gr.Button("Analyze")
157
  ner_output = gr.JSON(label="Extracted Entities")
158
  docs_output = gr.Textbox(label="Retrieved Legal Context")
159
+ external_output = gr.Textbox(label="External Legal Data")
160
  insight_output = gr.Textbox(label="Legal Insight")
161
 
162
  submit_btn.click(
163
  fn=main_interface,
164
  inputs=[text_input, entity_types],
165
+ outputs=[ner_output, docs_output, external_output, insight_output]
166
  )
167
 
168
  demo.launch()