varun500 commited on
Commit
8f08840
·
1 Parent(s): 042db08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -14
app.py CHANGED
@@ -1,25 +1,31 @@
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import streamlit as st
4
- model = AutoModelForCausalLM.from_pretrained("RWKV/rwkv-4-169m-pile")
5
- tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-169m-pile")
 
 
6
  def main():
7
- st.title("Dragon Text Generator")
8
- st.write("Enter a prompt and generate dragon-inspired text!")
9
 
10
- # Input prompt
11
- prompt = st.text_area("Enter your prompt", value="", height=150)
12
 
13
- if st.button("Generate Text"):
14
- if prompt.strip() != "":
15
- # Generate text based on the provided prompt
16
- inputs = tokenizer(prompt, return_tensors="pt")
17
- output = model.generate(inputs["input_ids"], max_new_tokens=20)
18
- generated_text = tokenizer.decode(output[0].tolist())
19
- st.markdown("## Generated Text")
 
20
  st.write(generated_text)
21
  else:
22
- st.warning("Please enter a prompt.")
 
 
 
23
 
24
  if __name__ == "__main__":
25
  main()
 
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import streamlit as st
4
+ model_id = "RWKV/rwkv-raven-1b5"
5
+ model = AutoModelForCausalLM.from_pretrained(model_id).to(0)
6
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
7
+
8
  def main():
9
+ st.title("Raven Text Generator")
10
+ st.write("Ask a question about ravens and get a response!")
11
 
12
+ # Input question
13
+ question = st.text_input("Ask a question")
14
 
15
+ if st.button("Generate Response"):
16
+ if question.strip() != "":
17
+ # Generate response based on the provided question
18
+ prompt = f"### Instruction: {question}\n### Response:"
19
+ inputs = tokenizer(prompt, return_tensors="pt").to(0)
20
+ output = model.generate(inputs["input_ids"], max_new_tokens=100)
21
+ generated_text = tokenizer.decode(output[0].tolist(), skip_special_tokens=True)
22
+ st.markdown("## Generated Response")
23
  st.write(generated_text)
24
  else:
25
+ st.warning("Please enter a question.")
26
+
27
+ if __name__ == "__main__":
28
+ main()
29
 
30
  if __name__ == "__main__":
31
  main()