1. Einführung in GANs
Hauptkomponenten von GANs
- Generator: Erzeugt neue Daten aus zufälligem Rauschen.
- Diskriminator: Unterscheidet zwischen echten Daten und vom Generator erzeugten Daten.
- Der Generator nimmt ein zufälliges Rauschen als Eingabe und erzeugt eine neue Dateninstanz.
- Der Diskriminator bewertet diese generierte Dateninstanz und entscheidet, ob sie echt oder generiert ist.
- Beide Netzwerke werden gleichzeitig trainiert, wobei der Generator versucht, den Diskriminator zu täuschen, und der Diskriminator versucht, die Fälschungen zu erkennen.
2. Aufbau eines einfachen GANs
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(True),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(True),
nn.Linear(hidden_size, output_size),
nn.Tanh()
)
def forward(self, x):
return self.main(x)
# Beispielhafter Generator
input_size = 100
hidden_size = 256
output_size = 784 # Für ein 28x28 Bild
gen = Generator(input_size, hidden_size, output_size)
print(gen)
Output:
Generator(
(main): Sequential(
(0): Linear(in_features=100, out_features=256, bias=True)
(1): ReLU(inplace=True)
(2): Linear(in_features=256, out_features=256, bias=True)
(3): ReLU(inplace=True)
(4): Linear(in_features=256, out_features=784, bias=True)
(5): Tanh()
)
)
Diskriminator:
class Discriminator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(True),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(True),
nn.Linear(hidden_size, output_size),
nn.Sigmoid()
)
def forward(self, x):
return self.main(x)
# Beispielhafter Diskriminator
input_size = 784 # Für ein 28x28 Bild
hidden_size = 256
output_size = 1
disc = Discriminator(input_size, hidden_size, output_size)
print(disc)
Output:
Discriminator(
(main): Sequential(
(0): Linear(in_features=784, out_features=256, bias=True)
(1): ReLU(inplace=True)
(2): Linear(in_features=256, out_features=256, bias=True)
(3): ReLU(inplace=True)
(4): Linear(in_features=256, out_features=1, bias=True)
(5): Sigmoid()
)
)
3. Praktisches Beispiel: Bildgenerierung
In diesem Beispiel werden wir ein einfaches GAN verwenden, um handgeschriebene Ziffern aus dem MNIST-Datensatz zu generieren.
Schritt 1: Datenvorbereitung
import torchvision
import torchvision.transforms as transforms
# Datenvorbereitung
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
# Laden des MNIST-Datensatzes
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
print("MNIST-Datensatz geladen")
Output:
MNIST-Datensatz geladen
Schritt 2: Training des GANs
import torch.optim as optim
# Hyperparameter
num_epochs = 50
batch_size = 64
learning_rate = 0.0002
# Verlustfunktion und Optimierer
criterion = nn.BCELoss()
optimizer_gen = optim.Adam(gen.parameters(), lr=learning_rate)
optimizer_disc = optim.Adam(disc.parameters(), lr=learning_rate)
# Training des GANs
for epoch in range(num_epochs):
for i, (images, _) in enumerate(train_loader):
# Bereiten der echten Daten
real_images = images.view(-1, 28*28)
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# Trainieren des Diskriminators
outputs = disc(real_images)
d_loss_real = criterion(outputs, real_labels)
real_score = outputs
# Generieren von Fake-Daten
z = torch.randn(batch_size, 100)
fake_images = gen(z)
outputs = disc(fake_images)
d_loss_fake = criterion(outputs, fake_labels)
fake_score = outputs
# Gesamter Diskriminatorverlust und Optimierung
d_loss = d_loss_real + d_loss_fake
optimizer_disc.zero_grad()
d_loss.backward()
optimizer_disc.step()
# Trainieren des Generators
z = torch.randn(batch_size, 100)
fake_images = gen(z)
outputs = disc(fake_images)
g_loss = criterion(outputs, real_labels)
optimizer_gen.zero_grad()
g_loss.backward()
optimizer_gen.step()
print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')
print('Training abgeschlossen')
Output:
Epoch [1/50], d_loss: 0.2087, g_loss: 2.8590
Epoch [2/50], d_loss: 0.1201, g_loss: 3.3875
...
Epoch [49/50], d_loss: 0.2516, g_loss: 2.7628
Epoch [50/50], d_loss: 0.1844, g_loss: 3.0174
Training abgeschlossen
Schritt 3: Generierung von Bildern
import matplotlib.pyplot as plt
# Generieren von Bildern nach dem Training
z = torch.randn(16, 100)
fake_images = gen(z)
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
fake_images = fake_images.data
# Anzeigen der generierten Bilder
fig, axes = plt.subplots(4, 4, figsize=(5, 5))
for i, ax in enumerate(axes.flatten()):
ax.imshow(fake_images[i].squeeze().numpy(), cmap='gray')
ax.axis('off')
plt.show()
Output:
(Angezeigte generierte Bilder in einem 4x4-Grid)
Fazit
Generative Adversarial Networks (GANs) sind eine leistungsstarke Methode zur Generierung realistischer Daten. In diesem Artikel haben wir die Grundlagen von GANs erklärt, ein einfaches GAN mit PyTorch erstellt und ein praktisches Beispiel zur Bildgenerierung durchgespielt. Durch das Verständnis und die Anwendung dieser Techniken können Sie leistungsfähige Modelle für kreative und datenintensive Anwendungen entwickeln. Nutzen Sie die Flexibilität und Leistungsfähigkeit von PyTorch, um Ihre Machine-Learning-Projekte erfolgreich umzusetzen.