Skip to content

Commit 974784d

Browse files
committed
Add
1 parent e135329 commit 974784d

File tree

2 files changed

+73
-18
lines changed

2 files changed

+73
-18
lines changed

Learn/Datasets/MNIST.cs

Lines changed: 71 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,46 +15,101 @@ public class DataSet
1515

1616
}
1717

18+
public struct MnistImage
19+
{
20+
public int Cols, Rows;
21+
public byte [] Data;
22+
23+
public MnistImage (int cols, int rows, byte [] data)
24+
{
25+
Cols = cols;
26+
Rows = rows;
27+
Data = data;
28+
}
29+
}
30+
1831
public class Mnist
1932
{
2033
public DataSet Train { get; private set; }
2134
public DataSet Validation { get; private set; }
2235
public DataSet Test { get; private set; }
2336

24-
const string SourceUrl = "http://yann.lecun.com/exdb/mnist/";
25-
const string TrainImages = "train-images-idx3-ubyte.gz";
26-
const string TrainLabels = "train-labels-idx1-ubyte.gz";
27-
const string TestImages = "t10k-images-idx3-ubyte.gz";
28-
const string TestLabels = "t10k-labels-idx1-ubyte.gz";
29-
37+
public MnistImage [] TrainImages, TestImages, ValidationImages;
38+
public byte [] TrainLabels, TestLabels, ValidationLabels;
3039

3140
int Read32 (Stream s)
3241
{
3342
var x = new byte [4];
3443
s.Read (x, 0, 4);
35-
return DataConverter.BigEndian.GetInt32 (x);
44+
return DataConverter.BigEndian.GetInt32 (x, 0);
3645
}
3746

38-
void ExtractImages (Stream input, string file)
47+
MnistImage [] ExtractImages (Stream input, string file)
3948
{
40-
var gz = new GZipStream (input, CompressionMode.Decompress);
41-
if (Read32 (gz) != 2051)
42-
throw new Exception ("Invalid magic number found on the MNIST " + file);
43-
var count = Read32 (gz);
44-
var rows = Read32 (gz);
45-
var cols = Read32 (gz);
46-
var buffer = new byte [rows * cols * count];
49+
using (var gz = new GZipStream (input, CompressionMode.Decompress)) {
50+
if (Read32 (gz) != 2051)
51+
throw new Exception ("Invalid magic number found on the MNIST " + file);
52+
var count = Read32 (gz);
53+
var rows = Read32 (gz);
54+
var cols = Read32 (gz);
4755

56+
var result = new MnistImage [count];
57+
for (int i = 0; i < count; i++) {
58+
var size = rows * cols;
59+
var data = new byte [size];
60+
gz.Read (data, 0, size);
4861

62+
result [i] = new MnistImage (cols, rows, data);
63+
}
64+
return result;
65+
}
66+
}
67+
68+
69+
byte [] ExtractLabels (Stream input, string file)
70+
{
71+
using (var gz = new GZipStream (input, CompressionMode.Decompress)) {
72+
if (Read32 (gz) != 2049)
73+
throw new Exception ("Invalid magic number found on the MNIST " + file);
74+
var count = Read32 (gz);
75+
var labels = new byte [count];
76+
gz.Read (labels, 0, count);
77+
78+
return labels;
79+
}
80+
}
81+
82+
T [] Pick<T> (T [] source, int first, int last)
83+
{
84+
if (last == 0)
85+
last = source.Length;
86+
var count = last - first;
87+
var result = new T [count];
88+
Array.Copy (source, first, result, 0, count);
89+
return result;
4990
}
5091

5192
public void ReadDataSets (string trainDir, bool fakeData = false, bool oneHot = false, TFDataType dtype = TFDataType.Float, bool reshape = true, int validationSize = 5000)
5293
{
94+
const string SourceUrl = "http://yann.lecun.com/exdb/mnist/";
95+
const string TrainImagesName = "train-images-idx3-ubyte.gz";
96+
const string TrainLabelsName = "train-labels-idx1-ubyte.gz";
97+
const string TestImagesName = "t10k-images-idx3-ubyte.gz";
98+
const string TestLabelsName = "t10k-labels-idx1-ubyte.gz";
99+
53100
if (fakeData) {
54101
return;
55102
}
56103

57-
ExtractImages (Helper.MaybeDownload (SourceUrl, trainDir, TrainImages), TrainImages);
104+
TrainImages = ExtractImages (Helper.MaybeDownload (SourceUrl, trainDir, TrainImagesName), TrainImagesName);
105+
TestImages = ExtractImages (Helper.MaybeDownload (SourceUrl, trainDir, TestImagesName), TestImagesName);
106+
TrainLabels = ExtractLabels (Helper.MaybeDownload (SourceUrl, trainDir, TrainLabelsName), TrainLabelsName);
107+
TestLabels = ExtractLabels (Helper.MaybeDownload (SourceUrl, trainDir, TestLabelsName), TestLabelsName);
108+
109+
ValidationImages = Pick (TrainImages, 0, validationSize);
110+
ValidationLabels = Pick (TrainLabels, 0, validationSize);
111+
TrainImages = Pick (TrainImages, validationSize, 0);
112+
TrainLabels = Pick (TrainLabels, validationSize, 0);
58113

59114
}
60115
}

SampleTest/SampleTest.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,8 @@ public static void Main (string [] args)
213213
t.TestSession ();
214214
//t.AttributesTest ();
215215

216-
//var n = new Mnist ();
217-
//n.ReadDataSets ("/Users/miguel/Downloads");
216+
var n = new Mnist ();
217+
n.ReadDataSets ("/Users/miguel/Downloads");
218218
}
219219
}
220220
}

0 commit comments

Comments
 (0)