ash-171 commited on
Commit
20cf0cb
·
verified ·
1 Parent(s): b1bd1d7

Update src/app/main_agent.py

Browse files
Files changed (1) hide show
  1. src/app/main_agent.py +68 -7
src/app/main_agent.py CHANGED
@@ -1,11 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from langchain_core.messages import BaseMessage, AIMessage
2
  from langchain_core.runnables import RunnableLambda, Runnable
3
- from langchain_community.llms import Ollama
4
  from langchain.tools import Tool
5
  from langgraph.graph import MessageGraph
6
  import re
 
 
7
 
8
- llm = Ollama(model="gemma3", temperature=0.0) # llama3.1
 
 
 
 
 
 
9
 
10
  def create_agent(accent_tool_obj) -> tuple[Runnable, Runnable]:
11
  accent_tool = Tool(
@@ -36,12 +93,16 @@ def create_agent(accent_tool_obj) -> tuple[Runnable, Runnable]:
36
  transcript = accent_tool_obj.last_transcript or ""
37
  prompt = f"""You are given this transcript of a video:
38
 
39
- \"\"\"{transcript}\"\"\"
 
 
 
 
 
 
 
40
 
41
- Now respond to the user's follow-up question: {user_question}
42
- """
43
- response = llm.invoke(prompt)
44
- return AIMessage(content=response)
45
 
46
  follow_up_agent = RunnableLambda(follow_up_node)
47
 
 
1
+ # from langchain_core.messages import BaseMessage, AIMessage
2
+ # from langchain_core.runnables import RunnableLambda, Runnable
3
+ # from langchain_community.llms import Ollama
4
+ # from langchain.tools import Tool
5
+ # from langgraph.graph import MessageGraph
6
+ # import re
7
+
8
+ # llm = Ollama(model="gemma3", temperature=0.0) # llama3.1
9
+
10
+ # def create_agent(accent_tool_obj) -> tuple[Runnable, Runnable]:
11
+ # accent_tool = Tool(
12
+ # name="AccentAnalyzer",
13
+ # func=accent_tool_obj.analyze,
14
+ # description="Analyze a public MP4 video URL and determine the English accent with transcription."
15
+ # )
16
+
17
+ # def analyze_node(messages: list[BaseMessage]) -> AIMessage:
18
+ # last_input = messages[-1].content
19
+ # match = re.search(r'https?://\S+', last_input)
20
+ # if match:
21
+ # url = match.group()
22
+ # result = accent_tool.func(url)
23
+ # else:
24
+ # result = "No valid video URL found in your message."
25
+ # return AIMessage(content=result)
26
+
27
+ # graph = MessageGraph()
28
+ # graph.add_node("analyze_accent", RunnableLambda(analyze_node))
29
+ # graph.set_entry_point("analyze_accent")
30
+ # graph.set_finish_point("analyze_accent")
31
+ # analysis_agent = graph.compile()
32
+
33
+ # # Follow-up agent that uses transcript and responds to questions
34
+ # def follow_up_node(messages: list[BaseMessage]) -> AIMessage:
35
+ # user_question = messages[-1].content
36
+ # transcript = accent_tool_obj.last_transcript or ""
37
+ # prompt = f"""You are given this transcript of a video:
38
+
39
+ # \"\"\"{transcript}\"\"\"
40
+
41
+ # Now respond to the user's follow-up question: {user_question}
42
+ # """
43
+ # response = llm.invoke(prompt)
44
+ # return AIMessage(content=response)
45
+
46
+ # follow_up_agent = RunnableLambda(follow_up_node)
47
+
48
+ # return analysis_agent, follow_up_agent
49
+
50
+
51
  from langchain_core.messages import BaseMessage, AIMessage
52
  from langchain_core.runnables import RunnableLambda, Runnable
 
53
  from langchain.tools import Tool
54
  from langgraph.graph import MessageGraph
55
  import re
56
+ import torch
57
+ from transformers import pipeline
58
 
59
+ # Load the Gemma 3 model pipeline once
60
+ gemma_pipeline = pipeline(
61
+ task="text-generation",
62
+ model="google/gemma-3-4b-it", # or your preferred Gemma 3 model
63
+ device=0, # set -1 for CPU, 0 or other for GPU
64
+ torch_dtype=torch.bfloat16
65
+ )
66
 
67
  def create_agent(accent_tool_obj) -> tuple[Runnable, Runnable]:
68
  accent_tool = Tool(
 
93
  transcript = accent_tool_obj.last_transcript or ""
94
  prompt = f"""You are given this transcript of a video:
95
 
96
+ \"\"\"{transcript}\"\"\"
97
+
98
+ Now respond to the user's follow-up question: {user_question}
99
+ """
100
+ # Use the pipeline to generate the response text
101
+ # pipeline output is a list of dicts with 'generated_text'
102
+ outputs = gemma_pipeline(prompt, max_new_tokens=256, do_sample=False)
103
+ response_text = outputs[0]['generated_text']
104
 
105
+ return AIMessage(content=response_text)
 
 
 
106
 
107
  follow_up_agent = RunnableLambda(follow_up_node)
108