Skip to content

Commit c3ae9f0

Browse files
authored
Add files via upload
1 parent 11b02c6 commit c3ae9f0

File tree

1 file changed

+238
-0
lines changed
  • how-to-use-azureml/ml-frameworks/pytorch/distributed-pytorch-with-distributeddataparallel

1 file changed

+238
-0
lines changed
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
# Copyright (c) 2017 Facebook, Inc. All rights reserved.
2+
# BSD 3-Clause License
3+
#
4+
# Script adapted from:
5+
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
6+
# ==============================================================================
7+
8+
# imports
9+
import torch
10+
import torchvision
11+
import torchvision.transforms as transforms
12+
import torch.nn as nn
13+
import torch.nn.functional as F
14+
import torch.optim as optim
15+
import os
16+
import argparse
17+
18+
19+
# define network architecture
20+
class Net(nn.Module):
21+
def __init__(self):
22+
super(Net, self).__init__()
23+
self.conv1 = nn.Conv2d(3, 32, 3)
24+
self.pool = nn.MaxPool2d(2, 2)
25+
self.conv2 = nn.Conv2d(32, 64, 3)
26+
self.conv3 = nn.Conv2d(64, 128, 3)
27+
self.fc1 = nn.Linear(128 * 6 * 6, 120)
28+
self.dropout = nn.Dropout(p=0.2)
29+
self.fc2 = nn.Linear(120, 84)
30+
self.fc3 = nn.Linear(84, 10)
31+
32+
def forward(self, x):
33+
x = F.relu(self.conv1(x))
34+
x = self.pool(F.relu(self.conv2(x)))
35+
x = self.pool(F.relu(self.conv3(x)))
36+
x = x.view(-1, 128 * 6 * 6)
37+
x = self.dropout(F.relu(self.fc1(x)))
38+
x = F.relu(self.fc2(x))
39+
x = self.fc3(x)
40+
return x
41+
42+
43+
def train(train_loader, model, criterion, optimizer, epoch, device, print_freq, rank):
44+
running_loss = 0.0
45+
for i, data in enumerate(train_loader, 0):
46+
# get the inputs; data is a list of [inputs, labels]
47+
inputs, labels = data[0].to(device), data[1].to(device)
48+
49+
# zero the parameter gradients
50+
optimizer.zero_grad()
51+
52+
# forward + backward + optimize
53+
outputs = model(inputs)
54+
loss = criterion(outputs, labels)
55+
loss.backward()
56+
optimizer.step()
57+
58+
# print statistics
59+
running_loss += loss.item()
60+
if i % print_freq == 0: # print every print_freq mini-batches
61+
print(
62+
"Rank %d: [%d, %5d] loss: %.3f"
63+
% (rank, epoch + 1, i + 1, running_loss / print_freq)
64+
)
65+
running_loss = 0.0
66+
67+
68+
def evaluate(test_loader, model, device):
69+
classes = (
70+
"plane",
71+
"car",
72+
"bird",
73+
"cat",
74+
"deer",
75+
"dog",
76+
"frog",
77+
"horse",
78+
"ship",
79+
"truck",
80+
)
81+
82+
model.eval()
83+
84+
correct = 0
85+
total = 0
86+
class_correct = list(0.0 for i in range(10))
87+
class_total = list(0.0 for i in range(10))
88+
with torch.no_grad():
89+
for data in test_loader:
90+
images, labels = data[0].to(device), data[1].to(device)
91+
outputs = model(images)
92+
_, predicted = torch.max(outputs.data, 1)
93+
total += labels.size(0)
94+
correct += (predicted == labels).sum().item()
95+
c = (predicted == labels).squeeze()
96+
for i in range(10):
97+
label = labels[i]
98+
class_correct[label] += c[i].item()
99+
class_total[label] += 1
100+
101+
# print total test set accuracy
102+
print(
103+
"Accuracy of the network on the 10000 test images: %d %%"
104+
% (100 * correct / total)
105+
)
106+
107+
# print test accuracy for each of the classes
108+
for i in range(10):
109+
print(
110+
"Accuracy of %5s : %2d %%"
111+
% (classes[i], 100 * class_correct[i] / class_total[i])
112+
)
113+
114+
115+
def main(args):
116+
# get PyTorch environment variables
117+
world_size = int(os.environ["WORLD_SIZE"])
118+
rank = int(os.environ["RANK"])
119+
local_rank = int(os.environ["LOCAL_RANK"])
120+
121+
distributed = world_size > 1
122+
123+
# set device
124+
if distributed:
125+
device = torch.device("cuda", local_rank)
126+
else:
127+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
128+
129+
# initialize distributed process group using default env:// method
130+
if distributed:
131+
torch.distributed.init_process_group(backend="nccl")
132+
133+
# define train and test dataset DataLoaders
134+
transform = transforms.Compose(
135+
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
136+
)
137+
138+
train_set = torchvision.datasets.CIFAR10(
139+
root=args.data_dir, train=True, download=False, transform=transform
140+
)
141+
142+
if distributed:
143+
train_sampler = torch.utils.data.distributed.DistributedSampler(train_set)
144+
else:
145+
train_sampler = None
146+
147+
train_loader = torch.utils.data.DataLoader(
148+
train_set,
149+
batch_size=args.batch_size,
150+
shuffle=(train_sampler is None),
151+
num_workers=args.workers,
152+
sampler=train_sampler,
153+
)
154+
155+
test_set = torchvision.datasets.CIFAR10(
156+
root=args.data_dir, train=False, download=False, transform=transform
157+
)
158+
test_loader = torch.utils.data.DataLoader(
159+
test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers
160+
)
161+
162+
model = Net().to(device)
163+
164+
# wrap model with DDP
165+
if distributed:
166+
model = nn.parallel.DistributedDataParallel(
167+
model, device_ids=[local_rank], output_device=local_rank
168+
)
169+
170+
# define loss function and optimizer
171+
criterion = nn.CrossEntropyLoss()
172+
optimizer = optim.SGD(
173+
model.parameters(), lr=args.learning_rate, momentum=args.momentum
174+
)
175+
176+
# train the model
177+
for epoch in range(args.epochs):
178+
print("Rank %d: Starting epoch %d" % (rank, epoch))
179+
if distributed:
180+
train_sampler.set_epoch(epoch)
181+
model.train()
182+
train(
183+
train_loader,
184+
model,
185+
criterion,
186+
optimizer,
187+
epoch,
188+
device,
189+
args.print_freq,
190+
rank,
191+
)
192+
193+
print("Rank %d: Finished Training" % (rank))
194+
195+
if not distributed or rank == 0:
196+
os.makedirs(args.output_dir, exist_ok=True)
197+
model_path = os.path.join(args.output_dir, "cifar_net.pt")
198+
torch.save(model.state_dict(), model_path)
199+
200+
# evaluate on full test dataset
201+
evaluate(test_loader, model, device)
202+
203+
204+
if __name__ == "__main__":
205+
# setup argparse
206+
parser = argparse.ArgumentParser()
207+
parser.add_argument(
208+
"--data-dir", type=str, help="directory containing CIFAR-10 dataset"
209+
)
210+
parser.add_argument("--epochs", default=10, type=int, help="number of epochs")
211+
parser.add_argument(
212+
"--batch-size",
213+
default=16,
214+
type=int,
215+
help="mini batch size for each gpu/process",
216+
)
217+
parser.add_argument(
218+
"--workers",
219+
default=2,
220+
type=int,
221+
help="number of data loading workers for each gpu/process",
222+
)
223+
parser.add_argument(
224+
"--learning-rate", default=0.001, type=float, help="learning rate"
225+
)
226+
parser.add_argument("--momentum", default=0.9, type=float, help="momentum")
227+
parser.add_argument(
228+
"--output-dir", default="outputs", type=str, help="directory to save model to"
229+
)
230+
parser.add_argument(
231+
"--print-freq",
232+
default=200,
233+
type=int,
234+
help="frequency of printing training statistics",
235+
)
236+
args = parser.parse_args()
237+
238+
main(args)

0 commit comments

Comments
 (0)