Connor Adams commited on
Commit
4cc3ef0
·
1 Parent(s): 3c26361

Use Gemini

Browse files
agent.py CHANGED
@@ -1,15 +1,43 @@
 
1
  import logfire
2
 
3
  from pydantic_ai import Agent
4
- from tools import safe_duckduckgo_search_tool, get_youtube_transcript
 
 
5
  logfire.configure()
6
  logfire.instrument_pydantic_ai()
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  class BasicAgent:
9
  def __init__(self):
10
  self.agent = Agent(
11
- "openai:o3-mini",
12
- tools=[safe_duckduckgo_search_tool(), get_youtube_transcript],
13
  system_prompt="You are a helpful assistant that can answer questions about the world.",
14
  )
15
 
 
1
+ import os
2
  import logfire
3
 
4
  from pydantic_ai import Agent
5
+ from google import genai
6
+ from google.genai import types
7
+
8
  logfire.configure()
9
  logfire.instrument_pydantic_ai()
10
 
11
+ def web_search_tool(question: str) -> str | None:
12
+ """Given a question only, search the web to answer the question.
13
+
14
+ Args:
15
+ question (str): Question to answer
16
+
17
+ Returns:
18
+ str: Answer to the question
19
+
20
+ Raises:
21
+ RuntimeError: If processing fails"""
22
+ try:
23
+ client = genai.Client(api_key=os.environ["GEMINI_API_KEY"])
24
+ response = client.models.generate_content(
25
+ model="gemini-2.5-flash-preview-05-20",
26
+ contents=question,
27
+ config=types.GenerateContentConfig(
28
+ tools=[types.Tool(google_search=types.GoogleSearch())]
29
+ )
30
+ )
31
+
32
+ return response.text
33
+ except Exception as e:
34
+ raise RuntimeError(f"Processing failed: {str(e)}") from e
35
+
36
  class BasicAgent:
37
  def __init__(self):
38
  self.agent = Agent(
39
+ "gemini-2.5-flash-preview-05-20",
40
+ tools=[web_search_tool],
41
  system_prompt="You are a helpful assistant that can answer questions about the world.",
42
  )
43
 
tools/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from .safe_duck import safe_duckduckgo_search_tool
2
- from .youtube_transcript_tool import get_youtube_transcript
 
 
 
tools/safe_duck.py DELETED
@@ -1,114 +0,0 @@
1
- """
2
- safe_duck_tool.py
3
- A resilient, family-friendly DuckDuckGo search Tool for Pydantic-AI.
4
- """
5
-
6
- from __future__ import annotations
7
-
8
- import functools
9
- import time
10
- from dataclasses import dataclass
11
- from typing_extensions import TypedDict
12
-
13
- import anyio
14
- import anyio.to_thread
15
- from duckduckgo_search import DDGS, exceptions
16
- from pydantic import TypeAdapter
17
- from pydantic_ai.tools import Tool
18
-
19
-
20
- # ──────────────────────────────────────────────────────────────────────────
21
- # 1. Types
22
- # ──────────────────────────────────────────────────────────────────────────
23
- class DuckDuckGoResult(TypedDict):
24
- title: str
25
- href: str
26
- body: str
27
-
28
-
29
- duckduckgo_ta = TypeAdapter(list[DuckDuckGoResult])
30
-
31
-
32
- # ──────────────────────────────────────────────────────────────────────────
33
- # 2. Search wrapper with cache + back-off
34
- # ──────────────────────────────────────────────────────────────────────────
35
- @functools.lru_cache(maxsize=512)
36
- def _safe_search(
37
- query: str,
38
- *,
39
- ddgs_constructor_kwargs_tuple: tuple,
40
- safesearch: str,
41
- max_results: int | None,
42
- retries: int = 5,
43
- ) -> list[dict[str, str]]:
44
- wait = 1
45
- for _ in range(retries):
46
- try:
47
-
48
- ddgs = DDGS(**dict(ddgs_constructor_kwargs_tuple))
49
-
50
- return list(
51
- ddgs.text(query, safesearch=safesearch, max_results=max_results)
52
- )
53
- except exceptions.RatelimitException as e:
54
- time.sleep(getattr(e, "retry_after", wait))
55
- wait = min(wait * 2, 30)
56
- raise RuntimeError("DuckDuckGo kept rate-limiting after multiple attempts")
57
-
58
-
59
- # ──────────────────────────────────────────────────────────────────────────
60
- # 3. Tool implementation
61
- # ──────────────────────────────────────────────────────────────────────────
62
- @dataclass
63
- class _SafeDuckToolImpl:
64
- ddgs_constructor_kwargs: dict # Renamed from client_kwargs
65
- safesearch: str # Added to store safesearch setting
66
- max_results: int | None
67
-
68
- async def __call__(self, query: str) -> list[DuckDuckGoResult]:
69
- search = functools.partial(
70
- _safe_search,
71
- # Convert dict to sorted tuple of items to make it hashable
72
- ddgs_constructor_kwargs_tuple=tuple(
73
- sorted(self.ddgs_constructor_kwargs.items())
74
- ),
75
- safesearch=self.safesearch, # Pass stored safesearch
76
- max_results=self.max_results,
77
- )
78
- results = await anyio.to_thread.run_sync(search, query)
79
- # validate & coerce with Pydantic
80
- return duckduckgo_ta.validate_python(results)
81
-
82
-
83
- def safe_duckduckgo_search_tool(
84
- *,
85
- safesearch: str = "moderate", # "on" | "moderate" | "off"
86
- timeout: int = 15,
87
- max_results: int | None = None,
88
- proxy: str | None = None, # e.g. "socks5h://user:pw@host:1080"
89
- ) -> Tool:
90
- """
91
- Create a resilient, Safe-Search-enabled DuckDuckGo search Tool.
92
-
93
- Drop-in replacement for `pydantic_ai.common_tools.duckduckgo.duckduckgo_search_tool`.
94
- """
95
- # Arguments for DDGS constructor
96
- ddgs_constructor_kwargs = dict(
97
- timeout=timeout,
98
- proxy=proxy,
99
- )
100
- # Arguments for ddgs.text() method are handled separately (safesearch, max_results)
101
-
102
- impl = _SafeDuckToolImpl(
103
- ddgs_constructor_kwargs=ddgs_constructor_kwargs,
104
- safesearch=safesearch,
105
- max_results=max_results,
106
- )
107
- return Tool(
108
- impl.__call__,
109
- name="safe_duckduckgo_search",
110
- description=(
111
- "DuckDuckGo web search with Safe Search, automatic back-off, and "
112
- "LRU caching. Pass a plain-text query; returns a list of results."
113
- ),
114
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/youtube_transcript_tool.py DELETED
@@ -1,19 +0,0 @@
1
- from youtube_transcript_api import YouTubeTranscriptApi
2
-
3
- def get_youtube_transcript(video_id: str) -> str:
4
- """
5
- Fetches the transcript for a given YouTube video ID.
6
-
7
- Args:
8
- video_id: The ID of the YouTube video.
9
-
10
- Returns:
11
- The transcript of the video as a string, or an error message if the transcript cannot be fetched.
12
- """
13
- try:
14
- transcript_list = YouTubeTranscriptApi.list_transcripts(video_id)
15
- transcript = transcript_list.find_generated_transcript(['en'])
16
- fetched_transcript = transcript.fetch()
17
- return " ".join([segment['text'] for segment in fetched_transcript])
18
- except Exception as e:
19
- return f"Error fetching transcript: {str(e)}"