Spaces:
Sleeping
Sleeping
DVampire
commited on
Commit
Β·
bf5c0e0
1
Parent(s):
d3e5344
update website
Browse files- configs/paper_agent.py +4 -0
- src/config/__init__.py +3 -0
- src/config/config.py +86 -0
- src/database/__init__.py +5 -0
- src/database/db.py +143 -0
- src/logger/__init__.py +10 -0
- src/logger/logger.py +229 -0
- src/utils/__init__.py +8 -0
- src/utils/hf_utils.py +0 -0
- src/utils/path_utils.py +12 -0
- src/utils/singleton.py +25 -0
configs/paper_agent.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
workdir = "workdir"
|
2 |
+
tag = "paper_agent"
|
3 |
+
exp_path = f"{workdir}/{tag}"
|
4 |
+
log_path = "agent.log"
|
src/config/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .config import config
|
2 |
+
|
3 |
+
__all__ = ['config']
|
src/config/config.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from mmengine import Config as MMConfig
|
3 |
+
from argparse import Namespace
|
4 |
+
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
load_dotenv(verbose=True)
|
7 |
+
|
8 |
+
from finworld.utils import assemble_project_path, get_tag_name, Singleton, set_seed
|
9 |
+
|
10 |
+
def check_level(level: str) -> bool:
|
11 |
+
"""
|
12 |
+
Check if the level is valid.
|
13 |
+
"""
|
14 |
+
valid_levels = ['1day', '1min', '5min', '15min', '30min', '1hour', '4hour']
|
15 |
+
if level not in valid_levels:
|
16 |
+
return False
|
17 |
+
return True
|
18 |
+
|
19 |
+
def process_general(config: MMConfig) -> MMConfig:
|
20 |
+
|
21 |
+
config.exp_path = assemble_project_path(os.path.join(config.workdir, config.tag))
|
22 |
+
os.makedirs(config.exp_path, exist_ok=True)
|
23 |
+
|
24 |
+
config.log_path = os.path.join(config.exp_path, getattr(config, 'log_path', 'finworld.log'))
|
25 |
+
|
26 |
+
if "checkpoint_path" in config:
|
27 |
+
config.checkpoint_path = os.path.join(config.exp_path, getattr(config, 'checkpoint_path', 'checkpoint'))
|
28 |
+
os.makedirs(config.checkpoint_path, exist_ok=True)
|
29 |
+
|
30 |
+
if "plot_path" in config:
|
31 |
+
config.plot_path = os.path.join(config.exp_path, getattr(config, 'plot_path', 'plot'))
|
32 |
+
os.makedirs(config.plot_path, exist_ok=True)
|
33 |
+
|
34 |
+
if "tracker" in config:
|
35 |
+
for key, value in config.tracker.items():
|
36 |
+
config.tracker[key]['logging_dir'] = os.path.join(config.exp_path, value['logging_dir'])
|
37 |
+
|
38 |
+
if "seed" in config:
|
39 |
+
set_seed(config.seed)
|
40 |
+
|
41 |
+
return config
|
42 |
+
|
43 |
+
|
44 |
+
class Config(MMConfig, metaclass=Singleton):
|
45 |
+
def __init__(self):
|
46 |
+
super(Config, self).__init__()
|
47 |
+
|
48 |
+
def init_config(self, config_path: str, args: Namespace) -> None:
|
49 |
+
# Initialize the general configuration
|
50 |
+
mmconfig = MMConfig.fromfile(filename=assemble_project_path(config_path))
|
51 |
+
if 'cfg_options' not in args or args.cfg_options is None:
|
52 |
+
cfg_options = dict()
|
53 |
+
else:
|
54 |
+
cfg_options = args.cfg_options
|
55 |
+
for item in args.__dict__:
|
56 |
+
if item not in ['config', 'cfg_options'] and args.__dict__[item] is not None:
|
57 |
+
cfg_options[item] = args.__dict__[item]
|
58 |
+
mmconfig.merge_from_dict(cfg_options)
|
59 |
+
|
60 |
+
tag = get_tag_name(
|
61 |
+
tag=getattr(mmconfig, 'tag', None),
|
62 |
+
assets_name=getattr(mmconfig, 'assets_name', None),
|
63 |
+
source=getattr(mmconfig, 'source', None),
|
64 |
+
data_type= getattr(mmconfig, 'data_type', None),
|
65 |
+
level= getattr(mmconfig, 'level', None),
|
66 |
+
)
|
67 |
+
mmconfig.tag = tag
|
68 |
+
|
69 |
+
# Process general configuration
|
70 |
+
mmconfig = process_general(mmconfig)
|
71 |
+
|
72 |
+
# Initialize the price downloader configuration
|
73 |
+
if 'downloader' in mmconfig:
|
74 |
+
if "assets_path" in mmconfig.downloader:
|
75 |
+
mmconfig.downloader.assets_path = assemble_project_path(mmconfig.downloader.assets_path)
|
76 |
+
assert check_level(mmconfig.downloader.level), f"Invalid level: {mmconfig.downloader.level}. Valid levels are: ['1day', '1min', '5min', '15min', '30min', '1hour', '4hour']"
|
77 |
+
|
78 |
+
if 'processor' in mmconfig:
|
79 |
+
if "assets_path" in mmconfig.processor:
|
80 |
+
mmconfig.processor.assets_path = assemble_project_path(mmconfig.processor.assets_path)
|
81 |
+
mmconfig.processor.repo_id = f"{os.getenv('HF_REPO_NAME')}/{mmconfig.processor.repo_id}"
|
82 |
+
mmconfig.processor.repo_type = mmconfig.processor.repo_type if 'repo_type' in mmconfig.processor else 'dataset'
|
83 |
+
|
84 |
+
self.__dict__.update(mmconfig.__dict__)
|
85 |
+
|
86 |
+
config = Config()
|
src/database/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Database management for paper caching
|
2 |
+
|
3 |
+
from .db import PapersDatabase, db
|
4 |
+
|
5 |
+
__all__ = ['PapersDatabase', 'db']
|
src/database/db.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import sqlite3
|
4 |
+
from datetime import date, datetime, timedelta
|
5 |
+
from typing import Any, Dict, List, Optional
|
6 |
+
from contextlib import contextmanager
|
7 |
+
|
8 |
+
|
9 |
+
class PapersDatabase():
|
10 |
+
def __init__(self, **kwargs):
|
11 |
+
super().__init__(**kwargs)
|
12 |
+
self.db_path = None
|
13 |
+
|
14 |
+
def init_db(self, config):
|
15 |
+
"""Initialize the database with required tables"""
|
16 |
+
|
17 |
+
self.db_path = config.db_path
|
18 |
+
|
19 |
+
with self.get_connection() as conn:
|
20 |
+
cursor = conn.cursor()
|
21 |
+
|
22 |
+
# Create papers cache table
|
23 |
+
cursor.execute('''
|
24 |
+
CREATE TABLE IF NOT EXISTS papers_cache (
|
25 |
+
date_str TEXT PRIMARY KEY,
|
26 |
+
html_content TEXT NOT NULL,
|
27 |
+
parsed_cards TEXT NOT NULL,
|
28 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
29 |
+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
30 |
+
)
|
31 |
+
''')
|
32 |
+
|
33 |
+
# Create latest_date table to track the most recent available date
|
34 |
+
cursor.execute('''
|
35 |
+
CREATE TABLE IF NOT EXISTS latest_date (
|
36 |
+
id INTEGER PRIMARY KEY CHECK (id = 1),
|
37 |
+
date_str TEXT NOT NULL,
|
38 |
+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
39 |
+
)
|
40 |
+
''')
|
41 |
+
|
42 |
+
# Insert default latest_date record if it doesn't exist
|
43 |
+
cursor.execute('''
|
44 |
+
INSERT OR IGNORE INTO latest_date (id, date_str)
|
45 |
+
VALUES (1, ?)
|
46 |
+
''', (date.today().isoformat(),))
|
47 |
+
|
48 |
+
conn.commit()
|
49 |
+
|
50 |
+
@contextmanager
|
51 |
+
def get_connection(self):
|
52 |
+
"""Context manager for database connections"""
|
53 |
+
conn = sqlite3.connect(self.db_path)
|
54 |
+
conn.row_factory = sqlite3.Row # Enable dict-like access
|
55 |
+
try:
|
56 |
+
yield conn
|
57 |
+
finally:
|
58 |
+
conn.close()
|
59 |
+
|
60 |
+
def get_cached_papers(self, date_str: str) -> Optional[Dict[str, Any]]:
|
61 |
+
"""Get cached papers for a specific date"""
|
62 |
+
with self.get_connection(self.db_path) as conn:
|
63 |
+
cursor = conn.cursor()
|
64 |
+
cursor.execute('''
|
65 |
+
SELECT parsed_cards, created_at
|
66 |
+
FROM papers_cache
|
67 |
+
WHERE date_str = ?
|
68 |
+
''', (date_str,))
|
69 |
+
|
70 |
+
row = cursor.fetchone()
|
71 |
+
if row:
|
72 |
+
return {
|
73 |
+
'cards': json.loads(row['parsed_cards']),
|
74 |
+
'cached_at': row['created_at']
|
75 |
+
}
|
76 |
+
return None
|
77 |
+
|
78 |
+
def cache_papers(self, date_str: str, html_content: str, parsed_cards: List[Dict[str, Any]]):
|
79 |
+
"""Cache papers for a specific date"""
|
80 |
+
with self.get_connection() as conn:
|
81 |
+
cursor = conn.cursor()
|
82 |
+
cursor.execute('''
|
83 |
+
INSERT OR REPLACE INTO papers_cache
|
84 |
+
(date_str, html_content, parsed_cards, updated_at)
|
85 |
+
VALUES (?, ?, ?, CURRENT_TIMESTAMP)
|
86 |
+
''', (date_str, html_content, json.dumps(parsed_cards)))
|
87 |
+
conn.commit()
|
88 |
+
|
89 |
+
def get_latest_cached_date(self) -> Optional[str]:
|
90 |
+
"""Get the latest cached date"""
|
91 |
+
with self.get_connection() as conn:
|
92 |
+
cursor = conn.cursor()
|
93 |
+
cursor.execute('SELECT date_str FROM latest_date WHERE id = 1')
|
94 |
+
row = cursor.fetchone()
|
95 |
+
return row['date_str'] if row else None
|
96 |
+
|
97 |
+
def update_latest_date(self, date_str: str):
|
98 |
+
"""Update the latest available date"""
|
99 |
+
with self.get_connection() as conn:
|
100 |
+
cursor = conn.cursor()
|
101 |
+
cursor.execute('''
|
102 |
+
UPDATE latest_date
|
103 |
+
SET date_str = ?, updated_at = CURRENT_TIMESTAMP
|
104 |
+
WHERE id = 1
|
105 |
+
''', (date_str,))
|
106 |
+
conn.commit()
|
107 |
+
|
108 |
+
def is_cache_fresh(self, date_str: str, max_age_hours: int = 24) -> bool:
|
109 |
+
"""Check if cache is fresh (within max_age_hours)"""
|
110 |
+
with self.get_connection() as conn:
|
111 |
+
cursor = conn.cursor()
|
112 |
+
cursor.execute('''
|
113 |
+
SELECT updated_at
|
114 |
+
FROM papers_cache
|
115 |
+
WHERE date_str = ?
|
116 |
+
''', (date_str,))
|
117 |
+
|
118 |
+
row = cursor.fetchone()
|
119 |
+
if not row:
|
120 |
+
return False
|
121 |
+
|
122 |
+
cached_time = datetime.fromisoformat(row['updated_at'].replace('Z', '+00:00'))
|
123 |
+
age = datetime.now(cached_time.tzinfo) - cached_time
|
124 |
+
return age.total_seconds() < max_age_hours * 3600
|
125 |
+
|
126 |
+
def cleanup_old_cache(self, days_to_keep: int = 7):
|
127 |
+
"""Clean up old cache entries"""
|
128 |
+
cutoff_date = (datetime.now() - timedelta(days=days_to_keep)).isoformat()
|
129 |
+
with self.get_connection() as conn:
|
130 |
+
cursor = conn.cursor()
|
131 |
+
cursor.execute('''
|
132 |
+
DELETE FROM papers_cache
|
133 |
+
WHERE updated_at < ?
|
134 |
+
''', (cutoff_date,))
|
135 |
+
conn.commit()
|
136 |
+
|
137 |
+
def __str__(self):
|
138 |
+
return f"PapersDatabase(db_path={self.db_path})"
|
139 |
+
|
140 |
+
def __repr__(self):
|
141 |
+
return self.__str__()
|
142 |
+
|
143 |
+
db = PapersDatabase()
|
src/logger/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .logger import logger, LogLevel, AgentLogger, YELLOW_HEX
|
2 |
+
from .monitor import Monitor, Timing, TokenUsage
|
3 |
+
|
4 |
+
__all__ = ["logger",
|
5 |
+
"LogLevel",
|
6 |
+
"AgentLogger",
|
7 |
+
"Monitor",
|
8 |
+
"YELLOW_HEX",
|
9 |
+
"Timing",
|
10 |
+
"TokenUsage"]
|
src/logger/logger.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import json
|
3 |
+
from enum import IntEnum
|
4 |
+
|
5 |
+
from rich import box
|
6 |
+
from rich.console import Console, Group
|
7 |
+
from rich.panel import Panel
|
8 |
+
from rich.rule import Rule
|
9 |
+
from rich.syntax import Syntax
|
10 |
+
from rich.table import Table
|
11 |
+
from rich.tree import Tree
|
12 |
+
|
13 |
+
from src.utils import (
|
14 |
+
escape_code_brackets,
|
15 |
+
Singleton
|
16 |
+
)
|
17 |
+
|
18 |
+
YELLOW_HEX = "#d4b702"
|
19 |
+
|
20 |
+
class LogLevel(IntEnum):
|
21 |
+
OFF = -1 # No output
|
22 |
+
ERROR = 0 # Only errors
|
23 |
+
INFO = 1 # Normal output (default)
|
24 |
+
DEBUG = 2 # Detailed output
|
25 |
+
|
26 |
+
class AgentLogger(logging.Logger, metaclass=Singleton):
|
27 |
+
def __init__(self, name="logger", level=logging.INFO):
|
28 |
+
# Initialize the parent class
|
29 |
+
super().__init__(name, level)
|
30 |
+
|
31 |
+
# Define a formatter for log messages
|
32 |
+
self.formatter = logging.Formatter(
|
33 |
+
fmt="\033[92m%(asctime)s - %(name)s:%(levelname)s\033[0m: %(filename)s:%(lineno)s - %(message)s",
|
34 |
+
datefmt="%H:%M:%S",
|
35 |
+
)
|
36 |
+
|
37 |
+
def init_logger(self, log_path: str, level=logging.INFO):
|
38 |
+
"""
|
39 |
+
Initialize the logger with a file path and optional main process check.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
log_path (str): The log file path.
|
43 |
+
level (int, optional): The logging level. Defaults to logging.INFO.
|
44 |
+
accelerator (Accelerator, optional): Accelerator instance to determine the main process.
|
45 |
+
"""
|
46 |
+
|
47 |
+
# Add a console handler for logging to the console
|
48 |
+
console_handler = logging.StreamHandler()
|
49 |
+
console_handler.setLevel(level)
|
50 |
+
console_handler.setFormatter(self.formatter)
|
51 |
+
self.addHandler(console_handler)
|
52 |
+
|
53 |
+
# Add a file handler for logging to the file
|
54 |
+
file_handler = logging.FileHandler(
|
55 |
+
log_path, mode="a"
|
56 |
+
) # 'a' mode appends to the file
|
57 |
+
file_handler.setLevel(level)
|
58 |
+
file_handler.setFormatter(self.formatter)
|
59 |
+
self.addHandler(file_handler)
|
60 |
+
|
61 |
+
self.console = Console(width=100)
|
62 |
+
self.file_console = Console(file=open(log_path, "a"), width=100)
|
63 |
+
|
64 |
+
# Prevent duplicate logs from propagating to the root logger
|
65 |
+
self.propagate = False
|
66 |
+
|
67 |
+
def log(self, *args, level: int | str | LogLevel = LogLevel.INFO, **kwargs) -> None:
|
68 |
+
"""Logs a message to the console.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
level (LogLevel, optional): Defaults to LogLevel.INFO.
|
72 |
+
"""
|
73 |
+
if isinstance(level, str):
|
74 |
+
level = LogLevel[level.upper()]
|
75 |
+
if level <= self.level:
|
76 |
+
self.info(*args, **kwargs)
|
77 |
+
|
78 |
+
def info(self, msg, *args, **kwargs):
|
79 |
+
"""
|
80 |
+
Overridden info method with stacklevel adjustment for correct log location.
|
81 |
+
"""
|
82 |
+
if isinstance(msg, (Rule, Panel, Group, Tree, Table, Syntax)):
|
83 |
+
self.console.print(msg)
|
84 |
+
self.file_console.print(msg)
|
85 |
+
else:
|
86 |
+
kwargs.setdefault(
|
87 |
+
"stacklevel", 2
|
88 |
+
) # Adjust stack level to show the actual caller
|
89 |
+
if "style" in kwargs:
|
90 |
+
kwargs.pop("style")
|
91 |
+
if "level" in kwargs:
|
92 |
+
kwargs.pop("level")
|
93 |
+
super().info(msg, *args, **kwargs)
|
94 |
+
|
95 |
+
def warning(self, msg, *args, **kwargs):
|
96 |
+
kwargs.setdefault("stacklevel", 2)
|
97 |
+
super().warning(msg, *args, **kwargs)
|
98 |
+
|
99 |
+
def error(self, msg, *args, **kwargs):
|
100 |
+
kwargs.setdefault("stacklevel", 2)
|
101 |
+
super().error(msg, *args, **kwargs)
|
102 |
+
|
103 |
+
def critical(self, msg, *args, **kwargs):
|
104 |
+
kwargs.setdefault("stacklevel", 2)
|
105 |
+
super().critical(msg, *args, **kwargs)
|
106 |
+
|
107 |
+
def debug(self, msg, *args, **kwargs):
|
108 |
+
kwargs.setdefault("stacklevel", 2)
|
109 |
+
super().debug(msg, *args, **kwargs)
|
110 |
+
|
111 |
+
def log_error(self, error_message: str) -> None:
|
112 |
+
self.info(escape_code_brackets(error_message), style="bold red", level=LogLevel.ERROR)
|
113 |
+
|
114 |
+
def log_markdown(self, content: str, title: str | None = None, level=LogLevel.INFO, style=YELLOW_HEX) -> None:
|
115 |
+
markdown_content = Syntax(
|
116 |
+
content,
|
117 |
+
lexer="markdown",
|
118 |
+
theme="github-dark",
|
119 |
+
word_wrap=True,
|
120 |
+
)
|
121 |
+
if title:
|
122 |
+
self.info(
|
123 |
+
Group(
|
124 |
+
Rule(
|
125 |
+
"[bold italic]" + title,
|
126 |
+
align="left",
|
127 |
+
style=style,
|
128 |
+
),
|
129 |
+
markdown_content,
|
130 |
+
),
|
131 |
+
level=level,
|
132 |
+
)
|
133 |
+
else:
|
134 |
+
self.info(markdown_content, level=level)
|
135 |
+
|
136 |
+
def log_code(self, title: str, content: str, level: int = LogLevel.INFO) -> None:
|
137 |
+
self.info(
|
138 |
+
Panel(
|
139 |
+
Syntax(
|
140 |
+
content,
|
141 |
+
lexer="python",
|
142 |
+
theme="monokai",
|
143 |
+
word_wrap=True,
|
144 |
+
),
|
145 |
+
title="[bold]" + title,
|
146 |
+
title_align="left",
|
147 |
+
box=box.HORIZONTALS,
|
148 |
+
),
|
149 |
+
level=level,
|
150 |
+
)
|
151 |
+
|
152 |
+
def log_rule(self, title: str, level: int = LogLevel.INFO) -> None:
|
153 |
+
self.info(
|
154 |
+
Rule(
|
155 |
+
"[bold]" + title,
|
156 |
+
characters="β",
|
157 |
+
style=YELLOW_HEX,
|
158 |
+
),
|
159 |
+
level=LogLevel.INFO,
|
160 |
+
)
|
161 |
+
|
162 |
+
def log_task(self, content: str, subtitle: str, title: str | None = None, level: LogLevel = LogLevel.INFO) -> None:
|
163 |
+
self.info(
|
164 |
+
Panel(
|
165 |
+
f"\n[bold]{escape_code_brackets(content)}\n",
|
166 |
+
title="[bold]New run" + (f" - {title}" if title else ""),
|
167 |
+
subtitle=subtitle,
|
168 |
+
border_style=YELLOW_HEX,
|
169 |
+
subtitle_align="left",
|
170 |
+
),
|
171 |
+
level=level,
|
172 |
+
)
|
173 |
+
|
174 |
+
def log_messages(self, messages: list[dict], level: LogLevel = LogLevel.DEBUG) -> None:
|
175 |
+
messages_as_string = "\n".join([json.dumps(dict(message), indent=4, ensure_ascii=False) for message in messages])
|
176 |
+
self.info(
|
177 |
+
Syntax(
|
178 |
+
messages_as_string,
|
179 |
+
lexer="markdown",
|
180 |
+
theme="github-dark",
|
181 |
+
word_wrap=True,
|
182 |
+
),
|
183 |
+
level=level,
|
184 |
+
)
|
185 |
+
|
186 |
+
def visualize_agent_tree(self, agent):
|
187 |
+
def create_tools_section(tools_dict):
|
188 |
+
table = Table(show_header=True, header_style="bold")
|
189 |
+
table.add_column("Name", style="#1E90FF")
|
190 |
+
table.add_column("Description")
|
191 |
+
table.add_column("Arguments")
|
192 |
+
|
193 |
+
for name, tool in tools_dict.items():
|
194 |
+
args = [
|
195 |
+
f"{arg_name} (`{info.get('type', 'Any')}`{', optional' if info.get('optional') else ''}): {info.get('description', '')}"
|
196 |
+
for arg_name, info in getattr(tool, "inputs", {}).items()
|
197 |
+
]
|
198 |
+
table.add_row(name, getattr(tool, "description", str(tool)), "\n".join(args))
|
199 |
+
|
200 |
+
return Group("π οΈ [italic #1E90FF]Tools:[/italic #1E90FF]", table)
|
201 |
+
|
202 |
+
def get_agent_headline(agent, name: str | None = None):
|
203 |
+
name_headline = f"{name} | " if name else ""
|
204 |
+
return f"[bold {YELLOW_HEX}]{name_headline}{agent.__class__.__name__} | {agent.model.model_id}"
|
205 |
+
|
206 |
+
def build_agent_tree(parent_tree, agent_obj):
|
207 |
+
"""Recursively builds the agent tree."""
|
208 |
+
parent_tree.add(create_tools_section(agent_obj.tools))
|
209 |
+
|
210 |
+
if agent_obj.managed_agents:
|
211 |
+
agents_branch = parent_tree.add("π€ [italic #1E90FF]Managed agents:")
|
212 |
+
for name, managed_agent in agent_obj.managed_agents.items():
|
213 |
+
agent_tree = agents_branch.add(get_agent_headline(managed_agent, name))
|
214 |
+
if managed_agent.__class__.__name__ == "CodeAgent":
|
215 |
+
agent_tree.add(
|
216 |
+
f"β
[italic #1E90FF]Authorized imports:[/italic #1E90FF] {managed_agent.additional_authorized_imports}"
|
217 |
+
)
|
218 |
+
agent_tree.add(f"π [italic #1E90FF]Description:[/italic #1E90FF] {managed_agent.description}")
|
219 |
+
build_agent_tree(agent_tree, managed_agent)
|
220 |
+
|
221 |
+
main_tree = Tree(get_agent_headline(agent))
|
222 |
+
if agent.__class__.__name__ == "CodeAgent":
|
223 |
+
main_tree.add(
|
224 |
+
f"β
[italic #1E90FF]Authorized imports:[/italic #1E90FF] {agent.additional_authorized_imports}"
|
225 |
+
)
|
226 |
+
build_agent_tree(main_tree, agent)
|
227 |
+
self.console.print(main_tree)
|
228 |
+
|
229 |
+
logger = AgentLogger()
|
src/utils/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .path_utils import get_project_root, assemble_project_path
|
2 |
+
from .singleton import Singleton
|
3 |
+
|
4 |
+
__all__ = [
|
5 |
+
"get_project_root",
|
6 |
+
"assemble_project_path",
|
7 |
+
"Singleton"
|
8 |
+
]
|
src/utils/hf_utils.py
ADDED
File without changes
|
src/utils/path_utils.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import os
|
3 |
+
|
4 |
+
def get_project_root():
|
5 |
+
root = str(Path(__file__).resolve().parents[2])
|
6 |
+
return root
|
7 |
+
|
8 |
+
def assemble_project_path(path):
|
9 |
+
"""Assemble a path relative to the project root directory"""
|
10 |
+
if not os.path.isabs(path):
|
11 |
+
path = os.path.join(get_project_root(), path)
|
12 |
+
return path
|
src/utils/singleton.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A singleton metaclass for ensuring only one instance of a class."""
|
2 |
+
|
3 |
+
import abc
|
4 |
+
|
5 |
+
|
6 |
+
class Singleton(abc.ABCMeta, type):
|
7 |
+
"""
|
8 |
+
Singleton metaclass for ensuring only one instance of a class.
|
9 |
+
"""
|
10 |
+
|
11 |
+
_instances = {}
|
12 |
+
|
13 |
+
def __call__(cls, *args, **kwargs):
|
14 |
+
"""Call method for the singleton metaclass."""
|
15 |
+
if cls not in cls._instances:
|
16 |
+
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
17 |
+
return cls._instances[cls]
|
18 |
+
|
19 |
+
|
20 |
+
class AbstractSingleton(abc.ABC, metaclass=Singleton):
|
21 |
+
"""
|
22 |
+
Abstract singleton class for ensuring only one instance of a class.
|
23 |
+
"""
|
24 |
+
|
25 |
+
pass
|