bhkkhjgkk commited on
Commit
63b4fe7
·
verified ·
1 Parent(s): 68b47d7

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +11 -18
main.py CHANGED
@@ -1,13 +1,12 @@
1
  from fastapi import FastAPI
2
- from fastapi.responses import StreamingResponse
3
  from pydantic import BaseModel
4
  from huggingface_hub import InferenceClient
 
5
  import uvicorn
6
- import asyncio
7
 
8
  app = FastAPI()
9
 
10
- client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
11
 
12
  class Item(BaseModel):
13
  prompt: str
@@ -26,8 +25,10 @@ def format_prompt(message, history):
26
  prompt += f"[INST] {message} [/INST]"
27
  return prompt
28
 
29
- async def generate(item: Item):
30
- temperature = max(float(item.temperature), 1e-2) # Ensure temperature is not too low
 
 
31
  top_p = float(item.top_p)
32
 
33
  generate_kwargs = dict(
@@ -40,19 +41,11 @@ async def generate(item: Item):
40
  )
41
 
42
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
43
-
44
- # Stream the response from the model
45
- async def event_stream():
46
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
47
-
48
- async for response in stream:
49
- yield response.token.text # Yield each token as it is received
50
-
51
- # Optional: Add a small delay to simulate streaming effect (if needed)
52
- await asyncio.sleep(0.1)
53
-
54
- return event_stream()
55
 
56
  @app.post("/generate/")
57
  async def generate_text(item: Item):
58
- return StreamingResponse(generate(item), media_type="text/event-stream")
 
1
  from fastapi import FastAPI
 
2
  from pydantic import BaseModel
3
  from huggingface_hub import InferenceClient
4
+ from fastapi.responses import StreamingResponse
5
  import uvicorn
 
6
 
7
  app = FastAPI()
8
 
9
+ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
10
 
11
  class Item(BaseModel):
12
  prompt: str
 
25
  prompt += f"[INST] {message} [/INST]"
26
  return prompt
27
 
28
+ async def generate_stream(item: Item):
29
+ temperature = float(item.temperature)
30
+ if temperature < 1e-2:
31
+ temperature = 1e-2
32
  top_p = float(item.top_p)
33
 
34
  generate_kwargs = dict(
 
41
  )
42
 
43
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
44
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
45
+
46
+ for response in stream:
47
+ yield response.token.text # Stream each token as it's received
 
 
 
 
 
 
 
 
48
 
49
  @app.post("/generate/")
50
  async def generate_text(item: Item):
51
+ return StreamingResponse(generate_stream(item), media_type="text/plain")