File size: 4,500 Bytes
d124aee 043ea5c 572201a 3402390 d124aee 572201a 287b917 d124aee 418bed4 287b917 d124aee 572201a 287b917 572201a 287b917 3402390 a0384f7 d124aee 3402390 d124aee da5ada1 a0384f7 d124aee da5ada1 a0384f7 d124aee da5ada1 3402390 a0384f7 da5ada1 3402390 da5ada1 d124aee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import ModelCard, DatasetCard, model_info, dataset_info
import logging
from typing import Tuple, Literal
import functools
import spaces
from cachetools import TTLCache
from cachetools.func import ttl_cache
import time
import os
os.environ['HF_TRANSFER'] = "1"
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables
MODEL_NAME = "davanstrien/Smol-Hub-tldr"
model = None
tokenizer = None
device = None
CACHE_TTL = 6 * 60 * 60 # 6 hours in seconds
CACHE_MAXSIZE = 100
def load_model():
global model, tokenizer, device
logger.info("Loading model and tokenizer...")
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
model = model.to(device)
model.eval()
return True
except Exception as e:
logger.error(f"Failed to load model: {e}")
return False
@functools.lru_cache(maxsize=100)
def get_card_info(hub_id: str) -> Tuple[str, str]:
"""Get card information from a Hugging Face hub_id."""
try:
info = model_info(hub_id)
card = ModelCard.load(hub_id)
return "model", card.text
except Exception as e:
logger.error(f"Error fetching model card for {hub_id}: {e}")
try:
info = dataset_info(hub_id)
card = DatasetCard.load(hub_id)
return "dataset", card.text
except Exception as e:
logger.error(f"Error fetching dataset card for {hub_id}: {e}")
raise ValueError(f"Could not find model or dataset with id {hub_id}")
@spaces.GPU
def _generate_summary_gpu(card_text: str, card_type: str) -> str:
"""Internal function that runs on GPU."""
# Determine prefix based on card type
prefix = "<MODEL_CARD>" if card_type == "model" else "<DATASET_CARD>"
# Format input according to the chat template
messages = [{"role": "user", "content": f"{prefix}{card_text}"}]
inputs = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt"
)
inputs = inputs.to(device)
# Generate with optimized settings
with torch.no_grad():
outputs = model.generate(
inputs,
max_new_tokens=60,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
temperature=0.4,
do_sample=True,
use_cache=True,
)
# Extract and clean up the summary
input_length = inputs.shape[1]
response = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=False)
# Extract just the summary part
try:
summary = response.split("<CARD_SUMMARY>")[-1].split("</CARD_SUMMARY>")[0].strip()
except IndexError:
summary = response.strip()
return summary
@ttl_cache(maxsize=CACHE_MAXSIZE, ttl=CACHE_TTL)
def generate_summary(card_text: str, card_type: str) -> str:
"""Cached wrapper for generate_summary with TTL."""
return _generate_summary_gpu(card_text, card_type)
def summarize(hub_id: str = "") -> str:
"""Interface function for Gradio. Returns JSON format."""
try:
if hub_id:
# Fetch and infer card type automatically
card_type, card_text = get_card_info(hub_id)
# Use the cached wrapper
summary = generate_summary(card_text, card_type)
return f'{{"summary": "{summary}", "type": "{card_type}", "hub_id": "{hub_id}"}}'
else:
error_msg = "Error: Hub ID must be provided"
return f'{{"error": "{error_msg}"}}'
except Exception as e:
error_msg = str(e)
return f'{{"error": "{error_msg}"}}'
def create_interface():
interface = gr.Interface(
fn=summarize,
inputs=gr.Textbox(label="Hub ID", placeholder="e.g., huggingface/llama-7b"),
outputs=gr.JSON(label="Output"),
title="Hugging Face Hub TLDR Generator",
description="Generate concise summaries of model and dataset cards from the Hugging Face Hub.",
)
return interface
if __name__ == "__main__":
if load_model():
interface = create_interface()
interface.launch()
else:
print("Failed to load model. Please check the logs for details.")
|