@@ -4,9 +4,9 @@ import ReactDOM from 'react-dom'
4
4
import * as tf from '@tensorflow/tfjs'
5
5
import './styles.css'
6
6
7
- const MODEL_URL = './web_model/tensorflowjs_model.pb '
8
- const WEIGHTS_URL = '. /web_model/weights_manifest.json '
9
- const LABELS = [ 'Pepsi' , 'Mountain Dew' , 'Coke' ]
7
+ const LABELS_URL = process . env . PUBLIC_URL + '/labels.json '
8
+ const MODEL_URL = process . env . PUBLIC_URL + ' /web_model/tensorflowjs_model.pb '
9
+ const WEIGHTS_URL = process . env . PUBLIC_URL + '/web_model/weights_manifest.json'
10
10
11
11
const TFWrapper = model => {
12
12
const calculateMaxScores = ( scores , numBoxes , numClasses ) => {
@@ -141,30 +141,31 @@ class App extends React.Component {
141
141
}
142
142
} )
143
143
} )
144
-
145
144
const modelPromise = tf . loadFrozenModel ( MODEL_URL , WEIGHTS_URL )
146
- Promise . all ( [ modelPromise , webCamPromise ] )
145
+ const labelsPromise = fetch ( LABELS_URL ) . then ( data => data . json ( ) )
146
+ Promise . all ( [ modelPromise , labelsPromise , webCamPromise ] )
147
147
. then ( values => {
148
- this . detectFrame ( this . videoRef . current , values [ 0 ] )
148
+ const [ model , labels ] = values
149
+ this . detectFrame ( this . videoRef . current , model , labels )
149
150
} )
150
151
. catch ( error => {
151
152
console . error ( error )
152
153
} )
153
154
}
154
155
}
155
156
156
- detectFrame = ( video , model ) => {
157
+ detectFrame = ( video , model , labels ) => {
157
158
TFWrapper ( model )
158
159
. detect ( video )
159
160
. then ( predictions => {
160
- this . renderPredictions ( predictions )
161
+ this . renderPredictions ( predictions , labels )
161
162
requestAnimationFrame ( ( ) => {
162
- this . detectFrame ( video , model )
163
+ this . detectFrame ( video , model , labels )
163
164
} )
164
165
} )
165
166
}
166
167
167
- renderPredictions = predictions => {
168
+ renderPredictions = ( predictions , labels ) => {
168
169
const ctx = this . canvasRef . current . getContext ( '2d' )
169
170
ctx . clearRect ( 0 , 0 , ctx . canvas . width , ctx . canvas . height )
170
171
// Font options.
@@ -176,7 +177,7 @@ class App extends React.Component {
176
177
const y = prediction . bbox [ 1 ]
177
178
const width = prediction . bbox [ 2 ]
178
179
const height = prediction . bbox [ 3 ]
179
- const label = LABELS [ parseInt ( prediction . class ) ]
180
+ const label = labels [ parseInt ( prediction . class ) ]
180
181
// Draw the bounding box.
181
182
ctx . strokeStyle = '#00FFFF'
182
183
ctx . lineWidth = 4
@@ -191,7 +192,7 @@ class App extends React.Component {
191
192
predictions . forEach ( prediction => {
192
193
const x = prediction . bbox [ 0 ]
193
194
const y = prediction . bbox [ 1 ]
194
- const label = LABELS [ parseInt ( prediction . class ) ]
195
+ const label = labels [ parseInt ( prediction . class ) ]
195
196
// Draw the text last to ensure it's on top.
196
197
ctx . fillStyle = '#000000'
197
198
ctx . fillText ( label , x , y )
0 commit comments