Skip to content

Commit 549ceee

Browse files
committed
exported one_hot_tensor
1 parent e769919 commit 549ceee

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

automlp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
__all__ = "automlp parameters".split()
22

3-
from .automlp import ArrayDataset, RebatchingLoader, Trainer, AutoMLP, GridSearch, smart_target
3+
from .automlp import ArrayDataset, RebatchingLoader, Trainer, AutoMLP, GridSearch, smart_target, one_hot_tensor
44
from .parameters import Parameter, Constant, UniformParameter, QuantizedParameter
55
from .parameters import LogParameter, QuantizedLogParameter, ParameterSet, Exploration

automlp/automlp.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,13 @@ def as_tensor(images):
162162
raise ValueError("unknown dtype: {}".format(images.dtype))
163163

164164

165-
@deprecated
166-
def one_hot_tensor(classes, nclasses=None, value=1.0):
167-
classes = as_class_tensor(classes)
168-
if nclasses is None:
169-
nclasses = 1 + np.amax(classes)
165+
def one_hot_tensor(classes, nclasses, value=1.0):
166+
classes = classes.to(torch.int64)
167+
targets = torch.zeros(len(classes), nclasses)
168+
targets = targets.scatter(
169+
1, classes.unsqueeze(1).to(
170+
targets.device), value)
171+
return targets
170172

171173

172174
def getbatch_random(dataset, batch_size, convert=from_numpy):
@@ -253,7 +255,7 @@ def __init__(self, model, mode="crossentropy", device=None,
253255

254256
self.make_optimizer = make_optimizer
255257
self.device = device or guess_input_device(model)
256-
set_mode(self, mode)
258+
self.set_mode(mode)
257259

258260

259261
def set_mode(self, mode):

0 commit comments

Comments
 (0)