pravin007s commited on
Commit
8d26180
·
verified ·
1 Parent(s): 6908b72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -62
app.py CHANGED
@@ -6,7 +6,14 @@ Original file is located at
6
  """
7
 
8
  import os
 
9
  from huggingface_hub import login
 
 
 
 
 
 
10
 
11
  # Retrieve the actual token from the environment variable
12
  hf_token = os.getenv("HF_TOKEN")
@@ -18,23 +25,23 @@ if hf_token:
18
  else:
19
  raise ValueError("Hugging Face token not found in environment variables.")
20
 
21
- # Import necessary libraries
22
- from transformers import MarianMTModel, MarianTokenizer, pipeline
23
- import requests
24
- import io
25
- from PIL import Image
26
- import matplotlib.pyplot as plt
27
- import gradio as gr
28
-
29
- # Load the translation model and tokenizer
30
  model_name = "Helsinki-NLP/opus-mt-mul-en"
31
- tokenizer = MarianTokenizer.from_pretrained(model_name)
32
- model = MarianMTModel.from_pretrained(model_name)
33
 
34
  # Create a translation pipeline
35
  translator = pipeline("translation", model=model, tokenizer=tokenizer)
36
 
37
- # Function for translation
 
 
 
 
 
 
 
 
38
  def translate_text(tamil_text):
39
  try:
40
  translation = translator(tamil_text, max_length=40)
@@ -43,78 +50,68 @@ def translate_text(tamil_text):
43
  except Exception as e:
44
  return f"An error occurred: {str(e)}"
45
 
46
- # API credentials and endpoint
47
- API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev"
48
- headers = {"Authorization": f"Bearer {hf_token}"}
49
-
50
- # Function to send payload and generate image
51
- def generate_image(prompt):
52
  try:
53
- response = requests.post(API_URL, headers=headers, json={"inputs": prompt})
54
-
55
- # Check if the response is successful
56
- if response.status_code == 200:
57
- print("API call successful, generating image...")
58
- image_bytes = response.content
59
-
60
- # Try opening the image
61
- try:
62
- image = Image.open(io.BytesIO(image_bytes))
63
- return image
64
- except Exception as e:
65
- print(f"Error opening image: {e}")
66
- return None
67
- else:
68
- print(f"Failed to get image: Status code {response.status_code}")
69
- print("Response content:", response.text) # Print response for debugging
70
- return None
71
-
72
  except Exception as e:
73
  print(f"An error occurred: {e}")
74
  return None
75
 
76
- # Display image
77
- def show_image(image):
78
- if image:
79
- plt.imshow(image)
80
- plt.axis('off') # Hide axes
81
- plt.show()
82
- else:
83
- print("No image to display")
84
-
85
- # Load GPT-Neo model for creative text generation
86
- from transformers import AutoTokenizer, AutoModelForCausalLM
87
- gpt_neo_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
88
- gpt_neo_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M")
89
-
90
- # Function to generate creative text based on translated text
91
- def generate_creative_text(translated_text):
92
  input_ids = gpt_neo_tokenizer(translated_text, return_tensors='pt').input_ids
93
- generated_text_ids = gpt_neo_model.generate(input_ids, max_length=100)
94
  creative_text = gpt_neo_tokenizer.decode(generated_text_ids[0], skip_special_tokens=True)
95
  return creative_text
96
 
97
- # Function to handle the full workflow
98
- def translate_generate_image_and_text(tamil_text):
99
  # Step 1: Translate Tamil text to English
100
  translated_text = translate_text(tamil_text)
101
 
102
- # Step 2: Generate an image based on the translated text
103
- image = generate_image(translated_text)
104
 
105
  # Step 3: Generate creative text based on the translated text
106
  creative_text = generate_creative_text(translated_text)
107
 
108
  return translated_text, creative_text, image
109
 
110
- # Create Gradio interface
 
 
 
 
 
 
 
 
 
111
  interface = gr.Interface(
112
- fn=translate_generate_image_and_text,
113
  inputs="text",
114
  outputs=["text", "text", "image"],
115
- title="Tamil to English Translation, Image Generation & Creative Text",
116
- description="Enter Tamil text to translate to English, generate an image, and create creative text based on the translation."
 
117
  )
118
 
119
  # Launch Gradio app
120
  interface.launch()
 
 
6
  """
7
 
8
  import os
9
+ import asyncio
10
  from huggingface_hub import login
11
+ from transformers import MarianMTModel, MarianTokenizer, pipeline, AutoTokenizer, AutoModelForCausalLM
12
+ import aiohttp
13
+ import io
14
+ from PIL import Image
15
+ import matplotlib.pyplot as plt
16
+ import gradio as gr
17
 
18
  # Retrieve the actual token from the environment variable
19
  hf_token = os.getenv("HF_TOKEN")
 
25
  else:
26
  raise ValueError("Hugging Face token not found in environment variables.")
27
 
28
+ # Load the translation model and tokenizer (cached for faster loading)
 
 
 
 
 
 
 
 
29
  model_name = "Helsinki-NLP/opus-mt-mul-en"
30
+ tokenizer = MarianTokenizer.from_pretrained(model_name, cache_dir="./cache")
31
+ model = MarianMTModel.from_pretrained(model_name, cache_dir="./cache")
32
 
33
  # Create a translation pipeline
34
  translator = pipeline("translation", model=model, tokenizer=tokenizer)
35
 
36
+ # Load GPT-Neo model for creative text generation (cached)
37
+ gpt_neo_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M", cache_dir="./cache")
38
+ gpt_neo_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M", cache_dir="./cache")
39
+
40
+ # API credentials and endpoint for image generation
41
+ API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev"
42
+ headers = {"Authorization": f"Bearer {hf_token}"}
43
+
44
+ # Function for translation (batch translation for multiple inputs)
45
  def translate_text(tamil_text):
46
  try:
47
  translation = translator(tamil_text, max_length=40)
 
50
  except Exception as e:
51
  return f"An error occurred: {str(e)}"
52
 
53
+ # Asynchronous function to send payload and generate image
54
+ async def generate_image_async(prompt):
 
 
 
 
55
  try:
56
+ async with aiohttp.ClientSession() as session:
57
+ async with session.post(API_URL, headers=headers, json={"inputs": prompt}) as response:
58
+ if response.status == 200:
59
+ print("API call successful, generating image...")
60
+ image_bytes = await response.read()
61
+
62
+ # Try opening the image
63
+ try:
64
+ image = Image.open(io.BytesIO(image_bytes))
65
+ return image
66
+ except Exception as e:
67
+ print(f"Error opening image: {e}")
68
+ return None
69
+ else:
70
+ print(f"Failed to get image: Status code {response.status}")
71
+ return None
 
 
 
72
  except Exception as e:
73
  print(f"An error occurred: {e}")
74
  return None
75
 
76
+ # Generate creative text based on the translated text (with optimization for generation)
77
+ def generate_creative_text(translated_text, max_length=50):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  input_ids = gpt_neo_tokenizer(translated_text, return_tensors='pt').input_ids
79
+ generated_text_ids = gpt_neo_model.generate(input_ids, max_length=max_length, num_return_sequences=1, do_sample=True, top_k=50)
80
  creative_text = gpt_neo_tokenizer.decode(generated_text_ids[0], skip_special_tokens=True)
81
  return creative_text
82
 
83
+ # Handle the full workflow: translate, generate image, generate creative text
84
+ async def translate_generate_image_and_text(tamil_text):
85
  # Step 1: Translate Tamil text to English
86
  translated_text = translate_text(tamil_text)
87
 
88
+ # Step 2: Generate an image based on the translated text asynchronously
89
+ image = await generate_image_async(translated_text)
90
 
91
  # Step 3: Generate creative text based on the translated text
92
  creative_text = generate_creative_text(translated_text)
93
 
94
  return translated_text, creative_text, image
95
 
96
+ # Display image
97
+ def show_image(image):
98
+ if image:
99
+ plt.imshow(image)
100
+ plt.axis('off') # Hide axes
101
+ plt.show()
102
+ else:
103
+ print("No image to display")
104
+
105
+ # Create Gradio interface with live updates for faster feedback
106
  interface = gr.Interface(
107
+ fn=lambda tamil_text: asyncio.run(translate_generate_image_and_text(tamil_text)),
108
  inputs="text",
109
  outputs=["text", "text", "image"],
110
+ title="Optimized Tamil to English Translation, Image Generation & Creative Text",
111
+ description="Enter Tamil text to translate to English, generate an image, and create creative text based on the translation.",
112
+ live=True # Enables real-time outputs for faster feedback
113
  )
114
 
115
  # Launch Gradio app
116
  interface.launch()
117
+