Update app.py
Browse files
app.py
CHANGED
@@ -4,86 +4,139 @@ import json
|
|
4 |
import uuid
|
5 |
import torch
|
6 |
import datetime
|
7 |
-
import
|
8 |
-
from transformers import AutoTokenizer,
|
9 |
-
from huggingface_hub import HfApi, create_repo,
|
10 |
-
from
|
11 |
|
12 |
-
#
|
13 |
-
MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
|
14 |
HF_DATASET_REPO = "M2ai/mgtd-logs"
|
15 |
HF_TOKEN = os.getenv("Mgtd")
|
16 |
DATASET_CREATED = False
|
17 |
|
18 |
-
#
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
#
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
def setup_hf_dataset():
|
26 |
global DATASET_CREATED
|
27 |
if not DATASET_CREATED and HF_TOKEN:
|
28 |
try:
|
29 |
-
api = HfApi()
|
30 |
create_repo(HF_DATASET_REPO, repo_type="dataset", token=HF_TOKEN, exist_ok=True)
|
31 |
DATASET_CREATED = True
|
32 |
-
print(f"Dataset {HF_DATASET_REPO} is ready")
|
33 |
except Exception as e:
|
34 |
print(f"Error setting up dataset: {e}")
|
35 |
-
elif not HF_TOKEN:
|
36 |
-
print("Warning: HF_TOKEN not set. Logs will be saved locally only.")
|
37 |
|
|
|
38 |
def infer_and_log(text_input):
|
39 |
-
|
40 |
-
with torch.no_grad():
|
41 |
-
outputs = model(**inputs)
|
42 |
-
logits = outputs.logits.tolist()
|
43 |
-
predicted = torch.argmax(outputs.logits, dim=-1).item()
|
44 |
-
label = model.config.id2label[predicted]
|
45 |
-
|
46 |
timestamp = datetime.datetime.now().isoformat()
|
47 |
submission_id = str(uuid.uuid4())
|
|
|
48 |
log_data = {
|
49 |
"id": submission_id,
|
50 |
"timestamp": timestamp,
|
51 |
"input": text_input,
|
52 |
-
"
|
53 |
}
|
54 |
|
|
|
55 |
log_file = f"logs/{timestamp.replace(':', '_')}.json"
|
56 |
with open(log_file, "w") as f:
|
57 |
json.dump(log_data, f, indent=2)
|
58 |
|
59 |
if HF_TOKEN and DATASET_CREATED:
|
60 |
try:
|
61 |
-
|
62 |
-
api.upload_file(
|
63 |
path_or_fileobj=log_file,
|
64 |
path_in_repo=f"logs/{os.path.basename(log_file)}",
|
65 |
repo_id=HF_DATASET_REPO,
|
66 |
repo_type="dataset",
|
67 |
token=HF_TOKEN
|
68 |
)
|
69 |
-
print(f"Uploaded log {submission_id}
|
70 |
except Exception as e:
|
71 |
-
print(f"Error uploading
|
72 |
|
73 |
-
return
|
74 |
|
75 |
def clear_fields():
|
76 |
return "", ""
|
77 |
|
78 |
-
#
|
79 |
setup_hf_dataset()
|
80 |
|
|
|
81 |
with gr.Blocks() as app:
|
82 |
-
gr.Markdown("##
|
83 |
|
84 |
with gr.Row():
|
85 |
-
input_box = gr.Textbox(label="Input Text", lines=10
|
86 |
-
output_box = gr.Textbox(label="Output", lines=2, interactive=False)
|
87 |
|
88 |
with gr.Row():
|
89 |
submit_btn = gr.Button("Submit")
|
|
|
4 |
import uuid
|
5 |
import torch
|
6 |
import datetime
|
7 |
+
import torch.nn as nn
|
8 |
+
from transformers import AutoTokenizer, AutoModel, AutoConfig
|
9 |
+
from huggingface_hub import HfApi, create_repo, hf_hub_download
|
10 |
+
from torchcrf import CRF
|
11 |
|
12 |
+
# Constants
|
|
|
13 |
HF_DATASET_REPO = "M2ai/mgtd-logs"
|
14 |
HF_TOKEN = os.getenv("Mgtd")
|
15 |
DATASET_CREATED = False
|
16 |
|
17 |
+
# Model identifiers
|
18 |
+
code = "ENG"
|
19 |
+
pntr = 2
|
20 |
+
model_name_or_path = "microsoft/mdeberta-v3-base"
|
21 |
+
hf_token = os.environ.get("HF_WRITE") # Set this before running
|
22 |
+
|
23 |
+
# Download model checkpoint
|
24 |
+
file_path = hf_hub_download(
|
25 |
+
repo_id="1024m/MGTD-Long-New",
|
26 |
+
filename=f"{code}/mdeberta-epoch-{pntr}.pt",
|
27 |
+
token=hf_token,
|
28 |
+
local_dir="./checkpoints"
|
29 |
+
)
|
30 |
+
|
31 |
+
# Define CRF model
|
32 |
+
class AutoModelCRF(nn.Module):
|
33 |
+
def __init__(self, model_name_or_path, dropout=0.075):
|
34 |
+
super().__init__()
|
35 |
+
self.config = AutoConfig.from_pretrained(model_name_or_path)
|
36 |
+
self.num_labels = 2
|
37 |
+
self.encoder = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True, config=self.config)
|
38 |
+
self.dropout = nn.Dropout(dropout)
|
39 |
+
self.linear = nn.Linear(self.config.hidden_size, self.num_labels)
|
40 |
+
self.crf = CRF(self.num_labels, batch_first=True)
|
41 |
+
|
42 |
+
def forward(self, input_ids, attention_mask):
|
43 |
+
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
44 |
+
seq_output = self.dropout(outputs[0])
|
45 |
+
emissions = self.linear(seq_output)
|
46 |
+
tags = self.crf.decode(emissions, attention_mask.byte())
|
47 |
+
return tags, emissions
|
48 |
+
|
49 |
+
# Load model
|
50 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
51 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
52 |
+
model = AutoModelCRF(model_name_or_path)
|
53 |
+
checkpoint = torch.load(file_path, map_location="cpu")
|
54 |
+
model.load_state_dict(checkpoint.get("model_state_dict", checkpoint), strict=False)
|
55 |
+
model = model.to(device)
|
56 |
+
model.eval()
|
57 |
+
|
58 |
+
# Inference function
|
59 |
+
def get_word_classifications(text):
|
60 |
+
text = " ".join(text.split(" ")[:2048])
|
61 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
62 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
63 |
+
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
64 |
+
with torch.no_grad():
|
65 |
+
tags, _ = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
|
66 |
+
word_tags = []
|
67 |
+
current_word = ""
|
68 |
+
current_tag = None
|
69 |
+
for token, tag in zip(tokens, tags[0]):
|
70 |
+
if token in ["<s>", "</s>"]:
|
71 |
+
continue
|
72 |
+
if token.startswith("▁"):
|
73 |
+
if current_word:
|
74 |
+
word_tags.append(str(current_tag))
|
75 |
+
current_word = token[1:] if token != "▁" else ""
|
76 |
+
current_tag = tag
|
77 |
+
else:
|
78 |
+
current_word += token
|
79 |
+
if current_word:
|
80 |
+
word_tags.append(str(current_tag))
|
81 |
+
return word_tags
|
82 |
+
|
83 |
+
# HF logging setup
|
84 |
def setup_hf_dataset():
|
85 |
global DATASET_CREATED
|
86 |
if not DATASET_CREATED and HF_TOKEN:
|
87 |
try:
|
|
|
88 |
create_repo(HF_DATASET_REPO, repo_type="dataset", token=HF_TOKEN, exist_ok=True)
|
89 |
DATASET_CREATED = True
|
90 |
+
print(f"Dataset {HF_DATASET_REPO} is ready.")
|
91 |
except Exception as e:
|
92 |
print(f"Error setting up dataset: {e}")
|
|
|
|
|
93 |
|
94 |
+
# Main inference + logging function
|
95 |
def infer_and_log(text_input):
|
96 |
+
word_tags = get_word_classifications(text_input)
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
timestamp = datetime.datetime.now().isoformat()
|
98 |
submission_id = str(uuid.uuid4())
|
99 |
+
|
100 |
log_data = {
|
101 |
"id": submission_id,
|
102 |
"timestamp": timestamp,
|
103 |
"input": text_input,
|
104 |
+
"output_tags": word_tags
|
105 |
}
|
106 |
|
107 |
+
os.makedirs("logs", exist_ok=True)
|
108 |
log_file = f"logs/{timestamp.replace(':', '_')}.json"
|
109 |
with open(log_file, "w") as f:
|
110 |
json.dump(log_data, f, indent=2)
|
111 |
|
112 |
if HF_TOKEN and DATASET_CREATED:
|
113 |
try:
|
114 |
+
HfApi().upload_file(
|
|
|
115 |
path_or_fileobj=log_file,
|
116 |
path_in_repo=f"logs/{os.path.basename(log_file)}",
|
117 |
repo_id=HF_DATASET_REPO,
|
118 |
repo_type="dataset",
|
119 |
token=HF_TOKEN
|
120 |
)
|
121 |
+
print(f"Uploaded log {submission_id}")
|
122 |
except Exception as e:
|
123 |
+
print(f"Error uploading log: {e}")
|
124 |
|
125 |
+
return " ".join(word_tags)
|
126 |
|
127 |
def clear_fields():
|
128 |
return "", ""
|
129 |
|
130 |
+
# Prepare dataset once
|
131 |
setup_hf_dataset()
|
132 |
|
133 |
+
# Gradio UI
|
134 |
with gr.Blocks() as app:
|
135 |
+
gr.Markdown("## MDeBERTa+CRF Word Tagger")
|
136 |
|
137 |
with gr.Row():
|
138 |
+
input_box = gr.Textbox(label="Input Text", lines=10)
|
139 |
+
output_box = gr.Textbox(label="Output Tags", lines=2, interactive=False)
|
140 |
|
141 |
with gr.Row():
|
142 |
submit_btn = gr.Button("Submit")
|