Spaces:
Paused
Paused
### GENERAL IMPORTS ### | |
from datetime import date, timedelta | |
### DOMAIN ### | |
from .domain import Employee, Task, ScheduleInfo | |
### TIMEFOLD ### | |
from timefold.solver.score import HardSoftDecimalScore | |
from timefold.solver.score._constraint_factory import ConstraintFactory | |
from timefold.solver.score._joiners import Joiners | |
from timefold.solver.score._group_by import ConstraintCollectors | |
from timefold.solver.score._annotations import constraint_provider | |
def get_slot_overlap(task1: Task, task2: Task) -> int: | |
"""Calculate the number of overlapping slots between two tasks. | |
Args: | |
task1 (Task): The first task. | |
task2 (Task): The second task. | |
Returns: | |
int: The number of overlapping slots. | |
""" | |
task1_end: int = task1.start_slot + task1.duration_slots | |
task2_end: int = task2.start_slot + task2.duration_slots | |
overlap_start: int = max(task1.start_slot, task2.start_slot) | |
overlap_end: int = min(task1_end, task2_end) | |
return max(0, overlap_end - overlap_start) | |
def get_slot_date(slot: int) -> date: | |
"""Convert a slot index to a date. | |
Args: | |
slot (int): The slot index. | |
Returns: | |
date: The date corresponding to the slot. | |
""" | |
return date.today() + timedelta(days=slot // 20) # 20 slots per day | |
def tasks_violate_sequence_order(task1: Task, task2: Task) -> bool: | |
"""Check if two tasks violate the project sequence order. | |
Args: | |
task1 (Task): The first task. | |
task2 (Task): The second task. | |
Returns: | |
bool: True if task1 should come before task2 but overlaps with it. | |
""" | |
# Different tasks only | |
if task1.id == task2.id: | |
return False | |
# Both tasks must have project_id attribute | |
if not (hasattr(task1, "project_id") and hasattr(task2, "project_id")): | |
return False | |
# Task1 must belong to a project | |
if task1.project_id == "": | |
return False | |
# Tasks must be in the same project | |
if task1.project_id != task2.project_id: | |
return False | |
# Task1 must have lower sequence number (should come first) | |
if task1.sequence_number >= task2.sequence_number: | |
return False | |
# Task1 overlaps with task2 (task1 should finish before task2 starts) | |
return task1.start_slot + task1.duration_slots > task2.start_slot | |
def define_constraints(constraint_factory: ConstraintFactory) -> list: | |
""" | |
Define the constraints for the timetable problem. | |
Args: | |
constraint_factory (ConstraintFactory): The constraint factory. | |
Returns: | |
list[Constraint]: The constraints. | |
""" | |
return [ | |
# Hard constraints | |
required_skill(constraint_factory), | |
no_overlapping_tasks(constraint_factory), | |
task_within_schedule(constraint_factory), | |
task_fits_in_schedule(constraint_factory), | |
unavailable_employee(constraint_factory), | |
maintain_project_task_order(constraint_factory), | |
# Soft constraints | |
undesired_day_for_employee(constraint_factory), | |
desired_day_for_employee(constraint_factory), | |
balance_employee_task_assignments(constraint_factory), | |
] | |
### CONSTRAINTS ### | |
def required_skill(constraint_factory: ConstraintFactory): | |
return ( | |
constraint_factory.for_each(Task) | |
.filter( | |
lambda task: task.employee is not None | |
and task.required_skill not in task.employee.skills | |
) | |
.penalize(HardSoftDecimalScore.ONE_HARD) | |
.as_constraint("Required skill") | |
) | |
def no_overlapping_tasks(constraint_factory: ConstraintFactory): | |
return ( | |
constraint_factory.for_each_unique_pair( | |
Task, | |
Joiners.equal(lambda task: task.employee.name), | |
Joiners.overlapping( | |
lambda task: task.start_slot, | |
lambda task: task.start_slot + task.duration_slots, | |
), | |
) | |
.penalize(HardSoftDecimalScore.ONE_HARD, get_slot_overlap) | |
.as_constraint("No overlapping tasks") | |
) | |
def task_within_schedule(constraint_factory: ConstraintFactory): | |
return ( | |
constraint_factory.for_each(Task) | |
.filter(lambda task: task.start_slot < 0) | |
.penalize(HardSoftDecimalScore.ONE_HARD) | |
.as_constraint("Task within schedule") | |
) | |
def task_fits_in_schedule(constraint_factory: ConstraintFactory): | |
return ( | |
constraint_factory.for_each(Task) | |
.join(ScheduleInfo) | |
.filter( | |
lambda task, schedule_info: task.start_slot + task.duration_slots | |
> schedule_info.total_slots | |
) | |
.penalize(HardSoftDecimalScore.ONE_HARD) | |
.as_constraint("Task fits in schedule") | |
) | |
def unavailable_employee(constraint_factory: ConstraintFactory): | |
return ( | |
constraint_factory.for_each(Task) | |
.filter( | |
lambda task: task.employee is not None | |
and get_slot_date(task.start_slot) in task.employee.unavailable_dates | |
) | |
.penalize(HardSoftDecimalScore.ONE_HARD) | |
.as_constraint("Unavailable employee") | |
) | |
def undesired_day_for_employee(constraint_factory: ConstraintFactory): | |
return ( | |
constraint_factory.for_each(Task) | |
.filter( | |
lambda task: task.employee is not None | |
and get_slot_date(task.start_slot) in task.employee.undesired_dates | |
) | |
.penalize(HardSoftDecimalScore.ONE_SOFT) | |
.as_constraint("Undesired day for employee") | |
) | |
def desired_day_for_employee(constraint_factory: ConstraintFactory): | |
return ( | |
constraint_factory.for_each(Task) | |
.filter( | |
lambda task: task.employee is not None | |
and get_slot_date(task.start_slot) in task.employee.desired_dates | |
) | |
.reward(HardSoftDecimalScore.ONE_SOFT) | |
.as_constraint("Desired day for employee") | |
) | |
def maintain_project_task_order(constraint_factory: ConstraintFactory): | |
"""Ensure tasks within the same project maintain their original order.""" | |
return ( | |
constraint_factory.for_each(Task) | |
.join(Task) | |
.filter(tasks_violate_sequence_order) | |
.penalize( | |
HardSoftDecimalScore.ONE_SOFT, | |
lambda task1, task2: 100 | |
* (task1.start_slot + task1.duration_slots - task2.start_slot), | |
) # High penalty (100x) proportional to overlap to strongly encourage proper sequencing | |
.as_constraint("Project task sequence order") | |
) | |
def balance_employee_task_assignments(constraint_factory: ConstraintFactory): | |
return ( | |
constraint_factory.for_each(Task) | |
.group_by(lambda task: task.employee, ConstraintCollectors.count()) | |
.complement( | |
Employee, lambda e: 0 | |
) # Include all employees which are not assigned to any task | |
.group_by( | |
ConstraintCollectors.load_balance( | |
lambda employee, task_count: employee, | |
lambda employee, task_count: task_count, | |
) | |
) | |
.penalize_decimal( | |
HardSoftDecimalScore.ONE_SOFT, | |
lambda load_balance: load_balance.unfairness(), | |
) | |
.as_constraint("Balance employee task assignments") | |
) | |