Spaces:
Running
Running
File size: 2,840 Bytes
5301c48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
from starfish import data_factory
from starfish.common.env_loader import load_env_file
from datasets import load_dataset
import json
import asyncio
import os
import random
from agents import Agent, Runner, function_tool, ModelSettings
from agents.tool import WebSearchTool
from pydantic import BaseModel, Field
load_env_file()
class DiagnosisSuggestion(BaseModel):
code: str = Field(..., description="The suggested diagnosis code (e.g., ICD-10)")
confidence: float = Field(..., description="Model confidence in the suggestion, between 0 and 1")
reason: str = Field(..., description="Explanation or rationale for the suggested diagnosis")
async def run_model_gen(num_datapoints, model_name="openai/gpt-4o-mini"):
# Get HF token from environment
hf_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
# Load the dataset
dataset = load_dataset("starfishdata/playground_endocronology_notes_1500", split="train", token=hf_token)
# Get total number of samples
total_samples = len(dataset)
# Generate random indices
random_indices = random.sample(range(total_samples), num_datapoints)
# Create list of dictionaries with only transcript key
transcript_list = [{"transcript": dataset[idx]["transcript"]} for idx in random_indices]
# Create the Agent
diagnosis_code_agent = Agent(
name="Diagnosis Code Agent",
tools=[WebSearchTool()],
model=model_name,
output_type=DiagnosisSuggestion,
model_settings=ModelSettings(tool_choice="required"),
tool_use_behavior="stop_on_first_tool",
instructions="""
You are an Endocrinology Medical Coding Specialist.
You will be provided with a medical transcript describing a patient encounter.
Your task is to analyze the medical transcript and assign the most appropriate diagnosis code(s).
You will have access to a web search tool and only use it to search endocrinology related code and verification.
Use it only to verify the accuracy or current validity of the diagnosis codes.
""",
)
web_search_prompt = """Please select top 3 likely code from given list for this doctor and patient conversation transcript.
Transcript: {transcript}
"""
@data_factory(max_concurrency=100, task_runner_timeout=300)
async def generate_data(transcript):
diagnosis_code_result = await Runner.run(diagnosis_code_agent, input=web_search_prompt.format(transcript=transcript))
code_result = diagnosis_code_result.final_output.model_dump()
return [{"transcript": transcript, "icd_10_code": code_result["code"]}]
return generate_data.run(transcript_list)
if __name__ == "__main__":
# Run the async function
results = asyncio.run(run_model_gen())
print(len(results))
print(results[0].keys())
|