Commit
395aceb
·
verified ·
1 Parent(s): b9f3278

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import torch
3
  import os
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
5
 
6
  # Load Hugging Face token from the environment variable
7
  HF_TOKEN = os.getenv("HF_TOKEN")
@@ -10,7 +11,8 @@ if HF_TOKEN is None:
10
 
11
  # Check for GPU support and configure appropriately
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
- print(f"Device being used: {device}")
 
14
 
15
  # Model configurations
16
  MSA_TO_SYRIAN_MODEL = "Omartificial-Intelligence-Space/Shami-MT"
@@ -27,6 +29,7 @@ syrian_to_msa_model = AutoModelForSeq2SeqLM.from_pretrained(SYRIAN_TO_MSA_MODEL)
27
 
28
  print("Models loaded successfully!")
29
 
 
30
  def translate_msa_to_syrian(text):
31
  """Translate from Modern Standard Arabic to Syrian dialect"""
32
  if not text.strip():
@@ -40,6 +43,7 @@ def translate_msa_to_syrian(text):
40
  except Exception as e:
41
  return f"Translation error: {str(e)}"
42
 
 
43
  def translate_syrian_to_msa(text):
44
  """Translate from Syrian dialect to Modern Standard Arabic"""
45
  if not text.strip():
 
2
  import torch
3
  import os
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
+ import spaces
6
 
7
  # Load Hugging Face token from the environment variable
8
  HF_TOKEN = os.getenv("HF_TOKEN")
 
11
 
12
  # Check for GPU support and configure appropriately
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ zero = torch.Tensor([0]).to(device)
15
+ print(f"Device being used: {zero.device}")
16
 
17
  # Model configurations
18
  MSA_TO_SYRIAN_MODEL = "Omartificial-Intelligence-Space/Shami-MT"
 
29
 
30
  print("Models loaded successfully!")
31
 
32
+ @spaces.GPU(duration=120)
33
  def translate_msa_to_syrian(text):
34
  """Translate from Modern Standard Arabic to Syrian dialect"""
35
  if not text.strip():
 
43
  except Exception as e:
44
  return f"Translation error: {str(e)}"
45
 
46
+ @spaces.GPU(duration=120)
47
  def translate_syrian_to_msa(text):
48
  """Translate from Syrian dialect to Modern Standard Arabic"""
49
  if not text.strip():