Akbartus commited on
Commit
d82d16c
·
verified ·
1 Parent(s): c14e8e9

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +5 -9
main.py CHANGED
@@ -1,19 +1,15 @@
1
- import transformers
2
- import torch
3
  from fastapi import FastAPI
4
  from pydantic import BaseModel
5
  import uvicorn
6
 
7
 
8
- app = FastAPI()
9
-
10
 
11
 
12
- model_id = "meta-llama/Meta-Llama-3-8B"
13
 
14
- pipeline = transformers.pipeline(
15
- "text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto"
16
- )
17
 
18
 
19
 
@@ -24,7 +20,7 @@ class Item(BaseModel):
24
 
25
 
26
  def generate(item: Item):
27
- pipeline(item.prompt)
28
 
29
 
30
  @app.post("/generate/")
 
1
+ from transformers import pipeline, set_seed
 
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
  import uvicorn
5
 
6
 
7
+ generator = pipeline('text-generation', model='gpt2')
8
+ set_seed(42)
9
 
10
 
11
+ app = FastAPI()
12
 
 
 
 
13
 
14
 
15
 
 
20
 
21
 
22
  def generate(item: Item):
23
+ generator(item.prompt, max_length=30, num_return_sequences=5)
24
 
25
 
26
  @app.post("/generate/")