File size: 4,577 Bytes
afc04de
 
 
 
c8f3971
afc04de
 
c8f3971
1f1f572
c8f3971
 
 
1f1f572
afc04de
 
 
 
 
 
1f1f572
 
 
 
 
 
c8f3971
afc04de
 
 
 
c8f3971
afc04de
 
 
 
 
 
 
1f1f572
 
c8f3971
afc04de
 
 
 
 
 
 
 
 
 
1f1f572
 
 
afc04de
1f1f572
afc04de
1f1f572
afc04de
 
 
1f1f572
afc04de
 
1f1f572
afc04de
1f1f572
c8f3971
1f1f572
 
 
 
afc04de
1f1f572
 
 
 
 
 
 
 
 
 
 
 
 
afc04de
1f1f572
afc04de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f1f572
 
 
 
 
 
 
 
 
afc04de
1f1f572
afc04de
1f1f572
afc04de
 
 
 
 
 
 
1f1f572
afc04de
1f1f572
afc04de
 
 
 
1f1f572
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
import pymupdf
import docx
from pptx import Presentation
from fastapi import FastAPI, File, UploadFile, HTTPException
from typing import List, Dict

app = FastAPI()

# Model and tokenizer initialization
MODEL_LIST = ["nikravan/glm-4vq"]
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = MODEL_LIST[0]
MODEL_NAME = "GLM-4vq"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True
)

def extract_text(path):
    return open(path, 'r').read()

def extract_pdf(path):
    doc = pymupdf.open(path)
    text = ""
    for page in doc:
        text += page.get_text()
    return text

def extract_docx(path):
    doc = docx.Document(path)
    data = [paragraph.text for paragraph in doc.paragraphs]
    return '\n\n'.join(data)

def extract_pptx(path):
    prs = Presentation(path)
    text = ""
    for slide in prs.slides:
        for shape in slide.shapes:
            if hasattr(shape, "text"):
                text += shape.text + "\n"
    return text

def mode_load(path):
    file_type = path.split(".")[-1].lower()
    if file_type in ["pdf", "txt", "py", "docx", "pptx"]:
        if file_type == "pdf":
            content = extract_pdf(path)
        elif file_type == "docx":
            content = extract_docx(path)
        elif file_type == "pptx":
            content = extract_pptx(path)
        else:
            content = extract_text(path)
        return "doc", content[:5000]
    elif file_type in ["png", "jpg", "jpeg", "bmp", "tiff", "webp"]:
        content = Image.open(path).convert('RGB')
        return "image", content
    else:
        raise HTTPException(status_code=400, detail="Unsupported file type")

@app.post("/test/")
async def test_endpoint(message: Dict[str, str]):
    if "text" not in message:
        raise HTTPException(status_code=400, detail="Missing 'text' in request body")
    
    response = {"message": f"Received your message: {message['text']}"}
    return response

@app.post("/chat/")
async def chat_endpoint(
    message: Dict[str, str],
    history: List[Dict[str, str]] = [],
    temperature: float = 0.8,
    max_length: int = 4096,
    top_p: float = 1.0,
    top_k: int = 10,
    penalty: float = 1.0
):
    conversation = []
    if "files" in message and message["files"]:
        choice, contents = mode_load(message["files"][-1])
        if choice == "image":
            conversation.append({"role": "user", "image": contents, "content": message['text']})
        elif choice == "doc":
            format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message['text']
            conversation.append({"role": "user", "content": format_msg})
    else:
        if len(history) == 0:
            conversation.append({"role": "user", "content": message['text']})
        else:
            for prompt, answer in history:
                if answer is None:
                    conversation.extend([{"role": "user", "content": ""}, {"role": "assistant", "content": ""}])
                else:
                    conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
            if len(history) > 0:
                choice, contents = mode_load(history[-1][0])
                if choice == "image":
                    conversation.append({"role": "user", "image": contents, "content": message['text']})
                elif choice == "doc":
                    format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message['text']
                    conversation.append({"role": "user", "content": format_msg})
                else:
                    conversation.append({"role": "user", "content": message['text']})

    input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
    streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
    
    generate_kwargs = dict(
        max_length=max_length,
        streamer=streamer,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        repetition_penalty=penalty
    )
    
    with torch.no_grad():
        buffer = ""
        for new_text in streamer:
            buffer += new_text
        return {"response": buffer}