Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer | |
from PIL import Image | |
import torch | |
import requests | |
# Load caption model | |
caption_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
# Load ChatTTS (via inference API) | |
CHAT_TTS_API = "https://api-inference.huggingface.co/models/2Noise/ChatTTS" | |
headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"} | |
def generate_caption(image): | |
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values | |
output_ids = caption_model.generate(pixel_values, max_length=50, num_beams=4) | |
caption = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
return caption | |
def tts_audio(text): | |
payload = {"inputs": text} | |
response = requests.post(CHAT_TTS_API, headers=headers, json=payload) | |
if response.status_code == 200: | |
return response.content | |
else: | |
raise Exception(f"TTS API 오류: {response.status_code}, {response.text}") | |
def process(image): | |
caption = generate_caption(image) | |
audio = tts_audio(caption) | |
return caption, (audio, "result.wav") | |
demo = gr.Interface( | |
fn=process, | |
inputs=gr.Image(type="pil"), | |
outputs=[gr.Text(label="설명"), gr.Audio(label="TTS 음성")], | |
title="🎨 AI 그림 설명 낭독기", | |
) | |
demo.launch() | |