|
import gradio as gr |
|
from transformers import pipeline |
|
|
|
|
|
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") |
|
|
|
def classify_text(sequence, candidate_labels, multi_label): |
|
|
|
labels = [label.strip() for label in candidate_labels.split(',')] |
|
|
|
results = classifier(sequence, labels, multi_label=multi_label) |
|
|
|
formatted_results = {label: score for label, score in zip(results['labels'], results['scores'])} |
|
return formatted_results |
|
|
|
|
|
examples = [ |
|
["The market has been incredibly volatile this year, with tech stocks leading the charge.", "finance, technology, sports, education", False], |
|
["LeBron James scores 30 points to lead the Lakers to a Game 7 victory over the Celtics.", "sports, technology, finance, entertainment", False], |
|
["Tesla's new battery technology could revolutionize the electric vehicle industry.", "technology, finance, environment, education", False], |
|
["The local school district has announced a new STEM initiative to better prepare students for careers in technology.", "education, technology, politics, finance", False], |
|
] |
|
|
|
|
|
iface = gr.Interface(fn=classify_text, |
|
inputs=[gr.Textbox(label="Text to classify"), |
|
gr.Textbox(label="Candidate labels (comma-separated)"), |
|
gr.Checkbox(label="Multi-label classification", value=False)], |
|
outputs=gr.JSON(label="Classification Results"), |
|
title="Zero-Shot Text Classification with BART", |
|
description="This model uses 'bart-large-mnli' for zero-shot text classification. Enter text to classify, provide candidate labels separated by commas, and select whether it's multi-label classification.", |
|
examples=examples, |
|
css="footer{display:none !important}", |
|
allow_flagging="never") |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|