|
import os |
|
import logging |
|
import uuid |
|
import base64 |
|
import requests |
|
from datetime import datetime |
|
from flask import Flask, render_template, request, jsonify, redirect, url_for |
|
from flask_sqlalchemy import SQLAlchemy |
|
from flask_login import LoginManager, login_user, logout_user, login_required, current_user |
|
from sqlalchemy.orm import DeclarativeBase |
|
import cohere |
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
class Base(DeclarativeBase): |
|
pass |
|
|
|
app = Flask(__name__, instance_path="/tmp/instance") |
|
|
|
app.config["SESSION_COOKIE_SECURE"] = False |
|
app.config["SESSION_COOKIE_SAMESITE"] = "Lax" |
|
|
|
|
|
|
|
app.config["SQLALCHEMY_DATABASE_URI"] = os.environ.get("DATABASE_URL", "sqlite:////tmp/inkboard.db") |
|
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False |
|
app.config["SQLALCHEMY_ENGINE_OPTIONS"] = { |
|
'pool_pre_ping': True, |
|
"pool_recycle": 300, |
|
} |
|
|
|
db = SQLAlchemy(app, model_class=Base) |
|
|
|
|
|
|
|
login_manager = LoginManager() |
|
login_manager.init_app(app) |
|
login_manager.login_view = 'login' |
|
login_manager.login_message = 'Please log in to access InkBoard' |
|
|
|
|
|
|
|
API_KEY = os.environ.get("API_KEY") |
|
HUGGINGFACE_API_KEY = os.environ.get("HUGGINGFACE_API_KEY") |
|
|
|
cohere_client = cohere.Client(API_KEY) if API_KEY else None |
|
|
|
|
|
|
|
from models import User, Creation |
|
|
|
@login_manager.user_loader |
|
def load_user(user_id): |
|
return User.query.get(int(user_id)) |
|
|
|
|
|
|
|
@app.route('/') |
|
def index(): |
|
if current_user.is_authenticated: |
|
return render_template('dashboard.html', user=current_user) |
|
return render_template('index.html') |
|
|
|
@app.route('/register', methods=['GET', 'POST']) |
|
def register(): |
|
if request.method == 'POST': |
|
data = request.get_json() if request.is_json else request.form |
|
username = data.get('username', '').strip() |
|
email = data.get('email', '').strip() |
|
password = data.get('password', '').strip() |
|
|
|
if not all([username, email, password]): |
|
return jsonify({'error': 'All fields are required'}), 400 |
|
|
|
if User.query.filter_by(username=username).first() or User.query.filter_by(email=email).first(): |
|
return jsonify({'error': 'Username or email already exists'}), 400 |
|
|
|
user = User(username=username, email=email) |
|
user.set_password(password) |
|
db.session.add(user) |
|
db.session.commit() |
|
login_user(user) |
|
|
|
return jsonify({'success': True, 'redirect': url_for('index')}) if request.is_json else redirect(url_for('index')) |
|
|
|
return render_template('register.html') |
|
|
|
@app.route('/login', methods=['GET', 'POST']) |
|
def login(): |
|
if request.method == 'POST': |
|
data = request.get_json() if request.is_json else request.form |
|
username = data.get('username', '').strip() |
|
password = data.get('password', '').strip() |
|
|
|
user = User.query.filter((User.username == username) | (User.email == username)).first() |
|
if user and user.check_password(password): |
|
login_user(user) |
|
return jsonify({'success': True, 'redirect': url_for('index')}) if request.is_json else redirect(url_for('index')) |
|
|
|
return jsonify({'error': 'Invalid credentials'}), 401 |
|
|
|
return render_template('login.html') |
|
|
|
@app.route('/logout') |
|
@login_required |
|
def logout(): |
|
logout_user() |
|
return redirect(url_for('index')) |
|
|
|
|
|
|
|
@app.route('/generate', methods=['POST']) |
|
def generate_content(): |
|
if not cohere_client: |
|
return jsonify({'error': 'Cohere API key not configured'}), 500 |
|
|
|
try: |
|
data = request.get_json() |
|
scene_idea = data.get('scene_idea', '').strip() |
|
if not scene_idea: |
|
return jsonify({'error': 'Please provide a scene idea'}), 400 |
|
|
|
story_prompt = f"Transform this scene idea into a vivid paragraph:\nScene idea: {scene_idea}" |
|
story_response = cohere_client.generate( |
|
model='command', |
|
prompt=story_prompt, |
|
max_tokens=200, |
|
temperature=0.7, |
|
k=0 |
|
) |
|
expanded_story = story_response.generations[0].text.strip() |
|
|
|
try: |
|
image_url = generate_image_hf(scene_idea, expanded_story) |
|
except Exception as e: |
|
logging.warning(f"Image generation failed: {e}") |
|
image_url = "<svg>Your fallback image here</svg>" |
|
|
|
creation_id = str(uuid.uuid4()) |
|
creation = Creation( |
|
id=creation_id, |
|
user_id=current_user.id if current_user.is_authenticated else None, |
|
scene_idea=scene_idea, |
|
story=expanded_story, |
|
image_url=image_url |
|
) |
|
db.session.add(creation) |
|
db.session.commit() |
|
|
|
return jsonify({ |
|
'success': True, |
|
'story': expanded_story, |
|
'image_url': image_url, |
|
'creation_id': creation_id |
|
}) |
|
|
|
except Exception as e: |
|
logging.error(f"Error in /generate route: {e}") |
|
return jsonify({'error': 'Server failed to generate content'}), 500 |
|
|
|
@app.route('/save_journal', methods=['POST']) |
|
@login_required |
|
def save_journal(): |
|
data = request.get_json() |
|
creation_id = data.get('creation_id') |
|
journal_entry = data.get('journal_entry', '').strip() |
|
|
|
creation = Creation.query.filter_by(id=creation_id, user_id=current_user.id).first() |
|
if not creation: |
|
return jsonify({'error': 'Not found'}), 404 |
|
|
|
creation.journal_entry = journal_entry |
|
creation.updated_at = datetime.utcnow() |
|
db.session.commit() |
|
return jsonify({'success': True}) |
|
|
|
@app.route('/get_creations') |
|
@login_required |
|
def get_creations(): |
|
creations = Creation.query.filter_by(user_id=current_user.id).order_by(Creation.created_at.desc()).all() |
|
return jsonify({'creations': [{ |
|
'id': c.id, |
|
'scene_idea': c.scene_idea, |
|
'story': c.story, |
|
'image_url': c.image_url, |
|
'journal_entry': c.journal_entry, |
|
'created_at': c.created_at.isoformat() |
|
} for c in creations]}) |
|
|
|
|
|
|
|
def generate_image_hf(scene_idea, story): |
|
if not HUGGINGFACE_API_KEY: |
|
return None |
|
|
|
prompt = f"An artistic illustration of: {scene_idea}, dreamy and vivid" |
|
models = [ |
|
"runwayml/stable-diffusion-v1-5", |
|
"stabilityai/stable-diffusion-2-1", |
|
"CompVis/stable-diffusion-v1-4" |
|
] |
|
|
|
for model in models: |
|
try: |
|
res = requests.post( |
|
f"https://api-inference.huggingface.co/models/{model}", |
|
headers={"Authorization": f"Bearer {HUGGINGFACE_API_KEY}"}, |
|
json={"inputs": prompt}, |
|
timeout=60 |
|
) |
|
if res.status_code == 200: |
|
return f"data:image/png;base64,{base64.b64encode(res.content).decode()}" |
|
except Exception as e: |
|
logging.warning(f"Model {model} failed: {str(e)}") |
|
|
|
return "<svg>Image generation failed</svg>" |
|
|
|
|
|
|
|
with app.app_context(): |
|
db.create_all() |
|
|
|
if __name__ == '__main__': |
|
app.run(host='0.0.0.0', port=5000) |