import os import logging import uuid import base64 import requests from datetime import datetime from flask import Flask, render_template, request, jsonify, redirect, url_for from flask_sqlalchemy import SQLAlchemy from flask_login import LoginManager, login_user, logout_user, login_required, current_user from sqlalchemy.orm import DeclarativeBase import cohere # ------------------ Setup ------------------ logging.basicConfig(level=logging.INFO) class Base(DeclarativeBase): pass app = Flask(__name__, instance_path="/tmp/instance") app.config["SESSION_COOKIE_SECURE"] = False app.config["SESSION_COOKIE_SAMESITE"] = "Lax" # ------------------ Database ------------------ app.config["SQLALCHEMY_DATABASE_URI"] = os.environ.get("DATABASE_URL", "sqlite:////tmp/inkboard.db") app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False app.config["SQLALCHEMY_ENGINE_OPTIONS"] = { 'pool_pre_ping': True, "pool_recycle": 300, } db = SQLAlchemy(app, model_class=Base) # ------------------ Flask-Login ------------------ login_manager = LoginManager() login_manager.init_app(app) login_manager.login_view = 'login' login_manager.login_message = 'Please log in to access InkBoard' # ------------------ API Keys ------------------ API_KEY = os.environ.get("API_KEY") HUGGINGFACE_API_KEY = os.environ.get("HUGGINGFACE_API_KEY") cohere_client = cohere.Client(API_KEY) if API_KEY else None # ------------------ Models ------------------ from models import User, Creation @login_manager.user_loader def load_user(user_id): return User.query.get(int(user_id)) # ------------------ Routes ------------------ @app.route('/') def index(): if current_user.is_authenticated: return render_template('dashboard.html', user=current_user) return render_template('index.html') @app.route('/register', methods=['GET', 'POST']) def register(): if request.method == 'POST': data = request.get_json() if request.is_json else request.form username = data.get('username', '').strip() email = data.get('email', '').strip() password = data.get('password', '').strip() if not all([username, email, password]): return jsonify({'error': 'All fields are required'}), 400 if User.query.filter_by(username=username).first() or User.query.filter_by(email=email).first(): return jsonify({'error': 'Username or email already exists'}), 400 user = User(username=username, email=email) user.set_password(password) db.session.add(user) db.session.commit() login_user(user) return jsonify({'success': True, 'redirect': url_for('index')}) if request.is_json else redirect(url_for('index')) return render_template('register.html') @app.route('/login', methods=['GET', 'POST']) def login(): if request.method == 'POST': data = request.get_json() if request.is_json else request.form username = data.get('username', '').strip() password = data.get('password', '').strip() user = User.query.filter((User.username == username) | (User.email == username)).first() if user and user.check_password(password): login_user(user) return jsonify({'success': True, 'redirect': url_for('index')}) if request.is_json else redirect(url_for('index')) return jsonify({'error': 'Invalid credentials'}), 401 return render_template('login.html') @app.route('/logout') @login_required def logout(): logout_user() return redirect(url_for('index')) # --- Modified generate route: no login required now --- @app.route('/generate', methods=['POST']) def generate_content(): if not cohere_client: return jsonify({'error': 'Cohere API key not configured'}), 500 try: data = request.get_json() scene_idea = data.get('scene_idea', '').strip() if not scene_idea: return jsonify({'error': 'Please provide a scene idea'}), 400 story_prompt = f"Transform this scene idea into a vivid paragraph:\nScene idea: {scene_idea}" story_response = cohere_client.generate( model='command', prompt=story_prompt, max_tokens=200, temperature=0.7, k=0 ) expanded_story = story_response.generations[0].text.strip() try: image_url = generate_image_hf(scene_idea, expanded_story) except Exception as e: logging.warning(f"Image generation failed: {e}") image_url = "Your fallback image here" creation_id = str(uuid.uuid4()) creation = Creation( id=creation_id, user_id=current_user.id if current_user.is_authenticated else None, scene_idea=scene_idea, story=expanded_story, image_url=image_url ) db.session.add(creation) db.session.commit() return jsonify({ 'success': True, 'story': expanded_story, 'image_url': image_url, 'creation_id': creation_id }) except Exception as e: logging.error(f"Error in /generate route: {e}") return jsonify({'error': 'Server failed to generate content'}), 500 @app.route('/save_journal', methods=['POST']) @login_required def save_journal(): data = request.get_json() creation_id = data.get('creation_id') journal_entry = data.get('journal_entry', '').strip() creation = Creation.query.filter_by(id=creation_id, user_id=current_user.id).first() if not creation: return jsonify({'error': 'Not found'}), 404 creation.journal_entry = journal_entry creation.updated_at = datetime.utcnow() db.session.commit() return jsonify({'success': True}) @app.route('/get_creations') @login_required def get_creations(): creations = Creation.query.filter_by(user_id=current_user.id).order_by(Creation.created_at.desc()).all() return jsonify({'creations': [{ 'id': c.id, 'scene_idea': c.scene_idea, 'story': c.story, 'image_url': c.image_url, 'journal_entry': c.journal_entry, 'created_at': c.created_at.isoformat() } for c in creations]}) # ------------------ Image Generation ------------------ def generate_image_hf(scene_idea, story): if not HUGGINGFACE_API_KEY: return None prompt = f"An artistic illustration of: {scene_idea}, dreamy and vivid" models = [ "runwayml/stable-diffusion-v1-5", "stabilityai/stable-diffusion-2-1", "CompVis/stable-diffusion-v1-4" ] for model in models: try: res = requests.post( f"https://api-inference.huggingface.co/models/{model}", headers={"Authorization": f"Bearer {HUGGINGFACE_API_KEY}"}, json={"inputs": prompt}, timeout=60 ) if res.status_code == 200: return f"data:image/png;base64,{base64.b64encode(res.content).decode()}" except Exception as e: logging.warning(f"Model {model} failed: {str(e)}") return "Image generation failed" # ------------------ Init & Run ------------------ with app.app_context(): db.create_all() if __name__ == '__main__': app.run(host='0.0.0.0', port=5000)