File size: 1,258 Bytes
3008b61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
import requests
from typing import Optional

# 🔐 Environment Variables
LLAMA3_URL = "https://c5dk65n3sd14gjo1.us-east-1.aws.endpoints.huggingface.cloud"
HF_TOKEN = os.environ.get("HF_TOKEN")

# 📜 Headers
HEADERS = {
    "Authorization": f"Bearer {HF_TOKEN}",
    "Content-Type": "application/json"
}

# 🧠 Prompt builder and caller for LLaMA 3 8B (QCA)
def call_llama3_8b(base_prompt: str, tail_prompt: str) -> Optional[str]:
    prompt = f"<s>[INST]{base_prompt}\n\n{tail_prompt}[/INST]</s>"

    try:
        response = requests.post(
            LLAMA3_URL,
            headers=HEADERS,
            json={"inputs": prompt},
            timeout=60
        )
        response.raise_for_status()
        data = response.json()

        # Parse generated output
        if isinstance(data, list) and data:
            raw_output = data[0].get("generated_text", "")
        elif isinstance(data, dict):
            raw_output = data.get("generated_text", "")
        else:
            return None

        if "[/INST]</s>" in raw_output:
            return raw_output.split("[/INST]</s>")[-1].strip()
        return raw_output.strip()

    except Exception as e:
        print(f"⚠️ LLaMA 3.1 8B API call failed: {e}")
        return None