amiguel commited on
Commit
1dfd2ed
Β·
verified Β·
1 Parent(s): 4268d48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -14
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import streamlit as st
2
  from transformers import (
3
  AutoTokenizer,
4
- AutoModelForSeq2SeqLM,
5
  TextIteratorStreamer
6
  )
7
  from huggingface_hub import login
@@ -18,7 +18,7 @@ if not HF_TOKEN:
18
  raise ValueError("Missing Hugging Face Token. Please set the HF_TOKEN environment variable.")
19
 
20
  # βœ… Only PT-T5 Model
21
- MODEL_NAME = "amiguel/Meta-Llama-3.1-8B-Instruct-lei-geral-trabalho"# "amiguel/mistral-angolan-laborlaw" #"amiguel/mistral-angolan-laborlaw-ptt5"
22
 
23
  # UI Setup
24
  st.set_page_config(page_title="Assistente LGT | Angola", page_icon="πŸš€", layout="centered")
@@ -54,7 +54,7 @@ def load_model():
54
  try:
55
  login(token=HF_TOKEN)
56
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN, use_fast=False)
57
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float32).to("cuda" if torch.cuda.is_available() else "cpu")
58
  return model, tokenizer
59
  except Exception as e:
60
  st.error(f"πŸ€– Erro ao carregar o modelo: {str(e)}")
@@ -63,10 +63,8 @@ def load_model():
63
  # Streaming response generation
64
  def generate_response(prompt, context, model, tokenizer):
65
  full_prompt = f"Contexto:\n{context}\n\nPergunta: {prompt}\nResposta:"
66
-
67
  inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True).to(model.device)
68
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
69
-
70
  generation_kwargs = {
71
  "input_ids": inputs["input_ids"],
72
  "attention_mask": inputs["attention_mask"],
@@ -78,7 +76,6 @@ def generate_response(prompt, context, model, tokenizer):
78
  "use_cache": True,
79
  "streamer": streamer
80
  }
81
-
82
  Thread(target=model.generate, kwargs=generation_kwargs).start()
83
  return streamer
84
 
@@ -100,7 +97,7 @@ if prompt := st.chat_input("Faca uma pergunta sobre a LGT..."):
100
 
101
  # Load model if not loaded
102
  if "model" not in st.session_state:
103
- with st.spinner("πŸ”„ A carregar o modelo PT-T5..."):
104
  model, tokenizer = load_model()
105
  if not model:
106
  st.stop()
@@ -119,25 +116,20 @@ if prompt := st.chat_input("Faca uma pergunta sobre a LGT..."):
119
  try:
120
  start_time = time.time()
121
  streamer = generate_response(prompt, context, model, tokenizer)
122
-
123
  for chunk in streamer:
124
  full_response += chunk.strip() + " "
125
  response_box.markdown(full_response + "β–Œ", unsafe_allow_html=True)
126
-
127
  end_time = time.time()
128
  input_tokens = len(tokenizer(prompt)["input_ids"])
129
  output_tokens = len(tokenizer(full_response)["input_ids"])
130
  speed = output_tokens / (end_time - start_time)
131
- cost_usd = ((input_tokens / 1e6) * 5) + ((output_tokens / 1e6) * 15)
132
  cost_aoa = cost_usd * 1160
133
-
134
  st.caption(
135
  f"πŸ”‘ Tokens: {input_tokens} β†’ {output_tokens} | πŸ•’ Velocidade: {speed:.1f}t/s | "
136
  f"πŸ’° USD: ${cost_usd:.4f} | πŸ‡¦πŸ‡΄ AOA: {cost_aoa:.2f}"
137
  )
138
-
139
  response_box.markdown(full_response.strip())
140
  st.session_state.messages.append({"role": "assistant", "content": full_response.strip()})
141
-
142
  except Exception as e:
143
- st.error(f"⚑ Erro ao gerar resposta: {str(e)}")
 
1
  import streamlit as st
2
  from transformers import (
3
  AutoTokenizer,
4
+ AutoModelForCausalLM, # Use AutoModelForCausalLM instead of AutoModelForSeq2SeqLM
5
  TextIteratorStreamer
6
  )
7
  from huggingface_hub import login
 
18
  raise ValueError("Missing Hugging Face Token. Please set the HF_TOKEN environment variable.")
19
 
20
  # βœ… Only PT-T5 Model
21
+ MODEL_NAME = "amiguel/Meta-Llama-3.1-8B-Instruct-lei-geral-trabalho"
22
 
23
  # UI Setup
24
  st.set_page_config(page_title="Assistente LGT | Angola", page_icon="πŸš€", layout="centered")
 
54
  try:
55
  login(token=HF_TOKEN)
56
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN, use_fast=False)
57
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float32).to("cuda" if torch.cuda.is_available() else "cpu")
58
  return model, tokenizer
59
  except Exception as e:
60
  st.error(f"πŸ€– Erro ao carregar o modelo: {str(e)}")
 
63
  # Streaming response generation
64
  def generate_response(prompt, context, model, tokenizer):
65
  full_prompt = f"Contexto:\n{context}\n\nPergunta: {prompt}\nResposta:"
 
66
  inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True).to(model.device)
67
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
68
  generation_kwargs = {
69
  "input_ids": inputs["input_ids"],
70
  "attention_mask": inputs["attention_mask"],
 
76
  "use_cache": True,
77
  "streamer": streamer
78
  }
 
79
  Thread(target=model.generate, kwargs=generation_kwargs).start()
80
  return streamer
81
 
 
97
 
98
  # Load model if not loaded
99
  if "model" not in st.session_state:
100
+ with st.spinner("πŸ”„ A carregar o modelo ..."):
101
  model, tokenizer = load_model()
102
  if not model:
103
  st.stop()
 
116
  try:
117
  start_time = time.time()
118
  streamer = generate_response(prompt, context, model, tokenizer)
 
119
  for chunk in streamer:
120
  full_response += chunk.strip() + " "
121
  response_box.markdown(full_response + "β–Œ", unsafe_allow_html=True)
 
122
  end_time = time.time()
123
  input_tokens = len(tokenizer(prompt)["input_ids"])
124
  output_tokens = len(tokenizer(full_response)["input_ids"])
125
  speed = output_tokens / (end_time - start_time)
126
+ cost_usd = ((input_tokens / 1e6) * 0.0001) + ((output_tokens / 1e6) * 0.0001)
127
  cost_aoa = cost_usd * 1160
 
128
  st.caption(
129
  f"πŸ”‘ Tokens: {input_tokens} β†’ {output_tokens} | πŸ•’ Velocidade: {speed:.1f}t/s | "
130
  f"πŸ’° USD: ${cost_usd:.4f} | πŸ‡¦πŸ‡΄ AOA: {cost_aoa:.2f}"
131
  )
 
132
  response_box.markdown(full_response.strip())
133
  st.session_state.messages.append({"role": "assistant", "content": full_response.strip()})
 
134
  except Exception as e:
135
+ st.error(f"⚑ Erro ao gerar resposta: {str(e)}")