|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torchvision import datasets, transforms
|
|
from torch.utils.data import DataLoader
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
import os
|
|
|
|
print("🚀 7Gen - Gelişmiş MNIST Üretici Sistemi 🚀")
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
print(f'Kullanılan cihaz: {device}')
|
|
|
|
|
|
batch_size = 64
|
|
latent_dim = 100
|
|
num_classes = 10
|
|
num_epochs = 100
|
|
lr = 0.0002
|
|
|
|
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.5], [0.5])
|
|
])
|
|
|
|
dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
|
|
|
|
class Generator(nn.Module):
|
|
def __init__(self):
|
|
super(Generator, self).__init__()
|
|
|
|
self.label_emb = nn.Embedding(num_classes, num_classes)
|
|
|
|
self.model = nn.Sequential(
|
|
nn.Linear(latent_dim + num_classes, 256),
|
|
nn.LeakyReLU(0.2),
|
|
nn.BatchNorm1d(256),
|
|
|
|
nn.Linear(256, 512),
|
|
nn.LeakyReLU(0.2),
|
|
nn.BatchNorm1d(512),
|
|
|
|
nn.Linear(512, 1024),
|
|
nn.LeakyReLU(0.2),
|
|
nn.BatchNorm1d(1024),
|
|
|
|
nn.Linear(1024, 784),
|
|
nn.Tanh()
|
|
)
|
|
|
|
def forward(self, noise, labels):
|
|
label_embedding = self.label_emb(labels)
|
|
gen_input = torch.cat((noise, label_embedding), -1)
|
|
img = self.model(gen_input)
|
|
img = img.view(img.size(0), 1, 28, 28)
|
|
return img
|
|
|
|
|
|
class Discriminator(nn.Module):
|
|
def __init__(self):
|
|
super(Discriminator, self).__init__()
|
|
|
|
self.label_emb = nn.Embedding(num_classes, num_classes)
|
|
|
|
self.model = nn.Sequential(
|
|
nn.Linear(784 + num_classes, 512),
|
|
nn.LeakyReLU(0.2),
|
|
nn.Dropout(0.3),
|
|
|
|
nn.Linear(512, 256),
|
|
nn.LeakyReLU(0.2),
|
|
nn.Dropout(0.3),
|
|
|
|
nn.Linear(256, 1),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
def forward(self, img, labels):
|
|
img_flat = img.view(img.size(0), -1)
|
|
label_embedding = self.label_emb(labels)
|
|
d_input = torch.cat((img_flat, label_embedding), -1)
|
|
validity = self.model(d_input)
|
|
return validity
|
|
|
|
|
|
generator = Generator().to(device)
|
|
discriminator = Discriminator().to(device)
|
|
|
|
|
|
adversarial_loss = nn.BCELoss()
|
|
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
|
|
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)
|
|
|
|
|
|
os.makedirs('generated_images', exist_ok=True)
|
|
|
|
|
|
print("\n🔥 7Gen Eğitimi Başlıyor...")
|
|
|
|
for epoch in range(num_epochs):
|
|
for i, (imgs, labels) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")):
|
|
imgs = imgs.to(device)
|
|
labels = labels.to(device)
|
|
batch_size = imgs.size(0)
|
|
|
|
|
|
valid = torch.ones(batch_size, 1).to(device)
|
|
fake = torch.zeros(batch_size, 1).to(device)
|
|
|
|
|
|
optimizer_G.zero_grad()
|
|
z = torch.randn(batch_size, latent_dim).to(device)
|
|
gen_labels = torch.randint(0, num_classes, (batch_size,)).to(device)
|
|
gen_imgs = generator(z, gen_labels)
|
|
|
|
g_loss = adversarial_loss(discriminator(gen_imgs, gen_labels), valid)
|
|
g_loss.backward()
|
|
optimizer_G.step()
|
|
|
|
|
|
optimizer_D.zero_grad()
|
|
real_loss = adversarial_loss(discriminator(imgs, labels), valid)
|
|
fake_loss = adversarial_loss(discriminator(gen_imgs.detach(), gen_labels), fake)
|
|
d_loss = (real_loss + fake_loss) / 2
|
|
|
|
d_loss.backward()
|
|
optimizer_D.step()
|
|
|
|
print(f"Epoch {epoch+1}/{num_epochs} - D loss: {d_loss:.4f}, G loss: {g_loss:.4f}")
|
|
|
|
|
|
if (epoch + 1) % 10 == 0:
|
|
with torch.no_grad():
|
|
z = torch.randn(100, latent_dim).to(device)
|
|
labels = torch.tensor([i for i in range(10) for _ in range(10)]).to(device)
|
|
gen_imgs = generator(z, labels)
|
|
gen_imgs = (gen_imgs + 1) / 2
|
|
|
|
fig, axes = plt.subplots(10, 10, figsize=(10, 10))
|
|
for i in range(10):
|
|
for j in range(10):
|
|
idx = i * 10 + j
|
|
axes[i, j].imshow(gen_imgs[idx][0].cpu().numpy(), cmap='gray')
|
|
axes[i, j].axis('off')
|
|
plt.savefig(f'generated_images/7gen_epoch_{epoch+1}.png')
|
|
plt.close()
|
|
|
|
|
|
os.makedirs('models', exist_ok=True)
|
|
torch.save(generator.state_dict(), 'models/7gen_generator.pth')
|
|
torch.save(discriminator.state_dict(), 'models/7gen_discriminator.pth')
|
|
|
|
print("\n✅ 7Gen eğitimi tamamlandı!")
|
|
|
|
|
|
def generate_digit(digit, num_samples=5):
|
|
generator.eval()
|
|
with torch.no_grad():
|
|
z = torch.randn(num_samples, latent_dim).to(device)
|
|
labels = torch.full((num_samples,), digit).to(device)
|
|
gen_imgs = generator(z, labels)
|
|
gen_imgs = (gen_imgs + 1) / 2
|
|
|
|
plt.figure(figsize=(10, 2))
|
|
for i in range(num_samples):
|
|
plt.subplot(1, num_samples, i+1)
|
|
plt.imshow(gen_imgs[i][0].cpu().numpy(), cmap='gray')
|
|
plt.axis('off')
|
|
plt.savefig(f'generated_images/digit_{digit}_samples.png')
|
|
plt.show()
|
|
|
|
|
|
print("\n🎯 Test örnekleri üretiliyor...")
|
|
for digit in range(10):
|
|
generate_digit(digit, num_samples=5)
|
|
|
|
print("\n🎉 7Gen hazır! generated_images klasörüne bak.") |