abdullahalioo commited on
Commit
d7f32ed
·
verified ·
1 Parent(s): 3e272d7

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +33 -67
main.py CHANGED
@@ -2,34 +2,35 @@ from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.responses import StreamingResponse
5
- from transformers import AutoTokenizer, AutoModelForCausalLM
6
- import torch
7
- import os
8
  import asyncio
 
 
 
 
 
9
 
10
- # Set cache directories
11
- cache_dir = "/tmp/hf_home"
12
- os.environ["HF_HOME"] = cache_dir
13
- os.environ["TRANSFORMERS_CACHE"] = cache_dir
14
- os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
15
 
16
- os.makedirs(cache_dir, exist_ok=True)
17
- os.chmod(cache_dir, 0o777)
 
18
 
19
- # Load model and tokenizer
20
- model_name = "EleutherAI/gpt-neo-1.3B"
21
- tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
22
- model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
23
 
24
- # Set pad token if not defined
25
- if tokenizer.pad_token is None:
26
- tokenizer.pad_token = tokenizer.eos_token
27
 
28
- # Set device
29
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
- model.to(device)
31
 
32
- # Initialize FastAPI
33
  app = FastAPI()
34
 
35
  # Enable CORS
@@ -41,57 +42,22 @@ app.add_middleware(
41
  allow_headers=["*"],
42
  )
43
 
 
44
  class Question(BaseModel):
45
  question: str
46
 
47
- SYSTEM_PROMPT = "You are a helpful, professional, and highly persuasive sales assistant..."
48
-
49
- chat_history_ids = None
50
-
51
- async def generate_response_chunks(prompt: str):
52
- global chat_history_ids
53
-
54
- # Combine system prompt and user input
55
- input_text = SYSTEM_PROMPT + "\nUser: " + prompt + "\nBot:"
56
- new_input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
57
-
58
- # Create attention mask (handle case where pad_token_id might be None)
59
- attention_mask = torch.ones_like(new_input_ids)
60
-
61
- if chat_history_ids is not None:
62
- input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1)
63
- attention_mask = torch.cat([
64
- torch.ones_like(chat_history_ids),
65
- attention_mask
66
- ], dim=-1)
67
- else:
68
- input_ids = new_input_ids
69
-
70
- # Generate response
71
- output_ids = model.generate(
72
- input_ids,
73
- attention_mask=attention_mask,
74
- max_new_tokens=200,
75
- do_sample=True,
76
- top_p=0.9,
77
- temperature=0.7,
78
- pad_token_id=tokenizer.eos_token_id
79
- )
80
-
81
- # Update chat history
82
- chat_history_ids = output_ids
83
-
84
- # Decode only the new tokens
85
- response = tokenizer.decode(output_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
86
-
87
- # Stream the response
88
- for word in response.split():
89
- yield word + " "
90
- await asyncio.sleep(0.03)
91
 
 
92
  @app.post("/ask")
93
  async def ask(question: Question):
94
  return StreamingResponse(
95
- generate_response_chunks(question.question),
96
  media_type="text/plain"
97
- )
 
2
  from pydantic import BaseModel
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.responses import StreamingResponse
5
+ from hugchat import hugchat
6
+ from hugchat.login import Login
 
7
  import asyncio
8
+ import os
9
+ from dotenv import load_dotenv
10
+
11
+ # Load environment variables from .env file
12
+ load_dotenv()
13
 
14
+ # Read credentials from environment variables
15
+ EMAIL = os.getenv("EMAIL")
16
+ PASSWD = os.getenv("PASSWD")
 
 
17
 
18
+ # Cookie storage
19
+ cookie_path_dir = "./cookies/"
20
+ os.makedirs(cookie_path_dir, exist_ok=True)
21
 
22
+ # HugChat login
23
+ sign = Login(EMAIL, PASSWD)
24
+ cookies = sign.login(cookie_dir_path=cookie_path_dir, save_cookies=True)
 
25
 
26
+ # Create chatbot instance
27
+ chatbot = hugchat.ChatBot(cookies=cookies.get_dict())
 
28
 
29
+ # Optional: Use assistant ID
30
+ ASSISTANT_ID = "66017fca58d60bd7d5c5c26c" # Replace if needed
31
+ chatbot.new_conversation(assistant=ASSISTANT_ID, switch_to=True)
32
 
33
+ # FastAPI setup
34
  app = FastAPI()
35
 
36
  # Enable CORS
 
42
  allow_headers=["*"],
43
  )
44
 
45
+ # Request model
46
  class Question(BaseModel):
47
  question: str
48
 
49
+ # Token stream function
50
+ async def generate_response_stream(prompt: str):
51
+ for chunk in chatbot.chat(prompt, stream=True):
52
+ token = chunk.get("token", "")
53
+ if token:
54
+ yield token
55
+ await asyncio.sleep(0.02)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ # Endpoint
58
  @app.post("/ask")
59
  async def ask(question: Question):
60
  return StreamingResponse(
61
+ generate_response_stream(question.question),
62
  media_type="text/plain"
63
+ )