SameerArz commited on
Commit
52a55fa
·
verified ·
1 Parent(s): d8cb74d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -17
app.py CHANGED
@@ -15,20 +15,21 @@ load_dotenv()
15
  # Constants
16
  HF_TOKEN = os.getenv("HF_TOKEN")
17
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
18
- IMAGE_GENERATION_SPACE_NAME = "stabilityai/stable-diffusion-3.5-large-turbo"
 
19
 
20
- # Initialize Groq client with minimal parameters
21
  try:
22
  groq_client = Groq(api_key=GROQ_API_KEY)
23
  except Exception as e:
24
  st.error(f"Failed to initialize Groq client: {e}")
25
  groq_client = None
26
 
27
- # LLM Models (free options)
28
  LLM_MODELS = {
29
  "Mixtral 8x7B (Groq)": "mixtral-8x7b-32768",
30
  "Mistral 7B (HF)": "mistralai/Mixtral-7B-Instruct-v0.1",
31
- "LLaMA 13B (HF)": "meta-llama/Llama-13b-hf" # Note: May require approval; replace if needed
32
  }
33
 
34
  # Utility Functions
@@ -45,7 +46,7 @@ def generate_tutor_output(subject, difficulty, student_input, model):
45
  Format your response as a JSON object with keys: "lesson", "question", "feedback"
46
  """
47
 
48
- if model.startswith("mixtral") and groq_client: # Groq model
49
  try:
50
  completion = groq_client.chat.completions.create(
51
  messages=[{
@@ -61,8 +62,8 @@ def generate_tutor_output(subject, difficulty, student_input, model):
61
  return json.loads(completion.choices[0].message.content)
62
  except Exception as e:
63
  st.error(f"Groq error: {e}")
64
- return {"lesson": "Error generating lesson", "question": "N/A", "feedback": "N/A"}
65
- else: # Hugging Face models
66
  try:
67
  client = Client("https://api-inference.huggingface.co/models/" + model, hf_token=HF_TOKEN)
68
  response = client.predict(prompt, api_name="/generate")
@@ -71,18 +72,16 @@ def generate_tutor_output(subject, difficulty, student_input, model):
71
  st.warning(f"HF model {model} failed, falling back to Mixtral.")
72
  if groq_client:
73
  return generate_tutor_output(subject, difficulty, student_input, "mixtral-8x7b-32768")
74
- return {"lesson": "Error generating lesson", "question": "N/A", "feedback": "N/A"}
75
 
76
  def generate_image(prompt, path='temp_image.png'):
77
  try:
78
- client = Client(IMAGE_GENERATION_SPACE_NAME, hf_token=HF_TOKEN)
79
- result = client.predict(
80
- prompt=prompt,
81
- width=512,
82
- height=512,
83
- api_name="/predict"
84
- )
85
- image = Image.open(result)
86
  image.save(path)
87
  return path
88
  except Exception as e:
@@ -91,6 +90,9 @@ def generate_image(prompt, path='temp_image.png'):
91
 
92
  def generate_video(images, audio_text, language, speaker, path='temp_video.mp4'):
93
  try:
 
 
 
94
  audio_client = Client("habib926653/Multilingual-TTS")
95
  audio_result = audio_client.predict(
96
  text=audio_text,
@@ -106,8 +108,11 @@ def generate_video(images, audio_text, language, speaker, path='temp_video.mp4')
106
  f.write(audio_bytes)
107
 
108
  audio_clip = mp.AudioFileClip(audio_path)
109
- duration_per_image = audio_clip.duration / len(images)
110
  image_clips = [mp.ImageClip(img).set_duration(duration_per_image) for img in images if img]
 
 
 
111
  video = mp.concatenate_videoclips(image_clips, method="compose").set_audio(audio_clip)
112
  video.write_videofile(path, fps=24, codec='libx264')
113
  return path
 
15
  # Constants
16
  HF_TOKEN = os.getenv("HF_TOKEN")
17
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
18
+ # Switching to HF Inference API for stability
19
+ IMAGE_GENERATION_API = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-2-1"
20
 
21
+ # Initialize Groq client
22
  try:
23
  groq_client = Groq(api_key=GROQ_API_KEY)
24
  except Exception as e:
25
  st.error(f"Failed to initialize Groq client: {e}")
26
  groq_client = None
27
 
28
+ # LLM Models
29
  LLM_MODELS = {
30
  "Mixtral 8x7B (Groq)": "mixtral-8x7b-32768",
31
  "Mistral 7B (HF)": "mistralai/Mixtral-7B-Instruct-v0.1",
32
+ "LLaMA 13B (HF)": "meta-llama/Llama-13b-hf"
33
  }
34
 
35
  # Utility Functions
 
46
  Format your response as a JSON object with keys: "lesson", "question", "feedback"
47
  """
48
 
49
+ if model.startswith("mixtral") and groq_client:
50
  try:
51
  completion = groq_client.chat.completions.create(
52
  messages=[{
 
62
  return json.loads(completion.choices[0].message.content)
63
  except Exception as e:
64
  st.error(f"Groq error: {e}")
65
+ return {"lesson": "Sorry, unable to generate lesson due to API issue.", "question": "N/A", "feedback": "Please try again or check your input."}
66
+ else:
67
  try:
68
  client = Client("https://api-inference.huggingface.co/models/" + model, hf_token=HF_TOKEN)
69
  response = client.predict(prompt, api_name="/generate")
 
72
  st.warning(f"HF model {model} failed, falling back to Mixtral.")
73
  if groq_client:
74
  return generate_tutor_output(subject, difficulty, student_input, "mixtral-8x7b-32768")
75
+ return {"lesson": "Sorry, unable to generate lesson.", "question": "N/A", "feedback": "N/A"}
76
 
77
  def generate_image(prompt, path='temp_image.png'):
78
  try:
79
+ client = Client(IMAGE_GENERATION_API, hf_token=HF_TOKEN)
80
+ result = client.predict(prompt, api_name="/predict")
81
+ if isinstance(result, str): # Handle file path or binary data
82
+ image = Image.open(result)
83
+ else:
84
+ image = Image.open(result)
 
 
85
  image.save(path)
86
  return path
87
  except Exception as e:
 
90
 
91
  def generate_video(images, audio_text, language, speaker, path='temp_video.mp4'):
92
  try:
93
+ if not images or all(img is None for img in images):
94
+ st.error("No valid images to create video.")
95
+ return None
96
  audio_client = Client("habib926653/Multilingual-TTS")
97
  audio_result = audio_client.predict(
98
  text=audio_text,
 
108
  f.write(audio_bytes)
109
 
110
  audio_clip = mp.AudioFileClip(audio_path)
111
+ duration_per_image = audio_clip.duration / len([img for img in images if img])
112
  image_clips = [mp.ImageClip(img).set_duration(duration_per_image) for img in images if img]
113
+ if not image_clips:
114
+ st.error("No image clips generated.")
115
+ return None
116
  video = mp.concatenate_videoclips(image_clips, method="compose").set_audio(audio_clip)
117
  video.write_videofile(path, fps=24, codec='libx264')
118
  return path