Tridev / app.py
Aneeshmishra's picture
Create app.py
1f33001 verified
raw
history blame
1.68 kB
import os, textwrap, torch, gradio as gr
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
)
# βœ… 1. Use the *Instruct* checkpoint
MODEL_ID = os.getenv(
"MODEL_ID",
"mistralai/Mixtral-8x7B-Instruct-v0.3" # correct model name
)
# βœ… 2. Load in 4-bit so it fits on Hugging-Face ZeroGPU (<15 GB)
bnb_cfg = BitsAndBytesConfig(load_in_4bit=True)
tok = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=bnb_cfg, # 4-bit
device_map="auto",
torch_dtype=torch.float16, # ZeroGPU has a single T4/L4
trust_remote_code=True, # required for Mixtral
)
# βœ… 3. Use *text-generation* with an explicit prompt template
prompt_tmpl = (
"Summarise the following transcript in short in 1 or 2 paragraph and point wise and don't miss any key information cover all"
)
gen = pipeline(
"text-generation",
model=model,
tokenizer=tok,
max_new_tokens=256,
temperature=0.1,
)
MAX_CHUNK = 6_000 # β‰ˆ4 k tokens
def summarize(txt):
parts = textwrap.wrap(txt, MAX_CHUNK, break_long_words=False)
partials = [
gen(prompt_tmpl.format(chunk=p))[0]["generated_text"].split("### Summary:")[-1].strip()
for p in parts
]
return gen(prompt_tmpl.format(chunk=" ".join(partials)))[0]["generated_text"]\
.split("### Summary:")[-1].strip()
demo = gr.Interface(
fn=summarize,
inputs=gr.Textbox(lines=20, label="Transcript"),
outputs="text",
title="Mixtral-8Γ—7B Transcript Summariser",
)
if __name__ == "__main__":
demo.launch()