File size: 2,686 Bytes
4d19455
c92f73a
 
5215f12
4d19455
5215f12
 
4d19455
c56a1eb
 
4d19455
c92f73a
 
 
4d19455
 
 
c92f73a
 
 
5215f12
c92f73a
 
 
5215f12
c92f73a
5215f12
c92f73a
 
 
 
 
 
ac31724
c92f73a
 
 
 
 
ac31724
4d19455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c92f73a
 
4d19455
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
from peft import AutoPeftModelForCausalLM, PeftConfig
from transformers import AutoTokenizer, AutoConfig
from huggingface_hub import login, snapshot_download
import torch
import os
import json

# Login using secret (secure, no hardcode)
login(os.environ["HF_TOKEN"])

# Model setup (loads once on Space startup)
model_id = "agarkovv/CryptoTrader-LM"
base_model_id = "mistralai/Ministral-8B-Instruct-2410"
MAX_LENGTH = 32768
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"  # Use GPU if available (ZeroGPU on HF)

# Download adapter files
adapter_local_dir = snapshot_download(repo_id=model_id)
config_path = os.path.join(adapter_local_dir, "adapter_config.json")
with open(config_path, 'r') as f:
    adapter_config = json.load(f)
if 'model_type' in adapter_config:
    del adapter_config['model_type']
with open(config_path, 'w') as f:
    json.dump(adapter_config, f)

# Download base model config locally to avoid gated access issues
base_local_dir = snapshot_download(repo_id=base_model_id, allow_patterns="config.json")
base_config_path = os.path.join(base_local_dir, "config.json")
base_config = AutoConfig.from_pretrained(base_config_path)

# Load model with explicit base config
token = os.environ["HF_TOKEN"]
model = AutoPeftModelForCausalLM.from_pretrained(
    adapter_local_dir,
    config=base_config,
    token=token
)
tokenizer = AutoTokenizer.from_pretrained(base_model_id, token=token)
model = model.to(DEVICE)
model.eval()

def predict_trading_decision(prompt: str) -> str:
    """Predict daily trading decision (buy, sell, or hold) for BTC or ETH based on news and historical prices.
    
    Args:
        prompt: Input prompt containing cryptocurrency news and historical price data (format: [INST]YOUR PROMPT HERE[/INST]).
    
    Returns:
        Generated trading decision as text (e.g., 'Buy BTC at $62k').
    """
    # Format prompt as required
    formatted_prompt = f"[INST]{prompt}[/INST]"
    
    inputs = tokenizer(
        formatted_prompt, return_tensors="pt", padding=False, max_length=MAX_LENGTH, truncation=True
    )
    inputs = {key: value.to(model.device) for key, value in inputs.items()}

    res = model.generate(
        **inputs,
        use_cache=True,
        max_new_tokens=MAX_LENGTH,
    )
    output = tokenizer.decode(res[0], skip_special_tokens=True)
    return output

# Gradio Interface
demo = gr.Interface(
    fn=predict_trading_decision,
    inputs=gr.Textbox(label="Input Prompt (News + Prices)"),
    outputs=gr.Textbox(label="Trading Decision"),
    title="CryptoTrader-LM MCP Tool",
    description="Predict buy/sell/hold for BTC/ETH."
)

# Launch with MCP support
demo.launch(mcp_server=True)