Skip to content

Commit 7af539a

Browse files
committed
Update Movie Recommendations Notebook
1 parent 28d6822 commit 7af539a

File tree

1 file changed

+104
-21
lines changed

1 file changed

+104
-21
lines changed

Experimental/Movielens Recommendation.ipynb

Lines changed: 104 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,45 @@
11
{
22
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Recommendation Model with Approximate Item Matching\n",
8+
"\n",
9+
"This notebook shows how to train a simple Neural Collaborative Filtering model for recommeding movies to users. We also show how learnt movie embeddings are stored in an appoximate similarity matching index, using Spotify's [Annoy library](https://github.com/spotify/annoy), so that we can quickly find and recommend the most relevant movies to a given customer. We show how this index to search for similar movies.\n",
10+
"\n",
11+
"In essense, this tutorial works as follows:\n",
12+
"1. Download the movielens dataset.\n",
13+
"2. Train a simple Neural Collaborative Model using TensorFlow custom estimator.\n",
14+
"3. Extract the learnt movie embeddings.\n",
15+
"4. Build an approximate similarity matching index for the movie embeddings.\n",
16+
"5. Export the trained model, which receives a user Id, and output the user embedding.\n",
17+
"\n",
18+
"The recommendation is served as follows:\n",
19+
"1. Receives a user Id\n",
20+
"2. Get the user embedding from the exported model\n",
21+
"3. Find the similar movie embeddings to the user embedding in the index\n",
22+
"4. Return the movie Ids of these embeddings to recommend\n",
23+
"\n",
24+
"<a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/tf-estimator-tutorials/blob/master/Experimental/Movielens%20Recommendation.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
25+
]
26+
},
27+
{
28+
"cell_type": "markdown",
29+
"metadata": {},
30+
"source": [
31+
"## Setup"
32+
]
33+
},
34+
{
35+
"cell_type": "code",
36+
"execution_count": null,
37+
"metadata": {},
38+
"outputs": [],
39+
"source": [
40+
"!pip install annoy"
41+
]
42+
},
343
{
444
"cell_type": "code",
545
"execution_count": 1,
@@ -32,7 +72,7 @@
3272
"cell_type": "markdown",
3373
"metadata": {},
3474
"source": [
35-
"## Download Data"
75+
"## 1. Download Data"
3676
]
3777
},
3878
{
@@ -373,7 +413,14 @@
373413
"cell_type": "markdown",
374414
"metadata": {},
375415
"source": [
376-
"## Define Metadata"
416+
"## 2. Build the TensorFlow Model"
417+
]
418+
},
419+
{
420+
"cell_type": "markdown",
421+
"metadata": {},
422+
"source": [
423+
"### 2.1 Define Metadata"
377424
]
378425
},
379426
{
@@ -393,12 +440,12 @@
393440
"cell_type": "markdown",
394441
"metadata": {},
395442
"source": [
396-
"## Define Data Input Function"
443+
"### 2.2 Define Data Input Function"
397444
]
398445
},
399446
{
400447
"cell_type": "code",
401-
"execution_count": 18,
448+
"execution_count": null,
402449
"metadata": {},
403450
"outputs": [],
404451
"source": [
@@ -418,10 +465,7 @@
418465
" num_epochs=num_epochs,\n",
419466
" shuffle= (mode==tf.estimator.ModeKeys.TRAIN)\n",
420467
" )\n",
421-
" \n",
422-
" iterator = dataset.make_one_shot_iterator()\n",
423-
" features, target = iterator.get_next()\n",
424-
" return features, target\n",
468+
" return dataset\n",
425469
" \n",
426470
" return _input_fn"
427471
]
@@ -430,7 +474,7 @@
430474
"cell_type": "markdown",
431475
"metadata": {},
432476
"source": [
433-
"## Create Feature Columns"
477+
"### 2.3 Create Feature Columns"
434478
]
435479
},
436480
{
@@ -466,7 +510,7 @@
466510
"cell_type": "markdown",
467511
"metadata": {},
468512
"source": [
469-
"## Define Model Function"
513+
"### 2.4 Define Model Function"
470514
]
471515
},
472516
{
@@ -506,14 +550,14 @@
506550
" mode=mode,\n",
507551
" loss=loss,\n",
508552
" train_op=train_op\n",
509-
" )\n"
553+
" )"
510554
]
511555
},
512556
{
513557
"cell_type": "markdown",
514558
"metadata": {},
515559
"source": [
516-
"## Create Estimator"
560+
"### 2.5 Create Estimator"
517561
]
518562
},
519563
{
@@ -537,7 +581,7 @@
537581
"cell_type": "markdown",
538582
"metadata": {},
539583
"source": [
540-
"## Define Experiment"
584+
"### 2.6 Define Experiment"
541585
]
542586
},
543587
{
@@ -612,7 +656,7 @@
612656
"cell_type": "markdown",
613657
"metadata": {},
614658
"source": [
615-
"## Run Experiment with Parameters"
659+
"### 2.7 Run Experiment with Parameters"
616660
]
617661
},
618662
{
@@ -710,7 +754,7 @@
710754
"cell_type": "markdown",
711755
"metadata": {},
712756
"source": [
713-
"## Extract Movie Embeddings "
757+
"## 3. Extract Movie Embeddings "
714758
]
715759
},
716760
{
@@ -766,7 +810,7 @@
766810
"cell_type": "markdown",
767811
"metadata": {},
768812
"source": [
769-
"## Build Annoy Index"
813+
"## 4. Build Annoy Index"
770814
]
771815
},
772816
{
@@ -1145,7 +1189,7 @@
11451189
"cell_type": "markdown",
11461190
"metadata": {},
11471191
"source": [
1148-
"## Export the Model\n",
1192+
"## 5. Export the Model\n",
11491193
"This needed to receive a userId and produce the embedding for the user."
11501194
]
11511195
},
@@ -1234,6 +1278,13 @@
12341278
"print(output)"
12351279
]
12361280
},
1281+
{
1282+
"cell_type": "markdown",
1283+
"metadata": {},
1284+
"source": [
1285+
"## Serve Movie Recommendations to a User"
1286+
]
1287+
},
12371288
{
12381289
"cell_type": "code",
12391290
"execution_count": 190,
@@ -1276,11 +1327,43 @@
12761327
]
12771328
},
12781329
{
1279-
"cell_type": "code",
1280-
"execution_count": null,
1330+
"cell_type": "markdown",
12811331
"metadata": {},
1282-
"outputs": [],
1283-
"source": []
1332+
"source": [
1333+
"## License"
1334+
]
1335+
},
1336+
{
1337+
"cell_type": "markdown",
1338+
"metadata": {},
1339+
"source": [
1340+
"---\n",
1341+
"\n",
1342+
"Author: Khalid Salama\n",
1343+
"\n",
1344+
"\n",
1345+
"---\n",
1346+
"***Disclaimer***: This is not an official Google product. This sample code provided for an educational purpose.\n",
1347+
"\n",
1348+
"---\n",
1349+
"\n",
1350+
"Copyright 2019 Google LLC\n",
1351+
"\n",
1352+
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
1353+
"you may not use this file except in compliance with the License.\n",
1354+
"You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0.\n",
1355+
"\n",
1356+
"Unless required by applicable law or agreed to in writing, software\n",
1357+
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
1358+
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
1359+
"See the License for the specific language governing permissions and\n",
1360+
"limitations under the License.\n",
1361+
"\n",
1362+
"\n",
1363+
"---\n",
1364+
"\n",
1365+
"\n"
1366+
]
12841367
}
12851368
],
12861369
"metadata": {

0 commit comments

Comments
 (0)