hjawwad456's picture
fix styles to show cross on crop images
8a64d7a
raw
history blame
16.3 kB
import json
from typing import Generator, List
import gradio as gr
from openai import OpenAI
from crop_utils import get_image_crop
from prompts import (
get_chat_system_prompt,
get_live_event_system_prompt,
get_live_event_user_prompt,
get_street_interview_prompt,
get_street_interview_system_prompt,
)
from transcript import TranscriptProcessor
from utils import css, get_transcript_for_url, head
from utils import openai_tools as tools
from utils import setup_openai_key
client = OpenAI()
def get_initial_analysis(
transcript_processor: TranscriptProcessor, cid, rsid, origin, ct, uid
) -> Generator[str, None, None]:
"""Perform initial analysis of the transcript using OpenAI."""
try:
transcript = transcript_processor.get_transcript()
speaker_mapping = transcript_processor.speaker_mapping
client = OpenAI()
if "localhost" in origin:
link_start = "http"
else:
link_start = "https"
if ct == "si": # street interview
user_prompt = get_street_interview_prompt(transcript, uid, rsid, link_start)
system_prompt = get_street_interview_system_prompt(cid, rsid, origin, ct)
completion = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
stream=True,
temperature=0.1,
)
else:
system_prompt = get_live_event_system_prompt(
cid, rsid, origin, ct, speaker_mapping, transcript
)
user_prompt = get_live_event_user_prompt(uid, link_start)
completion = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
stream=True,
temperature=0.1,
)
collected_messages = []
# Iterate through the stream
for chunk in completion:
if chunk.choices[0].delta.content is not None:
chunk_message = chunk.choices[0].delta.content
collected_messages.append(chunk_message)
# Yield the accumulated message so far
yield "".join(collected_messages)
except Exception as e:
print(f"Error in initial analysis: {str(e)}")
yield "An error occurred during initial analysis. Please check your API key and file path."
def chat(
message: str,
chat_history: List,
transcript_processor: TranscriptProcessor,
cid,
rsid,
origin,
ct,
uid,
):
try:
client = OpenAI()
if "localhost" in origin:
link_start = "http"
else:
link_start = "https"
speaker_mapping = transcript_processor.speaker_mapping
system_prompt = get_chat_system_prompt(
cid=cid,
rsid=rsid,
origin=origin,
ct=ct,
speaker_mapping=speaker_mapping,
transcript=transcript_processor.get_transcript(),
link_start=link_start,
)
messages = [{"role": "system", "content": system_prompt}]
for user_msg, assistant_msg in chat_history:
if user_msg is not None:
messages.append({"role": "user", "content": user_msg})
if assistant_msg is not None:
messages.append({"role": "assistant", "content": assistant_msg})
# Add the current message
messages.append({"role": "user", "content": message})
completion = client.chat.completions.create(
model="gpt-4o",
messages=messages,
tools=tools,
stream=True,
temperature=0.3,
)
collected_messages = []
tool_calls_detected = False
for chunk in completion:
if chunk.choices[0].delta.tool_calls:
tool_calls_detected = True
# Handle tool calls without streaming
response = client.chat.completions.create(
model="gpt-4o",
messages=messages,
tools=tools,
)
if response.choices[0].message.tool_calls:
tool_call = response.choices[0].message.tool_calls[0]
if tool_call.function.name == "get_image":
# Return the image directly in the chat
image_data = get_image_crop(cid, rsid, uid)
print(response.choices[0].message)
messages.append(response.choices[0].message)
function_call_result_message = {
"role": "tool",
"content": "Here are the Image Crops",
"name": tool_call.function.name,
"tool_call_id": tool_call.id,
}
messages.append(function_call_result_message)
yield image_data
return
if tool_call.function.name == "correct_speaker_name_with_url":
args = eval(tool_call.function.arguments)
url = args.get("url", None)
if url:
transcript_processor.correct_speaker_mapping_with_agenda(
url
)
corrected_speaker_mapping = (
transcript_processor.speaker_mapping
)
messages.append(response.choices[0].message)
function_call_result_message = {
"role": "tool",
"content": json.dumps(
{
"speaker_mapping": f"Corrected Speaker Mapping... {corrected_speaker_mapping}"
}
),
"name": tool_call.function.name,
"tool_call_id": tool_call.id,
}
messages.append(function_call_result_message)
# Get final response after tool call
final_response = client.chat.completions.create(
model="gpt-4o",
messages=messages,
stream=True,
)
collected_chunk = ""
for final_chunk in final_response:
if final_chunk.choices[0].delta.content:
collected_chunk += final_chunk.choices[
0
].delta.content
yield collected_chunk
return
else:
function_call_result_message = {
"role": "tool",
"content": "No URL Provided",
"name": tool_call.function.name,
"tool_call_id": tool_call.id,
}
elif tool_call.function.name == "correct_call_type":
args = eval(tool_call.function.arguments)
call_type = args.get("call_type", None)
if call_type:
# Stream the analysis for corrected call type
for content in get_initial_analysis(
transcript_processor,
call_type,
rsid,
origin,
call_type,
uid,
):
yield content
return
break # Exit streaming loop if tool calls detected
if not tool_calls_detected and chunk.choices[0].delta.content is not None:
chunk_message = chunk.choices[0].delta.content
collected_messages.append(chunk_message)
yield "".join(collected_messages)
except Exception as e:
print(f"Unexpected error in chat: {str(e)}")
import traceback
print(f"Traceback: {traceback.format_exc()}")
yield "Sorry, there was an error processing your request."
def create_chat_interface():
"""Create and configure the chat interface."""
with gr.Blocks(
fill_height=True,
fill_width=True,
css=css,
head=head,
theme=gr.themes.Default(
font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]
),
) as demo:
chatbot = gr.Chatbot(
elem_id="chatbot_box",
layout="bubble",
show_label=False,
show_share_button=False,
show_copy_all_button=False,
show_copy_button=False,
render=True,
)
msg = gr.Textbox(elem_id="chatbot_textbox", show_label=False)
transcript_processor_state = gr.State() # maintain state of imp things
call_id_state = gr.State()
colab_id_state = gr.State()
origin_state = gr.State()
ct_state = gr.State()
turl_state = gr.State()
uid_state = gr.State()
iframe_html = "<iframe id='link-frame'></iframe>"
gr.HTML(value=iframe_html) # Add iframe to the UI
def respond(
message: str,
chat_history: List,
transcript_processor,
cid,
rsid,
origin,
ct,
uid,
):
if not transcript_processor:
bot_message = "Transcript processor not initialized."
chat_history.append((message, bot_message))
return "", chat_history
chat_history.append((message, ""))
for chunk in chat(
message,
chat_history[:-1], # Exclude the current incomplete message
transcript_processor,
cid,
rsid,
origin,
ct,
uid,
):
chat_history[-1] = (message, chunk)
yield "", chat_history
msg.submit(
respond,
[
msg,
chatbot,
transcript_processor_state,
call_id_state,
colab_id_state,
origin_state,
ct_state,
uid_state,
],
[msg, chatbot],
)
# Handle initial loading with streaming
def on_app_load(request: gr.Request):
turls = None
cid = request.query_params.get("cid", None)
rsid = request.query_params.get("rsid", None)
origin = request.query_params.get("origin", None)
ct = request.query_params.get("ct", None)
turl = request.query_params.get("turl", None)
uid = request.query_params.get("uid", None)
pnames = request.query_params.get("pnames", None)
required_params = ["cid", "rsid", "origin", "ct", "turl", "uid"]
missing_params = [
param
for param in required_params
if request.query_params.get(param) is None
]
if missing_params:
error_message = (
f"Missing required parameters: {', '.join(missing_params)}"
)
chatbot_value = [(None, error_message)]
return [chatbot_value, None, None, None, None, None, None, None]
if ct == "rp":
# split turls based on ,
turls = turl.split(",")
pnames = [pname.replace("_", " ") for pname in pnames.split(",")]
try:
if turls:
transcript_data = []
for turl in turls:
print("Getting Transcript for URL")
transcript_data.append(get_transcript_for_url(turl))
print("Now creating Processor")
transcript_processor = TranscriptProcessor(
transcript_data=transcript_data,
call_type=ct,
person_names=pnames,
)
else:
transcript_data = get_transcript_for_url(turl)
transcript_processor = TranscriptProcessor(
transcript_data=transcript_data, call_type=ct
)
# Initialize with empty message
chatbot_value = [(None, "")]
# Return initial values with the transcript processor
return [
chatbot_value,
transcript_processor,
cid,
rsid,
origin,
ct,
turl,
uid,
]
except Exception as e:
print(e)
error_message = f"Error processing call_id {cid}: {str(e)}"
chatbot_value = [(None, error_message)]
return [chatbot_value, None, None, None, None, None, None, None]
def display_processing_message(chatbot_value):
"""Display the processing message while maintaining state."""
# Create new chatbot value with processing message
new_chatbot_value = [
(None, "Video is being processed. Please wait for the results...")
]
# Return all states to maintain them
return new_chatbot_value
def stream_initial_analysis(
chatbot_value, transcript_processor, cid, rsid, origin, ct, uid
):
if not transcript_processor:
return chatbot_value
try:
for chunk in get_initial_analysis(
transcript_processor, cid, rsid, origin, ct, uid
):
# Update the existing message instead of creating a new one
chatbot_value[0] = (None, chunk)
yield chatbot_value
except Exception as e:
chatbot_value[0] = (None, f"Error during analysis: {str(e)}")
yield chatbot_value
demo.load(
on_app_load,
inputs=None,
outputs=[
chatbot,
transcript_processor_state,
call_id_state,
colab_id_state,
origin_state,
ct_state,
turl_state,
uid_state,
],
).then(
display_processing_message,
inputs=[chatbot],
outputs=[chatbot],
).then(
stream_initial_analysis,
inputs=[
chatbot,
transcript_processor_state,
call_id_state,
colab_id_state,
origin_state,
ct_state,
uid_state,
],
outputs=[chatbot],
)
return demo
def main():
"""Main function to run the application."""
try:
setup_openai_key()
demo = create_chat_interface()
demo.launch(share=True)
except Exception as e:
print(f"Error starting application: {str(e)}")
raise
if __name__ == "__main__":
main()