Spaces:
Sleeping
Sleeping
# app.py | |
import os | |
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import joblib | |
import spacy | |
from langchain_core.pydantic_v1 import BaseModel, Field | |
from langchain.prompts import HumanMessagePromptTemplate, ChatPromptTemplate | |
from langchain.output_parsers import PydanticOutputParser | |
from langchain_openai import ChatOpenAI | |
from transformers import pipeline | |
### 1. Translator ### | |
chat = ChatOpenAI() | |
class TextTranslator(BaseModel): | |
output: str = Field(description="Translated output") | |
output_parser = PydanticOutputParser(pydantic_object=TextTranslator) | |
format_instructions = output_parser.get_format_instructions() | |
def text_translator(input_text: str, language: str) -> str: | |
template = """Enter the text that you want to translate: | |
{input_text}, and enter the language that you want it to translate to {language}. {format_instructions}""" | |
human_prompt = HumanMessagePromptTemplate.from_template(template) | |
prompt = ChatPromptTemplate.from_messages([human_prompt]).format_prompt( | |
input_text=input_text, language=language, format_instructions=format_instructions) | |
response = chat(messages=prompt.to_messages()) | |
return output_parser.parse(response.content).output | |
### 2. Sentiment ### | |
sentiment_model = pipeline("sentiment-analysis", model="cardiffnlp/twitter-xlm-roberta-base-sentiment") | |
def sentiment_analysis(message, history): | |
result = sentiment_model(message) | |
return f"Sentimiento : {result[0]['label']} (Probabilidad: {result[0]['score']:.2f})" | |
### 3. Financial Analyst ### | |
nlp = spacy.load('en_core_web_sm') | |
nlp.add_pipe('sentencizer') | |
def split_in_sentences(text): | |
return [str(sent).strip() for sent in nlp(text).sents] | |
def make_spans(text, results): | |
labels = [r['label'] for r in results] | |
return list(zip(split_in_sentences(text), labels)) | |
auth_token = os.environ.get("HF_Token") | |
asr = pipeline("automatic-speech-recognition", "facebook/wav2vec2-base-960h") | |
def speech_to_text(audio): | |
return asr(audio)["text"] | |
summarizer = pipeline("summarization", model="knkarthick/MEETING_SUMMARY") | |
def summarize_text(text): | |
return summarizer(text)[0]['summary_text'] | |
fin_model = pipeline("sentiment-analysis", model='yiyanghkust/finbert-tone') | |
def text_to_sentiment(text): | |
return fin_model(text)[0]["label"] | |
def fin_ner(text): | |
return gr.Interface.load("dslim/bert-base-NER", src='models', use_auth_token=auth_token)(text) | |
def fin_ext(text): | |
return make_spans(text, fin_model(split_in_sentences(text))) | |
def fls(text): | |
model = pipeline("text-classification", model="demo-org/finbert_fls", tokenizer="demo-org/finbert_fls", use_auth_token=auth_token) | |
return make_spans(text, model(split_in_sentences(text))) | |
### 4. Personal Info Detection ### | |
def detect_personal_info(text): | |
model = gr.Interface.load("iiiorg/piiranha-v1-detect-personal-information") | |
return model(text) | |
### 5. Customer Churn ### | |
script_dir = os.path.dirname(os.path.abspath(__file__)) | |
pipeline_path = os.path.join(script_dir, 'toolkit', 'pipeline.joblib') | |
model_path = os.path.join(script_dir, 'toolkit', 'Random Forest Classifier.joblib') | |
pipeline_model = joblib.load(pipeline_path) | |
model = joblib.load(model_path) | |
def calculate_total_charges(tenure, monthly_charges): | |
return tenure * monthly_charges | |
def predict(SeniorCitizen, Partner, Dependents, tenure, | |
InternetService, OnlineSecurity, OnlineBackup, DeviceProtection, TechSupport, | |
StreamingTV, StreamingMovies, Contract, PaperlessBilling, PaymentMethod, | |
MonthlyCharges): | |
TotalCharges = calculate_total_charges(tenure, MonthlyCharges) | |
input_df = pd.DataFrame({ | |
'SeniorCitizen': [SeniorCitizen], 'Partner': [Partner], 'Dependents': [Dependents], | |
'tenure': [tenure], 'InternetService': [InternetService], 'OnlineSecurity': [OnlineSecurity], | |
'OnlineBackup': [OnlineBackup], 'DeviceProtection': [DeviceProtection], 'TechSupport': [TechSupport], | |
'StreamingTV': [StreamingTV], 'StreamingMovies': [StreamingMovies], 'Contract': [Contract], | |
'PaperlessBilling': [PaperlessBilling], 'PaymentMethod': [PaymentMethod], | |
'MonthlyCharges': [MonthlyCharges], 'TotalCharges': [TotalCharges] | |
}) | |
X_processed = pipeline_model.transform(input_df) | |
cat_encoder = pipeline_model.named_steps['preprocessor'].named_transformers_['cat'].named_steps['onehot'] | |
feature_names = [*input_df.select_dtypes(exclude='object').columns, *cat_encoder.get_feature_names_out()] | |
final_df = pd.DataFrame(X_processed, columns=feature_names) | |
pred_probs = model.predict_proba(final_df)[0] | |
return { | |
"Prediction: CHURN 🔴": pred_probs[1], | |
"Prediction: STAY ✅": pred_probs[0] | |
} | |
### COMBINED UI ### | |
with gr.Blocks() as demo: | |
with gr.Tab("Translator"): | |
gr.Markdown("## Translator") | |
input_text = gr.Textbox(label="Text to Translate") | |
language = gr.Textbox(label="Target Language") | |
output = gr.Textbox(label="Translated Text") | |
gr.Button("Translate").click(text_translator, inputs=[input_text, language], outputs=output) | |
with gr.Tab("Sentiment"): | |
gr.Markdown("## Sentiment Analysis") | |
gr.ChatInterface(sentiment_analysis, type="messages") | |
with gr.Tab("Financial Analyst"): | |
gr.Markdown("## Financial Analyst") | |
audio = gr.Audio(source="microphone", type="filepath") | |
text_input = gr.Textbox() | |
summary = gr.Textbox() | |
tone_label = gr.Label() | |
gr.Button("Speech to Text").click(speech_to_text, inputs=audio, outputs=text_input) | |
gr.Button("Summarize").click(summarize_text, inputs=text_input, outputs=summary) | |
gr.Button("Classify Tone").click(text_to_sentiment, inputs=summary, outputs=tone_label) | |
gr.HighlightedText(label="Tone").render() | |
gr.HighlightedText(label="Forward-Looking").render() | |
gr.Button("Analyze All").click(fn=fin_ext, inputs=text_input, outputs=None).click(fls, inputs=text_input, outputs=None) | |
gr.Button("Entities").click(fin_ner, inputs=text_input, outputs=None) | |
with gr.Tab("Personal Info Detector"): | |
gr.Markdown("## Detect Personal Info") | |
pi_input = gr.Textbox() | |
pi_output = gr.HighlightedText() | |
gr.Button("Detect").click(detect_personal_info, inputs=pi_input, outputs=pi_output) | |
with gr.Tab("Customer Churn"): | |
gr.Markdown("## Customer Churn Prediction") | |
inputs = [ | |
gr.Radio(["Yes", "No"], label="SeniorCitizen"), | |
gr.Radio(["Yes", "No"], label="Partner"), | |
gr.Radio(["No", "Yes"], label="Dependents"), | |
gr.Slider(1, 73, step=1, label="Tenure"), | |
gr.Radio(["DSL", "Fiber optic", "No Internet"], label="InternetService"), | |
gr.Radio(["No", "Yes"], label="OnlineSecurity"), | |
gr.Radio(["No", "Yes"], label="OnlineBackup"), | |
gr.Radio(["No", "Yes"], label="DeviceProtection"), | |
gr.Radio(["No", "Yes"], label="TechSupport"), | |
gr.Radio(["No", "Yes"], label="StreamingTV"), | |
gr.Radio(["No", "Yes"], label="StreamingMovies"), | |
gr.Radio(["Month-to-month", "One year", "Two year"], label="Contract"), | |
gr.Radio(["Yes", "No"], label="PaperlessBilling"), | |
gr.Radio(["Electronic check", "Mailed check", "Bank transfer (automatic)", "Credit card (automatic)"], label="PaymentMethod"), | |
gr.Slider(18.40, 118.65, label="MonthlyCharges") | |
] | |
churn_output = gr.Label() | |
gr.Button("Predict").click(predict, inputs=inputs, outputs=churn_output) | |
if __name__ == "__main__": | |
demo.launch() | |