Manju017's picture
code to include the necessary imports and settings to use the Accelerate library effectively
bce1941 verified
raw
history blame
1.07 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import infer_auto_device_map
# Load the model name
model_name = "ai4bharat/Airavata"
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Automatically determine the device map
device_map = infer_auto_device_map(model_name)
# Load the model with the device map
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map=device_map,
load_in_8bit=True # Use 8-bit precision for reduced memory usage
)
# Define the inference function
def generate_text(prompt):
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Create the Gradio interface
interface = gr.Interface(
fn=generate_text,
inputs="text",
outputs="text",
title="Airavata Text Generation Model",
description="This is the AI4Bharat Airavata model for text generation in Indic languages."
)
# Launch the interface
interface.launch()