Skip to content

Commit 7002c29

Browse files
MetaHIN: Add embedding initialization script
1 parent ea4ecb7 commit 7002c29

File tree

1 file changed

+111
-0
lines changed

1 file changed

+111
-0
lines changed
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Import libraries
2+
import torch
3+
from torch.autograd import Variable
4+
5+
6+
class UserEmbedding(torch.nn.Module):
7+
"""
8+
Initialize user embedding class
9+
"""
10+
def __init__(self, config):
11+
"""
12+
Initialize the user class
13+
:param config: experiment configuration
14+
"""
15+
super(UserEmbedding, self).__init__()
16+
self.num_gender = config['num_gender'] # Number of genders
17+
self.num_age = config['num_age'] # Number of ages
18+
self.num_occupation = config['num_occupation'] # Number of occupations
19+
self.num_zipcode = config['num_zipcode'] # Number of zipcodes
20+
21+
self.embedding_dim = config['embedding_dim'] # Number of embedding dimensions
22+
23+
# Create gender embeddings
24+
self.embedding_gender = torch.nn.Embedding(
25+
num_embeddings=self.num_gender,
26+
embedding_dim=self.embedding_dim
27+
)
28+
29+
# Create age embeddings
30+
self.embedding_age = torch.nn.Embedding(
31+
num_embeddings=self.num_age,
32+
embedding_dim=self.embedding_dim
33+
)
34+
35+
# Create occupation embeddings
36+
self.embedding_occupation = torch.nn.Embedding(
37+
num_embeddings=self.num_occupation,
38+
embedding_dim=self.embedding_dim
39+
)
40+
41+
# Create zipcode area embeddings
42+
self.embedding_area = torch.nn.Embedding(
43+
num_embeddings=self.num_zipcode,
44+
embedding_dim=self.embedding_dim
45+
)
46+
47+
def forward(self, user_fea):
48+
"""
49+
Perform forward pass on user features
50+
:param user_fea: user features
51+
:return: one-dimensional embedding
52+
"""
53+
# Collect user features
54+
gender_idx = Variable(user_fea[:, 0], requires_grad=False)
55+
age_idx = Variable(user_fea[:, 1], requires_grad=False)
56+
occupation_idx = Variable(user_fea[:, 2], requires_grad=False)
57+
area_idx = Variable(user_fea[:, 3], requires_grad=False)
58+
59+
# Perform embedding processes based on the user input
60+
gender_emb = self.embedding_gender(gender_idx)
61+
age_emb = self.embedding_age(age_idx)
62+
occupation_emb = self.embedding_occupation(occupation_idx)
63+
area_emb = self.embedding_area(area_idx)
64+
65+
# Concatenate the embedded vectors
66+
return torch.cat((gender_emb, age_emb, occupation_emb, area_emb), 1) # (1, 4*32)
67+
68+
69+
class ItemEmbeddingML(torch.nn.Module):
70+
"""
71+
Initialize item embedding class
72+
"""
73+
def __init__(self, config):
74+
"""
75+
Initialize the item class
76+
:param config: experiment configuration
77+
"""
78+
super(ItemEmbeddingML, self).__init__()
79+
self.num_rate = config['num_rate'] # Number of rate levels
80+
self.num_genre = config['num_genre'] # Number of genres
81+
self.embedding_dim = config['embedding_dim'] # Number of embedding dimensions
82+
83+
# Create rate category embeddings
84+
self.embedding_rate = torch.nn.Embedding(
85+
num_embeddings=self.num_rate,
86+
embedding_dim=self.embedding_dim
87+
)
88+
89+
# Create genre embeddings
90+
self.embedding_genre = torch.nn.Linear(
91+
in_features=self.num_genre,
92+
out_features=self.embedding_dim,
93+
bias=False
94+
)
95+
96+
def forward(self, item_fea):
97+
"""
98+
Perform forward pass on item features
99+
:param item_fea: item features
100+
:return: one-dimensional embedding
101+
"""
102+
# Collect item features
103+
rate_idx = Variable(item_fea[:, 0], requires_grad=False)
104+
genre_idx = Variable(item_fea[:, 1:26], requires_grad=False)
105+
106+
# Perform embedding processes based on the item input
107+
rate_emb = self.embedding_rate(rate_idx) # (1,32)
108+
genre_emb = self.embedding_genre(genre_idx.float()) / torch.sum(genre_idx.float(), 1).view(-1, 1) # (1,32)
109+
110+
# Concatenate the embedded vectors
111+
return torch.cat((rate_emb, genre_emb), 1) # (1, 2*32)

0 commit comments

Comments
 (0)