Spaces:
Sleeping
Sleeping
Update src/app/main_agent.py
Browse files- src/app/main_agent.py +14 -17
src/app/main_agent.py
CHANGED
@@ -57,15 +57,8 @@ import torch
|
|
57 |
from transformers import pipeline
|
58 |
import os
|
59 |
|
60 |
-
model_name = 'meta-llama/Llama-3.1-8B' #"google/gemma-3-4b-it"
|
61 |
# Load the Gemma 3 model pipeline once
|
62 |
-
|
63 |
-
task="text-generation",
|
64 |
-
model= model_name, # or your preferred Gemma 3 model
|
65 |
-
device=0, # set -1 for CPU, 0 or other for GPU
|
66 |
-
torch_dtype=torch.bfloat16,
|
67 |
-
use_auth_token=os.getenv("HF_TOKEN")
|
68 |
-
)
|
69 |
|
70 |
def create_agent(accent_tool_obj) -> tuple[Runnable, Runnable]:
|
71 |
accent_tool = Tool(
|
@@ -94,15 +87,19 @@ def create_agent(accent_tool_obj) -> tuple[Runnable, Runnable]:
|
|
94 |
def follow_up_node(messages: list[BaseMessage]) -> AIMessage:
|
95 |
user_question = messages[-1].content
|
96 |
transcript = accent_tool_obj.last_transcript or ""
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
106 |
response_text = outputs[0]['generated_text']
|
107 |
|
108 |
return AIMessage(content=response_text)
|
|
|
57 |
from transformers import pipeline
|
58 |
import os
|
59 |
|
|
|
60 |
# Load the Gemma 3 model pipeline once
|
61 |
+
pipe = pipeline("text-generation", model="google/gemma-3-1b-it", device="cuda", torch_dtype=torch.bfloat16)
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
def create_agent(accent_tool_obj) -> tuple[Runnable, Runnable]:
|
64 |
accent_tool = Tool(
|
|
|
87 |
def follow_up_node(messages: list[BaseMessage]) -> AIMessage:
|
88 |
user_question = messages[-1].content
|
89 |
transcript = accent_tool_obj.last_transcript or ""
|
90 |
+
messages = [
|
91 |
+
[
|
92 |
+
{
|
93 |
+
"role": "system",
|
94 |
+
"content": [{"type": "text", "text": "You are a helpful assistant."},]
|
95 |
+
},
|
96 |
+
{
|
97 |
+
"role": "user",
|
98 |
+
"content": [{"type": "text", "text": "Analyse the transcript. "},]
|
99 |
+
},
|
100 |
+
],
|
101 |
+
]
|
102 |
+
outputs = pipe(prompt, max_new_tokens=256, do_sample=False)
|
103 |
response_text = outputs[0]['generated_text']
|
104 |
|
105 |
return AIMessage(content=response_text)
|