Spaces:
Runtime error
Runtime error
| import os | |
| import random | |
| import uuid | |
| import base64 | |
| import json | |
| import re | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from PIL import Image | |
| from datetime import datetime | |
| from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler | |
| import anthropic | |
| # ============================================================ | |
| # === GLOBALS & DATA STORAGE FILES | |
| # ============================================================ | |
| LIKES_CACHE_FILE = "likes_cache.json" | |
| LOG_CACHE_FILE = "log_cache.json" | |
| QUOTE_CACHE_FILE = "quotes_cache.json" | |
| STATIC_URL_PREFIX = "https://huggingface.co/spaces/awacke1/dalle-3-xl-lora-v2/file=" | |
| # Initialize caches / load from JSON | |
| def load_json(file): | |
| if os.path.exists(file): | |
| with open(file, 'r', encoding='utf-8') as f: | |
| return json.load(f) | |
| return {} | |
| def save_json(file, data): | |
| with open(file, 'w', encoding='utf-8') as f: | |
| json.dump(data, f, indent=4) | |
| likes_cache = load_json(LIKES_CACHE_FILE) or {} | |
| chat_logs = load_json(LOG_CACHE_FILE) if os.path.exists(LOG_CACHE_FILE) else [] | |
| quotes = load_json(QUOTE_CACHE_FILE) if os.path.exists(QUOTE_CACHE_FILE) else [] | |
| # DataFrame for images | |
| image_metadata = pd.DataFrame(columns=['Filename','Prompt','Likes','Dislikes','Hearts','Created']) | |
| # ============================================================ | |
| # === ANTHROPIC CLIENT (Claude) | |
| # ============================================================ | |
| anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY", None) | |
| claude_client = anthropic.Anthropic(api_key=anthropic_api_key) if anthropic_api_key else None | |
| # ============================================================ | |
| # === IMAGE PIPELINE | |
| # ============================================================ | |
| pipe = None | |
| if torch.cuda.is_available(): | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| "fluently/Fluently-XL-v4", | |
| torch_dtype=torch.float16, | |
| use_safetensors=True, | |
| ) | |
| pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
| pipe.load_lora_weights("ehristoforu/dalle-3-xl-v2", weight_name="dalle-3-xl-lora-v2.safetensors", adapter_name="dalle") | |
| pipe.set_adapters("dalle") | |
| pipe.to("cuda") | |
| MAX_SEED = np.iinfo(np.int32).max | |
| # ============================================================ | |
| # === HELPER FUNCTIONS | |
| # ============================================================ | |
| def randomize_seed_fn(seed: int, randomize_seed: bool): | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| return int(seed) | |
| def sanitize_prompt(prompt): | |
| return re.sub(r'[^\w\s-]', '', prompt.lower())[:50] | |
| def save_image_locally(img, prompt): | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| safe_prompt = sanitize_prompt(prompt) | |
| filename = f"{timestamp}_{safe_prompt}.png" | |
| img.save(filename) | |
| if filename not in likes_cache: | |
| likes_cache[filename] = {'likes': 0, 'dislikes': 0, 'hearts': 0} | |
| save_json(LIKES_CACHE_FILE, likes_cache) | |
| global image_metadata | |
| new_row = { | |
| 'Filename': filename, | |
| 'Prompt': prompt, | |
| 'Likes': 0, | |
| 'Dislikes': 0, | |
| 'Hearts': 0, | |
| 'Created': str(datetime.now()) | |
| } | |
| image_metadata = pd.concat([image_metadata, pd.DataFrame([new_row])], ignore_index=True) | |
| return filename | |
| def log_input_output(user_input, model_output, link=""): | |
| global chat_logs | |
| chat_logs.append({ | |
| "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| "input": user_input, | |
| "output": model_output, | |
| "file_link": link | |
| }) | |
| save_json(LOG_CACHE_FILE, chat_logs) | |
| def generate_image( | |
| prompt, negative_prompt, use_negative_prompt, seed, width, height, guidance_scale, randomize_seed | |
| ): | |
| if pipe is None: | |
| return ["No GPU available, cannot generate images."], 0, [], [], [] | |
| seed = randomize_seed_fn(seed, randomize_seed) | |
| if not use_negative_prompt: | |
| negative_prompt = "" | |
| images = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| width=width, | |
| height=height, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=20, | |
| num_images_per_prompt=1, | |
| cross_attention_kwargs={"scale": 0.65}, | |
| output_type="pil", | |
| ).images | |
| filenames = [] | |
| for img in images: | |
| fname = save_image_locally(img, prompt) | |
| filenames.append(fname) | |
| links = [f"{STATIC_URL_PREFIX}{f}" for f in filenames] | |
| # Log the generation | |
| log_input_output(user_input=prompt, model_output="(image generated)", link=", ".join(links)) | |
| # Return Gradio objects | |
| return filenames, seed, links, get_image_gallery(), image_metadata.values.tolist() | |
| def get_image_gallery(): | |
| return [ | |
| (row["Filename"], f"{row['Filename']}\nPrompt: {row['Prompt']}\n👍 {row['Likes']} 👎 {row['Dislikes']} ❤️ {row['Hearts']}") | |
| for _, row in image_metadata.iterrows() | |
| if os.path.exists(row["Filename"]) | |
| ] | |
| def vote_image(filename, vote_type): | |
| if filename and filename in likes_cache: | |
| likes_cache[filename][vote_type] += 1 | |
| save_json(LIKES_CACHE_FILE, likes_cache) | |
| idx = image_metadata.index[image_metadata['Filename'] == filename] | |
| if not idx.empty: | |
| image_metadata.at[idx, vote_type.capitalize()] = image_metadata.at[idx, vote_type.capitalize()] + 1 | |
| return get_image_gallery(), image_metadata.values.tolist() | |
| def delete_image(filename): | |
| if filename and os.path.exists(filename): | |
| os.remove(filename) | |
| if filename in likes_cache: | |
| del likes_cache[filename] | |
| save_json(LIKES_CACHE_FILE, likes_cache) | |
| global image_metadata | |
| image_metadata = image_metadata[image_metadata['Filename'] != filename] | |
| return get_image_gallery(), image_metadata.values.tolist() | |
| def delete_all_images(): | |
| global image_metadata, likes_cache | |
| for f in image_metadata["Filename"].tolist(): | |
| if os.path.exists(f): | |
| os.remove(f) | |
| image_metadata = pd.DataFrame(columns=['Filename','Prompt','Likes','Dislikes','Hearts','Created']) | |
| likes_cache.clear() | |
| save_json(LIKES_CACHE_FILE, likes_cache) | |
| return get_image_gallery(), image_metadata.values.tolist() | |
| # === QUOTES Demo (Optional) === | |
| def add_quote(q): | |
| if q.strip(): | |
| quotes.append({ | |
| "text": q, | |
| "likes": 0, | |
| "created": datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| }) | |
| save_json(QUOTE_CACHE_FILE, quotes) | |
| return [[idx, itm["text"], itm["likes"], itm["created"]] for idx, itm in enumerate(quotes)] | |
| def like_quote(idx): | |
| if 0 <= idx < len(quotes): | |
| quotes[idx]["likes"] += 1 | |
| save_json(QUOTE_CACHE_FILE, quotes) | |
| return [[i, itm["text"], itm["likes"], itm["created"]] for i, itm in enumerate(quotes)] | |
| # === CLAUDE Chat === | |
| def chat_claude(user_message): | |
| if not claude_client: | |
| return "No Anthropic API key configured." | |
| if not user_message.strip(): | |
| return "Empty message." | |
| resp = claude_client.messages.create( | |
| model="claude-3-sonnet-20240229", | |
| max_tokens=1000, | |
| messages=[{"role": "user", "content": user_message}], | |
| ) | |
| text = resp.content[0].text | |
| log_input_output(user_input=user_message, model_output=text, link="") | |
| return text | |
| # === Refresh gallery + DF | |
| def refresh_gallery_and_df(): | |
| return gr.update(value=get_image_gallery()), gr.update(value=image_metadata.values.tolist()) | |
| # ============================================================ | |
| # === BUILD GRADIO UI | |
| # ============================================================ | |
| DESCRIPTION = """# 🎨 ArtForge & Claude Chat | |
| Generate AI art, chat with Claude, log everything, and vote on images. | |
| """ | |
| examples = [ | |
| "Futuristic cityscape in neon lighting", | |
| "Cute cat wearing a wizard hat", | |
| "Surreal landscape with floating islands", | |
| ] | |
| with gr.Blocks(css=".gradio-container {max-width: 1024px !important}") as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Tab("Generate Images"): | |
| with gr.Row(): | |
| prompt = gr.Text(label="Prompt", max_lines=1) | |
| run_button = gr.Button("Run") | |
| result = gr.Gallery(label="Result", columns=1, preview=True) | |
| use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True) | |
| negative_prompt = gr.Text( | |
| label="Negative prompt", | |
| lines=3, | |
| value="(deformed, distorted:1.3), poorly drawn, bad anatomy", | |
| ) | |
| seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| width = gr.Slider(label="Width", minimum=512, maximum=2048, step=64, value=1024) | |
| height = gr.Slider(label="Height", minimum=512, maximum=2048, step=64, value=1024) | |
| guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=20, step=0.5, value=7) | |
| run_button.click( | |
| fn=generate_image, | |
| inputs=[prompt, negative_prompt, use_negative_prompt, seed, width, height, guidance_scale, randomize_seed], | |
| outputs=[result, seed, gr.HTML(visible=False), gr.Gallery(), gr.Dataframe()], | |
| api_name="run" | |
| ) | |
| gr.Examples(examples=examples, inputs=prompt) | |
| with gr.Tab("Chat with Claude"): | |
| claude_input = gr.Textbox(label="Your Message") | |
| claude_output = gr.Textbox(label="Claude's Reply", lines=4) | |
| send_claude = gr.Button("Send to Claude") | |
| send_claude.click(chat_claude, inputs=claude_input, outputs=claude_output) | |
| with gr.Tab("Logs & Management"): | |
| with gr.Accordion("All Logs", open=False): | |
| logs_data = gr.Dataframe( | |
| value=pd.DataFrame(chat_logs), | |
| label="Input/Output Logs", | |
| interactive=False, | |
| wrap=True | |
| ) | |
| with gr.Tab("Gallery & Voting"): | |
| image_gallery = gr.Gallery(label="Generated Images", columns=4) | |
| metadata_df = gr.Dataframe( | |
| label="Image Metadata", | |
| headers=["Filename", "Prompt", "Likes", "Dislikes", "Hearts", "Created"], | |
| interactive=False | |
| ) | |
| selected_image = gr.State() | |
| with gr.Row(): | |
| like_button = gr.Button("👍 Like") | |
| dislike_button = gr.Button("👎 Dislike") | |
| heart_button = gr.Button("❤️ Heart") | |
| delete_image_button = gr.Button("🗑️ Delete Image") | |
| delete_all_button = gr.Button("🗑️ Delete All") | |
| image_gallery.select(fn=lambda evt: evt, inputs=[], outputs=[selected_image]) | |
| like_button.click(fn=lambda x: vote_image(x, 'likes'), inputs=selected_image, outputs=[image_gallery, metadata_df]) | |
| dislike_button.click(fn=lambda x: vote_image(x, 'dislikes'), inputs=selected_image, outputs=[image_gallery, metadata_df]) | |
| heart_button.click(fn=lambda x: vote_image(x, 'hearts'), inputs=selected_image, outputs=[image_gallery, metadata_df]) | |
| delete_image_button.click(fn=delete_image, inputs=selected_image, outputs=[image_gallery, metadata_df]) | |
| delete_all_button.click(fn=delete_all_images, outputs=[image_gallery, metadata_df]) | |
| with gr.Tab("Quotes (Optional)"): | |
| quote_input = gr.Textbox(label="Enter a quote") | |
| add_q_button = gr.Button("Add Quote") | |
| quote_df = gr.Dataframe(value=[(idx, q['text'], q['likes'], q['created']) for idx,q in enumerate(quotes)], | |
| headers=["Index","Text","Likes","Created"], interactive=False) | |
| selected_quote = gr.Number(label="Index to Like") | |
| like_q_button = gr.Button("Like Quote") | |
| add_q_button.click(fn=add_quote, inputs=quote_input, outputs=quote_df) | |
| like_q_button.click(fn=like_quote, inputs=selected_quote, outputs=quote_df) | |
| demo.load(fn=refresh_gallery_and_df, outputs=[image_gallery, metadata_df]) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20).launch() | |