Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM | |
| from transformers import MBartForConditionalGeneration, MBart50TokenizerFast | |
| from aksharamukha import transliterate | |
| import torch | |
| from dotenv import load_dotenv | |
| import os | |
| import requests | |
| access_token = os.getenv('token') | |
| # Set up device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| chat_language = 'sin_Sinh' | |
| trans_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") | |
| eng_trans_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| translator = pipeline('translation', model=trans_model, tokenizer=eng_trans_tokenizer, src_lang="eng_Latn", tgt_lang=chat_language, max_length = 400, device=device) | |
| # Initialize translation pipelines | |
| pipe = pipeline("translation", model="thilina/mt5-sinhalese-english") | |
| sin_trans_model = AutoModelForSeq2SeqLM.from_pretrained("thilina/mt5-sinhalese-english") | |
| si_trans_tokenizer = AutoTokenizer.from_pretrained("thilina/mt5-sinhalese-english") | |
| singlish_pipe = pipeline("text2text-generation", model="Dhahlan2000/Simple_Translation-model-for-GPT-v15") | |
| # Translation functions | |
| def translate_Singlish_to_sinhala(text): | |
| translated_text = singlish_pipe(f"translate Singlish to Sinhala: {text}", clean_up_tokenization_spaces=False)[0]['generated_text'] | |
| return translated_text.replace('\u200d', '') | |
| def translate_english_to_sinhala(text): | |
| # Split the text into sentences or paragraphs | |
| parts = text.split("\n") # Split by new lines for paragraphs, adjust as needed | |
| translated_parts = [] | |
| for part in parts: | |
| translated_part = translator(part, clean_up_tokenization_spaces=False)[0]['translation_text'] | |
| translated_parts.append(translated_part) | |
| # Join the translated parts back together | |
| translated_text = "\n".join(translated_parts) | |
| return translated_text.replace("ප් රභූවරුන්", "").replace('\u200d', '') | |
| def translate_sinhala_to_english(text): | |
| # Split the text into sentences or paragraphs | |
| parts = text.split("\n") # Split by new lines for paragraphs, adjust as needed | |
| translated_parts = [] | |
| for part in parts: | |
| # Tokenize each part | |
| inputs = si_trans_tokenizer(part.strip(), return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| # Generate translation | |
| outputs = sin_trans_model.generate(**inputs) | |
| # Decode translated text while preserving formatting | |
| translated_part = si_trans_tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
| translated_parts.append(translated_part) | |
| # Join the translated parts back together | |
| translated_text = "\n".join(translated_parts) | |
| return translated_text | |
| def transliterate_from_sinhala(text): | |
| # Define the source and target scripts | |
| source_script = 'Sinhala' | |
| target_script = 'Velthuis' | |
| # Perform transliteration | |
| latin_text = transliterate.process(source_script, target_script, text) | |
| # Convert to a list to allow modification | |
| latin_text_list = list(latin_text) | |
| # Replace periods with the following character | |
| i = 0 | |
| for i in range(len(latin_text_list) - 1): | |
| if latin_text_list[i] == '.': | |
| latin_text_list[i] = '' | |
| if latin_text_list[i] == '*': | |
| latin_text_list[i] = '' | |
| if latin_text_list[i] == '\"': | |
| latin_text_list[i] = '' | |
| # Convert back to a string | |
| latin_text = ''.join(latin_text_list) | |
| return latin_text.lower() | |
| def transliterate_to_sinhala(text): | |
| # Define the source and target scripts | |
| source_script = 'Velthuis' | |
| target_script = 'Sinhala' | |
| # Perform transliteration | |
| latin_text = transliterate.process(source_script, target_script, text) | |
| return latin_text | |
| tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token = access_token) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "google/gemma-2b-it", | |
| torch_dtype=torch.bfloat16, | |
| token = access_token | |
| ) | |
| def conversation_predict(input_text): | |
| input_ids = tokenizer(input_text, return_tensors="pt") | |
| outputs = model.generate(**input_ids) | |
| return tokenizer.decode(outputs[0]) | |
| def ai_predicted(user_input): | |
| user_input = translate_Singlish_to_sinhala(user_input) | |
| print("You(Singlish): ", user_input,"\n") | |
| user_input = transliterate_to_sinhala(user_input) | |
| print("You(Sinhala): ", user_input,"\n") | |
| user_input = translate_sinhala_to_english(user_input) | |
| print("You(English): ", user_input,"\n") | |
| # Get AI response | |
| ai_response = conversation_predict(user_input) | |
| # Split the AI response into separate lines | |
| # ai_response_lines = ai_response.split("</s>") | |
| print("AI(English): ", ai_response,"\n") | |
| response = translate_english_to_sinhala(ai_response) | |
| print("AI(Sinhala): ", response,"\n") | |
| response = transliterate_from_sinhala(response) | |
| print(response) | |
| return response | |
| # Gradio Interface | |
| def respond( | |
| message, | |
| history: list[tuple[str, str]], | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| ): | |
| messages = [{"role": "system", "content": system_message}] | |
| for val in history: | |
| if val[0]: | |
| messages.append({"role": "user", "content": val[0]}) | |
| if val[1]: | |
| messages.append({"role": "assistant", "content": val[1]}) | |
| messages.append({"role": "user", "content": message}) | |
| response = ai_predicted(message) | |
| yield response | |
| demo = gr.ChatInterface( | |
| respond, | |
| additional_inputs=[ | |
| gr.Textbox(value="You are a friendly Chatbot.", label="System message"), | |
| gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
| gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-p (nucleus sampling)", | |
| ), | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |