import os import random import gradio as gr from dotenv import load_dotenv from sqlalchemy import ( TIMESTAMP, Boolean, Column, ForeignKey, Integer, String, Text, create_engine, or_, ) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Mapped, relationship, sessionmaker from sqlalchemy.sql import func from datasets import load_dataset ds = load_dataset("bilguun/flickr30k-mn") load_dotenv() DATABASE_URL = os.getenv("DATABASE_URL") engine = create_engine(DATABASE_URL) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() # images_captions model class ImagesCaptions(Base): __tablename__ = "images_captions" id = Column(Integer, primary_key=True, index=True) image_id = Column(Integer) image_name = Column(String) caption_num = Column(Integer) caption = Column(Text) caption_mn_v1 = Column(Text) caption_mn_v2 = Column(Text) # task model class Task(Base): __tablename__ = "task" id = Column(Integer, primary_key=True, index=True) image_caption_id = Column(Integer, ForeignKey("images_captions.id")) caption_num = Column(Integer) reverse_caption = Column(Boolean) status = Column(String) image_caption: Mapped[ImagesCaptions] = relationship("ImagesCaptions") # task_submission model class TaskSubmission(Base): __tablename__ = "task_submission" id = Column(Integer, primary_key=True, index=True) task_id = Column(Integer, ForeignKey("task.id")) choice = Column(Text) created_by = Column(String) created_at = Column(TIMESTAMP, server_default=func.now()) task: Mapped[Task] = relationship("Task") def get_random_task() -> Task | None: """Retrieves a random task from the top 100 pending or in_progress tasks.""" db = SessionLocal() try: tasks = ( db.query(Task) .filter(or_(Task.status == "pending", Task.status == "in_progress")) .where(Task.image_caption_id % random.randint(1, 5) == 0) .order_by(Task.image_caption_id.asc()) .limit(500) .all() ) if tasks: random_task = random.choice(tasks) return random_task else: return None finally: db.close() def random_task(): task = get_random_task() if task is None: return None, None, None, None with SessionLocal() as db: task = db.query(Task).filter(Task.id == task.id).first() if task is None: return None, None, None, None caption1 = str(task.image_caption.caption_mn_v1) caption2 = str(task.image_caption.caption_mn_v2) if task.reverse_caption: caption1, caption2 = caption2, caption1 return ( ds["train"][task.image_caption.image_id]["image"], # str(task.image_caption.caption), caption1, caption2, int(task.id), ) css = """ .caption-btn { background: #fcdccc; border: 2px solid #f09162; } .dark .caption-btn { background: #26201f; border: 2px solid #40271a; } """ with gr.Blocks(css=css) as blind_test: username = gr.Textbox( label="Нэрээ оруулна уу", placeholder="Нэр", max_lines=1, max_length=40 ) local_storage = gr.BrowserState([""]) @blind_test.load(inputs=[local_storage], outputs=[username]) def load_from_local_storage(saved_values): print("loading from local storage", saved_values) return saved_values[0] @gr.on([username.change], inputs=[username], outputs=[local_storage]) def save_to_local_storage(username): return [username] task_id = gr.State(None) image, desc, choice1, choice2 = None, None, None, None img_preview = gr.Image( image, label="Зураг", show_label=True, show_download_button=False, height=400 ) md_desc = gr.Markdown( "### Доорх хоёр тайлбараас зурагтай хамгийн сайн тохирч буйг сонгоно уу." ) with gr.Row(equal_height=True, variant="panel"): with gr.Column(scale=1): caption_choice1_button = gr.Button( choice1, variant="secondary", elem_classes="caption-btn" ) with gr.Column(scale=1): caption_choice2_button = gr.Button( choice2, variant="secondary", elem_classes="caption-btn" ) def on_submit(username: str, choice: int, task_id: int): print("on_submit", username if username is not None else None, choice, task_id) if username == "": gr.Warning("Нэрээ оруулна уу!") return ( gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), ) if choice not in [1, 2]: return ( gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), ) with SessionLocal() as db: task = db.query(Task).filter(Task.id == task_id).first() if task is None: return None, None, None, None if task.reverse_caption: choice = 2 if choice == 1 else 1 task_submission = TaskSubmission( task_id=task.id, choice=choice, created_by=username ) db.add(task_submission) db.commit() submission_count = ( db.query(TaskSubmission) .filter(TaskSubmission.task_id == task.id) .count() ) if submission_count >= 3: task.status = "done" elif submission_count >= 0: task.status = "in_progress" db.commit() image, choice1, choice2, task_id = random_task() return image, choice1, choice2, task_id @caption_choice1_button.click( inputs=[username, task_id], outputs=[img_preview, caption_choice1_button, caption_choice2_button, task_id], ) def submit_choice1(username, task_id): return on_submit(username, 1, task_id) @caption_choice2_button.click( inputs=[username, task_id], outputs=[img_preview, caption_choice1_button, caption_choice2_button, task_id], ) def submit_choice2(username, task_id): return on_submit(username, 2, task_id) blind_test.load( fn=random_task, outputs=[img_preview, caption_choice1_button, caption_choice2_button, task_id], ) blind_test.launch()