ASG Models commited on
Commit
e6bbc72
·
verified ·
1 Parent(s): 76d6781

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -2
app.py CHANGED
@@ -7,6 +7,29 @@ import requests
7
  from genai_chat_ai import AI,create_chat_session
8
  api_key = os.environ.get("Id_mode_vits")
9
  headers = {"Authorization": f"Bearer {api_key}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def remove_extra_spaces(text):
12
  """
@@ -69,9 +92,9 @@ with gr.Blocks() as demo: # Use gr.Blocks to wrap the entire interface
69
  API_URL = f"https://api-inference.huggingface.co/models/{model_choice}"
70
  text_answer = get_answer_ai(text)
71
  text_answer = remove_extra_spaces(text_answer)
72
- data_ai = query(text_answer, API_URL)
73
  if generate_user_audio: # Generate user audio if needed
74
- data_user = query(text, API_URL)
75
  return data_user, data_ai, text_answer
76
  else:
77
  return data_ai # Return None for user_audio
 
7
  from genai_chat_ai import AI,create_chat_session
8
  api_key = os.environ.get("Id_mode_vits")
9
  headers = {"Authorization": f"Bearer {api_key}"}
10
+ from transformers import pipeline
11
+ from transformers import AutoTokenizer,VitsModel
12
+ import torch
13
+ models= {}
14
+ tokenizer = AutoTokenizer.from_pretrained("asg2024/vits-ar-sa-huba",token=api_key)
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ def get_model(name_model):
17
+ global models
18
+ if name_model in not models:
19
+ models[name_model]=VitsModel.from_pretrained(name_model,token=api_key).to(device)
20
+ return models[name_model]
21
+
22
+ def genrate_speech(text,name_model):
23
+ inputs=tokenizer(text,return_tensors="pt")
24
+ model=get_model(name_model)
25
+ with torch.no_grad():
26
+ wav=model(
27
+ input_ids= input_ids.input_ids.to(device),
28
+ attention_mask=input_ids.attention_mask.to(device),
29
+ speaker_id=0
30
+ ).waveform.cpu().numpy().reshape(-1)
31
+ return model.config.sampling_rate,wav
32
+
33
 
34
  def remove_extra_spaces(text):
35
  """
 
92
  API_URL = f"https://api-inference.huggingface.co/models/{model_choice}"
93
  text_answer = get_answer_ai(text)
94
  text_answer = remove_extra_spaces(text_answer)
95
+ data_ai = genrate_speech(text_answer,model_choice)#query(text_answer, API_URL)
96
  if generate_user_audio: # Generate user audio if needed
97
+ data_user =genrate_speech(text_answer,model_choice)# query(text, API_URL)
98
  return data_user, data_ai, text_answer
99
  else:
100
  return data_ai # Return None for user_audio