import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer import os # --- 모델 로드 --- # 모델 경로 설정 (Hugging Face 모델 ID) model_id = "microsoft/bitnet-b1.58-2B-4T" # 모델 로드 시 경고 메시지를 최소화하기 위해 로깅 레벨 설정 os.environ["TRANSFORMERS_VERBOSITY"] = "error" # AutoModelForCausalLM과 AutoTokenizer를 로드합니다. # BitNet 모델은 trust_remote_code=True가 필요합니다. # bf16은 메모리 사용량을 줄이고 속도를 향상시킬 수 있습니다 (GPU 지원 시). # CPU만 사용하는 경우 torch_dtype을 생략하거나 torch.float32로 설정할 수 있습니다. try: print(f"모델 로딩 중: {model_id}...") tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) # GPU가 사용 가능하면 bf16 사용 if torch.cuda.is_available(): model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, trust_remote_code=True ).to("cuda") # GPU로 모델 이동 print("GPU를 사용하여 모델 로드 완료.") else: model = AutoModelForCausalLM.from_pretrained( model_id, trust_remote_code=True ) print("CPU를 사용하여 모델 로드 완료. 성능이 느릴 수 있습니다.") except Exception as e: print(f"모델 로드 중 오류 발생: {e}") tokenizer = None model = None print("모델 로드에 실패했습니다. 애플리케이션이 제대로 동작하지 않을 수 있습니다.") # --- 텍스트 생성 함수 --- def generate_text(prompt, max_length=100, temperature=0.7): if model is None or tokenizer is None: return "모델 로드에 실패하여 텍스트 생성을 할 수 없습니다." try: # 프롬프트 토큰화 inputs = tokenizer(prompt, return_tensors="pt") # GPU 사용 가능 시 GPU로 입력 이동 if torch.cuda.is_available(): inputs = {k: v.to("cuda") for k, v in inputs.items()} # 텍스트 생성 # LLaMA 3 토크나이저를 사용하므로 chat template 적용 가능 (선택 사항) # 메시지 형식을 사용하지 않고 직접 프롬프트 입력 시 아래 코드 사용 outputs = model.generate( **inputs, max_new_tokens=max_length, temperature=temperature, do_sample=True, # 샘플링 활성화 pad_token_id=tokenizer.eos_token_id # 패딩 토큰 ID 설정 (필요시) ) # 생성된 텍스트 디코딩 # 입력 프롬프트 부분을 제외하고 생성된 부분만 디코딩 generated_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True) return generated_text except Exception as e: return f"텍스트 생성 중 오류 발생: {e}" # --- Gradio 인터페이스 설정 --- if model is not None and tokenizer is not None: interface = gr.Interface( fn=generate_text, inputs=[ gr.Textbox(lines=2, placeholder="텍스트를 입력하세요...", label="입력 프롬프트"), gr.Slider(minimum=10, maximum=500, value=100, label="최대 생성 길이"), gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Temperature (창의성)") ], outputs=gr.Textbox(label="생성된 텍스트"), title="BitNet b1.58-2B-4T 텍스트 생성 데모", description="BitNet b1.58-2B-4T 모델을 사용하여 텍스트를 생성합니다." ) # Gradio 앱 실행 # share=True를 하면 임시 공개 링크가 생성됩니다. interface.launch(share=False) else: print("모델 로드 실패로 인해 Gradio 인터페이스를 실행할 수 없습니다.")