bhkkhjgkk commited on
Commit
dad25ac
·
verified ·
1 Parent(s): 473963a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +17 -13
main.py CHANGED
@@ -1,8 +1,9 @@
1
  from fastapi import FastAPI
 
2
  from pydantic import BaseModel
3
  from huggingface_hub import InferenceClient
4
  import uvicorn
5
-
6
 
7
  app = FastAPI()
8
 
@@ -25,10 +26,8 @@ def format_prompt(message, history):
25
  prompt += f"[INST] {message} [/INST]"
26
  return prompt
27
 
28
- def generate(item: Item):
29
- temperature = float(item.temperature)
30
- if temperature < 1e-2:
31
- temperature = 1e-2
32
  top_p = float(item.top_p)
33
 
34
  generate_kwargs = dict(
@@ -41,14 +40,19 @@ def generate(item: Item):
41
  )
42
 
43
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
44
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
45
- output = ""
46
-
47
- for response in stream:
48
- output += response.token.text
49
- return output
 
 
 
 
 
 
50
 
51
  @app.post("/generate/")
52
  async def generate_text(item: Item):
53
- return {"response": generate(item)}
54
-
 
1
  from fastapi import FastAPI
2
+ from fastapi.responses import StreamingResponse
3
  from pydantic import BaseModel
4
  from huggingface_hub import InferenceClient
5
  import uvicorn
6
+ import asyncio
7
 
8
  app = FastAPI()
9
 
 
26
  prompt += f"[INST] {message} [/INST]"
27
  return prompt
28
 
29
+ async def generate(item: Item):
30
+ temperature = max(float(item.temperature), 1e-2) # Ensure temperature is not too low
 
 
31
  top_p = float(item.top_p)
32
 
33
  generate_kwargs = dict(
 
40
  )
41
 
42
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
43
+
44
+ # Stream the response from the model
45
+ async def event_stream():
46
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
47
+
48
+ async for response in stream:
49
+ yield response.token.text # Yield each token as it is received
50
+
51
+ # Optional: Add a small delay to simulate streaming effect (if needed)
52
+ await asyncio.sleep(0.1)
53
+
54
+ return event_stream()
55
 
56
  @app.post("/generate/")
57
  async def generate_text(item: Item):
58
+ return StreamingResponse(generate(item), media_type="text/event-stream")