ash-171 commited on
Commit
a182f79
·
verified ·
1 Parent(s): ec3ab59

Update src/app/main_agent.py

Browse files
Files changed (1) hide show
  1. src/app/main_agent.py +14 -17
src/app/main_agent.py CHANGED
@@ -57,15 +57,8 @@ import torch
57
  from transformers import pipeline
58
  import os
59
 
60
- model_name = 'meta-llama/Llama-3.1-8B' #"google/gemma-3-4b-it"
61
  # Load the Gemma 3 model pipeline once
62
- gemma_pipeline = pipeline(
63
- task="text-generation",
64
- model= model_name, # or your preferred Gemma 3 model
65
- device=0, # set -1 for CPU, 0 or other for GPU
66
- torch_dtype=torch.bfloat16,
67
- use_auth_token=os.getenv("HF_TOKEN")
68
- )
69
 
70
  def create_agent(accent_tool_obj) -> tuple[Runnable, Runnable]:
71
  accent_tool = Tool(
@@ -94,15 +87,19 @@ def create_agent(accent_tool_obj) -> tuple[Runnable, Runnable]:
94
  def follow_up_node(messages: list[BaseMessage]) -> AIMessage:
95
  user_question = messages[-1].content
96
  transcript = accent_tool_obj.last_transcript or ""
97
- prompt = f"""You are given this transcript of a video:
98
-
99
- \"\"\"{transcript}\"\"\"
100
-
101
- Now respond to the user's follow-up question: {user_question}
102
- """
103
- # Use the pipeline to generate the response text
104
- # pipeline output is a list of dicts with 'generated_text'
105
- outputs = gemma_pipeline(prompt, max_new_tokens=256, do_sample=False)
 
 
 
 
106
  response_text = outputs[0]['generated_text']
107
 
108
  return AIMessage(content=response_text)
 
57
  from transformers import pipeline
58
  import os
59
 
 
60
  # Load the Gemma 3 model pipeline once
61
+ pipe = pipeline("text-generation", model="google/gemma-3-1b-it", device="cuda", torch_dtype=torch.bfloat16)
 
 
 
 
 
 
62
 
63
  def create_agent(accent_tool_obj) -> tuple[Runnable, Runnable]:
64
  accent_tool = Tool(
 
87
  def follow_up_node(messages: list[BaseMessage]) -> AIMessage:
88
  user_question = messages[-1].content
89
  transcript = accent_tool_obj.last_transcript or ""
90
+ messages = [
91
+ [
92
+ {
93
+ "role": "system",
94
+ "content": [{"type": "text", "text": "You are a helpful assistant."},]
95
+ },
96
+ {
97
+ "role": "user",
98
+ "content": [{"type": "text", "text": "Analyse the transcript. "},]
99
+ },
100
+ ],
101
+ ]
102
+ outputs = pipe(prompt, max_new_tokens=256, do_sample=False)
103
  response_text = outputs[0]['generated_text']
104
 
105
  return AIMessage(content=response_text)