Jose Benitez commited on
Commit
bbc89f6
·
0 Parent(s):
.env.example ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ REPLICATE_API_TOKEN=
2
+
3
+ SUPABASE_KEY=
4
+ SUPABASE_URL=
5
+
6
+ GOOGLE_CLIENT_ID=
7
+ GOOGLE_CLIENT_SECRET=
8
+ SECRET_KEY=
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ .env
Dockerfile ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.8-slim
2
+
3
+ WORKDIR /usr/src/app
4
+ COPY . .
5
+ RUN pip install --no-cache-dir gradio
6
+ EXPOSE 7860
7
+ ENV GRADIO_SERVER_NAME="0.0.0.0"
8
+
9
+ CMD ["python", "app.py"]
README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ test card
2
+ 4242 4242 4242 4242
3
+ wip
assets/logo.jpg ADDED
auth.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from starlette.config import Config
2
+ from authlib.integrations.starlette_client import OAuth
3
+ from config import GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET
4
+
5
+ config_data = {'GOOGLE_CLIENT_ID': GOOGLE_CLIENT_ID, 'GOOGLE_CLIENT_SECRET': GOOGLE_CLIENT_SECRET}
6
+ starlette_config = Config(environ=config_data)
7
+ oauth = OAuth(starlette_config)
8
+ oauth.register(
9
+ name='google',
10
+ server_metadata_url='https://accounts.google.com/.well-known/openid-configuration',
11
+ client_kwargs={'scope': 'openid email profile'},
12
+ )
config.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ import logging
4
+
5
+ load_dotenv()
6
+
7
+ SUPABASE_URL = os.getenv("SUPABASE_URL")
8
+ SUPABASE_KEY = os.getenv("SUPABASE_KEY")
9
+ SECRET_KEY = os.getenv("SECRET_KEY")
10
+
11
+ GOOGLE_CLIENT_ID = os.getenv("GOOGLE_CLIENT_ID")
12
+ GOOGLE_CLIENT_SECRET = os.getenv("GOOGLE_CLIENT_SECRET")
13
+
14
+ STRIPE_API_KEY = os.getenv("STRIPE_API_KEY")
15
+ STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET")
16
+
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
database.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from supabase import create_client, Client
3
+ from config import SUPABASE_URL, SUPABASE_KEY
4
+
5
+ supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
6
+
7
+ def get_user_credits(user_id):
8
+ user = supabase.table("users").select("generation_credits, train_credits").eq("id", user_id).execute()
9
+ if user.data:
10
+ return user.data[0]["generation_credits"], user.data[0]["train_credits"]
11
+ return 0, 0
12
+
13
+ def update_user_credits(user_id, generation_credits, train_credits):
14
+ supabase.table("users").update({
15
+ "generation_credits": generation_credits,
16
+ "train_credits": train_credits
17
+ }).eq("id", user_id).execute()
18
+
19
+ def get_or_create_user(google_id, email, name, given_name, profile_picture):
20
+ user = supabase.table("users").select("*").eq("google_id", google_id).execute()
21
+
22
+ if not user.data:
23
+ new_user = {
24
+ "google_id": google_id,
25
+ "email": email,
26
+ "name": name,
27
+ "profile_picture": profile_picture,
28
+ "generation_credits": 2,
29
+ "train_credits": 1,
30
+ "given_name": given_name
31
+ }
32
+ result = supabase.table("users").insert(new_user).execute()
33
+ return result.data[0]
34
+ else:
35
+ return user.data[0]
36
+
37
+ def get_lora_models_info():
38
+ lora_models = supabase.table("lora_models").select("*").execute()
39
+ return lora_models.data
40
+
41
+ def get_user_by_id(user_id):
42
+ user = supabase.table("users").select("*").eq("id", user_id).execute()
43
+ if user.data:
44
+ return user.data[0]
45
+ return None
gradio_app.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import os
4
+ import json
5
+ import zipfile
6
+ from pathlib import Path
7
+
8
+ from database import get_user_credits, update_user_credits, get_lora_models_info
9
+ from services.image_generation import generate_image
10
+ from services.train_lora import lora_pipeline
11
+ from utils.image_utils import url_to_pil_image
12
+
13
+ lora_models = get_lora_models_info()
14
+
15
+
16
+ if not isinstance(lora_models, list):
17
+ raise ValueError("Expected loras_models to be a list of dictionaries.")
18
+
19
+ login_css_path = Path(__file__).parent / 'static/css/login.css'
20
+ main_css_path = Path(__file__).parent / 'static/css/main.css'
21
+ landing_html_path = Path(__file__).parent / 'static/html/landing.html'
22
+ main_header_path = Path(__file__).parent / 'static/html/main_header.html'
23
+
24
+ if login_css_path.is_file(): # Check if the file exists
25
+ with login_css_path.open() as file:
26
+ login_css = file.read()
27
+
28
+ if main_css_path.is_file(): # Check if the file exists
29
+ with main_css_path.open() as file:
30
+ main_css = file.read()
31
+
32
+ if landing_html_path.is_file():
33
+ with landing_html_path.open() as file:
34
+ landin_page = file.read()
35
+
36
+ if main_header_path.is_file():
37
+ with main_header_path.open() as file:
38
+ main_header = file.read()
39
+
40
+ def update_selection(evt: gr.SelectData, width, height):
41
+ selected_lora = lora_models[evt.index]
42
+ new_placeholder = f"Ingresa un prompt para tu modelo {selected_lora['lora_name']}"
43
+ trigger_word = selected_lora["trigger_word"]
44
+ updated_text = f"#### Palabra clave: {trigger_word} ✨"
45
+
46
+ if "aspect" in selected_lora:
47
+ if selected_lora["aspect"] == "portrait":
48
+ width, height = 768, 1024
49
+ elif selected_lora["aspect"] == "landscape":
50
+ width, height = 1024, 768
51
+
52
+ return gr.update(placeholder=new_placeholder), updated_text, evt.index, width, height
53
+
54
+ def compress_and_train(files, model_name, trigger_word, train_steps, lora_rank, batch_size, learning_rate):
55
+ if not files:
56
+ return "No images uploaded. Please upload images before training."
57
+
58
+ # Create a directory in the user's home folder
59
+ output_dir = os.path.expanduser("~/gradio_training_data")
60
+ os.makedirs(output_dir, exist_ok=True)
61
+
62
+ # Create a zip file in the output directory
63
+ zip_path = os.path.join(output_dir, "training_data.zip")
64
+
65
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
66
+ for file_info in files:
67
+ file_path = file_info[0] # The first element of the tuple is the file path
68
+ file_name = os.path.basename(file_path)
69
+ zipf.write(file_path, file_name)
70
+
71
+ print(f"Zip file created at: {zip_path}")
72
+
73
+ print(f'[INFO] Procesando {trigger_word}')
74
+ # Now call the train_lora function with the zip file path
75
+ result = lora_pipeline(zip_path,
76
+ model_name,
77
+ trigger_word=trigger_word,
78
+ steps=train_steps,
79
+ lora_rank=lora_rank,
80
+ batch_size=batch_size,
81
+ autocaption=True,
82
+ learning_rate=learning_rate)
83
+
84
+ return f"{result}\n\nZip file saved at: {zip_path}"
85
+
86
+ def run_lora(request: gr.Request, prompt, cfg_scale, steps, selected_index, randomize_seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
87
+ user = request.session.get('user')
88
+ if not user:
89
+ raise gr.Error("User not authenticated. Please log in.")
90
+
91
+ generation_credits, _ = get_user_credits(user['id'])
92
+
93
+ if generation_credits <= 0:
94
+ raise gr.Error("Ya no tienes creditos disponibles. Compra para continuar.")
95
+
96
+ image_url = generate_image(prompt, steps, cfg_scale, width, height, lora_scale, progress)
97
+ image = url_to_pil_image(image_url)
98
+
99
+ # Update user's credits
100
+ new_generation_credits = generation_credits - 1
101
+ update_user_credits(user['id'], new_generation_credits, user['train_credits'])
102
+
103
+ # Update session data
104
+ user['generation_credits'] = new_generation_credits
105
+ request.session['user'] = user
106
+
107
+ print(f"Generation credits remaining: {new_generation_credits}")
108
+
109
+ return image, new_generation_credits
110
+
111
+ def display_credits(request: gr.Request):
112
+ user = request.session.get('user')
113
+ if user:
114
+ generation_credits, train_credits = get_user_credits(user['id'])
115
+ return generation_credits, train_credits
116
+ return 0, 0
117
+
118
+ def load_greet_and_credits(request: gr.Request):
119
+ greeting = greet(request)
120
+ generation_credits, train_credits = display_credits(request)
121
+ return greeting, generation_credits, train_credits
122
+
123
+ def greet(request: gr.Request):
124
+ user = request.session.get('user')
125
+ if user:
126
+ with gr.Column():
127
+ with gr.Row():
128
+ greeting = f"Hola 👋 {user['given_name']}!"
129
+ return f"{greeting}\n"
130
+ return "OBTU AI. Please log in."
131
+
132
+ with gr.Blocks(theme=gr.themes.Soft(), css=login_css) as login_demo:
133
+ with gr.Column(elem_id="google-btn-container", elem_classes="google-btn-container svelte-vt1mxs gap"):
134
+ btn = gr.Button("Iniciar Sesion con Google", elem_classes="login-with-google-btn")
135
+ _js_redirect = """
136
+ () => {
137
+ url = '/login' + window.location.search;
138
+ window.open(url, '_blank');
139
+ }
140
+ """
141
+ btn.click(None, js=_js_redirect)
142
+ gr.HTML(landin_page)
143
+
144
+
145
+ header = '<script src="https://cdn.lordicon.com/lordicon.js"></script>'
146
+
147
+ with gr.Blocks(theme=gr.themes.Soft(), head=header, css=main_css) as main_demo:
148
+ title = gr.HTML(main_header)
149
+
150
+ with gr.Column(elem_id="logout-btn-container"):
151
+ gr.Button("Salir", link="/logout", elem_id="logout_btn")
152
+
153
+
154
+ greetings = gr.Markdown("Loading user information...")
155
+ gr.Button("Comprar Creditos", link="/buy_credits", elem_id="buy_credits_btn")
156
+
157
+ selected_index = gr.State(None)
158
+
159
+ with gr.Row():
160
+ with gr.Column():
161
+ generation_credits_display = gr.Number(label="Generation Credits", precision=0, interactive=False)
162
+ with gr.Column():
163
+ train_credits_display = gr.Number(label="Training Credits", precision=0, interactive=False)
164
+
165
+
166
+ with gr.Tabs():
167
+ with gr.TabItem('Generacion'):
168
+ with gr.Row():
169
+ with gr.Column(scale=3):
170
+ prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Ingresa un prompt para empezar a crear")
171
+ with gr.Column(scale=1, elem_id="gen_column"):
172
+ generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
173
+
174
+ with gr.Row():
175
+ with gr.Column(scale=4):
176
+ result = gr.Image(label="Imagen Generada")
177
+
178
+ with gr.Column(scale=3):
179
+ with gr.Accordion("Tus Modelos"):
180
+ user_model_gallery = gr.Gallery(
181
+ label="Galeria de Modelos",
182
+ allow_preview=False,
183
+ columns=3,
184
+ elem_id="galley"
185
+ )
186
+
187
+ with gr.Accordion("Modelos Publicos", open=False):
188
+ selected_info = gr.Markdown("")
189
+ gallery = gr.Gallery(
190
+ [(item["image_url"], item["lora_name"]) for item in lora_models],
191
+ label="Galeria de Modelos Publicos",
192
+ allow_preview=False,
193
+ columns=3,
194
+ elem_id="gallery"
195
+ )
196
+
197
+
198
+ with gr.Accordion("Configuracion Avanzada", open=False):
199
+ with gr.Row():
200
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
201
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
202
+ with gr.Row():
203
+ width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
204
+ height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
205
+ with gr.Row():
206
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
207
+ lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.95)
208
+
209
+ gallery.select(
210
+ update_selection,
211
+ inputs=[width, height],
212
+ outputs=[prompt, selected_info, selected_index, width, height]
213
+ )
214
+
215
+ gr.on(
216
+ triggers=[generate_button.click, prompt.submit],
217
+ fn=run_lora,
218
+ inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, width, height, lora_scale],
219
+ outputs=[result, generation_credits_display]
220
+ )
221
+
222
+ with gr.TabItem("Training"):
223
+ gr.Markdown("# Entrena tu propio modelo 🧠")
224
+ gr.Markdown("En esta seccion podes entrenar tu propio modelo a partir de tus imagenes.")
225
+ with gr.Row():
226
+ with gr.Column():
227
+ train_dataset = gr.Gallery(columns=4, interactive=True, label="Tus Imagenes")
228
+ model_name = gr.Textbox(label="Nombre del Modelo",)
229
+ trigger_word = gr.Textbox(label="Palabra clave",
230
+ info="Esta seria una palabra clave para luego indicar al modelo cuando debe usar estas nuevas capacidad es que le vamos a ensenar",
231
+ )
232
+ train_button = gr.Button("Comenzar Training")
233
+ with gr.Accordion("Configuracion Avanzada", open=False):
234
+ train_steps = gr.Slider(label="Training Steps", minimum=100, maximum=10000, step=100, value=1000)
235
+ lora_rank = gr.Number(label='lora_rank', value=16)
236
+ batch_size = gr.Number(label='batch_size', value=1)
237
+ learning_rate = gr.Number(label='learning_rate', value=0.0004)
238
+ training_status = gr.Textbox(label="Training Status")
239
+
240
+
241
+
242
+ train_button.click(
243
+ compress_and_train,
244
+ inputs=[train_dataset, model_name, trigger_word, train_steps, lora_rank, batch_size, learning_rate],
245
+ outputs=training_status
246
+ )
247
+
248
+
249
+ #main_demo.load(greet, None, title)
250
+ #main_demo.load(greet, None, greetings)
251
+ #main_demo.load((greet, display_credits), None, [greetings, generation_credits_display, train_credits_display])
252
+ main_demo.load(load_greet_and_credits, None, [greetings, generation_credits_display, train_credits_display])
253
+
254
+
255
+
256
+ # TODO:
257
+ '''
258
+ - Galeria Modelos Propios (si existe alguno del user, si no, mostrar un mensaje para entrenar)
259
+ - Galeria Modelos Open Source (accordion)
260
+ - Training con creditos.
261
+ - Stripe(?)
262
+ - Mejorar boton de login/logout
263
+ - Retoque landing page
264
+ '''
265
+
266
+
main.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uvicorn
2
+ from fastapi import FastAPI
3
+ from fastapi.staticfiles import StaticFiles
4
+ from starlette.middleware.sessions import SessionMiddleware
5
+ from config import SECRET_KEY
6
+ from routes import router, get_user
7
+ from gradio_app import login_demo, main_demo
8
+ import gradio as gr
9
+ from pathlib import Path
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+
12
+ app = FastAPI()
13
+
14
+ login_demo.queue()
15
+ main_demo.queue()
16
+
17
+ static_dir = Path("./static")
18
+ app.mount("/static", StaticFiles(directory=static_dir, html=True), name="static")
19
+ #app.mount("/assets", StaticFiles(directory="assets", html=True), name="assets")
20
+
21
+ app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY)
22
+
23
+ app.include_router(router)
24
+
25
+ app = gr.mount_gradio_app(app, login_demo, path="/main")
26
+ app = gr.mount_gradio_app(app, main_demo, path="/gradio", auth_dependency=get_user, show_error=True)
27
+
28
+ if __name__ == "__main__":
29
+ uvicorn.run(app)
30
+
31
+
models.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+ class User(BaseModel):
4
+ id: str
5
+ google_id: str
6
+ email: str
7
+ name: str
8
+ given_name: str
9
+ profile_picture: str
10
+ generation_credits: int
11
+ train_credits: int
routes.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # routes.py
2
+ from fastapi import APIRouter, Depends, Request
3
+ from starlette.responses import RedirectResponse
4
+ from auth import oauth
5
+ from database import get_or_create_user, update_user_credits, get_user_by_id
6
+ from authlib.integrations.starlette_client import OAuthError
7
+ import gradio as gr
8
+ from utils.stripe_utils import create_checkout_session, verify_webhook, retrieve_stripe_session
9
+
10
+ router = APIRouter()
11
+
12
+ def get_user(request: Request):
13
+ user = request.session.get('user')
14
+ return user['name'] if user else None
15
+
16
+ @router.get('/')
17
+ def public(request: Request, user = Depends(get_user)):
18
+ root_url = gr.route_utils.get_root_url(request, "/", None)
19
+ print(f'Root URL: {root_url}')
20
+ if user:
21
+ return RedirectResponse(url=f'{root_url}/gradio/')
22
+ else:
23
+ return RedirectResponse(url=f'{root_url}/main/')
24
+
25
+ @router.route('/logout')
26
+ async def logout(request: Request):
27
+ request.session.pop('user', None)
28
+ return RedirectResponse(url='/')
29
+
30
+ @router.route('/login')
31
+ async def login(request: Request):
32
+ root_url = gr.route_utils.get_root_url(request, "/login", None)
33
+ redirect_uri = f"{root_url}/auth"
34
+ return await oauth.google.authorize_redirect(request, redirect_uri)
35
+
36
+ @router.route('/auth')
37
+ async def auth(request: Request):
38
+ try:
39
+ token = await oauth.google.authorize_access_token(request)
40
+ user_info = token.get('userinfo')
41
+ if user_info:
42
+ google_id = user_info['sub']
43
+ email = user_info['email']
44
+ name = user_info['name']
45
+ given_name = user_info['given_name']
46
+ profile_picture = user_info.get('picture', '')
47
+
48
+ user = get_or_create_user(google_id, email, name, given_name, profile_picture)
49
+ request.session['user'] = user
50
+
51
+ return RedirectResponse(url='/gradio')
52
+ else:
53
+ return RedirectResponse(url='/main')
54
+ except OAuthError as e:
55
+ print(f"OAuth Error: {str(e)}")
56
+ return RedirectResponse(url='/main')
57
+
58
+ # Handle Stripe payments
59
+ @router.get("/buy_credits")
60
+ async def buy_credits(request: Request):
61
+ user = request.session.get('user')
62
+ if not user:
63
+ return {"error": "User not authenticated"}
64
+ session = create_checkout_session(100, 50, user['id']) # $1 for 50 credits
65
+
66
+ # Store the session ID and user ID in the session
67
+ request.session['stripe_session_id'] = session['id']
68
+ request.session['user_id'] = user['id']
69
+ print(f"Stripe session created: {session['id']} for user {user['id']}")
70
+
71
+ return RedirectResponse(session['url'])
72
+
73
+ @router.post("/webhook")
74
+ async def stripe_webhook(request: Request):
75
+ payload = await request.body()
76
+ sig_header = request.headers.get("Stripe-Signature")
77
+
78
+ event = verify_webhook(payload, sig_header)
79
+
80
+ if event is None:
81
+ return {"error": "Invalid payload or signature"}
82
+
83
+ if event['type'] == 'checkout.session.completed':
84
+ session = event['data']['object']
85
+ user_id = session.get('client_reference_id')
86
+
87
+ if user_id:
88
+ # Fetch the user from the database
89
+ user = get_user_by_id(user_id) # You'll need to implement this function
90
+ if user:
91
+ # Update user's credits
92
+ new_credits = user['generation_credits'] + 50 # Assuming 50 credits were purchased
93
+ update_user_credits(user['id'], new_credits, user['train_credits'])
94
+ print(f"Credits updated for user {user['id']}")
95
+ else:
96
+ print(f"User not found for ID: {user_id}")
97
+ else:
98
+ print("No client_reference_id found in the session")
99
+
100
+ return {"status": "success"}
101
+
102
+ # @router.get("/success")
103
+ # async def payment_success(request: Request):
104
+ # print("Payment successful")
105
+ # user = request.session.get('user')
106
+ # print(user)
107
+ # if user:
108
+ # updated_user = get_user_by_id(user['id'])
109
+ # if updated_user:
110
+ # request.session['user'] = updated_user
111
+ # return RedirectResponse(url='/gradio', status_code=303)
112
+ # return RedirectResponse(url='/login', status_code=303)
113
+
114
+ @router.get("/cancel")
115
+ async def payment_cancel(request: Request):
116
+ print("Payment cancelled")
117
+ user = request.session.get('user')
118
+ print(user)
119
+ if user:
120
+ return RedirectResponse(url='/gradio', status_code=303)
121
+ return RedirectResponse(url='/login', status_code=303)
122
+
123
+ @router.get("/success")
124
+ async def payment_success(request: Request):
125
+ print("Payment successful")
126
+ stripe_session_id = request.session.get('stripe_session_id')
127
+ user_id = request.session.get('user_id')
128
+
129
+ print(f"Session data: stripe_session_id={stripe_session_id}, user_id={user_id}")
130
+
131
+ if stripe_session_id and user_id:
132
+ # Retrieve the Stripe session
133
+ stripe_session = retrieve_stripe_session(stripe_session_id)
134
+
135
+ if stripe_session.get('payment_status') == 'paid':
136
+ user = get_user_by_id(user_id)
137
+ if user:
138
+ # Update the session with the latest user data
139
+ request.session['user'] = user
140
+ print(f"User session updated: {user}")
141
+
142
+ # Clear the stripe_session_id and user_id from the session
143
+ request.session.pop('stripe_session_id', None)
144
+ request.session.pop('user_id', None)
145
+
146
+ return RedirectResponse(url='/gradio', status_code=303)
147
+ else:
148
+ print(f"User not found for ID: {user_id}")
149
+ else:
150
+ print(f"Payment not completed for session: {stripe_session_id}")
151
+ else:
152
+ print("No Stripe session ID or user ID found in the session")
153
+
154
+ return RedirectResponse(url='/login', status_code=303)
services/get_stripe.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ import stripe
2
+
services/image_generation.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import replicate
2
+ from PIL import Image
3
+ import requests
4
+ from io import BytesIO
5
+
6
+ #model_custom_test = "josebenitezg/flux-dev-ruth-estilo-1:c7ff81b58007c7cee3f69416e1e999192dafd8d1b1f269ea6cae137f04b34172"
7
+ flux_pro = "black-forest-labs/flux-pro"
8
+ def generate_image(prompt, steps, cfg_scale, width, height, lora_scale, progress, trigger_word='hi'):
9
+ print(f"Generating image for prompt: {prompt}")
10
+ img_url = replicate.run(
11
+ flux_pro,
12
+ input={
13
+ "steps": steps,
14
+ "prompt": prompt,
15
+ "guidance": cfg_scale,
16
+ "interval": 2,
17
+ "aspect_ratio": "1:1",
18
+ "safety_tolerance": 2
19
+ }
20
+ )
21
+ return img_url
services/train_lora.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import replicate
2
+ import os
3
+ from huggingface_hub import create_repo
4
+
5
+ REPLICATE_OWNER = "josebenitezg"
6
+
7
+ def lora_pipeline(zip_path, model_name, trigger_word="TOK", steps=1000, lora_rank=16, batch_size=1, autocaption=True, learning_rate=0.0004):
8
+ print(f'Creating dataset for {model_name}')
9
+ repo_name = f"joselobenitezg/flux-dev-{model_name}"
10
+ create_repo(repo_name, repo_type='model')
11
+
12
+ lora_name = f"flux-dev-{model_name}"
13
+
14
+ model = replicate.models.create(
15
+ owner=REPLICATE_OWNER,
16
+ name=lora_name,
17
+ visibility="public", # or "private" if you prefer
18
+ hardware="gpu-t4", # Replicate will override this for fine-tuned models
19
+ description="A fine-tuned FLUX.1 model"
20
+ )
21
+
22
+ print(f"Model created: {model.name}")
23
+ print(f"Model URL: https://replicate.com/{model.owner}/{model.name}")
24
+
25
+ # Now use this model as the destination for your training
26
+ print(f"[INFO] Starting training")
27
+
28
+ print(f'\n[INFO] Parametros a entrenar: \n Trigger word: {trigger_word}\n steps: {steps} \n lora_rank: {lora_rank}\n autocaption: {autocaption}\n learning_rate: {learning_rate}\n')
29
+ training = replicate.trainings.create(
30
+ version="ostris/flux-dev-lora-trainer:1296f0ab2d695af5a1b5eeee6e8ec043145bef33f1675ce1a2cdb0f81ec43f02",
31
+ input={
32
+ "input_images": open(zip_path, "rb"),
33
+ "steps": steps,
34
+ "lora_rank": lora_rank,
35
+ "batch_size": batch_size,
36
+ "autocaption": autocaption,
37
+ "trigger_word": trigger_word,
38
+ "learning_rate": learning_rate,
39
+ "hf_token": os.getenv('HF_TOKEN'), # optional
40
+ "hf_repo_id": repo_name, # optional
41
+ },
42
+ destination=f"{model.owner}/{model.name}"
43
+ )
44
+
45
+ print(f"Training started: {training.status}")
46
+ print(f"Training URL: https://replicate.com/p/{training.id}")
static/css/login.css ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ .login-with-google-btn {
3
+ display: inline-block;
4
+ width: 220px; /* Ancho fijo */
5
+ max-width: 100%; /* Para asegurar responsividad */
6
+ transition: background-color .3s, box-shadow .3s;
7
+ padding: 8px 12px 8px 35px;
8
+ border: none;
9
+ border-radius: 3px;
10
+ box-shadow: 0 -1px 0 rgba(0, 0, 0, .04), 0 1px 1px rgba(0, 0, 0, .25);
11
+ color: #757575;
12
+ font-size: 12px;
13
+ font-weight: 500;
14
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Oxygen, Ubuntu, Cantarell, "Fira Sans", "Droid Sans", "Helvetica Neue", sans-serif;
15
+ background-image: url();
16
+ background-color: white;
17
+ background-repeat: no-repeat;
18
+ background-position: 10px 50%;
19
+ background-size: 15px 15px;
20
+ text-align: center;
21
+ }
22
+
23
+ /* Contenedor para centrar el botón */
24
+ /* Estilos adicionales para el botón de Google y su contenedor */
25
+ .google-btn-container {
26
+ display: flex;
27
+ justify-content: flex-end;
28
+ width: 100%;
29
+ padding-right: 20px;
30
+ box-sizing: border-box;
31
+ position: absolute;
32
+ top: 20px;
33
+ right: 0;
34
+ }
35
+ .svelte-vt1mxs.gap {
36
+ position: static !important;
37
+ margin-top: 0 !important;
38
+ }
39
+
40
+ .login-with-google-btn:active {
41
+ background-color: #eeeeee;
42
+ }
43
+
44
+ .login-with-google-btn:focus {
45
+ outline: none;
46
+ box-shadow:
47
+ 0 -1px 0 rgba(0, 0, 0, .04),
48
+ 0 2px 4px rgba(0, 0, 0, .25),
49
+ 0 0 0 3px #c8dafc;
50
+ }
51
+
52
+ .login-with-google-btn:disabled {
53
+ filter: grayscale(100%);
54
+ background-color: #ebebeb;
55
+ box-shadow: 0 -1px 0 rgba(0, 0, 0, .04), 0 1px 1px rgba(0, 0, 0, .25);
56
+ cursor: not-allowed;
57
+ }
58
+
59
+ /* Estilos específicos para trabajar con las clases de Gradio */
60
+ .svelte-vt1mxs.gap {
61
+ position: absolute;
62
+ top: 20px;
63
+ right: 20px;
64
+ z-index: 1000;
65
+ }
66
+ @media(max-width: 768px) {
67
+ .feature-grid {
68
+ grid-template-columns: 1fr;
69
+ }
70
+ .google-btn-container {
71
+ position: static;
72
+ justify-content: center;
73
+ padding-right: 0;
74
+ margin-top: 20px;
75
+ }
76
+ }
static/css/main.css ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #gen_btn {
2
+ height: 100%
3
+ }
4
+
5
+ #title {
6
+ text-align: center
7
+ }
8
+
9
+ #title h1 {
10
+ font-size: 3em;
11
+ display: inline-flex;
12
+ align-items: center
13
+ }
14
+
15
+ #title img {
16
+ width: 100px;
17
+ margin-right: 0.5em
18
+ }
19
+
20
+ #gallery .grid-wrap {
21
+ height: 10vh
22
+ }
23
+
24
+ /* Estilo para el contenedor del botón */
25
+ #logout-btn-container.svelte-vt1mxs.gap {
26
+ position: absolute;
27
+ top: 10px;
28
+ right: 10px;
29
+ z-index: 1000;
30
+ display: flex;
31
+ justify-content: flex-end;
32
+ width: auto;
33
+ }
34
+
35
+ /* Estilo para el botón de logout */
36
+ #logout_btn.lg.secondary.svelte-cmf5ev {
37
+ width: auto;
38
+ min-width: 80px;
39
+ background-color: #f44336;
40
+ color: white;
41
+ border: none;
42
+ padding: 5px 10px;
43
+ border-radius: 5px;
44
+ cursor: pointer;
45
+ font-size: 0.9em;
46
+ transition: background-color 0.3s;
47
+ text-align: center;
48
+ text-decoration: none;
49
+ display: inline-block;
50
+ margin-left: auto; /* Empuja el botón hacia la derecha */
51
+ }
52
+
53
+ #logout_btn.lg.secondary.svelte-cmf5ev:hover {
54
+ background-color: #d32f2f;
55
+ }
56
+
57
+ /* Ajuste del layout principal si es necesario */
58
+ .gradio-container {
59
+ position: relative;
60
+ }
static/html/landing.html ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="es">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>ObtuAI - Creación Visual con IA</title>
7
+ <style>
8
+ :root {
9
+ --background-color: #0f0f0f;
10
+ --text-color: #ffffff;
11
+ --accent-color: #bb86fc;
12
+ --surface-color: #1e1e1e;
13
+ --border-radius: 20px;
14
+ }
15
+ body {
16
+ font-family: 'Arial', sans-serif;
17
+ line-height: 1.6;
18
+ color: var(--text-color);
19
+ background-color: var(--background-color);
20
+ margin: 0;
21
+ padding: 0;
22
+ }
23
+ .container {
24
+ width: 90%;
25
+ max-width: 1200px;
26
+ margin: auto;
27
+ overflow: hidden;
28
+ padding: 0 20px;
29
+ }
30
+ header {
31
+ background: var(--surface-color);
32
+ padding: 20px 0;
33
+ position: relative;
34
+ border-bottom-left-radius: var(--border-radius);
35
+ border-bottom-right-radius: var(--border-radius);
36
+ }
37
+ .header-content {
38
+ display: flex;
39
+ justify-content: space-between;
40
+ align-items: center;
41
+ }
42
+ .logo {
43
+ font-size: 2em;
44
+ font-weight: bold;
45
+ color: var(--accent-color);
46
+ }
47
+ #google-btn-container {
48
+ position: absolute;
49
+ right: 20px;
50
+ top: 20px;
51
+ }
52
+ .hero {
53
+ background: url('https://news.ubc.ca/wp-content/uploads/2023/08/AdobeStock_559145847.jpeg') no-repeat center center/cover;
54
+ height: 60vh;
55
+ position: relative;
56
+ display: flex;
57
+ align-items: center;
58
+ justify-content: center;
59
+ text-align: center;
60
+ margin-top: 20px;
61
+ border-radius: var(--border-radius);
62
+ overflow: hidden;
63
+ }
64
+ .hero-content {
65
+ background: rgba(0,0,0,0.7);
66
+ padding: 30px;
67
+ border-radius: var(--border-radius);
68
+ max-width: 600px;
69
+ }
70
+ .hero-content h1 {
71
+ font-size: 2.5em;
72
+ margin-bottom: 0.5em;
73
+ }
74
+ .features {
75
+ padding: 40px 0;
76
+ }
77
+ .feature-grid {
78
+ display: grid;
79
+ grid-template-columns: repeat(auto-fit, minmax(250px, 1fr));
80
+ gap: 30px;
81
+ }
82
+ .feature {
83
+ background: var(--surface-color);
84
+ padding: 30px;
85
+ border-radius: var(--border-radius);
86
+ box-shadow: 0 8px 20px rgba(0,0,0,0.3);
87
+ transition: transform 0.3s ease;
88
+ }
89
+ .feature:hover {
90
+ transform: translateY(-5px);
91
+ }
92
+ .testimonials {
93
+ background: var(--surface-color);
94
+ padding: 40px 0;
95
+ border-radius: var(--border-radius);
96
+ margin-top: 40px;
97
+ }
98
+ .testimonial {
99
+ background: rgba(255,255,255,0.1);
100
+ padding: 25px;
101
+ margin-top: 30px;
102
+ border-radius: var(--border-radius);
103
+ transition: transform 0.3s ease;
104
+ }
105
+ .testimonial:hover {
106
+ transform: scale(1.03);
107
+ }
108
+ footer {
109
+ background: var(--surface-color);
110
+ text-align: center;
111
+ padding: 20px 0;
112
+ margin-top: 40px;
113
+ border-top-left-radius: var(--border-radius);
114
+ border-top-right-radius: var(--border-radius);
115
+ }
116
+ @media(max-width: 768px) {
117
+ .header-content {
118
+ flex-direction: column;
119
+ text-align: center;
120
+ }
121
+ #google-btn-container {
122
+ position: static;
123
+ transform: none;
124
+ margin-top: 20px;
125
+ }
126
+ .hero-content {
127
+ padding: 20px;
128
+ }
129
+ }
130
+ </style>
131
+ </head>
132
+ <body>
133
+ <header>
134
+ <div class="container">
135
+ <div class="header-content">
136
+ <div class="logo">🎨 ObtuAI</div>
137
+ <div id="google-btn-container">
138
+ <!-- El botón será insertado aquí por Gradio -->
139
+ </div>
140
+ </div>
141
+ </div>
142
+ </header>
143
+
144
+ <div class="container">
145
+ <section class="hero">
146
+ <div class="hero-content">
147
+ <h1>🚀 Bienvenido al Futuro de la Creación Visual</h1>
148
+ <p>Crea imágenes con IA en segundos. ¡Escribe tu idea y mira cómo se convierte en arte!</p>
149
+ </div>
150
+ </section>
151
+
152
+ <section class="features">
153
+ <h2>🌟 Descubre el Poder de la Generación de Imágenes por IA</h2>
154
+ <div class="feature-grid">
155
+ <div class="feature">
156
+ <h3>Personaliza</h3>
157
+ <p>Alimenta tu modelo con tus propias imágenes y estilos.</p>
158
+ </div>
159
+ <div class="feature">
160
+ <h3>Entrena</h3>
161
+ <p>Nuestra IA aprende de tus preferencias.</p>
162
+ </div>
163
+ <div class="feature">
164
+ <h3>Crea</h3>
165
+ <p>Genera imágenes que reflejen tu visión única.</p>
166
+ </div>
167
+ </div>
168
+ </section>
169
+
170
+ <section class="testimonials">
171
+ <div class="container">
172
+ <h2>💬 Lo Que Dicen Nuestros Usuarios</h2>
173
+ <div class="testimonial">
174
+ <p>"ObtuAI ha revolucionado mi proceso creativo. ¡Ahora puedo visualizar mis ideas más locas en minutos!"</p>
175
+ <p><strong>- Ana, Diseñadora Gráfica</strong></p>
176
+ </div>
177
+ <div class="testimonial">
178
+ <p>"Entrenar mi propio modelo fue sorprendentemente fácil. Ahora hago fotografías mías y de mis clientes en segundos."</p>
179
+ <p><strong>- Carlos, Fotógrafo Profesional</strong></p>
180
+ </div>
181
+ </div>
182
+ </section>
183
+ </div>
184
+
185
+ <footer>
186
+ <p>ObtuAI - Tus ideas locas en píxeles con AI.</p>
187
+ </footer>
188
+ </body>
189
+ </html>
static/html/main_header.html ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="es">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>ObtuAI Header</title>
7
+ <style>
8
+ :root {
9
+ --background-color: #4337C9;
10
+ --text-color: #ffffff;
11
+ --accent-color: #bb86fc;
12
+ --surface-color:#4337C9;
13
+ --border-radius: 20px;
14
+ }
15
+ body {
16
+ line-height: 1.6;
17
+ color: var(--text-color);
18
+ background-color: var(--background-color);
19
+ margin: 0;
20
+ padding: 0;
21
+ }
22
+ header {
23
+ background: var(--surface-color);
24
+ padding: 15px 20px;
25
+ border-bottom-left-radius: var(--border-radius);
26
+ border-bottom-right-radius: var(--border-radius);
27
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
28
+ }
29
+ .header-content {
30
+ display: flex;
31
+ justify-content: space-between;
32
+ align-items: center;
33
+ max-width: 1200px;
34
+ margin: 0 auto;
35
+ }
36
+ .logo {
37
+ background: rgba(187, 134, 252, 0.1);
38
+ padding: 10px 20px;
39
+ border-radius: var(--border-radius);
40
+ }
41
+ .logo h1 {
42
+ font-size: 2em;
43
+ color: var(--accent-color);
44
+ margin: 0;
45
+ }
46
+ .status {
47
+ display: flex;
48
+ align-items: center;
49
+ background: rgba(255, 255, 255, 0.1);
50
+ padding: 10px 20px;
51
+ border-radius: var(--border-radius);
52
+ }
53
+ .badge {
54
+ background-color: #4CAF50;
55
+ color: white;
56
+ padding: 5px 15px;
57
+ border-radius: 25px;
58
+ font-size: 0.8em;
59
+ margin-left: 15px;
60
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
61
+ }
62
+
63
+ </style>
64
+ </head>
65
+ <body>
66
+ <header>
67
+ <div class="header-content">
68
+ <h1>Obtu AI 📸</h1>
69
+ <div class="status">
70
+ <lord-icon
71
+ src="https://cdn.lordicon.com/jgjfuggm.json"
72
+ trigger="loop"
73
+ state="loop-cycle"
74
+ colors="primary:#4be1ec,secondary:#4030e8"
75
+ style="width:50px;height:50px">
76
+ </lord-icon>
77
+ <span class="badge">GPU🔥</span>
78
+ </div>
79
+ </div>
80
+ </header>
81
+ </body>
82
+ </html>
utils/image_utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from PIL import Image
3
+ from io import BytesIO
4
+
5
+ def url_to_pil_image(url):
6
+ try:
7
+ # Ensure url is a string, not a list
8
+ if isinstance(url, list):
9
+ url = url[0] # Take the first URL if it's a list
10
+
11
+ response = requests.get(url)
12
+ response.raise_for_status()
13
+ image = Image.open(BytesIO(response.content))
14
+
15
+ # Convert to RGB if the image is in RGBA mode (for transparency)
16
+ if image.mode == 'RGBA':
17
+ image = image.convert('RGB')
18
+
19
+ return image
20
+ except Exception as e:
21
+ print(f"Error loading image from URL: {url}")
22
+ print(f"Error details: {str(e)}")
23
+ return None
utils/stripe_utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import stripe
2
+ from config import STRIPE_API_KEY, STRIPE_WEBHOOK_SECRET
3
+
4
+ stripe.api_key = STRIPE_API_KEY
5
+
6
+
7
+ def create_checkout_session(amount, quantity, user_id):
8
+ session = stripe.checkout.Session.create(
9
+ payment_method_types=['card'],
10
+ line_items=[{
11
+ 'price_data': {
12
+ 'currency': 'usd',
13
+ 'unit_amount': amount,
14
+ 'product_data': {
15
+ 'name': f'Buy {quantity} credits',
16
+ },
17
+ },
18
+ 'quantity': 1,
19
+ }],
20
+ mode='payment',
21
+ success_url='http://localhost:8000/success?session_id={CHECKOUT_SESSION_ID}&user_id=' + str(user_id),
22
+ cancel_url='http://localhost:8000/cancel?user_id=' + str(user_id),
23
+
24
+ client_reference_id=str(user_id), # Add this line
25
+ )
26
+ return session
27
+
28
+ def verify_webhook(payload, signature):
29
+ try:
30
+ event = stripe.Webhook.construct_event(
31
+ payload, signature, STRIPE_WEBHOOK_SECRET
32
+ )
33
+ return event
34
+ except ValueError as e:
35
+ return None
36
+ except stripe.error.SignatureVerificationError as e:
37
+ return None
38
+
39
+ def retrieve_stripe_session(session_id):
40
+ try:
41
+ return stripe.checkout.Session.retrieve(session_id)
42
+ except stripe.error.StripeError as e:
43
+ print(f"Error retrieving Stripe session: {str(e)}")
44
+ return None