DVampire commited on
Commit
bf5c0e0
Β·
1 Parent(s): d3e5344

update website

Browse files
configs/paper_agent.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ workdir = "workdir"
2
+ tag = "paper_agent"
3
+ exp_path = f"{workdir}/{tag}"
4
+ log_path = "agent.log"
src/config/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .config import config
2
+
3
+ __all__ = ['config']
src/config/config.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from mmengine import Config as MMConfig
3
+ from argparse import Namespace
4
+
5
+ from dotenv import load_dotenv
6
+ load_dotenv(verbose=True)
7
+
8
+ from finworld.utils import assemble_project_path, get_tag_name, Singleton, set_seed
9
+
10
+ def check_level(level: str) -> bool:
11
+ """
12
+ Check if the level is valid.
13
+ """
14
+ valid_levels = ['1day', '1min', '5min', '15min', '30min', '1hour', '4hour']
15
+ if level not in valid_levels:
16
+ return False
17
+ return True
18
+
19
+ def process_general(config: MMConfig) -> MMConfig:
20
+
21
+ config.exp_path = assemble_project_path(os.path.join(config.workdir, config.tag))
22
+ os.makedirs(config.exp_path, exist_ok=True)
23
+
24
+ config.log_path = os.path.join(config.exp_path, getattr(config, 'log_path', 'finworld.log'))
25
+
26
+ if "checkpoint_path" in config:
27
+ config.checkpoint_path = os.path.join(config.exp_path, getattr(config, 'checkpoint_path', 'checkpoint'))
28
+ os.makedirs(config.checkpoint_path, exist_ok=True)
29
+
30
+ if "plot_path" in config:
31
+ config.plot_path = os.path.join(config.exp_path, getattr(config, 'plot_path', 'plot'))
32
+ os.makedirs(config.plot_path, exist_ok=True)
33
+
34
+ if "tracker" in config:
35
+ for key, value in config.tracker.items():
36
+ config.tracker[key]['logging_dir'] = os.path.join(config.exp_path, value['logging_dir'])
37
+
38
+ if "seed" in config:
39
+ set_seed(config.seed)
40
+
41
+ return config
42
+
43
+
44
+ class Config(MMConfig, metaclass=Singleton):
45
+ def __init__(self):
46
+ super(Config, self).__init__()
47
+
48
+ def init_config(self, config_path: str, args: Namespace) -> None:
49
+ # Initialize the general configuration
50
+ mmconfig = MMConfig.fromfile(filename=assemble_project_path(config_path))
51
+ if 'cfg_options' not in args or args.cfg_options is None:
52
+ cfg_options = dict()
53
+ else:
54
+ cfg_options = args.cfg_options
55
+ for item in args.__dict__:
56
+ if item not in ['config', 'cfg_options'] and args.__dict__[item] is not None:
57
+ cfg_options[item] = args.__dict__[item]
58
+ mmconfig.merge_from_dict(cfg_options)
59
+
60
+ tag = get_tag_name(
61
+ tag=getattr(mmconfig, 'tag', None),
62
+ assets_name=getattr(mmconfig, 'assets_name', None),
63
+ source=getattr(mmconfig, 'source', None),
64
+ data_type= getattr(mmconfig, 'data_type', None),
65
+ level= getattr(mmconfig, 'level', None),
66
+ )
67
+ mmconfig.tag = tag
68
+
69
+ # Process general configuration
70
+ mmconfig = process_general(mmconfig)
71
+
72
+ # Initialize the price downloader configuration
73
+ if 'downloader' in mmconfig:
74
+ if "assets_path" in mmconfig.downloader:
75
+ mmconfig.downloader.assets_path = assemble_project_path(mmconfig.downloader.assets_path)
76
+ assert check_level(mmconfig.downloader.level), f"Invalid level: {mmconfig.downloader.level}. Valid levels are: ['1day', '1min', '5min', '15min', '30min', '1hour', '4hour']"
77
+
78
+ if 'processor' in mmconfig:
79
+ if "assets_path" in mmconfig.processor:
80
+ mmconfig.processor.assets_path = assemble_project_path(mmconfig.processor.assets_path)
81
+ mmconfig.processor.repo_id = f"{os.getenv('HF_REPO_NAME')}/{mmconfig.processor.repo_id}"
82
+ mmconfig.processor.repo_type = mmconfig.processor.repo_type if 'repo_type' in mmconfig.processor else 'dataset'
83
+
84
+ self.__dict__.update(mmconfig.__dict__)
85
+
86
+ config = Config()
src/database/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Database management for paper caching
2
+
3
+ from .db import PapersDatabase, db
4
+
5
+ __all__ = ['PapersDatabase', 'db']
src/database/db.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import sqlite3
4
+ from datetime import date, datetime, timedelta
5
+ from typing import Any, Dict, List, Optional
6
+ from contextlib import contextmanager
7
+
8
+
9
+ class PapersDatabase():
10
+ def __init__(self, **kwargs):
11
+ super().__init__(**kwargs)
12
+ self.db_path = None
13
+
14
+ def init_db(self, config):
15
+ """Initialize the database with required tables"""
16
+
17
+ self.db_path = config.db_path
18
+
19
+ with self.get_connection() as conn:
20
+ cursor = conn.cursor()
21
+
22
+ # Create papers cache table
23
+ cursor.execute('''
24
+ CREATE TABLE IF NOT EXISTS papers_cache (
25
+ date_str TEXT PRIMARY KEY,
26
+ html_content TEXT NOT NULL,
27
+ parsed_cards TEXT NOT NULL,
28
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
29
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
30
+ )
31
+ ''')
32
+
33
+ # Create latest_date table to track the most recent available date
34
+ cursor.execute('''
35
+ CREATE TABLE IF NOT EXISTS latest_date (
36
+ id INTEGER PRIMARY KEY CHECK (id = 1),
37
+ date_str TEXT NOT NULL,
38
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
39
+ )
40
+ ''')
41
+
42
+ # Insert default latest_date record if it doesn't exist
43
+ cursor.execute('''
44
+ INSERT OR IGNORE INTO latest_date (id, date_str)
45
+ VALUES (1, ?)
46
+ ''', (date.today().isoformat(),))
47
+
48
+ conn.commit()
49
+
50
+ @contextmanager
51
+ def get_connection(self):
52
+ """Context manager for database connections"""
53
+ conn = sqlite3.connect(self.db_path)
54
+ conn.row_factory = sqlite3.Row # Enable dict-like access
55
+ try:
56
+ yield conn
57
+ finally:
58
+ conn.close()
59
+
60
+ def get_cached_papers(self, date_str: str) -> Optional[Dict[str, Any]]:
61
+ """Get cached papers for a specific date"""
62
+ with self.get_connection(self.db_path) as conn:
63
+ cursor = conn.cursor()
64
+ cursor.execute('''
65
+ SELECT parsed_cards, created_at
66
+ FROM papers_cache
67
+ WHERE date_str = ?
68
+ ''', (date_str,))
69
+
70
+ row = cursor.fetchone()
71
+ if row:
72
+ return {
73
+ 'cards': json.loads(row['parsed_cards']),
74
+ 'cached_at': row['created_at']
75
+ }
76
+ return None
77
+
78
+ def cache_papers(self, date_str: str, html_content: str, parsed_cards: List[Dict[str, Any]]):
79
+ """Cache papers for a specific date"""
80
+ with self.get_connection() as conn:
81
+ cursor = conn.cursor()
82
+ cursor.execute('''
83
+ INSERT OR REPLACE INTO papers_cache
84
+ (date_str, html_content, parsed_cards, updated_at)
85
+ VALUES (?, ?, ?, CURRENT_TIMESTAMP)
86
+ ''', (date_str, html_content, json.dumps(parsed_cards)))
87
+ conn.commit()
88
+
89
+ def get_latest_cached_date(self) -> Optional[str]:
90
+ """Get the latest cached date"""
91
+ with self.get_connection() as conn:
92
+ cursor = conn.cursor()
93
+ cursor.execute('SELECT date_str FROM latest_date WHERE id = 1')
94
+ row = cursor.fetchone()
95
+ return row['date_str'] if row else None
96
+
97
+ def update_latest_date(self, date_str: str):
98
+ """Update the latest available date"""
99
+ with self.get_connection() as conn:
100
+ cursor = conn.cursor()
101
+ cursor.execute('''
102
+ UPDATE latest_date
103
+ SET date_str = ?, updated_at = CURRENT_TIMESTAMP
104
+ WHERE id = 1
105
+ ''', (date_str,))
106
+ conn.commit()
107
+
108
+ def is_cache_fresh(self, date_str: str, max_age_hours: int = 24) -> bool:
109
+ """Check if cache is fresh (within max_age_hours)"""
110
+ with self.get_connection() as conn:
111
+ cursor = conn.cursor()
112
+ cursor.execute('''
113
+ SELECT updated_at
114
+ FROM papers_cache
115
+ WHERE date_str = ?
116
+ ''', (date_str,))
117
+
118
+ row = cursor.fetchone()
119
+ if not row:
120
+ return False
121
+
122
+ cached_time = datetime.fromisoformat(row['updated_at'].replace('Z', '+00:00'))
123
+ age = datetime.now(cached_time.tzinfo) - cached_time
124
+ return age.total_seconds() < max_age_hours * 3600
125
+
126
+ def cleanup_old_cache(self, days_to_keep: int = 7):
127
+ """Clean up old cache entries"""
128
+ cutoff_date = (datetime.now() - timedelta(days=days_to_keep)).isoformat()
129
+ with self.get_connection() as conn:
130
+ cursor = conn.cursor()
131
+ cursor.execute('''
132
+ DELETE FROM papers_cache
133
+ WHERE updated_at < ?
134
+ ''', (cutoff_date,))
135
+ conn.commit()
136
+
137
+ def __str__(self):
138
+ return f"PapersDatabase(db_path={self.db_path})"
139
+
140
+ def __repr__(self):
141
+ return self.__str__()
142
+
143
+ db = PapersDatabase()
src/logger/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .logger import logger, LogLevel, AgentLogger, YELLOW_HEX
2
+ from .monitor import Monitor, Timing, TokenUsage
3
+
4
+ __all__ = ["logger",
5
+ "LogLevel",
6
+ "AgentLogger",
7
+ "Monitor",
8
+ "YELLOW_HEX",
9
+ "Timing",
10
+ "TokenUsage"]
src/logger/logger.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import json
3
+ from enum import IntEnum
4
+
5
+ from rich import box
6
+ from rich.console import Console, Group
7
+ from rich.panel import Panel
8
+ from rich.rule import Rule
9
+ from rich.syntax import Syntax
10
+ from rich.table import Table
11
+ from rich.tree import Tree
12
+
13
+ from src.utils import (
14
+ escape_code_brackets,
15
+ Singleton
16
+ )
17
+
18
+ YELLOW_HEX = "#d4b702"
19
+
20
+ class LogLevel(IntEnum):
21
+ OFF = -1 # No output
22
+ ERROR = 0 # Only errors
23
+ INFO = 1 # Normal output (default)
24
+ DEBUG = 2 # Detailed output
25
+
26
+ class AgentLogger(logging.Logger, metaclass=Singleton):
27
+ def __init__(self, name="logger", level=logging.INFO):
28
+ # Initialize the parent class
29
+ super().__init__(name, level)
30
+
31
+ # Define a formatter for log messages
32
+ self.formatter = logging.Formatter(
33
+ fmt="\033[92m%(asctime)s - %(name)s:%(levelname)s\033[0m: %(filename)s:%(lineno)s - %(message)s",
34
+ datefmt="%H:%M:%S",
35
+ )
36
+
37
+ def init_logger(self, log_path: str, level=logging.INFO):
38
+ """
39
+ Initialize the logger with a file path and optional main process check.
40
+
41
+ Args:
42
+ log_path (str): The log file path.
43
+ level (int, optional): The logging level. Defaults to logging.INFO.
44
+ accelerator (Accelerator, optional): Accelerator instance to determine the main process.
45
+ """
46
+
47
+ # Add a console handler for logging to the console
48
+ console_handler = logging.StreamHandler()
49
+ console_handler.setLevel(level)
50
+ console_handler.setFormatter(self.formatter)
51
+ self.addHandler(console_handler)
52
+
53
+ # Add a file handler for logging to the file
54
+ file_handler = logging.FileHandler(
55
+ log_path, mode="a"
56
+ ) # 'a' mode appends to the file
57
+ file_handler.setLevel(level)
58
+ file_handler.setFormatter(self.formatter)
59
+ self.addHandler(file_handler)
60
+
61
+ self.console = Console(width=100)
62
+ self.file_console = Console(file=open(log_path, "a"), width=100)
63
+
64
+ # Prevent duplicate logs from propagating to the root logger
65
+ self.propagate = False
66
+
67
+ def log(self, *args, level: int | str | LogLevel = LogLevel.INFO, **kwargs) -> None:
68
+ """Logs a message to the console.
69
+
70
+ Args:
71
+ level (LogLevel, optional): Defaults to LogLevel.INFO.
72
+ """
73
+ if isinstance(level, str):
74
+ level = LogLevel[level.upper()]
75
+ if level <= self.level:
76
+ self.info(*args, **kwargs)
77
+
78
+ def info(self, msg, *args, **kwargs):
79
+ """
80
+ Overridden info method with stacklevel adjustment for correct log location.
81
+ """
82
+ if isinstance(msg, (Rule, Panel, Group, Tree, Table, Syntax)):
83
+ self.console.print(msg)
84
+ self.file_console.print(msg)
85
+ else:
86
+ kwargs.setdefault(
87
+ "stacklevel", 2
88
+ ) # Adjust stack level to show the actual caller
89
+ if "style" in kwargs:
90
+ kwargs.pop("style")
91
+ if "level" in kwargs:
92
+ kwargs.pop("level")
93
+ super().info(msg, *args, **kwargs)
94
+
95
+ def warning(self, msg, *args, **kwargs):
96
+ kwargs.setdefault("stacklevel", 2)
97
+ super().warning(msg, *args, **kwargs)
98
+
99
+ def error(self, msg, *args, **kwargs):
100
+ kwargs.setdefault("stacklevel", 2)
101
+ super().error(msg, *args, **kwargs)
102
+
103
+ def critical(self, msg, *args, **kwargs):
104
+ kwargs.setdefault("stacklevel", 2)
105
+ super().critical(msg, *args, **kwargs)
106
+
107
+ def debug(self, msg, *args, **kwargs):
108
+ kwargs.setdefault("stacklevel", 2)
109
+ super().debug(msg, *args, **kwargs)
110
+
111
+ def log_error(self, error_message: str) -> None:
112
+ self.info(escape_code_brackets(error_message), style="bold red", level=LogLevel.ERROR)
113
+
114
+ def log_markdown(self, content: str, title: str | None = None, level=LogLevel.INFO, style=YELLOW_HEX) -> None:
115
+ markdown_content = Syntax(
116
+ content,
117
+ lexer="markdown",
118
+ theme="github-dark",
119
+ word_wrap=True,
120
+ )
121
+ if title:
122
+ self.info(
123
+ Group(
124
+ Rule(
125
+ "[bold italic]" + title,
126
+ align="left",
127
+ style=style,
128
+ ),
129
+ markdown_content,
130
+ ),
131
+ level=level,
132
+ )
133
+ else:
134
+ self.info(markdown_content, level=level)
135
+
136
+ def log_code(self, title: str, content: str, level: int = LogLevel.INFO) -> None:
137
+ self.info(
138
+ Panel(
139
+ Syntax(
140
+ content,
141
+ lexer="python",
142
+ theme="monokai",
143
+ word_wrap=True,
144
+ ),
145
+ title="[bold]" + title,
146
+ title_align="left",
147
+ box=box.HORIZONTALS,
148
+ ),
149
+ level=level,
150
+ )
151
+
152
+ def log_rule(self, title: str, level: int = LogLevel.INFO) -> None:
153
+ self.info(
154
+ Rule(
155
+ "[bold]" + title,
156
+ characters="━",
157
+ style=YELLOW_HEX,
158
+ ),
159
+ level=LogLevel.INFO,
160
+ )
161
+
162
+ def log_task(self, content: str, subtitle: str, title: str | None = None, level: LogLevel = LogLevel.INFO) -> None:
163
+ self.info(
164
+ Panel(
165
+ f"\n[bold]{escape_code_brackets(content)}\n",
166
+ title="[bold]New run" + (f" - {title}" if title else ""),
167
+ subtitle=subtitle,
168
+ border_style=YELLOW_HEX,
169
+ subtitle_align="left",
170
+ ),
171
+ level=level,
172
+ )
173
+
174
+ def log_messages(self, messages: list[dict], level: LogLevel = LogLevel.DEBUG) -> None:
175
+ messages_as_string = "\n".join([json.dumps(dict(message), indent=4, ensure_ascii=False) for message in messages])
176
+ self.info(
177
+ Syntax(
178
+ messages_as_string,
179
+ lexer="markdown",
180
+ theme="github-dark",
181
+ word_wrap=True,
182
+ ),
183
+ level=level,
184
+ )
185
+
186
+ def visualize_agent_tree(self, agent):
187
+ def create_tools_section(tools_dict):
188
+ table = Table(show_header=True, header_style="bold")
189
+ table.add_column("Name", style="#1E90FF")
190
+ table.add_column("Description")
191
+ table.add_column("Arguments")
192
+
193
+ for name, tool in tools_dict.items():
194
+ args = [
195
+ f"{arg_name} (`{info.get('type', 'Any')}`{', optional' if info.get('optional') else ''}): {info.get('description', '')}"
196
+ for arg_name, info in getattr(tool, "inputs", {}).items()
197
+ ]
198
+ table.add_row(name, getattr(tool, "description", str(tool)), "\n".join(args))
199
+
200
+ return Group("πŸ› οΈ [italic #1E90FF]Tools:[/italic #1E90FF]", table)
201
+
202
+ def get_agent_headline(agent, name: str | None = None):
203
+ name_headline = f"{name} | " if name else ""
204
+ return f"[bold {YELLOW_HEX}]{name_headline}{agent.__class__.__name__} | {agent.model.model_id}"
205
+
206
+ def build_agent_tree(parent_tree, agent_obj):
207
+ """Recursively builds the agent tree."""
208
+ parent_tree.add(create_tools_section(agent_obj.tools))
209
+
210
+ if agent_obj.managed_agents:
211
+ agents_branch = parent_tree.add("πŸ€– [italic #1E90FF]Managed agents:")
212
+ for name, managed_agent in agent_obj.managed_agents.items():
213
+ agent_tree = agents_branch.add(get_agent_headline(managed_agent, name))
214
+ if managed_agent.__class__.__name__ == "CodeAgent":
215
+ agent_tree.add(
216
+ f"βœ… [italic #1E90FF]Authorized imports:[/italic #1E90FF] {managed_agent.additional_authorized_imports}"
217
+ )
218
+ agent_tree.add(f"πŸ“ [italic #1E90FF]Description:[/italic #1E90FF] {managed_agent.description}")
219
+ build_agent_tree(agent_tree, managed_agent)
220
+
221
+ main_tree = Tree(get_agent_headline(agent))
222
+ if agent.__class__.__name__ == "CodeAgent":
223
+ main_tree.add(
224
+ f"βœ… [italic #1E90FF]Authorized imports:[/italic #1E90FF] {agent.additional_authorized_imports}"
225
+ )
226
+ build_agent_tree(main_tree, agent)
227
+ self.console.print(main_tree)
228
+
229
+ logger = AgentLogger()
src/utils/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .path_utils import get_project_root, assemble_project_path
2
+ from .singleton import Singleton
3
+
4
+ __all__ = [
5
+ "get_project_root",
6
+ "assemble_project_path",
7
+ "Singleton"
8
+ ]
src/utils/hf_utils.py ADDED
File without changes
src/utils/path_utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import os
3
+
4
+ def get_project_root():
5
+ root = str(Path(__file__).resolve().parents[2])
6
+ return root
7
+
8
+ def assemble_project_path(path):
9
+ """Assemble a path relative to the project root directory"""
10
+ if not os.path.isabs(path):
11
+ path = os.path.join(get_project_root(), path)
12
+ return path
src/utils/singleton.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A singleton metaclass for ensuring only one instance of a class."""
2
+
3
+ import abc
4
+
5
+
6
+ class Singleton(abc.ABCMeta, type):
7
+ """
8
+ Singleton metaclass for ensuring only one instance of a class.
9
+ """
10
+
11
+ _instances = {}
12
+
13
+ def __call__(cls, *args, **kwargs):
14
+ """Call method for the singleton metaclass."""
15
+ if cls not in cls._instances:
16
+ cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
17
+ return cls._instances[cls]
18
+
19
+
20
+ class AbstractSingleton(abc.ABC, metaclass=Singleton):
21
+ """
22
+ Abstract singleton class for ensuring only one instance of a class.
23
+ """
24
+
25
+ pass