|
from fastapi import FastAPI, HTTPException |
|
from fastapi.responses import HTMLResponse |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
app = FastAPI(title="Chatbot") |
|
|
|
|
|
model_name = "microsoft/DialoGPT-medium" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
|
|
|
def get_chatbot_response(user_input: str, max_length=100): |
|
if not user_input: |
|
return "Please say something!" |
|
|
|
|
|
input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt") |
|
chat_history_ids = model.generate( |
|
input_ids, |
|
max_length=max_length, |
|
pad_token_id=tokenizer.eos_token_id, |
|
no_repeat_ngram_size=3, |
|
do_sample=True, |
|
top_k=50, |
|
top_p=0.95, |
|
temperature=0.8 |
|
) |
|
|
|
response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True) |
|
return response.strip() |
|
|
|
|
|
HTML_CONTENT = """ |
|
<!DOCTYPE html> |
|
<html lang="en"> |
|
<head> |
|
<meta charset="UTF-8"> |
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"> |
|
<title>Chatbot</title> |
|
<style> |
|
body { |
|
font-family: Arial, sans-serif; |
|
background-color: #f0f2f5; |
|
margin: 0; |
|
padding: 20px; |
|
display: flex; |
|
justify-content: center; |
|
align-items: center; |
|
min-height: 100vh; |
|
} |
|
.container { |
|
max-width: 800px; |
|
width: 100%; |
|
padding: 2rem; |
|
background: white; |
|
border-radius: 10px; |
|
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); |
|
} |
|
h1 { |
|
color: #2c3e50; |
|
text-align: center; |
|
margin-bottom: 2rem; |
|
} |
|
.chat-area { |
|
max-height: 400px; |
|
overflow-y: auto; |
|
margin-bottom: 1.5rem; |
|
padding: 1rem; |
|
background: #f9f9f9; |
|
border: 2px solid #eee; |
|
border-radius: 5px; |
|
} |
|
.message { |
|
margin: 0.5rem 0; |
|
padding: 0.8rem; |
|
border-radius: 5px; |
|
} |
|
.user-message { |
|
background-color: #3498db; |
|
color: white; |
|
margin-left: 20%; |
|
text-align: right; |
|
} |
|
.bot-message { |
|
background-color: #ecf0f1; |
|
color: #2c3e50; |
|
margin-right: 20%; |
|
} |
|
.input-section { |
|
display: flex; |
|
gap: 1rem; |
|
} |
|
input[type="text"] { |
|
flex: 1; |
|
padding: 0.8rem; |
|
border: 2px solid #ddd; |
|
border-radius: 5px; |
|
font-size: 1rem; |
|
} |
|
button { |
|
padding: 0.8rem 1.5rem; |
|
background-color: #3498db; |
|
color: white; |
|
border: none; |
|
border-radius: 5px; |
|
cursor: pointer; |
|
font-size: 1rem; |
|
transition: background-color 0.3s; |
|
} |
|
button:hover { |
|
background-color: #2980b9; |
|
} |
|
</style> |
|
</head> |
|
<body> |
|
<div class="container"> |
|
<h1>Chatbot</h1> |
|
<div class="chat-area" id="chatArea"></div> |
|
<div class="input-section"> |
|
<input type="text" id="userInput" placeholder="Type your message..." onkeypress="if(event.key === 'Enter') sendMessage();"> |
|
<button onclick="sendMessage()">Send</button> |
|
</div> |
|
</div> |
|
<script> |
|
const chatArea = document.getElementById("chatArea"); |
|
const userInput = document.getElementById("userInput"); |
|
|
|
function addMessage(text, isUser = false) { |
|
const messageDiv = document.createElement("div"); |
|
messageDiv.className = "message " + (isUser ? "user-message" : "bot-message"); |
|
messageDiv.textContent = text; |
|
chatArea.appendChild(messageDiv); |
|
chatArea.scrollTop = chatArea.scrollHeight; |
|
} |
|
|
|
async function sendMessage() { |
|
const text = userInput.value.trim(); |
|
if (!text) return; |
|
|
|
addMessage(text, true); |
|
userInput.value = ""; |
|
addMessage("Thinking..."); |
|
|
|
try { |
|
const response = await fetch("/chat", { |
|
method: "POST", |
|
headers: { "Content-Type": "application/json" }, |
|
body: JSON.stringify({ message: text }) |
|
}); |
|
const data = await response.json(); |
|
if (!response.ok) throw new Error(data.detail || "Chat error"); |
|
|
|
chatArea.lastChild.remove(); |
|
addMessage(data.response); |
|
} catch (error) { |
|
chatArea.lastChild.remove(); |
|
addMessage(`Error: ${error.message}`); |
|
} |
|
} |
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
async def read_root(): |
|
return HTML_CONTENT |
|
|
|
@app.post("/chat") |
|
async def chat_endpoint(data: dict): |
|
message = data.get("message", "") |
|
if not message: |
|
raise HTTPException(status_code=400, detail="No message provided") |
|
try: |
|
response = get_chatbot_response(message) |
|
return {"response": response} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Chat error: {str(e)}") |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |