File size: 6,781 Bytes
da890b5
 
fcfe145
da890b5
fcfe145
 
 
da890b5
fcfe145
 
da890b5
fcfe145
 
da890b5
fcfe145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da890b5
fcfe145
 
 
da890b5
fcfe145
 
 
da890b5
fcfe145
 
 
da890b5
fcfe145
 
 
 
 
 
 
 
 
da890b5
fcfe145
 
da890b5
fcfe145
 
 
 
 
 
 
 
 
 
 
 
 
 
da890b5
 
 
fcfe145
 
da890b5
 
fcfe145
 
da890b5
 
fcfe145
da890b5
 
 
 
 
fcfe145
da890b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcfe145
 
 
 
 
 
da890b5
 
 
 
 
fcfe145
 
 
da890b5
 
 
fcfe145
da890b5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import gradio as gr
from transformers import pipeline
import re

# Initialize pipelines
# Summarization pipeline with FLAN-T5
summarizer = pipeline("text2text-generation", model="google/flan-t5-small", tokenizer="google/flan-t5-small")

# Sentiment analysis pipeline
sentiment_analyzer = pipeline("sentiment-analysis", model="cardiffnlp/twitter-roberta-base-sentiment-latest")

# Automatic speech recognition pipeline for audio
asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-small")

def convert_to_json(transcript_text):
    """
    Convert the transcript into a structured JSON format.
    Attempts to identify speaker turns based on lines starting with 'Therapist:' or 'Client:'.
    If no clear pattern is found, the entire transcript is considered one turn.
    """
    lines = transcript_text.strip().split("\n")
    session_data = []
    
    # Regex patterns to identify lines with a speaker
    therapist_pattern = re.compile(r"^\s*(Therapist|T):", re.IGNORECASE)
    client_pattern = re.compile(r"^\s*(Client|C):", re.IGNORECASE)
    
    current_speaker = None
    current_text = []
    
    for line in lines:
        line = line.strip()
        if therapist_pattern.match(line):
            # If we have accumulated text from previous speaker, store it
            if current_speaker and current_text:
                session_data.append({"speaker": current_speaker, "text": " ".join(current_text).strip()})
                current_text = []
            
            current_speaker = "Therapist"
            # Remove the speaker prefix
            text_part = therapist_pattern.sub("", line).strip()
            current_text.append(text_part)
        
        elif client_pattern.match(line):
            if current_speaker and current_text:
                session_data.append({"speaker": current_speaker, "text": " ".join(current_text).strip()})
                current_text = []
            
            current_speaker = "Client"
            text_part = client_pattern.sub("", line).strip()
            current_text.append(text_part)
        
        else:
            # Just text, append to current speaker's segment if identified
            if current_speaker is None:
                # No speaker identified yet, assume unknown
                current_speaker = "Unknown"
            current_text.append(line)

    # Append the last collected segment
    if current_speaker and current_text:
        session_data.append({"speaker": current_speaker, "text": " ".join(current_text).strip()})

    # If no speakers identified at all and just one big chunk, still return it as JSON
    if not session_data:
        session_data = [{"speaker": "Unknown", "text": transcript_text.strip()}]

    # Create a final JSON structure
    json_data = {"session": session_data}
    return json_data

def analyze_session(transcript, custom_instruction, audio):
    # If audio is provided, we transcribe it and ignore the text transcript field
    if audio is not None:
        # Transcribe audio
        asr_result = asr_pipeline(audio)
        transcript_text = asr_result['text']
    else:
        # Use the provided transcript text
        transcript_text = transcript
    
    if not transcript_text.strip():
        return "Please provide a transcript or an audio file."
    
    # Convert transcript to JSON
    json_data = convert_to_json(transcript_text)

    # Prepare the prompt for summarization
    prompt = (
        "You are a helpful assistant that summarizes psychotherapy sessions. "
        "The session is provided in JSON format with speaker turns. "
        "Summarize the key themes, emotional shifts, and patterns from this session. "
    )
    if custom_instruction.strip():
        prompt += f" Additionally, {custom_instruction.strip()}"
    prompt += "\n\nJSON data:\n" + str(json_data)

    # Summarize using the LLM
    summary_output = summarizer(prompt, max_length=200, do_sample=False)
    summary = summary_output[0]['generated_text'].strip()
    
    # Sentiment analysis of the entire transcript
    sentiment_results = sentiment_analyzer(transcript_text)
    main_sentiment = sentiment_results[0]['label']
    
    # Simple keyword-based recurring concerns
    words = transcript_text.lower().split()
    keywords_of_interest = ["anxiety", "depression", "relationship", "stress", "fear", "goals", "progress", "cognitive", "behavior"]
    recurring_concerns = [word for word in words if word in keywords_of_interest]
    recurring_concerns = list(set(recurring_concerns))
    if not recurring_concerns:
        recurring_concerns_str = "No specific recurring concerns identified from the predefined list."
    else:
        recurring_concerns_str = "Recurring concerns include: " + ", ".join(recurring_concerns)
    
    # Suggest follow-up topics based on summary
    follow_up_suggestions = []
    if "progress" in summary.lower():
        follow_up_suggestions.append("Explore client's perception of progress in more detail.")
    if "relationship" in summary.lower():
        follow_up_suggestions.append("Discuss client's relationship dynamics further.")
    if not follow_up_suggestions:
        follow_up_suggestions.append("Consider following up on the emotional themes identified in the summary.")
    follow_up_suggestions_str = " ".join(follow_up_suggestions)
    
    final_output = f"**Summary of Session:**\n{summary}\n\n**Overall Sentiment:** {main_sentiment}\n\n**{recurring_concerns_str}**\n\n**Suggested Follow-Up Topics:** {follow_up_suggestions_str}"
    
    return final_output

# Build Gradio UI
description = """# Psychotherapy Session Summarizer

This tool summarizes psychotherapy session transcripts (text or audio) into key themes, emotional shifts, and patterns.

**How to Use:**
- You may upload an audio file of the session or paste the text transcript.
- Optionally provide a custom focus or instruction (e.g., "Focus on how the client talks about their anxiety.").
- Click 'Summarize' to generate a summary along with identified concerns and suggested follow-ups.
"""

with gr.Blocks() as demo:
    gr.Markdown(description)
    with gr.Row():
        transcript_input = gr.Textbox(label="Session Transcript (Text)", lines=10, placeholder="Paste the session transcript here...")
        audio_input = gr.Audio(source="upload", type="file", label="Session Audio (Optional)")
    custom_instruction_input = gr.Textbox(label="Custom Instruction (Optional)", placeholder="e.g., Focus on anxiety and coping strategies.")
    summarize_button = gr.Button("Summarize")
    output_box = gr.Markdown()
    
    summarize_button.click(fn=analyze_session, inputs=[transcript_input, custom_instruction_input, audio_input], outputs=output_box)

if __name__ == "__main__":
    demo.launch()