Spaces:
Sleeping
Sleeping
File size: 5,014 Bytes
1e1ecd3 5630c13 1e1ecd3 5630c13 1e1ecd3 5630c13 1e1ecd3 5630c13 1e1ecd3 5630c13 1e1ecd3 5630c13 1e1ecd3 5630c13 1e1ecd3 5630c13 1e1ecd3 50c246b 1e1ecd3 50c246b 1e1ecd3 50c246b 1e1ecd3 5630c13 459fd87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
from fastapi import FastAPI, File, UploadFile, HTTPException, Form
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import torchaudio
import numpy as np
import tempfile
import os
import warnings
warnings.filterwarnings("ignore")
app = FastAPI()
def extract_audio_features(audio_file_path):
# Load the audio file using torchaudio
waveform, sample_rate = torchaudio.load(audio_file_path)
# Ensure waveform is mono by averaging channels if necessary
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
waveform = waveform.squeeze() # Remove channel dimension if it's 1
# Extract pitch (fundamental frequency)
pitch_frequencies, voiced_flags, _ = torchaudio.functional.detect_pitch_frequency(
waveform, sample_rate, frame_time=0.01, win_length=1024
)
f0 = pitch_frequencies[voiced_flags > 0]
# Extract energy
energy = waveform.pow(2).numpy()
# Extract MFCCs
mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=13)
mfccs = mfcc_transform(waveform.unsqueeze(0)).squeeze(0).numpy()
# Estimate speech rate (simplified)
tempo = torchaudio.functional.estimate_tempo(waveform, sample_rate)
speech_rate = tempo / 60 if tempo is not None else 0
return f0.numpy(), energy, speech_rate, mfccs, waveform.numpy(), sample_rate
def analyze_voice_stress(audio_file_path):
f0, energy, speech_rate, mfccs, waveform, sample_rate = extract_audio_features(audio_file_path)
if len(f0) == 0:
raise ValueError("Could not extract fundamental frequency from the audio.")
mean_f0 = np.mean(f0)
std_f0 = np.std(f0)
mean_energy = np.mean(energy)
std_energy = np.std(energy)
gender = 'male' if mean_f0 < 165 else 'female'
norm_mean_f0 = 110 if gender == 'male' else 220
norm_std_f0 = 20
norm_mean_energy = 0.02
norm_std_energy = 0.005
norm_speech_rate = 4.4
norm_std_speech_rate = 0.5
z_f0 = (mean_f0 - norm_mean_f0) / norm_std_f0
z_energy = (mean_energy - norm_mean_energy) / norm_std_energy
z_speech_rate = (speech_rate - norm_speech_rate) / norm_std_speech_rate
stress_score = (0.4 * z_f0) + (0.4 * z_speech_rate) + (0.2 * z_energy)
stress_level = float(1 / (1 + np.exp(-stress_score)) * 100)
categories = ["Very Low Stress", "Low Stress", "Moderate Stress", "High Stress", "Very High Stress"]
category_idx = min(int(stress_level / 20), 4)
stress_category = categories[category_idx]
return {"stress_level": stress_level, "category": stress_category, "gender": gender}
def analyze_text_stress(text: str):
stress_keywords = ["anxious", "nervous", "stress", "panic", "tense"]
stress_score = sum([1 for word in stress_keywords if word in text.lower()])
stress_level = min(stress_score * 20, 100)
categories = ["Very Low Stress", "Low Stress", "Moderate Stress", "High Stress", "Very High Stress"]
category_idx = min(int(stress_level / 20), 4)
stress_category = categories[category_idx]
return {"stress_level": stress_level, "category": stress_category}
class StressResponse(BaseModel):
stress_level: float
category: str
gender: str = None # Optional, only for audio analysis
@app.post("/analyze-stress/", response_model=StressResponse)
async def analyze_stress(
file: UploadFile = File(None),
file_path: str = Form(None),
text: str = Form(None)
):
if file is None and file_path is None and text is None:
raise HTTPException(status_code=400, detail="Either a file, file path, or text input is required.")
# Handle audio file analysis
if file or file_path:
if file:
if not file.filename.endswith(".wav"):
raise HTTPException(status_code=400, detail="Only .wav files are supported.")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
temp_file.write(await file.read())
temp_wav_path = temp_file.name
else:
if not file_path.endswith(".wav"):
raise HTTPException(status_code=400, detail="Only .wav files are supported.")
if not os.path.exists(file_path):
raise HTTPException(status_code=400, detail="File path does not exist.")
temp_wav_path = file_path
try:
result = analyze_voice_stress(temp_wav_path)
return JSONResponse(content=result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
# Clean up temporary files
if file:
os.remove(temp_wav_path)
# Handle text analysis
elif text:
result = analyze_text_stress(text)
return JSONResponse(content=result)
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", 7860)) # Use the PORT environment variable if needed
uvicorn.run("app:app", host="0.0.0.0", port=port, reload=True)
|