File size: 1,756 Bytes
6a2e815
 
 
 
 
 
 
f0666ff
6a2e815
 
 
 
 
 
 
 
 
c2f5fd0
 
6e047fb
50ca940
6a2e815
 
a94d4e4
6a2e815
 
a94d4e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a2e815
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import gradio as gr

from peft import PeftModel, PeftConfig
from transformers import MistralForCausalLM, TextIteratorStreamer, AutoTokenizer, BitsAndBytesConfig
from time import sleep
from threading import Thread
from torch import float16
import spaces

config = PeftConfig.from_pretrained("lang-uk/dragoman")
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=float16,
    bnb_4bit_use_double_quant=False,
)

model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1",
    quantization_config=quant_config)
    #device_map="auto",)
model = PeftModel.from_pretrained(model, "lang-uk/dragoman").to("cuda")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False, add_bos_token=False)

@spaces.GPU(duration=30)
def translate(input_text):
    generated_text = ""
    input_text = input_text.strip()
    for chunk in input_text.split("\n"):
        if not chunk:
            generated_text += "\n"
            yield generated_text
            continue
        chunk = f"[INST] {chunk} [/INST]"
        inputs = tokenizer([chunk], return_tensors="pt").to(model.device)

        streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
        generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=200)

        thread = Thread(target=model.generate, kwargs=generation_kwargs)

        thread.start()
        
        for new_text in streamer:
            generated_text += new_text
            yield generated_text
        
        generated_text += "\n"
        yield generated_text


iface = gr.Interface(fn=translate, inputs="text", outputs="text", examples=[["who holds this neighborhood?"]])
iface.launch()