Spestly commited on
Commit
dc45496
·
verified ·
1 Parent(s): e043807

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -12
app.py CHANGED
@@ -3,23 +3,28 @@ import spaces
3
  from transformers import pipeline
4
  import torch
5
 
6
- # Initialize the pipeline with the Orion model
 
 
7
  @spaces.GPU
8
  def initialize_model():
9
- return pipeline(
10
- "text-generation",
11
- model="apexion-ai/Orion-V1-4B",
12
- torch_dtype=torch.float16,
13
- device_map="auto"
14
- )
15
-
16
- # Load the model
17
- pipe = initialize_model()
18
 
19
  @spaces.GPU
20
  def generate_response(message, history, max_length=512, temperature=0.7, top_p=0.9):
21
  """Generate response using the Orion model"""
22
 
 
 
 
23
  # Format the conversation history
24
  messages = []
25
 
@@ -34,13 +39,13 @@ def generate_response(message, history, max_length=512, temperature=0.7, top_p=0
34
 
35
  # Generate response
36
  try:
37
- response = pipe(
38
  messages,
39
  max_length=max_length,
40
  temperature=temperature,
41
  top_p=top_p,
42
  do_sample=True,
43
- pad_token_id=pipe.tokenizer.eos_token_id
44
  )
45
 
46
  # Extract the generated text
 
3
  from transformers import pipeline
4
  import torch
5
 
6
+ # Global variable to store the pipeline
7
+ pipe = None
8
+
9
  @spaces.GPU
10
  def initialize_model():
11
+ global pipe
12
+ if pipe is None:
13
+ pipe = pipeline(
14
+ "text-generation",
15
+ model="apexion-ai/Orion-V1-4B",
16
+ torch_dtype=torch.float16,
17
+ device_map="auto"
18
+ )
19
+ return pipe
20
 
21
  @spaces.GPU
22
  def generate_response(message, history, max_length=512, temperature=0.7, top_p=0.9):
23
  """Generate response using the Orion model"""
24
 
25
+ # Initialize model inside the GPU-decorated function
26
+ model_pipe = initialize_model()
27
+
28
  # Format the conversation history
29
  messages = []
30
 
 
39
 
40
  # Generate response
41
  try:
42
+ response = model_pipe(
43
  messages,
44
  max_length=max_length,
45
  temperature=temperature,
46
  top_p=top_p,
47
  do_sample=True,
48
+ pad_token_id=model_pipe.tokenizer.eos_token_id
49
  )
50
 
51
  # Extract the generated text