import gradio as gr from inference import SentimentInference import os from datasets import load_dataset import random # --- Initialize Sentiment Model --- CONFIG_PATH = os.path.join(os.path.dirname(__file__), "config.yaml") if not os.path.exists(CONFIG_PATH): CONFIG_PATH = "config.yaml" if not os.path.exists(CONFIG_PATH): raise FileNotFoundError( f"Configuration file not found. Tried {os.path.join(os.path.dirname(__file__), 'config.yaml')} and {CONFIG_PATH}. " f"Ensure 'config.yaml' exists and is accessible." ) print(f"Loading model with config: {CONFIG_PATH}") try: sentiment_inferer = SentimentInference(config_path=CONFIG_PATH) print("Sentiment model loaded successfully.") except Exception as e: print(f"Error loading sentiment model: {e}") sentiment_inferer = None # --- Load IMDB Dataset --- print("Loading IMDB dataset for samples...") try: imdb_dataset = load_dataset("imdb", split="test") print("IMDB dataset loaded successfully.") except Exception as e: print(f"Failed to load IMDB dataset: {e}. Sample loading will be disabled.") imdb_dataset = None def load_random_imdb_sample(): """Loads a random sample text from the IMDB dataset.""" if imdb_dataset is None: return "IMDB dataset not available. Cannot load sample.", None random_index = random.randint(0, len(imdb_dataset) - 1) sample = imdb_dataset[random_index] return sample["text"], sample["label"] def predict_sentiment(text_input, true_label_state): """Predicts sentiment for the given text_input.""" if sentiment_inferer is None: return "Error: Sentiment model could not be loaded. Please check the logs.", true_label_state if not text_input or not text_input.strip(): return "Please enter some text for analysis.", true_label_state try: prediction = sentiment_inferer.predict(text_input) sentiment = prediction['sentiment'] # Convert numerical label to text if available true_sentiment = None if true_label_state is not None: true_sentiment = "positive" if true_label_state == 1 else "negative" result = f"Predicted Sentiment: {sentiment.capitalize()}" if true_sentiment: result += f"\nTrue IMDB Label: {true_sentiment.capitalize()}" return result, None # Reset true label state after display except Exception as e: print(f"Error during prediction: {e}") return f"Error during prediction: {str(e)}", true_label_state # --- Gradio Interface --- with gr.Blocks() as demo: true_label = gr.State() gr.Markdown("## IMDb Sentiment Analyzer") gr.Markdown("Enter a movie review to classify its sentiment as Positive or Negative, or load a random sample from the IMDb dataset.") with gr.Row(): input_textbox = gr.Textbox(lines=7, placeholder="Enter movie review here...", label="Movie Review", scale=3) output_text = gr.Text(label="Analysis Result", scale=1) with gr.Row(): submit_button = gr.Button("Analyze Sentiment") load_sample_button = gr.Button("Load Random IMDB Sample") gr.Examples( examples=[ ["This movie was absolutely fantastic! The acting was superb and the plot was gripping."], ["I was really disappointed with this film. It was boring and the story made no sense."], ["An average movie, had some good parts but overall quite forgettable."], ["Wow so I don't think I've ever seen a movie quite like that. The plot was... interesting, and the acting was, well, hmm."] ], inputs=input_textbox ) # Wire actions submit_button.click( fn=predict_sentiment, inputs=[input_textbox, true_label], outputs=[output_text, true_label] ) load_sample_button.click( fn=load_random_imdb_sample, inputs=None, outputs=[input_textbox, true_label] ) if __name__ == '__main__': print("Launching Gradio interface...") demo.launch(share=False)