Skip to content

Commit 8168578

Browse files
committed
Add Runner API
1 parent a57c593 commit 8168578

3 files changed

Lines changed: 137 additions & 22 deletions

File tree

OpGenerator/OpGenerator.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,19 +410,22 @@ void Run ()
410410
if (oper.attr.Any (attr => CSharpType (attr.type) == null)) {
411411
var attr = oper.attr.First (a => CSharpType (a.type) == null);
412412

413-
//Console.WriteLine ($"Skip: {oper.name} due to attribute ({attr.type} {attr.name}) lacking a mapping to C#");
413+
Console.WriteLine ($"SkipTYPE: {oper.name} due to attribute ({attr.type} {attr.name}) lacking a mapping to C#");
414414
continue;
415415
}
416416

417417
// Ignore reference types as well (per go's binding)
418418
if (oper.input_arg.Any (ia => ia.is_ref)) {
419-
//Console.WriteLine ($"Skip: {oper.name} due to presence of an input argument that is a reference");
419+
var pars = String.Join (", ", oper.input_arg.Where (x => x.is_ref).Select (x => $"{x.type} {x.name}"));
420+
Console.WriteLine ($"SkipInREF: {oper.name} parameters with is_ref: {pars}");
420421
continue;
421422
}
422423

423424
// Ignore reference types as well (per go's binding)
424425
if (oper.output_arg.Any (ia => ia.is_ref)) {
425-
//Console.WriteLine ($"Skip: {oper.name} due to presence of an output argument that is a reference");
426+
var pars = String.Join (", ", oper.input_arg.Where (x => x.is_ref).Select (x => $"{x.type} {x.name}"));
427+
Console.WriteLine ($"SkipOutREF: {oper.name} parameters with is_ref: {pars}");
428+
426429
continue;
427430
}
428431

SampleTest/SampleTest.cs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,9 @@ public void TestSession ()
113113
var input_values = new TFTensor [] {
114114
3
115115
};
116+
var add_output = new TFOutput (add, 0);
116117
var outputs = new TFOutput [] {
117-
new TFOutput (add, 0)
118-
};
119-
var output_values = new TFTensor [] {
120-
3
118+
add_output
121119
};
122120

123121
var results = session.Run ( runOptions: null,
@@ -134,6 +132,17 @@ public void TestSession ()
134132
Assert (res.TensorByteSize == (UIntPtr) 4);
135133
Assert (Marshal.ReadInt32 (res.Data) == 3 + 2);
136134

135+
// Use runner API
136+
var runner = session.GetRunner ();
137+
runner.AddInput (new TFOutput (feed, 0), 3);
138+
runner.Fetch (add_output);
139+
results = runner.Run (status: status);
140+
res = results [0];
141+
Assert (res.TensorType == TFDataType.Int32);
142+
Assert (res.NumDims == 0); // Scalar
143+
Assert (res.TensorByteSize == (UIntPtr)4);
144+
Assert (Marshal.ReadInt32 (res.Data) == 3 + 2);
145+
137146

138147
}
139148
}

TensorFlowSharp/Tensorflow.cs

Lines changed: 118 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,6 +1463,13 @@ string MakeName (string operName, string userName)
14631463
}
14641464
}
14651465

1466+
/// <summary>
1467+
/// TFGraph name scope handle
1468+
/// </summary>
1469+
/// <remarks>
1470+
/// Instances of this class when disposed restore the CurrentNameScope to the
1471+
/// value they had when the TFGraph.WithScope method was called.
1472+
/// </remarks>
14661473
public class TFScope : IDisposable
14671474
{
14681475
TFGraph container;
@@ -1874,9 +1881,22 @@ public TFOperation FinishOperation (TFStatus status = null)
18741881
}
18751882
}
18761883

1884+
/// <summary>
1885+
/// Tensorflow operations attached to a <see cref="T:Tensorflow.TFGraph"/>.
1886+
/// </summary>
1887+
/// <remarks>
1888+
/// TFOperations are usually created by invoking one of the methods in
1889+
/// <see cref="T:Tensorflow.TFGraph"/>, but they can also be constructed
1890+
/// manually using the low-level <see cref="T:Tensorflow.TFOperationDesc"/> API.
1891+
/// </remarks>
18771892
public partial class TFOperation
18781893
{
18791894
internal IntPtr handle;
1895+
1896+
/// <summary>
1897+
/// Gets the handle to the unmanaged TF_Operation object.
1898+
/// </summary>
1899+
/// <value>The handle.</value>
18801900
public IntPtr Handle => handle;
18811901

18821902
// Pointer to the graph, to keep it from collecting if there are TFOperations alive.
@@ -1892,6 +1912,10 @@ internal TFOperation (TFGraph graph, IntPtr handle)
18921912
[DllImport (NativeBinding.TensorFlowLibrary)]
18931913
static extern unsafe IntPtr TF_OperationName (TF_Operation oper);
18941914

1915+
/// <summary>
1916+
/// The name for this operation/
1917+
/// </summary>
1918+
/// <value>The name.</value>
18951919
public string Name => handle == IntPtr.Zero ? "<ObjectDisposed>" : TF_OperationName (handle).GetStr ();
18961920

18971921
// extern const char * TF_OperationOpType (TF_Operation *oper);
@@ -1910,6 +1934,10 @@ internal TFOperation (TFGraph graph, IntPtr handle)
19101934
[DllImport (NativeBinding.TensorFlowLibrary)]
19111935
static extern unsafe int TF_OperationNumOutputs (TF_Operation oper);
19121936

1937+
/// <summary>
1938+
/// Gets the number of outputs on this operation.
1939+
/// </summary>
1940+
/// <value>The number outputs.</value>
19131941
public int NumOutputs => handle == IntPtr.Zero ? -1 : TF_OperationNumOutputs (handle);
19141942

19151943

@@ -1931,6 +1959,10 @@ public int OutputListLength (string argName, TFStatus status = null)
19311959
[DllImport (NativeBinding.TensorFlowLibrary)]
19321960
static extern unsafe int TF_OperationNumInputs (TF_Operation oper);
19331961

1962+
/// <summary>
1963+
/// Gets the number of inputs for this operation.
1964+
/// </summary>
1965+
/// <value>The number inputs.</value>
19341966
public int NumInputs => TF_OperationNumInputs (handle);
19351967

19361968

@@ -2081,6 +2113,14 @@ public TFAttributeMetadata GetAttributeMetadata (string attrName, TFStatus statu
20812113
// extern void TF_OperationToNodeDef (TF_Operation *oper, TF_Buffer *output_node_def, TF_Status *status);
20822114
[DllImport (NativeBinding.TensorFlowLibrary)]
20832115
static extern unsafe void TF_OperationToNodeDef (TF_Operation oper, LLBuffer* output_node_def, TF_Status status);
2116+
2117+
/// <summary>
2118+
/// Encodes the TFOperation as a protocol buffer payload
2119+
/// </summary>
2120+
/// <returns>The buffer with the encoded operation in the protocol buffer format.</returns>
2121+
/// <param name="status">Status.</param>
2122+
/// <remarks>
2123+
/// </remarks>
20842124
public TFBuffer ToNodeDef (TFStatus status = null)
20852125
{
20862126
if (handle == IntPtr.Zero)
@@ -2139,16 +2179,29 @@ public void SetPrefix (string prefix)
21392179

21402180
}
21412181

2182+
/// <summary>
2183+
/// Drives the execution of a graph
2184+
/// </summary>
2185+
/// <remarks>
2186+
/// This creates a new context to execute a TFGraph. You can use the
2187+
/// constructo to create an empty session, or you can load an existing
2188+
/// model using the FromSAvedModel static method in this class.
2189+
/// </remarks>
21422190
public class TFSession : TFDisposable
21432191
{
21442192
// extern TF_Session * TF_NewSession (TF_Graph *graph, const TF_SessionOptions *opts, TF_Status *status);
21452193
[DllImport (NativeBinding.TensorFlowLibrary)]
21462194
static extern unsafe TF_Session TF_NewSession (TF_Graph graph, TF_SessionOptions opts, TF_Status status);
2195+
TFGraph graph;
21472196

2148-
TFSession (IntPtr handle) : base (handle) { }
2197+
TFSession (IntPtr handle, TFGraph graph) : base (handle)
2198+
{
2199+
this.graph = graph;
2200+
}
21492201

21502202
public TFSession (TFGraph graph, TFSessionOptions sessionOptions, TFStatus status = null) : base (IntPtr.Zero)
21512203
{
2204+
this.graph = graph;
21522205
var cstatus = TFStatus.Setup (status);
21532206
var h = TF_NewSession (graph.handle, sessionOptions.handle, cstatus.handle);
21542207
cstatus.CheckMaybeRaise (status);
@@ -2157,6 +2210,7 @@ public TFSession (TFGraph graph, TFSessionOptions sessionOptions, TFStatus statu
21572210

21582211
public TFSession (TFGraph graph, TFStatus status = null) : base (IntPtr.Zero)
21592212
{
2213+
this.graph = graph;
21602214
var cstatus = TFStatus.Setup (status);
21612215
var empty = TFSessionOptions.TF_NewSessionOptions ();
21622216
var h = TF_NewSession (graph.handle, empty, cstatus.handle);
@@ -2186,8 +2240,9 @@ public TFSession FromSavedModel (TFSessionOptions sessionOptions, TFBuffer runOp
21862240
{
21872241
var h = TF_LoadSessionFromSavedModel (sessionOptions.handle, runOptions.LLBuffer, exportDir, tags, tags.Length, graph.handle, metaGraphDef.LLBuffer, cstatus.handle);
21882242

2189-
if (cstatus.CheckMaybeRaise (status))
2190-
return new TFSession (h);
2243+
if (cstatus.CheckMaybeRaise (status)) {
2244+
return new TFSession (h, graph);
2245+
}
21912246
}
21922247
return null;
21932248
}
@@ -2229,26 +2284,74 @@ internal override void NativeDispose (IntPtr handle)
22292284
[DllImport (NativeBinding.TensorFlowLibrary)]
22302285
static extern unsafe void TF_SessionRun (TF_Session session, LLBuffer* run_options, TFOutput [] inputs, TF_Tensor [] input_values, int ninputs, TFOutput [] outputs, TF_Tensor [] output_values, int noutputs, TF_Operation [] target_opers, int ntargets, LLBuffer* run_metadata, TF_Status status);
22312286

2232-
#if false
2233-
public struct Input
2287+
public class Runner
22342288
{
2235-
public TFOutput InputTF;
2236-
public TFTensor Value;
2237-
public Input (TFOutput input, TFTensor value)
2289+
List<TFOutput> inputs = new List<TFOutput> (), outputs = new List<TFOutput> ();
2290+
List<TFTensor> inputValues = new List<TFTensor> ();
2291+
List<TFOperation> targets = new List<TFOperation> ();
2292+
TFSession session;
2293+
2294+
public Runner (TFSession session)
2295+
{
2296+
this.session = session;
2297+
}
2298+
2299+
public Runner AddInput (TFOutput input, TFTensor value)
2300+
{
2301+
if (value == null)
2302+
throw new ArgumentNullException (nameof (value));
2303+
inputs.Add (input);
2304+
inputValues.Add (value);
2305+
return this;
2306+
}
2307+
2308+
2309+
public Runner AddTarget (params TFOperation [] targets)
2310+
{
2311+
foreach (var t in targets)
2312+
this.targets.Add (t);
2313+
return this;
2314+
}
2315+
2316+
public Runner AddTarget (params string [] targetNames)
22382317
{
2239-
InputTF = input;
2240-
Value = value;
2318+
foreach (var tn in targetNames)
2319+
this.targets.Add (session.graph [tn]);
2320+
return this;
2321+
}
2322+
2323+
public Runner Fetch (string operation, int index = 0)
2324+
{
2325+
var op = session.graph [operation];
2326+
outputs.Add (op [index]);
2327+
return this;
2328+
}
2329+
2330+
public Runner Fetch (TFOutput output)
2331+
{
2332+
outputs.Add (output);
2333+
return this;
2334+
}
2335+
2336+
public TFTensor [] Run (TFBuffer runMetadata = null, TFBuffer runOptions = null, TFStatus status = null)
2337+
{
2338+
return session.Run (inputs.ToArray (), inputValues.ToArray (), outputs.ToArray (), targets.ToArray (), runMetadata, runOptions, status);
22412339
}
22422340
}
22432341

2244-
public TFTensor [] Run (IEnumerable<Input> x)
2342+
/// <summary>
2343+
/// Gets a new runner, this provides a simpler API to prepare the inputs to run on a session
2344+
/// </summary>
2345+
/// <returns>The runner.</returns>
2346+
/// <remarks>
2347+
/// The runner has a simple API that allows developers to call the AddTarget, AddInput, AddOutput and Fetch
2348+
/// to construct the parameters that will be passed to the TFSession.Run method.
2349+
/// </remarks>
2350+
public Runner GetRunner ()
22452351
{
2246-
// This API call would look liek this:
2247-
Run (new [] { new Input (default (TFOutput), null) , new Input (default (TFOutput), null)});
2352+
return new Runner (this);
22482353
}
22492354

2250-
#endif
2251-
22522355
public TFTensor [] Run (TFOutput [] inputs, TFTensor [] inputValues, TFOutput [] outputs, TFOperation [] targetOpers = null, TFBuffer runMetadata = null, TFBuffer runOptions = null, TFStatus status = null)
22532356
{
22542357
if (handle == IntPtr.Zero)

0 commit comments

Comments
 (0)