AI Engineering: Scaling your models with Ray Train for Blazing-Fast Performance
DALL-E: Ray of hope inspired by Rembrandt

AI Engineering: Scaling your models with Ray Train for Blazing-Fast Performance

Note: This guide is aimed at those who are learning deep networks and are just starting to parallelize their models. If you're already familiar with distributed training and looking for more advanced concepts, feel free to skip ahead to Section 4.

1. Introduction

The landscape of machine learning has seen a monumental shift in the last decade, driven by the need to scale neural network training over increasingly larger datasets and more complex models. Neural networks, specifically Convolutional Neural Networks (CNNs), have proven to be highly effective in solving image classification tasks, but their computational complexity requires efficient use of resources like GPUs and CPUs, especially as the scale of both models and data grows.

This evolution is best exemplified by the transformation from traditional vanilla neural networks, through the inception of CNNs, and finally to the use of distributed computing frameworks such as Ray Train, which enables parallel training of these networks across multiple GPUs or even clusters of machines. This technical writeup examines this progression with a detailed focus on the transition to distributed training using Ray Train in the context of a practical task: image classification on the MNIST dataset.


2. Vanilla Neural Networks: The Starting Point

2.1 Overview of Vanilla Neural Networks

A vanilla neural network, also called a fully connected neural network, is characterized by layers of neurons where every neuron in one layer is connected to every neuron in the subsequent layer. The network consists of an input layer, hidden layers, and an output layer. While this architecture is effective for small, structured datasets, it becomes inefficient for high-dimensional data like images, due to the sheer number of parameters involved in fully connecting each pixel to neurons in the next layer.

2.2 Limitations of Vanilla Neural Networks in Image Processing

For image data (such as a 28x28 grayscale image), a vanilla network would require 784 input neurons (one per pixel). When combined with even a moderate hidden layer size (e.g., 128 neurons), the model would contain over 100,000 parameters to learn, making training slow and prone to overfitting. Moreover, vanilla neural networks fail to capture spatial information (e.g., the proximity of pixels), which is crucial for tasks like image recognition.

Example: Vanilla Neural Network for MNIST:

import torch
import torch.nn as nn

class VanillaNN(nn.Module):
    def __init__(self):
        super(VanillaNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # Flatten the input image
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x        

2.3 Scalability Challenges

Vanilla networks become computationally expensive as the data size and number of neurons increase. Training on a single CPU or GPU with such a network is impractical for large datasets or images with higher resolutions (e.g., 224x224). To overcome this limitation, Convolutional Neural Networks (CNNs) were introduced.


3. Convolutional Neural Networks (CNNs): Introducing Spatial Awareness

3.1 Architecture and Operation of CNNs

CNNs address the limitations of vanilla neural networks by introducing convolutional layers, which apply learnable filters (kernels) across the image to detect local features (edges, textures). Instead of connecting every pixel to every neuron in the next layer, CNNs leverage the spatial locality of the data. This greatly reduces the number of parameters and allows the model to better capture hierarchical patterns in images, such as edges in early layers and complex shapes in deeper layers.

CNN for MNIST:

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)  # 1 input channel, 32 output channels
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))  # First convolutional layer with ReLU
        x = torch.max_pool2d(x, 2)     # 2x2 Max pooling to reduce spatial dimensions
        x = torch.relu(self.conv2(x))  # Second convolutional layer with ReLU
        x = torch.max_pool2d(x, 2)     # Another 2x2 Max pooling
        x = x.view(-1, 64 * 7 * 7)     # Flatten for fully connected layer
        x = torch.relu(self.fc1(x))    # First fully connected layer
        x = self.fc2(x)                # Output layer (10 classes)
        return x
        

3.2 Efficiency Gains in CNNs

CNNs significantly reduce the number of parameters compared to fully connected networks. For example, a convolutional layer with 32 filters and a 3x3 kernel has only 320 parameters (32 filters * (3x3 kernel + 1 bias term per filter)) compared to the tens of thousands of parameters in a fully connected layer. The local receptive fields of CNNs capture spatial dependencies, making them particularly effective for image classification.

3.3 Scaling Issues with CNNs

Despite their efficiency, CNNs still face challenges when dealing with large datasets or models, such as ImageNet-scale classification tasks. Training deep CNNs like ResNet or EfficientNet requires significant computational resources, and even with state-of-the-art GPUs, training can take weeks. As model complexity grows, the time to train on a single GPU becomes prohibitive.

To handle these challenges, distributed training architectures like data parallelism are required. Data parallelism allows the training to be split across multiple GPUs, CPUs, or even multiple nodes in a cluster, significantly reducing training time.


4. Introduction to Distributed Training and Ray Train

Distributed training has become essential for training large models on massive datasets. The core idea is to split the dataset across multiple workers (devices or nodes), each of which trains a copy of the model on a subset of the data. After each batch, the gradients are averaged across the workers, ensuring that all models stay synchronized.


4.1 Distributed Training Strategies

1. Data Parallelism:

Each worker (e.g., a GPU) holds a full replica of the model and processes a subset of the data. Gradients are calculated locally and averaged across workers after each forward-backward pass.

2. Model Parallelism:

The model itself is partitioned across multiple devices. Each device is responsible for a subset of the model's parameters. This strategy is effective when the model is too large to fit into the memory of a single GPU.

Ray Train is a powerful distributed training framework built on top of Ray. It allows for scaling machine learning workloads across multiple GPUs, CPUs, or nodes in a cluster, simplifying both data and model parallelism. Ray Train abstracts many of the complexities involved in distributed training (such as gradient synchronization and distributed data loading) while providing seamless integration with deep learning frameworks like PyTorch.


5. Ray Train: Scalable Distributed Training

5.1 Overview of Ray Train

Ray Train provides a high-level API that simplifies distributed training. It abstracts the intricacies of data splitting, model synchronization, and multi-GPU utilization, making it easier for researchers and engineers to scale their neural networks across multiple devices. Ray Train integrates natively with PyTorch and TensorFlow, offering both flexibility and scalability.

Core Concepts in Ray Train:

  • Workers: Each worker represents an independent process running a portion of the training workload, either on a CPU or GPU.
  • Gradient Synchronization: Ray Train ensures that gradients calculated by each worker are synchronized after each batch, enabling consistent model updates.
  • Fault Tolerance: If a worker fails, Ray Train can recover gracefully by restarting the failed task or resuming from a checkpoint.

5.2 Parallel Training with Ray Train

Ray Train's data parallelism strategy involves distributing the dataset across workers, each processing a subset of the data in parallel. Each worker computes its own gradients, and the gradients are synchronized across all workers after each iteration. This ensures that model parameters remain consistent across workers. Here's how it works for training a CNN on MNIST using Ray Train.


6. Distributed CNN Training on MNIST with Ray Train

6.1 Installation and Setup

Before we dive into the code, make sure to install Ray and its relevant components.

pip install ray ray[train] torch torchvision        

6.2 Ray Train Configuration for MNIST

Here, we'll use Ray Train to parallelize the training of a CNN across 4 workers, each processing a subset of the MNIST dataset. This example will run on GPUs if available, but Ray Train can seamlessly switch between CPUs and GPUs based on the system's capabilities.

import ray
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch
import torch.nn as nn
import torch.optim as optim

# Initialize Ray
ray.init()

# CNN Model Definition
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(-1, 64 * 7 * 7)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Training Loop for Ray Train
def train_loop_per_worker(config):
    model = CNN()
    model = ray.train.torch.prepare_model(model)  # Prepares model for distributed training
    optimizer = optim.SGD(model.parameters(), lr=config["lr"])
    
    # Load dataset and DataLoader
    train_loader = DataLoader(config["train_dataset"], batch_size=config["batch_size"], shuffle=True)
    train_loader = ray.train.torch.prepare_data_loader(train_loader)

    # Training
    for epoch in range(config["epochs"]):
        model.train()
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = nn.functional.cross_entropy(output, target)
            loss.backward()
            optimizer.step()

# Define the MNIST Dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./mnist_data', train=True, download=True, transform=transform)

# Ray Train Configuration
scaling_config = ScalingConfig(num_workers=4, use_gpu=torch.cuda.is_available())

trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    scaling_config=scaling_config,
    train_loop_config={"train_dataset": train_dataset, "batch_size": 64, "lr": 0.01, "epochs": 5}
)

# Start Training
result = trainer.fit()

# Shutdown Ray
ray.shutdown()
        

6.3 Main Components

1. Model Preparation: The model is wrapped with prepare_model, which prepares it for distributed training by handling gradient synchronization across workers.

2. DataLoader Preparation: The prepare_data_loader function ensures that each worker gets its own subset of the dataset, enabling distributed data parallelism.

3. Gradient Synchronization: Ray Train synchronizes the gradients across all workers after each backward pass, ensuring consistent updates to model parameters.

4. Scaling Configuration: ScalingConfig allows us to specify the number of workers (e.g., 4) and whether we want to use GPUs. Ray Train will handle all the underlying distribution and parallelism automatically.

6.4 Parallel Training Workflow

  • Data Parallelism: The dataset is split into mini-batches, each of which is processed independently by one of the 4 workers. Each worker computes the gradients based on its data, and Ray Train synchronizes these gradients at the end of each batch to ensure that all workers update the model consistently.
  • Worker Initialization: Ray creates a separate worker for each device (CPU or GPU) specified in the ScalingConfig. Each worker runs a copy of the training loop independently.
  • Gradient Synchronization: After each batch, the gradients computed by each worker are synchronized across all devices, ensuring that model weights are updated consistently.


7. Conclusion

The evolution from vanilla neural networks to convolutional neural networks (CNNs) and eventually to distributed training architectures such as Ray Train reflects the growing need to scale deep learning models efficiently. Vanilla networks were limited in their ability to process large datasets like images due to their lack of spatial awareness and high computational cost. CNNs addressed these issues by leveraging convolutional layers, reducing the number of parameters, and capturing spatial hierarchies. However, as data and models continue to grow in complexity, even CNNs require distributed training solutions to manage training times and computational resources effectively.

Ray Train is a robust and flexible framework that simplifies distributed training by abstracting the complexities of parallelism, gradient synchronization, and data distribution. With minimal code changes, models like CNNs can be scaled to utilize multiple GPUs or even clusters, making it easier than ever to train large models on massive datasets.

By using Ray Train, we can unlock the full potential of modern hardware to accelerate the development of state-of-the-art deep learning models, enabling faster iterations, more efficient resource usage, and ultimately, better performance in real-world applications.



Malur Narayan

CTO l Building Responsible AI solutions | Speaker I ML research I Networking technologies I Digital solutions | Sustainability | Board Member | Advisor | Mental Health Advocate

4mo

Excelled write up on distributed training. Thanks

Like
Reply

To view or add a comment, sign in

More articles by Vijay Raghavan Ph.D., M.B.A.,

Insights from the community

Others also viewed

Explore topics