@@ -15,46 +15,101 @@ public class DataSet
1515
1616 }
1717
18+ public struct MnistImage
19+ {
20+ public int Cols , Rows ;
21+ public byte [ ] Data ;
22+
23+ public MnistImage ( int cols , int rows , byte [ ] data )
24+ {
25+ Cols = cols ;
26+ Rows = rows ;
27+ Data = data ;
28+ }
29+ }
30+
1831 public class Mnist
1932 {
2033 public DataSet Train { get ; private set ; }
2134 public DataSet Validation { get ; private set ; }
2235 public DataSet Test { get ; private set ; }
2336
24- const string SourceUrl = "http://yann.lecun.com/exdb/mnist/" ;
25- const string TrainImages = "train-images-idx3-ubyte.gz" ;
26- const string TrainLabels = "train-labels-idx1-ubyte.gz" ;
27- const string TestImages = "t10k-images-idx3-ubyte.gz" ;
28- const string TestLabels = "t10k-labels-idx1-ubyte.gz" ;
29-
37+ public MnistImage [ ] TrainImages , TestImages , ValidationImages ;
38+ public byte [ ] TrainLabels , TestLabels , ValidationLabels ;
3039
3140 int Read32 ( Stream s )
3241 {
3342 var x = new byte [ 4 ] ;
3443 s . Read ( x , 0 , 4 ) ;
35- return DataConverter . BigEndian . GetInt32 ( x ) ;
44+ return DataConverter . BigEndian . GetInt32 ( x , 0 ) ;
3645 }
3746
38- void ExtractImages ( Stream input , string file )
47+ MnistImage [ ] ExtractImages ( Stream input , string file )
3948 {
40- var gz = new GZipStream ( input , CompressionMode . Decompress ) ;
41- if ( Read32 ( gz ) != 2051 )
42- throw new Exception ( "Invalid magic number found on the MNIST " + file ) ;
43- var count = Read32 ( gz ) ;
44- var rows = Read32 ( gz ) ;
45- var cols = Read32 ( gz ) ;
46- var buffer = new byte [ rows * cols * count ] ;
49+ using ( var gz = new GZipStream ( input , CompressionMode . Decompress ) ) {
50+ if ( Read32 ( gz ) != 2051 )
51+ throw new Exception ( "Invalid magic number found on the MNIST " + file ) ;
52+ var count = Read32 ( gz ) ;
53+ var rows = Read32 ( gz ) ;
54+ var cols = Read32 ( gz ) ;
4755
56+ var result = new MnistImage [ count ] ;
57+ for ( int i = 0 ; i < count ; i ++ ) {
58+ var size = rows * cols ;
59+ var data = new byte [ size ] ;
60+ gz . Read ( data , 0 , size ) ;
4861
62+ result [ i ] = new MnistImage ( cols , rows , data ) ;
63+ }
64+ return result ;
65+ }
66+ }
67+
68+
69+ byte [ ] ExtractLabels ( Stream input , string file )
70+ {
71+ using ( var gz = new GZipStream ( input , CompressionMode . Decompress ) ) {
72+ if ( Read32 ( gz ) != 2049 )
73+ throw new Exception ( "Invalid magic number found on the MNIST " + file ) ;
74+ var count = Read32 ( gz ) ;
75+ var labels = new byte [ count ] ;
76+ gz . Read ( labels , 0 , count ) ;
77+
78+ return labels ;
79+ }
80+ }
81+
82+ T [ ] Pick < T > ( T [ ] source , int first , int last )
83+ {
84+ if ( last == 0 )
85+ last = source . Length ;
86+ var count = last - first ;
87+ var result = new T [ count ] ;
88+ Array . Copy ( source , first , result , 0 , count ) ;
89+ return result ;
4990 }
5091
5192 public void ReadDataSets ( string trainDir , bool fakeData = false , bool oneHot = false , TFDataType dtype = TFDataType . Float , bool reshape = true , int validationSize = 5000 )
5293 {
94+ const string SourceUrl = "http://yann.lecun.com/exdb/mnist/" ;
95+ const string TrainImagesName = "train-images-idx3-ubyte.gz" ;
96+ const string TrainLabelsName = "train-labels-idx1-ubyte.gz" ;
97+ const string TestImagesName = "t10k-images-idx3-ubyte.gz" ;
98+ const string TestLabelsName = "t10k-labels-idx1-ubyte.gz" ;
99+
53100 if ( fakeData ) {
54101 return ;
55102 }
56103
57- ExtractImages ( Helper . MaybeDownload ( SourceUrl , trainDir , TrainImages ) , TrainImages ) ;
104+ TrainImages = ExtractImages ( Helper . MaybeDownload ( SourceUrl , trainDir , TrainImagesName ) , TrainImagesName ) ;
105+ TestImages = ExtractImages ( Helper . MaybeDownload ( SourceUrl , trainDir , TestImagesName ) , TestImagesName ) ;
106+ TrainLabels = ExtractLabels ( Helper . MaybeDownload ( SourceUrl , trainDir , TrainLabelsName ) , TrainLabelsName ) ;
107+ TestLabels = ExtractLabels ( Helper . MaybeDownload ( SourceUrl , trainDir , TestLabelsName ) , TestLabelsName ) ;
108+
109+ ValidationImages = Pick ( TrainImages , 0 , validationSize ) ;
110+ ValidationLabels = Pick ( TrainLabels , 0 , validationSize ) ;
111+ TrainImages = Pick ( TrainImages , validationSize , 0 ) ;
112+ TrainLabels = Pick ( TrainLabels , validationSize , 0 ) ;
58113
59114 }
60115 }
0 commit comments