alex-abb commited on
Commit
089d26e
·
verified ·
1 Parent(s): 81bb439

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -6
app.py CHANGED
@@ -1,23 +1,55 @@
1
  import gradio as gr
2
  import spaces
3
  import transformers
 
4
  from transformers import pipeline
5
  import torch
6
  import os
7
- from transformers import AutoTokenizer, AutoModelForCausalLM
8
-
9
-
10
 
11
  api_token = os.environ.get("APIKEY")
12
 
13
 
14
 
 
15
  @spaces.GPU(duration=240)
16
- # Load model directly
17
 
18
- tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
19
- model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct",token=api_token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  # Fonction de génération de texte
 
1
  import gradio as gr
2
  import spaces
3
  import transformers
4
+ from transformers import AutoTokenizer,AutoModelForCausalLM
5
  from transformers import pipeline
6
  import torch
7
  import os
 
 
 
8
 
9
  api_token = os.environ.get("APIKEY")
10
 
11
 
12
 
13
+ model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
14
  @spaces.GPU(duration=240)
 
15
 
16
+ # Charger le modèle en spécifiant le token d'accès
17
+
18
+ pipeline = transformers.pipeline(
19
+ "text-generation",
20
+ model=model_id,
21
+ token = api_token,
22
+ model_kwargs={"torch_dtype": torch.bfloat16},
23
+ device_map="auto",
24
+ )
25
+
26
+ # Créer un pipeline pour la génération de texte
27
+ pipeline = transformers.pipeline(
28
+ "text-generation",
29
+ model=model,
30
+ tokenizer=model.config.tokenizer,
31
+ device_map="auto",
32
+ )
33
+
34
+ messages = [
35
+ {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
36
+ {"role": "user", "content": "Who are you?"},
37
+ ]
38
+
39
+ terminators = [
40
+ pipeline.tokenizer.eos_token_id,
41
+ pipeline.tokenizer.convert_tokens_to_ids("")
42
+ ]
43
 
44
+ # Utiliser le pipeline pour générer du texte
45
+ outputs = pipeline(
46
+ messages,
47
+ max_new_tokens=256,
48
+ eos_token_id=terminators,
49
+ do_sample=True,
50
+ temperature=0.6,
51
+ top_p=0.9,
52
+ )
53
 
54
 
55
  # Fonction de génération de texte