File size: 6,434 Bytes
78c807c
 
 
 
 
 
 
3cc8829
 
 
 
78c807c
 
 
 
 
 
 
 
 
 
 
 
51697f3
78c807c
 
 
 
51697f3
78c807c
 
 
 
 
 
 
 
 
51697f3
78c807c
 
 
51697f3
78c807c
 
 
218ec6c
3cc8829
 
51697f3
3cc8829
 
 
 
 
78c807c
 
 
 
 
 
3cc8829
78c807c
 
 
 
3cc8829
 
 
78c807c
 
 
 
 
3cc8829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78c807c
 
 
3cc8829
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import gradio as gr
from datasets import load_dataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import Pipeline
import warnings

# Filtert die UserWarning von scikit-learn wegen der fehlenden Stichhaltigkeit von Klassen im Klassifikationsbericht heraus
warnings.filterwarnings("ignore", category=UserWarning)

# 1. Laden und Vorbereiten des Datensatzes (einmalig beim Start)
try:
    dataset = load_dataset("banking77")
    texts = dataset['train']['text'] + dataset['test']['text']
    labels = dataset['train']['label'] + dataset['test']['label']
    label_encoder = LabelEncoder()
    numerical_labels = label_encoder.fit_transform(labels)
    label_names = label_encoder.classes_
    train_texts, test_texts, train_labels, test_labels = train_test_split(
        texts, numerical_labels, test_size=0.2, random_state=42, stratify=numerical_labels
    )
    print("Datensatz 'banking77' erfolgreich geladen.")
except Exception as e:
    print(f"Fehler beim Laden des Datensatzes: {e}")
    label_names = ["Fehler beim Laden"]
    pipeline = None
    print("Modell wird nicht trainiert, da der Datensatz nicht geladen werden konnte.")

# 2. Trainieren des Modells (einmalig beim Start)
if 'pipeline' not in locals() or pipeline is None:
    try:
        pipeline = Pipeline([
            ('tfidf', TfidfVectorizer()),
            ('classifier', LogisticRegression(solver='liblinear', multi_class='ovr', random_state=42))
        ])
        pipeline.fit(train_texts, train_labels)
        print("Modell erfolgreich trainiert.")
    except Exception as e:
        print(f"Fehler beim Trainieren des Modells: {e}")
        pipeline = None
        print("Modell konnte nicht trainiert werden.")

# 3. Funktion für die Vorhersage
def predict_intent(text):
    if pipeline is not None and len(label_names) > 0:
        try:
            prediction = pipeline.predict([text])[0]
            predicted_label = label_names[prediction]  # Korrektur: String-Label abrufen
            probabilities = pipeline.predict_proba([text])[0]
            confidences = {label_names[i]: f"{probabilities[i]:.2f}" for i in range(len(label_names))}
            return predicted_label, confidences
        except Exception as e:
            return "Fehler bei der Vorhersage", {"Fehler": f"Ein Fehler ist bei der Vorhersage aufgetreten: {e}"}
    else:
        return "Fehler", {"Fehler": "Modell nicht geladen oder trainiert."}

# 4. Erstellen der Gradio Interface
iface = gr.Interface(
    fn=predict_intent,
    inputs=gr.Textbox(label="Gib deine Kundenanfrage ein:", placeholder="z.B. Ich habe mein Passwort vergessen."),
    outputs=[
        gr.Label(label="Vorhergesagte Kundenintention:"),
        gr.JSON(label="Konfidenzwerte:")
    ],
    #  title und description auf Deutsch
    title="KI-gestützte Vorhersage von Kundenanfragen",
    description="Diese Anwendung sagt die Absicht einer Kundenanfrage voraus.  Gib eine Anfrage ein, um die vorhergesagte Kategorie und die Konfidenzwerte zu sehen.  Das Modell wurde auf dem Datensatz Banking77 trainiert.",
    examples=[
        ["Ich habe mein Passwort vergessen."],
        ["Wie kann ich Geld überweisen?"],
        ["Meine Karte ist verloren gegangen."],
        ["Was ist der aktuelle Zinssatz für ein Sparkonto?"]
    ],
    css="""
    .container {
        margin: 0 auto;
        max-width: 700px;
        padding: 20px;
        text-align: center;
    }
    .input_output_section {
        display: flex;
        flex-direction: column;
        align-items: center;
        margin-bottom: 20px;
    }
    .label {
        font-weight: bold;
        margin-bottom: 5px;
        color: #4a5568; /* Dunkleres Grau für bessere Lesbarkeit */
    }
    .textbox {
        border: 1px solid #cbd5e0;  /* Etwas hellerer Rahmen */
        border-radius: 0.375rem; /* Abgerundete Ecken gemäß Tailwind */
        padding: 0.75rem;
        width: 100%;
        max-width: 400px; /* Begrenze die Breite des Textfelds */
        margin-bottom: 1rem;
        font-size: 1rem;
        box-shadow: inset 0 2px 4px rgba(0,0,0,0.06); /* Subtiler Schatten */
        transition: border-color 0.2s ease-in-out, box-shadow 0.2s ease-in-out; /* Sanfte Übergänge */
    }
    .textbox:focus {
        outline: none;
        border-color: #3182ce;  /* Blauer Fokus-Rand */
        box-shadow: 0 0 0 3px rgba(66, 153, 225, 0.16);  /* Heller Fokus-Schatten */
    }
    .label_output {
        font-size: 1.25rem;
        font-weight: 600;
        color: #2d3748; /* Noch dunkler für die Ausgabe */
        margin-bottom: 1.5rem;
        padding: 0.5rem;
        border-radius: 0.375rem;
        background-color: #edf2f7; /* Sehr helles Grau für Hintergrund der Ausgabe */
        box-shadow: 0 1px 3px rgba(0,0,0,0.08); /* Sehr schwacher Schatten */
        min-width: 200px; /* Mindestbreite für die Ausgabe */
        text-align: center;
    }
    .json_output {
        background-color: #f7fafc; /* Noch helleres Grau für JSON */
        border: 1px solid #e2e8f0;
        border-radius: 0.375rem;
        padding: 1rem;
        font-family: 'Menlo', monospace; /* Monospace-Schriftart für JSON */
        font-size: 0.875rem;
        line-height: 1.5rem;
        overflow-x: auto; /* Horizontal scrollbar bei Überlauf */
        max-width: 400px;  /* Maximale Breite */
        margin: 0 auto; /* Zentrieren */
    }
    .examples {
        margin-top: 2rem;
        text-align: center;
    }
    .example_item {
        cursor: pointer;
        padding: 0.5rem 1rem;
        margin: 0.5rem;
        background-color: #e2e8f0; /* Hellgrauer Hintergrund für Beispiele */
        color: #2d3748;
        border-radius: 0.375rem;
        border: 1px solid #f0f4f8;
        transition: background-color 0.2s ease-in-out, transform 0.1s ease;
        display: inline-block; /* Damit die Breite automatisch angepasst wird */
        font-size: 0.9rem;
        box-shadow: 0 1px 2px rgba(0,0,0,0.05);
    }
    .example_item:hover {
        background-color: #cbd5e0; /* Dunkleres Grau bei Hover */
        transform: translateY(-2px); /* Leichter Hover-Effekt */
        border-color: #a0aec0;
    }
    """,
)

# 5. Starten der Gradio App (wird beim Ausführen des Skripts aktiv)
iface.launch(share=True)