File size: 5,809 Bytes
3426410 4fb66c9 3426410 46dc492 3426410 4fb66c9 c22866b 4fb66c9 3426410 2cb976e 3426410 642d3cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from typing import List, Union, Optional, Dict
import logging
from langchain.chains import SequentialChain, TransformChain
from .model import CancerClassifier, CancerExtractor
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="Cancer Text Processing API",
description="API for cancer-related text classification and information extraction",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
@app.get("/")
def read_root():
return JSONResponse({"message": "Cancer Classify & Extract API is live 🚀, Access Endpoint: /process"})
class TextInput(BaseModel):
text: Union[str, List[str]]
class ProcessingResult(BaseModel):
text: str
classification: Union[str, dict]
extraction: Union[str, dict]
error: Optional[str] = None
class BatchResponse(BaseModel):
results: List[ProcessingResult]
# Initialize models
try:
logger.info("Loading classification model...")
classification_pipeline = CancerClassifier()
logger.info("Loading extraction model...")
extraction_pipeline = CancerExtractor()
logger.info("Models loaded successfully")
except Exception as e:
logger.error(f"Failed to load models: {str(e)}")
raise RuntimeError("Could not initialize models")
def batch_classification_transform(inputs: Dict) -> Dict:
"""Process batch of texts through classification model"""
try:
texts = inputs["input_texts"]
if isinstance(texts, str):
texts = [texts] # Convert single text to batch of one
results = []
for text in texts:
try:
result = classification_pipeline.predict(text)
results.append(str(result))
except Exception as e:
logger.warning(f"Classification failed for text: {text[:50]}... Error: {str(e)}")
results.append({"error": str(e)})
return {"classification_results": results}
except Exception as e:
logger.error(f"Batch classification failed: {str(e)}")
raise
def batch_extraction_transform(inputs: Dict) -> Dict:
"""Process batch of texts through extraction model"""
try:
texts = inputs["input_texts"]
if isinstance(texts, str):
texts = [texts] # Convert single text to batch of one
results = []
for text in texts:
try:
result = extraction_pipeline.predict(text)
results.append(str(result))
except Exception as e:
logger.warning(f"Extraction failed for text: {text[:50]}... Error: {str(e)}")
results.append({"error": str(e)})
return {"extraction_results": results}
except Exception as e:
logger.error(f"Batch extraction failed: {str(e)}")
raise
# Create processing chains
classification_chain = TransformChain(
input_variables=["input_texts"],
output_variables=["classification_results"],
transform=batch_classification_transform
)
extraction_chain = TransformChain(
input_variables=["input_texts"],
output_variables=["extraction_results"],
transform=batch_extraction_transform
)
# Create sequential chain
processing_chain = SequentialChain(
chains=[classification_chain, extraction_chain],
input_variables=["input_texts"],
output_variables=["classification_results", "extraction_results"],
verbose=True
)
@app.post("/process", response_model=BatchResponse)
async def process_texts(input: TextInput):
"""
Process cancer-related texts through classification and extraction pipeline
Args:
input: TextInput object containing either a single string or list of strings
Returns:
BatchResponse with processing results for each input text
"""
try:
texts = [input.text] if isinstance(input.text, str) else input.text
# Validate input
if not isinstance(texts, list) or not all(isinstance(t, str) for t in texts):
raise HTTPException(status_code=400, detail="Input must be string or list of strings")
# Process through LangChain pipeline
chain_result = processing_chain({"input_texts": texts})
# Format results
results = []
for i, text in enumerate(texts):
classification = chain_result["classification_results"][i]
extraction = chain_result["extraction_results"][i]
error = None
if isinstance(classification, dict) and "error" in classification:
error = classification["error"]
elif isinstance(extraction, dict) and "error" in extraction:
error = extraction["error"]
results.append(ProcessingResult(
text=text,
classification=classification,
extraction=extraction,
error=error
))
return BatchResponse(results=results)
except Exception as e:
logger.error(f"Processing failed: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""Health check endpoint"""
try:
# Test with a simple cancer-related phrase
test_text = "breast cancer diagnosis"
classification_pipeline.predict(test_text)
extraction_pipeline.predict(test_text)
return {"status": "healthy", "models": ["classification", "extraction"]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
|