oberbics commited on
Commit
f764538
·
verified ·
1 Parent(s): 5f7fbe7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -173
app.py CHANGED
@@ -1,181 +1,85 @@
1
- import os
2
- import re
3
- import time
4
  import json
5
- from itertools import cycle
6
-
7
  import torch
8
  import gradio as gr
9
- from urllib.parse import unquote
10
- from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
11
-
12
- from data import extract_leaves, split_document, handle_broken_output, clean_json_text, sync_empty_fields
13
- from examples import examples as input_examples
14
- from nuextract_logging import log_event
15
-
16
-
17
- MAX_INPUT_SIZE = 100_000
18
- MAX_NEW_TOKENS = 4_000
19
- MAX_WINDOW_SIZE = 10_000
20
-
21
- markdown_description = """
22
- <!DOCTYPE html>
23
- <html lang="en">
24
- <head>
25
- <meta charset="UTF-8">
26
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
27
- </head>
28
- <body>
29
- <img src="https://cdn.prod.website-files.com/638364a4e52e440048a9529c/64188f405afcf42d0b85b926_logo_numind_final.png" alt="NuMind Logo" style="vertical-align: middle;width: 200px; height: 50px;">
30
- <br>
31
- <ul>
32
- <li>NuMind is a startup developing custom information extraction solutions.</li>
33
- <li>NuExtract is a zero-shot model. See the blog posts for more info (<a href="https://numind.ai/blog/nuextract-a-foundation-model-for-structured-extraction">NuExtract</a>, <a href="https://numind.ai/blog/nuextract-1-5---multilingual-infinite-context-still-small-and-better-than-gpt-4o">NuExtract-v1.5</a>).</li>
34
- <li>We have started to deploy NuMind Enterprise to customize, serve, and monitor NuExtract privately. If that interests you, let's chat 😊.</li>
35
- <li><strong>Website</strong>: <a href="https://www.numind.ai/">https://www.numind.ai/</a></li>
36
- </ul>
37
- <h1>NuExtract-v1.5</h1>
38
- <p>NuExtract-v1.5 is a fine-tuning of Phi-3.5-mini-instruct, trained on a private high-quality dataset for structured information extraction.
39
- It supports long documents and several languages (English, French, Spanish, German, Portuguese, and Italian).
40
- To use the model, provide an input text and a JSON template describing the information you need to extract.</p>
41
- <ul>
42
- <li><strong>Model</strong>: <a href="https://huggingface.co/numind/NuExtract-v1.5">numind/NuExtract-v1.5</a></li>
43
- </ul>
44
- <i>⚠️ In this space we restrict the model inputs to a maximum length of 10k tokens, with anything over 4k being processed in a sliding window. For full model performance, self-host the model or contact us.</i>
45
- <br>
46
- <i>⚠️ The model is trained to assume a valid JSON template. Attempts to use invalid JSON could lead to unpredictable results.</i>
47
- </body>
48
- </html>
49
- """
50
-
51
-
52
- def highlight_words(input_text, json_output):
53
- colors = cycle(["#90ee90", "#add8e6", "#ffb6c1", "#ffff99", "#ffa07a", "#20b2aa", "#87cefa", "#b0e0e6", "#dda0dd", "#ffdead"])
54
- color_map = {}
55
- highlighted_text = input_text
56
-
57
- leaves = extract_leaves(json_output)
58
- for path, value in leaves:
59
- path_key = tuple(path)
60
- if path_key not in color_map:
61
- color_map[path_key] = next(colors)
62
- color = color_map[path_key]
63
-
64
- escaped_value = re.escape(value).replace(r'\ ', r'\s+') # escape value and replace spaces with \s+
65
- pattern = rf"(?<=[ \n\t]){escaped_value}(?=[ \n\t\.\,\?\:\;])"
66
- replacement = f"<span style='background-color: {color};'>{unquote(value)}</span>"
67
- highlighted_text = re.sub(pattern, replacement, highlighted_text, flags=re.IGNORECASE)
68
-
69
- return highlighted_text
70
-
71
- def predict_chunk(text, template, current, model, tokenizer):
72
- current = clean_json_text(current)
73
-
74
- input_llm = f"<|input|>\n### Template:\n{template}\n### Current:\n{current}\n### Text:\n{text}\n\n<|output|>" + "{"
75
- input_ids = tokenizer(input_llm, return_tensors="pt", truncation=True, max_length=MAX_INPUT_SIZE).to("cuda")
76
- output = tokenizer.decode(model.generate(**input_ids, max_new_tokens=MAX_NEW_TOKENS, do_sample=False, use_cache=False)[0], skip_special_tokens=True)
77
- print(output)
78
- return clean_json_text(output.split("<|output|>")[1])
79
-
80
- def sliding_window_prediction(template, text, model, tokenizer, window_size=4000, overlap=128):
81
- # Split text into chunks of n tokens
82
- tokens = tokenizer.tokenize(text)
83
- chunks = split_document(text, window_size, overlap, tokenizer)
84
-
85
- # Iterate over text chunks
86
- prev = template
87
- full_pred = ""
88
-
89
- for i, chunk in enumerate(chunks):
90
- print(f"Processing chunk {i}...")
91
- pred = predict_chunk(chunk, template, prev, model, tokenizer)
92
 
93
- # Handle broken output
94
- pred = handle_broken_output(pred, prev)
 
 
 
95
 
96
- # create highlighted text
97
- try:
98
- highlighted_pred = highlight_words(text, json.loads(pred))
99
- except:
100
- highlighted_pred = text
101
-
102
- # attempt json parsing
103
- template_dict = None
104
- pred_dict = None
105
- try:
106
- template_dict = json.loads(template)
107
- except:
108
- pass
109
- try:
110
- pred_dict = json.loads(pred)
111
- except:
112
- pass
113
 
114
- # Sync empty fields
115
- if template_dict and pred_dict:
116
- synced_pred = sync_empty_fields(pred_dict, template_dict)
117
- synced_pred = json.dumps(synced_pred, indent=4, ensure_ascii=False)
118
- elif pred_dict:
119
- synced_pred = json.dumps(pred_dict, indent=4, ensure_ascii=False)
120
- else:
121
- synced_pred = pred
122
-
123
- # Return progress, current prediction, and updated HTML
124
- yield f"Processed chunk {i+1}/{len(chunks)}", synced_pred, highlighted_pred
125
-
126
- # Iterate
127
- prev = pred
128
-
129
-
130
- ######
131
-
132
- # Load the model and tokenizer
133
- model_name = "numind/NuExtract-v1.5"
134
- auth_token = os.environ.get("HF_TOKEN") or False
135
- model = AutoModelForCausalLM.from_pretrained(model_name,
136
- trust_remote_code=True,
137
- torch_dtype=torch.bfloat16,
138
- device_map="auto", use_auth_token=auth_token)
139
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)
140
- model.eval()
141
-
142
- def gradio_interface_function(template, text, size, is_example):
143
- if len(tokenizer.tokenize(text)) > MAX_INPUT_SIZE:
144
- yield "", "Input text too long for space. Download model to use unrestricted.", ""
145
- return # End the function since there was an error
146
-
147
- # Initialize the sliding window prediction process
148
- # Check if size is a boolean (from examples) and use a default if it is
149
- if isinstance(size, bool) or size == 'True':
150
- window_size = 4000 # Use default window size for examples
151
- else:
152
- window_size = int(size)
153
 
154
- prediction_generator = sliding_window_prediction(template, text, model, tokenizer, window_size=window_size)
155
-
156
- # Iterate over the generator to return values at each step
157
- for progress, full_pred, html_content in prediction_generator:
158
- yield progress, full_pred, html_content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
- # Removed the logging code entirely
161
-
162
- # Set up the Gradio interface
163
- iface = gr.Interface(
164
- description=markdown_description,
165
- fn=gradio_interface_function,
166
- inputs=[
167
- gr.Textbox(lines=2, placeholder="Enter Template here...", label="Template"),
168
- gr.Textbox(lines=2, placeholder="Enter input Text here...", label="Input Text"),
169
- gr.Textbox(lines=2, placeholder="Enter windows size here...", label="Size"),
170
- gr.Checkbox(label="Is Example?", visible=False),
171
- ],
172
- outputs=[
173
- gr.Textbox(label="Progress"),
174
- gr.Textbox(label="Model Output"),
175
- gr.HTML(label="Model Output with Highlighted Words"),
176
- ],
177
- examples=input_examples,
178
- # live=True # Enable real-time updates
179
- )
180
-
181
- iface.launch(debug=True, share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
 
 
2
  import torch
3
  import gradio as gr
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ # Simplified extraction function
7
+ def extract_structure(template, text, progress=None):
8
+ try:
9
+ # Format the input
10
+ prompt = f"<|input|>\n### Template:\n{template}\n### Text:\n{text}\n\n<|output|>"
11
 
12
+ # Generate prediction
13
+ input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
14
+ output = tokenizer.decode(model.generate(**input_ids, max_new_tokens=2000)[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ # Extract result
17
+ result = output.split("<|output|>")[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ # Highlight found items in text (simplified)
20
+ highlighted = f"<p>Processed text of length {len(text)} characters</p>"
21
+
22
+ return "Processing complete", result, highlighted
23
+ except Exception as e:
24
+ return f"Error: {str(e)}", "{}", "<p>Processing failed</p>"
25
+
26
+ # Load model
27
+ model_name = "numind/NuExtract-1.5"
28
+ try:
29
+ model = AutoModelForCausalLM.from_pretrained(
30
+ model_name,
31
+ torch_dtype=torch.float16, # Using float16 instead of bfloat16 for better compatibility
32
+ trust_remote_code=True,
33
+ device_map="auto"
34
+ )
35
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
36
+ model_loaded = True
37
+ except Exception as e:
38
+ print(f"Model loading error: {e}")
39
+ model_loaded = False
40
+
41
+ # Create interface
42
+ with gr.Blocks() as demo:
43
+ gr.Markdown("# NuExtract-1.5 Demo")
44
 
45
+ if not model_loaded:
46
+ gr.Markdown("## ⚠️ Model failed to load. Using dummy mode.")
47
+
48
+ with gr.Row():
49
+ with gr.Column():
50
+ template_input = gr.Textbox(
51
+ label="Template (JSON)",
52
+ value='{"name": "", "email": ""}',
53
+ lines=5
54
+ )
55
+ text_input = gr.Textbox(
56
+ label="Input Text",
57
+ value="Contact: John Smith ([email protected])",
58
+ lines=10
59
+ )
60
+ submit_btn = gr.Button("Extract Information")
61
+
62
+ with gr.Column():
63
+ progress_output = gr.Textbox(label="Progress")
64
+ result_output = gr.Textbox(label="Extracted Information")
65
+ html_output = gr.HTML(label="Highlighted Text")
66
+
67
+ submit_btn.click(
68
+ fn=extract_structure,
69
+ inputs=[template_input, text_input],
70
+ outputs=[progress_output, result_output, html_output]
71
+ )
72
+
73
+ # Simple example
74
+ gr.Examples(
75
+ [
76
+ [
77
+ '{"name": "", "email": ""}',
78
+ 'Contact: John Smith ([email protected])'
79
+ ]
80
+ ],
81
+ [template_input, text_input]
82
+ )
83
+
84
+ if __name__ == "__main__":
85
+ demo.launch()