Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
|
|
2 |
from transformers import pipeline
|
3 |
import PyPDF2
|
4 |
from docx import Document
|
|
|
5 |
|
6 |
# Load pipelines
|
7 |
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
@@ -11,7 +12,7 @@ ner = pipeline("ner", model="Jean-Baptiste/roberta-large-ner-english", grouped_e
|
|
11 |
def read_file(file_obj):
|
12 |
name = file_obj.name
|
13 |
if name.endswith(".txt"):
|
14 |
-
return file_obj.read().decode("utf-8")
|
15 |
elif name.endswith(".pdf"):
|
16 |
reader = PyPDF2.PdfReader(file_obj)
|
17 |
return " ".join([page.extract_text() for page in reader.pages if page.extract_text()])
|
@@ -26,10 +27,25 @@ def is_contract(text):
|
|
26 |
result = classifier(text[:1000], ["contract", "not a contract"])
|
27 |
return result['labels'][0] == 'contract', result
|
28 |
|
29 |
-
#
|
30 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
entities = ner(text[:1000])
|
32 |
-
|
|
|
|
|
|
|
33 |
|
34 |
# Main logic
|
35 |
def process_file(file):
|
@@ -39,7 +55,7 @@ def process_file(file):
|
|
39 |
|
40 |
is_contract_flag, classification = is_contract(text)
|
41 |
if is_contract_flag:
|
42 |
-
parties =
|
43 |
return "β
This is a contract.", ", ".join(parties)
|
44 |
else:
|
45 |
return "β This is NOT a contract.", ""
|
@@ -50,10 +66,10 @@ iface = gr.Interface(
|
|
50 |
inputs=gr.File(file_types=[".txt", ".pdf", ".docx"], label="Upload a document"),
|
51 |
outputs=[
|
52 |
gr.Textbox(label="Classification Result"),
|
53 |
-
gr.Textbox(label="Detected Parties (ORG/PER)")
|
54 |
],
|
55 |
title="Contract Classifier with RoBERTa",
|
56 |
-
description="Upload a document (.pdf, .txt, .docx) to detect if it's a contract and extract involved parties using RoBERTa."
|
57 |
)
|
58 |
|
59 |
iface.launch()
|
|
|
2 |
from transformers import pipeline
|
3 |
import PyPDF2
|
4 |
from docx import Document
|
5 |
+
import re
|
6 |
|
7 |
# Load pipelines
|
8 |
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
|
|
12 |
def read_file(file_obj):
|
13 |
name = file_obj.name
|
14 |
if name.endswith(".txt"):
|
15 |
+
return file_obj.read().decode("utf-8", errors="ignore")
|
16 |
elif name.endswith(".pdf"):
|
17 |
reader = PyPDF2.PdfReader(file_obj)
|
18 |
return " ".join([page.extract_text() for page in reader.pages if page.extract_text()])
|
|
|
27 |
result = classifier(text[:1000], ["contract", "not a contract"])
|
28 |
return result['labels'][0] == 'contract', result
|
29 |
|
30 |
+
# Rule-based + NER-based party extraction
|
31 |
+
def extract_parties_with_rules(text):
|
32 |
+
results = set()
|
33 |
+
|
34 |
+
# Rule-based: between X and Y
|
35 |
+
matches = re.findall(r'between\s+(.*?)\s+and\s+(.*?)[\.,\n]', text, re.IGNORECASE)
|
36 |
+
for match in matches:
|
37 |
+
results.update(match)
|
38 |
+
|
39 |
+
# Rule-based: "X" (Party A), etc.
|
40 |
+
named_matches = re.findall(r'β([^β]+)β\s*\(.*?Party [AB]\)', text)
|
41 |
+
results.update(named_matches)
|
42 |
+
|
43 |
+
# NER fallback
|
44 |
entities = ner(text[:1000])
|
45 |
+
ner_parties = [ent['word'] for ent in entities if ent['entity_group'] in ['ORG', 'PER']]
|
46 |
+
results.update(ner_parties)
|
47 |
+
|
48 |
+
return list(results)
|
49 |
|
50 |
# Main logic
|
51 |
def process_file(file):
|
|
|
55 |
|
56 |
is_contract_flag, classification = is_contract(text)
|
57 |
if is_contract_flag:
|
58 |
+
parties = extract_parties_with_rules(text)
|
59 |
return "β
This is a contract.", ", ".join(parties)
|
60 |
else:
|
61 |
return "β This is NOT a contract.", ""
|
|
|
66 |
inputs=gr.File(file_types=[".txt", ".pdf", ".docx"], label="Upload a document"),
|
67 |
outputs=[
|
68 |
gr.Textbox(label="Classification Result"),
|
69 |
+
gr.Textbox(label="Detected Parties (ORG/PER or Rule-based)")
|
70 |
],
|
71 |
title="Contract Classifier with RoBERTa",
|
72 |
+
description="Upload a document (.pdf, .txt, .docx) to detect if it's a contract and extract involved parties using RoBERTa + Rule-based matching."
|
73 |
)
|
74 |
|
75 |
iface.launch()
|