user1729's picture
main file modified
46dc492
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from typing import List, Union, Optional, Dict
import logging
from langchain.chains import SequentialChain, TransformChain
from .model import CancerClassifier, CancerExtractor
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="Cancer Text Processing API",
description="API for cancer-related text classification and information extraction",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
@app.get("/")
def read_root():
return JSONResponse({"message": "Cancer Classify & Extract API is live 🚀, Access Endpoint: /process"})
class TextInput(BaseModel):
text: Union[str, List[str]]
class ProcessingResult(BaseModel):
text: str
classification: Union[str, dict]
extraction: Union[str, dict]
error: Optional[str] = None
class BatchResponse(BaseModel):
results: List[ProcessingResult]
# Initialize models
try:
logger.info("Loading classification model...")
classification_pipeline = CancerClassifier()
logger.info("Loading extraction model...")
extraction_pipeline = CancerExtractor()
logger.info("Models loaded successfully")
except Exception as e:
logger.error(f"Failed to load models: {str(e)}")
raise RuntimeError("Could not initialize models")
def batch_classification_transform(inputs: Dict) -> Dict:
"""Process batch of texts through classification model"""
try:
texts = inputs["input_texts"]
if isinstance(texts, str):
texts = [texts] # Convert single text to batch of one
results = []
for text in texts:
try:
result = classification_pipeline.predict(text)
results.append(str(result))
except Exception as e:
logger.warning(f"Classification failed for text: {text[:50]}... Error: {str(e)}")
results.append({"error": str(e)})
return {"classification_results": results}
except Exception as e:
logger.error(f"Batch classification failed: {str(e)}")
raise
def batch_extraction_transform(inputs: Dict) -> Dict:
"""Process batch of texts through extraction model"""
try:
texts = inputs["input_texts"]
if isinstance(texts, str):
texts = [texts] # Convert single text to batch of one
results = []
for text in texts:
try:
result = extraction_pipeline.predict(text)
results.append(str(result))
except Exception as e:
logger.warning(f"Extraction failed for text: {text[:50]}... Error: {str(e)}")
results.append({"error": str(e)})
return {"extraction_results": results}
except Exception as e:
logger.error(f"Batch extraction failed: {str(e)}")
raise
# Create processing chains
classification_chain = TransformChain(
input_variables=["input_texts"],
output_variables=["classification_results"],
transform=batch_classification_transform
)
extraction_chain = TransformChain(
input_variables=["input_texts"],
output_variables=["extraction_results"],
transform=batch_extraction_transform
)
# Create sequential chain
processing_chain = SequentialChain(
chains=[classification_chain, extraction_chain],
input_variables=["input_texts"],
output_variables=["classification_results", "extraction_results"],
verbose=True
)
@app.post("/process", response_model=BatchResponse)
async def process_texts(input: TextInput):
"""
Process cancer-related texts through classification and extraction pipeline
Args:
input: TextInput object containing either a single string or list of strings
Returns:
BatchResponse with processing results for each input text
"""
try:
texts = [input.text] if isinstance(input.text, str) else input.text
# Validate input
if not isinstance(texts, list) or not all(isinstance(t, str) for t in texts):
raise HTTPException(status_code=400, detail="Input must be string or list of strings")
# Process through LangChain pipeline
chain_result = processing_chain({"input_texts": texts})
# Format results
results = []
for i, text in enumerate(texts):
classification = chain_result["classification_results"][i]
extraction = chain_result["extraction_results"][i]
error = None
if isinstance(classification, dict) and "error" in classification:
error = classification["error"]
elif isinstance(extraction, dict) and "error" in extraction:
error = extraction["error"]
results.append(ProcessingResult(
text=text,
classification=classification,
extraction=extraction,
error=error
))
return BatchResponse(results=results)
except Exception as e:
logger.error(f"Processing failed: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""Health check endpoint"""
try:
# Test with a simple cancer-related phrase
test_text = "breast cancer diagnosis"
classification_pipeline.predict(test_text)
extraction_pipeline.predict(test_text)
return {"status": "healthy", "models": ["classification", "extraction"]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))