bluenevus commited on
Commit
1668d21
·
verified ·
1 Parent(s): ca79387

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -36
app.py CHANGED
@@ -16,45 +16,48 @@ import logging
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
- # Initialize Gemini AI
20
- genai.configure(api_key='YOUR_GEMINI_API_KEY')
21
-
22
  # Set up device
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
 
25
- # Load Orpheus model
26
- print("Loading Orpheus model...")
27
- model_name = "canopylabs/orpheus-3b-0.1-ft"
28
-
29
- HF_TOKEN = "YOUR_HUGGINGFACE_TOKEN"
30
- login(token=HF_TOKEN)
31
-
32
- snapshot_download(
33
- repo_id=model_name,
34
- use_auth_token=HF_TOKEN,
35
- allow_patterns=[
36
- "config.json",
37
- "*.safetensors",
38
- "model.safetensors.index.json",
39
- ],
40
- ignore_patterns=[
41
- "optimizer.pt",
42
- "pytorch_model.bin",
43
- "training_args.bin",
44
- "scheduler.pt",
45
- "tokenizer.json",
46
- "tokenizer_config.json",
47
- "special_tokens_map.json",
48
- "vocab.json",
49
- "merges.txt",
50
- "tokenizer.*"
51
- ]
52
- )
53
-
54
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
55
- model.to(device)
56
- tokenizer = AutoTokenizer.from_pretrained(model_name)
57
- print(f"Orpheus model loaded to {device}")
 
 
 
 
 
 
58
 
59
  def generate_podcast_script(api_key, content, duration, num_hosts):
60
  genai.configure(api_key=api_key)
@@ -94,6 +97,7 @@ def generate_podcast_script(api_key, content, duration, num_hosts):
94
  return clean_text
95
 
96
  def text_to_speech(text, voice):
 
97
  inputs = tokenizer(text, return_tensors="pt").to(device)
98
  with torch.no_grad():
99
  output = model.generate(**inputs, max_new_tokens=256)
@@ -135,6 +139,10 @@ def render_podcast(api_key, script, voice1, voice2, num_hosts):
135
  with gr.Blocks() as demo:
136
  gr.Markdown("# AI Podcast Generator")
137
 
 
 
 
 
138
  api_key_input = gr.Textbox(label="Enter your Gemini API Key", type="password")
139
 
140
  with gr.Row():
@@ -159,6 +167,8 @@ with gr.Blocks() as demo:
159
  render_btn = gr.Button("Render Podcast")
160
  audio_output = gr.Audio(label="Generated Podcast")
161
 
 
 
162
  def generate_script_wrapper(api_key, content, duration, num_hosts):
163
  return generate_podcast_script(api_key, content, duration, num_hosts)
164
 
 
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
 
 
 
19
  # Set up device
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
+ # Initialize model and tokenizer as None
23
+ model = None
24
+ tokenizer = None
25
+
26
+ def load_model(hf_token):
27
+ global model, tokenizer
28
+
29
+ print("Loading Orpheus model...")
30
+ model_name = "canopylabs/orpheus-3b-0.1-ft"
31
+
32
+ login(token=hf_token)
33
+
34
+ snapshot_download(
35
+ repo_id=model_name,
36
+ use_auth_token=hf_token,
37
+ allow_patterns=[
38
+ "config.json",
39
+ "*.safetensors",
40
+ "model.safetensors.index.json",
41
+ ],
42
+ ignore_patterns=[
43
+ "optimizer.pt",
44
+ "pytorch_model.bin",
45
+ "training_args.bin",
46
+ "scheduler.pt",
47
+ "tokenizer.json",
48
+ "tokenizer_config.json",
49
+ "special_tokens_map.json",
50
+ "vocab.json",
51
+ "merges.txt",
52
+ "tokenizer.*"
53
+ ]
54
+ )
55
+
56
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
57
+ model.to(device)
58
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
59
+ print(f"Orpheus model loaded to {device}")
60
+ return "Model loaded successfully"
61
 
62
  def generate_podcast_script(api_key, content, duration, num_hosts):
63
  genai.configure(api_key=api_key)
 
97
  return clean_text
98
 
99
  def text_to_speech(text, voice):
100
+ global model, tokenizer
101
  inputs = tokenizer(text, return_tensors="pt").to(device)
102
  with torch.no_grad():
103
  output = model.generate(**inputs, max_new_tokens=256)
 
139
  with gr.Blocks() as demo:
140
  gr.Markdown("# AI Podcast Generator")
141
 
142
+ hf_token_input = gr.Textbox(label="Enter your Hugging Face API Token", type="password")
143
+ load_model_btn = gr.Button("Load Orpheus Model")
144
+ model_status = gr.Markdown("Model not loaded")
145
+
146
  api_key_input = gr.Textbox(label="Enter your Gemini API Key", type="password")
147
 
148
  with gr.Row():
 
167
  render_btn = gr.Button("Render Podcast")
168
  audio_output = gr.Audio(label="Generated Podcast")
169
 
170
+ load_model_btn.click(load_model, inputs=[hf_token_input], outputs=[model_status])
171
+
172
  def generate_script_wrapper(api_key, content, duration, num_hosts):
173
  return generate_podcast_script(api_key, content, duration, num_hosts)
174