Spaces:
Running
Running
from starfish import StructuredLLM, data_factory | |
from starfish.common.env_loader import load_env_file | |
from datasets import load_dataset | |
import json | |
import asyncio | |
load_env_file() | |
def run_model_probe(model_name="together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", num_datapoints=10): | |
# Load the dataset | |
dataset = load_dataset("starfishdata/endocrinology_transcription_and_notes_and_icd_codes", split="train") | |
top_n_data = dataset.select(range(num_datapoints)) | |
# Create a list to store the parsed data | |
parsed_data = [] | |
# Process each entry | |
for idx, entry in enumerate(top_n_data): | |
# Extract transcript - get the value directly from the transcript key | |
transcript = entry["transcript"] if isinstance(entry["transcript"], str) else entry["transcript"].get("transcript", "") | |
# Extract ICD-10 code (top_1 code) | |
icd_codes_str = entry.get("icd_10_code", "{}") | |
try: | |
icd_codes = json.loads(icd_codes_str) | |
top_1_code = icd_codes.get("top_1", {}).get("code", "") | |
except json.JSONDecodeError: | |
top_1_code = "" | |
# Add to parsed data | |
parsed_data.append({"id": idx, "transcript": transcript, "icd_10_code": top_1_code}) | |
model_probe_prompt = """ | |
Given a transcript of a patient's medical history, determine the ICD-10 code that is most relevant to the patient's condition. | |
Transcript: {{transcript}} | |
Please do not return anything other than the ICD-10 code in json format. | |
like this: {"icd_10_code": "A00.0"} | |
""" | |
response_gen_llm = StructuredLLM(model_name=model_name, prompt=model_probe_prompt, output_schema=[{"name": "icd_10_code", "type": "str"}]) | |
async def model_probe_batch(input_data): | |
response = await response_gen_llm.run(transcript=input_data["transcript"]) | |
return [{"id": input_data["id"], "generated_icd_10_code": response.data[0]["icd_10_code"], "actual_icd_10_code": input_data["icd_10_code"]}] | |
def evaluate_model(): | |
data = model_probe_batch.run(input_data=parsed_data[:num_datapoints]) | |
# Calculate exact match accuracy | |
exact_matches = sum(1 for item in data if item["generated_icd_10_code"] == item["actual_icd_10_code"]) | |
total_samples = len(data) | |
accuracy = (exact_matches / total_samples) * 100 | |
return {"total_samples": total_samples, "exact_matches": exact_matches, "accuracy": accuracy} | |
return evaluate_model() | |
if __name__ == "__main__": | |
# Example usage when running this file directly | |
results = run_model_probe(model_name="together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", num_datapoints=5) | |
print(results) | |