File size: 2,840 Bytes
5301c48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from starfish import data_factory
from starfish.common.env_loader import load_env_file
from datasets import load_dataset
import json
import asyncio
import os
import random
from agents import Agent, Runner, function_tool, ModelSettings
from agents.tool import WebSearchTool
from pydantic import BaseModel, Field

load_env_file()


class DiagnosisSuggestion(BaseModel):
    code: str = Field(..., description="The suggested diagnosis code (e.g., ICD-10)")
    confidence: float = Field(..., description="Model confidence in the suggestion, between 0 and 1")
    reason: str = Field(..., description="Explanation or rationale for the suggested diagnosis")


async def run_model_gen(num_datapoints, model_name="openai/gpt-4o-mini"):
    # Get HF token from environment
    hf_token = os.getenv("HUGGING_FACE_HUB_TOKEN")

    # Load the dataset
    dataset = load_dataset("starfishdata/playground_endocronology_notes_1500", split="train", token=hf_token)

    # Get total number of samples
    total_samples = len(dataset)

    # Generate random indices
    random_indices = random.sample(range(total_samples), num_datapoints)

    # Create list of dictionaries with only transcript key
    transcript_list = [{"transcript": dataset[idx]["transcript"]} for idx in random_indices]

    # Create the Agent
    diagnosis_code_agent = Agent(
        name="Diagnosis Code Agent",
        tools=[WebSearchTool()],
        model=model_name,
        output_type=DiagnosisSuggestion,
        model_settings=ModelSettings(tool_choice="required"),
        tool_use_behavior="stop_on_first_tool",
        instructions="""
        You are an Endocrinology Medical Coding Specialist.
        You will be provided with a medical transcript describing a patient encounter.
        Your task is to analyze the medical transcript and assign the most appropriate diagnosis code(s).
        You will have access to a web search tool and only use it to search endocrinology related code and verification.
        Use it only to verify the accuracy or current validity of the diagnosis codes.
        """,
    )

    web_search_prompt = """Please select top 3 likely code from given list for this doctor and patient conversation transcript.
        Transcript: {transcript}
    """

    @data_factory(max_concurrency=100, task_runner_timeout=300)
    async def generate_data(transcript):
        diagnosis_code_result = await Runner.run(diagnosis_code_agent, input=web_search_prompt.format(transcript=transcript))

        code_result = diagnosis_code_result.final_output.model_dump()

        return [{"transcript": transcript, "icd_10_code": code_result["code"]}]

    return generate_data.run(transcript_list)


if __name__ == "__main__":
    # Run the async function
    results = asyncio.run(run_model_gen())
    print(len(results))
    print(results[0].keys())