davanstrien HF Staff commited on
Commit
287b917
·
verified ·
1 Parent(s): 043ea5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -6,6 +6,7 @@ import logging
6
  from typing import Tuple, Literal
7
  import functools
8
  import spaces
 
9
  # Set up logging
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
@@ -16,6 +17,8 @@ model = None
16
  tokenizer = None
17
  device = None
18
 
 
 
19
  def load_model():
20
  global model, tokenizer, device
21
  logger.info("Loading model and tokenizer...")
@@ -48,8 +51,8 @@ def get_card_info(hub_id: str) -> Tuple[str, str]:
48
  raise ValueError(f"Could not find model or dataset with id {hub_id}")
49
 
50
  @spaces.GPU
51
- def generate_summary(card_text: str, card_type: str) -> str:
52
- """Generate a summary for the given card text."""
53
  # Determine prefix based on card type
54
  prefix = "<MODEL_CARD>" if card_type == "model" else "<DATASET_CARD>"
55
 
@@ -84,6 +87,11 @@ def generate_summary(card_text: str, card_type: str) -> str:
84
 
85
  return summary
86
 
 
 
 
 
 
87
  def summarize(hub_id: str = "", card_type: str = "model", content: str = "") -> str:
88
  """Interface function for Gradio."""
89
  try:
@@ -100,6 +108,7 @@ def summarize(hub_id: str = "", card_type: str = "model", content: str = "") ->
100
  else:
101
  return "Error: Either hub_id or content must be provided"
102
 
 
103
  summary = generate_summary(card_text, card_type)
104
  return summary
105
 
 
6
  from typing import Tuple, Literal
7
  import functools
8
  import spaces
9
+ from cashews import cache
10
  # Set up logging
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
 
17
  tokenizer = None
18
  device = None
19
 
20
+ cache.setup("mem://", size_limit="4gb")
21
+
22
  def load_model():
23
  global model, tokenizer, device
24
  logger.info("Loading model and tokenizer...")
 
51
  raise ValueError(f"Could not find model or dataset with id {hub_id}")
52
 
53
  @spaces.GPU
54
+ def _generate_summary_gpu(card_text: str, card_type: str) -> str:
55
+ """Internal function that runs on GPU."""
56
  # Determine prefix based on card type
57
  prefix = "<MODEL_CARD>" if card_type == "model" else "<DATASET_CARD>"
58
 
 
87
 
88
  return summary
89
 
90
+ @cache(ttl="6h")
91
+ def generate_summary(card_text: str, card_type: str) -> str:
92
+ """Cached wrapper for generate_summary."""
93
+ return _generate_summary_gpu(card_text, card_type)
94
+
95
  def summarize(hub_id: str = "", card_type: str = "model", content: str = "") -> str:
96
  """Interface function for Gradio."""
97
  try:
 
108
  else:
109
  return "Error: Either hub_id or content must be provided"
110
 
111
+ # Use the cached wrapper
112
  summary = generate_summary(card_text, card_type)
113
  return summary
114