Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- .gitignore +4 -0
- README.md +21 -1
- pyproject.toml +7 -7
- src/chattr/__init__.py +0 -82
- src/chattr/__main__.py +2 -6
- src/chattr/graph/__init__.py +0 -0
- src/chattr/graph/builder.py +255 -0
- src/chattr/graph/runner.py +6 -0
- src/chattr/gui.py +32 -52
- src/chattr/settings.py +139 -0
- src/chattr/utils.py +77 -0
- uv.lock +0 -0
.gitignore
CHANGED
@@ -5,6 +5,10 @@
|
|
5 |
.venv/
|
6 |
logs/
|
7 |
results/
|
|
|
|
|
|
|
|
|
8 |
.github/
|
9 |
.trunk/
|
10 |
.idea/
|
|
|
5 |
.venv/
|
6 |
logs/
|
7 |
results/
|
8 |
+
qdrant_storage/
|
9 |
+
assets/audio/
|
10 |
+
assets/video/
|
11 |
+
|
12 |
.github/
|
13 |
.trunk/
|
14 |
.idea/
|
README.md
CHANGED
@@ -8,4 +8,24 @@ app_port: 7860
|
|
8 |
short_description: Chat with Characters
|
9 |
---
|
10 |
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
short_description: Chat with Characters
|
9 |
---
|
10 |
|
11 |
+
## **Chattr**: App part of the Chatacter Backend
|
12 |
+
|
13 |
+
### Environment Variables
|
14 |
+
|
15 |
+
The configuration of the server is done using environment variables:
|
16 |
+
|
17 |
+
| Name | Description | Required | Default Value |
|
18 |
+
|:---------------------------|:---------------------------------|:--------:|:-------------------------------------------|
|
19 |
+
| `MODEL__URL` | OpenAI-compatible endpoint | ✘ | `https://api.groq.com/openai/v1` |
|
20 |
+
| `MODEL__NAME` | Model name to use for chat | ✘ | `llama3-70b-8192` |
|
21 |
+
| `MODEL__API_KEY` | API key for model access | ✔ | `None` |
|
22 |
+
| `MODEL__TEMPERATURE` | Model temperature (0.0-1.0) | ✘ | `0.0` |
|
23 |
+
| `SHORT_TERM_MEMORY__URL` | Redis URL for memory store | ✘ | `redis://localhost:6379` |
|
24 |
+
| `VECTOR_DATABASE__NAME` | Vector database collection name | ✘ | `chattr` |
|
25 |
+
| `VOICE_GENERATOR_MCP__URL` | MCP service for audio generation | ✘ | `http://localhost:8001/gradio_api/mcp/sse` |
|
26 |
+
| `VIDEO_GENERATOR_MCP__URL` | MCP service for video generation | ✘ | `http://localhost:8002/gradio_api/mcp/sse` |
|
27 |
+
| `DIRECTORY__ASSETS` | Base assets directory | ✘ | `./assets` |
|
28 |
+
| `DIRECTORY__LOG` | Log files directory | ✘ | `./logs` |
|
29 |
+
| `DIRECTORY__IMAGE` | Image assets directory | ✘ | `./assets/image` |
|
30 |
+
| `DIRECTORY__AUDIO` | Audio assets directory | ✘ | `./assets/audio` |
|
31 |
+
| `DIRECTORY__VIDEO` | Video assets directory | ✘ | `./assets/video` |
|
pyproject.toml
CHANGED
@@ -8,13 +8,13 @@ authors = [
|
|
8 |
]
|
9 |
requires-python = ">=3.12"
|
10 |
dependencies = [
|
11 |
-
"gradio>=5.
|
12 |
-
"langchain>=0.3.
|
13 |
"langchain-mcp-adapters>=0.1.9",
|
14 |
-
"langchain-openai>=0.3.
|
15 |
-
"langgraph>=0.
|
16 |
-
"
|
17 |
-
"
|
18 |
]
|
19 |
|
20 |
[project.scripts]
|
@@ -25,4 +25,4 @@ requires = ["uv_build"]
|
|
25 |
build-backend = "uv_build"
|
26 |
|
27 |
[dependency-groups]
|
28 |
-
dev = ["ruff>=0.12.
|
|
|
8 |
]
|
9 |
requires-python = ">=3.12"
|
10 |
dependencies = [
|
11 |
+
"gradio>=5.39.0",
|
12 |
+
"langchain>=0.3.27",
|
13 |
"langchain-mcp-adapters>=0.1.9",
|
14 |
+
"langchain-openai>=0.3.28",
|
15 |
+
"langgraph>=0.6.3",
|
16 |
+
"langgraph-checkpoint-redis>=0.0.8",
|
17 |
+
"m3u8>=6.0.0",
|
18 |
]
|
19 |
|
20 |
[project.scripts]
|
|
|
25 |
build-backend = "uv_build"
|
26 |
|
27 |
[dependency-groups]
|
28 |
+
dev = ["ruff>=0.12.7", "ty>=0.0.1a16"]
|
src/chattr/__init__.py
CHANGED
@@ -1,82 +0,0 @@
|
|
1 |
-
from datetime import datetime
|
2 |
-
from os import getenv
|
3 |
-
from pathlib import Path
|
4 |
-
|
5 |
-
from dotenv import load_dotenv
|
6 |
-
from loguru import logger
|
7 |
-
from requests import get
|
8 |
-
|
9 |
-
load_dotenv()
|
10 |
-
|
11 |
-
SERVER_URL: str = getenv(key="SERVER_URL", default="127.0.0.1")
|
12 |
-
SERVER_PORT: int = int(getenv(key="SERVER_PORT", default="7860"))
|
13 |
-
CURRENT_DATE: str = datetime.now().strftime(format="%Y-%m-%d_%H-%M-%S")
|
14 |
-
MCP_VOICE_GENERATOR: str = getenv(
|
15 |
-
key="MCP_VOICE_GENERATOR", default="http://localhost:8001/"
|
16 |
-
)
|
17 |
-
MCP_VIDEO_GENERATOR: str = getenv(
|
18 |
-
key="MCP_VIDEO_GENERATOR", default="http://localhost:8002/"
|
19 |
-
)
|
20 |
-
VECTOR_DATABASE_NAME: str = getenv(
|
21 |
-
key="VECTOR_DATABASE_NAME", default="chattr"
|
22 |
-
)
|
23 |
-
DOCKER_MODEL_RUNNER_URL: str = getenv(
|
24 |
-
key="DOCKER_MODEL_RUNNER_URL", default="http://127.0.0.1:12434/engines/v1"
|
25 |
-
)
|
26 |
-
DOCKER_MODEL_RUNNER_MODEL_NAME: str = getenv(
|
27 |
-
key="DOCKER_MODEL_RUNNER_MODEL_NAME",
|
28 |
-
default="ai/qwen3:0.6B-Q4_0",
|
29 |
-
)
|
30 |
-
GROQ_URL: str = getenv(
|
31 |
-
key="MODEL_URL", default="https://api.groq.com/openai/v1"
|
32 |
-
)
|
33 |
-
GROQ_MODEL_NAME: str = getenv(key="GROQ_MODEL_NAME", default="llama3-70b-8192")
|
34 |
-
|
35 |
-
BASE_DIR: Path = Path.cwd()
|
36 |
-
ASSETS_DIR: Path = BASE_DIR / "assets"
|
37 |
-
LOG_DIR: Path = BASE_DIR / "logs"
|
38 |
-
IMAGE_DIR: Path = ASSETS_DIR / "image"
|
39 |
-
AUDIO_DIR: Path = ASSETS_DIR / "audio"
|
40 |
-
VIDEO_DIR: Path = ASSETS_DIR / "video"
|
41 |
-
|
42 |
-
LOG_FILE_PATH: Path = LOG_DIR / f"{CURRENT_DATE}.log"
|
43 |
-
AUDIO_FILE_PATH: Path = AUDIO_DIR / f"{CURRENT_DATE}.wav"
|
44 |
-
VIDEO_FILE_PATH: Path = VIDEO_DIR / f"{CURRENT_DATE}.mp4"
|
45 |
-
|
46 |
-
ASSETS_DIR.mkdir(exist_ok=True)
|
47 |
-
IMAGE_DIR.mkdir(exist_ok=True)
|
48 |
-
AUDIO_DIR.mkdir(exist_ok=True)
|
49 |
-
VIDEO_DIR.mkdir(exist_ok=True)
|
50 |
-
LOG_DIR.mkdir(exist_ok=True)
|
51 |
-
|
52 |
-
MODEL_URL: str = (
|
53 |
-
DOCKER_MODEL_RUNNER_URL
|
54 |
-
if get(DOCKER_MODEL_RUNNER_URL, timeout=10).status_code == 200
|
55 |
-
else GROQ_URL
|
56 |
-
)
|
57 |
-
MODEL_NAME: str = (
|
58 |
-
DOCKER_MODEL_RUNNER_MODEL_NAME
|
59 |
-
if MODEL_URL == DOCKER_MODEL_RUNNER_URL
|
60 |
-
else GROQ_MODEL_NAME
|
61 |
-
)
|
62 |
-
MODEL_API_KEY: str = (
|
63 |
-
"not-needed"
|
64 |
-
if MODEL_URL == DOCKER_MODEL_RUNNER_URL
|
65 |
-
else getenv("GROQ_API_KEY")
|
66 |
-
)
|
67 |
-
MODEL_TEMPERATURE: float = float(getenv(key="MODEL_TEMPERATURE", default=0.0))
|
68 |
-
|
69 |
-
logger.add(
|
70 |
-
sink=LOG_FILE_PATH,
|
71 |
-
format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
|
72 |
-
colorize=True,
|
73 |
-
)
|
74 |
-
logger.info(f"Current date: {CURRENT_DATE}")
|
75 |
-
logger.info(f"Base directory: {BASE_DIR}")
|
76 |
-
logger.info(f"Assets directory: {ASSETS_DIR}")
|
77 |
-
logger.info(f"Log directory: {LOG_DIR}")
|
78 |
-
logger.info(f"Audio file path: {AUDIO_FILE_PATH}")
|
79 |
-
logger.info(f"Log file path: {LOG_FILE_PATH}")
|
80 |
-
logger.info(f"Model URL is going to be used is {MODEL_URL}")
|
81 |
-
logger.info(f"Model name is going to be used is {MODEL_NAME}")
|
82 |
-
logger.info(f"Model temperature is going to be used is {MODEL_TEMPERATURE}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/chattr/__main__.py
CHANGED
@@ -1,17 +1,13 @@
|
|
1 |
from gradio import Blocks
|
2 |
|
3 |
-
from chattr import SERVER_PORT, SERVER_URL
|
4 |
from chattr.gui import app_block
|
5 |
|
6 |
|
7 |
def main() -> None:
|
8 |
-
"""
|
9 |
-
Initializes and launches the Gradio-based Chattr application server with API access, monitoring, and PWA support enabled.
|
10 |
-
"""
|
11 |
app: Blocks = app_block()
|
12 |
app.queue(api_open=True).launch(
|
13 |
-
|
14 |
-
server_port=SERVER_PORT,
|
15 |
debug=True,
|
16 |
show_api=True,
|
17 |
enable_monitoring=True,
|
|
|
1 |
from gradio import Blocks
|
2 |
|
|
|
3 |
from chattr.gui import app_block
|
4 |
|
5 |
|
6 |
def main() -> None:
|
7 |
+
"""Initializes and launches the Gradio-based Chattr application server with API access, monitoring, and PWA support enabled."""
|
|
|
|
|
8 |
app: Blocks = app_block()
|
9 |
app.queue(api_open=True).launch(
|
10 |
+
server_port=7860,
|
|
|
11 |
debug=True,
|
12 |
show_api=True,
|
13 |
enable_monitoring=True,
|
src/chattr/graph/__init__.py
ADDED
File without changes
|
src/chattr/graph/builder.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This module contains the Graph class, which represents the main orchestration graph for the Chattr application."""
|
2 |
+
|
3 |
+
from json import dumps
|
4 |
+
from logging import getLogger
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import AsyncGenerator, Self
|
7 |
+
|
8 |
+
from gradio import ChatMessage
|
9 |
+
from gradio.components.chatbot import MetadataDict
|
10 |
+
from langchain_core.messages import HumanMessage
|
11 |
+
from langchain_core.runnables import Runnable, RunnableConfig
|
12 |
+
from langchain_core.tools import BaseTool
|
13 |
+
from langchain_mcp_adapters.client import MultiServerMCPClient
|
14 |
+
from langchain_mcp_adapters.sessions import (
|
15 |
+
SSEConnection,
|
16 |
+
StdioConnection,
|
17 |
+
StreamableHttpConnection,
|
18 |
+
WebsocketConnection,
|
19 |
+
)
|
20 |
+
from langchain_openai import ChatOpenAI
|
21 |
+
from langgraph.checkpoint.redis.aio import AsyncRedisSaver
|
22 |
+
from langgraph.graph import START, StateGraph
|
23 |
+
from langgraph.graph.message import MessagesState
|
24 |
+
from langgraph.graph.state import CompiledStateGraph
|
25 |
+
from langgraph.prebuilt import ToolNode, tools_condition
|
26 |
+
from langgraph.store.redis.aio import AsyncRedisStore
|
27 |
+
|
28 |
+
from chattr.settings import Settings
|
29 |
+
from chattr.utils import convert_audio_to_wav, download_file, is_url
|
30 |
+
|
31 |
+
logger = getLogger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
class Graph:
|
35 |
+
"""
|
36 |
+
Represents the main orchestration graph for the Chattr application.
|
37 |
+
This class manages the setup and execution of the conversational agent, tools, and state graph.
|
38 |
+
"""
|
39 |
+
|
40 |
+
settings: Settings
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
store: AsyncRedisStore,
|
45 |
+
saver: AsyncRedisSaver,
|
46 |
+
tools: list[BaseTool],
|
47 |
+
):
|
48 |
+
self._long_term_memory: AsyncRedisStore = store
|
49 |
+
self._short_term_memory: AsyncRedisSaver = saver
|
50 |
+
self._tools: list[BaseTool] = tools
|
51 |
+
self._llm: ChatOpenAI = self._initialize_llm()
|
52 |
+
self._model: Runnable = self._llm.bind_tools(self._tools)
|
53 |
+
self._graph: CompiledStateGraph = self._build_state_graph()
|
54 |
+
|
55 |
+
@classmethod
|
56 |
+
async def create(cls, settings: Settings) -> Self:
|
57 |
+
"""Async factory method to create a Graph instance."""
|
58 |
+
cls.settings: Settings = settings
|
59 |
+
store, saver = await cls._setup_memory()
|
60 |
+
tools: list[BaseTool] = await cls._setup_tools(
|
61 |
+
MultiServerMCPClient(cls._create_mcp_config())
|
62 |
+
)
|
63 |
+
return cls(store, saver, tools)
|
64 |
+
|
65 |
+
def _build_state_graph(self) -> CompiledStateGraph:
|
66 |
+
"""
|
67 |
+
Construct and compile the state graph for the Chattr application.
|
68 |
+
This method defines the nodes and edges for the conversational agent and tool interactions.
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
CompiledStateGraph: The compiled state graph is ready for execution.
|
72 |
+
"""
|
73 |
+
|
74 |
+
async def _call_model(state: MessagesState) -> MessagesState:
|
75 |
+
response = await self._model.ainvoke(
|
76 |
+
[self.settings.model.system_message] + state["messages"]
|
77 |
+
)
|
78 |
+
return MessagesState(messages=[response])
|
79 |
+
|
80 |
+
graph_builder: StateGraph = StateGraph(MessagesState)
|
81 |
+
graph_builder.add_node("agent", _call_model)
|
82 |
+
graph_builder.add_node("tools", ToolNode(self._tools))
|
83 |
+
graph_builder.add_edge(START, "agent")
|
84 |
+
graph_builder.add_conditional_edges("agent", tools_condition)
|
85 |
+
graph_builder.add_edge("tools", "agent")
|
86 |
+
return graph_builder.compile(
|
87 |
+
debug=True,
|
88 |
+
checkpointer=self._short_term_memory,
|
89 |
+
store=self._long_term_memory,
|
90 |
+
)
|
91 |
+
|
92 |
+
@classmethod
|
93 |
+
def _create_mcp_config(
|
94 |
+
cls,
|
95 |
+
) -> dict[
|
96 |
+
str,
|
97 |
+
StdioConnection
|
98 |
+
| SSEConnection
|
99 |
+
| StreamableHttpConnection
|
100 |
+
| WebsocketConnection,
|
101 |
+
]:
|
102 |
+
"""
|
103 |
+
Create the configuration dictionary for MCP (Multi-Component Protocol) servers.
|
104 |
+
This method sets up the connection details for each MCP server used by the application.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
dict: A dictionary mapping server names to their connection configurations.
|
108 |
+
"""
|
109 |
+
|
110 |
+
return {
|
111 |
+
"vector_database": StdioConnection(
|
112 |
+
command="uvx",
|
113 |
+
args=["mcp-server-qdrant"],
|
114 |
+
env={
|
115 |
+
"QDRANT_URL": str(cls.settings.vector_database.url),
|
116 |
+
"COLLECTION_NAME": cls.settings.vector_database.name,
|
117 |
+
},
|
118 |
+
transport="stdio",
|
119 |
+
),
|
120 |
+
"time": StdioConnection(
|
121 |
+
command="uvx",
|
122 |
+
args=["mcp-server-time"],
|
123 |
+
transport="stdio",
|
124 |
+
),
|
125 |
+
cls.settings.voice_generator_mcp.name: SSEConnection(
|
126 |
+
url=str(cls.settings.voice_generator_mcp.url),
|
127 |
+
transport=cls.settings.voice_generator_mcp.transport,
|
128 |
+
),
|
129 |
+
}
|
130 |
+
|
131 |
+
def _initialize_llm(self) -> ChatOpenAI:
|
132 |
+
"""
|
133 |
+
Initialize the ChatOpenAI language model using the provided settings.
|
134 |
+
This method creates and returns a ChatOpenAI instance configured with the model's URL, name, API key, and temperature.
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
ChatOpenAI: The initialized ChatOpenAI language model instance.
|
138 |
+
|
139 |
+
Raises:
|
140 |
+
Exception: If the model initialization fails.
|
141 |
+
"""
|
142 |
+
try:
|
143 |
+
return ChatOpenAI(
|
144 |
+
base_url=str(self.settings.model.url),
|
145 |
+
model=self.settings.model.name,
|
146 |
+
api_key=self.settings.model.api_key,
|
147 |
+
temperature=self.settings.model.temperature,
|
148 |
+
)
|
149 |
+
except Exception as e:
|
150 |
+
logger.error(f"Failed to initialize ChatOpenAI model: {e}")
|
151 |
+
raise
|
152 |
+
|
153 |
+
@classmethod
|
154 |
+
async def _setup_memory(cls) -> tuple[AsyncRedisStore, AsyncRedisSaver]:
|
155 |
+
"""
|
156 |
+
Initialize and set up the Redis store and checkpointer for state persistence.
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
tuple[AsyncRedisStore, AsyncRedisSaver]: Configured Redis store and saver instances.
|
160 |
+
"""
|
161 |
+
store_ctx = AsyncRedisStore.from_conn_string(str(cls.settings.memory.url))
|
162 |
+
checkpointer_ctx = AsyncRedisSaver.from_conn_string(
|
163 |
+
str(cls.settings.memory.url)
|
164 |
+
)
|
165 |
+
store = await store_ctx.__aenter__()
|
166 |
+
checkpointer = await checkpointer_ctx.__aenter__()
|
167 |
+
await store.setup()
|
168 |
+
await checkpointer.asetup()
|
169 |
+
return store, checkpointer
|
170 |
+
|
171 |
+
@staticmethod
|
172 |
+
async def _setup_tools(_mcp_client: MultiServerMCPClient) -> list[BaseTool]:
|
173 |
+
"""
|
174 |
+
Retrieve a list of tools from the provided MCP client.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
_mcp_client: The MultiServerMCPClient instance used to fetch available tools.
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
list[BaseTool]: A list of BaseTool objects retrieved from the MCP client.
|
181 |
+
"""
|
182 |
+
return await _mcp_client.get_tools()
|
183 |
+
|
184 |
+
def draw_graph(self) -> None:
|
185 |
+
"""Render the compiled state graph as a Mermaid PNG image and save it."""
|
186 |
+
self._graph.get_graph().draw_mermaid_png(
|
187 |
+
output_file_path=self.settings.directory.assets / "graph.png"
|
188 |
+
)
|
189 |
+
|
190 |
+
async def generate_response(
|
191 |
+
self, message: str, history: list[ChatMessage]
|
192 |
+
) -> AsyncGenerator[tuple[str, list[ChatMessage], Path | None]]:
|
193 |
+
"""
|
194 |
+
Generate a response to a user message and update the conversation history.
|
195 |
+
This asynchronous method streams responses from the state graph and yields updated history and audio file paths as needed.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
message: The user's input message as a string.
|
199 |
+
history: The conversation history as a list of ChatMessage objects.
|
200 |
+
|
201 |
+
Returns:
|
202 |
+
AsyncGenerator[tuple[str, list[ChatMessage], Path]]: Yields a tuple containing an empty string, the updated history, and a Path to an audio file if generated.
|
203 |
+
"""
|
204 |
+
async for response in self._graph.astream(
|
205 |
+
MessagesState(messages=[HumanMessage(content=message)]),
|
206 |
+
RunnableConfig(configurable={"thread_id": "1", "user_id": "1"}),
|
207 |
+
stream_mode="updates",
|
208 |
+
):
|
209 |
+
if response.keys() == {"agent"}:
|
210 |
+
last_agent_message = response["agent"]["messages"][-1]
|
211 |
+
if last_agent_message.tool_calls:
|
212 |
+
history.append(
|
213 |
+
ChatMessage(
|
214 |
+
role="assistant",
|
215 |
+
content=dumps(
|
216 |
+
last_agent_message.tool_calls[0]["args"], indent=4
|
217 |
+
),
|
218 |
+
metadata=MetadataDict(
|
219 |
+
title=last_agent_message.tool_calls[0]["name"],
|
220 |
+
id=last_agent_message.tool_calls[0]["id"],
|
221 |
+
),
|
222 |
+
)
|
223 |
+
)
|
224 |
+
else:
|
225 |
+
history.append(
|
226 |
+
ChatMessage(
|
227 |
+
role="assistant", content=last_agent_message.content
|
228 |
+
)
|
229 |
+
)
|
230 |
+
else:
|
231 |
+
last_tool_message = response["tools"]["messages"][-1]
|
232 |
+
history.append(
|
233 |
+
ChatMessage(
|
234 |
+
role="assistant",
|
235 |
+
content=last_tool_message.content,
|
236 |
+
metadata=MetadataDict(
|
237 |
+
title=last_tool_message.name,
|
238 |
+
id=last_tool_message.id,
|
239 |
+
),
|
240 |
+
)
|
241 |
+
)
|
242 |
+
if is_url(last_tool_message.content):
|
243 |
+
logger.info(f"Downloading audio from {last_tool_message.content}")
|
244 |
+
file_path: Path = (
|
245 |
+
self.settings.directory.audio / last_tool_message.id
|
246 |
+
)
|
247 |
+
download_file(
|
248 |
+
last_tool_message.content, file_path.with_suffix(".aac")
|
249 |
+
)
|
250 |
+
logger.info(f"Audio downloaded to {file_path.with_suffix('.aac')}")
|
251 |
+
convert_audio_to_wav(
|
252 |
+
file_path.with_suffix(".aac"), file_path.with_suffix(".wav")
|
253 |
+
)
|
254 |
+
yield "", history, file_path.with_suffix(".wav")
|
255 |
+
yield "", history, None
|
src/chattr/graph/runner.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from asyncio import run
|
2 |
+
|
3 |
+
from chattr.graph.builder import Graph
|
4 |
+
from chattr.settings import Settings
|
5 |
+
|
6 |
+
graph: Graph = run(Graph.create(Settings()))
|
src/chattr/gui.py
CHANGED
@@ -1,61 +1,41 @@
|
|
1 |
-
|
2 |
-
from gradio import Blocks, Button, Chatbot, ChatMessage, Row, Textbox
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
-
|
6 |
-
history: list[ChatMessage], thread_id: str
|
7 |
-
) -> list[ChatMessage]:
|
8 |
-
"""
|
9 |
-
Appends an assistant message about a quarterly sales plot to the chat history for the specified thread ID.
|
10 |
-
|
11 |
-
If the thread ID is 0, raises a Gradio error prompting for a valid thread ID.
|
12 |
-
|
13 |
-
Returns:
|
14 |
-
The updated chat history including the new assistant message.
|
15 |
-
"""
|
16 |
-
if thread_id == 0:
|
17 |
-
raise gradio.Error("Please enter a thread ID.")
|
18 |
-
history.append(
|
19 |
-
ChatMessage(
|
20 |
-
role="assistant",
|
21 |
-
content=f"Here is the plot of quarterly sales for {thread_id}.",
|
22 |
-
metadata={"title": "🛠️ Used tool Weather API"},
|
23 |
-
)
|
24 |
-
)
|
25 |
-
return history
|
26 |
|
27 |
|
28 |
def app_block() -> Blocks:
|
29 |
-
"""
|
30 |
-
|
|
|
31 |
|
32 |
Returns:
|
33 |
-
Blocks: The
|
34 |
"""
|
35 |
-
|
36 |
-
history = [
|
37 |
-
ChatMessage(role="assistant", content="How can I help you?"),
|
38 |
-
ChatMessage(
|
39 |
-
role="user", content="Can you make me a plot of quarterly sales?"
|
40 |
-
),
|
41 |
-
ChatMessage(
|
42 |
-
role="assistant",
|
43 |
-
content="I am happy to provide you that report and plot.",
|
44 |
-
),
|
45 |
-
]
|
46 |
-
with Blocks() as app:
|
47 |
-
with Row():
|
48 |
-
thread_id: Textbox = Textbox(
|
49 |
-
label="Thread ID", info="Enter Thread ID"
|
50 |
-
)
|
51 |
-
|
52 |
-
chatbot: Chatbot = Chatbot(history, type="messages")
|
53 |
-
|
54 |
with Row():
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This module contains the Gradio-based GUI for the Chattr app."""
|
|
|
2 |
|
3 |
+
from gradio import (
|
4 |
+
Audio,
|
5 |
+
Blocks,
|
6 |
+
Button,
|
7 |
+
Chatbot,
|
8 |
+
ClearButton,
|
9 |
+
Column,
|
10 |
+
PlayableVideo,
|
11 |
+
Row,
|
12 |
+
Textbox,
|
13 |
+
)
|
14 |
|
15 |
+
from chattr.graph.runner import graph
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
def app_block() -> Blocks:
|
19 |
+
"""Creates and returns the main Gradio Blocks interface for the Chattr app.
|
20 |
+
|
21 |
+
This function sets up the user interface, including video, audio, chatbot, and input controls.
|
22 |
|
23 |
Returns:
|
24 |
+
Blocks: The constructed Gradio Blocks interface for the chat application.
|
25 |
"""
|
26 |
+
with Blocks() as chat:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
with Row():
|
28 |
+
with Column():
|
29 |
+
video = PlayableVideo()
|
30 |
+
audio = Audio(sources="upload", type="filepath", format="wav")
|
31 |
+
with Column():
|
32 |
+
chatbot = Chatbot(
|
33 |
+
type="messages", show_copy_button=True, show_share_button=True
|
34 |
+
)
|
35 |
+
msg = Textbox()
|
36 |
+
with Row():
|
37 |
+
button = Button("Send", variant="primary")
|
38 |
+
ClearButton([msg, chatbot, video], variant="stop")
|
39 |
+
button.click(graph.generate_response, [msg, chatbot], [msg, chatbot, audio])
|
40 |
+
msg.submit(graph.generate_response, [msg, chatbot], [msg, chatbot, audio])
|
41 |
+
return chat
|
src/chattr/settings.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This module contains the settings for the Chattr app."""
|
2 |
+
|
3 |
+
from logging import getLogger
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import List, Literal, Self
|
6 |
+
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
from langchain_core.messages import SystemMessage
|
9 |
+
from pydantic import (
|
10 |
+
BaseModel,
|
11 |
+
DirectoryPath,
|
12 |
+
Field,
|
13 |
+
HttpUrl,
|
14 |
+
RedisDsn,
|
15 |
+
SecretStr,
|
16 |
+
StrictStr,
|
17 |
+
model_validator,
|
18 |
+
)
|
19 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
20 |
+
|
21 |
+
logger = getLogger(__name__)
|
22 |
+
|
23 |
+
load_dotenv()
|
24 |
+
|
25 |
+
|
26 |
+
class ModelSettings(BaseModel):
|
27 |
+
url: HttpUrl = Field(default=None)
|
28 |
+
name: StrictStr = Field(default=None)
|
29 |
+
api_key: SecretStr = Field(default=None)
|
30 |
+
temperature: float = Field(default=0.0, ge=0.0, le=1.0)
|
31 |
+
system_message: SystemMessage = SystemMessage(
|
32 |
+
content="You are a helpful assistant that can answer questions about the time and generate audio files from text."
|
33 |
+
)
|
34 |
+
|
35 |
+
@model_validator(mode="after")
|
36 |
+
def check_api_key_exist(self) -> Self:
|
37 |
+
"""
|
38 |
+
Ensure that an API key and model name are provided if a model URL is set.
|
39 |
+
This method validates the presence of required credentials for the model provider.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
Self: The validated ModelSettings instance.
|
43 |
+
|
44 |
+
Raises:
|
45 |
+
ValueError: If the API key or model name is missing when a model URL is provided.
|
46 |
+
"""
|
47 |
+
if self.url:
|
48 |
+
if not self.api_key or not self.api_key.get_secret_value():
|
49 |
+
raise ValueError(
|
50 |
+
"You need to provide API Key for the Model provider via `MODEL__API_KEY`"
|
51 |
+
)
|
52 |
+
if not self.name:
|
53 |
+
raise ValueError("You need to provide Model name via `MODEL__NAME`")
|
54 |
+
return self
|
55 |
+
|
56 |
+
|
57 |
+
class MemorySettings(BaseModel):
|
58 |
+
url: RedisDsn = Field(default=RedisDsn(url="redis://localhost:6379"))
|
59 |
+
|
60 |
+
|
61 |
+
class VectorDatabaseSettings(BaseModel):
|
62 |
+
name: StrictStr = Field(default="chattr")
|
63 |
+
url: HttpUrl = Field(default=HttpUrl(url="http://localhost:6333"))
|
64 |
+
|
65 |
+
|
66 |
+
class MCPSettings(BaseModel):
|
67 |
+
name: StrictStr = Field(default=None)
|
68 |
+
url: HttpUrl = Field(default=None)
|
69 |
+
command: StrictStr = Field(default=None)
|
70 |
+
args: List[StrictStr] = Field(default=[])
|
71 |
+
transport: Literal["sse", "stdio", "streamable_http", "websocket"] = Field(
|
72 |
+
default=None
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
class DirectorySettings(BaseModel):
|
77 |
+
base: DirectoryPath = Field(default_factory=lambda: Path.cwd())
|
78 |
+
assets: DirectoryPath = Field(default_factory=lambda: Path.cwd() / "assets")
|
79 |
+
log: DirectoryPath = Field(default_factory=lambda: Path.cwd() / "logs")
|
80 |
+
image: DirectoryPath = Field(
|
81 |
+
default_factory=lambda: Path.cwd() / "assets" / "image"
|
82 |
+
)
|
83 |
+
audio: DirectoryPath = Field(
|
84 |
+
default_factory=lambda: Path.cwd() / "assets" / "audio"
|
85 |
+
)
|
86 |
+
video: DirectoryPath = Field(
|
87 |
+
default_factory=lambda: Path.cwd() / "assets" / "video"
|
88 |
+
)
|
89 |
+
|
90 |
+
@model_validator(mode="after")
|
91 |
+
def create_missing_dirs(self) -> Self:
|
92 |
+
"""
|
93 |
+
Ensure that all specified directories exist, creating them if necessary.
|
94 |
+
This method checks and creates any missing directories defined in the DirectorySettings.
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
Self: The validated DirectorySettings instance.
|
98 |
+
"""
|
99 |
+
for directory in [
|
100 |
+
self.base,
|
101 |
+
self.assets,
|
102 |
+
self.log,
|
103 |
+
self.image,
|
104 |
+
self.audio,
|
105 |
+
self.video,
|
106 |
+
]:
|
107 |
+
directory.mkdir(exist_ok=True)
|
108 |
+
logger.info(f"Created directory: {directory}")
|
109 |
+
return self
|
110 |
+
|
111 |
+
|
112 |
+
class Settings(BaseSettings):
|
113 |
+
"""Configuration for the Chattr app."""
|
114 |
+
|
115 |
+
model_config = SettingsConfigDict(
|
116 |
+
env_nested_delimiter="__",
|
117 |
+
env_parse_none_str="None",
|
118 |
+
env_file=".env",
|
119 |
+
extra="ignore",
|
120 |
+
)
|
121 |
+
|
122 |
+
model: ModelSettings = ModelSettings()
|
123 |
+
memory: MemorySettings = MemorySettings()
|
124 |
+
vector_database: VectorDatabaseSettings = VectorDatabaseSettings()
|
125 |
+
voice_generator_mcp: MCPSettings = MCPSettings(
|
126 |
+
url="http://localhost:8080/gradio_api/mcp/sse",
|
127 |
+
transport="sse",
|
128 |
+
name="voice_generator",
|
129 |
+
)
|
130 |
+
video_generator_mcp: MCPSettings = MCPSettings(
|
131 |
+
url="http://localhost:8002/gradio_api/mcp/sse",
|
132 |
+
transport="sse",
|
133 |
+
name="video_generator",
|
134 |
+
)
|
135 |
+
directory: DirectorySettings = DirectorySettings()
|
136 |
+
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
print(Settings().model_dump())
|
src/chattr/utils.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This module contains utility functions for the Chattr app."""
|
2 |
+
|
3 |
+
from logging import getLogger
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
from m3u8 import M3U8, load
|
8 |
+
from pydantic import HttpUrl, ValidationError
|
9 |
+
from pydub import AudioSegment
|
10 |
+
from requests import Session
|
11 |
+
|
12 |
+
logger = getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
def is_url(value: Optional[str]) -> bool:
|
16 |
+
"""
|
17 |
+
Check if a string is a valid URL.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
value: The string to check. Can be None.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
bool: True if the string is a valid URL, False otherwise.
|
24 |
+
"""
|
25 |
+
if value is None:
|
26 |
+
return False
|
27 |
+
|
28 |
+
try:
|
29 |
+
HttpUrl(value)
|
30 |
+
return True
|
31 |
+
except ValidationError:
|
32 |
+
return False
|
33 |
+
|
34 |
+
|
35 |
+
def download_file(url: HttpUrl, path: Path) -> None:
|
36 |
+
"""
|
37 |
+
Download a file from a URL and save it to a local path.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
url: The URL to download the file from.
|
41 |
+
path: The local file path where the downloaded file will be saved.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
None
|
45 |
+
|
46 |
+
Raises:
|
47 |
+
requests.RequestException: If the HTTP request fails.
|
48 |
+
IOError: If file writing fails.
|
49 |
+
"""
|
50 |
+
if str(url).endswith(".m3u8"):
|
51 |
+
_playlist: M3U8 = load(url)
|
52 |
+
url: str = str(url).replace("playlist.m3u8", _playlist.segments[0].uri)
|
53 |
+
print(url)
|
54 |
+
session = Session()
|
55 |
+
response = session.get(url, stream=True, timeout=30)
|
56 |
+
response.raise_for_status()
|
57 |
+
with open(path, "wb") as f:
|
58 |
+
for chunk in response.iter_content(chunk_size=8192):
|
59 |
+
if chunk:
|
60 |
+
f.write(chunk)
|
61 |
+
|
62 |
+
|
63 |
+
def convert_audio_to_wav(input_path: Path, output_path: Path) -> None:
|
64 |
+
"""
|
65 |
+
Convert an audio file from aac to WAV format.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
input_path: The path to the input aac file.
|
69 |
+
output_path: The path to the output WAV file.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
None
|
73 |
+
"""
|
74 |
+
logger.info(f"Converting {input_path} to WAV format")
|
75 |
+
audio = AudioSegment.from_file(input_path, "aac")
|
76 |
+
audio.export(output_path, "wav")
|
77 |
+
logger.info(f"Converted {input_path} to {output_path}")
|
uv.lock
CHANGED
The diff for this file is too large to render.
See raw diff
|
|