Sam3838 commited on
Commit
3b9f744
·
verified ·
1 Parent(s): 3790a3b

Upload 4 files

Browse files
Files changed (4) hide show
  1. .gitignore +67 -0
  2. image_generator.py +100 -0
  3. models.py +74 -0
  4. requirements.txt +7 -0
.gitignore ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+ MANIFEST
23
+
24
+ # Model cache (if any)
25
+ models/
26
+
27
+ # Environment variables
28
+ .env
29
+ .env.local
30
+ .env.*.local
31
+
32
+ # IDE
33
+ .vscode/
34
+ .idea/
35
+ *.swp
36
+ *.swo
37
+ *~
38
+
39
+ # Model cache
40
+ models/
41
+ generated_images/
42
+
43
+ # Logs
44
+ *.log
45
+ logs/
46
+
47
+ # OS
48
+ .DS_Store
49
+ Thumbs.db
50
+
51
+ # Virtual environment
52
+ venv/
53
+ env/
54
+ ENV/
55
+
56
+ # Jupyter
57
+ .ipynb_checkpoints/
58
+
59
+ # pytest
60
+ .pytest_cache/
61
+
62
+ # Coverage
63
+ .coverage
64
+ htmlcov/
65
+
66
+ # Docker
67
+ .dockerignore
image_generator.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import os
4
+ import time
5
+ import logging
6
+ from typing import List
7
+ from PIL import Image
8
+ from huggingface_hub import InferenceClient
9
+ from config import config
10
+ from models import ImageGenerationRequest, ImageData, ResponseFormat
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class ImageGenerator:
16
+ """Text-to-image generator using Hugging Face InferenceClient"""
17
+
18
+ def __init__(self):
19
+ self.client = None
20
+ self._ensure_output_dir()
21
+
22
+ def _ensure_output_dir(self):
23
+ """Ensure output directory exists"""
24
+ os.makedirs(config.OUTPUT_DIR, exist_ok=True)
25
+
26
+ def _get_client(self):
27
+ """Get or create the InferenceClient"""
28
+ if self.client is None:
29
+ self.client = InferenceClient(
30
+ provider="replicate",
31
+ api_key=config.HF_TOKEN,
32
+ )
33
+ return self.client
34
+
35
+ def _image_to_base64(self, image: Image.Image) -> str:
36
+ """Convert PIL Image to base64 string"""
37
+ buffer = io.BytesIO()
38
+ image.save(buffer, format="PNG")
39
+ img_str = base64.b64encode(buffer.getvalue()).decode()
40
+ return img_str
41
+
42
+ def _save_image(self, image: Image.Image, filename: str) -> str:
43
+ """Save image and return URL"""
44
+ filepath = os.path.join(config.OUTPUT_DIR, filename)
45
+ image.save(filepath)
46
+ return f"{config.BASE_URL}/images/{filename}"
47
+
48
+ async def generate_images(self, request: ImageGenerationRequest) -> List[ImageData]:
49
+ """Generate images based on the request"""
50
+ client = self._get_client()
51
+
52
+ # Generate images
53
+ results = []
54
+
55
+ for i in range(request.n):
56
+ try:
57
+ logger.info(f"Generating image {i+1}/{request.n} for prompt: {request.prompt[:50]}...")
58
+
59
+ # Generate the image using HuggingFace InferenceClient
60
+ image = client.text_to_image(
61
+ request.prompt,
62
+ model=config.DEFAULT_MODEL,
63
+ )
64
+
65
+ # Create response based on format
66
+ if request.response_format == ResponseFormat.B64_JSON:
67
+ image_data = ImageData(
68
+ b64_json=self._image_to_base64(image),
69
+ revised_prompt=request.prompt
70
+ )
71
+ else:
72
+ # Save image and return URL
73
+ timestamp = int(time.time())
74
+ filename = f"generated_{timestamp}_{i}.png"
75
+ url = self._save_image(image, filename)
76
+ image_data = ImageData(
77
+ url=url,
78
+ revised_prompt=request.prompt
79
+ )
80
+
81
+ results.append(image_data)
82
+ logger.info(f"Successfully generated image {i+1}/{request.n}")
83
+
84
+ except Exception as e:
85
+ logger.error(f"Failed to generate image {i+1}: {e}")
86
+ # Continue with other images
87
+ continue
88
+
89
+ if not results:
90
+ raise Exception("Failed to generate any images")
91
+
92
+ return results
93
+
94
+ def cleanup(self):
95
+ """Cleanup resources"""
96
+ self.client = None
97
+
98
+
99
+ # Global instance
100
+ image_generator = ImageGenerator()
models.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import List, Optional, Literal
3
+ from enum import Enum
4
+
5
+
6
+ class ImageSize(str, Enum):
7
+ """Supported image sizes (OpenAI compatible)"""
8
+ SMALL = "256x256"
9
+ MEDIUM = "512x512"
10
+ LARGE = "1024x1024"
11
+ WIDE = "1792x1024"
12
+ TALL = "1024x1792"
13
+
14
+
15
+ class ImageQuality(str, Enum):
16
+ """Image quality options"""
17
+ STANDARD = "standard"
18
+ HD = "hd"
19
+
20
+
21
+ class ImageStyle(str, Enum):
22
+ """Image style options"""
23
+ VIVID = "vivid"
24
+ NATURAL = "natural"
25
+
26
+
27
+ class ResponseFormat(str, Enum):
28
+ """Response format options"""
29
+ URL = "url"
30
+ B64_JSON = "b64_json"
31
+
32
+
33
+ class ImageGenerationRequest(BaseModel):
34
+ """OpenAI compatible image generation request"""
35
+ prompt: str = Field(..., description="A text description of the desired image(s)")
36
+ model: str = Field(default="dall-e-3", description="The model to use for image generation")
37
+ n: int = Field(default=1, ge=1, le=10, description="Number of images to generate")
38
+ quality: ImageQuality = Field(default=ImageQuality.STANDARD, description="Quality of the image")
39
+ response_format: ResponseFormat = Field(default=ResponseFormat.URL, description="Response format")
40
+ size: ImageSize = Field(default=ImageSize.LARGE, description="Size of the generated images")
41
+ style: ImageStyle = Field(default=ImageStyle.VIVID, description="Style of the generated images")
42
+ user: Optional[str] = Field(default=None, description="A unique identifier representing your end-user")
43
+
44
+
45
+ class ImageData(BaseModel):
46
+ """Individual image data in response"""
47
+ url: Optional[str] = Field(default=None, description="URL of the generated image")
48
+ b64_json: Optional[str] = Field(default=None, description="Base64 encoded image data")
49
+ revised_prompt: Optional[str] = Field(default=None, description="The revised prompt used for generation")
50
+
51
+
52
+ class ImageGenerationResponse(BaseModel):
53
+ """OpenAI compatible image generation response"""
54
+ created: int = Field(..., description="Unix timestamp of when the image was created")
55
+ data: List[ImageData] = Field(..., description="List of generated images")
56
+
57
+
58
+ class ErrorResponse(BaseModel):
59
+ """Error response format"""
60
+ error: dict = Field(..., description="Error details")
61
+
62
+
63
+ class ModelInfo(BaseModel):
64
+ """Model information"""
65
+ id: str
66
+ object: str = "model"
67
+ created: int
68
+ owned_by: str
69
+
70
+
71
+ class ModelsResponse(BaseModel):
72
+ """Models list response"""
73
+ object: str = "list"
74
+ data: List[ModelInfo]
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi==0.104.1
2
+ uvicorn[standard]==0.24.0
3
+ pydantic==2.5.0
4
+ pillow==10.1.0
5
+ huggingface_hub
6
+ requests==2.31.0
7
+ python-multipart==0.0.6