|
|
|
import torch
|
|
import torch.nn as nn
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from PIL import Image
|
|
import os
|
|
|
|
|
|
class Generator(nn.Module):
|
|
def __init__(self):
|
|
super(Generator, self).__init__()
|
|
|
|
self.label_emb = nn.Embedding(10, 10)
|
|
|
|
self.model = nn.Sequential(
|
|
nn.Linear(100 + 10, 256),
|
|
nn.LeakyReLU(0.2),
|
|
nn.BatchNorm1d(256),
|
|
|
|
nn.Linear(256, 512),
|
|
nn.LeakyReLU(0.2),
|
|
nn.BatchNorm1d(512),
|
|
|
|
nn.Linear(512, 1024),
|
|
nn.LeakyReLU(0.2),
|
|
nn.BatchNorm1d(1024),
|
|
|
|
nn.Linear(1024, 784),
|
|
nn.Tanh()
|
|
)
|
|
|
|
def forward(self, noise, labels):
|
|
label_embedding = self.label_emb(labels)
|
|
gen_input = torch.cat((noise, label_embedding), -1)
|
|
img = self.model(gen_input)
|
|
img = img.view(img.size(0), 1, 28, 28)
|
|
return img
|
|
|
|
|
|
class SevenGenInference:
|
|
def __init__(self, model_path='models/7gen_generator.pth'):
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
self.latent_dim = 100
|
|
|
|
|
|
self.generator = Generator().to(self.device)
|
|
self.generator.load_state_dict(torch.load(model_path, map_location=self.device))
|
|
self.generator.eval()
|
|
|
|
print(f"🚀 7Gen yüklendi! Cihaz: {self.device}")
|
|
|
|
def generate_digit(self, digit, count=5):
|
|
"""Belirli bir rakamdan istenen sayıda üret"""
|
|
with torch.no_grad():
|
|
z = torch.randn(count, self.latent_dim).to(self.device)
|
|
labels = torch.full((count,), digit).to(self.device)
|
|
|
|
images = self.generator(z, labels)
|
|
images = (images + 1) / 2
|
|
|
|
return images.cpu()
|
|
|
|
def visualize_digits(self, digit, count=5, save_path=None):
|
|
"""Üretilen rakamları görselleştir"""
|
|
images = self.generate_digit(digit, count)
|
|
|
|
fig, axes = plt.subplots(1, count, figsize=(2*count, 2))
|
|
if count == 1:
|
|
axes = [axes]
|
|
|
|
for i, ax in enumerate(axes):
|
|
ax.imshow(images[i][0], cmap='gray')
|
|
ax.axis('off')
|
|
ax.set_title(f'Digit: {digit}')
|
|
|
|
plt.tight_layout()
|
|
|
|
if save_path:
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
print(f"💾 Görsel kaydedildi: {save_path}")
|
|
|
|
plt.show()
|
|
|
|
def generate_grid(self, samples_per_digit=10, save_path=None):
|
|
"""Her rakamdan örneklerle 10x10 grid oluştur"""
|
|
all_images = []
|
|
|
|
for digit in range(10):
|
|
images = self.generate_digit(digit, samples_per_digit)
|
|
all_images.append(images)
|
|
|
|
all_images = torch.cat(all_images, dim=0)
|
|
|
|
fig, axes = plt.subplots(10, samples_per_digit, figsize=(15, 15))
|
|
|
|
for i in range(10):
|
|
for j in range(samples_per_digit):
|
|
idx = i * samples_per_digit + j
|
|
axes[i, j].imshow(all_images[idx][0], cmap='gray')
|
|
axes[i, j].axis('off')
|
|
|
|
if j == 0:
|
|
axes[i, j].set_ylabel(f'{i}', rotation=0, size=20, labelpad=20)
|
|
|
|
plt.suptitle('7Gen - Üretilen Rakamlar', size=20)
|
|
plt.tight_layout()
|
|
|
|
if save_path:
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
print(f"💾 Grid kaydedildi: {save_path}")
|
|
|
|
plt.show()
|
|
|
|
def save_as_png(self, digit, count=1, output_dir='output'):
|
|
"""Tekil PNG dosyaları olarak kaydet"""
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
images = self.generate_digit(digit, count)
|
|
|
|
for i in range(count):
|
|
img = images[i][0].numpy()
|
|
img = (img * 255).astype(np.uint8)
|
|
|
|
pil_img = Image.fromarray(img)
|
|
filename = f"{output_dir}/digit_{digit}_{i+1}.png"
|
|
pil_img.save(filename)
|
|
|
|
print(f"💾 Kaydedildi: {filename}")
|
|
|
|
def interactive_generate(self):
|
|
"""İnteraktif kullanım"""
|
|
print("\n🎮 7Gen İnteraktif Mod")
|
|
print("Çıkmak için 'q' yazın")
|
|
|
|
while True:
|
|
try:
|
|
digit_input = input("\nHangi rakamı üretmek istersin? (0-9): ")
|
|
|
|
if digit_input.lower() == 'q':
|
|
print("👋 Görüşürüz!")
|
|
break
|
|
|
|
digit = int(digit_input)
|
|
if 0 <= digit <= 9:
|
|
count = int(input("Kaç tane üreteyim? (1-20): "))
|
|
if 1 <= count <= 20:
|
|
self.visualize_digits(digit, count)
|
|
else:
|
|
print("❌ 1-20 arası bir sayı gir!")
|
|
else:
|
|
print("❌ 0-9 arası bir rakam gir!")
|
|
|
|
except ValueError:
|
|
print("❌ Geçerli bir sayı gir!")
|
|
except KeyboardInterrupt:
|
|
print("\n👋 Görüşürüz!")
|
|
break
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
seven_gen = SevenGenInference()
|
|
|
|
|
|
print("\n📝 Örnek kullanımlar:")
|
|
print("1. Tekil rakam üret")
|
|
seven_gen.visualize_digits(digit=7, count=5)
|
|
|
|
print("\n2. Grid oluştur")
|
|
seven_gen.generate_grid(samples_per_digit=10, save_path='7gen_showcase.png')
|
|
|
|
print("\n3. PNG olarak kaydet")
|
|
seven_gen.save_as_png(digit=5, count=3, output_dir='output')
|
|
|
|
print("\n4. İnteraktif mod")
|
|
seven_gen.interactive_generate() |