Skip to content

Commit ba0fb1b

Browse files
committed
Make TFOperationDesc API fluent, document Runner API, make Run(TFOperation) clear previous fetches and return a single-value, instead of an array
1 parent c3b1e5f commit ba0fb1b

File tree

4 files changed

+221
-42
lines changed

4 files changed

+221
-42
lines changed

Learn/Datasets/Helper.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,5 @@ public static Stream MaybeDownload (string urlBase, string trainDir, string file
1717
}
1818
return File.OpenRead (target);
1919
}
20-
21-
2220
}
2321
}

SampleTest/SampleTest.cs

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//
1+
//
22
// This is just a dumping ground to exercise different capabilities
33
// of the API. Some idioms might be useful, some not, feel free to
44
//
@@ -58,12 +58,12 @@ void BasicConstantOps ()
5858

5959
// Add two constants
6060
var results = s.GetRunner ().Run (g.Add (a, b));
61-
var val = results [0].GetValue ();
61+
var val = results.GetValue ();
6262
Console.WriteLine ("a+b={0}", val);
6363

6464
// Multiply two constants
6565
results = s.GetRunner ().Run (g.Mul (a, b));
66-
Console.WriteLine ("a*b={0}", results [0].GetValue ());
66+
Console.WriteLine ("a*b={0}", results.GetValue ());
6767

6868
// TODO: API-wise, perhaps session.Run () can have a simple
6969
// overload where we only care about the fetched values,
@@ -92,13 +92,13 @@ void BasicVariables ()
9292
var runner = s.GetRunner ();
9393
runner.AddInput (var_a, new TFTensor ((short)3));
9494
runner.AddInput (var_b, new TFTensor ((short)2));
95-
Console.WriteLine ("a+b={0}", runner.Run (add) [0].GetValue ());
95+
Console.WriteLine ("a+b={0}", runner.Run (add).GetValue ());
9696

9797
runner = s.GetRunner ();
9898
runner.AddInput (var_a, new TFTensor ((short)3));
9999
runner.AddInput (var_b, new TFTensor ((short)2));
100100

101-
Console.WriteLine ("a*b={0}", runner.Run (mul) [0].GetValue ());
101+
Console.WriteLine ("a*b={0}", runner.Run (mul).GetValue ());
102102

103103
// TODO
104104
// Would be nice to have an API that allows me to pass the values at Run time, easily:
@@ -180,7 +180,7 @@ void BasicMatrix ()
180180
var product = g.MatMul (matrix1, matrix2);
181181

182182

183-
var result = s.GetRunner ().Run (product) [0];
183+
var result = s.GetRunner ().Run (product);
184184
Console.WriteLine ("Tensor ToString=" + result);
185185
Console.WriteLine ("Value [0,0]=" + ((double[,])result.GetValue ())[0,0]);
186186

@@ -200,6 +200,8 @@ int ArgMax (byte [,] array, int idx)
200200
return maxIdx;
201201
}
202202

203+
// This sample has a bug, I suspect the data loaded is incorrect, because the returned
204+
// values in distance is wrong, and so is the prediction computed from it.
203205
void NearestNeighbor ()
204206
{
205207
// Get the Mnist data
@@ -220,15 +222,17 @@ void NearestNeighbor ()
220222
using (var g = new TFGraph ()) {
221223
var s = new TFSession (g);
222224

223-
var xtr = g.Placeholder (TFDataType.Float, new TFShape (-1, 784));
224-
var xte = g.Placeholder (TFDataType.Float, new TFShape (784));
225+
226+
TFOutput xtr = g.Placeholder (TFDataType.Float, new TFShape (-1, 784));
227+
228+
TFOutput xte = g.Placeholder (TFDataType.Float, new TFShape (784));
225229

226230
// Nearest Neighbor calculation using L1 Distance
227231
// Calculate L1 Distance
228-
var distance = g.ReduceSum (g.Abs (g.Add (xtr, g.Neg (xte))), axis: g.Const (1));
232+
TFOutput distance = g.ReduceSum (g.Abs (g.Add (xtr, g.Neg (xte))), axis: g.Const (1));
229233

230234
// Prediction: Get min distance index (Nearest neighbor)
231-
var pred = g.ArgMin (distance, g.Const (0));
235+
TFOutput pred = g.ArgMin (distance, g.Const (0));
232236

233237
var accuracy = 0f;
234238
// Loop over the test data
@@ -237,11 +241,13 @@ void NearestNeighbor ()
237241

238242
// Get nearest neighbor
239243

240-
var result = runner.Fetch (pred).AddInput (xtr, Xtr).AddInput (xte, Xte [i].DataFloat).Run ();
244+
var result = runner.Fetch (pred).Fetch (distance).AddInput (xtr, Xtr).AddInput (xte, Xte [i].DataFloat).Run ();
245+
var r = result [0].GetValue ();
246+
var tr = result [1].GetValue ();
241247
var nn_index = (int)(long) result [0].GetValue ();
242248

243249
// Get nearest neighbor class label and compare it to its true label
244-
Console.WriteLine ($"Test {i}: Prediction: {nn_index} {ArgMax (Ytr, nn_index)} True class: {ArgMax (Yte, i)}");
250+
Console.WriteLine ($"Test {i}: Prediction: {ArgMax (Ytr, nn_index)} True class: {ArgMax (Yte, i)} (nn_index={nn_index})");
245251
if (ArgMax (Ytr, nn_index) == ArgMax (Yte, i))
246252
accuracy += 1f/ Xte.Length;
247253
}
@@ -300,6 +306,8 @@ public static void Main (string [] args)
300306

301307

302308
var t = new MainClass ();
309+
t.TestParametersWithIndexes ();
310+
t.AddControlInput ();
303311
t.TestImportGraphDef ();
304312
t.TestSession ();
305313
t.TestOperationOutputListSize ();

TensorFlowSharp/TensorFlowSharp.csproj

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
<AssemblyName>TensorFlowSharp</AssemblyName>
1111
<TargetFrameworkVersion>v4.6.1</TargetFrameworkVersion>
1212
<ReleaseVersion>0.2</ReleaseVersion>
13-
<PackOnBuild>true</PackOnBuild>
1413
<PackageId>TensorFlowSharp</PackageId>
1514
<PackageVersion>0.96</PackageVersion>
1615
<Authors>Miguel de Icaza</Authors>

0 commit comments

Comments
 (0)