inkboard7 / app.py
aminskjen's picture
Update app.py
e67bfab verified
import os
import logging
import uuid
import base64
import requests
from datetime import datetime
from flask import Flask, render_template, request, jsonify, redirect, url_for
from flask_sqlalchemy import SQLAlchemy
from flask_login import LoginManager, login_user, logout_user, login_required, current_user
from sqlalchemy.orm import DeclarativeBase
import cohere
# ------------------ Setup ------------------
logging.basicConfig(level=logging.INFO)
class Base(DeclarativeBase):
pass
app = Flask(__name__, instance_path="/tmp/instance")
app.config["SESSION_COOKIE_SECURE"] = False
app.config["SESSION_COOKIE_SAMESITE"] = "Lax"
# ------------------ Database ------------------
app.config["SQLALCHEMY_DATABASE_URI"] = os.environ.get("DATABASE_URL", "sqlite:////tmp/inkboard.db")
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
app.config["SQLALCHEMY_ENGINE_OPTIONS"] = {
'pool_pre_ping': True,
"pool_recycle": 300,
}
db = SQLAlchemy(app, model_class=Base)
# ------------------ Flask-Login ------------------
login_manager = LoginManager()
login_manager.init_app(app)
login_manager.login_view = 'login'
login_manager.login_message = 'Please log in to access InkBoard'
# ------------------ API Keys ------------------
API_KEY = os.environ.get("API_KEY")
HUGGINGFACE_API_KEY = os.environ.get("HUGGINGFACE_API_KEY")
cohere_client = cohere.Client(API_KEY) if API_KEY else None
# ------------------ Models ------------------
from models import User, Creation
@login_manager.user_loader
def load_user(user_id):
return User.query.get(int(user_id))
# ------------------ Routes ------------------
@app.route('/')
def index():
if current_user.is_authenticated:
return render_template('dashboard.html', user=current_user)
return render_template('index.html')
@app.route('/register', methods=['GET', 'POST'])
def register():
if request.method == 'POST':
data = request.get_json() if request.is_json else request.form
username = data.get('username', '').strip()
email = data.get('email', '').strip()
password = data.get('password', '').strip()
if not all([username, email, password]):
return jsonify({'error': 'All fields are required'}), 400
if User.query.filter_by(username=username).first() or User.query.filter_by(email=email).first():
return jsonify({'error': 'Username or email already exists'}), 400
user = User(username=username, email=email)
user.set_password(password)
db.session.add(user)
db.session.commit()
login_user(user)
return jsonify({'success': True, 'redirect': url_for('index')}) if request.is_json else redirect(url_for('index'))
return render_template('register.html')
@app.route('/login', methods=['GET', 'POST'])
def login():
if request.method == 'POST':
data = request.get_json() if request.is_json else request.form
username = data.get('username', '').strip()
password = data.get('password', '').strip()
user = User.query.filter((User.username == username) | (User.email == username)).first()
if user and user.check_password(password):
login_user(user)
return jsonify({'success': True, 'redirect': url_for('index')}) if request.is_json else redirect(url_for('index'))
return jsonify({'error': 'Invalid credentials'}), 401
return render_template('login.html')
@app.route('/logout')
@login_required
def logout():
logout_user()
return redirect(url_for('index'))
# --- Modified generate route: no login required now ---
@app.route('/generate', methods=['POST'])
def generate_content():
if not cohere_client:
return jsonify({'error': 'Cohere API key not configured'}), 500
try:
data = request.get_json()
scene_idea = data.get('scene_idea', '').strip()
if not scene_idea:
return jsonify({'error': 'Please provide a scene idea'}), 400
story_prompt = f"Transform this scene idea into a vivid paragraph:\nScene idea: {scene_idea}"
story_response = cohere_client.generate(
model='command',
prompt=story_prompt,
max_tokens=200,
temperature=0.7,
k=0
)
expanded_story = story_response.generations[0].text.strip()
try:
image_url = generate_image_hf(scene_idea, expanded_story)
except Exception as e:
logging.warning(f"Image generation failed: {e}")
image_url = "<svg>Your fallback image here</svg>"
creation_id = str(uuid.uuid4())
creation = Creation(
id=creation_id,
user_id=current_user.id if current_user.is_authenticated else None,
scene_idea=scene_idea,
story=expanded_story,
image_url=image_url
)
db.session.add(creation)
db.session.commit()
return jsonify({
'success': True,
'story': expanded_story,
'image_url': image_url,
'creation_id': creation_id
})
except Exception as e:
logging.error(f"Error in /generate route: {e}")
return jsonify({'error': 'Server failed to generate content'}), 500
@app.route('/save_journal', methods=['POST'])
@login_required
def save_journal():
data = request.get_json()
creation_id = data.get('creation_id')
journal_entry = data.get('journal_entry', '').strip()
creation = Creation.query.filter_by(id=creation_id, user_id=current_user.id).first()
if not creation:
return jsonify({'error': 'Not found'}), 404
creation.journal_entry = journal_entry
creation.updated_at = datetime.utcnow()
db.session.commit()
return jsonify({'success': True})
@app.route('/get_creations')
@login_required
def get_creations():
creations = Creation.query.filter_by(user_id=current_user.id).order_by(Creation.created_at.desc()).all()
return jsonify({'creations': [{
'id': c.id,
'scene_idea': c.scene_idea,
'story': c.story,
'image_url': c.image_url,
'journal_entry': c.journal_entry,
'created_at': c.created_at.isoformat()
} for c in creations]})
# ------------------ Image Generation ------------------
def generate_image_hf(scene_idea, story):
if not HUGGINGFACE_API_KEY:
return None
prompt = f"An artistic illustration of: {scene_idea}, dreamy and vivid"
models = [
"runwayml/stable-diffusion-v1-5",
"stabilityai/stable-diffusion-2-1",
"CompVis/stable-diffusion-v1-4"
]
for model in models:
try:
res = requests.post(
f"https://api-inference.huggingface.co/models/{model}",
headers={"Authorization": f"Bearer {HUGGINGFACE_API_KEY}"},
json={"inputs": prompt},
timeout=60
)
if res.status_code == 200:
return f"data:image/png;base64,{base64.b64encode(res.content).decode()}"
except Exception as e:
logging.warning(f"Model {model} failed: {str(e)}")
return "<svg>Image generation failed</svg>"
# ------------------ Init & Run ------------------
with app.app_context():
db.create_all()
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)