CJHauser commited on
Commit
9b32fcd
·
verified ·
1 Parent(s): 8853386

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +22 -17
src/streamlit_app.py CHANGED
@@ -2,26 +2,31 @@ import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
  # Load your model and tokenizer
5
- model_name = "CJHauser/PrisimAI-chat" # Your Hugging Face model
6
- tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = AutoModelForCausalLM.from_pretrained(model_name)
8
 
9
  # Streamlit App
10
- st.title("PrisimAI Chatbot")
 
11
 
12
- # Input box for user prompt
13
- user_input = st.text_input("Ask a question:")
14
 
15
- # When the user enters a prompt
16
  if user_input:
17
- # Encode the input
18
- inputs = tokenizer(user_input, return_tensors="pt")
19
-
20
- # Generate the response
21
- outputs = model.generate(inputs['input_ids'], max_length=100)
22
-
23
- # Decode the output
24
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
25
-
26
- # Show the response
27
- st.write(f"Response: {response}")
 
 
 
 
 
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
  # Load your model and tokenizer
5
+ model_name = "CJHauser/PrisimAI-chat"
6
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
7
  model = AutoModelForCausalLM.from_pretrained(model_name)
8
 
9
  # Streamlit App
10
+ st.set_page_config(page_title="PrisimAI Chatbot")
11
+ st.title("🤖 PrisimAI Chatbot")
12
 
13
+ # User prompt
14
+ user_input = st.text_input("Ask something:", placeholder="e.g. What is AI?")
15
 
 
16
  if user_input:
17
+ with st.spinner("Thinking..."):
18
+ inputs = tokenizer(user_input, return_tensors="pt")
19
+ outputs = model.generate(
20
+ inputs["input_ids"],
21
+ max_length=150,
22
+ do_sample=True,
23
+ temperature=0.7,
24
+ top_p=0.9,
25
+ pad_token_id=tokenizer.eos_token_id
26
+ )
27
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
+
29
+ # Remove the original prompt from the response if repeated
30
+ response = response.replace(user_input, "").strip()
31
+
32
+ st.markdown(f"**Response:** {response}")