PyTorch Quickstart#


../../../_images/colab_logo_32px.pngRun in Google Colab  ../../../_images/GitHub-Mark-32px.pngView source on GitHub


In this guide we will describe how to scale out PyTorch programs using Orca in 4 simple steps.

Step 0: Prepare Environment#

Conda is needed to prepare the Python environment for running this example. Please refer to the install guide for more details.

conda create -n py37 python=3.7  # "py37" is conda environment name, you can use any name you like.
conda activate py37
pip install bigdl-orca
pip install torch==1.7.1 torchvision==0.8.2
pip install six cloudpickle
pip install jep==3.9.0

Step 1: Init Orca Context#

from bigdl.orca import init_orca_context, stop_orca_context

cluster_mode = "local"
if cluster_mode == "local":  # For local machine
    init_orca_context(cores=4, memory="10g")
elif cluster_mode == "k8s":  # For K8s cluster
    init_orca_context(cluster_mode="k8s", num_nodes=2, cores=2, memory="10g", driver_memory="10g", driver_cores=1)
elif cluster_mode == "yarn":  # For Hadoop/YARN cluster
    init_orca_context(
    cluster_mode="yarn", cores=2, num_nodes=2, memory="10g",
    driver_memory="10g", driver_cores=1,
    conf={"spark.rpc.message.maxSize": "1024",
        "spark.task.maxFailures": "1",
        "spark.driver.extraJavaOptions": "-Dbigdl.failure.retryTimes=1"})

This is the only place where you need to specify local or distributed mode. View Orca Context for more details.

Note: You should export HADOOP_CONF_DIR=/path/to/hadoop/conf/dir when running on Hadoop YARN cluster. View Hadoop User Guide for more details.

Step 2: Define the Model#

You may define your model, loss and optimizer in the same way as in any standard (single node) PyTorch program.

import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

model = LeNet()
model.train()
criterion = nn.NLLLoss()
adam = torch.optim.Adam(model.parameters(), 0.001)

Step 3: Define Train Dataset#

You can define the dataset using standard Pytorch DataLoader.

import torch
from torchvision import datasets, transforms

torch.manual_seed(0)
dir='./'

batch_size=64
test_batch_size=64
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(dir, train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(dir, train=False,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=test_batch_size, shuffle=False)

Alternatively, we can also use a Data Creator Function or Orca XShards as the input data, especially when the data size is very large)

Step 4: Fit with Orca Estimator#

First, Create an Estimator

from bigdl.orca.learn.pytorch import Estimator 
from bigdl.orca.learn.metrics import Accuracy

est = Estimator.from_torch(model=model, optimizer=adam, loss=criterion, metrics=[Accuracy()])

Next, fit and evaluate using the Estimator

from bigdl.orca.learn.trigger import EveryEpoch 

est.fit(data=train_loader, epochs=10, validation_data=test_loader,
        checkpoint_trigger=EveryEpoch())

result = est.evaluate(data=test_loader)
for r in result:
    print(r, ":", result[r])

Step 5: Save and Load the Model#

Save the Estimator states (including model and optimizer) to the provided model path.

est.save("mnist_model")

Load the Estimator states (model and possibly with optimizer) from the provided model path.

est.load("mnist_model")

Note: You should call stop_orca_context() when your application finishes.