File size: 3,500 Bytes
ae433b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
from gliner import GLiNER
from vllm import LLM, SamplingParams
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import json
import torch

# Load mock legal corpus
with open("legal_corpus.json", "r", encoding="utf-8") as f:
    corpus = json.load(f)
documents = [item["text"] for item in corpus]

# Initialize sentence transformer for embeddings
embedder = SentenceTransformer("all-MiniLM-L6-v2")  # Lightweight embedder
embeddings = embedder.encode(documents, convert_to_numpy=True)

# Initialize FAISS index
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)

# Initialize GLiNER model
gliner_model = GLiNER.from_pretrained("NAMAA-Space/gliner_arabic-v2.1", load_tokenizer=True)

# Initialize QwQ-32B
llm = LLM(
    model="Qwen/QwQ-32B",
    quantization="awq",
    max_model_len=4096,
    gpu_memory_utilization=0.9
)
sampling_params = SamplingParams(temperature=0.7, max_tokens=512)

def retrieve_documents(query, k=2):
    """Retrieve top-k relevant documents using FAISS."""
    query_embedding = embedder.encode([query], convert_to_numpy=True)
    distances, indices = index.search(query_embedding, k)
    return [documents[idx] for idx in indices[0]]

def run_ner(text, entity_types):
    """Run NER with gliner_arabic-v2.1."""
    if not text or not entity_types:
        return []
    entity_list = [e.strip() for e in entity_types.split(",")]
    entities = gliner_model.predict_entities(text, entity_list, threshold=0.5)
    return [{"text": e["text"], "label": e["label"], "score": round(e["score"], 2)} for e in entities]

def generate_legal_insight(text, entities, retrieved_docs):
    """Generate insight with QwQ-32B using RAG."""
    entity_str = ", ".join([f"{e['text']} ({e['label']})" for e in entities])
    context = "\n".join(retrieved_docs)
    prompt = f"""You are a legal assistant for Arabic law. Using the following context and extracted entities, provide a concise legal insight (e.g., summary or explanation). Ensure the response is grounded in the context and entities.

Context:
{context}

Entities:
{entity_str}

Input Text:
{text}

Insight:"""
    outputs = llm.generate([prompt], sampling_params)
    return outputs[0].outputs[0].text

def main_interface(text, entity_types):
    """Main Gradio interface."""
    # Run NER
    ner_result = run_ner(text, entity_types)
    
    # Retrieve relevant documents
    retrieved_docs = retrieve_documents(text)
    
    # Generate legal insight
    insight = generate_legal_insight(text, ner_result, retrieved_docs)
    
    return ner_result, retrieved_docs, insight

# Gradio interface
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
    gr.Markdown("# Arabic Legal Demo: NER & RAG with GLiNER and QwQ-32B")
    with gr.Row():
        text_input = gr.Textbox(label="Arabic Legal Text", lines=5, placeholder="Enter Arabic legal text...")
        entity_types = gr.Textbox(
            label="Entity Types (comma-separated)",
            value="person,law,organization",
            placeholder="e.g., person,law,organization"
        )
    submit_btn = gr.Button("Analyze")
    ner_output = gr.JSON(label="Extracted Entities")
    docs_output = gr.Textbox(label="Retrieved Legal Context")
    insight_output = gr.Textbox(label="Legal Insight")
    
    submit_btn.click(
        fn=main_interface,
        inputs=[text_input, entity_types],
        outputs=[ner_output, docs_output, insight_output]
    )

demo.launch()