Akbartus commited on
Commit
a95c82c
·
verified ·
1 Parent(s): 8b56784

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +16 -37
main.py CHANGED
@@ -1,52 +1,31 @@
 
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from huggingface_hub import InferenceClient
4
  import uvicorn
5
 
6
 
7
  app = FastAPI()
8
 
9
- client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
 
 
 
 
 
 
 
 
10
 
11
  class Item(BaseModel):
12
  prompt: str
13
- history: list
14
- system_prompt: str
15
- temperature: float = 0.0
16
- max_new_tokens: int = 1024
17
- top_p: float = 0.15
18
- repetition_penalty: float = 1.0
19
-
20
- def format_prompt(message, history):
21
- prompt = "<s>"
22
- for user_prompt, bot_response in history:
23
- prompt += f"[INST] {user_prompt} [/INST]"
24
- prompt += f" {bot_response}</s> "
25
- prompt += f"[INST] {message} [/INST]"
26
- return prompt
27
 
28
  def generate(item: Item):
29
- temperature = float(item.temperature)
30
- if temperature < 1e-2:
31
- temperature = 1e-2
32
- top_p = float(item.top_p)
33
-
34
- generate_kwargs = dict(
35
- temperature=temperature,
36
- max_new_tokens=item.max_new_tokens,
37
- top_p=top_p,
38
- repetition_penalty=item.repetition_penalty,
39
- do_sample=True,
40
- seed=42,
41
- )
42
-
43
- formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
44
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
45
- output = ""
46
-
47
- for response in stream:
48
- output += response.token.text
49
- return output
50
 
51
  @app.post("/generate/")
52
  async def generate_text(item: Item):
 
1
+ import transformers
2
+ import torch
3
  from fastapi import FastAPI
4
  from pydantic import BaseModel
 
5
  import uvicorn
6
 
7
 
8
  app = FastAPI()
9
 
10
+
11
+
12
+ model_id = "meta-llama/Meta-Llama-3-8B"
13
+
14
+ pipeline = transformers.pipeline(
15
+ "text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto"
16
+ )
17
+
18
+
19
 
20
  class Item(BaseModel):
21
  prompt: str
22
+
23
+
24
+
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def generate(item: Item):
27
+ pipeline(item.prompt)
28
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  @app.post("/generate/")
31
  async def generate_text(item: Item):