nyasukun commited on
Commit
bec1fe0
·
1 Parent(s): 9290385

update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -6
app.py CHANGED
@@ -3,7 +3,7 @@ import gradio as gr
3
  import torch, pandas as pd
4
  import matplotlib.pyplot as plt
5
  import seaborn as sns
6
- from transformers import AutoTokenizer, AutoModelForCausalLM
7
 
8
  # ZeroGPU support
9
  try:
@@ -24,12 +24,41 @@ except ImportError:
24
  MODEL_NAME = "fdtn-ai/Foundation-Sec-8B"
25
  #MODEL_NAME = "sshleifer/tiny-gpt2"
26
 
27
- # Initialize tokenizer and model
28
  print(f"Loading model: {MODEL_NAME}")
29
- tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
30
- model = AutoModelForCausalLM.from_pretrained(
31
- MODEL_NAME, torch_dtype=torch.float16, device_map="auto"
32
- ).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  # Log device information
35
  if hasattr(model, 'device'):
 
3
  import torch, pandas as pd
4
  import matplotlib.pyplot as plt
5
  import seaborn as sns
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
 
8
  # ZeroGPU support
9
  try:
 
24
  MODEL_NAME = "fdtn-ai/Foundation-Sec-8B"
25
  #MODEL_NAME = "sshleifer/tiny-gpt2"
26
 
27
+ # Initialize tokenizer and model using pipeline approach
28
  print(f"Loading model: {MODEL_NAME}")
29
+ try:
30
+ print(f"Initializing text generation model: {MODEL_NAME}")
31
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
32
+ text_pipeline = pipeline(
33
+ "text-generation",
34
+ model=MODEL_NAME,
35
+ tokenizer=tokenizer,
36
+ torch_dtype=torch.bfloat16,
37
+ device_map="auto",
38
+ trust_remote_code=True
39
+ )
40
+ print(f"Model initialized successfully: {MODEL_NAME}")
41
+
42
+ # Extract model and tokenizer from pipeline for direct access
43
+ model = text_pipeline.model
44
+ tok = text_pipeline.tokenizer
45
+
46
+ except Exception as e:
47
+ print(f"Error initializing model {MODEL_NAME}: {str(e)}")
48
+ print("Falling back to tiny-gpt2...")
49
+ MODEL_NAME = "sshleifer/tiny-gpt2"
50
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
51
+ text_pipeline = pipeline(
52
+ "text-generation",
53
+ model=MODEL_NAME,
54
+ tokenizer=tokenizer,
55
+ torch_dtype=torch.bfloat16,
56
+ device_map="auto",
57
+ trust_remote_code=True
58
+ )
59
+ model = text_pipeline.model
60
+ tok = text_pipeline.tokenizer
61
+ print(f"Fallback model loaded: {MODEL_NAME}")
62
 
63
  # Log device information
64
  if hasattr(model, 'device'):