import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM from termcolor import colored # --- Model and Tokenizer Loading --- MODEL_PATH = "01/medical_model_rl/final" TOKENIZER_PATH = "01/medical_model_rl/final" print("Loading model and tokenizer...") try: tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, padding_side='left') model = AutoModelForCausalLM.from_pretrained(MODEL_PATH) model.resize_token_embeddings(len(tokenizer)) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() print(colored("Model loaded successfully.", "green")) except Exception as e: print(colored(f"Error loading model: {e}", "red")) model = None tokenizer = None # --- Chatbot Inference Function --- def medical_chatbot(message, history): """ Generates a response from the medical chatbot model. """ if not model or not tokenizer: return "Error: Model is not loaded. Please check the console for errors." try: # Format the prompt full_prompt = f"Question: {message}\n\nAnswer:" # Tokenize the input inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True).to(device) # Generate a response with torch.no_grad(): output_sequences = model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=128, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id, ) # Decode the response response_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True) # Extract only the answer part answer = response_text.split("Answer:")[-1].strip() return answer except Exception as e: print(colored(f"An error occurred during inference: {e}", "red")) return "Sorry, I encountered an error. Please try again." # --- Gradio UI --- chatbot_interface = gr.ChatInterface( fn=medical_chatbot, title="Medical Chatbot", description="Ask any medical question, and the AI will try to answer.", examples=[ ["What are the symptoms of diabetes?"], ["How does metformin work?"], ["What is the difference between a virus and a bacteria?"], ], theme="soft", ).launch(share=True)