rahul7star commited on
Commit
3233665
·
verified ·
1 Parent(s): 1094929

Update App1.py

Browse files
Files changed (1) hide show
  1. App1.py +69 -23
App1.py CHANGED
@@ -1,34 +1,80 @@
1
- import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
 
 
 
 
4
 
 
5
  MODEL_ID = "goonsai-com/civitaiprompts"
6
- MODEL_VARIANT = "Q4_K_M" # The quantized version
 
 
 
 
 
 
7
 
8
- print("Loading model...")
9
- tokenizer = AutoTokenizer.from_pretrained(f"hf.co/{MODEL_ID}:{MODEL_VARIANT}")
 
 
 
 
 
 
 
 
10
  model = AutoModelForCausalLM.from_pretrained(
11
- f"hf.co/{MODEL_ID}:{MODEL_VARIANT}",
12
- torch_dtype=torch.float16,
13
- device_map="auto"
 
 
14
  )
 
 
 
 
 
 
 
 
 
15
 
16
- def chat(prompt):
17
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
18
- output = model.generate(
 
 
 
 
 
19
  **inputs,
20
- max_length=200,
 
21
  temperature=0.7,
22
- do_sample=True
23
  )
24
- return tokenizer.decode(output[0], skip_special_tokens=True)
25
-
26
- iface = gr.Interface(
27
- fn=chat,
28
- inputs="text",
29
- outputs="text",
30
- title="CivitaI Prompt Model",
31
- description="Type a prompt and get a response."
32
- )
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- iface.launch()
 
 
 
 
1
  import torch
2
+ import logging
3
+ import time
4
+ import gradio as gr
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
 
7
+ # ---------------- CONFIG ----------------
8
  MODEL_ID = "goonsai-com/civitaiprompts"
9
+ MODEL_VARIANT = "Q4_K_M" # This is the HF tag for the quantized model
10
+ MODEL_NAME = "CivitAI-Prompts-Q4_K_M"
11
+
12
+ # ---------------- LOGGING ----------------
13
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
14
+ logger = logging.getLogger(__name__)
15
+ logger.info("Starting Gradio chatbot...")
16
 
17
+ # ---------------- LOAD MODEL ----------------
18
+ logger.info(f"Loading tokenizer from {MODEL_ID} (revision={MODEL_VARIANT})")
19
+ tokenizer = AutoTokenizer.from_pretrained(
20
+ MODEL_ID,
21
+ revision=MODEL_VARIANT,
22
+ trust_remote_code=True
23
+ )
24
+
25
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
26
+ logger.info(f"Loading model with dtype {dtype}")
27
  model = AutoModelForCausalLM.from_pretrained(
28
+ MODEL_ID,
29
+ revision=MODEL_VARIANT,
30
+ torch_dtype=dtype,
31
+ device_map="auto",
32
+ trust_remote_code=True
33
  )
34
+ logger.info("Model loaded successfully.")
35
+
36
+ # ---------------- CHAT FUNCTION ----------------
37
+ def chat_fn(message):
38
+ logger.info(f"Received message: {message}")
39
+
40
+ # Build prompt
41
+ full_text = f"User: {message}\nAssistant:"
42
+ logger.info(f"Full prompt for generation:\n{full_text}")
43
 
44
+ start_time = time.time()
45
+ # Tokenize input
46
+ inputs = tokenizer([full_text], return_tensors="pt", truncation=True, max_length=1024).to(model.device)
47
+ logger.info("Tokenized input.")
48
+
49
+ # Generate response
50
+ logger.info("Generating response...")
51
+ reply_ids = model.generate(
52
  **inputs,
53
+ max_new_tokens=512,
54
+ do_sample=True,
55
  temperature=0.7,
56
+ top_p=0.9
57
  )
58
+ response = tokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0]
59
+ assistant_reply = response.split("Assistant:")[-1].strip()
60
+ logger.info(f"Assistant reply: {assistant_reply}")
61
+ logger.info(f"Generation time: {time.time() - start_time:.2f}s")
62
+
63
+ return assistant_reply
64
+
65
+ # ---------------- GRADIO BLOCKS UI ----------------
66
+ with gr.Blocks() as demo:
67
+ gr.Markdown(f"# 🤖 {MODEL_NAME} (Stateless)")
68
+
69
+ with gr.Row():
70
+ with gr.Column():
71
+ message = gr.Textbox(label="Type your message...", placeholder="Hello!")
72
+ send_btn = gr.Button("Send")
73
+ with gr.Column():
74
+ output = gr.Textbox(label="Assistant Response", lines=10)
75
+
76
+ send_btn.click(chat_fn, inputs=[message], outputs=[output])
77
+ message.submit(chat_fn, inputs=[message], outputs=[output])
78
 
79
+ logger.info("Launching Gradio app...")
80
+ demo.launch()