File size: 5,781 Bytes
ae433b0
 
 
 
 
 
 
 
cf12b53
 
 
 
 
ae433b0
cf12b53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae433b0
 
 
 
cf12b53
 
ae433b0
 
cf12b53
ae433b0
 
 
 
cf12b53
ae433b0
cf12b53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae433b0
 
 
cf12b53
 
 
 
 
 
 
 
 
 
 
 
ae433b0
cf12b53
ae433b0
 
cf12b53
 
ae433b0
 
cf12b53
ae433b0
cf12b53
 
 
 
 
 
 
 
ae433b0
cf12b53
 
ae433b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf12b53
 
 
 
 
 
 
 
 
ae433b0
cf12b53
 
 
 
 
 
ae433b0
 
cf12b53
ae433b0
cf12b53
ae433b0
 
 
cf12b53
ae433b0
 
 
 
 
 
 
 
 
 
cf12b53
ae433b0
 
 
 
 
cf12b53
ae433b0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import gradio as gr
from gliner import GLiNER
from vllm import LLM, SamplingParams
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import json
import torch
import requests
import threading
from queue import Queue
import logging
import pynvml

# Configure logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

# Initialize NVML for GPU debugging
try:
    pynvml.nvmlInit()
    device_count = pynvml.nvmlDeviceGetCount()
    logger.info(f"NVML Initialized. GPU Count: {device_count}")
    for i in range(device_count):
        handle = pynvml.nvmlDeviceGetHandleByIndex(i)
        name = pynvml.nvmlDeviceGetName(handle)
        logger.info(f"GPU {i}: {name}")
except pynvml.NVMLError as e:
    logger.error(f"NVML Initialization Failed: {str(e)}")
    raise RuntimeError("Cannot initialize NVML. Check NVIDIA drivers.")

# Verify CUDA
if not torch.cuda.is_available():
    logger.error("CUDA not available")
    raise RuntimeError("No GPU detected. Ensure H200 GPU is available.")
logger.info(f"CUDA Version: {torch.version.cuda}")
logger.info(f"GPU Detected: {torch.cuda.get_device_name(0)}")
logger.info(f"Device Count: {torch.cuda.device_count()}")

# Load legal corpus
with open("legal_corpus.json", "r", encoding="utf-8") as f:
    corpus = json.load(f)
documents = [item["text"] for item in corpus]

# Initialize sentence transformer (GPU)
embedder = SentenceTransformer("all-MiniLM-L6-v2", device="cuda")
embeddings = embedder.encode(documents, convert_to_numpy=True)

# Initialize FAISS-GPU
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)

# Initialize GLiNER (GPU)
gliner_model = GLiNER.from_pretrained("NAMAA-Space/gliner_arabic-v2.1", load_tokenizer=True)
gliner_model = gliner_model.cuda()

# Initialize LLM (default to Qwen2-7B-Instruct-AWQ)
use_qwq_32b = False  # Set to True if H200 detection is fixed
model_name = "Qwen/Qwen2-7B-Instruct-AWQ" if not use_qwq_32b else "Qwen/QwQ-32B"
try:
    llm = LLM(
        model=model_name,
        quantization="awq",
        max_model_len=4096,
        gpu_memory_utilization=0.9,
        device="cuda"
    )
    logger.info(f"Loaded LLM: {model_name}")
except Exception as e:
    logger.error(f"Failed to initialize LLM: {str(e)}")
    raise

sampling_params = SamplingParams(temperature=0.7, max_tokens=512)

def fetch_external_legal_data(query, queue):
    """Fetch external legal data via HTTP request (mock API)."""
    try:
        response = requests.get(
            "https://api.example.com/legal",
            params={"query": query},
            timeout=5
        )
        response.raise_for_status()
        queue.put(response.json().get("text", "No external data found"))
    except requests.RequestException:
        queue.put("Failed to fetch external data")

def run_ner(text, entity_types, queue):
    """Run NER with gliner_arabic-v2.1."""
    if not text or not entity_types:
        queue.put([])
        return
    entity_list = [e.strip() for e in entity_types.split(",")]
    entities = gliner_model.predict_entities(text, entity_list, threshold=0.5)
    queue.put([{"text": e["text"], "label": e["label"], "score": round(e["score"], 2)} for e in entities])

def retrieve_documents(query, k=2):
    """Retrieve top-k documents using FAISS-GPU."""
    query_embedding = embedder.encode([query], convert_to_numpy=True)
    distances, indices = index.search(query_embedding, k)
    return [documents[idx] for idx in indices[0]]

def generate_legal_insight(text, entities, retrieved_docs, external_data):
    """Generate insight with LLM using RAG."""
    entity_str = ", ".join([f"{e['text']} ({e['label']})" for e in entities])
    context = "\n".join(retrieved_docs) + "\nExternal Data: " + external_data
    prompt = f"""You are a legal assistant for Arabic law. Using the following context, extracted entities, and external data, provide a concise legal insight.

Context:
{context}

Entities:
{entity_str}

Input Text:
{text}

Insight:"""
    outputs = llm.generate([prompt], sampling_params)
    return outputs[0].outputs[0].text

def main_interface(text, entity_types):
    """Main Gradio interface with threading."""
    ner_queue = Queue()
    external_queue = Queue()

    ner_thread = threading.Thread(target=run_ner, args=(text, entity_types, ner_queue))
    external_thread = threading.Thread(target=fetch_external_legal_data, args=(text, external_queue))
    
    ner_thread.start()
    external_thread.start()
    
    ner_thread.join()
    external_thread.join()

    ner_result = ner_queue.get()
    external_data = external_queue.get()

    retrieved_docs = retrieve_documents(text)
    
    insight = generate_legal_insight(text, ner_result, retrieved_docs, external_data)
    
    return ner_result, retrieved_docs, external_data, insight

# Gradio interface
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
    gr.Markdown("# Arabic Legal Demo: NER & RAG with GLiNER and LLM")
    with gr.Row():
        text_input = gr.Textbox(label="Arabic Legal Text", lines=5, placeholder="Enter Arabic legal text...")
        entity_types = gr.Textbox(
            label="Entity Types (comma-separated)",
            value="person,law,organization",
            placeholder="e.g., person,law,organization"
        )
    submit_btn = gr.Button("Analyze")
    ner_output = gr.JSON(label="Extracted Entities")
    docs_output = gr.Textbox(label="Retrieved Legal Context")
    external_output = gr.Textbox(label="External Legal Data")
    insight_output = gr.Textbox(label="Legal Insight")
    
    submit_btn.click(
        fn=main_interface,
        inputs=[text_input, entity_types],
        outputs=[ner_output, docs_output, external_output, insight_output]
    )

demo.launch()