import logging import os from transformers import AutoModelForCausalLM, AutoTokenizer from peft import LoraConfig, get_peft_model from trl import SFTTrainer, SFTConfig from datasets import load_dataset import torch import tarfile from huggingface_hub import HfApi logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) # Debug environment variables logger.info("Environment variables: %s", {k: "****" if "TOKEN" in k or k == "granite" else v for k, v in os.environ.items()}) model_path = "ibm-granite/granite-3.3-8b-instruct" dataset_path = "mycholpath/ascii-json" output_dir = "/app/granite-8b-finetuned-ascii" output_tarball = "/app/granite-8b-finetuned-ascii.tar.gz" model_repo = "mycholpath/granite-8b-finetuned-ascii" artifact_repo = "mycholpath/granite-finetuned-artifacts" # Get HF token from granite environment variable granite_var = os.getenv("granite") if not granite_var or not granite_var.startswith("HF_TOKEN="): logger.error("granite environment variable is not set or invalid. Expected format: HF_TOKEN=.") raise ValueError("granite environment variable is not set or invalid. Please set it in HF Space settings.") hf_token = granite_var.replace("HF_TOKEN=", "") logger.info("HF_TOKEN extracted from granite (value hidden for security)") logging.info("Loading tokenizer...") try: tokenizer = AutoTokenizer.from_pretrained( model_path, token=hf_token, cache_dir="/tmp/hf_cache", trust_remote_code=True ) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = 'right' except Exception as e: logger.error(f"Failed to load tokenizer: {str(e)}") raise logging.info("Loading model...") try: model = AutoModelForCausalLM.from_pretrained( model_path, token=hf_token, torch_dtype=torch.float16, device_map="auto", cache_dir="/tmp/hf_cache", trust_remote_code=True ) except Exception as e: logger.error(f"Failed to load model: {str(e)}") raise lora_config = LoraConfig( r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM" ) model = get_peft_model(model, lora_config) logging.info("Preparing to load private dataset...") logger.info("Using HF_TOKEN from granite for private dataset authentication") try: dataset = load_dataset(dataset_path, split="train", token=hf_token) logger.info(f"Dataset loaded successfully: {len(dataset)} examples") except Exception as e: logger.error(f"Failed to load dataset: {str(e)}") raise def formatting_prompts_func(example): formatted = f"{example['prompt']}\n{example['completion']}" return [formatted] # Use SFTConfig for training arguments sft_config = SFTConfig( output_dir=output_dir, num_train_epochs=5, per_device_train_batch_size=4, per_device_eval_batch_size=4, gradient_accumulation_steps=4, learning_rate=2e-4, weight_decay=0.01, eval_strategy="no", save_steps=50, logging_steps=10, fp16=True, max_grad_norm=0.3, warmup_ratio=0.03, lr_scheduler_type="cosine", max_seq_length=768, dataset_text_field=None, packing=False ) logging.info("Starting training...") try: trainer = SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=dataset, eval_dataset=None, formatting_func=formatting_prompts_func, args=sft_config ) except Exception as e: logger.error(f"Failed to initialize SFTTrainer: {str(e)}") raise trainer.train() logging.info("Saving fine-tuned model...") trainer.save_model(output_dir) tokenizer.save_pretrained(output_dir) # Create tarball for local retrieval try: with tarfile.open(output_tarball, "w:gz") as tar: tar.add(output_dir, arcname=os.path.basename(output_dir)) logger.info(f"Model tarball created: {output_tarball}") except Exception as e: logger.error(f"Failed to create model tarball: {str(e)}") raise # Upload model to HF Hub try: api = HfApi() logger.info(f"Creating model repository: {model_repo}") api.create_repo( repo_id=model_repo, repo_type="model", token=hf_token, private=True, exist_ok=True ) logger.info(f"Uploading model to {model_repo}") api.upload_folder( folder_path=output_dir, repo_id=model_repo, repo_type="model", token=hf_token, create_pr=False ) logger.info(f"Fine-tuned model uploaded to {model_repo}") except Exception as e: logger.error(f"Failed to upload model to HF Hub: {str(e)}") logger.warning("Continuing to tarball upload despite model upload failure") # Upload tarball to HF Hub dataset repository try: api = HfApi() logger.info(f"Creating dataset repository: {artifact_repo}") api.create_repo( repo_id=artifact_repo, repo_type="dataset", token=hf_token, private=True, exist_ok=True ) logger.info(f"Uploading tarball to {artifact_repo}") api.upload_file( path_or_fileobj=output_tarball, path_in_repo="granite-8b-finetuned-ascii.tar.gz", repo_id=artifact_repo, repo_type="dataset" token=hf_token ) logger.info(f"Tarball uploaded to {artifact_repo}/granite-8b-finetuned-ascii.tar.gz") except Exception as e: logger.error(f"Failed to upload tarball to HF Hub: {str(e)}") raise