File size: 5,644 Bytes
d124aee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import ModelCard, DatasetCard, model_info, dataset_info
import logging
from typing import Tuple, Literal
import functools
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables
MODEL_NAME = "davanstrien/Smol-Hub-tldr"
model = None
tokenizer = None
device = None
def load_model():
global model, tokenizer, device
logger.info("Loading model and tokenizer...")
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
model = model.to(device)
model.eval()
return True
except Exception as e:
logger.error(f"Failed to load model: {e}")
return False
@functools.lru_cache(maxsize=100)
def get_card_info(hub_id: str) -> Tuple[str, str]:
"""Get card information from a Hugging Face hub_id."""
try:
info = model_info(hub_id)
card = ModelCard.load(hub_id)
return "model", card.text
except Exception as e:
logger.error(f"Error fetching model card for {hub_id}: {e}")
try:
info = dataset_info(hub_id)
card = DatasetCard.load(hub_id)
return "dataset", card.text
except Exception as e:
logger.error(f"Error fetching dataset card for {hub_id}: {e}")
raise ValueError(f"Could not find model or dataset with id {hub_id}")
@functools.lru_cache(maxsize=100)
def generate_summary(card_text: str, card_type: str) -> str:
"""Generate a summary for the given card text."""
# Determine prefix based on card type
prefix = "<MODEL_CARD>" if card_type == "model" else "<DATASET_CARD>"
# Format input according to the chat template
messages = [{"role": "user", "content": f"{prefix}{card_text}"}]
inputs = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt"
)
inputs = inputs.to(device)
# Generate with optimized settings
with torch.no_grad():
outputs = model.generate(
inputs,
max_new_tokens=60,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
temperature=0.4,
do_sample=True,
use_cache=True,
)
# Extract and clean up the summary
input_length = inputs.shape[1]
response = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=False)
# Extract just the summary part
try:
summary = response.split("<CARD_SUMMARY>")[-1].split("</CARD_SUMMARY>")[0].strip()
except IndexError:
summary = response.strip()
return summary
def summarize(hub_id: str = "", card_type: str = "model", content: str = "") -> str:
"""Interface function for Gradio."""
try:
if hub_id:
# Fetch and validate card type
inferred_type, card_text = get_card_info(hub_id)
if card_type and card_type != inferred_type:
return f"Error: Provided card_type '{card_type}' doesn't match inferred type '{inferred_type}'"
card_type = inferred_type
elif content:
if not card_type:
return "Error: card_type must be provided when using direct content"
card_text = content
else:
return "Error: Either hub_id or content must be provided"
summary = generate_summary(card_text, card_type)
return summary
except Exception as e:
return f"Error: {str(e)}"
# Create the Gradio interface
def create_interface():
with gr.Blocks(title="Hub TLDR") as interface:
gr.Markdown("# Hugging Face Hub TLDR Generator")
gr.Markdown("Generate concise summaries of model and dataset cards from the Hugging Face Hub.")
with gr.Tab("Summarize by Hub ID"):
hub_id_input = gr.Textbox(
label="Hub ID",
placeholder="e.g., huggingface/llama-7b"
)
hub_id_type = gr.Radio(
choices=["model", "dataset"],
label="Card Type (optional)",
value="model"
)
hub_id_button = gr.Button("Generate Summary")
hub_id_output = gr.Textbox(label="Summary")
hub_id_button.click(
fn=summarize,
inputs=[hub_id_input, hub_id_type],
outputs=hub_id_output
)
with gr.Tab("Summarize Custom Content"):
content_input = gr.Textbox(
label="Content",
placeholder="Paste your model or dataset card content here...",
lines=10
)
content_type = gr.Radio(
choices=["model", "dataset"],
label="Card Type",
value="model"
)
content_button = gr.Button("Generate Summary")
content_output = gr.Textbox(label="Summary")
content_button.click(
fn=lambda content, card_type: summarize(content=content, card_type=card_type),
inputs=[content_input, content_type],
outputs=content_output
)
return interface
if __name__ == "__main__":
if load_model():
interface = create_interface()
interface.launch()
else:
print("Failed to load model. Please check the logs for details.")
|