g0th commited on
Commit
eb36578
·
verified ·
1 Parent(s): 2c0a962

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -4,15 +4,16 @@ import json
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
  from ppt_parser import transfer_to_structure
 
7
 
8
- # ✅ Hugging Face token for gated model access
9
  hf_token = os.getenv("HF_TOKEN")
10
 
11
- # ✅ Load summarization pipeline
12
  summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
13
 
14
- # ✅ Load Mistral 7B Instruct model
15
- @gr.cache()
16
  def load_mistral():
17
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", token=hf_token)
18
  model = AutoModelForCausalLM.from_pretrained(
@@ -21,12 +22,11 @@ def load_mistral():
21
  device_map="auto",
22
  token=hf_token
23
  )
24
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512)
25
- return pipe
26
 
27
  mistral_pipe = load_mistral()
28
 
29
- # ✅ Global text buffer
30
  extracted_text = ""
31
 
32
  def extract_text_from_pptx_json(parsed_json: dict) -> str:
 
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
  from ppt_parser import transfer_to_structure
7
+ from functools import lru_cache
8
 
9
+ # ✅ Get Hugging Face token from Space Secrets
10
  hf_token = os.getenv("HF_TOKEN")
11
 
12
+ # ✅ Load summarization model (BART)
13
  summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
14
 
15
+ # ✅ Load Mistral model (memoized to avoid reloading)
16
+ @lru_cache(maxsize=1)
17
  def load_mistral():
18
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", token=hf_token)
19
  model = AutoModelForCausalLM.from_pretrained(
 
22
  device_map="auto",
23
  token=hf_token
24
  )
25
+ return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512)
 
26
 
27
  mistral_pipe = load_mistral()
28
 
29
+ # ✅ Global variable to hold extracted content
30
  extracted_text = ""
31
 
32
  def extract_text_from_pptx_json(parsed_json: dict) -> str: