File size: 5,913 Bytes
d124aee
 
 
 
 
 
 
043ea5c
287b917
d124aee
 
 
 
 
 
 
 
 
 
287b917
 
d124aee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418bed4
287b917
 
d124aee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287b917
 
 
 
 
d124aee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287b917
d124aee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import ModelCard, DatasetCard, model_info, dataset_info
import logging
from typing import Tuple, Literal
import functools
import spaces
from cashews import cache
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Global variables
MODEL_NAME = "davanstrien/Smol-Hub-tldr"
model = None
tokenizer = None
device = None

cache.setup("mem://", size_limit="4gb")

def load_model():
    global model, tokenizer, device
    logger.info("Loading model and tokenizer...")
    try:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
        model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
        model = model.to(device)
        model.eval()
        return True
    except Exception as e:
        logger.error(f"Failed to load model: {e}")
        return False

@functools.lru_cache(maxsize=100)
def get_card_info(hub_id: str) -> Tuple[str, str]:
    """Get card information from a Hugging Face hub_id."""
    try:
        info = model_info(hub_id)
        card = ModelCard.load(hub_id)
        return "model", card.text
    except Exception as e:
        logger.error(f"Error fetching model card for {hub_id}: {e}")
        try:
            info = dataset_info(hub_id)
            card = DatasetCard.load(hub_id)
            return "dataset", card.text
        except Exception as e:
            logger.error(f"Error fetching dataset card for {hub_id}: {e}")
            raise ValueError(f"Could not find model or dataset with id {hub_id}")

@spaces.GPU
def _generate_summary_gpu(card_text: str, card_type: str) -> str:
    """Internal function that runs on GPU."""
    # Determine prefix based on card type
    prefix = "<MODEL_CARD>" if card_type == "model" else "<DATASET_CARD>"

    # Format input according to the chat template
    messages = [{"role": "user", "content": f"{prefix}{card_text}"}]
    inputs = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True, return_tensors="pt"
    )
    inputs = inputs.to(device)

    # Generate with optimized settings
    with torch.no_grad():
        outputs = model.generate(
            inputs,
            max_new_tokens=60,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            temperature=0.4,
            do_sample=True,
            use_cache=True,
        )

    # Extract and clean up the summary
    input_length = inputs.shape[1]
    response = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=False)

    # Extract just the summary part
    try:
        summary = response.split("<CARD_SUMMARY>")[-1].split("</CARD_SUMMARY>")[0].strip()
    except IndexError:
        summary = response.strip()

    return summary

@cache(ttl="6h")
def generate_summary(card_text: str, card_type: str) -> str:
    """Cached wrapper for generate_summary."""
    return _generate_summary_gpu(card_text, card_type)

def summarize(hub_id: str = "", card_type: str = "model", content: str = "") -> str:
    """Interface function for Gradio."""
    try:
        if hub_id:
            # Fetch and validate card type
            inferred_type, card_text = get_card_info(hub_id)
            if card_type and card_type != inferred_type:
                return f"Error: Provided card_type '{card_type}' doesn't match inferred type '{inferred_type}'"
            card_type = inferred_type
        elif content:
            if not card_type:
                return "Error: card_type must be provided when using direct content"
            card_text = content
        else:
            return "Error: Either hub_id or content must be provided"

        # Use the cached wrapper
        summary = generate_summary(card_text, card_type)
        return summary

    except Exception as e:
        return f"Error: {str(e)}"

# Create the Gradio interface
def create_interface():
    with gr.Blocks(title="Hub TLDR") as interface:
        gr.Markdown("# Hugging Face Hub TLDR Generator")
        gr.Markdown("Generate concise summaries of model and dataset cards from the Hugging Face Hub.")
        
        with gr.Tab("Summarize by Hub ID"):
            hub_id_input = gr.Textbox(
                label="Hub ID",
                placeholder="e.g., huggingface/llama-7b"
            )
            hub_id_type = gr.Radio(
                choices=["model", "dataset"],
                label="Card Type (optional)",
                value="model"
            )
            hub_id_button = gr.Button("Generate Summary")
            hub_id_output = gr.Textbox(label="Summary")
            
            hub_id_button.click(
                fn=summarize,
                inputs=[hub_id_input, hub_id_type],
                outputs=hub_id_output
            )

        with gr.Tab("Summarize Custom Content"):
            content_input = gr.Textbox(
                label="Content",
                placeholder="Paste your model or dataset card content here...",
                lines=10
            )
            content_type = gr.Radio(
                choices=["model", "dataset"],
                label="Card Type",
                value="model"
            )
            content_button = gr.Button("Generate Summary")
            content_output = gr.Textbox(label="Summary")
            
            content_button.click(
                fn=lambda content, card_type: summarize(content=content, card_type=card_type),
                inputs=[content_input, content_type],
                outputs=content_output
            )

    return interface

if __name__ == "__main__":
    if load_model():
        interface = create_interface()
        interface.launch()
    else:
        print("Failed to load model. Please check the logs for details.")