CODEXspace / tinyllama_inference.py
DEADLOCK007X's picture
Update JSON extraction in tinyllama_inference.py to select last JSON block
b72b033
raw
history blame
2.53 kB
import json
import re
from transformers import AutoTokenizer, AutoModelForCausalLM
# Global variables for caching the model and tokenizer
tokenizer, model = None, None
def load_model():
global tokenizer, model
if tokenizer is None or model is None:
# Use the DeepSeek instruct model for code evaluation.
model_name = "deepseek-ai/deepseek-coder-1.3b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
return tokenizer, model
def evaluate_code(question, code):
prompt = f"""You are an expert code evaluator.
Evaluate the following solution for the given problem.
Respond with exactly one JSON object (with no extra text) that has exactly two keys:
"stars": an integer between 0 and 5 (0 means completely incorrect, 5 means excellent),
"feedback": a concise string message.
The JSON must start with '{{' and end with '}}'.
Do not output anything else.
Question: "{question}"
Solution: "{code}"
Your response:"""
tokenizer, model = load_model()
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(
**inputs,
max_new_tokens=100, # Allow enough tokens for a complete response
temperature=0.2, # Small randomness for creativity
pad_token_id=tokenizer.eos_token_id,
do_sample=True # Enable sampling to encourage generation
)
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Raw model response:", response_text) # Debug output
# Use regex to extract all JSON objects (non-greedy)
matches = re.findall(r'\{.*?\}', response_text)
result = None
for m in matches:
try:
temp = json.loads(m)
# Check that the parsed JSON contains both expected keys
if isinstance(temp, dict) and "stars" in temp and "feedback" in temp:
result = temp
break
except Exception:
continue
if result is None:
result = {"stars": 0, "feedback": "Evaluation failed. Unable to extract valid JSON from AI response."}
return result
# For direct command-line testing.
if __name__ == "__main__":
import sys
if len(sys.argv) < 3:
print(json.dumps({"error": "Please provide a question and code as arguments"}))
sys.exit(1)
question = sys.argv[1]
code = sys.argv[2]
result = evaluate_code(question, code)
print(json.dumps(result))