File size: 2,183 Bytes
3afd23e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import gradio as gr
from transformers import pipeline
# Initialize the zero-shot classification pipeline with the BART model
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
def classify_text(sequence, candidate_labels, multi_label):
# Split candidate labels entered by the user
labels = [label.strip() for label in candidate_labels.split(',')]
# Perform classification
results = classifier(sequence, labels, multi_label=multi_label)
# Format the results
formatted_results = {label: score for label, score in zip(results['labels'], results['scores'])}
return formatted_results
# Examples for the interface
examples = [
["The market has been incredibly volatile this year, with tech stocks leading the charge.", "finance, technology, sports, education", False],
["LeBron James scores 30 points to lead the Lakers to a Game 7 victory over the Celtics.", "sports, technology, finance, entertainment", False],
["Tesla's new battery technology could revolutionize the electric vehicle industry.", "technology, finance, environment, education", False],
["The local school district has announced a new STEM initiative to better prepare students for careers in technology.", "education, technology, politics, finance", False],
]
# Define Gradio interface components
iface = gr.Interface(fn=classify_text,
inputs=[gr.Textbox(label="Text to classify"),
gr.Textbox(label="Candidate labels (comma-separated)"),
gr.Checkbox(label="Multi-label classification", value=False)],
outputs=gr.JSON(label="Classification Results"),
title="Zero-Shot Text Classification with BART",
description="This model uses 'bart-large-mnli' for zero-shot text classification. Enter text to classify, provide candidate labels separated by commas, and select whether it's multi-label classification.",
examples=examples,
css="footer{display:none !important}",
allow_flagging="never")
if __name__ == "__main__":
iface.launch()
|