24Sureshkumar's picture
Update app.py
7b64e9f verified
raw
history blame
5.63 kB
# app.py
# Install the required libraries before running this script:
# pip install transformers gradio Pillow requests torch
import os
import requests
from transformers import MarianMTModel, MarianTokenizer, AutoModelForCausalLM, AutoTokenizer
from PIL import Image, ImageDraw
import io
import gradio as gr
import torch
# Detect if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the MarianMT model and tokenizer for translation (Tamil to English)
model_name = "Helsinki-NLP/opus-mt-mul-en"
translation_model = MarianMTModel.from_pretrained(model_name).to(device)
translation_tokenizer = MarianTokenizer.from_pretrained(model_name)
# Load GPT-Neo for creative text generation
text_generation_model_name = "EleutherAI/gpt-neo-1.3B"
text_generation_model = AutoModelForCausalLM.from_pretrained(text_generation_model_name).to(device)
text_generation_tokenizer = AutoTokenizer.from_pretrained(text_generation_model_name)
# Add padding token to GPT-Neo tokenizer if not present
if text_generation_tokenizer.pad_token is None:
text_generation_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# Set your Hugging Face API key here or set the environment variable before running:
# export HF_API_KEY='your_actual_api_key' # on Linux/Mac
# setx HF_API_KEY "your_actual_api_key" # on Windows cmd (restart terminal after)
api_key = os.getenv('HF_API_KEY')
if api_key is None:
raise ValueError("Hugging Face API key is not set. Please set it as environment variable HF_API_KEY.")
headers = {"Authorization": f"Bearer {api_key}"}
# Define the API URL for image generation (replace with actual model URL)
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell" # Replace if you want a different model
# Query Hugging Face API to generate image with error handling
def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
if response.status_code != 200:
print(f"Error: Received status code {response.status_code}")
print(f"Response: {response.text}")
return None
return response.content
# Translate Tamil text to English
def translate_text(tamil_text):
inputs = translation_tokenizer(tamil_text, return_tensors="pt", padding=True, truncation=True).to(device)
translated_tokens = translation_model.generate(**inputs)
translation = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
return translation
# Generate an image based on the translated text with error handling
def generate_image(prompt):
image_bytes = query({"inputs": prompt})
if image_bytes is None:
# Return a blank image with error message
error_img = Image.new('RGB', (300, 300), color=(255, 0, 0))
d = ImageDraw.Draw(error_img)
d.text((10, 150), "Image Generation Failed", fill=(255, 255, 255))
return error_img
try:
image = Image.open(io.BytesIO(image_bytes))
return image
except Exception as e:
print(f"Error: {e}")
# Return an error image in case of failure
error_img = Image.new('RGB', (300, 300), color=(255, 0, 0))
d = ImageDraw.Draw(error_img)
d.text((10, 150), "Invalid Image Data", fill=(255, 255, 255))
return error_img
# Generate creative text based on the translated English text
def generate_creative_text(translated_text):
inputs = text_generation_tokenizer(translated_text, return_tensors="pt", padding=True, truncation=True).to(device)
generated_tokens = text_generation_model.generate(**inputs, max_length=100)
creative_text = text_generation_tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
return creative_text
# Function to handle the full workflow
def translate_generate_image_and_text(tamil_text):
# Step 1: Translate Tamil to English
translated_text = translate_text(tamil_text)
# Step 2: Generate an image from the translated text
image = generate_image(translated_text)
# Step 3: Generate creative text from the translated text
creative_text = generate_creative_text(translated_text)
return translated_text, creative_text, image
# CSS styling for the Gradio app
css = """
#transart-title {
font-size: 2.5em;
font-weight: bold;
color: #4CAF50;
text-align: center;
margin-bottom: 10px;
}
#transart-subtitle {
font-size: 1.25em;
text-align: center;
color: #555555;
margin-bottom: 20px;
}
body {
background-color: #f0f0f5;
}
.gradio-container {
font-family: 'Arial', sans-serif;
}
"""
# Title and subtitle HTML for Gradio markdown
title_markdown = """
# <div id="transart-title">TransArt</div>
### <div id="transart-subtitle">Tamil to English Translation, Creative Text & Image Generation</div>
"""
# Build Gradio interface
with gr.Blocks(css=css) as interface:
gr.Markdown(title_markdown)
with gr.Row():
with gr.Column():
tamil_input = gr.Textbox(label="Enter Tamil Text", placeholder="Type Tamil text here...", lines=3)
with gr.Column():
translated_output = gr.Textbox(label="Translated Text", interactive=False)
creative_text_output = gr.Textbox(label="Creative Generated Text", interactive=False)
generated_image_output = gr.Image(label="Generated Image")
gr.Button("Generate").click(
fn=translate_generate_image_and_text,
inputs=tamil_input,
outputs=[translated_output, creative_text_output, generated_image_output],
)
if __name__ == "__main__":
interface.launch(debug=True, server_name="0.0.0.0")