Aneeshmishra commited on
Commit
1f33001
Β·
verified Β·
1 Parent(s): 5bf0c89

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, textwrap, torch, gradio as gr
2
+ from transformers import (
3
+ pipeline,
4
+ AutoTokenizer,
5
+ AutoModelForCausalLM,
6
+ BitsAndBytesConfig,
7
+ )
8
+
9
+ # βœ… 1. Use the *Instruct* checkpoint
10
+ MODEL_ID = os.getenv(
11
+ "MODEL_ID",
12
+ "mistralai/Mixtral-8x7B-Instruct-v0.3" # correct model name
13
+ )
14
+
15
+ # βœ… 2. Load in 4-bit so it fits on Hugging-Face ZeroGPU (<15 GB)
16
+ bnb_cfg = BitsAndBytesConfig(load_in_4bit=True)
17
+
18
+ tok = AutoTokenizer.from_pretrained(MODEL_ID)
19
+
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ MODEL_ID,
22
+ quantization_config=bnb_cfg, # 4-bit
23
+ device_map="auto",
24
+ torch_dtype=torch.float16, # ZeroGPU has a single T4/L4
25
+ trust_remote_code=True, # required for Mixtral
26
+ )
27
+
28
+ # βœ… 3. Use *text-generation* with an explicit prompt template
29
+ prompt_tmpl = (
30
+ "Summarise the following transcript in short in 1 or 2 paragraph and point wise and don't miss any key information cover all"
31
+ )
32
+
33
+ gen = pipeline(
34
+ "text-generation",
35
+ model=model,
36
+ tokenizer=tok,
37
+ max_new_tokens=256,
38
+ temperature=0.1,
39
+ )
40
+
41
+ MAX_CHUNK = 6_000 # β‰ˆ4 k tokens
42
+
43
+ def summarize(txt):
44
+ parts = textwrap.wrap(txt, MAX_CHUNK, break_long_words=False)
45
+ partials = [
46
+ gen(prompt_tmpl.format(chunk=p))[0]["generated_text"].split("### Summary:")[-1].strip()
47
+ for p in parts
48
+ ]
49
+ return gen(prompt_tmpl.format(chunk=" ".join(partials)))[0]["generated_text"]\
50
+ .split("### Summary:")[-1].strip()
51
+
52
+ demo = gr.Interface(
53
+ fn=summarize,
54
+ inputs=gr.Textbox(lines=20, label="Transcript"),
55
+ outputs="text",
56
+ title="Mixtral-8Γ—7B Transcript Summariser",
57
+ )
58
+
59
+ if __name__ == "__main__":
60
+ demo.launch()