Skip to content

Commit 8559b6b

Browse files
committed
Small sample addition
1 parent 68adf2a commit 8559b6b

4 files changed

Lines changed: 144 additions & 2 deletions

File tree

Learn/Datasets/MNIST.cs

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,87 @@
77
using System.IO.Compression;
88
using Mono;
99
using TensorFlow;
10+
using System.Linq;
1011

1112
namespace Learn.Mnist
1213
{
14+
// Stores the per-image MNIST information we loaded from disk
15+
//
16+
// We store the data in two formats, byte array (as it came in from disk), and float array
17+
// where each 0..255 value has been mapped to 0.0f..1.0f
1318
public struct MnistImage
1419
{
1520
public int Cols, Rows;
1621
public byte [] Data;
22+
public float [] DataFloat;
1723

1824
public MnistImage (int cols, int rows, byte [] data)
1925
{
2026
Cols = cols;
2127
Rows = rows;
2228
Data = data;
29+
DataFloat = new float [data.Length];
30+
for (int i = 0; i < data.Length; i++) {
31+
DataFloat [i] = Data [i] / 255f;
32+
}
2333
}
2434
}
2535

36+
// Helper class used to load and work with the Mnist data set
2637
public class Mnist
2738
{
39+
//
40+
// The loaded results
41+
//
2842
public MnistImage [] TrainImages, TestImages, ValidationImages;
29-
public byte [] TrainLabels, TestLabels, ValidationLabels;
43+
public byte [] TrainLabels, TestLabels, ValidationLabels;
3044
public byte [,] OneHotTrainLabels, OneHotTestLabels, OneHotValidationLabels;
3145

46+
// Simple batch reader to get pieces of data from the dataset
47+
public BatchReader GetBatchReader (MnistImage [] source)
48+
{
49+
return new BatchReader (source);
50+
}
51+
52+
public class BatchReader
53+
{
54+
int start = 0;
55+
MnistImage [] source;
56+
57+
public BatchReader (MnistImage [] source)
58+
{
59+
this.source = source;
60+
}
61+
62+
public MnistImage [] Read (int batchSize)
63+
{
64+
var result = new MnistImage [batchSize];
65+
if (start + batchSize < source.Length) {
66+
Array.Copy (source, start, result, 0, batchSize);
67+
start += batchSize;
68+
} else {
69+
var firstLength = source.Length - start;
70+
Array.Copy (source, start, result, 0, firstLength);
71+
Array.Copy (source, 0, result, firstLength, batchSize-firstLength);
72+
start = firstLength;
73+
}
74+
return result;
75+
}
76+
77+
public TFTensor ReadAsTensor (int batchSize)
78+
{
79+
var result = new float [batchSize, 784];
80+
81+
var x = Read (batchSize);
82+
int p = 0;
83+
for (int i = 0; i < batchSize; i++) {
84+
Buffer.BlockCopy (x [i].DataFloat, 0, result, p, 784);
85+
p += 784;
86+
}
87+
return (TFTensor)result;
88+
}
89+
}
90+
3291
int Read32 (Stream s)
3392
{
3493
var x = new byte [4];
@@ -98,7 +157,7 @@ T [] Pick<T> (T [] source, int first, int last)
98157
/// <param name="trainDir">Directory where the training data is downlaoded to.</param>
99158
/// <param name="numClasses">Number classes to use for one-hot encoding, or zero if this is not desired</param>
100159
/// <param name="validationSize">Validation size.</param>
101-
public void ReadDataSets (string trainDir, int numClasses = 0, int validationSize = 5000)
160+
public void ReadDataSets (string trainDir, int numClasses = 10, int validationSize = 5000)
102161
{
103162
const string SourceUrl = "http://yann.lecun.com/exdb/mnist/";
104163
const string TrainImagesName = "train-images-idx3-ubyte.gz";
@@ -122,5 +181,12 @@ public void ReadDataSets (string trainDir, int numClasses = 0, int validationSiz
122181
OneHotTestLabels = OneHot (TestLabels, numClasses);
123182
}
124183
}
184+
185+
public static Mnist Load ()
186+
{
187+
var x = new Mnist ();
188+
x.ReadDataSets ("/tmp");
189+
return x;
190+
}
125191
}
126192
}

SampleTest/SampleTest.cs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.IO;
66
using System.Collections.Generic;
77
using Learn.Mnist;
8+
using System.Linq;
89

910
namespace SampleTest
1011
{
@@ -441,6 +442,67 @@ void BasicMatrix ()
441442

442443
};
443444
}
445+
446+
int ArgMax (byte [,] array, int idx)
447+
{
448+
int max = -1;
449+
int maxIdx = -1;
450+
var l = array.GetLength (1);
451+
for (int i = 0; i < l; i++)
452+
if (array [idx, i] > max)
453+
maxIdx = i;
454+
return maxIdx;
455+
}
456+
457+
void NearestNeighbor ()
458+
{
459+
// Get the Mnist data
460+
461+
var mnist = Mnist.Load ();
462+
463+
// 5000 for training
464+
const int trainCount = 5000;
465+
const int testCount = 200;
466+
var Xtr = mnist.GetBatchReader (mnist.TrainImages).ReadAsTensor (trainCount);
467+
var Ytr = mnist.OneHotTrainLabels;
468+
var Xte = mnist.GetBatchReader (mnist.TestImages).Read (testCount);
469+
var Yte = mnist.OneHotTestLabels;
470+
471+
472+
473+
Console.WriteLine ("Nearest neighbor on Mnist images");
474+
using (var g = new TFGraph ()) {
475+
var s = new TFSession (g);
476+
477+
var xtr = g.Placeholder (TFDataType.Float, new TFShape (-1, 784));
478+
var xte = g.Placeholder (TFDataType.Float, new TFShape (784));
479+
480+
// Nearest Neighbor calculation using L1 Distance
481+
// Calculate L1 Distance
482+
var distance = g.ReduceSum (g.Abs (g.Add (xtr, g.Neg (xte))), axis: g.Const (1));
483+
484+
// Prediction: Get min distance index (Nearest neighbor)
485+
var pred = g.ArgMin (distance, g.Const (0));
486+
487+
var accuracy = 0f;
488+
// Loop over the test data
489+
for (int i = 0; i < testCount; i++) {
490+
var runner = s.GetRunner ();
491+
492+
// Get nearest neighbor
493+
494+
var result = runner.Fetch (pred).AddInput (xtr, Xtr).AddInput (xte, Xte [i].DataFloat).Run ();
495+
var nn_index = (int)(long) result [0].GetValue ();
496+
497+
// Get nearest neighbor class label and compare it to its true label
498+
Console.WriteLine ($"Test {i}: Prediction: {nn_index} {ArgMax (Ytr, nn_index)} True class: {ArgMax (Yte, i)}");
499+
if (ArgMax (Ytr, nn_index) == ArgMax (Yte, i))
500+
accuracy += 1f/ Xte.Length;
501+
}
502+
Console.WriteLine ("Accuracy: " + accuracy);
503+
}
504+
}
505+
444506
#if true
445507
void LinearRegression ()
446508
{
@@ -514,6 +576,7 @@ public static void Main (string [] args)
514576
t.BasicVariables ();
515577
t.BasicMatrix ();
516578

579+
t.NearestNeighbor ();
517580

518581
}
519582
}

TensorFlowSharp/OperationsExtras.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ public TFOutput ReduceSum (TFOutput input, TFOutput? axis = null, bool? keep_dim
6969
/// <param name="value">Returns the value of the variable.</param>
7070
/// <param name="operName">Operation name, optional.</param>
7171
/// <returns>The returning TFOutput returns the handle to the variable.</returns>
72+
/// <remarks>
73+
/// Variables need to be initialized before the main execution so you will typically want to
74+
/// run the session on the variable
75+
/// </remarks>
7276
public TFOutput Variable (TFOutput initialValue, out TFOperation init, out TFOutput value, string operName = null)
7377
{
7478
var scopeName = MakeName ("Variable", operName);

TensorFlowSharp/Tensorflow.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2038,6 +2038,15 @@ public TFOutput GetOutput (TFInput operIn)
20382038
/// <summary>
20392039
/// Represents a specific output of an operation on a tensor.
20402040
/// </summary>
2041+
/// <remarks>
2042+
/// TFOutput objects represent one of the outputs of an operation in the graph
2043+
/// (TFGraph). Outputs have a data type, and eventually a shape that you can
2044+
/// retrieve by calling the <see cref="M:TensorFlow.TFGraph.GetShape"/> method.
2045+
///
2046+
/// These can be passed as an input argument to a function for adding operations
2047+
/// to a graph, or to the TFSession's Run and GetRunner method as values to be
2048+
/// fetched.
2049+
/// </remarks>
20412050
[StructLayout (LayoutKind.Sequential)]
20422051
public struct TFOutput
20432052
{

0 commit comments

Comments
 (0)