mike23415 commited on
Commit
1f2df23
·
verified ·
1 Parent(s): 6e146d4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from flask_cors import CORS
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
4
+ import torch
5
+ import os
6
+ import json
7
+
8
+ app = Flask(__name__)
9
+ CORS(app) # Enable CORS for all routes
10
+
11
+ # Set Hugging Face cache to ephemeral storage
12
+ os.environ["HF_HOME"] = "/data/.huggingface"
13
+
14
+ # Load Qwen2.5-1.5B model and tokenizer
15
+ model_name = "Qwen/Qwen2.5-1.5B-Instruct"
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
18
+
19
+ # Move to GPU if available
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ model.to(device)
22
+
23
+ # Data file for preloaded and dynamic data
24
+ data_file = "data/train_data.json"
25
+
26
+ # Load or initialize dataset
27
+ if os.path.exists(data_file):
28
+ with open(data_file, 'r') as f:
29
+ train_texts = json.load(f)
30
+ else:
31
+ train_texts = []
32
+ os.makedirs(os.path.dirname(data_file), exist_ok=True)
33
+ with open(data_file, 'w') as f:
34
+ json.dump(train_texts, f)
35
+ print(f"Loaded {len(train_texts)} examples from {data_file}")
36
+
37
+ # Model save directory
38
+ model_save_dir = "./results/model"
39
+
40
+ @app.route('/adapt', methods=['POST'])
41
+ def adapt_model():
42
+ try:
43
+ data = request.json
44
+ user_input = data.get('text', '')
45
+
46
+ if not user_input:
47
+ return jsonify({'error': 'No input provided'}), 400
48
+
49
+ # Generate self-edit
50
+ prompt = f"Rephrase this: {user_input}"
51
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128).to(device)
52
+ self_edit_output = model.generate(**inputs, max_length=150, num_return_sequences=1)
53
+ self_edit = tokenizer.decode(self_edit_output[0], skip_special_tokens=True)
54
+
55
+ # Add to training data and save to disk
56
+ train_texts.append({"prompt": user_input, "completion": self_edit})
57
+ with open(data_file, 'w') as f:
58
+ json.dump(train_texts, f, indent=2)
59
+
60
+ # Prepare dataset for fine-tuning
61
+ encodings = tokenizer(
62
+ [t["prompt"] + " " + t["completion"] for t in train_texts],
63
+ truncation=True,
64
+ padding=True,
65
+ max_length=256,
66
+ return_tensors="pt"
67
+ )
68
+ dataset = [
69
+ {
70
+ "input_ids": encodings["input_ids"][i],
71
+ "attention_mask": encodings["attention_mask"][i],
72
+ "labels": encodings["input_ids"][i]
73
+ } for i in range(len(train_texts))
74
+ ]
75
+
76
+ # Fine-tune model
77
+ training_args = TrainingArguments(
78
+ output_dir=model_save_dir,
79
+ num_train_epochs=1,
80
+ per_device_train_batch_size=2,
81
+ gradient_accumulation_steps=4,
82
+ logging_steps=10,
83
+ save_steps=10,
84
+ save_total_limit=1, # Keep only latest checkpoint
85
+ disable_tqdm=True,
86
+ fp16=True if torch.cuda.is_available() else False
87
+ )
88
+ trainer = Trainer(
89
+ model=model,
90
+ args=training_args,
91
+ train_dataset=dataset
92
+ )
93
+ trainer.train()
94
+
95
+ # Save model weights
96
+ trainer.save_model(model_save_dir)
97
+ tokenizer.save_pretrained(model_save_dir)
98
+
99
+ # Generate response
100
+ response_inputs = tokenizer(user_input, return_tensors="pt", truncation=True, max_length=128).to(device)
101
+ response_output = model.generate(**response_inputs, max_length=200, num_return_sequences=1)
102
+ response = tokenizer.decode(response_output[0], skip_special_tokens=True)
103
+
104
+ return jsonify({
105
+ 'input': user_input,
106
+ 'self_edit': self_edit,
107
+ 'response': response
108
+ })
109
+
110
+ except Exception as e:
111
+ return jsonify({'error': str(e)}), 500
112
+
113
+ if __name__ == '__main__':
114
+ app.run(host='0.0.0.0', port=7860)