ash-171 commited on
Commit
d429407
·
verified ·
1 Parent(s): 74cfe7e

Update src/app/main_agent.py

Browse files
Files changed (1) hide show
  1. src/app/main_agent.py +48 -117
src/app/main_agent.py CHANGED
@@ -1,119 +1,50 @@
1
- # from langchain_core.messages import BaseMessage, AIMessage
2
- # from langchain_core.runnables import RunnableLambda, Runnable
3
- # from langchain_community.llms import Ollama
4
- # from langchain.tools import Tool
5
- # from langgraph.graph import MessageGraph
6
- # import re
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- # llm = Ollama(model="gemma3", temperature=0.0) # llama3.1
9
 
10
- # def create_agent(accent_tool_obj) -> tuple[Runnable, Runnable]:
11
- # accent_tool = Tool(
12
- # name="AccentAnalyzer",
13
- # func=accent_tool_obj.analyze,
14
- # description="Analyze a public MP4 video URL and determine the English accent with transcription."
15
- # )
16
-
17
- # def analyze_node(messages: list[BaseMessage]) -> AIMessage:
18
- # last_input = messages[-1].content
19
- # match = re.search(r'https?://\S+', last_input)
20
- # if match:
21
- # url = match.group()
22
- # result = accent_tool.func(url)
23
- # else:
24
- # result = "No valid video URL found in your message."
25
- # return AIMessage(content=result)
26
-
27
- # graph = MessageGraph()
28
- # graph.add_node("analyze_accent", RunnableLambda(analyze_node))
29
- # graph.set_entry_point("analyze_accent")
30
- # graph.set_finish_point("analyze_accent")
31
- # analysis_agent = graph.compile()
32
-
33
- # # Follow-up agent that uses transcript and responds to questions
34
- # def follow_up_node(messages: list[BaseMessage]) -> AIMessage:
35
- # user_question = messages[-1].content
36
- # transcript = accent_tool_obj.last_transcript or ""
37
- # prompt = f"""You are given this transcript of a video:
38
-
39
- # \"\"\"{transcript}\"\"\"
40
-
41
- # Now respond to the user's follow-up question: {user_question}
42
- # """
43
- # response = llm.invoke(prompt)
44
- # return AIMessage(content=response)
45
-
46
- # follow_up_agent = RunnableLambda(follow_up_node)
47
-
48
- # return analysis_agent, follow_up_agent
49
-
50
-
51
- from langchain_core.messages import BaseMessage, AIMessage
52
- from langchain_core.runnables import RunnableLambda, Runnable
53
- from langchain.tools import Tool
54
- from langgraph.graph import MessageGraph
55
- import re
56
- import torch
57
- from transformers import pipeline
58
- import os
59
- # Load model directly
60
- from transformers import AutoTokenizer, AutoModelForCausalLM
61
-
62
- tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
63
- model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B")
64
-
65
- def create_agent(accent_tool_obj) -> tuple[Runnable, Runnable]:
66
- accent_tool = Tool(
67
- name="AccentAnalyzer",
68
- func=accent_tool_obj.analyze,
69
- description="Analyze a public MP4 video URL and determine the English accent with transcription."
70
- )
71
-
72
- def analyze_node(messages: list[BaseMessage]) -> AIMessage:
73
- last_input = messages[-1].content
74
- match = re.search(r'https?://\S+', last_input)
75
- if match:
76
- url = match.group()
77
- result = accent_tool.func(url)
78
- else:
79
- result = "No valid video URL found in your message."
80
- return AIMessage(content=result)
81
-
82
- graph = MessageGraph()
83
- graph.add_node("analyze_accent", RunnableLambda(analyze_node))
84
- graph.set_entry_point("analyze_accent")
85
- graph.set_finish_point("analyze_accent")
86
- analysis_agent = graph.compile()
87
-
88
- # Follow-up agent that uses transcript and responds to questions
89
- def follow_up_node(messages: list[BaseMessage]) -> AIMessage:
90
- user_question = messages[-1].content
91
- transcript = accent_tool_obj.last_transcript or ""
92
- messages = [
93
- [
94
- {
95
- "role": "system",
96
- "content": [{"type": "text", "text": "You are a helpful assistant."},]
97
- },
98
- {
99
- "role": "user",
100
- "content": [{"type": "text", "text": "Analyse the transcript. "},]
101
- },
102
- ],
103
- ]
104
- inputs = tokenizer.apply_chat_template(
105
- messages,
106
- add_generation_prompt=True,
107
- tokenize=True,
108
- return_dict=True,
109
- return_tensors="pt",
110
- )
111
- outputs = model.generate(**inputs, max_new_tokens=64)
112
- outputs = tokenizer.batch_decode(outputs)
113
- response_text = outputs[0]['generated_text']
114
-
115
- return AIMessage(content=response_text)
116
-
117
- follow_up_agent = RunnableLambda(follow_up_node)
118
-
119
- return analysis_agent, follow_up_agent
 
1
+ from langchain_core.messages import BaseMessage, AIMessage
2
+ from langchain_core.runnables import RunnableLambda, Runnable
3
+ from langchain_community.llms import Ollama
4
+ from langchain.tools import Tool
5
+ from langgraph.graph import MessageGraph
6
+ import re
7
+
8
+ llm = Ollama(model="gemma3", temperature=0.0) # llama3.1
9
+
10
+ def create_agent(accent_tool_obj) -> tuple[Runnable, Runnable]:
11
+ accent_tool = Tool(
12
+ name="AccentAnalyzer",
13
+ func=accent_tool_obj.analyze,
14
+ description="Analyze a public MP4 video URL and determine the English accent with transcription."
15
+ )
16
+
17
+ def analyze_node(messages: list[BaseMessage]) -> AIMessage:
18
+ last_input = messages[-1].content
19
+ match = re.search(r'https?://\S+', last_input)
20
+ if match:
21
+ url = match.group()
22
+ result = accent_tool.func(url)
23
+ else:
24
+ result = "No valid video URL found in your message."
25
+ return AIMessage(content=result)
26
+
27
+ graph = MessageGraph()
28
+ graph.add_node("analyze_accent", RunnableLambda(analyze_node))
29
+ graph.set_entry_point("analyze_accent")
30
+ graph.set_finish_point("analyze_accent")
31
+ analysis_agent = graph.compile()
32
+
33
+ # Follow-up agent that uses transcript and responds to questions
34
+ def follow_up_node(messages: list[BaseMessage]) -> AIMessage:
35
+ user_question = messages[-1].content
36
+ transcript = accent_tool_obj.last_transcript or ""
37
+ prompt = f"""You are given this transcript of a video:
38
+
39
+ \"\"\"{transcript}\"\"\"
40
+
41
+ Now respond to the user's follow-up question: {user_question}
42
+ """
43
+ response = llm.invoke(prompt)
44
+ return AIMessage(content=response)
45
+
46
+ follow_up_agent = RunnableLambda(follow_up_node)
47
+
48
+ return analysis_agent, follow_up_agent
49
 
 
50