|
|
|
from flask import Flask, request, jsonify |
|
from flask_cors import CORS |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
|
|
|
|
MODEL_PATHS = { |
|
"gpt_oss_120b": "openai/gpt-oss-120b", |
|
"deepseek_v3": "deepseek-ai/DeepSeek-V3.1-Base", |
|
"gemini_25_pro": "afu4642tD/gemini-2.5-pro", |
|
"veo3": "sudip1987/Generate_videos_with_Veo3", |
|
"open_sora": "hpcai-tech/Open-Sora-v2", |
|
"usp_image": "GD-ML/USP-Image_Generation", |
|
"text_to_music": "sander-wood/text-to-music", |
|
"qwen_image": "Qwen/Qwen-Image", |
|
"qwen_image_diff": "Comfy-Org/Qwen-Image-DiffSynth-ControlNets", |
|
"coqui_tts": "sk0032/coqui-tts-model", |
|
"edge_tts": "sysf/Edge-TTS", |
|
"whisper_large": "openai/whisper-large-v3-turbo", |
|
"blip2_opt": "Salesforce/blip2-opt-2.7b", |
|
"mini_gpt4": "Vision-CAIR/MiniGPT-4", |
|
"glm_45": "zai-org/GLM-4.5", |
|
"chatglm3": "zai-org/chatglm3-6b", |
|
"gpt_oss_20b": "openai/gpt-oss-20b", |
|
"m2m100": "facebook/m2m100_1.2B", |
|
"tiny_marian": "onnx-community/tiny-random-MarianMTModel", |
|
"memory_transformer": "Grpp/memory-transformer-ru", |
|
"rl_memory_agent": "BytedTsinghua-SIA/RL-MemoryAgent-14B", |
|
"m3_agent": "ByteDance-Seed/M3-Agent-Memorization", |
|
"text_to_video": "ali-vilab/text-to-video-ms-1.7b" |
|
} |
|
|
|
def generate_single_answer(prompt, model_key): |
|
"""Load model, generate answer, free memory""" |
|
model_name = MODEL_PATHS[model_key] |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
device_map="auto" if torch.cuda.is_available() else None, |
|
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
) |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
with torch.no_grad(): |
|
output = model.generate(**inputs, max_new_tokens=200) |
|
answer = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
|
|
del model, tokenizer, inputs, output |
|
torch.cuda.empty_cache() |
|
|
|
return answer |
|
|
|
@app.route("/ask", methods=["POST"]) |
|
def ask(): |
|
data = request.json |
|
prompt = data.get("prompt", "") |
|
selected_models = data.get("models", ["gpt_oss_120b", "deepseek_v3", "gemini_25_pro"]) |
|
|
|
|
|
answers = [] |
|
for model_key in selected_models: |
|
if model_key in MODEL_PATHS: |
|
try: |
|
ans = generate_single_answer(prompt, model_key) |
|
answers.append(ans) |
|
except Exception as e: |
|
answers.append(f"[Error loading {model_key}]") |
|
|
|
|
|
final_answer = " | ".join(answers) |
|
return jsonify({"answer": final_answer}) |
|
|
|
if __name__ == "__main__": |
|
app.run(host="0.0.0.0", port=5000, debug=True) |
|
|