alakxender commited on
Commit
82b0ab8
·
1 Parent(s): 8a588ad
Files changed (3) hide show
  1. app.py +9 -5
  2. title_gen.py +4 -2
  3. typo_check.py +15 -8
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
 
3
- from typo_check import css, process_input
4
- from title_gen import generate_title, MODEL_OPTIONS
5
 
6
 
7
  # Create Gradio interface using the latest syntax
@@ -11,6 +11,9 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
11
  gr.Markdown("# <center>Dhivehi Typo Correction</center>")
12
  gr.Markdown("This app uses a fine-tuned T5 model to correct typos in Dhivehi text. Enter text with typos and the model will attempt to fix them.")
13
 
 
 
 
14
  with gr.Row():
15
  input_text = gr.Textbox(
16
  lines=1,
@@ -19,7 +22,8 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
19
  rtl=True,
20
  elem_classes="textbox1"
21
  )
22
-
 
23
  with gr.Row():
24
  corrected_text = gr.Textbox(
25
  lines=1,
@@ -37,7 +41,7 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
37
  submit_btn = gr.Button("ރަނގަޅު ކޮށްލުމަށް",elem_classes="textbox1") # "Correct" in Dhivehi
38
  submit_btn.click(
39
  fn=process_input,
40
- inputs=input_text,
41
  outputs=[corrected_text, highlighted_diff]
42
  )
43
 
@@ -78,7 +82,7 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
78
  with gr.Row():
79
  article_content = gr.Textbox(lines=10, label="Article Content", rtl=True, elem_classes="textbox1")
80
  with gr.Row():
81
- model_choice = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value="V6 Model", label="Model")
82
  with gr.Row():
83
  seed = gr.Slider(0, 10000, value=42, step=1, label="Random Seed")
84
  use_sampling = gr.Checkbox(label="Use Sampling (Creative/Random)", value=False)
 
1
  import gradio as gr
2
 
3
+ from typo_check import css, process_input,MODEL_OPTIONS_TYPO
4
+ from title_gen import generate_title, MODEL_OPTIONS_TITLE
5
 
6
 
7
  # Create Gradio interface using the latest syntax
 
11
  gr.Markdown("# <center>Dhivehi Typo Correction</center>")
12
  gr.Markdown("This app uses a fine-tuned T5 model to correct typos in Dhivehi text. Enter text with typos and the model will attempt to fix them.")
13
 
14
+ with gr.Row():
15
+ model_choice = gr.Dropdown(choices=list(MODEL_OPTIONS_TYPO.keys()), value="A3 Model", label="Model")
16
+
17
  with gr.Row():
18
  input_text = gr.Textbox(
19
  lines=1,
 
22
  rtl=True,
23
  elem_classes="textbox1"
24
  )
25
+
26
+
27
  with gr.Row():
28
  corrected_text = gr.Textbox(
29
  lines=1,
 
41
  submit_btn = gr.Button("ރަނގަޅު ކޮށްލުމަށް",elem_classes="textbox1") # "Correct" in Dhivehi
42
  submit_btn.click(
43
  fn=process_input,
44
+ inputs=[input_text,model_choice],
45
  outputs=[corrected_text, highlighted_diff]
46
  )
47
 
 
82
  with gr.Row():
83
  article_content = gr.Textbox(lines=10, label="Article Content", rtl=True, elem_classes="textbox1")
84
  with gr.Row():
85
+ model_choice = gr.Dropdown(choices=list(MODEL_OPTIONS_TITLE.keys()), value="V6 Model", label="Model")
86
  with gr.Row():
87
  seed = gr.Slider(0, 10000, value=42, step=1, label="Random Seed")
88
  use_sampling = gr.Checkbox(label="Use Sampling (Creative/Random)", value=False)
title_gen.py CHANGED
@@ -2,9 +2,10 @@ import random
2
  import numpy as np
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
5
 
6
  # Available models
7
- MODEL_OPTIONS = {
8
  "V6 Model": "alakxender/t5-divehi-title-generation-v6",
9
  "XS Model": "alakxender/t5-dhivehi-title-generation-xs"
10
  }
@@ -28,6 +29,7 @@ prefix = "2title: "
28
  max_input_length = 512
29
  max_target_length = 32
30
 
 
31
  def generate_title(content, seed, use_sampling, model_choice):
32
  random.seed(seed)
33
  np.random.seed(seed)
@@ -35,7 +37,7 @@ def generate_title(content, seed, use_sampling, model_choice):
35
  if torch.cuda.is_available():
36
  torch.cuda.manual_seed_all(seed)
37
 
38
- model_dir = MODEL_OPTIONS[model_choice]
39
  tokenizer, model = get_model_and_tokenizer(model_dir)
40
 
41
  input_text = prefix + content.strip()
 
2
  import numpy as np
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
+ import spaces
6
 
7
  # Available models
8
+ MODEL_OPTIONS_TITLE = {
9
  "V6 Model": "alakxender/t5-divehi-title-generation-v6",
10
  "XS Model": "alakxender/t5-dhivehi-title-generation-xs"
11
  }
 
29
  max_input_length = 512
30
  max_target_length = 32
31
 
32
+ @spaces.GPU()
33
  def generate_title(content, seed, use_sampling, model_choice):
34
  random.seed(seed)
35
  np.random.seed(seed)
 
37
  if torch.cuda.is_available():
38
  torch.cuda.manual_seed_all(seed)
39
 
40
+ model_dir = MODEL_OPTIONS_TITLE[model_choice]
41
  tokenizer, model = get_model_and_tokenizer(model_dir)
42
 
43
  input_text = prefix + content.strip()
typo_check.py CHANGED
@@ -5,19 +5,25 @@ import difflib
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
  import torch
7
  import gradio as gr
 
8
 
9
- # Load the fine-tuned model and tokenizer
10
- MODEL_PATH = "alakxender/dhivehi-quick-spell-check-t5" # Change this to your model path if different
 
 
 
11
 
12
  # Function to load model and tokenizer
13
- def load_model():
14
  print("Loading model and tokenizer...")
15
  try:
16
- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
 
 
17
  if tokenizer.pad_token is None:
18
  tokenizer.pad_token = tokenizer.eos_token
19
 
20
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH)
21
 
22
  # Move model to GPU if available
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -67,7 +73,7 @@ def correct_typo(text, model, tokenizer, device):
67
  return f"Error: {str(e)}"
68
 
69
  # Initialize model and tokenizer
70
- model, tokenizer, device = load_model()
71
 
72
  if model is None:
73
  print("Failed to load model. Please check your model and tokenizer paths.")
@@ -103,9 +109,10 @@ def highlight_differences(original, corrected):
103
  return f'<div class="dhivehi-diff">{" ".join(html_parts)}</div>'
104
 
105
  # Function to process the input for Gradio
106
- def process_input(text):
 
107
  if model is None:
108
- load_model()
109
 
110
  corrected = correct_typo(text, model, tokenizer, device)
111
  highlighted = highlight_differences(text, corrected)
 
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
  import torch
7
  import gradio as gr
8
+ import spaces
9
 
10
+ # Available models
11
+ MODEL_OPTIONS_TYPO = {
12
+ "A3 Model": "alakxender/t5-dhivehi-typo-corrector-asr",
13
+ "XS Model": "alakxender/dhivehi-quick-spell-check-t5"
14
+ }
15
 
16
  # Function to load model and tokenizer
17
+ def load_model(model_choice):
18
  print("Loading model and tokenizer...")
19
  try:
20
+ selected_model = MODEL_OPTIONS_TYPO[model_choice]
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained(selected_model)
23
  if tokenizer.pad_token is None:
24
  tokenizer.pad_token = tokenizer.eos_token
25
 
26
+ model = AutoModelForSeq2SeqLM.from_pretrained(selected_model)
27
 
28
  # Move model to GPU if available
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
73
  return f"Error: {str(e)}"
74
 
75
  # Initialize model and tokenizer
76
+ model, tokenizer, device = load_model("A3 Model")
77
 
78
  if model is None:
79
  print("Failed to load model. Please check your model and tokenizer paths.")
 
109
  return f'<div class="dhivehi-diff">{" ".join(html_parts)}</div>'
110
 
111
  # Function to process the input for Gradio
112
+ @spaces.GPU()
113
+ def process_input(text,model_choice):
114
  if model is None:
115
+ load_model(model_choice)
116
 
117
  corrected = correct_typo(text, model, tokenizer, device)
118
  highlighted = highlight_differences(text, corrected)