Skip to content

Commit a419c32

Browse files
committed
More overloads, improved API, more tests
1 parent 677d1b9 commit a419c32

File tree

9 files changed

+2579
-2336
lines changed

9 files changed

+2579
-2336
lines changed

Learn/Datasets/MNIST.cs

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ public class Mnist
2727
{
2828
public MnistImage [] TrainImages, TestImages, ValidationImages;
2929
public byte [] TrainLabels, TestLabels, ValidationLabels;
30+
public byte [,] OneHotTrainLabels, OneHotTestLabels, OneHotValidationLabels;
3031

3132
int Read32 (Stream s)
3233
{
@@ -80,18 +81,31 @@ T [] Pick<T> (T [] source, int first, int last)
8081
return result;
8182
}
8283

83-
public void ReadDataSets (string trainDir, bool fakeData = false, bool oneHot = false, TFDataType dtype = TFDataType.Float, bool reshape = true, int validationSize = 5000)
84+
// Turn the labels array that contains values 0..numClasses-1 into
85+
// a One-hot encoded array
86+
byte [,] OneHot (byte [] labels, int numClasses)
87+
{
88+
var oneHot = new byte [labels.Length, numClasses];
89+
for (int i = 0; i < labels.Length; i++) {
90+
oneHot [i, labels [i]] = 1;
91+
}
92+
return oneHot;
93+
}
94+
95+
/// <summary>
96+
/// Reads the data sets.
97+
/// </summary>
98+
/// <param name="trainDir">Directory where the training data is downlaoded to.</param>
99+
/// <param name="numClasses">Number classes to use for one-hot encoding, or zero if this is not desired</param>
100+
/// <param name="validationSize">Validation size.</param>
101+
public void ReadDataSets (string trainDir, int numClasses = 0, int validationSize = 5000)
84102
{
85103
const string SourceUrl = "http://yann.lecun.com/exdb/mnist/";
86104
const string TrainImagesName = "train-images-idx3-ubyte.gz";
87105
const string TrainLabelsName = "train-labels-idx1-ubyte.gz";
88106
const string TestImagesName = "t10k-images-idx3-ubyte.gz";
89107
const string TestLabelsName = "t10k-labels-idx1-ubyte.gz";
90108

91-
if (fakeData) {
92-
return;
93-
}
94-
95109
TrainImages = ExtractImages (Helper.MaybeDownload (SourceUrl, trainDir, TrainImagesName), TrainImagesName);
96110
TestImages = ExtractImages (Helper.MaybeDownload (SourceUrl, trainDir, TestImagesName), TestImagesName);
97111
TrainLabels = ExtractLabels (Helper.MaybeDownload (SourceUrl, trainDir, TrainLabelsName), TrainLabelsName);
@@ -102,6 +116,11 @@ public void ReadDataSets (string trainDir, bool fakeData = false, bool oneHot =
102116
TrainImages = Pick (TrainImages, validationSize, 0);
103117
TrainLabels = Pick (TrainLabels, validationSize, 0);
104118

119+
if (numClasses != -1) {
120+
OneHotTrainLabels = OneHot (TrainLabels, numClasses);
121+
OneHotValidationLabels = OneHot (ValidationLabels, numClasses);
122+
OneHotTestLabels = OneHot (TestLabels, numClasses);
123+
}
105124
}
106125
}
107126
}

OpGenerator/OpGenerator.cs

Lines changed: 75 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ bool IsListArg (OpDef.ArgDef arg)
9292
//
9393
Dictionary<string, bool> inferred_input_args;
9494
List<OpDef.AttrDef> required_attrs, optional_attrs;
95+
bool return_is_tfoutput;
9596

9697
void SetupArguments (OpDef def)
9798
{
@@ -117,6 +118,22 @@ void SetupArguments (OpDef def)
117118
else
118119
optional_attrs.Add (attr);
119120
}
121+
// API: currently, if we have a single ref TFOutput result, we make the signature of the
122+
// function return that TFOutput instead of the TFOperation (as you can get the TFOperation
123+
// from the TFOutput anyways.
124+
//
125+
// When we move to tuples, we could probably put everything in a Tuple result, but for now
126+
// mult-return functions will just return all outputs on ref variables, instead of the first
127+
// as a ref, and the rest as TFOutputs.
128+
//
129+
// This means that we generate methods like this:
130+
// TFOutput Constant (....)
131+
// when there is a single output
132+
//
133+
// TFOperation Foo (..)
134+
// When there is no result or more than one result.
135+
return_is_tfoutput = def.output_arg.Count == 1;
136+
120137
}
121138

122139
// Generates arguments:
@@ -134,10 +151,12 @@ string FillArguments (OpDef def)
134151
foreach (var attr in required_attrs)
135152
sb.AppendFormat ($", {CSharpType (attr.type)} {ParamMap (attr.name)}");
136153

137-
foreach (var arg in def.output_arg) {
138-
string type = "TFOutput" + (IsListArg (arg) ? "[]" : "");
154+
if (!return_is_tfoutput) {
155+
foreach (var arg in def.output_arg) {
156+
string type = "TFOutput" + (IsListArg (arg) ? "[]" : "");
139157

140-
sb.AppendFormat ($", ref {type} {ParamMap (arg.name)}");
158+
sb.AppendFormat ($", ref {type} {ParamMap (arg.name)}");
159+
}
141160
}
142161

143162
int n = 0;
@@ -161,6 +180,7 @@ void Comment (string text)
161180
}
162181
}
163182

183+
164184
// Produces the C# inline documentation
165185
void GenDocs (OpDef oper)
166186
{
@@ -174,12 +194,14 @@ void GenDocs (OpDef oper)
174194
Comment (input.description);
175195
p ($"/// </param>");
176196
}
177-
foreach (var attr in oper.output_arg) {
178-
if (String.IsNullOrEmpty (attr.description))
179-
continue;
180-
p ($"/// <param name=\"{ParamMap (attr.name)}\">");
181-
Comment (attr.description);
182-
p ($"/// </param>");
197+
if (!return_is_tfoutput) {
198+
foreach (var attr in oper.output_arg) {
199+
if (String.IsNullOrEmpty (attr.description))
200+
continue;
201+
p ($"/// <param name=\"{ParamMap (attr.name)}\">");
202+
Comment (attr.description);
203+
p ($"/// </param>");
204+
}
183205
}
184206
p ("/// <param name=\"operName\">");
185207
p ($"/// If specified, the created operation in the graph will be this one, otherwise it will be named '{oper.name}'.");
@@ -193,6 +215,11 @@ void GenDocs (OpDef oper)
193215
p ($"/// </param>");
194216
}
195217

218+
if (return_is_tfoutput) {
219+
p ($"/// <returns>");
220+
Comment (oper.output_arg.First ().description);
221+
p ($"/// </returns>");
222+
}
196223
if (!String.IsNullOrEmpty (oper.description)) {
197224
p ("/// <remarks>");
198225
Comment (oper.description);
@@ -246,15 +273,26 @@ void SetAttribute (string type, string attrName, string csAttrName)
246273
/// <param name="oper">Oper.</param>
247274
void Generate (OpDef oper)
248275
{
276+
249277
SetupArguments (oper);
250278
GenDocs (oper);
251279

252280
var name = oper.name;
281+
string retType;
282+
283+
if (return_is_tfoutput) {
284+
if (oper.output_arg.Any (x => IsListArg (x)))
285+
retType = "TFOutput []";
286+
else
287+
retType = "TFOutput";
288+
} else
289+
retType = "TFOperation";
253290

254-
p ($"public TFOperation {name} (Scope scope{FillArguments(oper)}, string operName = null)");
291+
292+
p ($"public {retType} {name} (Scope scope{FillArguments(oper)}, string operName = null)");
255293
pi ("{");
256294
bool needStatus = required_attrs.Concat (optional_attrs).Any (attr => attr.type.Contains ("TFTensor"));
257-
p ($"var desc = new TFOperationDesc (this, operName, operName == null ? \"{oper.name}\" : operName);");
295+
p ($"var desc = new TFOperationDesc (this, \"{oper.name}\", operName == null ? \"{oper.name}\" : operName);");
258296
foreach (var arg in oper.input_arg) {
259297
if (IsListArg (arg))
260298
p ($"desc.AddInputs ({ParamMap (arg.name)});");
@@ -285,24 +323,42 @@ void Generate (OpDef oper)
285323
if (oper.output_arg.Any (x => IsListArg (x))) {
286324
p ("int _idx = 0, _n = 0;");
287325
foreach (var arg in oper.output_arg) {
288-
326+
string retDecl = "", retOutput;
327+
328+
if (return_is_tfoutput){
329+
retDecl = "var ";
330+
retOutput = "_ret";
331+
} else
332+
retOutput = ParamMap (arg.name);
333+
289334
if (IsListArg (arg)) {
290335
var outputs = new StringBuilder ();
291-
p ("_n = op.InputListLength (\"arg.name\");");
292-
p ($"{ParamMap (arg.name)} = new TFOutput [_n];");
336+
p ($"_n = op.OutputListLength (\"{arg.name}\");");
337+
p ($"{retDecl}{retOutput} = new TFOutput [_n];");
293338
pi ("for (int i = 0; i < _n; i++)");
294-
p ($"{ParamMap (arg.name)} [i] = new TFOutput (op, _idx++);");
339+
p ($"{retOutput} [i] = new TFOutput (op, _idx++);");
295340
pd ("");
296-
} else
297-
p ($"{ParamMap (arg.name)} = new TFOutput (op, _idx++);");
341+
if (return_is_tfoutput)
342+
p ($"return {retOutput};");
343+
} else {
344+
if (return_is_tfoutput) {
345+
p ($"return new TFOutput (op, _idx++);");
346+
} else {
347+
p ($"{retOutput} = new TFOutput (op, _idx++);");
348+
}
349+
}
298350
}
299351
} else {
300352
int idx = 0;
301353
foreach (var arg in oper.output_arg) {
302-
p ($"{ParamMap (arg.name)} = new TFOutput (op, {idx++});");
354+
if (return_is_tfoutput)
355+
p ($"return new TFOutput (op, 0);");
356+
else
357+
p ($"{ParamMap (arg.name)} = new TFOutput (op, {idx++});");
303358
}
304359
}
305-
p ("return op;");
360+
if (!return_is_tfoutput)
361+
p ("return op;");
306362
pd ("}\n");
307363
}
308364

SampleTest/SampleTest.cs

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
using TensorFlow;
55
using System.IO;
66
using System.Collections.Generic;
7-
using Learn;
7+
using Learn.Mnist;
8+
using CsvHelper;
89

910
namespace SampleTest
1011
{
@@ -139,6 +140,32 @@ public void TestSession ()
139140
}
140141
}
141142

143+
public void TestOperationOutputListSize ()
144+
{
145+
using (var graph = new TFGraph ()) {
146+
var c1 = graph.Const (null, TFTensor.Constant (1L), "c1");
147+
var c2 = graph.Const (null, TFTensor.Constant (new long [,] { { 1, 2 }, { 3, 4 } }), "c2");
148+
149+
var outputs = graph.ShapeN (null, new TFOutput [] { c1, c2 });
150+
var op = outputs [0].Operation;
151+
152+
Assert (op.OutputListLength ("output") == 2);
153+
Assert (op.NumOutputs == 2);
154+
}
155+
}
156+
157+
public void TestOutputShape ()
158+
{
159+
using (var graph = new TFGraph ()) {
160+
var c1 = graph.Const (null, TFTensor.Constant (0L), "c1");
161+
var s1 = graph.GetShape (c1);
162+
var c2 = graph.Const (null, TFTensor.Constant (new long [] { 1, 2, 3 }), "c2");
163+
var s2 = graph.GetShape (c2);
164+
var c3 = graph.Const (null, TFTensor.Constant (new long [,] { { 1, 2, 3 }, { 4, 5, 6 } }), "c3");
165+
var s3 = graph.GetShape (c3);
166+
}
167+
}
168+
142169
// For this to work, we need to surface REGISTER_OP from C++ to C
143170

144171
class AttributeTest : IDisposable
@@ -194,6 +221,7 @@ public void AttributesTest ()
194221

195222
}
196223

224+
197225
public static void p (string p)
198226
{
199227
Console.WriteLine (p);
@@ -211,10 +239,21 @@ public static void Main (string [] args)
211239
var t = new MainClass ();
212240
t.TestImportGraphDef ();
213241
t.TestSession ();
242+
t.TestOperationOutputListSize ();
243+
244+
// Current failing test
245+
t.TestOutputShape ();
214246
//t.AttributesTest ();
215247

248+
216249
var n = new Mnist ();
217-
n.ReadDataSets ("/Users/miguel/Downloads");
250+
const int img_size = 28;
251+
const int img_size_flat = img_size * img_size;
252+
const int num_channels = 1; // black and white
253+
const int numClasses = 10;
254+
255+
//n.ReadDataSets ("/Users/miguel/Downloads", numClasses: numClasses);
256+
218257
}
219258
}
220259
}

SampleTest/SampleTest.csproj

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@
2929
<ItemGroup>
3030
<Reference Include="System" />
3131
<Reference Include="System.Xml" />
32+
<Reference Include="CsvHelper">
33+
<HintPath>..\packages\CsvHelper.2.16.3.0\lib\net45\CsvHelper.dll</HintPath>
34+
</Reference>
35+
<Reference Include="mscorlib" />
36+
<Reference Include="System.Core" />
37+
<Reference Include="System.Numerics" />
3238
</ItemGroup>
3339
<ItemGroup>
3440
<Compile Include="SampleTest.cs" />
@@ -44,5 +50,8 @@
4450
<Name>Learn</Name>
4551
</ProjectReference>
4652
</ItemGroup>
53+
<ItemGroup>
54+
<None Include="packages.config" />
55+
</ItemGroup>
4756
<Import Project="$(MSBuildBinPath)\Microsoft.CSharp.targets" />
4857
</Project>

SampleTest/packages.config

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
<?xml version="1.0" encoding="utf-8"?>
2+
<packages>
3+
<package id="CsvHelper" version="2.16.3.0" targetFramework="net45" />
4+
</packages>

0 commit comments

Comments
 (0)