@@ -187,9 +187,9 @@ void BasicMatrix ()
187187 } ;
188188 }
189189
190- int ArgMax ( byte [ , ] array , int idx )
190+ int ArgMax ( float [ , ] array , int idx )
191191 {
192- int max = - 1 ;
192+ float max = - 1 ;
193193 int maxIdx = - 1 ;
194194 var l = array . GetLength ( 1 ) ;
195195 for ( int i = 0 ; i < l ; i ++ )
@@ -198,7 +198,17 @@ int ArgMax (byte [,] array, int idx)
198198 max = array [ idx , i ] ;
199199 }
200200 return maxIdx ;
201- }
201+ }
202+
203+ public float [ ] Extract ( float [ , ] array , int index )
204+ {
205+ var n = array . GetLength ( 1 ) ;
206+ var ret = new float [ n ] ;
207+
208+ for ( int i = 0 ; i < n ; i ++ )
209+ ret [ i ] = array [ index , i ] ;
210+ return ret ;
211+ }
202212
203213 // This sample has a bug, I suspect the data loaded is incorrect, because the returned
204214 // values in distance is wrong, and so is the prediction computed from it.
@@ -211,25 +221,21 @@ void NearestNeighbor ()
211221 // 5000 for training
212222 const int trainCount = 5000 ;
213223 const int testCount = 200 ;
214- var Xtr = mnist . GetBatchReader ( mnist . TrainImages ) . ReadAsTensor ( trainCount ) ;
215- var Ytr = mnist . OneHotTrainLabels ;
216- var Xte = mnist . GetBatchReader ( mnist . TestImages ) . Read ( testCount ) ;
217- var Yte = mnist . OneHotTestLabels ;
218-
219-
224+ ( var trainingImages , var trainingLabels ) = mnist . GetTrainReader ( ) . NextBatch ( trainCount ) ;
225+ ( var testImages , var testLabels ) = mnist . GetTestReader ( ) . NextBatch ( testCount ) ;
220226
221227 Console . WriteLine ( "Nearest neighbor on Mnist images" ) ;
222228 using ( var g = new TFGraph ( ) ) {
223229 var s = new TFSession ( g ) ;
224230
225231
226- TFOutput xtr = g . Placeholder ( TFDataType . Float , new TFShape ( - 1 , 784 ) ) ;
232+ TFOutput trainingInput = g . Placeholder ( TFDataType . Float , new TFShape ( - 1 , 784 ) ) ;
227233
228234 TFOutput xte = g . Placeholder ( TFDataType . Float , new TFShape ( 784 ) ) ;
229235
230236 // Nearest Neighbor calculation using L1 Distance
231237 // Calculate L1 Distance
232- TFOutput distance = g . ReduceSum ( g . Abs ( g . Add ( xtr , g . Neg ( xte ) ) ) , axis : g . Const ( 1 ) ) ;
238+ TFOutput distance = g . ReduceSum ( g . Abs ( g . Add ( trainingInput , g . Neg ( xte ) ) ) , axis : g . Const ( 1 ) ) ;
233239
234240 // Prediction: Get min distance index (Nearest neighbor)
235241 TFOutput pred = g . ArgMin ( distance , g . Const ( 0 ) ) ;
@@ -241,15 +247,15 @@ void NearestNeighbor ()
241247
242248 // Get nearest neighbor
243249
244- var result = runner . Fetch ( pred ) . Fetch ( distance ) . AddInput ( xtr , Xtr ) . AddInput ( xte , Xte [ i ] . DataFloat ) . Run ( ) ;
250+ var result = runner . Fetch ( pred ) . Fetch ( distance ) . AddInput ( trainingInput , trainingImages ) . AddInput ( xte , Extract ( testImages , i ) ) . Run ( ) ;
245251 var r = result [ 0 ] . GetValue ( ) ;
246252 var tr = result [ 1 ] . GetValue ( ) ;
247253 var nn_index = ( int ) ( long ) result [ 0 ] . GetValue ( ) ;
248254
249255 // Get nearest neighbor class label and compare it to its true label
250- Console . WriteLine ( $ "Test { i } : Prediction: { ArgMax ( Ytr , nn_index ) } True class: { ArgMax ( Yte , i ) } (nn_index={ nn_index } )") ;
251- if ( ArgMax ( Ytr , nn_index ) == ArgMax ( Yte , i ) )
252- accuracy += 1f / Xte . Length ;
256+ Console . WriteLine ( $ "Test { i } : Prediction: { ArgMax ( trainingLabels , nn_index ) } True class: { ArgMax ( testLabels , i ) } (nn_index={ nn_index } )") ;
257+ if ( ArgMax ( trainingLabels , nn_index ) == ArgMax ( testLabels , i ) )
258+ accuracy += 1f / testImages . Length ;
253259 }
254260 Console . WriteLine ( "Accuracy: " + accuracy ) ;
255261 }
0 commit comments