import gradio as gr from transformers import pipeline # Initialize the zero-shot classification pipeline with the BART model classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") def classify_text(sequence, candidate_labels, multi_label): # Split candidate labels entered by the user labels = [label.strip() for label in candidate_labels.split(',')] # Perform classification results = classifier(sequence, labels, multi_label=multi_label) # Format the results formatted_results = {label: score for label, score in zip(results['labels'], results['scores'])} return formatted_results # Examples for the interface examples = [ ["The market has been incredibly volatile this year, with tech stocks leading the charge.", "finance, technology, sports, education", False], ["LeBron James scores 30 points to lead the Lakers to a Game 7 victory over the Celtics.", "sports, technology, finance, entertainment", False], ["Tesla's new battery technology could revolutionize the electric vehicle industry.", "technology, finance, environment, education", False], ["The local school district has announced a new STEM initiative to better prepare students for careers in technology.", "education, technology, politics, finance", False], ] # Define Gradio interface components iface = gr.Interface(fn=classify_text, inputs=[gr.Textbox(label="Text to classify"), gr.Textbox(label="Candidate labels (comma-separated)"), gr.Checkbox(label="Multi-label classification", value=False)], outputs=gr.JSON(label="Classification Results"), title="Zero-Shot Text Classification with BART", description="This model uses 'bart-large-mnli' for zero-shot text classification. Enter text to classify, provide candidate labels separated by commas, and select whether it's multi-label classification.", examples=examples, css="footer{display:none !important}", allow_flagging="never") if __name__ == "__main__": iface.launch()