Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -140,7 +140,7 @@ class CustomDataset(torch.utils.data.Dataset):
|
|
140 |
text,
|
141 |
max_length=self.max_length,
|
142 |
padding='max_length',
|
143 |
-
|
144 |
return_attention_mask=True,
|
145 |
return_tensors='pt',
|
146 |
)
|
@@ -154,7 +154,7 @@ class CustomDataset(torch.utils.data.Dataset):
|
|
154 |
logger.error(f"Error in processing item {idx}: {e}")
|
155 |
raise
|
156 |
|
157 |
-
def train_model(model_name, data, batch_size, epochs, learning_rate=1e-5, max_length=
|
158 |
try:
|
159 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
160 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
@@ -266,7 +266,7 @@ iface = gr.Interface(
|
|
266 |
fn=gradio_interface,
|
267 |
inputs=[
|
268 |
gr.Textbox(lines=5, label="Enter comma-separated URLs"),
|
269 |
-
gr.File(label="Upload file (including zip files)"),
|
270 |
gr.Textbox(lines=10, label="Enter or paste large text"),
|
271 |
gr.Textbox(label="Model name", value="distilbert-base-uncased"),
|
272 |
gr.Number(label="Batch size", value=8),
|
|
|
140 |
text,
|
141 |
max_length=self.max_length,
|
142 |
padding='max_length',
|
143 |
+
trunc ation=True,
|
144 |
return_attention_mask=True,
|
145 |
return_tensors='pt',
|
146 |
)
|
|
|
154 |
logger.error(f"Error in processing item {idx}: {e}")
|
155 |
raise
|
156 |
|
157 |
+
def train_model(model_name, data, batch_size, epochs, learning_rate=1e-5, max_length=2048):
|
158 |
try:
|
159 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
160 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
266 |
fn=gradio_interface,
|
267 |
inputs=[
|
268 |
gr.Textbox(lines=5, label="Enter comma-separated URLs"),
|
269 |
+
gr.File(label="Upload file (including zip files)", type="file", max_size=1 * 1024 * 1024 * 1024), # Allow files up to 1 GB
|
270 |
gr.Textbox(lines=10, label="Enter or paste large text"),
|
271 |
gr.Textbox(label="Model name", value="distilbert-base-uncased"),
|
272 |
gr.Number(label="Batch size", value=8),
|