Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,181 +1,85 @@
|
|
1 |
-
import os
|
2 |
-
import re
|
3 |
-
import time
|
4 |
import json
|
5 |
-
from itertools import cycle
|
6 |
-
|
7 |
import torch
|
8 |
import gradio as gr
|
9 |
-
from
|
10 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
|
11 |
-
|
12 |
-
from data import extract_leaves, split_document, handle_broken_output, clean_json_text, sync_empty_fields
|
13 |
-
from examples import examples as input_examples
|
14 |
-
from nuextract_logging import log_event
|
15 |
-
|
16 |
-
|
17 |
-
MAX_INPUT_SIZE = 100_000
|
18 |
-
MAX_NEW_TOKENS = 4_000
|
19 |
-
MAX_WINDOW_SIZE = 10_000
|
20 |
-
|
21 |
-
markdown_description = """
|
22 |
-
<!DOCTYPE html>
|
23 |
-
<html lang="en">
|
24 |
-
<head>
|
25 |
-
<meta charset="UTF-8">
|
26 |
-
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
27 |
-
</head>
|
28 |
-
<body>
|
29 |
-
<img src="https://cdn.prod.website-files.com/638364a4e52e440048a9529c/64188f405afcf42d0b85b926_logo_numind_final.png" alt="NuMind Logo" style="vertical-align: middle;width: 200px; height: 50px;">
|
30 |
-
<br>
|
31 |
-
<ul>
|
32 |
-
<li>NuMind is a startup developing custom information extraction solutions.</li>
|
33 |
-
<li>NuExtract is a zero-shot model. See the blog posts for more info (<a href="https://numind.ai/blog/nuextract-a-foundation-model-for-structured-extraction">NuExtract</a>, <a href="https://numind.ai/blog/nuextract-1-5---multilingual-infinite-context-still-small-and-better-than-gpt-4o">NuExtract-v1.5</a>).</li>
|
34 |
-
<li>We have started to deploy NuMind Enterprise to customize, serve, and monitor NuExtract privately. If that interests you, let's chat 😊.</li>
|
35 |
-
<li><strong>Website</strong>: <a href="https://www.numind.ai/">https://www.numind.ai/</a></li>
|
36 |
-
</ul>
|
37 |
-
<h1>NuExtract-v1.5</h1>
|
38 |
-
<p>NuExtract-v1.5 is a fine-tuning of Phi-3.5-mini-instruct, trained on a private high-quality dataset for structured information extraction.
|
39 |
-
It supports long documents and several languages (English, French, Spanish, German, Portuguese, and Italian).
|
40 |
-
To use the model, provide an input text and a JSON template describing the information you need to extract.</p>
|
41 |
-
<ul>
|
42 |
-
<li><strong>Model</strong>: <a href="https://huggingface.co/numind/NuExtract-v1.5">numind/NuExtract-v1.5</a></li>
|
43 |
-
</ul>
|
44 |
-
<i>⚠️ In this space we restrict the model inputs to a maximum length of 10k tokens, with anything over 4k being processed in a sliding window. For full model performance, self-host the model or contact us.</i>
|
45 |
-
<br>
|
46 |
-
<i>⚠️ The model is trained to assume a valid JSON template. Attempts to use invalid JSON could lead to unpredictable results.</i>
|
47 |
-
</body>
|
48 |
-
</html>
|
49 |
-
"""
|
50 |
-
|
51 |
-
|
52 |
-
def highlight_words(input_text, json_output):
|
53 |
-
colors = cycle(["#90ee90", "#add8e6", "#ffb6c1", "#ffff99", "#ffa07a", "#20b2aa", "#87cefa", "#b0e0e6", "#dda0dd", "#ffdead"])
|
54 |
-
color_map = {}
|
55 |
-
highlighted_text = input_text
|
56 |
-
|
57 |
-
leaves = extract_leaves(json_output)
|
58 |
-
for path, value in leaves:
|
59 |
-
path_key = tuple(path)
|
60 |
-
if path_key not in color_map:
|
61 |
-
color_map[path_key] = next(colors)
|
62 |
-
color = color_map[path_key]
|
63 |
-
|
64 |
-
escaped_value = re.escape(value).replace(r'\ ', r'\s+') # escape value and replace spaces with \s+
|
65 |
-
pattern = rf"(?<=[ \n\t]){escaped_value}(?=[ \n\t\.\,\?\:\;])"
|
66 |
-
replacement = f"<span style='background-color: {color};'>{unquote(value)}</span>"
|
67 |
-
highlighted_text = re.sub(pattern, replacement, highlighted_text, flags=re.IGNORECASE)
|
68 |
-
|
69 |
-
return highlighted_text
|
70 |
-
|
71 |
-
def predict_chunk(text, template, current, model, tokenizer):
|
72 |
-
current = clean_json_text(current)
|
73 |
-
|
74 |
-
input_llm = f"<|input|>\n### Template:\n{template}\n### Current:\n{current}\n### Text:\n{text}\n\n<|output|>" + "{"
|
75 |
-
input_ids = tokenizer(input_llm, return_tensors="pt", truncation=True, max_length=MAX_INPUT_SIZE).to("cuda")
|
76 |
-
output = tokenizer.decode(model.generate(**input_ids, max_new_tokens=MAX_NEW_TOKENS, do_sample=False, use_cache=False)[0], skip_special_tokens=True)
|
77 |
-
print(output)
|
78 |
-
return clean_json_text(output.split("<|output|>")[1])
|
79 |
-
|
80 |
-
def sliding_window_prediction(template, text, model, tokenizer, window_size=4000, overlap=128):
|
81 |
-
# Split text into chunks of n tokens
|
82 |
-
tokens = tokenizer.tokenize(text)
|
83 |
-
chunks = split_document(text, window_size, overlap, tokenizer)
|
84 |
-
|
85 |
-
# Iterate over text chunks
|
86 |
-
prev = template
|
87 |
-
full_pred = ""
|
88 |
-
|
89 |
-
for i, chunk in enumerate(chunks):
|
90 |
-
print(f"Processing chunk {i}...")
|
91 |
-
pred = predict_chunk(chunk, template, prev, model, tokenizer)
|
92 |
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
95 |
|
96 |
-
#
|
97 |
-
|
98 |
-
|
99 |
-
except:
|
100 |
-
highlighted_pred = text
|
101 |
-
|
102 |
-
# attempt json parsing
|
103 |
-
template_dict = None
|
104 |
-
pred_dict = None
|
105 |
-
try:
|
106 |
-
template_dict = json.loads(template)
|
107 |
-
except:
|
108 |
-
pass
|
109 |
-
try:
|
110 |
-
pred_dict = json.loads(pred)
|
111 |
-
except:
|
112 |
-
pass
|
113 |
|
114 |
-
#
|
115 |
-
|
116 |
-
synced_pred = sync_empty_fields(pred_dict, template_dict)
|
117 |
-
synced_pred = json.dumps(synced_pred, indent=4, ensure_ascii=False)
|
118 |
-
elif pred_dict:
|
119 |
-
synced_pred = json.dumps(pred_dict, indent=4, ensure_ascii=False)
|
120 |
-
else:
|
121 |
-
synced_pred = pred
|
122 |
-
|
123 |
-
# Return progress, current prediction, and updated HTML
|
124 |
-
yield f"Processed chunk {i+1}/{len(chunks)}", synced_pred, highlighted_pred
|
125 |
-
|
126 |
-
# Iterate
|
127 |
-
prev = pred
|
128 |
-
|
129 |
-
|
130 |
-
######
|
131 |
-
|
132 |
-
# Load the model and tokenizer
|
133 |
-
model_name = "numind/NuExtract-v1.5"
|
134 |
-
auth_token = os.environ.get("HF_TOKEN") or False
|
135 |
-
model = AutoModelForCausalLM.from_pretrained(model_name,
|
136 |
-
trust_remote_code=True,
|
137 |
-
torch_dtype=torch.bfloat16,
|
138 |
-
device_map="auto", use_auth_token=auth_token)
|
139 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)
|
140 |
-
model.eval()
|
141 |
-
|
142 |
-
def gradio_interface_function(template, text, size, is_example):
|
143 |
-
if len(tokenizer.tokenize(text)) > MAX_INPUT_SIZE:
|
144 |
-
yield "", "Input text too long for space. Download model to use unrestricted.", ""
|
145 |
-
return # End the function since there was an error
|
146 |
-
|
147 |
-
# Initialize the sliding window prediction process
|
148 |
-
# Check if size is a boolean (from examples) and use a default if it is
|
149 |
-
if isinstance(size, bool) or size == 'True':
|
150 |
-
window_size = 4000 # Use default window size for examples
|
151 |
-
else:
|
152 |
-
window_size = int(size)
|
153 |
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
)
|
180 |
-
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import json
|
|
|
|
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
+
# Simplified extraction function
|
7 |
+
def extract_structure(template, text, progress=None):
|
8 |
+
try:
|
9 |
+
# Format the input
|
10 |
+
prompt = f"<|input|>\n### Template:\n{template}\n### Text:\n{text}\n\n<|output|>"
|
11 |
|
12 |
+
# Generate prediction
|
13 |
+
input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
|
14 |
+
output = tokenizer.decode(model.generate(**input_ids, max_new_tokens=2000)[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
+
# Extract result
|
17 |
+
result = output.split("<|output|>")[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
+
# Highlight found items in text (simplified)
|
20 |
+
highlighted = f"<p>Processed text of length {len(text)} characters</p>"
|
21 |
+
|
22 |
+
return "Processing complete", result, highlighted
|
23 |
+
except Exception as e:
|
24 |
+
return f"Error: {str(e)}", "{}", "<p>Processing failed</p>"
|
25 |
+
|
26 |
+
# Load model
|
27 |
+
model_name = "numind/NuExtract-1.5"
|
28 |
+
try:
|
29 |
+
model = AutoModelForCausalLM.from_pretrained(
|
30 |
+
model_name,
|
31 |
+
torch_dtype=torch.float16, # Using float16 instead of bfloat16 for better compatibility
|
32 |
+
trust_remote_code=True,
|
33 |
+
device_map="auto"
|
34 |
+
)
|
35 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
36 |
+
model_loaded = True
|
37 |
+
except Exception as e:
|
38 |
+
print(f"Model loading error: {e}")
|
39 |
+
model_loaded = False
|
40 |
+
|
41 |
+
# Create interface
|
42 |
+
with gr.Blocks() as demo:
|
43 |
+
gr.Markdown("# NuExtract-1.5 Demo")
|
44 |
|
45 |
+
if not model_loaded:
|
46 |
+
gr.Markdown("## ⚠️ Model failed to load. Using dummy mode.")
|
47 |
+
|
48 |
+
with gr.Row():
|
49 |
+
with gr.Column():
|
50 |
+
template_input = gr.Textbox(
|
51 |
+
label="Template (JSON)",
|
52 |
+
value='{"name": "", "email": ""}',
|
53 |
+
lines=5
|
54 |
+
)
|
55 |
+
text_input = gr.Textbox(
|
56 |
+
label="Input Text",
|
57 |
+
value="Contact: John Smith ([email protected])",
|
58 |
+
lines=10
|
59 |
+
)
|
60 |
+
submit_btn = gr.Button("Extract Information")
|
61 |
+
|
62 |
+
with gr.Column():
|
63 |
+
progress_output = gr.Textbox(label="Progress")
|
64 |
+
result_output = gr.Textbox(label="Extracted Information")
|
65 |
+
html_output = gr.HTML(label="Highlighted Text")
|
66 |
+
|
67 |
+
submit_btn.click(
|
68 |
+
fn=extract_structure,
|
69 |
+
inputs=[template_input, text_input],
|
70 |
+
outputs=[progress_output, result_output, html_output]
|
71 |
+
)
|
72 |
+
|
73 |
+
# Simple example
|
74 |
+
gr.Examples(
|
75 |
+
[
|
76 |
+
[
|
77 |
+
'{"name": "", "email": ""}',
|
78 |
+
'Contact: John Smith ([email protected])'
|
79 |
+
]
|
80 |
+
],
|
81 |
+
[template_input, text_input]
|
82 |
+
)
|
83 |
+
|
84 |
+
if __name__ == "__main__":
|
85 |
+
demo.launch()
|