Skip to content

Commit 2404a53

Browse files
committed
Complete optional attribute support, started work to load Mnist datasets
1 parent 9ee2559 commit 2404a53

10 files changed

Lines changed: 4992 additions & 231 deletions

File tree

Learn/DataConverter.cs

Lines changed: 1862 additions & 0 deletions
Large diffs are not rendered by default.

Learn/Datasets/Helper.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using System;
2+
using System.IO;
3+
using System.Net;
4+
5+
namespace Learn
6+
{
7+
public class Helper
8+
{
9+
public static Stream MaybeDownload (string urlBase, string trainDir, string file)
10+
{
11+
if (!Directory.Exists (trainDir))
12+
Directory.CreateDirectory (trainDir);
13+
var target = Path.Combine (trainDir, file);
14+
if (!File.Exists (target)) {
15+
var wc = new WebClient ();
16+
wc.DownloadFile (urlBase + file, target);
17+
}
18+
return File.OpenRead (target);
19+
}
20+
21+
22+
}
23+
}

Learn/Datasets/MNIST.cs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
//
2+
// Code to download and load the MNIST data.
3+
//
4+
5+
using System;
6+
using System.IO;
7+
using System.IO.Compression;
8+
using Mono;
9+
using TensorFlow;
10+
11+
namespace Learn
12+
{
13+
public class DataSet
14+
{
15+
16+
}
17+
18+
public class Mnist
19+
{
20+
public DataSet Train { get; private set; }
21+
public DataSet Validation { get; private set; }
22+
public DataSet Test { get; private set; }
23+
24+
const string SourceUrl = "http://yann.lecun.com/exdb/mnist/";
25+
const string TrainImages = "train-images-idx3-ubyte.gz";
26+
const string TrainLabels = "train-labels-idx1-ubyte.gz";
27+
const string TestImages = "t10k-images-idx3-ubyte.gz";
28+
const string TestLabels = "t10k-labels-idx1-ubyte.gz";
29+
30+
31+
int Read32 (Stream s)
32+
{
33+
var x = new byte [4];
34+
s.Read (x, 0, 4);
35+
return DataConverter.BigEndian.GetInt32 (x);
36+
}
37+
38+
void ExtractImages (Stream input, string file)
39+
{
40+
var gz = new GZipStream (input, CompressionMode.Decompress);
41+
if (Read32 (gz) != 2051)
42+
throw new Exception ("Invalid magic number found on the MNIST " + file);
43+
var count = Read32 (gz);
44+
var rows = Read32 (gz);
45+
var cols = Read32 (gz);
46+
var buffer = new byte [rows * cols * count];
47+
48+
49+
}
50+
51+
public void ReadDataSets (string trainDir, bool fakeData = false, bool oneHot = false, TFDataType dtype = TFDataType.Float, bool reshape = true, int validationSize = 5000)
52+
{
53+
if (fakeData) {
54+
return;
55+
}
56+
57+
ExtractImages (Helper.MaybeDownload (SourceUrl, trainDir, TrainImages), TrainImages);
58+
59+
}
60+
}
61+
}
Lines changed: 113 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,19 @@
1-
using System;
1+
//
2+
// This is the driver for the operation generator, this takes data that
3+
// is provided by the Tensorflow runtime to produce strongly-typed and
4+
// high level methods on the TFGraph class.
5+
//
6+
// The result is generated into a partial class that is lined with the
7+
// main TensorFlowSharp library
8+
//
9+
// Authors:
10+
// Miguel de Icaza
11+
//
12+
// Copyright 2017, the year of downfall, Microsoft Inc
13+
//
14+
#pragma warning disable RECS0063 // Warns when a culture-aware 'StartsWith' call is used by default.
15+
16+
using System;
217
using System.Collections.Generic;
318
using System.IO;
419
using ProtoBuf;
@@ -9,8 +24,6 @@
924

1025
class OpGenerator
1126
{
12-
StreamWriter output;
13-
1427
//
1528
// Maps a TensorFlow type to a C# type
1629
//
@@ -46,6 +59,15 @@ string CSharpType (string tfType)
4659
return cstype + (list ? "[]" : "");
4760
}
4861

62+
bool IsReferenceType (string tfType)
63+
{
64+
if (tfType.StartsWith ("list("))
65+
return true;
66+
if (tfType == "tensor" || tfType == "string" || tfType == "shape")
67+
return true;
68+
return false;
69+
}
70+
4971
// Maps a parameter name to a C# acceptable name, to avoid clashes with
5072
// language keywords
5173
string ParamMap (string paramName)
@@ -118,10 +140,14 @@ string FillArguments (OpDef def)
118140
sb.AppendFormat ($", ref {type} {ParamMap (arg.name)}");
119141
}
120142

121-
// FIXME: finish this part
122143
int n = 0;
123-
foreach (var attr in optional_attrs)
124-
sb.AppendFormat ($", object optional{n++}");
144+
foreach (var attr in optional_attrs) {
145+
bool reftype = IsReferenceType (attr.type);
146+
var cstype = CSharpType (attr.type);
147+
var cstypesuffix = reftype ? "" : "?";
148+
149+
sb.AppendFormat ($", {cstype}{cstypesuffix} {attr.name} = null");
150+
}
125151
return sb.ToString ();
126152
}
127153

@@ -158,13 +184,62 @@ void GenDocs (OpDef oper)
158184
p ("/// <param name=\"operName\">");
159185
p ($"/// If specified, the created operation in the graph will be this one, otherwise it will be named '{oper.name}'.");
160186
p ("/// </param>");
187+
foreach (var attr in optional_attrs) {
188+
if (String.IsNullOrEmpty (attr.description))
189+
continue;
190+
p ($"/// <param name=\"{ParamMap (attr.name)}\">");
191+
Comment ("Optional argument");
192+
Comment (attr.description);
193+
p ($"/// </param>");
194+
}
195+
161196
if (!String.IsNullOrEmpty (oper.description)) {
162197
p ("/// <remarks>");
163198
Comment (oper.description);
164199
p ("/// </remarks>");
165200
}
166201
}
167202

203+
void SetAttribute (string type, string attrName, string csAttrName)
204+
{
205+
if (type == "shape") {
206+
p ($"desc.SetAttrShape (\"{attrName}\", {csAttrName});");
207+
return;
208+
}
209+
if (type.StartsWith ("list(shape")) {
210+
p ($"desc.SetAttrShape (\"{attrName}\", {csAttrName});");
211+
return;
212+
}
213+
214+
var cstype = CSharpType (type);
215+
switch (cstype) {
216+
case "long":
217+
case "long[]":
218+
case "string":
219+
case "string[]":
220+
case "float":
221+
case "float[]":
222+
case "bool":
223+
case "bool[]":
224+
p ($"desc.SetAttr (\"{attrName}\", {csAttrName});");
225+
break;
226+
case "TFDataType":
227+
case "TFDataType[]":
228+
p ($"desc.SetAttrType (\"{attrName}\", {csAttrName});");
229+
break;
230+
231+
// This should pass the cstatus, but requires the
232+
// function to take a TFStatus as well, so need to weave that
233+
// in the parameters
234+
case "TFTensor":
235+
case "TFTensor[]":
236+
p ($"desc.SetAttr (\"{attrName}\", {csAttrName} /* cstatus */);");
237+
break;
238+
default:
239+
throw new Exception ("Unexpected type: " + cstype);
240+
}
241+
}
242+
168243
/// <summary>
169244
/// Generate the specified oper.
170245
/// </summary>
@@ -190,33 +265,22 @@ void Generate (OpDef oper)
190265
// If we have attributes
191266
if (required_attrs.Count > 0 || optional_attrs.Count > 0) {
192267
foreach (var attr in required_attrs) {
193-
var cstype = CSharpType (attr.type);
194-
switch (cstype) {
195-
case "int":
196-
case "int[]":
197-
case "string":
198-
case "string[]":
199-
case "float":
200-
case "float[]":
201-
case "bool":
202-
case "bool[]":
203-
p ($"desc.SetAttr (\"{attr.name}\", {ParamMap(attr.name)});");
204-
break;
205-
case "TFDataType":
206-
case "TFDataType[]":
207-
p ($"desc.SetAttrType (\"{attr.name}\", {ParamMap (attr.name)});");
208-
break;
209-
210-
// This should pass the cstatus, but requires the
211-
// function to take a TFStatus as well, so need to weave that
212-
// in the parameters
213-
case "TFTensor":
214-
case "TFTensor[]":
215-
p ($"desc.SetAttr (\"{attr.name}\", {ParamMap (attr.name)} /* cstatus */);");
216-
break;
217-
}
268+
SetAttribute (attr.type, attr.name, ParamMap (attr.name));
269+
}
270+
271+
foreach (var attr in optional_attrs) {
272+
var reftype = IsReferenceType (attr.type);
273+
var csattr = ParamMap (attr.name);
274+
if (reftype)
275+
pi ($"if ({csattr} != null)");
276+
else
277+
pi ($"if ({csattr}.HasValue)");
278+
SetAttribute (attr.type, attr.name, csattr + (reftype ? "" : ".Value"));
279+
pd ("");
280+
218281
}
219282
}
283+
220284
p ("var op = desc.FinishOperation ();");
221285
if (oper.output_arg.Any (x => IsListArg (x))) {
222286
p ("int _idx = 0, _n = 0;");
@@ -248,6 +312,8 @@ void Run ()
248312
output = File.CreateText ("../../../TensorFlowSharp/Operations.cs");
249313

250314
var operations = Serializer.Deserialize<List<OpDef>> (new MemoryStream (TFCore.GetAllOpList ().ToArray ()));
315+
p ("using System;\n");
316+
251317
pi ("namespace TensorFlow {");
252318
pi ("public partial class TFGraph {");
253319
foreach (var oper in operations){
@@ -256,16 +322,24 @@ void Run ()
256322
continue;
257323

258324
// Ignore functions where we lack a C# type mapping
259-
if (oper.attr.Any (attr => CSharpType (attr.type) == null))
325+
if (oper.attr.Any (attr => CSharpType (attr.type) == null)) {
326+
var attr = oper.attr.First (a => CSharpType (a.type) == null);
327+
328+
//Console.WriteLine ($"Skip: {oper.name} due to attribute ({attr.type} {attr.name}) lacking a mapping to C#");
260329
continue;
330+
}
261331

262332
// Ignore reference types as well (per go's binding)
263-
if (oper.input_arg.Any (ia => ia.is_ref))
333+
if (oper.input_arg.Any (ia => ia.is_ref)) {
334+
//Console.WriteLine ($"Skip: {oper.name} due to presence of an input argument that is a reference");
264335
continue;
265-
336+
}
337+
266338
// Ignore reference types as well (per go's binding)
267-
if (oper.output_arg.Any (ia => ia.is_ref))
339+
if (oper.output_arg.Any (ia => ia.is_ref)) {
340+
//Console.WriteLine ($"Skip: {oper.name} due to presence of an output argument that is a reference");
268341
continue;
342+
}
269343

270344
// Undocumented operation, perhaps we should not surface
271345
if (oper.summary == "")
@@ -278,8 +352,12 @@ void Run ()
278352
output.Close ();
279353
}
280354

355+
// The output file
356+
StreamWriter output;
357+
281358
int indent = 0;
282359

360+
// Convenience methods to generate output
283361
void pi (string fmt, params object [] args)
284362
{
285363
p (fmt, args);

OpGenerator/OpGenerator.csproj

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
</Reference>
3838
</ItemGroup>
3939
<ItemGroup>
40-
<Compile Include="Program.cs" />
40+
<Compile Include="OpGenerator.cs" />
4141
<Compile Include="Properties\AssemblyInfo.cs" />
4242
<Compile Include="Opdefs.cs" />
4343
</ItemGroup>
@@ -50,5 +50,8 @@
5050
<Name>TensorFlowSharp</Name>
5151
</ProjectReference>
5252
</ItemGroup>
53+
<ItemGroup>
54+
<Folder Include="New Folder\" />
55+
</ItemGroup>
5356
<Import Project="$(MSBuildBinPath)\Microsoft.CSharp.targets" />
5457
</Project>

0 commit comments

Comments
 (0)