Spaces:
Running
Running
| from starfish import StructuredLLM, data_factory | |
| from starfish.common.env_loader import load_env_file | |
| from datasets import load_dataset | |
| import json | |
| import asyncio | |
| load_env_file() | |
| def run_model_probe(model_name="together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", num_datapoints=10): | |
| # Load the dataset | |
| dataset = load_dataset("starfishdata/endocrinology_transcription_and_notes_and_icd_codes", split="train") | |
| top_n_data = dataset.select(range(num_datapoints)) | |
| # Create a list to store the parsed data | |
| parsed_data = [] | |
| # Process each entry | |
| for idx, entry in enumerate(top_n_data): | |
| # Extract transcript - get the value directly from the transcript key | |
| transcript = entry["transcript"] if isinstance(entry["transcript"], str) else entry["transcript"].get("transcript", "") | |
| # Extract ICD-10 code (top_1 code) | |
| icd_codes_str = entry.get("icd_10_code", "{}") | |
| try: | |
| icd_codes = json.loads(icd_codes_str) | |
| top_1_code = icd_codes.get("top_1", {}).get("code", "") | |
| except json.JSONDecodeError: | |
| top_1_code = "" | |
| # Add to parsed data | |
| parsed_data.append({"id": idx, "transcript": transcript, "icd_10_code": top_1_code}) | |
| model_probe_prompt = """ | |
| Given a transcript of a patient's medical history, determine the ICD-10 code that is most relevant to the patient's condition. | |
| Transcript: {{transcript}} | |
| Please do not return anything other than the ICD-10 code in json format. | |
| like this: {"icd_10_code": "A00.0"} | |
| """ | |
| response_gen_llm = StructuredLLM(model_name=model_name, prompt=model_probe_prompt, output_schema=[{"name": "icd_10_code", "type": "str"}]) | |
| async def model_probe_batch(input_data): | |
| response = await response_gen_llm.run(transcript=input_data["transcript"]) | |
| return [{"id": input_data["id"], "generated_icd_10_code": response.data[0]["icd_10_code"], "actual_icd_10_code": input_data["icd_10_code"]}] | |
| def evaluate_model(): | |
| data = model_probe_batch.run(input_data=parsed_data[:num_datapoints]) | |
| # Calculate exact match accuracy | |
| exact_matches = sum(1 for item in data if item["generated_icd_10_code"] == item["actual_icd_10_code"]) | |
| total_samples = len(data) | |
| accuracy = (exact_matches / total_samples) * 100 | |
| return {"total_samples": total_samples, "exact_matches": exact_matches, "accuracy": accuracy} | |
| return evaluate_model() | |
| if __name__ == "__main__": | |
| # Example usage when running this file directly | |
| results = run_model_probe(model_name="together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", num_datapoints=5) | |
| print(results) | |