7Gen / 7gen_inference.py
NextGenC's picture
Upload 24 files
8add310 verified
# 7Gen Inference - Rakam Üretme Arayüzü
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import os
# Model yapısı (eğitimde kullandığımız ile aynı olmalı)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.label_emb = nn.Embedding(10, 10)
self.model = nn.Sequential(
nn.Linear(100 + 10, 256),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(256),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(512),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(1024),
nn.Linear(1024, 784),
nn.Tanh()
)
def forward(self, noise, labels):
label_embedding = self.label_emb(labels)
gen_input = torch.cat((noise, label_embedding), -1)
img = self.model(gen_input)
img = img.view(img.size(0), 1, 28, 28)
return img
# 7Gen sınıfı
class SevenGenInference:
def __init__(self, model_path='models/7gen_generator.pth'):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.latent_dim = 100
# Modeli yükle
self.generator = Generator().to(self.device)
self.generator.load_state_dict(torch.load(model_path, map_location=self.device))
self.generator.eval()
print(f"🚀 7Gen yüklendi! Cihaz: {self.device}")
def generate_digit(self, digit, count=5):
"""Belirli bir rakamdan istenen sayıda üret"""
with torch.no_grad():
z = torch.randn(count, self.latent_dim).to(self.device)
labels = torch.full((count,), digit).to(self.device)
images = self.generator(z, labels)
images = (images + 1) / 2 # [-1,1] -> [0,1]
return images.cpu()
def visualize_digits(self, digit, count=5, save_path=None):
"""Üretilen rakamları görselleştir"""
images = self.generate_digit(digit, count)
fig, axes = plt.subplots(1, count, figsize=(2*count, 2))
if count == 1:
axes = [axes]
for i, ax in enumerate(axes):
ax.imshow(images[i][0], cmap='gray')
ax.axis('off')
ax.set_title(f'Digit: {digit}')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"💾 Görsel kaydedildi: {save_path}")
plt.show()
def generate_grid(self, samples_per_digit=10, save_path=None):
"""Her rakamdan örneklerle 10x10 grid oluştur"""
all_images = []
for digit in range(10):
images = self.generate_digit(digit, samples_per_digit)
all_images.append(images)
all_images = torch.cat(all_images, dim=0)
fig, axes = plt.subplots(10, samples_per_digit, figsize=(15, 15))
for i in range(10):
for j in range(samples_per_digit):
idx = i * samples_per_digit + j
axes[i, j].imshow(all_images[idx][0], cmap='gray')
axes[i, j].axis('off')
if j == 0:
axes[i, j].set_ylabel(f'{i}', rotation=0, size=20, labelpad=20)
plt.suptitle('7Gen - Üretilen Rakamlar', size=20)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"💾 Grid kaydedildi: {save_path}")
plt.show()
def save_as_png(self, digit, count=1, output_dir='output'):
"""Tekil PNG dosyaları olarak kaydet"""
os.makedirs(output_dir, exist_ok=True)
images = self.generate_digit(digit, count)
for i in range(count):
img = images[i][0].numpy()
img = (img * 255).astype(np.uint8)
pil_img = Image.fromarray(img)
filename = f"{output_dir}/digit_{digit}_{i+1}.png"
pil_img.save(filename)
print(f"💾 Kaydedildi: {filename}")
def interactive_generate(self):
"""İnteraktif kullanım"""
print("\n🎮 7Gen İnteraktif Mod")
print("Çıkmak için 'q' yazın")
while True:
try:
digit_input = input("\nHangi rakamı üretmek istersin? (0-9): ")
if digit_input.lower() == 'q':
print("👋 Görüşürüz!")
break
digit = int(digit_input)
if 0 <= digit <= 9:
count = int(input("Kaç tane üreteyim? (1-20): "))
if 1 <= count <= 20:
self.visualize_digits(digit, count)
else:
print("❌ 1-20 arası bir sayı gir!")
else:
print("❌ 0-9 arası bir rakam gir!")
except ValueError:
print("❌ Geçerli bir sayı gir!")
except KeyboardInterrupt:
print("\n👋 Görüşürüz!")
break
# Ana kullanım
if __name__ == "__main__":
# 7Gen'i başlat
seven_gen = SevenGenInference()
# Örnekler
print("\n📝 Örnek kullanımlar:")
print("1. Tekil rakam üret")
seven_gen.visualize_digits(digit=7, count=5)
print("\n2. Grid oluştur")
seven_gen.generate_grid(samples_per_digit=10, save_path='7gen_showcase.png')
print("\n3. PNG olarak kaydet")
seven_gen.save_as_png(digit=5, count=3, output_dir='output')
print("\n4. İnteraktif mod")
seven_gen.interactive_generate()