Spaces:
Build error
Build error
import gradio as gr | |
from gliner import GLiNER | |
from vllm import LLM, SamplingParams | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import numpy as np | |
import json | |
import torch | |
# Load mock legal corpus | |
with open("legal_corpus.json", "r", encoding="utf-8") as f: | |
corpus = json.load(f) | |
documents = [item["text"] for item in corpus] | |
# Initialize sentence transformer for embeddings | |
embedder = SentenceTransformer("all-MiniLM-L6-v2") # Lightweight embedder | |
embeddings = embedder.encode(documents, convert_to_numpy=True) | |
# Initialize FAISS index | |
dimension = embeddings.shape[1] | |
index = faiss.IndexFlatL2(dimension) | |
index.add(embeddings) | |
# Initialize GLiNER model | |
gliner_model = GLiNER.from_pretrained("NAMAA-Space/gliner_arabic-v2.1", load_tokenizer=True) | |
# Initialize QwQ-32B | |
llm = LLM( | |
model="Qwen/QwQ-32B", | |
quantization="awq", | |
max_model_len=4096, | |
gpu_memory_utilization=0.9 | |
) | |
sampling_params = SamplingParams(temperature=0.7, max_tokens=512) | |
def retrieve_documents(query, k=2): | |
"""Retrieve top-k relevant documents using FAISS.""" | |
query_embedding = embedder.encode([query], convert_to_numpy=True) | |
distances, indices = index.search(query_embedding, k) | |
return [documents[idx] for idx in indices[0]] | |
def run_ner(text, entity_types): | |
"""Run NER with gliner_arabic-v2.1.""" | |
if not text or not entity_types: | |
return [] | |
entity_list = [e.strip() for e in entity_types.split(",")] | |
entities = gliner_model.predict_entities(text, entity_list, threshold=0.5) | |
return [{"text": e["text"], "label": e["label"], "score": round(e["score"], 2)} for e in entities] | |
def generate_legal_insight(text, entities, retrieved_docs): | |
"""Generate insight with QwQ-32B using RAG.""" | |
entity_str = ", ".join([f"{e['text']} ({e['label']})" for e in entities]) | |
context = "\n".join(retrieved_docs) | |
prompt = f"""You are a legal assistant for Arabic law. Using the following context and extracted entities, provide a concise legal insight (e.g., summary or explanation). Ensure the response is grounded in the context and entities. | |
Context: | |
{context} | |
Entities: | |
{entity_str} | |
Input Text: | |
{text} | |
Insight:""" | |
outputs = llm.generate([prompt], sampling_params) | |
return outputs[0].outputs[0].text | |
def main_interface(text, entity_types): | |
"""Main Gradio interface.""" | |
# Run NER | |
ner_result = run_ner(text, entity_types) | |
# Retrieve relevant documents | |
retrieved_docs = retrieve_documents(text) | |
# Generate legal insight | |
insight = generate_legal_insight(text, ner_result, retrieved_docs) | |
return ner_result, retrieved_docs, insight | |
# Gradio interface | |
with gr.Blocks(theme=gr.themes.Monochrome()) as demo: | |
gr.Markdown("# Arabic Legal Demo: NER & RAG with GLiNER and QwQ-32B") | |
with gr.Row(): | |
text_input = gr.Textbox(label="Arabic Legal Text", lines=5, placeholder="Enter Arabic legal text...") | |
entity_types = gr.Textbox( | |
label="Entity Types (comma-separated)", | |
value="person,law,organization", | |
placeholder="e.g., person,law,organization" | |
) | |
submit_btn = gr.Button("Analyze") | |
ner_output = gr.JSON(label="Extracted Entities") | |
docs_output = gr.Textbox(label="Retrieved Legal Context") | |
insight_output = gr.Textbox(label="Legal Insight") | |
submit_btn.click( | |
fn=main_interface, | |
inputs=[text_input, entity_types], | |
outputs=[ner_output, docs_output, insight_output] | |
) | |
demo.launch() |