Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -1,4 +1,4 @@ | |
| 1 | 
            -
            import gradio as gr | 
| 2 | 
             
            import torch
         | 
| 3 | 
             
            from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
         | 
| 4 |  | 
| @@ -10,48 +10,41 @@ bnb_config = BitsAndBytesConfig( | |
| 10 | 
             
                bnb_4bit_compute_dtype=torch.bfloat16
         | 
| 11 | 
             
            )
         | 
| 12 |  | 
| 13 | 
            -
             | 
| 14 | 
            -
             | 
| 15 | 
            -
             | 
| 16 | 
            -
             | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
| 19 | 
            -
             | 
| 20 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 21 |  | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
| 24 | 
            -
             | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
| 29 | 
            -
             | 
| 30 | 
            -
                
         | 
| 31 | 
            -
                 | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
                    bot_input_ids = torch.cat([flat_history_tensor, new_user_input_ids], dim=-1) if self.history else new_user_input_ids            
         | 
| 36 | 
            -
                    chat_history_ids = model.generate(bot_input_ids, max_length=2000, pad_token_id=tokenizer.eos_token_id)            
         | 
| 37 | 
            -
                    self.history.append(chat_history_ids[:, bot_input_ids.shape[-1]:].tolist()[0])            
         | 
| 38 | 
            -
                    response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)     
         | 
| 39 | 
            -
                    return response      
         | 
| 40 | 
            -
              
         | 
| 41 | 
            -
            bot = ChatBot() 
         | 
| 42 | 
            -
              
         | 
| 43 | 
            -
            title = "👋🏻Welcome to Tonic's EZ Chat🚀"    
         | 
| 44 | 
            -
            description = "You can use this Space to test out the current model (MistralMed) or duplicate this Space and use it for any other model on 🤗HuggingFace. Join me on [Discord](https://discord.gg/fpEPNZGsbt) to build together."    
         | 
| 45 | 
            -
            examples = [["What is the boiling point of nitrogen?"]]    
         | 
| 46 | 
            -
              
         | 
| 47 | 
            -
            iface = gr.Interface(    
         | 
| 48 | 
            -
                fn=bot.predict,    
         | 
| 49 | 
            -
                title=title,    
         | 
| 50 | 
            -
                description=description,    
         | 
| 51 | 
            -
                examples=examples,    
         | 
| 52 | 
            -
                inputs="text",    
         | 
| 53 | 
            -
                outputs="text", 
         | 
| 54 | 
             
                theme="ParityError/Anime"
         | 
| 55 | 
            -
            ) | 
| 56 | 
            -
             | 
| 57 | 
            -
            iface.launch() | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
             
            import torch
         | 
| 3 | 
             
            from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
         | 
| 4 |  | 
|  | |
| 10 | 
             
                bnb_4bit_compute_dtype=torch.bfloat16
         | 
| 11 | 
             
            )
         | 
| 12 |  | 
| 13 | 
            +
            # Load the fine-tuned model "Tonic/mistralmed"
         | 
| 14 | 
            +
            model = AutoModelForCausalLM.from_pretrained("Tonic/mistralmed", quantization_config=bnb_config)
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            tokenizer = AutoTokenizer.from_pretrained("Tonic/mistralmed", trust_remote_code=True)
         | 
| 17 | 
            +
            tokenizer.pad_token = tokenizer.eos_token
         | 
| 18 | 
            +
            tokenizer.padding_side = 'left'
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            class ChatBot:
         | 
| 21 | 
            +
                def __init__(self):
         | 
| 22 | 
            +
                    self.history = []
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def predict(self, input):
         | 
| 25 | 
            +
                    new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors="pt")
         | 
| 26 | 
            +
                    flat_history = [item for sublist in self.history for item in sublist]
         | 
| 27 | 
            +
                    flat_history_tensor = torch.tensor(flat_history).unsqueeze(dim=0)
         | 
| 28 | 
            +
                    bot_input_ids = torch.cat([flat_history_tensor, new_user_input_ids], dim=-1) if self.history else new_user_input_ids
         | 
| 29 | 
            +
                    chat_history_ids = model.generate(bot_input_ids, max_length=2000, pad_token_id=tokenizer.eos_token_id)
         | 
| 30 | 
            +
                    self.history.append(chat_history_ids[:, bot_input_ids.shape[-1]:].tolist()[0])
         | 
| 31 | 
            +
                    response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
         | 
| 32 | 
            +
                    return response
         | 
| 33 |  | 
| 34 | 
            +
            bot = ChatBot()
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            title = "👋🏻Welcome to Tonic's EZ Chat🚀"
         | 
| 37 | 
            +
            description = "You can use this Space to test out the current model (MistralMed) or duplicate this Space and use it for any other model on 🤗HuggingFace. Join me on [Discord](https://discord.gg/fpEPNZGsbt) to build together."
         | 
| 38 | 
            +
            examples = [["What is the boiling point of nitrogen"]]
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            iface = gr.Interface(
         | 
| 41 | 
            +
                fn=bot.predict,
         | 
| 42 | 
            +
                title=title,
         | 
| 43 | 
            +
                description=description,
         | 
| 44 | 
            +
                examples=examples,
         | 
| 45 | 
            +
                inputs="text",
         | 
| 46 | 
            +
                outputs="text",
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 47 | 
             
                theme="ParityError/Anime"
         | 
| 48 | 
            +
            )
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            iface.launch()
         | 
 
			
