wikihop-server / BFS_solver.py
stillerman's picture
stillerman HF Staff
reorganizing
2c2bab6
raw
history blame
2.85 kB
from collections import deque
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich.progress import Progress, SpinnerColumn, TextColumn
console = Console()
class WikiSolver:
def __init__(self, wiki_run_engine):
"""Initialize the solver with a WikiRunEnvironment instance"""
self.engine = wiki_run_engine
def find_path(self, start_article, target_article):
"""Find the shortest path using BFS"""
if not self.engine.article_exists(start_article):
return None, "Start article not found in wiki data"
if not self.engine.article_exists(target_article):
return None, "Target article not found in wiki data"
# Initialize BFS
queue = deque([(start_article, [start_article])])
visited = {start_article}
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console
) as progress:
task = progress.add_task("Finding path...", total=None)
while queue:
current, path = queue.popleft()
if current == target_article:
return path, None
# Get all links from current article
# Need to set current article to get links
state = self.engine.reset(current, target_article)
for next_article in state['available_links']:
if next_article not in visited and self.engine.article_exists(next_article):
visited.add(next_article)
queue.append((next_article, path + [next_article]))
return None, "No path found"
def display_solution(self, path, start_article, target_article):
"""Display the solution in a beautiful format"""
if not path:
console.print("[red]No solution found![/red]")
return
# Create a panel for the solution
console.print(Panel(
f"[bold cyan]Solution Found![/bold cyan]\n"
f"From: [green]{start_article}[/green]\n"
f"To: [red]{target_article}[/red]\n"
f"Steps: [yellow]{len(path)-1}[/yellow]",
title="Wiki Run Solver",
border_style="cyan"
))
# Create a table for the path
table = Table(show_header=True, header_style="bold magenta")
table.add_column("Step", style="dim")
table.add_column("Article", style="cyan")
for i, article in enumerate(path):
table.add_row(
str(i),
article
)
console.print(table)