ghostai1 commited on
Commit
ae433b0
·
verified ·
1 Parent(s): 209a1de

Create app.py

Browse files

Framework build 1

Files changed (1) hide show
  1. app.py +103 -0
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gliner import GLiNER
3
+ from vllm import LLM, SamplingParams
4
+ from sentence_transformers import SentenceTransformer
5
+ import faiss
6
+ import numpy as np
7
+ import json
8
+ import torch
9
+
10
+ # Load mock legal corpus
11
+ with open("legal_corpus.json", "r", encoding="utf-8") as f:
12
+ corpus = json.load(f)
13
+ documents = [item["text"] for item in corpus]
14
+
15
+ # Initialize sentence transformer for embeddings
16
+ embedder = SentenceTransformer("all-MiniLM-L6-v2") # Lightweight embedder
17
+ embeddings = embedder.encode(documents, convert_to_numpy=True)
18
+
19
+ # Initialize FAISS index
20
+ dimension = embeddings.shape[1]
21
+ index = faiss.IndexFlatL2(dimension)
22
+ index.add(embeddings)
23
+
24
+ # Initialize GLiNER model
25
+ gliner_model = GLiNER.from_pretrained("NAMAA-Space/gliner_arabic-v2.1", load_tokenizer=True)
26
+
27
+ # Initialize QwQ-32B
28
+ llm = LLM(
29
+ model="Qwen/QwQ-32B",
30
+ quantization="awq",
31
+ max_model_len=4096,
32
+ gpu_memory_utilization=0.9
33
+ )
34
+ sampling_params = SamplingParams(temperature=0.7, max_tokens=512)
35
+
36
+ def retrieve_documents(query, k=2):
37
+ """Retrieve top-k relevant documents using FAISS."""
38
+ query_embedding = embedder.encode([query], convert_to_numpy=True)
39
+ distances, indices = index.search(query_embedding, k)
40
+ return [documents[idx] for idx in indices[0]]
41
+
42
+ def run_ner(text, entity_types):
43
+ """Run NER with gliner_arabic-v2.1."""
44
+ if not text or not entity_types:
45
+ return []
46
+ entity_list = [e.strip() for e in entity_types.split(",")]
47
+ entities = gliner_model.predict_entities(text, entity_list, threshold=0.5)
48
+ return [{"text": e["text"], "label": e["label"], "score": round(e["score"], 2)} for e in entities]
49
+
50
+ def generate_legal_insight(text, entities, retrieved_docs):
51
+ """Generate insight with QwQ-32B using RAG."""
52
+ entity_str = ", ".join([f"{e['text']} ({e['label']})" for e in entities])
53
+ context = "\n".join(retrieved_docs)
54
+ prompt = f"""You are a legal assistant for Arabic law. Using the following context and extracted entities, provide a concise legal insight (e.g., summary or explanation). Ensure the response is grounded in the context and entities.
55
+
56
+ Context:
57
+ {context}
58
+
59
+ Entities:
60
+ {entity_str}
61
+
62
+ Input Text:
63
+ {text}
64
+
65
+ Insight:"""
66
+ outputs = llm.generate([prompt], sampling_params)
67
+ return outputs[0].outputs[0].text
68
+
69
+ def main_interface(text, entity_types):
70
+ """Main Gradio interface."""
71
+ # Run NER
72
+ ner_result = run_ner(text, entity_types)
73
+
74
+ # Retrieve relevant documents
75
+ retrieved_docs = retrieve_documents(text)
76
+
77
+ # Generate legal insight
78
+ insight = generate_legal_insight(text, ner_result, retrieved_docs)
79
+
80
+ return ner_result, retrieved_docs, insight
81
+
82
+ # Gradio interface
83
+ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
84
+ gr.Markdown("# Arabic Legal Demo: NER & RAG with GLiNER and QwQ-32B")
85
+ with gr.Row():
86
+ text_input = gr.Textbox(label="Arabic Legal Text", lines=5, placeholder="Enter Arabic legal text...")
87
+ entity_types = gr.Textbox(
88
+ label="Entity Types (comma-separated)",
89
+ value="person,law,organization",
90
+ placeholder="e.g., person,law,organization"
91
+ )
92
+ submit_btn = gr.Button("Analyze")
93
+ ner_output = gr.JSON(label="Extracted Entities")
94
+ docs_output = gr.Textbox(label="Retrieved Legal Context")
95
+ insight_output = gr.Textbox(label="Legal Insight")
96
+
97
+ submit_btn.click(
98
+ fn=main_interface,
99
+ inputs=[text_input, entity_types],
100
+ outputs=[ner_output, docs_output, insight_output]
101
+ )
102
+
103
+ demo.launch()