Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -486,5 +486,239 @@ def start_answer_generation(model_choice: str):
|
|
486 |
thread.start()
|
487 |
|
488 |
return f"Answer generation started using {model_choice}. Check progress."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
489 |
|
490 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
486 |
thread.start()
|
487 |
|
488 |
return f"Answer generation started using {model_choice}. Check progress."
|
489 |
+
def get_generation_progress():
|
490 |
+
"""
|
491 |
+
Get the current progress of answer generation.
|
492 |
+
"""
|
493 |
+
if not processing_status["is_processing"] and processing_status["progress"] == 0:
|
494 |
+
return "Not started"
|
495 |
+
|
496 |
+
if processing_status["is_processing"]:
|
497 |
+
progress = processing_status["progress"]
|
498 |
+
total = processing_status["total"]
|
499 |
+
status_msg = f"Generating answers... {progress}/{total} completed"
|
500 |
+
return status_msg
|
501 |
+
else:
|
502 |
+
# Generation completed
|
503 |
+
if cached_answers:
|
504 |
+
# Create DataFrame with results
|
505 |
+
display_data = []
|
506 |
+
for task_id, data in cached_answers.items():
|
507 |
+
display_data.append({
|
508 |
+
"Task ID": task_id,
|
509 |
+
"Question": data["question"][:100] + "..." if len(data["question"]) > 100 else data["question"],
|
510 |
+
"Generated Answer": data["answer"][:200] + "..." if len(data["answer"]) > 200 else data["answer"]
|
511 |
+
})
|
512 |
+
|
513 |
+
df = pd.DataFrame(display_data)
|
514 |
+
status_msg = f"Answer generation completed! {len(cached_answers)} answers ready for submission."
|
515 |
+
return status_msg, df
|
516 |
+
else:
|
517 |
+
return "Answer generation completed but no answers were generated."
|
518 |
+
|
519 |
+
def submit_cached_answers(profile: gr.OAuthProfile | None):
|
520 |
+
"""
|
521 |
+
Submit the cached answers to the evaluation API.
|
522 |
+
"""
|
523 |
+
global cached_answers
|
524 |
+
|
525 |
+
if not profile:
|
526 |
+
return "Please log in to Hugging Face first.", None
|
527 |
+
|
528 |
+
if not cached_answers:
|
529 |
+
return "No cached answers available. Please generate answers first.", None
|
530 |
+
|
531 |
+
username = profile.username
|
532 |
+
space_id = os.getenv("SPACE_ID")
|
533 |
+
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" if space_id else "Unknown"
|
534 |
+
|
535 |
+
# Prepare submission payload
|
536 |
+
answers_payload = []
|
537 |
+
for task_id, data in cached_answers.items():
|
538 |
+
answers_payload.append({
|
539 |
+
"task_id": task_id,
|
540 |
+
"submitted_answer": data["answer"]
|
541 |
+
})
|
542 |
+
|
543 |
+
submission_data = {
|
544 |
+
"username": username.strip(),
|
545 |
+
"agent_code": agent_code,
|
546 |
+
"answers": answers_payload
|
547 |
+
}
|
548 |
+
|
549 |
+
# Submit to API
|
550 |
+
api_url = DEFAULT_API_URL
|
551 |
+
submit_url = f"{api_url}/submit"
|
552 |
+
|
553 |
+
print(f"Submitting {len(answers_payload)} answers to: {submit_url}")
|
554 |
+
|
555 |
+
try:
|
556 |
+
response = requests.post(submit_url, json=submission_data, timeout=60)
|
557 |
+
response.raise_for_status()
|
558 |
+
result_data = response.json()
|
559 |
+
|
560 |
+
final_status = (
|
561 |
+
f"Submission Successful!\n"
|
562 |
+
f"User: {result_data.get('username')}\n"
|
563 |
+
f"Overall Score: {result_data.get('score', 'N/A')}% "
|
564 |
+
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
|
565 |
+
f"Message: {result_data.get('message', 'No message received.')}"
|
566 |
+
)
|
567 |
+
|
568 |
+
# Create results DataFrame
|
569 |
+
results_log = []
|
570 |
+
for task_id, data in cached_answers.items():
|
571 |
+
results_log.append({
|
572 |
+
"Task ID": task_id,
|
573 |
+
"Question": data["question"],
|
574 |
+
"Submitted Answer": data["answer"]
|
575 |
+
})
|
576 |
+
|
577 |
+
results_df = pd.DataFrame(results_log)
|
578 |
+
return final_status, results_df
|
579 |
+
|
580 |
+
except requests.exceptions.HTTPError as e:
|
581 |
+
error_detail = f"Server responded with status {e.response.status_code}."
|
582 |
+
try:
|
583 |
+
error_json = e.response.json()
|
584 |
+
error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
|
585 |
+
except:
|
586 |
+
error_detail += f" Response: {e.response.text[:500]}"
|
587 |
+
return f"Submission Failed: {error_detail}", None
|
588 |
+
|
589 |
+
except requests.exceptions.Timeout:
|
590 |
+
return "Submission Failed: The request timed out.", None
|
591 |
+
|
592 |
+
except Exception as e:
|
593 |
+
return f"Submission Failed: {e}", None
|
594 |
|
595 |
+
def clear_cache():
|
596 |
+
"""
|
597 |
+
Clear all cached data.
|
598 |
+
"""
|
599 |
+
global cached_answers, cached_questions, processing_status
|
600 |
+
cached_answers = {}
|
601 |
+
cached_questions = []
|
602 |
+
processing_status = {"is_processing": False, "progress": 0, "total": 0}
|
603 |
+
return "Cache cleared successfully.", None
|
604 |
+
|
605 |
+
def test_media_processing(image_files, audio_files, question):
|
606 |
+
"""
|
607 |
+
Test the media processing functionality with uploaded files.
|
608 |
+
"""
|
609 |
+
if not question:
|
610 |
+
question = "What can you tell me about the uploaded media?"
|
611 |
+
|
612 |
+
agent = IntelligentAgent(debug=True)
|
613 |
+
|
614 |
+
# Convert file paths to lists
|
615 |
+
image_paths = [img.name for img in image_files] if image_files else None
|
616 |
+
audio_paths = [aud.name for aud in audio_files] if audio_files else None
|
617 |
+
|
618 |
+
try:
|
619 |
+
result = agent(question, image_files=image_paths, audio_files=audio_paths)
|
620 |
+
return result
|
621 |
+
except Exception as e:
|
622 |
+
return f"Error processing media: {e}"
|
623 |
+
|
624 |
+
# --- Enhanced Gradio Interface ---
|
625 |
+
with gr.Blocks(title="Intelligent Agent with Media Processing") as demo:
|
626 |
+
gr.Markdown("# Intelligent Agent with Conditional Search and Media Processing")
|
627 |
+
gr.Markdown("This agent can process images and audio files, uses an LLM to decide when search is needed, optimizing for both accuracy and efficiency.")
|
628 |
+
|
629 |
+
with gr.Row():
|
630 |
+
gr.LoginButton()
|
631 |
+
clear_btn = gr.Button("Clear Cache", variant="secondary")
|
632 |
+
|
633 |
+
with gr.Tab("Media Processing Test"):
|
634 |
+
gr.Markdown("### Test Image and Audio Processing")
|
635 |
+
|
636 |
+
with gr.Row():
|
637 |
+
with gr.Column():
|
638 |
+
image_upload = gr.File(
|
639 |
+
label="Upload Images",
|
640 |
+
file_types=["image"],
|
641 |
+
file_count="multiple"
|
642 |
+
)
|
643 |
+
audio_upload = gr.File(
|
644 |
+
label="Upload Audio Files",
|
645 |
+
file_types=["audio"],
|
646 |
+
file_count="multiple"
|
647 |
+
)
|
648 |
+
|
649 |
+
with gr.Column():
|
650 |
+
test_question = gr.Textbox(
|
651 |
+
label="Question about the media",
|
652 |
+
placeholder="What can you tell me about these files?",
|
653 |
+
lines=3
|
654 |
+
)
|
655 |
+
test_btn = gr.Button("Process Media", variant="primary")
|
656 |
+
|
657 |
+
test_output = gr.Textbox(
|
658 |
+
label="Processing Result",
|
659 |
+
lines=10,
|
660 |
+
interactive=False
|
661 |
+
)
|
662 |
+
|
663 |
+
test_btn.click(
|
664 |
+
fn=test_media_processing,
|
665 |
+
inputs=[image_upload, audio_upload, test_question],
|
666 |
+
outputs=test_output
|
667 |
+
)
|
668 |
+
|
669 |
+
with gr.Tab("Step 1: Fetch Questions"):
|
670 |
+
gr.Markdown("### Fetch Questions from API")
|
671 |
+
fetch_btn = gr.Button("Fetch Questions", variant="primary")
|
672 |
+
fetch_status = gr.Textbox(label="Fetch Status", lines=2, interactive=False)
|
673 |
+
questions_table = gr.DataFrame(label="Available Questions", wrap=True)
|
674 |
+
|
675 |
+
fetch_btn.click(
|
676 |
+
fn=fetch_questions,
|
677 |
+
outputs=[fetch_status, questions_table]
|
678 |
+
)
|
679 |
+
|
680 |
+
with gr.Tab("Step 2: Generate Answers"):
|
681 |
+
gr.Markdown("### Generate Answers with Intelligent Search Decision")
|
682 |
+
|
683 |
+
with gr.Row():
|
684 |
+
model_choice = gr.Dropdown(
|
685 |
+
choices=["Llama 3.1 8B", "Mistral 7B"],
|
686 |
+
value="Llama 3.1 8B",
|
687 |
+
label="Select Model"
|
688 |
+
)
|
689 |
+
generate_btn = gr.Button("Start Answer Generation", variant="primary")
|
690 |
+
refresh_btn = gr.Button("Refresh Progress", variant="secondary")
|
691 |
+
|
692 |
+
generation_status = gr.Textbox(label="Generation Status", lines=2, interactive=False)
|
693 |
+
answers_table = gr.DataFrame(label="Generated Answers", wrap=True)
|
694 |
+
|
695 |
+
generate_btn.click(
|
696 |
+
fn=start_answer_generation,
|
697 |
+
inputs=[model_choice],
|
698 |
+
outputs=generation_status
|
699 |
+
)
|
700 |
+
|
701 |
+
refresh_btn.click(
|
702 |
+
fn=get_generation_progress,
|
703 |
+
outputs=[generation_status, answers_table]
|
704 |
+
)
|
705 |
+
|
706 |
+
with gr.Tab("Step 3: Submit Results"):
|
707 |
+
gr.Markdown("### Submit Generated Answers")
|
708 |
+
submit_btn = gr.Button("Submit Answers", variant="primary")
|
709 |
+
submit_status = gr.Textbox(label="Submission Status", lines=4, interactive=False)
|
710 |
+
results_table = gr.DataFrame(label="Submission Results", wrap=True)
|
711 |
+
|
712 |
+
submit_btn.click(
|
713 |
+
fn=submit_cached_answers,
|
714 |
+
outputs=[submit_status, results_table]
|
715 |
+
)
|
716 |
+
|
717 |
+
# Clear cache functionality
|
718 |
+
clear_btn.click(
|
719 |
+
fn=clear_cache,
|
720 |
+
outputs=[fetch_status, questions_table]
|
721 |
+
)
|
722 |
+
|
723 |
+
if __name__ == "__main__":
|
724 |
+
demo.launch()
|