aaliyaan commited on
Commit
4451a94
·
1 Parent(s): dda6a92
Files changed (1) hide show
  1. app.py +21 -6
app.py CHANGED
@@ -2,10 +2,15 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
3
  from PyPDF2 import PdfReader
4
 
 
5
  # Models and Tokenizers Setup
6
  models = {
7
  "Text Generator (Bloom)": {
8
- "model": AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m"),
 
 
 
 
9
  "tokenizer": AutoTokenizer.from_pretrained("bigscience/bloom-560m"),
10
  },
11
  "PDF Summarizer (T5)": {
@@ -32,9 +37,19 @@ def chat_with_model(model_choice, user_message, chat_history, file=None):
32
  model = model_info["model"]
33
 
34
  # Tokenize Input
35
- inputs = tokenizer(user_message, return_tensors="pt", padding=True, truncation=True, max_length=512)
36
- # Generate Output
37
- outputs = model.generate(**inputs, max_length=150, num_beams=5, early_stopping=True)
 
 
 
 
 
 
 
 
 
 
38
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
39
 
40
  # Update Chat History
@@ -50,7 +65,7 @@ def extract_text_from_pdf(file):
50
 
51
  # Interface Setup
52
  def create_chat_interface():
53
- with gr.Blocks(css="""
54
  .chatbox {
55
  background-color: #f7f7f8;
56
  border-radius: 12px;
@@ -102,4 +117,4 @@ def create_chat_interface():
102
 
103
  if __name__ == "__main__":
104
  interface = create_chat_interface()
105
- interface.launch()
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
3
  from PyPDF2 import PdfReader
4
 
5
+
6
  # Models and Tokenizers Setup
7
  models = {
8
  "Text Generator (Bloom)": {
9
+ "model": AutoModelForCausalLM.from_pretrained(
10
+ "bigscience/bloom-560m",
11
+ device_map="auto",
12
+ torch_dtype="auto"
13
+ ),
14
  "tokenizer": AutoTokenizer.from_pretrained("bigscience/bloom-560m"),
15
  },
16
  "PDF Summarizer (T5)": {
 
37
  model = model_info["model"]
38
 
39
  # Tokenize Input
40
+ inputs = tokenizer(user_message, return_tensors="pt", truncation=True, max_length=256)
41
+
42
+ # Adjust max_length and parameters for the PDF summarizer
43
+ max_length = 100
44
+ num_beams = 3
45
+ outputs = model.generate(
46
+ **inputs,
47
+ max_length=max_length,
48
+ num_beams=num_beams,
49
+ early_stopping=True,
50
+ no_repeat_ngram_size=2
51
+ )
52
+
53
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
54
 
55
  # Update Chat History
 
65
 
66
  # Interface Setup
67
  def create_chat_interface():
68
+ with gr.Blocks(css="""
69
  .chatbox {
70
  background-color: #f7f7f8;
71
  border-radius: 12px;
 
117
 
118
  if __name__ == "__main__":
119
  interface = create_chat_interface()
120
+ interface.launch(server_name="0.0.0.0", server_port=7860)