Spaces:
Sleeping
Sleeping
import os | |
import random | |
import gradio as gr | |
from dotenv import load_dotenv | |
from sqlalchemy import ( | |
TIMESTAMP, | |
Boolean, | |
Column, | |
ForeignKey, | |
Integer, | |
String, | |
Text, | |
create_engine, | |
or_, | |
) | |
from sqlalchemy.ext.declarative import declarative_base | |
from sqlalchemy.orm import Mapped, relationship, sessionmaker | |
from sqlalchemy.sql import func | |
from datasets import load_dataset | |
ds = load_dataset("bilguun/flickr30k-mn") | |
load_dotenv() | |
DATABASE_URL = os.getenv("DATABASE_URL") | |
engine = create_engine(DATABASE_URL) | |
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |
Base = declarative_base() | |
# images_captions model | |
class ImagesCaptions(Base): | |
__tablename__ = "images_captions" | |
id = Column(Integer, primary_key=True, index=True) | |
image_id = Column(Integer) | |
image_name = Column(String) | |
caption_num = Column(Integer) | |
caption = Column(Text) | |
caption_mn_v1 = Column(Text) | |
caption_mn_v2 = Column(Text) | |
# task model | |
class Task(Base): | |
__tablename__ = "task" | |
id = Column(Integer, primary_key=True, index=True) | |
image_caption_id = Column(Integer, ForeignKey("images_captions.id")) | |
caption_num = Column(Integer) | |
reverse_caption = Column(Boolean) | |
status = Column(String) | |
image_caption: Mapped[ImagesCaptions] = relationship("ImagesCaptions") | |
# task_submission model | |
class TaskSubmission(Base): | |
__tablename__ = "task_submission" | |
id = Column(Integer, primary_key=True, index=True) | |
task_id = Column(Integer, ForeignKey("task.id")) | |
choice = Column(Text) | |
created_by = Column(String) | |
created_at = Column(TIMESTAMP, server_default=func.now()) | |
task: Mapped[Task] = relationship("Task") | |
def get_random_task() -> Task | None: | |
"""Retrieves a random task from the top 100 pending or in_progress tasks.""" | |
db = SessionLocal() | |
try: | |
tasks = ( | |
db.query(Task) | |
.filter(or_(Task.status == "pending", Task.status == "in_progress")) | |
.where(Task.image_caption_id % random.randint(1, 5) == 0) | |
.order_by(Task.image_caption_id.asc()) | |
.limit(500) | |
.all() | |
) | |
if tasks: | |
random_task = random.choice(tasks) | |
return random_task | |
else: | |
return None | |
finally: | |
db.close() | |
def random_task(): | |
task = get_random_task() | |
if task is None: | |
return None, None, None, None | |
with SessionLocal() as db: | |
task = db.query(Task).filter(Task.id == task.id).first() | |
if task is None: | |
return None, None, None, None | |
caption1 = str(task.image_caption.caption_mn_v1) | |
caption2 = str(task.image_caption.caption_mn_v2) | |
if task.reverse_caption: | |
caption1, caption2 = caption2, caption1 | |
return ( | |
ds["train"][task.image_caption.image_id]["image"], | |
# str(task.image_caption.caption), | |
caption1, | |
caption2, | |
int(task.id), | |
) | |
css = """ | |
.caption-btn { | |
background: #fcdccc; | |
border: 2px solid #f09162; | |
} | |
.dark .caption-btn { | |
background: #26201f; | |
border: 2px solid #40271a; | |
} | |
""" | |
with gr.Blocks(css=css) as blind_test: | |
username = gr.Textbox( | |
label="Нэрээ оруулна уу", placeholder="Нэр", max_lines=1, max_length=40 | |
) | |
local_storage = gr.BrowserState([""]) | |
def load_from_local_storage(saved_values): | |
print("loading from local storage", saved_values) | |
return saved_values[0] | |
def save_to_local_storage(username): | |
return [username] | |
task_id = gr.State(None) | |
image, desc, choice1, choice2 = None, None, None, None | |
img_preview = gr.Image( | |
image, label="Зураг", show_label=True, show_download_button=False, height=400 | |
) | |
md_desc = gr.Markdown( | |
"### Доорх хоёр тайлбараас зурагтай хамгийн сайн тохирч буйг сонгоно уу." | |
) | |
with gr.Row(equal_height=True, variant="panel"): | |
with gr.Column(scale=1): | |
caption_choice1_button = gr.Button( | |
choice1, variant="secondary", elem_classes="caption-btn" | |
) | |
with gr.Column(scale=1): | |
caption_choice2_button = gr.Button( | |
choice2, variant="secondary", elem_classes="caption-btn" | |
) | |
def on_submit(username: str, choice: int, task_id: int): | |
print("on_submit", username if username is not None else None, choice, task_id) | |
if username == "": | |
gr.Warning("Нэрээ оруулна уу!") | |
return ( | |
gr.update(visible=True), | |
gr.update(visible=True), | |
gr.update(visible=True), | |
gr.update(visible=True), | |
) | |
if choice not in [1, 2]: | |
return ( | |
gr.update(visible=True), | |
gr.update(visible=True), | |
gr.update(visible=True), | |
gr.update(visible=True), | |
) | |
with SessionLocal() as db: | |
task = db.query(Task).filter(Task.id == task_id).first() | |
if task is None: | |
return None, None, None, None | |
if task.reverse_caption: | |
choice = 2 if choice == 1 else 1 | |
task_submission = TaskSubmission( | |
task_id=task.id, choice=choice, created_by=username | |
) | |
db.add(task_submission) | |
db.commit() | |
submission_count = ( | |
db.query(TaskSubmission) | |
.filter(TaskSubmission.task_id == task.id) | |
.count() | |
) | |
if submission_count >= 3: | |
task.status = "done" | |
elif submission_count >= 0: | |
task.status = "in_progress" | |
db.commit() | |
image, choice1, choice2, task_id = random_task() | |
return image, choice1, choice2, task_id | |
def submit_choice1(username, task_id): | |
return on_submit(username, 1, task_id) | |
def submit_choice2(username, task_id): | |
return on_submit(username, 2, task_id) | |
blind_test.load( | |
fn=random_task, | |
outputs=[img_preview, caption_choice1_button, caption_choice2_button, task_id], | |
) | |
blind_test.launch() | |