transart / app.py
pravin007s's picture
Update app.py
8d26180 verified
raw
history blame
4.66 kB
# -*- coding: utf-8 -*-
"""gen ai project f.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1iF7hdOjWNeFUtGvUYdaFsBErJGnY1h5J
"""
import os
import asyncio
from huggingface_hub import login
from transformers import MarianMTModel, MarianTokenizer, pipeline, AutoTokenizer, AutoModelForCausalLM
import aiohttp
import io
from PIL import Image
import matplotlib.pyplot as plt
import gradio as gr
# Retrieve the actual token from the environment variable
hf_token = os.getenv("HF_TOKEN")
# Check if the token is retrieved properly
if hf_token:
# Use the retrieved token
login(token=hf_token, add_to_git_credential=True)
else:
raise ValueError("Hugging Face token not found in environment variables.")
# Load the translation model and tokenizer (cached for faster loading)
model_name = "Helsinki-NLP/opus-mt-mul-en"
tokenizer = MarianTokenizer.from_pretrained(model_name, cache_dir="./cache")
model = MarianMTModel.from_pretrained(model_name, cache_dir="./cache")
# Create a translation pipeline
translator = pipeline("translation", model=model, tokenizer=tokenizer)
# Load GPT-Neo model for creative text generation (cached)
gpt_neo_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M", cache_dir="./cache")
gpt_neo_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M", cache_dir="./cache")
# API credentials and endpoint for image generation
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev"
headers = {"Authorization": f"Bearer {hf_token}"}
# Function for translation (batch translation for multiple inputs)
def translate_text(tamil_text):
try:
translation = translator(tamil_text, max_length=40)
translated_text = translation[0]['translation_text']
return translated_text
except Exception as e:
return f"An error occurred: {str(e)}"
# Asynchronous function to send payload and generate image
async def generate_image_async(prompt):
try:
async with aiohttp.ClientSession() as session:
async with session.post(API_URL, headers=headers, json={"inputs": prompt}) as response:
if response.status == 200:
print("API call successful, generating image...")
image_bytes = await response.read()
# Try opening the image
try:
image = Image.open(io.BytesIO(image_bytes))
return image
except Exception as e:
print(f"Error opening image: {e}")
return None
else:
print(f"Failed to get image: Status code {response.status}")
return None
except Exception as e:
print(f"An error occurred: {e}")
return None
# Generate creative text based on the translated text (with optimization for generation)
def generate_creative_text(translated_text, max_length=50):
input_ids = gpt_neo_tokenizer(translated_text, return_tensors='pt').input_ids
generated_text_ids = gpt_neo_model.generate(input_ids, max_length=max_length, num_return_sequences=1, do_sample=True, top_k=50)
creative_text = gpt_neo_tokenizer.decode(generated_text_ids[0], skip_special_tokens=True)
return creative_text
# Handle the full workflow: translate, generate image, generate creative text
async def translate_generate_image_and_text(tamil_text):
# Step 1: Translate Tamil text to English
translated_text = translate_text(tamil_text)
# Step 2: Generate an image based on the translated text asynchronously
image = await generate_image_async(translated_text)
# Step 3: Generate creative text based on the translated text
creative_text = generate_creative_text(translated_text)
return translated_text, creative_text, image
# Display image
def show_image(image):
if image:
plt.imshow(image)
plt.axis('off') # Hide axes
plt.show()
else:
print("No image to display")
# Create Gradio interface with live updates for faster feedback
interface = gr.Interface(
fn=lambda tamil_text: asyncio.run(translate_generate_image_and_text(tamil_text)),
inputs="text",
outputs=["text", "text", "image"],
title="Optimized Tamil to English Translation, Image Generation & Creative Text",
description="Enter Tamil text to translate to English, generate an image, and create creative text based on the translation.",
live=True # Enables real-time outputs for faster feedback
)
# Launch Gradio app
interface.launch()