File size: 1,524 Bytes
641689c
 
 
 
 
 
 
a2f4562
 
 
 
 
1f33001
a2f4562
1f33001
 
 
7e7fb74
a2f4562
1f33001
 
a2f4562
1f33001
 
 
7e7fb74
 
1f33001
7e7fb74
1f33001
7e7fb74
1f33001
 
7e7fb74
 
1f33001
 
7e7fb74
1f33001
 
7e7fb74
 
 
 
1f33001
 
992fc97
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch, textwrap, gradio as gr
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    pipeline,
)
MODEL_ID = "mistralai/Mixtral-8x7B-Instruct-v0.1"   # FP16 weights
bnb_cfg  = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,   # keeps mat-mul fast
)

tok   = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    trust_remote_code=True,
    quantization_config=bnb_cfg,            # perfectly fine here
)

prompt_tpl = (
    "Summarise the following transcript in short in 1 or 2 paragraph and point wise and don't miss any key information cover all"
)

gen = pipeline("text-generation", model=model, tokenizer=tok,
               max_new_tokens=256, temperature=0.3)

MAX_CHUNK = 6_000   # β‰ˆ 4 k tokens

def summarize(txt: str) -> str:
    parts = textwrap.wrap(txt, MAX_CHUNK, break_long_words=False)
    partials = [
        gen(prompt_tpl.format(chunk=p))[0]["generated_text"]
        .split("### Summary:")[-1].strip()
        for p in parts
    ]
    return gen(prompt_tpl.format(chunk=" ".join(partials)))[0]["generated_text"]\
             .split("### Summary:")[-1].strip()

demo = gr.Interface(fn=summarize,
                    inputs=gr.Textbox(lines=20, label="Transcript"),
                    outputs="text",
                    title="Free Transcript Summariser – Mixtral-8Γ—7B")

if __name__ == "__main__":
    demo.launch()