Sshubam commited on
Commit
4b47d17
Β·
verified Β·
1 Parent(s): 00a469f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -1
app.py CHANGED
@@ -48,6 +48,36 @@ LANGUAGES = {
48
  "Bodo": "brx_Deva"
49
  }
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
 
53
  @spaces.GPU
@@ -168,7 +198,7 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
168
  text_input
169
  ],
170
  outputs=text_output,
171
- fn=generate,
172
  cache_examples=True,
173
  examples_per_page=5
174
  )
 
48
  "Bodo": "brx_Deva"
49
  }
50
 
51
+ @spaces.GPU
52
+ def generate_for_examples(
53
+ tgt_lang: str,
54
+ message: str,
55
+ max_new_tokens: int = 1024,
56
+ temperature: float = 0.6,
57
+ top_p: float = 0.9,
58
+ top_k: int = 50,
59
+ repetition_penalty: float = 1.2,
60
+ ) -> str:
61
+ conversation = []
62
+ conversation.append({"role": "user", "content": f"Translate the following text to {tgt_lang}: {message}"})
63
+
64
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
65
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
66
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
67
+ input_ids = input_ids.to(model.device)
68
+
69
+ outputs = model.generate(
70
+ input_ids=input_ids,
71
+ max_new_tokens=max_new_tokens,
72
+ do_sample=True,
73
+ top_p=top_p,
74
+ top_k=top_k,
75
+ temperature=temperature,
76
+ num_beams=1,
77
+ repetition_penalty=repetition_penalty,
78
+ )
79
+
80
+ return tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
81
 
82
 
83
  @spaces.GPU
 
198
  text_input
199
  ],
200
  outputs=text_output,
201
+ fn=generate_for_examples,
202
  cache_examples=True,
203
  examples_per_page=5
204
  )