aig2 / app.py
vitorcalvi's picture
1
5630c13
raw
history blame
5.01 kB
from fastapi import FastAPI, File, UploadFile, HTTPException, Form
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import torchaudio
import numpy as np
import tempfile
import os
import warnings
warnings.filterwarnings("ignore")
app = FastAPI()
def extract_audio_features(audio_file_path):
# Load the audio file using torchaudio
waveform, sample_rate = torchaudio.load(audio_file_path)
# Ensure waveform is mono by averaging channels if necessary
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
waveform = waveform.squeeze() # Remove channel dimension if it's 1
# Extract pitch (fundamental frequency)
pitch_frequencies, voiced_flags, _ = torchaudio.functional.detect_pitch_frequency(
waveform, sample_rate, frame_time=0.01, win_length=1024
)
f0 = pitch_frequencies[voiced_flags > 0]
# Extract energy
energy = waveform.pow(2).numpy()
# Extract MFCCs
mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=13)
mfccs = mfcc_transform(waveform.unsqueeze(0)).squeeze(0).numpy()
# Estimate speech rate (simplified)
tempo = torchaudio.functional.estimate_tempo(waveform, sample_rate)
speech_rate = tempo / 60 if tempo is not None else 0
return f0.numpy(), energy, speech_rate, mfccs, waveform.numpy(), sample_rate
def analyze_voice_stress(audio_file_path):
f0, energy, speech_rate, mfccs, waveform, sample_rate = extract_audio_features(audio_file_path)
if len(f0) == 0:
raise ValueError("Could not extract fundamental frequency from the audio.")
mean_f0 = np.mean(f0)
std_f0 = np.std(f0)
mean_energy = np.mean(energy)
std_energy = np.std(energy)
gender = 'male' if mean_f0 < 165 else 'female'
norm_mean_f0 = 110 if gender == 'male' else 220
norm_std_f0 = 20
norm_mean_energy = 0.02
norm_std_energy = 0.005
norm_speech_rate = 4.4
norm_std_speech_rate = 0.5
z_f0 = (mean_f0 - norm_mean_f0) / norm_std_f0
z_energy = (mean_energy - norm_mean_energy) / norm_std_energy
z_speech_rate = (speech_rate - norm_speech_rate) / norm_std_speech_rate
stress_score = (0.4 * z_f0) + (0.4 * z_speech_rate) + (0.2 * z_energy)
stress_level = float(1 / (1 + np.exp(-stress_score)) * 100)
categories = ["Very Low Stress", "Low Stress", "Moderate Stress", "High Stress", "Very High Stress"]
category_idx = min(int(stress_level / 20), 4)
stress_category = categories[category_idx]
return {"stress_level": stress_level, "category": stress_category, "gender": gender}
def analyze_text_stress(text: str):
stress_keywords = ["anxious", "nervous", "stress", "panic", "tense"]
stress_score = sum([1 for word in stress_keywords if word in text.lower()])
stress_level = min(stress_score * 20, 100)
categories = ["Very Low Stress", "Low Stress", "Moderate Stress", "High Stress", "Very High Stress"]
category_idx = min(int(stress_level / 20), 4)
stress_category = categories[category_idx]
return {"stress_level": stress_level, "category": stress_category}
class StressResponse(BaseModel):
stress_level: float
category: str
gender: str = None # Optional, only for audio analysis
@app.post("/analyze-stress/", response_model=StressResponse)
async def analyze_stress(
file: UploadFile = File(None),
file_path: str = Form(None),
text: str = Form(None)
):
if file is None and file_path is None and text is None:
raise HTTPException(status_code=400, detail="Either a file, file path, or text input is required.")
# Handle audio file analysis
if file or file_path:
if file:
if not file.filename.endswith(".wav"):
raise HTTPException(status_code=400, detail="Only .wav files are supported.")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
temp_file.write(await file.read())
temp_wav_path = temp_file.name
else:
if not file_path.endswith(".wav"):
raise HTTPException(status_code=400, detail="Only .wav files are supported.")
if not os.path.exists(file_path):
raise HTTPException(status_code=400, detail="File path does not exist.")
temp_wav_path = file_path
try:
result = analyze_voice_stress(temp_wav_path)
return JSONResponse(content=result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
# Clean up temporary files
if file:
os.remove(temp_wav_path)
# Handle text analysis
elif text:
result = analyze_text_stress(text)
return JSONResponse(content=result)
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", 7860)) # Use the PORT environment variable if needed
uvicorn.run("app:app", host="0.0.0.0", port=port, reload=True)