Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM | |
| from aksharamukha import transliterate | |
| import torch | |
| # Set up device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load translation models and tokenizers | |
| trans_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M").to(device) | |
| eng_trans_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") | |
| translator = pipeline('translation', model=trans_model, tokenizer=eng_trans_tokenizer, src_lang="eng_Latn", tgt_lang='sin_Sinh', max_length=400, device=device) | |
| sin_trans_model = AutoModelForSeq2SeqLM.from_pretrained("thilina/mt5-sinhalese-english").to(device) | |
| si_trans_tokenizer = AutoTokenizer.from_pretrained("thilina/mt5-sinhalese-english") | |
| singlish_pipe = pipeline("text2text-generation", model="Dhahlan2000/Simple_Translation-model-for-GPT-v14") | |
| # Translation functions | |
| def translate_Singlish_to_sinhala(text): | |
| translated_text = singlish_pipe(f"translate Singlish to Sinhala: {text}", clean_up_tokenization_spaces=False)[0]['generated_text'] | |
| return translated_text | |
| def translate_english_to_sinhala(text): | |
| parts = text.split("\n") | |
| translated_parts = [translator(part, clean_up_tokenization_spaces=False)[0]['translation_text'] for part in parts] | |
| return "\n".join(translated_parts).replace("ප් රභූවරුන්", "") | |
| def translate_sinhala_to_english(text): | |
| parts = text.split("\n") | |
| translated_parts = [] | |
| for part in parts: | |
| inputs = si_trans_tokenizer(part.strip(), return_tensors="pt", padding=True, truncation=True, max_length=512).to(device) | |
| outputs = sin_trans_model.generate(**inputs) | |
| translated_part = si_trans_tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
| translated_parts.append(translated_part) | |
| return "\n".join(translated_parts) | |
| def transliterate_from_sinhala(text): | |
| latin_text = transliterate.process('Sinhala', 'Velthuis', text).replace('.', '').replace('*', '').replace('"', '').lower() | |
| return latin_text | |
| def transliterate_to_sinhala(text): | |
| return transliterate.process('Velthuis', 'Sinhala', text) | |
| # Load conversation model | |
| conv_model_name = "google/gemma-7b" | |
| tokenizer = AutoTokenizer.from_pretrained(conv_model_name) | |
| model = AutoModelForCausalLM.from_pretrained(conv_model_name).to(device) | |
| def conversation_predict(text): | |
| input_ids = tokenizer(text, return_tensors="pt").to(device) | |
| outputs = model.generate(**input_ids) | |
| return tokenizer.decode(outputs[0]) | |
| def ai_predicted(user_input): | |
| if user_input.lower() == 'exit': | |
| return "Goodbye!" | |
| user_input = translate_Singlish_to_sinhala(user_input) | |
| user_input = transliterate_to_sinhala(user_input) | |
| user_input = translate_sinhala_to_english(user_input) | |
| ai_response = conversation_predict(user_input) | |
| ai_response_lines = ai_response.split("</s>") | |
| response = translate_english_to_sinhala(ai_response_lines[-1]) | |
| response = transliterate_from_sinhala(response) | |
| return response | |
| # Gradio Interface | |
| def respond( | |
| message, | |
| history: list[tuple[str, str]], | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| ): | |
| messages = [{"role": "system", "content": system_message}] | |
| for val in history: | |
| if val[0]: | |
| messages.append({"role": "user", "content": val[0]}) | |
| if val[1]: | |
| messages.append({"role": "assistant", "content": val[1]}) | |
| messages.append({"role": "user", "content": message}) | |
| response = ai_predicted(message) | |
| yield response | |
| demo = gr.ChatInterface( | |
| respond, | |
| additional_inputs=[ | |
| gr.Textbox(value="You are a friendly Chatbot.", label="System message"), | |
| gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
| gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-p (nucleus sampling)", | |
| ), | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |