File size: 4,574 Bytes
f7c0abb
f9d8346
19fe1fe
fa8e2ce
19fe1fe
d0fc55f
19fe1fe
f7c0abb
f9d8346
 
 
 
f7c0abb
 
ba4e0a8
 
 
19fe1fe
 
ba4e0a8
19fe1fe
ba4e0a8
 
 
 
fa8e2ce
f9d8346
ba4e0a8
19fe1fe
6025f1c
f9d8346
6025f1c
2372d93
ba4e0a8
 
 
 
 
 
 
 
 
 
 
 
6025f1c
 
603790a
ba4e0a8
d0fc55f
f7c0abb
19fe1fe
f7c0abb
 
9ab6d04
6025f1c
 
ba4e0a8
 
f7c0abb
 
d0fc55f
045ef7e
 
f7c0abb
 
f9d8346
045ef7e
19fe1fe
f7c0abb
19fe1fe
 
 
 
ba4e0a8
19fe1fe
 
ba4e0a8
 
 
 
 
 
 
 
 
 
19fe1fe
ba4e0a8
 
19fe1fe
 
f9d8346
b9e465f
9ab6d04
19fe1fe
f9d8346
19fe1fe
 
fa8e2ce
ba4e0a8
93c4b1f
7a83ce6
20d0b59
387e225
19fe1fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import logging
from fastapi import FastAPI, HTTPException, Query
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from openai import AsyncOpenAI
from typing import Optional

# Configure logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

app = FastAPI()

# Define valid models (replace with actual models supported by your AI server)
VALID_MODELS = ["default-model", "another-model"]  # Update this list

class GenerateRequest(BaseModel):
    prompt: str
    publisher: Optional[str] = None  # Allow publisher in the body if needed

async def generate_ai_response(prompt: str, model: str, publisher: Optional[str]):
    logger.debug(f"Received prompt: {prompt}, model: {model}, publisher: {publisher}")
    
    # Configuration for AI endpoint
    token = os.getenv("GITHUB_TOKEN")
    endpoint = os.getenv("AI_SERVER_URL", "https://models.github.ai/inference")
    default_publisher = os.getenv("DEFAULT_PUBLISHER", "abdullahalioo")  # Fallback publisher
    
    if not token:
        logger.error("GitHub token not configured")
        raise HTTPException(status_code=500, detail="GitHub token not configured")

    # Use provided publisher or fallback to environment variable
    final_publisher = publisher or default_publisher
    if not final_publisher:
        logger.error("Publisher is required")
        raise HTTPException(status_code=400, detail="Publisher is required")

    # Validate model
    if model not in VALID_MODELS:
        logger.error(f"Invalid model: {model}. Valid models: {VALID_MODELS}")
        raise HTTPException(status_code=400, detail=f"Invalid model. Valid models: {VALID_MODELS}")

    logger.debug(f"Using endpoint: {endpoint}, publisher: {final_publisher}")
    client = AsyncOpenAI(base_url=endpoint, api_key=token)

    try:
        # Include publisher in the request payload (modify as needed based on AI server requirements)
        stream = await client.chat.completions.create(
            messages=[
                {"role": "system", "content": "You are a helpful assistant named Orion, created by Abdullah Ali"},
                {"role": "user", "content": prompt}
            ],
            model=model,
            temperature=1.0,
            top_p=1.0,
            stream=True,
            extra_body={"publisher": final_publisher}  # Add publisher to extra_body
        )

        async for chunk in stream:
            if chunk.choices and chunk.choices[0].delta.content:
                yield chunk.choices[0].delta.content

    except Exception as err:
        logger.error(f"AI generation failed: {str(err)}")
        yield f"Error: {str(err)}"
        raise HTTPException(status_code=500, detail=f"AI generation failed: {str(err)}")

@app.post("/generate", summary="Generate AI response", response_description="Streaming AI response")
async def generate_response(
    model: str = Query("default-model", description="The AI model to use"),
    prompt: Optional[str] = Query(None, description="The input text prompt for the AI"),
    publisher: Optional[str] = Query(None, description="Publisher identifier (optional, defaults to DEFAULT_PUBLISHER env var)"),
    request: Optional[GenerateRequest] = None
):
    """
    Generate a streaming AI response based on the provided prompt, model, and publisher.
    
    - **model**: The AI model to use (e.g., default-model)
    - **prompt**: The input text prompt for the AI (query param or body)
    - **publisher**: The publisher identifier (optional, defaults to DEFAULT_PUBLISHER env var)
    """
    logger.debug(f"Request received - model: {model}, prompt: {prompt}, publisher: {publisher}, body: {request}")
    
    # Determine prompt source: query parameter or request body
    final_prompt = prompt if prompt is not None else (request.prompt if request is not None else None)
    # Determine publisher source: query parameter or request body
    final_publisher = publisher if publisher is not None else (request.publisher if request is not None else None)
    
    if not final_prompt or not final_prompt.strip():
        logger.error("Prompt cannot be empty")
        raise HTTPException(status_code=400, detail="Prompt cannot be empty")
    
    if not model or not model.strip():
        logger.error("Model cannot be empty")
        raise HTTPException(status_code=400, detail="Model cannot be empty")
    
    return StreamingResponse(
        generate_ai_response(final_prompt, model, final_publisher),
        media_type="text/event-stream"
    )

def get_app():
    return app