GAN (Generative Adversarial Network)

Also known as: GAN, generative adversarial network, adversarial network, generative model
Deep LearningGenerative AiMachine LearningComputer VisionAlgorithms
Category: Deep Learning
Difficulty: advanced

Summary: A Generative Adversarial Network (GAN) is a machine learning architecture consisting of two neural networks—a generator and a discriminator—that compete against each other in a game-theoretic framework. The generator learns to create realistic synthetic data while the discriminator learns to distinguish between real and generated data, leading to increasingly sophisticated data generation capabilities across domains like images, text, and audio.

Overview

Generative Adversarial Networks (GANs) represent a revolutionary approach to generative modeling introduced by Ian Goodfellow in 2014. By setting up a competitive game between two neural networks, GANs have achieved remarkable success in generating realistic images, videos, audio, and other types of data. The adversarial training process leads to a generator that can produce increasingly convincing synthetic data, making GANs foundational to modern generative AI.

Core Architecture

The Adversarial Game

GANs consist of two competing neural networks:

class GAN:
    def __init__(self, latent_dim, data_dim):
        # Generator: Creates fake data from random noise
        self.generator = Generator(latent_dim, data_dim)
        
        # Discriminator: Distinguishes real from fake data  
        self.discriminator = Discriminator(data_dim)
        
    def adversarial_game(self):
        """The core adversarial training process"""
        
        # Generator's goal: Fool the discriminator
        # min_G E[log(1 - D(G(z)))]
        
        # Discriminator's goal: Correctly classify real vs fake
        # max_D E[log(D(x))] + E[log(1 - D(G(z)))]
        
        # This creates a minimax game:
        # min_G max_D V(D,G) = E[log(D(x))] + E[log(1-D(G(z)))]
        
        return "Nash equilibrium when G creates perfect fakes and D can't tell difference"

```text

### Generator Network

The generator transforms random noise into synthetic data:

```python
class Generator(nn.Module):
    """Generator network for creating fake data"""
    
    def __init__(self, latent_dim=100, img_size=28, channels=1):
        super().__init__()
        
        self.img_size = img_size
        self.channels = channels
        
        # Calculate output dimensions
        self.init_size = img_size // 4
        self.l1 = nn.Sequential(
            nn.Linear(latent_dim, 128 * self.init_size ** 2)
        )
        
        # Upsampling layers
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, channels, 3, stride=1, padding=1),
            nn.Tanh()  # Output in range [-1, 1]
        )
        
    def forward(self, z):
        """Generate fake data from random noise z"""
        
        # Map noise to feature maps
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        
        # Generate image through upsampling
        img = self.conv_blocks(out)
        
        return img

# Example usage

generator = Generator(latent_dim=100, img_size=28, channels=1)

## Generate fake MNIST digits

batch_size = 64
noise = torch.randn(batch_size, 100)  # Random noise
fake_images = generator(noise)         # Generated images

```text

### Discriminator Network

The discriminator classifies data as real or fake:

```python
class Discriminator(nn.Module):
    """Discriminator network for real vs fake classification"""
    
    def __init__(self, img_size=28, channels=1):
        super().__init__()
        
        def discriminator_block(in_filters, out_filters, bn=True):
            """Returns layers of each discriminator block"""
            block = [
                nn.Conv2d(in_filters, out_filters, 3, 2, 1),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Dropout2d(0.25)
            ]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block
        
        self.model = nn.Sequential(
            *discriminator_block(channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )
        
        # Calculate size after convolutions
        ds_size = img_size // 2 ** 4
        self.adv_layer = nn.Sequential(
            nn.Linear(128 * ds_size ** 2, 1),
            nn.Sigmoid()  # Output probability
        )
        
    def forward(self, img):
        """Classify image as real (1) or fake (0)"""
        
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        
        return validity

## Example usage

discriminator = Discriminator(img_size=28, channels=1)

## Evaluate real and fake images

real_images = torch.randn(64, 1, 28, 28)  # Real data
fake_images = generator(torch.randn(64, 100))  # Generated data

real_scores = discriminator(real_images)      # Should be close to 1
fake_scores = discriminator(fake_images)      # Should be close to 0

```text

## Training Process

### Adversarial Training Loop

```python
class GANTrainer:
    def __init__(self, generator, discriminator, lr=0.0002, b1=0.5, b2=0.999):
        self.generator = generator
        self.discriminator = discriminator
        
        # Separate optimizers for G and D
        self.optimizer_G = torch.optim.Adam(
            generator.parameters(), lr=lr, betas=(b1, b2)
        )
        self.optimizer_D = torch.optim.Adam(
            discriminator.parameters(), lr=lr, betas=(b1, b2)
        )
        
        # Loss function
        self.adversarial_loss = nn.BCELoss()
        
    def train_step(self, real_batch):
        """Single training step for both networks"""
        
        batch_size = real_batch.size(0)
        
        # Labels for real and fake data
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)
        
        # ---------------------
        #  Train Discriminator
        # ---------------------
        
        self.optimizer_D.zero_grad()
        
        # Real data
        real_validity = self.discriminator(real_batch)
        real_loss = self.adversarial_loss(real_validity, real_labels)
        
        # Fake data
        z = torch.randn(batch_size, 100)  # Random noise
        fake_batch = self.generator(z).detach()  # Don't compute gradients for G
        fake_validity = self.discriminator(fake_batch)
        fake_loss = self.adversarial_loss(fake_validity, fake_labels)
        
        # Total discriminator loss
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        self.optimizer_D.step()
        
        # -----------------
        #  Train Generator  
        # -----------------
        
        self.optimizer_G.zero_grad()
        
        # Generate fake data
        z = torch.randn(batch_size, 100)
        fake_batch = self.generator(z)
        
        # Generator wants discriminator to classify fake as real
        fake_validity = self.discriminator(fake_batch)
        g_loss = self.adversarial_loss(fake_validity, real_labels)  # Use real_labels!
        
        g_loss.backward()
        self.optimizer_G.step()
        
        return {
            'discriminator_loss': d_loss.item(),
            'generator_loss': g_loss.item(),
            'real_accuracy': (real_validity > 0.5).float().mean().item(),
            'fake_accuracy': (fake_validity < 0.5).float().mean().item()
        }
    
    def train(self, dataloader, epochs):
        """Full training loop"""
        
        for epoch in range(epochs):
            epoch_d_loss = 0
            epoch_g_loss = 0
            
            for i, (real_batch, _) in enumerate(dataloader):
                
                # Training step
                metrics = self.train_step(real_batch)
                
                epoch_d_loss += metrics['discriminator_loss']
                epoch_g_loss += metrics['generator_loss']
                
                # Log progress
                if i % 100 == 0:
                    print(f"Epoch [{epoch}/{epochs}] Batch [{i}/{len(dataloader)}]")
                    print(f"D_loss: {metrics['discriminator_loss']:.4f}, "
                          f"G_loss: {metrics['generator_loss']:.4f}")
            
            # Epoch summary
            avg_d_loss = epoch_d_loss / len(dataloader)
            avg_g_loss = epoch_g_loss / len(dataloader)
            
            print(f"Epoch {epoch} completed:")
            print(f"Average D_loss: {avg_d_loss:.4f}")
            print(f"Average G_loss: {avg_g_loss:.4f}")
            
            # Generate samples for visualization
            if epoch % 10 == 0:
                self.generate_samples(f"epoch_{epoch}.png")
    
    def generate_samples(self, filename, n_samples=64):
        """Generate and save sample images"""
        
        self.generator.eval()
        with torch.no_grad():
            z = torch.randn(n_samples, 100)
            fake_images = self.generator(z)
            
            # Save images in a grid
            save_image(fake_images, filename, nrow=8, normalize=True)
        
        self.generator.train()

## Training example

trainer = GANTrainer(generator, discriminator)
trainer.train(dataloader, epochs=200)

```text

### Training Challenges and Solutions

```python
class GANTrainingStabilizer:
    """Techniques to stabilize GAN training"""
    
    def __init__(self):
        self.training_techniques = {
            'label_smoothing': self.apply_label_smoothing,
            'feature_matching': self.feature_matching_loss,
            'historical_averaging': self.historical_averaging,
            'minibatch_discrimination': self.minibatch_discrimination,
            'spectral_normalization': self.apply_spectral_norm
        }
    
    def apply_label_smoothing(self, labels, smoothing=0.1):
        """Apply label smoothing to prevent overconfident discriminator"""
        
        # Real labels: 1 -> 0.9, Fake labels: 0 -> 0.1
        if labels.mean() > 0.5:  # Real labels
            return labels - smoothing
        else:  # Fake labels
            return labels + smoothing
    
    def feature_matching_loss(self, real_features, fake_features):
        """Feature matching to prevent mode collapse"""
        
        # Match statistics of intermediate discriminator features
        real_mean = real_features.mean(dim=0)
        fake_mean = fake_features.mean(dim=0)
        
        feature_loss = F.mse_loss(fake_mean, real_mean)
        
        return feature_loss
    
    def spectral_normalization_conv(self, conv_layer):
        """Apply spectral normalization to convolution layer"""
        
        return nn.utils.spectral_norm(conv_layer)
    
    def wasserstein_loss(self, real_validity, fake_validity):
        """Wasserstein loss for improved training stability"""
        
        # WGAN loss: maximize real_scores - fake_scores
        wasserstein_loss = -torch.mean(real_validity) + torch.mean(fake_validity)
        
        return wasserstein_loss
    
    def gradient_penalty(self, discriminator, real_data, fake_data, device):
        """Gradient penalty for WGAN-GP"""
        
        batch_size = real_data.size(0)
        
        # Random interpolation between real and fake data
        alpha = torch.rand(batch_size, 1, 1, 1).to(device)
        interpolated = alpha * real_data + (1 - alpha) * fake_data
        interpolated.requires_grad_(True)
        
        # Discriminator output for interpolated data
        d_interpolated = discriminator(interpolated)
        
        # Compute gradients
        gradients = torch.autograd.grad(
            outputs=d_interpolated,
            inputs=interpolated,
            grad_outputs=torch.ones_like(d_interpolated),
            create_graph=True,
            retain_graph=True
        )[0]
        
        # Gradient penalty
        gradient_norm = gradients.view(batch_size, -1).norm(2, dim=1)
        penalty = torch.mean((gradient_norm - 1) ** 2)
        
        return penalty

class ImprovedGANTrainer(GANTrainer):
    """GAN trainer with stability improvements"""
    
    def __init__(self, generator, discriminator, **kwargs):
        super().__init__(generator, discriminator, **kwargs)
        self.stabilizer = GANTrainingStabilizer()
        self.lambda_gp = 10  # Gradient penalty coefficient
        
    def train_step_wgan_gp(self, real_batch):
        """Training step with WGAN-GP improvements"""
        
        batch_size = real_batch.size(0)
        
        # Train Discriminator
        for _ in range(5):  # Train D more than G
            self.optimizer_D.zero_grad()
            
            # Real data
            real_validity = self.discriminator(real_batch)
            
            # Fake data
            z = torch.randn(batch_size, 100)
            fake_batch = self.generator(z).detach()
            fake_validity = self.discriminator(fake_batch)
            
            # Wasserstein loss
            d_loss = self.stabilizer.wasserstein_loss(real_validity, fake_validity)
            
            # Gradient penalty
            gp = self.stabilizer.gradient_penalty(
                self.discriminator, real_batch, fake_batch, real_batch.device
            )
            
            # Total discriminator loss
            d_total_loss = d_loss + self.lambda_gp * gp
            d_total_loss.backward()
            self.optimizer_D.step()
        
        # Train Generator
        self.optimizer_G.zero_grad()
        
        z = torch.randn(batch_size, 100)
        fake_batch = self.generator(z)
        fake_validity = self.discriminator(fake_batch)
        
        # Generator loss (maximize fake_validity)
        g_loss = -torch.mean(fake_validity)
        g_loss.backward()
        self.optimizer_G.step()
        
        return {
            'discriminator_loss': d_total_loss.item(),
            'generator_loss': g_loss.item(),
            'gradient_penalty': gp.item()
        }

```text

## GAN Variants

### Deep Convolutional GAN (DCGAN)

```python
class DCGANGenerator(nn.Module):
    """DCGAN Generator with architectural best practices"""
    
    def __init__(self, latent_dim=100, feature_maps=64, channels=3):
        super().__init__()
        
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(latent_dim, feature_maps * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(feature_maps * 8),
            nn.ReLU(True),
            
            # state size: (feature_maps*8) x 4 x 4
            nn.ConvTranspose2d(feature_maps * 8, feature_maps * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 4),
            nn.ReLU(True),
            
            # state size: (feature_maps*4) x 8 x 8
            nn.ConvTranspose2d(feature_maps * 4, feature_maps * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 2),
            nn.ReLU(True),
            
            # state size: (feature_maps*2) x 16 x 16
            nn.ConvTranspose2d(feature_maps * 2, feature_maps, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps),
            nn.ReLU(True),
            
            # state size: (feature_maps) x 32 x 32
            nn.ConvTranspose2d(feature_maps, channels, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size: (channels) x 64 x 64
        )
        
    def forward(self, input):
        return self.main(input)

class DCGANDiscriminator(nn.Module):
    """DCGAN Discriminator with architectural best practices"""
    
    def __init__(self, channels=3, feature_maps=64):
        super().__init__()
        
        self.main = nn.Sequential(
            # input is (channels) x 64 x 64
            nn.Conv2d(channels, feature_maps, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            # state size: (feature_maps) x 32 x 32
            nn.Conv2d(feature_maps, feature_maps * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 2),
            nn.LeakyReLU(0.2, inplace=True),
            
            # state size: (feature_maps*2) x 16 x 16
            nn.Conv2d(feature_maps * 2, feature_maps * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            # state size: (feature_maps*4) x 8 x 8
            nn.Conv2d(feature_maps * 4, feature_maps * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 8),
            nn.LeakyReLU(0.2, inplace=True),
            
            # state size: (feature_maps*8) x 4 x 4
            nn.Conv2d(feature_maps * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
        
    def forward(self, input):
        return self.main(input).view(-1, 1).squeeze(1)

## DCGAN architectural guidelines

dcgan_principles = {
    "generator": [
        "Use transposed convolutions for upsampling",
        "Use batch normalization in all layers except output",
        "Use ReLU activation in all layers except output (use Tanh)",
        "No fully connected layers except first layer"
    ],
    
    "discriminator": [
        "Use strided convolutions for downsampling",
        "Use batch normalization in all layers except first",
        "Use LeakyReLU activations",
        "No fully connected layers except last layer"
    ]
}

```text

### Conditional GAN (cGAN)

```python
class ConditionalGenerator(nn.Module):
    """Generator that takes both noise and class label as input"""
    
    def __init__(self, latent_dim=100, n_classes=10, img_size=28, channels=1):
        super().__init__()
        
        self.img_size = img_size
        self.channels = channels
        
        # Embedding layer for class labels
        self.label_emb = nn.Embedding(n_classes, n_classes)
        
        # Input dimension is noise + embedded label
        input_dim = latent_dim + n_classes
        
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *block(input_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, channels * img_size * img_size),
            nn.Tanh()
        )
        
    def forward(self, noise, labels):
        """Generate conditioned on class labels"""
        
        # Embed labels
        label_embedding = self.label_emb(labels)
        
        # Concatenate noise and label embedding
        gen_input = torch.cat((noise, label_embedding), -1)
        
        # Generate image
        img = self.model(gen_input)
        img = img.view(img.size(0), self.channels, self.img_size, self.img_size)
        
        return img

class ConditionalDiscriminator(nn.Module):
    """Discriminator that takes both image and class label"""
    
    def __init__(self, n_classes=10, img_size=28, channels=1):
        super().__init__()
        
        # Embedding for class labels
        self.label_emb = nn.Embedding(n_classes, n_classes)
        
        # Input is image + embedded label
        input_dim = channels * img_size * img_size + n_classes
        
        self.model = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
    def forward(self, img, labels):
        """Discriminate based on image and label"""
        
        # Flatten image
        img_flat = img.view(img.size(0), -1)
        
        # Embed labels
        label_embedding = self.label_emb(labels)
        
        # Concatenate image and label
        d_input = torch.cat((img_flat, label_embedding), -1)
        
        # Discriminate
        validity = self.model(d_input)
        
        return validity

## Conditional GAN training

def train_conditional_gan(generator, discriminator, dataloader):
    """Training loop for conditional GAN"""
    
    for epoch in range(epochs):
        for i, (real_imgs, labels) in enumerate(dataloader):
            
            batch_size = real_imgs.size(0)
            
            # Train Discriminator
            optimizer_D.zero_grad()
            
            # Real images
            real_validity = discriminator(real_imgs, labels)
            real_loss = adversarial_loss(real_validity, torch.ones_like(real_validity))
            
            # Fake images
            z = torch.randn(batch_size, 100)
            fake_labels = torch.randint(0, 10, (batch_size,))  # Random labels
            fake_imgs = generator(z, fake_labels)
            fake_validity = discriminator(fake_imgs.detach(), fake_labels)
            fake_loss = adversarial_loss(fake_validity, torch.zeros_like(fake_validity))
            
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            optimizer_D.step()
            
            # Train Generator
            optimizer_G.zero_grad()
            
            # Generate fake images
            z = torch.randn(batch_size, 100)
            fake_labels = torch.randint(0, 10, (batch_size,))
            fake_imgs = generator(z, fake_labels)
            fake_validity = discriminator(fake_imgs, fake_labels)
            
            g_loss = adversarial_loss(fake_validity, torch.ones_like(fake_validity))
            g_loss.backward()
            optimizer_G.step()

```text

### StyleGAN

```python
class StyleGenerator(nn.Module):
    """Simplified StyleGAN-inspired generator"""
    
    def __init__(self, latent_dim=512, n_layers=8):
        super().__init__()
        
        # Mapping network: Z -> W
        self.mapping_network = nn.Sequential(
            *[nn.Sequential(
                nn.Linear(latent_dim, latent_dim),
                nn.LeakyReLU(0.2)
            ) for _ in range(8)]
        )
        
        # Synthesis network layers
        self.synthesis_layers = nn.ModuleList()
        
        for i in range(n_layers):
            # Each layer has adaptive instance normalization
            layer = StyleSynthesisLayer(
                in_channels=512 // (2 ** min(i, 4)),
                out_channels=512 // (2 ** min(i+1, 4)),
                style_dim=latent_dim
            )
            self.synthesis_layers.append(layer)
    
    def forward(self, z, inject_noise=True):
        """Generate image with style control"""
        
        # Map to style space
        w = self.mapping_network(z)
        
        # Start with learned constant
        x = torch.ones(z.size(0), 512, 4, 4).to(z.device)
        
        # Apply synthesis layers
        for layer in self.synthesis_layers:
            x = layer(x, w, inject_noise)
        
        return x

class StyleSynthesisLayer(nn.Module):
    """Single synthesis layer with style modulation"""
    
    def __init__(self, in_channels, out_channels, style_dim):
        super().__init__()
        
        # Convolution
        self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        
        # Style modulation
        self.style_scale = nn.Linear(style_dim, in_channels)
        self.style_shift = nn.Linear(style_dim, in_channels)
        
        # Noise injection
        self.noise_strength = nn.Parameter(torch.zeros(1))
        
        # Activation
        self.activation = nn.LeakyReLU(0.2)
        
    def forward(self, x, style, inject_noise=True):
        # Style modulation (Adaptive Instance Normalization)
        style_scale = self.style_scale(style).unsqueeze(2).unsqueeze(3)
        style_shift = self.style_shift(style).unsqueeze(2).unsqueeze(3)
        
        # Normalize features
        x_norm = F.instance_norm(x)
        
        # Apply style
        x_styled = style_scale * x_norm + style_shift
        
        # Convolution
        x = self.conv(x_styled)
        
        # Add noise
        if inject_noise:
            noise = torch.randn_like(x)
            x = x + self.noise_strength * noise
        
        # Activation
        x = self.activation(x)
        
        return x

def style_mixing_regularization(generator, z1, z2, mixing_prob=0.5):
    """Style mixing regularization for StyleGAN"""
    
    if torch.rand(1) < mixing_prob:
        # Use different styles for different layers
        crossover_point = torch.randint(1, len(generator.synthesis_layers), (1,))
        
        # Generate with mixed styles
        w1 = generator.mapping_network(z1)
        w2 = generator.mapping_network(z2)
        
        # Apply style mixing at random crossover point
        mixed_w = torch.cat([w1[:crossover_point], w2[crossover_point:]])
        
        return generator.synthesis_forward(mixed_w)
    else:
        return generator(z1)

```text

## Applications and Use Cases

### Image Generation and Editing

```python
class ImageGANApplications:
    """Various applications of GANs in image processing"""
    
    def __init__(self):
        self.applications = {
            'face_generation': self.generate_faces,
            'image_super_resolution': self.super_resolution,
            'image_inpainting': self.inpaint_images,
            'domain_transfer': self.domain_transfer,
            'data_augmentation': self.augment_data
        }
    
    def generate_faces(self, generator, n_faces=100):
        """Generate realistic face images"""
        
        with torch.no_grad():
            z = torch.randn(n_faces, 100)
            fake_faces = generator(z)
            
        return fake_faces
    
    def super_resolution(self, lr_image, sr_generator):
        """Enhance image resolution using SRGAN"""
        
        with torch.no_grad():
            hr_image = sr_generator(lr_image)
            
        return hr_image
    
    def inpaint_images(self, masked_image, mask, inpainting_generator):
        """Fill missing parts of images"""
        
        with torch.no_grad():
            completed_image = inpainting_generator(masked_image, mask)
            
        return completed_image
    
    def domain_transfer(self, source_image, cyclegan):
        """Transfer image from one domain to another (e.g., day to night)"""
        
        with torch.no_grad():
            target_image = cyclegan.generator_AB(source_image)
            
        return target_image
    
    def augment_data(self, generator, n_augmentations=1000):
        """Generate synthetic training data"""
        
        augmented_data = []
        
        with torch.no_grad():
            for _ in range(n_augmentations):
                z = torch.randn(1, 100)
                synthetic_sample = generator(z)
                augmented_data.append(synthetic_sample)
        
        return torch.cat(augmented_data, dim=0)

## Text-to-Image Generation (Simplified)

class TextToImageGAN:
    """Simplified text-to-image GAN"""
    
    def __init__(self, text_encoder, generator):
        self.text_encoder = text_encoder  # E.g., CLIP text encoder
        self.generator = generator
        
    def generate_from_text(self, text_descriptions):
        """Generate images from text descriptions"""
        
        # Encode text to conditioning vector
        text_embeddings = self.text_encoder(text_descriptions)
        
        # Generate images conditioned on text
        z = torch.randn(len(text_descriptions), 100)
        generated_images = self.generator(z, text_embeddings)
        
        return generated_images
    
    def interpolate_between_texts(self, text1, text2, steps=10):
        """Generate interpolation between two text descriptions"""
        
        # Encode texts
        emb1 = self.text_encoder([text1])
        emb2 = self.text_encoder([text2])
        
        # Linear interpolation
        alphas = torch.linspace(0, 1, steps)
        interpolated_images = []
        
        for alpha in alphas:
            interp_emb = alpha * emb2 + (1 - alpha) * emb1
            z = torch.randn(1, 100)
            img = self.generator(z, interp_emb)
            interpolated_images.append(img)
        
        return torch.cat(interpolated_images, dim=0)

```text

### Audio and Music Generation

```python
class AudioGAN:
    """GAN for audio waveform generation"""
    
    def __init__(self, sample_rate=16000, window_size=1024):
        self.sample_rate = sample_rate
        self.window_size = window_size
        
        # Generator: noise -> raw audio waveform
        self.generator = self.build_audio_generator()
        
        # Discriminator: multiple scales for temporal modeling
        self.discriminators = nn.ModuleList([
            self.build_discriminator(scale) for scale in [1, 2, 4]
        ])
    
    def build_audio_generator(self):
        """Build generator for raw audio waveforms"""
        
        return nn.Sequential(
            # Transpose convolutions for upsampling
            nn.ConvTranspose1d(100, 1024, 4, stride=1),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            
            nn.ConvTranspose1d(1024, 512, 4, stride=2, padding=1),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),
            
            nn.ConvTranspose1d(512, 256, 4, stride=2, padding=1),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            
            nn.ConvTranspose1d(256, 128, 4, stride=2, padding=1),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2),
            
            nn.ConvTranspose1d(128, 1, 4, stride=2, padding=1),
            nn.Tanh()  # Audio samples in [-1, 1]
        )
    
    def build_discriminator(self, scale):
        """Build multi-scale discriminator"""
        
        return nn.Sequential(
            nn.Conv1d(1, 64, 15, stride=scale, padding=7),
            nn.LeakyReLU(0.2),
            
            nn.Conv1d(64, 128, 41, stride=4, padding=20),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2),
            
            nn.Conv1d(128, 256, 41, stride=4, padding=20),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            
            nn.Conv1d(256, 512, 41, stride=4, padding=20),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),
            
            nn.Conv1d(512, 1, 3, stride=1, padding=1),
            nn.Sigmoid()
        )
    
    def generate_audio(self, duration_seconds=2.0):
        """Generate audio of specified duration"""
        
        # Calculate required latent size
        target_samples = int(duration_seconds * self.sample_rate)
        
        # Generate
        with torch.no_grad():
            z = torch.randn(1, 100, 1)  # Latent noise
            audio_waveform = self.generator(z)
            
            # Trim or pad to desired length
            if audio_waveform.size(-1) > target_samples:
                audio_waveform = audio_waveform[:, :, :target_samples]
            elif audio_waveform.size(-1) < target_samples:
                padding = target_samples - audio_waveform.size(-1)
                audio_waveform = F.pad(audio_waveform, (0, padding))
        
        return audio_waveform.squeeze()

```text

## Evaluation Metrics

### Quantitative Metrics

```python
class GANEvaluator:
    """Comprehensive evaluation of GAN performance"""
    
    def __init__(self, device='cuda'):
        self.device = device
        self.inception_model = self.load_inception_model()
        
    def load_inception_model(self):
        """Load pre-trained Inception model for FID/IS computation"""
        
        from torchvision.models import inception_v3
        
        model = inception_v3(pretrained=True, transform_input=False)
        model.eval()
        model.to(self.device)
        
        return model
    
    def compute_fid(self, real_images, fake_images):
        """Compute Frechet Inception Distance"""
        
        # Extract features using Inception model
        real_features = self.extract_inception_features(real_images)
        fake_features = self.extract_inception_features(fake_images)
        
        # Compute statistics
        real_mean = np.mean(real_features, axis=0)
        fake_mean = np.mean(fake_features, axis=0)
        
        real_cov = np.cov(real_features, rowvar=False)
        fake_cov = np.cov(fake_features, rowvar=False)
        
        # FID formula
        diff = real_mean - fake_mean
        fid = np.dot(diff, diff) + np.trace(real_cov + fake_cov - 2 * sqrtm(real_cov @ fake_cov))
        
        return fid.real
    
    def compute_inception_score(self, fake_images, splits=10):
        """Compute Inception Score"""
        
        # Get predictions from Inception model
        with torch.no_grad():
            fake_images_resized = F.interpolate(fake_images, size=(299, 299), mode='bilinear')
            predictions = F.softmax(self.inception_model(fake_images_resized), dim=1)
        
        # Split into groups
        predictions = predictions.cpu().numpy()
        scores = []
        
        for i in range(splits):
            part = predictions[i * len(predictions) // splits:(i + 1) * len(predictions) // splits]
            
            # KL divergence calculation
            kl_div = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, axis=0), 0)))
            kl_div = np.mean(np.sum(kl_div, axis=1))
            scores.append(np.exp(kl_div))
        
        return np.mean(scores), np.std(scores)
    
    def compute_lpips(self, real_images, fake_images):
        """Compute Learned Perceptual Image Patch Similarity"""
        
        import lpips
        
        # Initialize LPIPS model
        lpips_model = lpips.LPIPS(net='alex').to(self.device)
        
        # Compute perceptual distances
        distances = []
        
        for real_img, fake_img in zip(real_images, fake_images):
            real_img = real_img.unsqueeze(0)
            fake_img = fake_img.unsqueeze(0)
            
            distance = lpips_model(real_img, fake_img)
            distances.append(distance.item())
        
        return np.mean(distances)
    
    def compute_precision_recall(self, real_features, fake_features, k=3):
        """Compute precision and recall using k-nearest neighbors"""
        
        # Compute pairwise distances
        real_to_fake = self.compute_pairwise_distances(real_features, fake_features)
        fake_to_real = self.compute_pairwise_distances(fake_features, real_features)
        
        # Find k-nearest neighbors
        real_nn_fake = np.sort(real_to_fake, axis=1)[:, :k]
        fake_nn_real = np.sort(fake_to_real, axis=1)[:, :k]
        
        # Compute precision and recall
        precision = np.mean(fake_nn_real[:, -1] < np.median(real_nn_fake[:, -1]))
        recall = np.mean(real_nn_fake[:, -1] < np.median(fake_nn_real[:, -1]))
        
        return precision, recall
    
    def comprehensive_evaluation(self, generator, real_dataloader, n_samples=10000):
        """Run comprehensive evaluation suite"""
        
        # Generate samples
        fake_images = []
        real_images = []
        
        generator.eval()
        with torch.no_grad():
            # Generate fake samples
            for _ in range(n_samples // 100):
                z = torch.randn(100, 100).to(self.device)
                fake_batch = generator(z)
                fake_images.append(fake_batch)
            
            # Collect real samples
            for i, (real_batch, _) in enumerate(real_dataloader):
                real_images.append(real_batch.to(self.device))
                if i * real_batch.size(0) >= n_samples:
                    break
        
        fake_images = torch.cat(fake_images, dim=0)[:n_samples]
        real_images = torch.cat(real_images, dim=0)[:n_samples]
        
        # Compute metrics
        results = {}
        
        print("Computing FID...")
        results['fid'] = self.compute_fid(real_images, fake_images)
        
        print("Computing Inception Score...")
        is_mean, is_std = self.compute_inception_score(fake_images)
        results['inception_score_mean'] = is_mean
        results['inception_score_std'] = is_std
        
        print("Computing LPIPS...")
        results['lpips'] = self.compute_lpips(real_images[:1000], fake_images[:1000])
        
        print("Computing Precision/Recall...")
        real_feats = self.extract_inception_features(real_images)
        fake_feats = self.extract_inception_features(fake_images)
        precision, recall = self.compute_precision_recall(real_feats, fake_feats)
        results['precision'] = precision
        results['recall'] = recall
        
        return results

## Usage example

evaluator = GANEvaluator()
metrics = evaluator.comprehensive_evaluation(generator, test_dataloader)

print("GAN Evaluation Results:")
for metric, value in metrics.items():
    print(f"{metric}: {value:.4f}")

```text

### Qualitative Assessment

```python
class QualitativeEvaluator:
    """Tools for qualitative evaluation of GANs"""
    
    def __init__(self, generator):
        self.generator = generator
        
    def generate_interpolation_grid(self, n_steps=10, save_path='interpolation.png'):
        """Generate interpolation between random points in latent space"""
        
        # Random start and end points
        z1 = torch.randn(1, 100)
        z2 = torch.randn(1, 100)
        
        # Linear interpolation
        alphas = torch.linspace(0, 1, n_steps).unsqueeze(1)
        interpolated_z = alphas * z2 + (1 - alphas) * z1
        
        # Generate images
        with torch.no_grad():
            interpolated_images = self.generator(interpolated_z)
        
        # Save as grid
        save_image(interpolated_images, save_path, nrow=n_steps, normalize=True)
        
        return interpolated_images
    
    def analyze_mode_collapse(self, n_samples=1000):
        """Check for mode collapse by analyzing diversity"""
        
        # Generate many samples
        generated_samples = []
        
        with torch.no_grad():
            for _ in range(n_samples // 100):
                z = torch.randn(100, 100)
                batch = self.generator(z)
                generated_samples.append(batch)
        
        all_samples = torch.cat(generated_samples, dim=0)
        
        # Compute pairwise similarities
        flattened = all_samples.view(n_samples, -1)
        similarities = F.cosine_similarity(
            flattened.unsqueeze(1),
            flattened.unsqueeze(0),
            dim=2
        )
        
        # Remove diagonal (self-similarity)
        mask = torch.eye(n_samples).bool()
        similarities[mask] = 0
        
        # Analyze statistics
        mean_similarity = similarities.mean().item()
        max_similarity = similarities.max().item()
        
        # High mean similarity might indicate mode collapse
        mode_collapse_score = mean_similarity
        
        return {
            'mean_similarity': mean_similarity,
            'max_similarity': max_similarity,
            'mode_collapse_score': mode_collapse_score,
            'likely_mode_collapse': mode_collapse_score > 0.8
        }
    
    def visualize_training_progress(self, checkpoints_dir):
        """Create visualization of training progress"""
        
        import os
        from PIL import Image
        
        checkpoint_files = sorted([f for f in os.listdir(checkpoints_dir) if f.endswith('.pt')])
        
        progress_images = []
        
        for checkpoint_file in checkpoint_files:
            # Load checkpoint
            checkpoint_path = os.path.join(checkpoints_dir, checkpoint_file)
            checkpoint = torch.load(checkpoint_path)
            self.generator.load_state_dict(checkpoint['generator'])
            
            # Generate sample
            with torch.no_grad():
                z = torch.randn(16, 100)  # Fixed noise for consistency
                sample = self.generator(z)
                
            # Convert to PIL and store
            grid = make_grid(sample, nrow=4, normalize=True)
            grid_np = grid.permute(1, 2, 0).numpy()
            grid_pil = Image.fromarray((grid_np * 255).astype(np.uint8))
            progress_images.append(grid_pil)
        
        # Create animation/video
        progress_images[0].save(
            'training_progress.gif',
            save_all=True,
            append_images=progress_images[1:],
            duration=500,
            loop=0
        )

```text

## Common Challenges and Solutions

### Training Instability

```python
class TrainingStabilizer:
    """Solutions for common GAN training problems"""
    
    def __init__(self):
        self.stabilization_techniques = {
            'mode_collapse': self.prevent_mode_collapse,
            'vanishing_gradients': self.handle_vanishing_gradients,
            'training_oscillation': self.stabilize_training,
            'discriminator_overpowering': self.balance_networks
        }
    
    def prevent_mode_collapse(self, generator, discriminator):
        """Techniques to prevent mode collapse"""
        
        solutions = {
            'unrolled_gan': self.implement_unrolled_gan,
            'minibatch_discrimination': self.add_minibatch_discrimination,
            'feature_matching': self.add_feature_matching,
            'diverse_loss': self.add_diversity_loss
        }
        
        return solutions
    
    def implement_unrolled_gan(self, generator, discriminator, unroll_steps=5):
        """Unrolled GAN to prevent mode collapse"""
        
        # Store original discriminator state
        original_state = copy.deepcopy(discriminator.state_dict())
        
        # Unroll discriminator training
        for _ in range(unroll_steps):
            # Simulate discriminator training step
            # (simplified - actual implementation more complex)
            pass
        
        # Compute generator loss with "unrolled" discriminator
        generator_loss = self.compute_generator_loss(generator, discriminator)
        
        # Restore discriminator state
        discriminator.load_state_dict(original_state)
        
        return generator_loss
    
    def add_minibatch_discrimination(self, discriminator, num_features=100):
        """Add minibatch discrimination layer"""
        
        class MinibatchDiscrimination(nn.Module):
            def __init__(self, input_features, num_features, num_kernels):
                super().__init__()
                self.num_features = num_features
                self.num_kernels = num_kernels
                self.T = nn.Parameter(torch.randn(input_features, num_features * num_kernels))
                
            def forward(self, x):
                # Compute minibatch features
                M = torch.mm(x, self.T)
                M = M.view(-1, self.num_features, self.num_kernels)
                
                # Compute L1 distances between samples
                diffs = M.unsqueeze(0) - M.unsqueeze(1)
                abs_diffs = torch.abs(diffs).sum(dim=2)
                
                # Sum exponential of negative distances
                minibatch_features = torch.exp(-abs_diffs).sum(dim=1)
                
                return torch.cat([x, minibatch_features], dim=1)
        
        # Add to discriminator architecture
        mb_disc_layer = MinibatchDiscrimination(
            input_features=discriminator.final_layer_size,
            num_features=num_features,
            num_kernels=5
        )
        
        return mb_disc_layer
    
    def balance_networks(self, d_loss, g_loss, d_lr, g_lr):
        """Dynamic learning rate balancing"""
        
        # If discriminator is too strong, reduce its learning rate
        if d_loss < 0.1 and g_loss > 2.0:
            d_lr *= 0.9
            g_lr *= 1.1
        
        # If generator is too strong, reduce its learning rate
        elif g_loss < 0.1 and d_loss > 2.0:
            g_lr *= 0.9
            d_lr *= 1.1
        
        return d_lr, g_lr
    
    def implement_progressive_growing(self, generator, discriminator, current_resolution):
        """Progressive growing of GANs for stable high-resolution training"""
        
        # Gradually increase resolution during training
        resolution_schedule = [4, 8, 16, 32, 64, 128, 256, 512, 1024]
        
        if current_resolution in resolution_schedule:
            next_idx = resolution_schedule.index(current_resolution) + 1
            
            if next_idx < len(resolution_schedule):
                next_resolution = resolution_schedule[next_idx]
                
                # Add new layers to networks
                generator = self.add_resolution_layer(generator, next_resolution)
                discriminator = self.add_resolution_layer(discriminator, next_resolution)
        
        return generator, discriminator

```text

## Future Directions

### Emerging Architectures

```python
class NextGenGANs:
    """Emerging GAN architectures and techniques"""
    
    def __init__(self):
        self.emerging_techniques = {
            'progressive_distillation': self.progressive_distillation,
            'neural_ode_gans': self.neural_ode_generator,
            'transformer_gans': self.transformer_based_gan,
            'efficient_attention': self.efficient_attention_gan
        }
    
    def progressive_distillation(self, teacher_gan, student_gan):
        """Progressive distillation for efficient GANs"""
        
        # Student learns to match teacher with fewer parameters
        distillation_loss = nn.MSELoss()
        
        def distillation_step(real_batch):
            z = torch.randn(real_batch.size(0), 100)
            
            # Teacher and student outputs
            with torch.no_grad():
                teacher_output = teacher_gan.generator(z)
            
            student_output = student_gan.generator(z)
            
            # Knowledge distillation loss
            kd_loss = distillation_loss(student_output, teacher_output)
            
            # Add discriminator loss
            d_loss = student_gan.discriminator_loss(student_output, real_batch)
            
            total_loss = kd_loss + 0.1 * d_loss
            
            return total_loss
        
        return distillation_step
    
    def transformer_based_gan(self, seq_len=64, d_model=512):
        """GAN using transformer architecture"""
        
        class TransformerGenerator(nn.Module):
            def __init__(self, seq_len, d_model, nhead=8, num_layers=6):
                super().__init__()
                
                self.seq_len = seq_len
                self.d_model = d_model
                
                # Input projection
                self.input_proj = nn.Linear(100, d_model)
                
                # Positional encoding
                self.pos_encoding = nn.Parameter(torch.randn(seq_len, d_model))
                
                # Transformer layers
                encoder_layer = nn.TransformerEncoderLayer(
                    d_model=d_model,
                    nhead=nhead,
                    dim_feedforward=2048,
                    dropout=0.1
                )
                self.transformer = nn.TransformerEncoder(
                    encoder_layer,
                    num_layers=num_layers
                )
                
                # Output projection
                self.output_proj = nn.Linear(d_model, 3)  # RGB channels
                
            def forward(self, z):
                batch_size = z.size(0)
                
                # Project input
                x = self.input_proj(z).unsqueeze(1)  # [batch, 1, d_model]
                
                # Expand to sequence
                x = x.repeat(1, self.seq_len, 1)  # [batch, seq_len, d_model]
                
                # Add positional encoding
                x = x + self.pos_encoding.unsqueeze(0)
                
                # Transformer forward pass
                x = x.transpose(0, 1)  # [seq_len, batch, d_model]
                x = self.transformer(x)
                x = x.transpose(0, 1)  # [batch, seq_len, d_model]
                
                # Output projection
                output = self.output_proj(x)  # [batch, seq_len, 3]
                
                # Reshape to image format
                img_size = int(self.seq_len ** 0.5)
                output = output.view(batch_size, img_size, img_size, 3)
                output = output.permute(0, 3, 1, 2)  # [batch, 3, H, W]
                
                return torch.tanh(output)
        
        return TransformerGenerator(seq_len, d_model)
    
    def neural_ode_generator(self, latent_dim=100):
        """Generator using Neural ODEs for continuous generation"""
        
        try:
            from torchdiffeq import odeint
            
            class ODEFunc(nn.Module):
                def __init__(self, dim):
                    super().__init__()
                    self.net = nn.Sequential(
                        nn.Linear(dim, 256),
                        nn.Tanh(),
                        nn.Linear(256, 256),
                        nn.Tanh(),
                        nn.Linear(256, dim)
                    )
                
                def forward(self, t, y):
                    return self.net(y)
            
            class NeuralODEGenerator(nn.Module):
                def __init__(self, latent_dim):
                    super().__init__()
                    self.ode_func = ODEFunc(latent_dim)
                    self.decoder = nn.Sequential(
                        nn.Linear(latent_dim, 256),
                        nn.ReLU(),
                        nn.Linear(256, 512),
                        nn.ReLU(),
                        nn.Linear(512, 28*28),
                        nn.Tanh()
                    )
                
                def forward(self, z, integration_time=1.0):
                    # Solve ODE
                    t = torch.tensor([0., integration_time]).to(z.device)
                    trajectory = odeint(self.ode_func, z, t)
                    
                    # Use final state
                    final_state = trajectory[-1]
                    
                    # Decode to image
                    img = self.decoder(final_state)
                    img = img.view(-1, 1, 28, 28)
                    
                    return img
            
            return NeuralODEGenerator(latent_dim)
            
        except ImportError:
            print("torchdiffeq not available, returning standard generator")
            return None

```text
Generative Adversarial Networks have fundamentally transformed the landscape of generative modeling, enabling the
creation of highly realistic synthetic data across multiple domains. From their original formulation as a minimax game
between generator and discriminator networks to sophisticated variants like StyleGAN and beyond, GANs continue to push
the boundaries of what's possible in artificial data generation. While training stability remains a challenge, ongoing
research in architecture design, training techniques, and evaluation metrics continues to expand the capabilities and
applications of this powerful framework.