Skip to content
Merged
Prev Previous commit
Next Next commit
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 9, 2024
commit f5333b2b3a02485e328db3bbb339a995631148ac
33 changes: 16 additions & 17 deletions test/3x/torch/algorithms/fp8_quant/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,46 @@
import os
import sys
import torch
import time
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import torch.nn as nn
import torch.nn.functional as F

import habana_frameworks.torch.core as htcore
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 64)
self.fc3 = nn.Linear(64, 10)
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 64)
self.fc3 = nn.Linear(64, 10)

def forward(self, x):
out = x.view(-1,28*28)
out = x.view(-1, 28 * 28)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
out = F.log_softmax(out, dim=1)
return out


model = Net()
checkpoint = torch.load('mnist-epoch_20.pth')
checkpoint = torch.load("mnist-epoch_20.pth")
model.load_state_dict(checkpoint)

model = model.eval()

model = model.to("hpu")



model = torch.compile(model,backend="hpu_backend")
model = torch.compile(model, backend="hpu_backend")


transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

data_path = './data'
data_path = "./data"
test_dataset = datasets.MNIST(data_path, train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32)

Expand All @@ -57,4 +56,4 @@ def forward(self, x):
correct += output.argmax(1).eq(label).sum().item()

accuracy = correct / len(test_loader.dataset) * 100
print('Inference with torch.compile Completed. Accuracy: {:.2f}%'.format(accuracy))
print("Inference with torch.compile Completed. Accuracy: {:.2f}%".format(accuracy))