Skip to content
Merged
Prev Previous commit
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 10, 2024
commit 651a1267a438fb3cd3358a436586c464bdfa4024
32 changes: 16 additions & 16 deletions test/3x/torch/algorithms/fp8_quant/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
import os
import sys
import torch
import time

import habana_frameworks.torch.core as htcore

from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 64)
self.fc3 = nn.Linear(64, 10)
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 64)
self.fc3 = nn.Linear(64, 10)

def forward(self, x):
out = x.view(-1,28*28)
out = x.view(-1, 28 * 28)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
out = F.log_softmax(out, dim=1)
return out


def test_hpu():
model = Net()
model_link = "https://vault.habana.ai/artifactory/misc/inference/mnist/mnist-epoch_20.pth"
Expand All @@ -36,14 +38,12 @@ def test_hpu():

model = model.to("hpu")

transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

data_path = './data'
test_kwargs = {'batch_size': 32}
data_path = "./data"
test_kwargs = {"batch_size": 32}
dataset1 = datasets.MNIST(data_path, train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(dataset1,**test_kwargs)
test_loader = torch.utils.data.DataLoader(dataset1, **test_kwargs)

correct = 0
for batch_idx, (data, label) in enumerate(test_loader):
Expand All @@ -52,5 +52,5 @@ def test_hpu():
htcore.mark_step()
correct += output.max(1)[1].eq(label).sum()

accuracy = 100. * correct / (len(test_loader) * 32)
assert accuracy > 90
accuracy = 100.0 * correct / (len(test_loader) * 32)
assert accuracy > 90