Spaces:
Running
Running
| import gradio as gr | |
| from indicnlp.transliterate.unicode_transliterate import UnicodeIndicTransliterator | |
| from transformers import VisionEncoderDecoderModel, AutoProcessor, AutoTokenizer | |
| from PIL import Image | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| snapshot_download(repo_id = "QuickHawk/trocr-indic") | |
| ENCODER_MODEL_NAME = "facebook/deit-base-distilled-patch16-224" | |
| DECODER_MODEL_NAME = "ai4bharat/IndicBART" | |
| processor = AutoProcessor.from_pretrained(ENCODER_MODEL_NAME, use_fast=True) | |
| tokenizer = AutoTokenizer.from_pretrained(DECODER_MODEL_NAME, use_fast=True) | |
| model = VisionEncoderDecoderModel.from_pretrained(r"QuickHawk/trocr-indic") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| LANG_MAP = { | |
| "as": "Assamese", | |
| "bn": "Bengali", | |
| "gu": "Gujarati", | |
| "hi": "Hindi", | |
| "kn": "Kannada", | |
| "ml": "Malayalam", | |
| "mr": "Marathi", | |
| "or": "Odia", | |
| "pa": "Punjabi", | |
| "ta": "Tamil", | |
| "te": "Telugu", | |
| "ur": "Urdu" | |
| } | |
| bos_id = tokenizer._convert_token_to_id_with_added_voc("<s>") | |
| eos_id = tokenizer._convert_token_to_id_with_added_voc("</s>") | |
| pad_id = tokenizer._convert_token_to_id_with_added_voc("<pad>") | |
| def predict(image): | |
| with torch.no_grad(): | |
| pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device) | |
| outputs_ids = model.generate( | |
| pixel_values, | |
| use_cache=True, | |
| num_beams=4, | |
| max_length=128, | |
| min_length=1, | |
| early_stopping=True, | |
| pad_token_id=pad_id, | |
| bos_token_id=bos_id, | |
| eos_token_id=eos_id, | |
| decoder_start_token_id=tokenizer._convert_token_to_id_with_added_voc("<2en>") | |
| ) | |
| lang_token = tokenizer.decode(outputs_ids[0][1]) | |
| lang = lang_token[2:-1] | |
| caption = tokenizer.decode(outputs_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
| return UnicodeIndicTransliterator.transliterate(caption, "hi", lang), LANG_MAP[lang] | |
| gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=[gr.Text(label = "Predicted Text"), gr.Text(label = "Predicted Language")]).launch() | |