Spaces:
Running
Running
File size: 6,403 Bytes
62cf4ef 6fdc19a 980d57f 6fdc19a 980d57f 62cf4ef 6fdc19a 62cf4ef 6fdc19a 62cf4ef 6fdc19a 62cf4ef 6fdc19a 62cf4ef 6fdc19a 62cf4ef 6fdc19a 62cf4ef 980d57f 62cf4ef 6fdc19a 980d57f 62cf4ef 980d57f 62cf4ef 6fdc19a 62cf4ef 55a3068 62cf4ef 55a3068 62cf4ef 6fdc19a 62cf4ef 6fdc19a 62cf4ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import json
from typing import Any
import streamlit as st
import time
from surf_spot_finder.tools import (
driving_hours_to_meters,
get_area_lat_lon,
get_surfing_spots,
get_wave_forecast,
get_wind_forecast,
)
from surf_spot_finder.config import Config
from any_agent import AgentConfig, AnyAgent, TracingConfig
from any_agent.tracing.trace import AgentTrace, TotalTokenUseAndCost
from any_agent.tracing.otel_types import StatusCode
from any_agent.evaluation import evaluate, TraceEvaluationResult
async def run_agent(user_inputs: dict[str, Any]):
st.markdown("### π Running Surf Spot Finder...")
if "huggingface" in user_inputs["model_id"]:
model_args = {
"extra_headers": {"X-HF-Bill-To": "mozilla-ai"},
"temperature": 0.0,
}
else:
model_args = {}
agent_config = AgentConfig(
model_id=user_inputs["model_id"],
model_args=model_args,
tools=[
get_wind_forecast,
get_wave_forecast,
get_area_lat_lon,
get_surfing_spots,
driving_hours_to_meters,
],
)
config = Config(
location=user_inputs["location"],
max_driving_hours=user_inputs["max_driving_hours"],
date=user_inputs["date"],
framework=user_inputs["framework"],
main_agent=agent_config,
managed_agents=[],
evaluation_cases=[user_inputs.get("evaluation_case")]
if user_inputs.get("evaluation_case")
else None,
)
agent = await AnyAgent.create_async(
agent_framework=config.framework,
agent_config=config.main_agent,
managed_agents=config.managed_agents,
tracing=TracingConfig(console=True, cost_info=True),
)
query = config.input_prompt_template.format(
LOCATION=config.location,
MAX_DRIVING_HOURS=config.max_driving_hours,
DATE=config.date,
)
st.markdown("#### π Query")
st.code(query, language="text")
start_time = time.time()
with st.spinner("π€ Analyzing surf spots..."):
agent_trace: AgentTrace = await agent.run_async(query)
agent.exit()
end_time = time.time()
execution_time = end_time - start_time
cost: TotalTokenUseAndCost = agent_trace.get_total_cost()
st.markdown("### π Results")
time_col, cost_col, tokens_col = st.columns(3)
with time_col:
st.info(f"β±οΈ Execution Time: {execution_time:.2f} seconds")
with cost_col:
st.info(f"π° Estimated Cost: ${cost.total_cost:.6f}")
with tokens_col:
st.info(f"π¦ Total Tokens: {cost.total_tokens:,}")
st.markdown("#### Final Output")
st.info(agent_trace.final_output)
# Display the agent trace in a more organized way
with st.expander("### π§© Agent Trace"):
for span in agent_trace.spans:
# Header with name and status
col1, col2 = st.columns([4, 1])
with col1:
st.markdown(f"**{span.name}**")
if span.attributes:
# st.json(span.attributes, expanded=False)
if "input.value" in span.attributes:
try:
input_value = json.loads(span.attributes["input.value"])
if isinstance(input_value, list) and len(input_value) > 0:
st.write(f"Input: {input_value[-1]}")
else:
st.write(f"Input: {input_value}")
except: # noqa: E722
st.write(f"Input: {span.attributes['input.value']}")
if "output.value" in span.attributes:
try:
output_value = json.loads(span.attributes["output.value"])
if isinstance(output_value, list) and len(output_value) > 0:
st.write(f"Output: {output_value[-1]}")
else:
st.write(f"Output: {output_value}")
except: # noqa: E722
st.write(f"Output: {span.attributes['output.value']}")
with col2:
status_color = (
"green" if span.status.status_code == StatusCode.OK else "red"
)
st.markdown(
f"<span style='color: {status_color}'>β {span.status.status_code.name}</span>",
unsafe_allow_html=True,
)
if config.evaluation_cases is not None:
assert (
len(config.evaluation_cases) == 1
), "Only one evaluation case is supported in the demo"
st.markdown("### π Evaluation Results")
with st.spinner("Evaluating results..."):
case = config.evaluation_cases[0]
result: TraceEvaluationResult = evaluate(
evaluation_case=case,
trace=agent_trace,
agent_framework=config.framework,
)
all_results = (
result.checkpoint_results
+ result.hypothesis_answer_results
+ result.direct_results
)
# Create columns for better layout
col1, col2 = st.columns(2)
with col1:
st.markdown("#### Criteria Results")
for checkpoint in all_results:
if checkpoint.passed:
st.success(f"β
{checkpoint.criteria}")
else:
st.error(f"β {checkpoint.criteria}")
with col2:
st.markdown("#### Overall Score")
total_points = sum([result.points for result in all_results])
if total_points == 0:
msg = "Total points is 0, cannot calculate score."
raise ValueError(msg)
passed_points = sum(
[result.points for result in all_results if result.passed]
)
# Create a nice score display
st.markdown(f"### {passed_points}/{total_points}")
percentage = (passed_points / total_points) * 100
st.progress(percentage / 100)
st.markdown(f"**{percentage:.1f}%**")
|