acecalisto3 commited on
Commit
c156661
·
verified ·
1 Parent(s): bddaba5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -140,7 +140,7 @@ class CustomDataset(torch.utils.data.Dataset):
140
  text,
141
  max_length=self.max_length,
142
  padding='max_length',
143
- truncation=True,
144
  return_attention_mask=True,
145
  return_tensors='pt',
146
  )
@@ -154,7 +154,7 @@ class CustomDataset(torch.utils.data.Dataset):
154
  logger.error(f"Error in processing item {idx}: {e}")
155
  raise
156
 
157
- def train_model(model_name, data, batch_size, epochs, learning_rate=1e-5, max_length=512):
158
  try:
159
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
160
  tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -266,7 +266,7 @@ iface = gr.Interface(
266
  fn=gradio_interface,
267
  inputs=[
268
  gr.Textbox(lines=5, label="Enter comma-separated URLs"),
269
- gr.File(label="Upload file (including zip files)"),
270
  gr.Textbox(lines=10, label="Enter or paste large text"),
271
  gr.Textbox(label="Model name", value="distilbert-base-uncased"),
272
  gr.Number(label="Batch size", value=8),
 
140
  text,
141
  max_length=self.max_length,
142
  padding='max_length',
143
+ trunc ation=True,
144
  return_attention_mask=True,
145
  return_tensors='pt',
146
  )
 
154
  logger.error(f"Error in processing item {idx}: {e}")
155
  raise
156
 
157
+ def train_model(model_name, data, batch_size, epochs, learning_rate=1e-5, max_length=2048):
158
  try:
159
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
160
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
266
  fn=gradio_interface,
267
  inputs=[
268
  gr.Textbox(lines=5, label="Enter comma-separated URLs"),
269
+ gr.File(label="Upload file (including zip files)", type="file", max_size=1 * 1024 * 1024 * 1024), # Allow files up to 1 GB
270
  gr.Textbox(lines=10, label="Enter or paste large text"),
271
  gr.Textbox(label="Model name", value="distilbert-base-uncased"),
272
  gr.Number(label="Batch size", value=8),