Skip to content

Commit c541944

Browse files
committed
Make it simpler to spot the usage samples
1 parent 958a2c0 commit c541944

File tree

4 files changed

+308
-287
lines changed

4 files changed

+308
-287
lines changed

SampleTest/LowLevelTests.cs

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
//
2+
// Low-level tests
3+
//
4+
using System;
5+
using System.Runtime.CompilerServices;
6+
using System.Runtime.InteropServices;
7+
using TensorFlow;
8+
using System.IO;
9+
using System.Collections.Generic;
10+
using Learn.Mnist;
11+
using System.Linq;
12+
13+
namespace SampleTest
14+
{
15+
partial class MainClass
16+
{
17+
TFOperation Placeholder (TFGraph graph, TFStatus s)
18+
{
19+
var desc = new TFOperationDesc (graph, "Placeholder", "feed");
20+
desc.SetAttrType ("dtype", TFDataType.Int32);
21+
Console.WriteLine ("Handle: {0}", desc.Handle);
22+
var j = desc.FinishOperation ();
23+
Console.WriteLine ("FinishHandle: {0}", j.Handle);
24+
return j;
25+
}
26+
27+
TFOperation ScalarConst (int v, TFGraph graph, TFStatus status)
28+
{
29+
var desc = new TFOperationDesc (graph, "Const", "scalar");
30+
desc.SetAttr ("value", v, status);
31+
if (status.StatusCode != TFCode.Ok)
32+
return null;
33+
desc.SetAttrType ("dtype", TFDataType.Int32);
34+
return desc.FinishOperation ();
35+
}
36+
37+
TFOperation Add (TFOperation left, TFOperation right, TFGraph graph, TFStatus status)
38+
{
39+
var op = new TFOperationDesc (graph, "AddN", "add");
40+
41+
op.AddInputs (new TFOutput (left, 0), new TFOutput (right, 0));
42+
return op.FinishOperation ();
43+
}
44+
45+
public void TestImportGraphDef ()
46+
{
47+
var status = new TFStatus ();
48+
TFBuffer graphDef;
49+
50+
// Create graph with two nodes, "x" and "3"
51+
using (var graph = new TFGraph ()) {
52+
Assert (status);
53+
Placeholder (graph, status);
54+
Assert (graph ["feed"] != null);
55+
56+
ScalarConst (3, graph, status);
57+
Assert (graph ["scalar"] != null);
58+
59+
// Export to GraphDef
60+
graphDef = new TFBuffer ();
61+
graph.ToGraphDef (graphDef, status);
62+
Assert (status);
63+
};
64+
65+
// Import it again, with a prefix, in a fresh graph
66+
using (var graph = new TFGraph ()) {
67+
using (var options = new TFImportGraphDefOptions ()) {
68+
options.SetPrefix ("imported");
69+
graph.Import (graphDef, options, status);
70+
Assert (status);
71+
}
72+
graphDef.Dispose ();
73+
74+
var scalar = graph ["imported/scalar"];
75+
var feed = graph ["imported/feed"];
76+
Assert (scalar != null);
77+
78+
Assert (feed != null);
79+
80+
// Can add nodes to the imported graph without trouble
81+
Add (feed, scalar, graph, status);
82+
Assert (status);
83+
}
84+
}
85+
86+
public void TestSession ()
87+
{
88+
var status = new TFStatus ();
89+
using (var graph = new TFGraph ()) {
90+
var feed = Placeholder (graph, status);
91+
var two = ScalarConst (2, graph, status);
92+
var add = Add (feed, two, graph, status);
93+
Assert (status);
94+
95+
// Create a session for this graph
96+
using (var session = new TFSession (graph, status)) {
97+
Assert (status);
98+
99+
// Run the graph
100+
var inputs = new TFOutput [] {
101+
new TFOutput (feed, 0)
102+
};
103+
var input_values = new TFTensor [] {
104+
3
105+
};
106+
var add_output = new TFOutput (add, 0);
107+
var outputs = new TFOutput [] {
108+
add_output
109+
};
110+
111+
var results = session.Run (runOptions: null,
112+
inputs: inputs,
113+
inputValues: input_values,
114+
outputs: outputs,
115+
targetOpers: null,
116+
runMetadata: null,
117+
status: status);
118+
Assert (status);
119+
var res = results [0];
120+
Assert (res.TensorType == TFDataType.Int32);
121+
Assert (res.NumDims == 0); // Scalar
122+
Assert (res.TensorByteSize == (UIntPtr)4);
123+
Assert (Marshal.ReadInt32 (res.Data) == 3 + 2);
124+
125+
// Use runner API
126+
var runner = session.GetRunner ();
127+
runner.AddInput (new TFOutput (feed, 0), 3);
128+
runner.Fetch (add_output);
129+
results = runner.Run (status: status);
130+
res = results [0];
131+
Assert (res.TensorType == TFDataType.Int32);
132+
Assert (res.NumDims == 0); // Scalar
133+
Assert (res.TensorByteSize == (UIntPtr)4);
134+
Assert (Marshal.ReadInt32 (res.Data) == 3 + 2);
135+
136+
137+
}
138+
}
139+
}
140+
141+
public void TestOperationOutputListSize ()
142+
{
143+
using (var graph = new TFGraph ()) {
144+
var c1 = graph.Const (1L, "c1");
145+
var cl = graph.Const (new int [] { 1, 2 }, "cl");
146+
var c2 = graph.Const (new long [,] { { 1, 2 }, { 3, 4 } }, "c2");
147+
148+
var outputs = graph.ShapeN (new TFOutput [] { c1, c2 });
149+
var op = outputs [0].Operation;
150+
151+
Assert (op.OutputListLength ("output") == 2);
152+
Assert (op.NumOutputs == 2);
153+
}
154+
}
155+
156+
public void TestOutputShape ()
157+
{
158+
using (var graph = new TFGraph ()) {
159+
var c1 = graph.Const (0L, "c1");
160+
var s1 = graph.GetShape (c1);
161+
var c2 = graph.Const (new long [] { 1, 2, 3 }, "c2");
162+
var s2 = graph.GetShape (c2);
163+
var c3 = graph.Const (new long [,] { { 1, 2, 3 }, { 4, 5, 6 } }, "c3");
164+
var s3 = graph.GetShape (c3);
165+
}
166+
}
167+
168+
class WhileTester : IDisposable
169+
{
170+
public TFStatus status;
171+
public TFGraph graph;
172+
public TFSession session;
173+
public TFSession.Runner runner;
174+
public TFOutput [] inputs, outputs;
175+
176+
public WhileTester ()
177+
{
178+
status = new TFStatus ();
179+
graph = new TFGraph ();
180+
}
181+
182+
public void Init (int ninputs, TFGraph.WhileConstructor constructor)
183+
{
184+
inputs = new TFOutput [ninputs];
185+
for (int i = 0; i < ninputs; ++i)
186+
inputs [i] = graph.Placeholder (TFDataType.Int32, operName: "p" + i);
187+
188+
Assert (status);
189+
outputs = graph.While (inputs, constructor, status);
190+
Assert (status);
191+
}
192+
193+
public TFTensor [] Run (params int [] inputValues)
194+
{
195+
Assert (inputValues.Length == inputs.Length);
196+
197+
session = new TFSession (graph);
198+
runner = session.GetRunner ();
199+
200+
for (int i = 0; i < inputs.Length; i++)
201+
runner.AddInput (inputs [i], (TFTensor)inputValues [i]);
202+
runner.Fetch (outputs);
203+
return runner.Run ();
204+
}
205+
206+
public void Dispose ()
207+
{
208+
if (session != null)
209+
session.Dispose ();
210+
if (graph != null)
211+
graph.Dispose ();
212+
}
213+
}
214+
215+
public void WhileTest ()
216+
{
217+
using (var j = new WhileTester ()) {
218+
219+
// Create loop: while (input1 < input2) input1 += input2 + 1
220+
j.Init (2, (TFGraph conditionGraph, TFOutput [] condInputs, out TFOutput condOutput, TFGraph bodyGraph, TFOutput [] bodyInputs, TFOutput [] bodyOutputs, out string name) => {
221+
Assert (bodyGraph.Handle != IntPtr.Zero);
222+
Assert (conditionGraph.Handle != IntPtr.Zero);
223+
224+
var status = new TFStatus ();
225+
var lessThan = conditionGraph.Less (condInputs [0], condInputs [1]);
226+
227+
Assert (status);
228+
condOutput = new TFOutput (lessThan.Operation, 0);
229+
230+
var add1 = bodyGraph.Add (bodyInputs [0], bodyInputs [1]);
231+
var one = bodyGraph.Const (1);
232+
var add2 = bodyGraph.Add (add1, one);
233+
bodyOutputs [0] = new TFOutput (add2, 0);
234+
bodyOutputs [1] = bodyInputs [1];
235+
236+
name = "Simple1";
237+
});
238+
239+
var res = j.Run (-9, 2);
240+
241+
Assert (3 == (int)res [0].GetValue ());
242+
Assert (2 == (int)res [1].GetValue ());
243+
};
244+
}
245+
246+
// For this to work, we need to surface REGISTER_OP from C++ to C
247+
248+
class AttributeTest : IDisposable
249+
{
250+
static int counter;
251+
public TFStatus Status;
252+
TFGraph graph;
253+
TFOperationDesc desc;
254+
255+
public AttributeTest ()
256+
{
257+
Status = new TFStatus ();
258+
graph = new TFGraph ();
259+
}
260+
261+
public TFOperationDesc Init (string op)
262+
{
263+
string opname = "AttributeTest";
264+
if (op.StartsWith ("list(")) {
265+
op = op.Substring (5, op.Length - 6);
266+
opname += "List";
267+
}
268+
opname += op;
269+
return new TFOperationDesc (graph, opname, "name" + (counter++));
270+
}
271+
272+
public void Dispose ()
273+
{
274+
graph.Dispose ();
275+
Status.Dispose ();
276+
}
277+
}
278+
279+
void ExpectMeta (TFOperation op, string name, int expectedListSize, TFAttributeType expectedType, int expectedTotalSize)
280+
{
281+
var meta = op.GetAttributeMetadata (name);
282+
Assert (meta.IsList == (expectedListSize >= 0 ? true : false));
283+
Assert (expectedListSize == meta.ListSize);
284+
Assert (expectedTotalSize == expectedTotalSize);
285+
Assert (expectedType == meta.Type);
286+
}
287+
288+
public void AttributesTest ()
289+
{
290+
using (var x = new AttributeTest ()) {
291+
var shape1 = new TFShape (new long [] { 1, 3 });
292+
var shape2 = new TFShape (2, 4, 6);
293+
var desc = x.Init ("list(shape)");
294+
desc.SetAttrShape ("v", new TFShape [] { shape1, shape2 });
295+
var op = desc.FinishOperation ();
296+
ExpectMeta (op, "v", 2, TFAttributeType.Shape, 5);
297+
}
298+
299+
}
300+
}
301+
}

SampleTest/README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1-
This contains a simple test suite to exercise the low-level TensorFlowSharp API.
1+
This contains a simple test suite to exercise the low-level TensorFlowSharp API
2+
and ports of some simple examples on how to use the API.
3+
4+
The `LowLevelTests.cs` are the low-level tests, while the driver that
5+
shows how to use the API is in `SampleTest.cs`

0 commit comments

Comments
 (0)