Spaces:
Paused
Paused
| import yaml | |
| import sys | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import uvicorn | |
| from .api import LLMApi | |
| from .routes import router, init_router | |
| from utils.logging import setup_logger | |
| from huggingface_hub import login | |
| from pathlib import Path | |
| from dotenv import load_dotenv | |
| import os | |
| def validate_hf(): | |
| """ | |
| Validate Hugging Face authentication. | |
| Checks for .env file, loads environment variables, and attempts HF login if token exists. | |
| """ | |
| logger = setup_logger(config, "hf_validation") | |
| # Check for .env file | |
| env_path = Path('.env') | |
| if env_path.exists(): | |
| logger.info("Found .env file, loading environment variables") | |
| load_dotenv() | |
| else: | |
| logger.warning("No .env file found. Fine if you're on Huggingface, but you need one to run locally on your PC.") | |
| # Check for HF token | |
| hf_token = os.getenv('HF_TOKEN') | |
| if not hf_token: | |
| logger.error("No HF_TOKEN found in environment variables") | |
| return False | |
| try: | |
| # Attempt login | |
| login(token=hf_token) | |
| logger.info("Successfully authenticated with Hugging Face") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to authenticate with Hugging Face: {str(e)}") | |
| return False | |
| def load_config(): | |
| """Load configuration from yaml file""" | |
| with open("main/config.yaml", "r") as f: | |
| return yaml.safe_load(f) | |
| def create_app(): | |
| config = load_config() | |
| logger = setup_logger(config, "main") | |
| logger.info("Starting LLM API server") | |
| app = FastAPI( | |
| title="LLM API", | |
| description="API for Large Language Model operations", | |
| version=config["api"]["version"] | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=config["api"]["cors"]["origins"], | |
| allow_credentials=config["api"]["cors"]["credentials"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize routes with config | |
| init_router(config) | |
| app.include_router(router, prefix=f"{config['api']['prefix']}/{config['api']['version']}") | |
| logger.info("FastAPI application created successfully") | |
| return app | |
| def test_locally(): | |
| """Run local tests for development and debugging""" | |
| config = load_config() | |
| logger = setup_logger(config, "test") | |
| logger.info("Starting local tests") | |
| api = LLMApi(config) | |
| model_name = config["model"]["defaults"]["model_name"] | |
| logger.info(f"Testing with model: {model_name}") | |
| # Test download | |
| logger.info("Testing model download...") | |
| api.download_model(model_name) | |
| logger.info("Download complete") | |
| # Test initialization | |
| logger.info("Initializing model...") | |
| api.initialize_model(model_name) | |
| logger.info("Model initialized") | |
| # Test embedding | |
| test_text = "Dette er en test av embeddings generering fra en teknisk tekst om HMS rutiner på arbeidsplassen." | |
| logger.info("Testing embedding generation...") | |
| embedding = api.generate_embedding(test_text) | |
| logger.info(f"Generated embedding of length: {len(embedding)}") | |
| logger.info(f"First few values: {embedding[:5]}") | |
| # Test generation | |
| test_prompts = [ | |
| "Tell me what happens in a nuclear reactor.", | |
| ] | |
| # Test regular generation | |
| logger.info("Testing regular generation:") | |
| for prompt in test_prompts: | |
| logger.info(f"Prompt: {prompt}") | |
| response = api.generate_response( | |
| prompt=prompt, | |
| system_message="You are a helpful assistant." | |
| ) | |
| logger.info(f"Response: {response}") | |
| # Test streaming generation | |
| logger.info("Testing streaming generation:") | |
| logger.info(f"Prompt: {test_prompts[0]}") | |
| for chunk in api.generate_stream( | |
| prompt=test_prompts[0], | |
| system_message="You are a helpful assistant." | |
| ): | |
| print(chunk, end="", flush=True) | |
| print("\n") | |
| logger.info("Local tests completed") | |
| app = create_app() | |
| if __name__ == "__main__": | |
| config = load_config() | |
| validate_hf() | |
| if len(sys.argv) > 1 and sys.argv[1] == "test": | |
| test_locally() | |
| else: | |
| uvicorn.run( | |
| "main.app:app", | |
| host=config["server"]["host"], | |
| port=config["server"]["port"], | |
| log_level="trace", | |
| reload=True, | |
| workers=1, | |
| access_log=False, | |
| use_colors=True | |
| ) |