|
import os |
|
import requests |
|
from typing import Optional |
|
|
|
|
|
LLAMA3_URL = "https://c5dk65n3sd14gjo1.us-east-1.aws.endpoints.huggingface.cloud" |
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
|
|
|
HEADERS = { |
|
"Authorization": f"Bearer {HF_TOKEN}", |
|
"Content-Type": "application/json" |
|
} |
|
|
|
|
|
def call_llama3_8b(base_prompt: str, tail_prompt: str) -> Optional[str]: |
|
prompt = f"<s>[INST]{base_prompt}\n\n{tail_prompt}[/INST]</s>" |
|
|
|
try: |
|
response = requests.post( |
|
LLAMA3_URL, |
|
headers=HEADERS, |
|
json={"inputs": prompt}, |
|
timeout=60 |
|
) |
|
response.raise_for_status() |
|
data = response.json() |
|
|
|
|
|
if isinstance(data, list) and data: |
|
raw_output = data[0].get("generated_text", "") |
|
elif isinstance(data, dict): |
|
raw_output = data.get("generated_text", "") |
|
else: |
|
return None |
|
|
|
if "[/INST]</s>" in raw_output: |
|
return raw_output.split("[/INST]</s>")[-1].strip() |
|
return raw_output.strip() |
|
|
|
except Exception as e: |
|
print(f"β οΈ LLaMA 3.1 8B API call failed: {e}") |
|
return None |
|
|