File size: 5,631 Bytes
7b64e9f
 
 
 
 
 
 
 
 
 
e97aebb
7b64e9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59921cd
7b64e9f
 
a558492
7b64e9f
 
 
 
 
 
 
 
e5964e8
7b64e9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87e851b
7b64e9f
 
 
 
 
 
 
 
 
e97aebb
7b64e9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e97aebb
7b64e9f
 
 
 
 
ac06315
7b64e9f
 
 
 
 
 
 
 
 
 
ac06315
7b64e9f
 
 
 
 
e97aebb
87e851b
7b64e9f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# app.py

# Install the required libraries before running this script:
# pip install transformers gradio Pillow requests torch

import os
import requests
from transformers import MarianMTModel, MarianTokenizer, AutoModelForCausalLM, AutoTokenizer
from PIL import Image, ImageDraw
import io
import gradio as gr
import torch

# Detect if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the MarianMT model and tokenizer for translation (Tamil to English)
model_name = "Helsinki-NLP/opus-mt-mul-en"
translation_model = MarianMTModel.from_pretrained(model_name).to(device)
translation_tokenizer = MarianTokenizer.from_pretrained(model_name)

# Load GPT-Neo for creative text generation
text_generation_model_name = "EleutherAI/gpt-neo-1.3B"
text_generation_model = AutoModelForCausalLM.from_pretrained(text_generation_model_name).to(device)
text_generation_tokenizer = AutoTokenizer.from_pretrained(text_generation_model_name)

# Add padding token to GPT-Neo tokenizer if not present
if text_generation_tokenizer.pad_token is None:
    text_generation_tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# Set your Hugging Face API key here or set the environment variable before running:
# export HF_API_KEY='your_actual_api_key'  # on Linux/Mac
# setx HF_API_KEY "your_actual_api_key"    # on Windows cmd (restart terminal after)
api_key = os.getenv('HF_API_KEY')
if api_key is None:
    raise ValueError("Hugging Face API key is not set. Please set it as environment variable HF_API_KEY.")

headers = {"Authorization": f"Bearer {api_key}"}

# Define the API URL for image generation (replace with actual model URL)
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"  # Replace if you want a different model

# Query Hugging Face API to generate image with error handling
def query(payload):
    response = requests.post(API_URL, headers=headers, json=payload)
    if response.status_code != 200:
        print(f"Error: Received status code {response.status_code}")
        print(f"Response: {response.text}")
        return None
    return response.content

# Translate Tamil text to English
def translate_text(tamil_text):
    inputs = translation_tokenizer(tamil_text, return_tensors="pt", padding=True, truncation=True).to(device)
    translated_tokens = translation_model.generate(**inputs)
    translation = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
    return translation

# Generate an image based on the translated text with error handling
def generate_image(prompt):
    image_bytes = query({"inputs": prompt})

    if image_bytes is None:
        # Return a blank image with error message
        error_img = Image.new('RGB', (300, 300), color=(255, 0, 0))
        d = ImageDraw.Draw(error_img)
        d.text((10, 150), "Image Generation Failed", fill=(255, 255, 255))
        return error_img

    try:
        image = Image.open(io.BytesIO(image_bytes))
        return image
    except Exception as e:
        print(f"Error: {e}")
        # Return an error image in case of failure
        error_img = Image.new('RGB', (300, 300), color=(255, 0, 0))
        d = ImageDraw.Draw(error_img)
        d.text((10, 150), "Invalid Image Data", fill=(255, 255, 255))
        return error_img

# Generate creative text based on the translated English text
def generate_creative_text(translated_text):
    inputs = text_generation_tokenizer(translated_text, return_tensors="pt", padding=True, truncation=True).to(device)
    generated_tokens = text_generation_model.generate(**inputs, max_length=100)
    creative_text = text_generation_tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
    return creative_text

# Function to handle the full workflow
def translate_generate_image_and_text(tamil_text):
    # Step 1: Translate Tamil to English
    translated_text = translate_text(tamil_text)

    # Step 2: Generate an image from the translated text
    image = generate_image(translated_text)

    # Step 3: Generate creative text from the translated text
    creative_text = generate_creative_text(translated_text)

    return translated_text, creative_text, image

# CSS styling for the Gradio app
css = """
#transart-title {
    font-size: 2.5em;
    font-weight: bold;
    color: #4CAF50;
    text-align: center;
    margin-bottom: 10px;
}
#transart-subtitle {
    font-size: 1.25em;
    text-align: center;
    color: #555555;
    margin-bottom: 20px;
}
body {
    background-color: #f0f0f5;
}
.gradio-container {
    font-family: 'Arial', sans-serif;
}
"""

# Title and subtitle HTML for Gradio markdown
title_markdown = """
# <div id="transart-title">TransArt</div>
### <div id="transart-subtitle">Tamil to English Translation, Creative Text & Image Generation</div>
"""

# Build Gradio interface
with gr.Blocks(css=css) as interface:
    gr.Markdown(title_markdown)
    with gr.Row():
        with gr.Column():
            tamil_input = gr.Textbox(label="Enter Tamil Text", placeholder="Type Tamil text here...", lines=3)
        with gr.Column():
            translated_output = gr.Textbox(label="Translated Text", interactive=False)
            creative_text_output = gr.Textbox(label="Creative Generated Text", interactive=False)
            generated_image_output = gr.Image(label="Generated Image")

    gr.Button("Generate").click(
        fn=translate_generate_image_and_text,
        inputs=tamil_input,
        outputs=[translated_output, creative_text_output, generated_image_output],
    )

if __name__ == "__main__":
    interface.launch(debug=True, server_name="0.0.0.0")