napstablook911 commited on
Commit
3a947b7
·
verified ·
1 Parent(s): c905e8e

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +148 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,150 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ import io
3
+ from PIL import Image
4
+ import soundfile as sf
5
+ import librosa
6
+ import numpy as np
7
+ import torch # Importa torch
8
+ import sys
9
+ sys.setrecursionlimit(2000) # Aumentiamo il limite di ricorsione
10
+
11
+ # --- Configurazione del Dispositivo ---
12
+ # Questo rileva automaticamente se MPS (GPU Apple Silicon) è disponibile
13
+ # Per ora, useremo la CPU come fallback se MPS è problematico per Stable Audio
14
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
15
+ # ******************** MODIFICA QUI: Forza device = "cpu" ********************
16
+ # Per superare i problemi di Stable Audio su MPS con float16/float32
17
+ # FORZA LA CPU PER TUTTI I MODELLI, per semplicità.
18
+ # Se la caption genera velocemente, potremmo tornare indietro e mettere il modello vit_gpt2 su MPS
19
+ device = "cpu"
20
+ # **************************************************************************
21
+ st.write(f"Utilizzo del dispositivo: {device}")
22
+
23
+
24
+ # --- 1. Caricamento dei Modelli AI (spostati qui, fuori dalle funzioni Streamlit) ---
25
+ @st.cache_resource
26
+ def load_models():
27
+ # Caricamento del modello per la captioning (ViT-GPT2)
28
+ from transformers import AutoFeatureExtractor, AutoTokenizer, AutoModelForVision2Seq
29
+ st.write("Caricamento del modello ViT-GPT2 per la captioning dell'immagine...")
30
+
31
+ vit_gpt2_feature_extractor = AutoFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
32
+ vit_gpt2_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
33
+
34
+ # Questo modello andrà sulla CPU
35
+ vit_gpt2_model = AutoModelForVision2Seq.from_pretrained("nlpconnect/vit-gpt2-image-captioning").to(device)
36
+
37
+ st.write("Modello ViT-GPT2 caricato.")
38
+
39
+ # Caricamento del modello Text-to-Audio (Stable Audio Open - 1.0)
40
+ from diffusers import DiffusionPipeline
41
+ st.write("Caricamento del modello Stable Audio Open - 1.0 per la generazione del soundscape...")
42
+ # ******************** MODIFICA QUI ********************
43
+ # Assicurati che non ci sia torch_dtype=torch.float16 e che vada sulla CPU
44
+ stable_audio_pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-audio-open-1.0", force_download=True).to(device)
45
+ # ******************************************************
46
+ st.write("Modello Stable Audio Open 1.0 caricato.")
47
+
48
+ return vit_gpt2_feature_extractor, vit_gpt2_model, vit_gpt2_tokenizer, stable_audio_pipeline
49
+
50
+ # Carica i modelli all'avvio dell'app
51
+ vit_gpt2_feature_extractor, vit_gpt2_model, vit_gpt2_tokenizer, stable_audio_pipeline = load_models()
52
+
53
+
54
+ # --- 2. Funzioni della Pipeline ---
55
+ def generate_image_caption(image_pil):
56
+ pixel_values = vit_gpt2_feature_extractor(images=image_pil.convert("RGB"), return_tensors="pt").pixel_values
57
+ pixel_values = pixel_values.to(device) # Sposta input su CPU
58
+
59
+ # Token di inizio per GPT-2, assicurandosi che sia su CPU
60
+ # Ottieni il decoder_start_token_id dal modello o dal tokenizer
61
+ if hasattr(vit_gpt2_model.config, "decoder_start_token_id"):
62
+ decoder_start_token_id = vit_gpt2_model.config.decoder_start_token_id
63
+ else:
64
+ if vit_gpt2_tokenizer.pad_token_id is not None:
65
+ decoder_start_token_id = vit_gpt2_tokenizer.pad_token_id
66
+ else:
67
+ decoder_start_token_id = 50256 # Default GPT-2 EOS token
68
+
69
+ # Crea un input_ids iniziale con il decoder_start_token_id e spostalo su CPU
70
+ input_ids = torch.ones((1, 1), device=device, dtype=torch.long) * decoder_start_token_id
71
+
72
+
73
+ output_ids = vit_gpt2_model.generate(
74
+ pixel_values=pixel_values,
75
+ input_ids=input_ids,
76
+ max_length=50,
77
+ do_sample=True,
78
+ top_k=50,
79
+ temperature=0.7,
80
+ no_repeat_ngram_size=2,
81
+ early_stopping=True
82
+ )
83
+ caption = vit_gpt2_tokenizer.decode(output_ids[0], skip_special_tokens=True)
84
+ return caption
85
+
86
+
87
+ def generate_soundscape_from_caption(caption: str, duration_seconds: int = 10):
88
+ st.write(f"Generazione soundscape per: '{caption}' (durata: {duration_seconds}s)")
89
+ with st.spinner("Generazione audio in corso..."):
90
+ try:
91
+ # Assicurati che il modello sia già su CPU dal caricamento
92
+ audio_output = stable_audio_pipeline(
93
+ prompt=caption,
94
+ audio_end_in_s=duration_seconds
95
+ ).audios
96
+
97
+ audio_data = audio_output[0].cpu().numpy()
98
+ sample_rate = stable_audio_pipeline.sample_rate
99
+
100
+ audio_data = audio_data.astype(np.float32)
101
+ audio_data = librosa.util.normalize(audio_data)
102
+
103
+ buffer = io.BytesIO()
104
+ sf.write(buffer, audio_data, sample_rate, format='WAV')
105
+ buffer.seek(0)
106
+ return buffer.getvalue(), sample_rate
107
+
108
+ except Exception as e:
109
+ st.error(f"Errore durante la generazione dell'audio: {e}")
110
+ return None, None
111
+
112
+
113
+ # --- 3. Interfaccia Streamlit ---
114
+ st.title("Generatore di Paesaggi Sonori da Immagini")
115
+ st.write("Carica un'immagine e otterrai una descrizione testuale e un paesaggio sonoro generato!")
116
+
117
+ uploaded_file = st.file_uploader("Scegli un'immagine...", type=["jpg", "jpeg", "png"])
118
+
119
+ if uploaded_file is not None:
120
+ input_image = Image.open(uploaded_file)
121
+ st.image(input_image, caption='Immagine Caricata.', use_column_width=True)
122
+ st.write("")
123
+
124
+ audio_duration = st.slider("Durata audio (secondi):", 5, 30, 10, key="audio_duration_slider")
125
+
126
+
127
+ if st.button("Genera Paesaggio Sonoro"):
128
+ st.subheader("Processo in Corso...")
129
+
130
+ # PASSO 1: Genera la caption
131
+ st.write("Generazione della caption...")
132
+ caption = generate_image_caption(input_image)
133
+ st.write(f"Caption generata: **{caption}**")
134
+
135
+ # PASSO 2: Genera il soundscape
136
+ st.write("Generazione del paesaggio sonoro...")
137
+ audio_data_bytes, sample_rate = generate_soundscape_from_caption(caption, duration_seconds=audio_duration)
138
+
139
+ if audio_data_bytes is not None:
140
+ st.subheader("Paesaggio Sonoro Generato")
141
+ st.audio(audio_data_bytes, format='audio/wav', sample_rate=sample_rate)
142
 
143
+ st.download_button(
144
+ label="Scarica Audio WAV",
145
+ data=audio_data_bytes,
146
+ file_name="paesaggio_sonoro_generato.wav",
147
+ mime="audio/wav"
148
+ )
149
+ else:
150
+ st.error("La generazione del paesaggio sonoro è fallita.")