Spaces:
Sleeping
Sleeping
import spaces | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
import gradio as gr | |
import torch | |
import logging | |
import sys | |
from accelerate import infer_auto_device_map, init_empty_weights | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Define the model name | |
model_name = "meta-llama/Llama-2-7b-hf" | |
try: | |
logger.info("Starting model initialization...") | |
# Check CUDA availability | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {device}") | |
# Configure PyTorch settings | |
if device == "cuda": | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
# Load tokenizer | |
logger.info("Loading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
use_auth_token=True | |
) | |
logger.info("Tokenizer loaded successfully") | |
# Load model with 8-bit quantization | |
logger.info("Loading model...") | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
trust_remote_code=True, | |
use_auth_token=True, | |
load_in_8bit=True, | |
device_map="auto" | |
) | |
logger.info("Model loaded successfully") | |
# Create pipeline | |
logger.info("Creating generation pipeline...") | |
model_gen = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=256, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
repetition_penalty=1.1, | |
device_map="auto" | |
) | |
logger.info("Pipeline created successfully") | |
except Exception as e: | |
logger.error(f"Error during initialization: {str(e)}") | |
raise | |
# Configure system message | |
system_message = """You are AQuaBot, an AI assistant aware of environmental impact. | |
You help users with any topic while raising awareness about water consumption | |
in AI. Did you know that training GPT-3 consumed 5.4 million liters of water, | |
equivalent to the daily consumption of a city of 10,000 people?""" | |
# Llama 2 specific tokens | |
B_INST, E_INST = "[INST]", "[/INST]" | |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" | |
# Constants for water consumption calculation | |
WATER_PER_TOKEN = { | |
"input_training": 0.0000309, | |
"output_training": 0.0000309, | |
"input_inference": 0.05, | |
"output_inference": 0.05 | |
} | |
# Initialize variables | |
total_water_consumption = 0 | |
def calculate_tokens(text): | |
try: | |
return len(tokenizer.encode(text)) | |
except Exception as e: | |
logger.error(f"Error calculating tokens: {str(e)}") | |
return len(text.split()) + len(text) // 4 # Fallback to approximation | |
def calculate_water_consumption(text, is_input=True): | |
tokens = calculate_tokens(text) | |
if is_input: | |
return tokens * (WATER_PER_TOKEN["input_training"] + WATER_PER_TOKEN["input_inference"]) | |
return tokens * (WATER_PER_TOKEN["output_training"] + WATER_PER_TOKEN["output_inference"]) | |
def format_prompt(user_input, chat_history): | |
""" | |
Format the prompt according to Llama 2 specific style | |
""" | |
prompt = f"{B_INST}{B_SYS}{system_message}{E_SYS}" | |
if chat_history: | |
for user_msg, assistant_msg in chat_history: | |
prompt += f"{user_msg}{E_INST}{assistant_msg}{B_INST}" | |
prompt += f"{user_input}{E_INST}" | |
return prompt | |
def generate_response(user_input, chat_history): | |
try: | |
logger.info("Generating response for user input...") | |
global total_water_consumption | |
# Calculate water consumption for input | |
input_water_consumption = calculate_water_consumption(user_input, True) | |
total_water_consumption += input_water_consumption | |
# Format prompt for Llama 2 | |
prompt = format_prompt(user_input, chat_history) | |
logger.info("Generating model response...") | |
outputs = model_gen( | |
prompt, | |
max_new_tokens=256, | |
return_full_text=False, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
logger.info("Model response generated successfully") | |
assistant_response = outputs[0]['generated_text'].strip() | |
# Calculate water consumption for output | |
output_water_consumption = calculate_water_consumption(assistant_response, False) | |
total_water_consumption += output_water_consumption | |
# Update chat history | |
chat_history.append([user_input, assistant_response]) | |
# Prepare water consumption message | |
water_message = f""" | |
<div style="position: fixed; top: 20px; right: 20px; | |
background-color: white; padding: 15px; | |
border: 2px solid #ff0000; border-radius: 10px; | |
box-shadow: 0 2px 4px rgba(0,0,0,0.1);"> | |
<div style="color: #ff0000; font-size: 24px; font-weight: bold;"> | |
💧 {total_water_consumption:.4f} ml | |
</div> | |
<div style="color: #666; font-size: 14px;"> | |
Water Consumed | |
</div> | |
</div> | |
""" | |
return chat_history, water_message | |
except Exception as e: | |
logger.error(f"Error in generate_response: {str(e)}") | |
error_message = f"An error occurred: {str(e)}" | |
chat_history.append([user_input, error_message]) | |
return chat_history, show_water | |
# Create Gradio interface | |
try: | |
logger.info("Creating Gradio interface...") | |
with gr.Blocks(css="div.gradio-container {background-color: #f0f2f6}") as demo: | |
gr.HTML(""" | |
<div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px;"> | |
<h1 style="color: #2d333a;">AQuaBot</h1> | |
<p style="color: #4a5568;"> | |
Welcome to AQuaBot - An AI assistant powered by Llama 2 that helps raise awareness | |
about water consumption in language models. | |
</p> | |
</div> | |
""") | |
chatbot = gr.Chatbot() | |
message = gr.Textbox( | |
placeholder="Type your message here...", | |
show_label=False | |
) | |
show_water = gr.HTML(f""" | |
<div style="position: fixed; top: 20px; right: 20px; | |
background-color: white; padding: 15px; | |
border: 2px solid #ff0000; border-radius: 10px; | |
box-shadow: 0 2px 4px rgba(0,0,0,0.1);"> | |
<div style="color: #ff0000; font-size: 24px; font-weight: bold;"> | |
💧 0.0000 ml | |
</div> | |
<div style="color: #666; font-size: 14px;"> | |
Water Consumed | |
</div> | |
</div> | |
""") | |
clear = gr.Button("Clear Chat") | |
# Add footer with citation, disclaimer, and credits | |
gr.HTML(""" | |
<div style="text-align: center; max-width: 800px; margin: 20px auto; padding: 20px; | |
background-color: #f8f9fa; border-radius: 10px;"> | |
<div style="margin-bottom: 15px;"> | |
<p style="color: #666; font-size: 14px; font-style: italic;"> | |
Water consumption calculations are based on the study:<br> | |
Li, P. et al. (2023). Making AI Less Thirsty: Uncovering and Addressing the Secret Water | |
Footprint of AI Models. ArXiv Preprint, | |
<a href="https://arxiv.org/abs/2304.03271" target="_blank">https://arxiv.org/abs/2304.03271</a> | |
</p> | |
</div> | |
<div style="border-top: 1px solid #ddd; padding-top: 15px;"> | |
<p style="color: #666; font-size: 14px;"> | |
<strong>Model Information:</strong> This application uses Meta's Llama 2 (7B) model, | |
a state-of-the-art language model fine-tuned for chat interactions. Water consumption | |
calculations are based on the methodology from the cited paper. | |
</p> | |
</div> | |
<div style="border-top: 1px solid #ddd; margin-top: 15px; padding-top: 15px;"> | |
<p style="color: #666; font-size: 14px;"> | |
Created by Camilo Vega - AI Consultant<br> | |
<a href="https://github.com/vegadevs/aquabot" target="_blank">GitHub Repository</a> | |
</p> | |
</div> | |
</div> | |
""") | |
def submit(user_input, chat_history): | |
return generate_response(user_input, chat_history) | |
# Configure event handlers | |
message.submit(submit, [message, chatbot], [chatbot, show_water]) | |
clear.click( | |
lambda: ([], f""" | |
<div style="position: fixed; top: 20px; right: 20px; | |
background-color: white; padding: 15px; | |
border: 2px solid #ff0000; border-radius: 10px; | |
box-shadow: 0 2px 4px rgba(0,0,0,0.1);"> | |
<div style="color: #ff0000; font-size: 24px; font-weight: bold;"> | |
💧 0.0000 ml | |
</div> | |
<div style="color: #666; font-size: 14px;"> | |
Water Consumed | |
</div> | |
</div> | |
"""), | |
None, | |
[chatbot, show_water] | |
) | |
logger.info("Gradio interface created successfully") | |
# Launch the application | |
logger.info("Launching application...") | |
demo.launch() | |
except Exception as e: | |
logger.error(f"Error in Gradio interface creation: {str(e)}") | |
raise |