Spaces:
Runtime error
Runtime error
import gradio as gr | |
from datasets import load_dataset | |
from sklearn.model_selection import train_test_split | |
from sklearn.preprocessing import LabelEncoder | |
from sklearn.linear_model import LogisticRegression | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.pipeline import Pipeline | |
import warnings | |
# Filtert die UserWarning von scikit-learn wegen der fehlenden Stichhaltigkeit von Klassen im Klassifikationsbericht heraus | |
warnings.filterwarnings("ignore", category=UserWarning) | |
# 1. Laden und Vorbereiten des Datensatzes (einmalig beim Start) | |
try: | |
dataset = load_dataset("banking77") | |
texts = dataset['train']['text'] + dataset['test']['text'] | |
labels = dataset['train']['label'] + dataset['test']['label'] | |
label_encoder = LabelEncoder() | |
numerical_labels = label_encoder.fit_transform(labels) | |
label_names = label_encoder.classes_ | |
train_texts, test_texts, train_labels, test_labels = train_test_split( | |
texts, numerical_labels, test_size=0.2, random_state=42, stratify=numerical_labels | |
) | |
print("Datensatz 'banking77' erfolgreich geladen.") | |
except Exception as e: | |
print(f"Fehler beim Laden des Datensatzes: {e}") | |
label_names = ["Fehler beim Laden"] | |
pipeline = None | |
print("Modell wird nicht trainiert, da der Datensatz nicht geladen werden konnte.") | |
# 2. Trainieren des Modells (einmalig beim Start) | |
if 'pipeline' not in locals() or pipeline is None: | |
try: | |
pipeline = Pipeline([ | |
('tfidf', TfidfVectorizer()), | |
('classifier', LogisticRegression(solver='liblinear', multi_class='ovr', random_state=42)) | |
]) | |
pipeline.fit(train_texts, train_labels) | |
print("Modell erfolgreich trainiert.") | |
except Exception as e: | |
print(f"Fehler beim Trainieren des Modells: {e}") | |
pipeline = None | |
print("Modell konnte nicht trainiert werden.") | |
# 3. Funktion für die Vorhersage | |
def predict_intent(text): | |
if pipeline is not None and len(label_names) > 0: | |
try: | |
prediction = pipeline.predict([text])[0] | |
predicted_label = label_names[prediction] # Korrektur: String-Label abrufen | |
probabilities = pipeline.predict_proba([text])[0] | |
confidences = {label_names[i]: f"{probabilities[i]:.2f}" for i in range(len(label_names))} | |
return predicted_label, confidences | |
except Exception as e: | |
return "Fehler bei der Vorhersage", {"Fehler": f"Ein Fehler ist bei der Vorhersage aufgetreten: {e}"} | |
else: | |
return "Fehler", {"Fehler": "Modell nicht geladen oder trainiert."} | |
# 4. Erstellen der Gradio Interface | |
iface = gr.Interface( | |
fn=predict_intent, | |
inputs=gr.Textbox(label="Gib deine Kundenanfrage ein:", placeholder="z.B. Ich habe mein Passwort vergessen."), | |
outputs=[ | |
gr.Label(label="Vorhergesagte Kundenintention:"), | |
gr.JSON(label="Konfidenzwerte:") | |
], | |
# title und description auf Deutsch | |
title="KI-gestützte Vorhersage von Kundenanfragen", | |
description="Diese Anwendung sagt die Absicht einer Kundenanfrage voraus. Gib eine Anfrage ein, um die vorhergesagte Kategorie und die Konfidenzwerte zu sehen. Das Modell wurde auf dem Datensatz Banking77 trainiert.", | |
examples=[ | |
["Ich habe mein Passwort vergessen."], | |
["Wie kann ich Geld überweisen?"], | |
["Meine Karte ist verloren gegangen."], | |
["Was ist der aktuelle Zinssatz für ein Sparkonto?"] | |
], | |
css=""" | |
.container { | |
margin: 0 auto; | |
max-width: 700px; | |
padding: 20px; | |
text-align: center; | |
} | |
.input_output_section { | |
display: flex; | |
flex-direction: column; | |
align-items: center; | |
margin-bottom: 20px; | |
} | |
.label { | |
font-weight: bold; | |
margin-bottom: 5px; | |
color: #4a5568; /* Dunkleres Grau für bessere Lesbarkeit */ | |
} | |
.textbox { | |
border: 1px solid #cbd5e0; /* Etwas hellerer Rahmen */ | |
border-radius: 0.375rem; /* Abgerundete Ecken gemäß Tailwind */ | |
padding: 0.75rem; | |
width: 100%; | |
max-width: 400px; /* Begrenze die Breite des Textfelds */ | |
margin-bottom: 1rem; | |
font-size: 1rem; | |
box-shadow: inset 0 2px 4px rgba(0,0,0,0.06); /* Subtiler Schatten */ | |
transition: border-color 0.2s ease-in-out, box-shadow 0.2s ease-in-out; /* Sanfte Übergänge */ | |
} | |
.textbox:focus { | |
outline: none; | |
border-color: #3182ce; /* Blauer Fokus-Rand */ | |
box-shadow: 0 0 0 3px rgba(66, 153, 225, 0.16); /* Heller Fokus-Schatten */ | |
} | |
.label_output { | |
font-size: 1.25rem; | |
font-weight: 600; | |
color: #2d3748; /* Noch dunkler für die Ausgabe */ | |
margin-bottom: 1.5rem; | |
padding: 0.5rem; | |
border-radius: 0.375rem; | |
background-color: #edf2f7; /* Sehr helles Grau für Hintergrund der Ausgabe */ | |
box-shadow: 0 1px 3px rgba(0,0,0,0.08); /* Sehr schwacher Schatten */ | |
min-width: 200px; /* Mindestbreite für die Ausgabe */ | |
text-align: center; | |
} | |
.json_output { | |
background-color: #f7fafc; /* Noch helleres Grau für JSON */ | |
border: 1px solid #e2e8f0; | |
border-radius: 0.375rem; | |
padding: 1rem; | |
font-family: 'Menlo', monospace; /* Monospace-Schriftart für JSON */ | |
font-size: 0.875rem; | |
line-height: 1.5rem; | |
overflow-x: auto; /* Horizontal scrollbar bei Überlauf */ | |
max-width: 400px; /* Maximale Breite */ | |
margin: 0 auto; /* Zentrieren */ | |
} | |
.examples { | |
margin-top: 2rem; | |
text-align: center; | |
} | |
.example_item { | |
cursor: pointer; | |
padding: 0.5rem 1rem; | |
margin: 0.5rem; | |
background-color: #e2e8f0; /* Hellgrauer Hintergrund für Beispiele */ | |
color: #2d3748; | |
border-radius: 0.375rem; | |
border: 1px solid #f0f4f8; | |
transition: background-color 0.2s ease-in-out, transform 0.1s ease; | |
display: inline-block; /* Damit die Breite automatisch angepasst wird */ | |
font-size: 0.9rem; | |
box-shadow: 0 1px 2px rgba(0,0,0,0.05); | |
} | |
.example_item:hover { | |
background-color: #cbd5e0; /* Dunkleres Grau bei Hover */ | |
transform: translateY(-2px); /* Leichter Hover-Effekt */ | |
border-color: #a0aec0; | |
} | |
""", | |
) | |
# 5. Starten der Gradio App (wird beim Ausführen des Skripts aktiv) | |
iface.launch(share=True) | |