Update app.py
Browse files
app.py
CHANGED
@@ -6,6 +6,7 @@ import logging
|
|
6 |
from typing import Tuple, Literal
|
7 |
import functools
|
8 |
import spaces
|
|
|
9 |
# Set up logging
|
10 |
logging.basicConfig(level=logging.INFO)
|
11 |
logger = logging.getLogger(__name__)
|
@@ -16,6 +17,8 @@ model = None
|
|
16 |
tokenizer = None
|
17 |
device = None
|
18 |
|
|
|
|
|
19 |
def load_model():
|
20 |
global model, tokenizer, device
|
21 |
logger.info("Loading model and tokenizer...")
|
@@ -48,8 +51,8 @@ def get_card_info(hub_id: str) -> Tuple[str, str]:
|
|
48 |
raise ValueError(f"Could not find model or dataset with id {hub_id}")
|
49 |
|
50 |
@spaces.GPU
|
51 |
-
def
|
52 |
-
"""
|
53 |
# Determine prefix based on card type
|
54 |
prefix = "<MODEL_CARD>" if card_type == "model" else "<DATASET_CARD>"
|
55 |
|
@@ -84,6 +87,11 @@ def generate_summary(card_text: str, card_type: str) -> str:
|
|
84 |
|
85 |
return summary
|
86 |
|
|
|
|
|
|
|
|
|
|
|
87 |
def summarize(hub_id: str = "", card_type: str = "model", content: str = "") -> str:
|
88 |
"""Interface function for Gradio."""
|
89 |
try:
|
@@ -100,6 +108,7 @@ def summarize(hub_id: str = "", card_type: str = "model", content: str = "") ->
|
|
100 |
else:
|
101 |
return "Error: Either hub_id or content must be provided"
|
102 |
|
|
|
103 |
summary = generate_summary(card_text, card_type)
|
104 |
return summary
|
105 |
|
|
|
6 |
from typing import Tuple, Literal
|
7 |
import functools
|
8 |
import spaces
|
9 |
+
from cashews import cache
|
10 |
# Set up logging
|
11 |
logging.basicConfig(level=logging.INFO)
|
12 |
logger = logging.getLogger(__name__)
|
|
|
17 |
tokenizer = None
|
18 |
device = None
|
19 |
|
20 |
+
cache.setup("mem://", size_limit="4gb")
|
21 |
+
|
22 |
def load_model():
|
23 |
global model, tokenizer, device
|
24 |
logger.info("Loading model and tokenizer...")
|
|
|
51 |
raise ValueError(f"Could not find model or dataset with id {hub_id}")
|
52 |
|
53 |
@spaces.GPU
|
54 |
+
def _generate_summary_gpu(card_text: str, card_type: str) -> str:
|
55 |
+
"""Internal function that runs on GPU."""
|
56 |
# Determine prefix based on card type
|
57 |
prefix = "<MODEL_CARD>" if card_type == "model" else "<DATASET_CARD>"
|
58 |
|
|
|
87 |
|
88 |
return summary
|
89 |
|
90 |
+
@cache(ttl="6h")
|
91 |
+
def generate_summary(card_text: str, card_type: str) -> str:
|
92 |
+
"""Cached wrapper for generate_summary."""
|
93 |
+
return _generate_summary_gpu(card_text, card_type)
|
94 |
+
|
95 |
def summarize(hub_id: str = "", card_type: str = "model", content: str = "") -> str:
|
96 |
"""Interface function for Gradio."""
|
97 |
try:
|
|
|
108 |
else:
|
109 |
return "Error: Either hub_id or content must be provided"
|
110 |
|
111 |
+
# Use the cached wrapper
|
112 |
summary = generate_summary(card_text, card_type)
|
113 |
return summary
|
114 |
|