Skip to content

Commit 4202ec1

Browse files
committed
update how labels are handled
1 parent 43aed50 commit 4202ec1

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
/public/web_model/*
2+
/public/labels.json
23

34
# Logs
45
logs

src/index.js

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ import ReactDOM from 'react-dom'
44
import * as tf from '@tensorflow/tfjs'
55
import './styles.css'
66

7-
const MODEL_URL = './web_model/tensorflowjs_model.pb'
8-
const WEIGHTS_URL = './web_model/weights_manifest.json'
9-
const LABELS = ['Pepsi', 'Mountain Dew', 'Coke']
7+
const LABELS_URL = process.env.PUBLIC_URL + '/labels.json'
8+
const MODEL_URL = process.env.PUBLIC_URL + '/web_model/tensorflowjs_model.pb'
9+
const WEIGHTS_URL = process.env.PUBLIC_URL + '/web_model/weights_manifest.json'
1010

1111
const TFWrapper = model => {
1212
const calculateMaxScores = (scores, numBoxes, numClasses) => {
@@ -141,30 +141,31 @@ class App extends React.Component {
141141
}
142142
})
143143
})
144-
145144
const modelPromise = tf.loadFrozenModel(MODEL_URL, WEIGHTS_URL)
146-
Promise.all([modelPromise, webCamPromise])
145+
const labelsPromise = fetch(LABELS_URL).then(data => data.json())
146+
Promise.all([modelPromise, labelsPromise, webCamPromise])
147147
.then(values => {
148-
this.detectFrame(this.videoRef.current, values[0])
148+
const [model, labels] = values
149+
this.detectFrame(this.videoRef.current, model, labels)
149150
})
150151
.catch(error => {
151152
console.error(error)
152153
})
153154
}
154155
}
155156

156-
detectFrame = (video, model) => {
157+
detectFrame = (video, model, labels) => {
157158
TFWrapper(model)
158159
.detect(video)
159160
.then(predictions => {
160-
this.renderPredictions(predictions)
161+
this.renderPredictions(predictions, labels)
161162
requestAnimationFrame(() => {
162-
this.detectFrame(video, model)
163+
this.detectFrame(video, model, labels)
163164
})
164165
})
165166
}
166167

167-
renderPredictions = predictions => {
168+
renderPredictions = (predictions, labels) => {
168169
const ctx = this.canvasRef.current.getContext('2d')
169170
ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height)
170171
// Font options.
@@ -176,7 +177,7 @@ class App extends React.Component {
176177
const y = prediction.bbox[1]
177178
const width = prediction.bbox[2]
178179
const height = prediction.bbox[3]
179-
const label = LABELS[parseInt(prediction.class)]
180+
const label = labels[parseInt(prediction.class)]
180181
// Draw the bounding box.
181182
ctx.strokeStyle = '#00FFFF'
182183
ctx.lineWidth = 4
@@ -191,7 +192,7 @@ class App extends React.Component {
191192
predictions.forEach(prediction => {
192193
const x = prediction.bbox[0]
193194
const y = prediction.bbox[1]
194-
const label = LABELS[parseInt(prediction.class)]
195+
const label = labels[parseInt(prediction.class)]
195196
// Draw the text last to ensure it's on top.
196197
ctx.fillStyle = '#000000'
197198
ctx.fillText(label, x, y)

0 commit comments

Comments
 (0)