ML_sup / app.py
ElPierrito's picture
Update app.py
51697f3 verified
import gradio as gr
from datasets import load_dataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import Pipeline
import warnings
# Filtert die UserWarning von scikit-learn wegen der fehlenden Stichhaltigkeit von Klassen im Klassifikationsbericht heraus
warnings.filterwarnings("ignore", category=UserWarning)
# 1. Laden und Vorbereiten des Datensatzes (einmalig beim Start)
try:
dataset = load_dataset("banking77")
texts = dataset['train']['text'] + dataset['test']['text']
labels = dataset['train']['label'] + dataset['test']['label']
label_encoder = LabelEncoder()
numerical_labels = label_encoder.fit_transform(labels)
label_names = label_encoder.classes_
train_texts, test_texts, train_labels, test_labels = train_test_split(
texts, numerical_labels, test_size=0.2, random_state=42, stratify=numerical_labels
)
print("Datensatz 'banking77' erfolgreich geladen.")
except Exception as e:
print(f"Fehler beim Laden des Datensatzes: {e}")
label_names = ["Fehler beim Laden"]
pipeline = None
print("Modell wird nicht trainiert, da der Datensatz nicht geladen werden konnte.")
# 2. Trainieren des Modells (einmalig beim Start)
if 'pipeline' not in locals() or pipeline is None:
try:
pipeline = Pipeline([
('tfidf', TfidfVectorizer()),
('classifier', LogisticRegression(solver='liblinear', multi_class='ovr', random_state=42))
])
pipeline.fit(train_texts, train_labels)
print("Modell erfolgreich trainiert.")
except Exception as e:
print(f"Fehler beim Trainieren des Modells: {e}")
pipeline = None
print("Modell konnte nicht trainiert werden.")
# 3. Funktion für die Vorhersage
def predict_intent(text):
if pipeline is not None and len(label_names) > 0:
try:
prediction = pipeline.predict([text])[0]
predicted_label = label_names[prediction] # Korrektur: String-Label abrufen
probabilities = pipeline.predict_proba([text])[0]
confidences = {label_names[i]: f"{probabilities[i]:.2f}" for i in range(len(label_names))}
return predicted_label, confidences
except Exception as e:
return "Fehler bei der Vorhersage", {"Fehler": f"Ein Fehler ist bei der Vorhersage aufgetreten: {e}"}
else:
return "Fehler", {"Fehler": "Modell nicht geladen oder trainiert."}
# 4. Erstellen der Gradio Interface
iface = gr.Interface(
fn=predict_intent,
inputs=gr.Textbox(label="Gib deine Kundenanfrage ein:", placeholder="z.B. Ich habe mein Passwort vergessen."),
outputs=[
gr.Label(label="Vorhergesagte Kundenintention:"),
gr.JSON(label="Konfidenzwerte:")
],
# title und description auf Deutsch
title="KI-gestützte Vorhersage von Kundenanfragen",
description="Diese Anwendung sagt die Absicht einer Kundenanfrage voraus. Gib eine Anfrage ein, um die vorhergesagte Kategorie und die Konfidenzwerte zu sehen. Das Modell wurde auf dem Datensatz Banking77 trainiert.",
examples=[
["Ich habe mein Passwort vergessen."],
["Wie kann ich Geld überweisen?"],
["Meine Karte ist verloren gegangen."],
["Was ist der aktuelle Zinssatz für ein Sparkonto?"]
],
css="""
.container {
margin: 0 auto;
max-width: 700px;
padding: 20px;
text-align: center;
}
.input_output_section {
display: flex;
flex-direction: column;
align-items: center;
margin-bottom: 20px;
}
.label {
font-weight: bold;
margin-bottom: 5px;
color: #4a5568; /* Dunkleres Grau für bessere Lesbarkeit */
}
.textbox {
border: 1px solid #cbd5e0; /* Etwas hellerer Rahmen */
border-radius: 0.375rem; /* Abgerundete Ecken gemäß Tailwind */
padding: 0.75rem;
width: 100%;
max-width: 400px; /* Begrenze die Breite des Textfelds */
margin-bottom: 1rem;
font-size: 1rem;
box-shadow: inset 0 2px 4px rgba(0,0,0,0.06); /* Subtiler Schatten */
transition: border-color 0.2s ease-in-out, box-shadow 0.2s ease-in-out; /* Sanfte Übergänge */
}
.textbox:focus {
outline: none;
border-color: #3182ce; /* Blauer Fokus-Rand */
box-shadow: 0 0 0 3px rgba(66, 153, 225, 0.16); /* Heller Fokus-Schatten */
}
.label_output {
font-size: 1.25rem;
font-weight: 600;
color: #2d3748; /* Noch dunkler für die Ausgabe */
margin-bottom: 1.5rem;
padding: 0.5rem;
border-radius: 0.375rem;
background-color: #edf2f7; /* Sehr helles Grau für Hintergrund der Ausgabe */
box-shadow: 0 1px 3px rgba(0,0,0,0.08); /* Sehr schwacher Schatten */
min-width: 200px; /* Mindestbreite für die Ausgabe */
text-align: center;
}
.json_output {
background-color: #f7fafc; /* Noch helleres Grau für JSON */
border: 1px solid #e2e8f0;
border-radius: 0.375rem;
padding: 1rem;
font-family: 'Menlo', monospace; /* Monospace-Schriftart für JSON */
font-size: 0.875rem;
line-height: 1.5rem;
overflow-x: auto; /* Horizontal scrollbar bei Überlauf */
max-width: 400px; /* Maximale Breite */
margin: 0 auto; /* Zentrieren */
}
.examples {
margin-top: 2rem;
text-align: center;
}
.example_item {
cursor: pointer;
padding: 0.5rem 1rem;
margin: 0.5rem;
background-color: #e2e8f0; /* Hellgrauer Hintergrund für Beispiele */
color: #2d3748;
border-radius: 0.375rem;
border: 1px solid #f0f4f8;
transition: background-color 0.2s ease-in-out, transform 0.1s ease;
display: inline-block; /* Damit die Breite automatisch angepasst wird */
font-size: 0.9rem;
box-shadow: 0 1px 2px rgba(0,0,0,0.05);
}
.example_item:hover {
background-color: #cbd5e0; /* Dunkleres Grau bei Hover */
transform: translateY(-2px); /* Leichter Hover-Effekt */
border-color: #a0aec0;
}
""",
)
# 5. Starten der Gradio App (wird beim Ausführen des Skripts aktiv)
iface.launch(share=True)