Skip to content

Commit 3f49044

Browse files
committed
Add support for iOS (callbacks need to be flagged) + documentation
1 parent 5c591e7 commit 3f49044

4 files changed

Lines changed: 98 additions & 20 deletions

File tree

TensorFlowSharp/Buffer.cs

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,26 @@
1111

1212
namespace TensorFlow
1313
{
14+
/// <summary>
15+
/// This attribute can be applied to callback functions that will be invoked
16+
/// from unmanaged code to managed code.
17+
/// </summary>
18+
/// <remarks>
19+
/// <code>
20+
/// [TensorFlow.MonoPInvokeCallback (typeof (BufferReleaseFunc))]
21+
/// internal static void MyFreeFunc (IntPtr data, IntPtr length){..}
22+
/// </code>
23+
/// </remarks>
24+
public sealed class MonoPInvokeCallbackAttribute : Attribute
25+
{
26+
/// <summary>
27+
/// Use this constructor to annotate the type of the callback function that
28+
/// will be invoked from unmanaged code.
29+
/// </summary>
30+
/// <param name="t">T.</param>
31+
public MonoPInvokeCallbackAttribute (Type t) { }
32+
}
33+
1434
[StructLayout (LayoutKind.Sequential)]
1535
internal struct LLBuffer
1636
{
@@ -64,7 +84,14 @@ unsafe public TFBuffer () : base ((IntPtr)TF_NewBuffer ())
6484
/// <remarks>
6585
/// Methods of this signature are invoked with the data pointer and the
6686
/// lenght pointer when then TFBuffer no longer needs to hold on to the
67-
/// data.
87+
/// data. If you are using this on platforms with static compilation
88+
/// like iOS, you need to annotate your callback with the MonoPInvokeCallbackAttribute,
89+
/// like this:
90+
///
91+
/// <code>
92+
/// [TensorFlow.MonoPInvokeCallback (typeof (BufferReleaseFunc))]
93+
/// internal static void MyFreeFunc (IntPtr data, IntPtr length){..}
94+
/// </code>
6895
/// </remarks>
6996
public delegate void BufferReleaseFunc (IntPtr data, IntPtr lenght);
7097

@@ -91,7 +118,8 @@ unsafe public TFBuffer (IntPtr buffer, long size, BufferReleaseFunc release) : b
91118
buf->data_deallocator = Marshal.GetFunctionPointerForDelegate (release);
92119
}
93120

94-
internal static void FreeBlock (IntPtr data, IntPtr lenght)
121+
[MonoPInvokeCallback (typeof (BufferReleaseFunc))]
122+
internal static void FreeBlock (IntPtr data, IntPtr length)
95123
{
96124
Marshal.FreeHGlobal (data);
97125
}

TensorFlowSharp/OperationsExtras.cs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,38 @@ public TFOutput Variable (TFOutput initialValue, out TFOperation init, out TFOut
9292
}
9393

9494
List<TFOperation> pending_init_variables;
95+
96+
/// <summary>
97+
/// Registers a specified variable as an initialization variable.
98+
/// </summary>
99+
/// <param name="variable">Variable to register.</param>
100+
/// <remarks>
101+
/// <para>
102+
/// This is a convenience method to track the variables that need to be initialized in the graph,
103+
/// you can retrieve the list of all those variables by calling the <see cref="M:TensorFlow.TFGraph.GetGlobalVariablesInitializer"/>
104+
/// which will return this list and clear the state at that point.
105+
/// </para>
106+
/// <para>
107+
/// You typically use this method from helper methods to register all the variables that you want
108+
/// initialized, and a higher level method will retrieve all these variables and initialize them
109+
/// at their convenience.
110+
/// </para>
111+
/// </remarks>
95112
public void AddInitVariable (TFOperation variable)
96113
{
97114
if (pending_init_variables == null)
98115
pending_init_variables = new List<TFOperation> ();
99116
pending_init_variables.Add (variable);
100117
}
101118

119+
/// <summary>
120+
/// Gets the list of all registered global variables.
121+
/// </summary>
122+
/// <returns>The array of variables that should be initialized.</returns>
123+
/// <remarks>
124+
/// After this method is invoked the list of pending initialization variables
125+
/// is cleared.
126+
/// </remarks>
102127
public TFOperation [] GetGlobalVariablesInitializer ()
103128
{
104129
var res = pending_init_variables.ToArray ();
@@ -171,8 +196,6 @@ public TFOutput Variable (TFOutput initialValue, string operName = null)
171196
//
172197
TFOutput ShapeTensorOutput (TFShape shape)
173198
{
174-
Array a;
175-
176199
if (shape.IsLongArray)
177200
return Const (shape.ToArray (), TFDataType.Int64);
178201
else

TensorFlowSharp/Tensor.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ public class TFTensor : TFDisposable
4848
static extern unsafe TF_Tensor TF_NewTensor (TFDataType dataType, IntPtr zeroDims, int num_dims, IntPtr data, size_t len, Deallocator deallocator, IntPtr deallocator_arg);
4949

5050
internal TFTensor (IntPtr handle) : base (handle) { }
51+
52+
[MonoPInvokeCallback (typeof (Deallocator))]
5153
internal static void FreeTensorData (IntPtr data, IntPtr len, IntPtr closure)
5254
{
5355
Marshal.FreeHGlobal (data);
@@ -624,7 +626,7 @@ unsafe public static implicit operator TFTensor (Array array)
624626
/// <param name="dims">Describes the tensor shape, an array that indicates .</param>
625627
/// <param name="data">Pointer to the raw data that will be used to initialize the tensor.</param>
626628
/// <param name="dataSize">The size of the data being passed in.</param>
627-
/// <param name="deallocator">Deallocator method, it is invoked when the tensor is destroyed to release the data pointed to by <paramref name="data"/>.</param>
629+
/// <param name="deallocator">Deallocator method, it is invoked when the tensor is destroyed to release the data pointed to by <paramref name="data"/>. On platforms like iOS (or other static compilation platforms), yiou must annotate the method specified in the deallocator with a <see cref="T:TensorFlow.MonoPInvokeCallbackAttribute"/>.</param>
628630
/// <param name="deallocatorData">An optional argument of data that is passed to the deallocator method when the tensor is destroyed, you can use this to pass context information.</param>
629631
public TFTensor (TFDataType dataType, long [] dims, IntPtr data, size_t dataSize, Deallocator deallocator, IntPtr deallocatorData) : base (IntPtr.Zero)
630632
{

TensorFlowSharp/Tensorflow.cs

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,9 @@ public class TFSessionOptions : TFDisposable
382382
[DllImport (NativeBinding.TensorFlowLibrary)]
383383
internal static extern unsafe TF_SessionOptions TF_NewSessionOptions ();
384384

385+
/// <summary>
386+
/// Initializes a new instance of the <see cref="T:TensorFlow.TFSessionOptions"/> class.
387+
/// </summary>
385388
public TFSessionOptions () : base (TF_NewSessionOptions ()) { }
386389

387390
// extern void TF_DeleteSessionOptions (TF_SessionOptions *);
@@ -485,6 +488,12 @@ internal override void NativeDispose (IntPtr handle)
485488
[DllImport (NativeBinding.TensorFlowLibrary)]
486489
static extern unsafe void TF_GraphSetTensorShape (TF_Graph graph, TFOutput output, IntPtr dims, int num_dims, TF_Status status);
487490

491+
/// <summary>
492+
/// Sets the tensor shape of the tensor referenced by <paramref name="output"/> to the shape described by <paramref name="dims"/>.
493+
/// </summary>
494+
/// <param name="output">The tensor on which this method will operate in the graph.</param>
495+
/// <param name="dims">The tensor shape, specified as an array of dimensions.</param>
496+
/// <param name="status">Status buffer, if specified a status code will be left here, if not specified, a <see cref="T:TensorFlow.TFException"/> exception is raised if there is an error.</param>
488497
public void SetTensorShape (TFOutput output, long [] dims, TFStatus status = null)
489498
{
490499
if (handle == IntPtr.Zero)
@@ -516,6 +525,13 @@ public int GetTensorNumDims (TFOutput output, TFStatus status = null)
516525
[DllImport (NativeBinding.TensorFlowLibrary)]
517526
static extern unsafe void TF_GraphGetTensorShape (TF_Graph graph, TFOutput output, long [] dims, int num_dims, TF_Status status);
518527

528+
/// <summary>
529+
/// Returns the shape of a tensor specified in <paramref name="output"/>.
530+
/// </summary>
531+
///
532+
/// <returns>The tensor shape. If the number of dimensions in the shape is unknown or the shape is, a scalar, the values in the array will be zero. Otherwise, each element of will be set corresponding to the size of the dimension. An unknown dimension is represented by -1.</returns>
533+
/// <param name="output">The tensor that you want to look up. </param>
534+
/// <param name="status">Status buffer, if specified a status code will be left here, if not specified, a <see cref="T:TensorFlow.TFException"/> exception is raised if there is an error.</param>
519535
public long [] GetTensorShape (TFOutput output, TFStatus status = null)
520536
{
521537
if (handle == IntPtr.Zero)
@@ -535,6 +551,11 @@ public long [] GetTensorShape (TFOutput output, TFStatus status = null)
535551
[DllImport (NativeBinding.TensorFlowLibrary)]
536552
static extern unsafe void TF_GraphToGraphDef (TF_Graph graph, LLBuffer* output_graph_def, TF_Status status);
537553

554+
/// <summary>
555+
/// Write out a serialized representation of the graph (as a GraphDef protocol buffer message) into <paramref name="outputGraphDef"/>.
556+
/// </summary>
557+
/// <param name="outputGraphDef">Target buffer where the graphs is serialized into.</param>
558+
/// <param name="status">Status buffer, if specified a status code will be left here, if not specified, a <see cref="T:TensorFlow.TFException"/> exception is raised if there is an error.</param>
538559
public void ToGraphDef (TFBuffer outputGraphDef, TFStatus status = null)
539560
{
540561
if (handle == IntPtr.Zero)
@@ -560,7 +581,7 @@ public void ToGraphDef (TFBuffer outputGraphDef, TFStatus status = null)
560581
/// <returns>The import.</returns>
561582
/// <param name="graphDef">A buffer containing the serialized graph.</param>
562583
/// <param name="prefix">A prefix that will be prepended to names of nodes in the <paramref name="graphDef"/> when they are imported into the graph.</param>
563-
/// <param name="status">Status buffer.</param>
584+
/// <param name="status">Status buffer, if specified a status code will be left here, if not specified, a <see cref="T:TensorFlow.TFException"/> exception is raised if there is an error.</param>
564585
public void Import (TFBuffer graphDef, string prefix = "", TFStatus status = null)
565586
{
566587
if (handle == IntPtr.Zero)
@@ -582,7 +603,7 @@ public void Import (TFBuffer graphDef, string prefix = "", TFStatus status = nul
582603
/// <returns>The import.</returns>
583604
/// <param name="graphDef">A buffer containing the serialized graph.</param>
584605
/// <param name="options">Importing graph options.</param>
585-
/// <param name="status">Status buffer.</param>
606+
/// <param name="status">Status buffer, if specified a status code will be left here, if not specified, a <see cref="T:TensorFlow.TFException"/> exception is raised if there is an error.</param>
586607
public void Import (TFBuffer graphDef, TFImportGraphDefOptions options, TFStatus status = null)
587608
{
588609
if (handle == IntPtr.Zero)
@@ -605,8 +626,8 @@ public void Import (TFBuffer graphDef, TFImportGraphDefOptions options, TFStatus
605626
/// </summary>
606627
/// <returns>The import.</returns>
607628
/// <param name="buffer">A byte array containing the serialized graph.</param>
608-
/// <param name="prefix">A prefix that will be prepended to names of nodes in the <paramref name="graphDef"/> when they are imported into the graph.</param>
609-
/// <param name="status">Status buffer.</param>
629+
/// <param name="prefix">A prefix that will be prepended to names of nodes in the graph when they are imported into the graph.</param>
630+
/// <param name="status">Status buffer, if specified a status code will be left here, if not specified, a <see cref="T:TensorFlow.TFException"/> exception is raised if there is an error.</param>
610631
public void Import (byte [] buffer, string prefix = "", TFStatus status = null)
611632
{
612633
if (handle == IntPtr.Zero)
@@ -627,7 +648,7 @@ public void Import (byte [] buffer, string prefix = "", TFStatus status = null)
627648
/// <returns>The import.</returns>
628649
/// <param name="buffer">A byte array containing the serialized graph.</param>
629650
/// <param name="options">Importing graph options.</param>
630-
/// <param name="status">Status buffer.</param>
651+
/// <param name="status">Status buffer, if specified a status code will be left here, if not specified, a <see cref="T:TensorFlow.TFException"/> exception is raised if there is an error.</param>
631652
public void Import (byte [] buffer, TFImportGraphDefOptions options, TFStatus status = null)
632653
{
633654
if (handle == IntPtr.Zero)
@@ -686,7 +707,7 @@ public IEnumerable<TFOperation> GetEnumerator ()
686707
/// </summary>
687708
/// <returns>null for single dimension, .</returns>
688709
/// <param name="output">The output operation to probe.</param>
689-
/// <param name="status">Status.</param>
710+
/// <param name="status">Status buffer, if specified a status code will be left here, if not specified, a <see cref="T:TensorFlow.TFException"/> exception is raised if there is an error.</param>
690711
public long [] GetShape (TFOutput output, TFStatus status = null)
691712
{
692713
if (handle == IntPtr.Zero)
@@ -794,7 +815,7 @@ unsafe extern static void TF_GraphImportGraphDefWithReturnOutputs (
794815
/// <param name="graphDef">Serialized graph definition (in protocol buffer format).</param>
795816
/// <param name="options">Import options.</param>
796817
/// <param name="returnOutputs">Array large enough to contain all the return options.</param>
797-
/// <param name="status">Status, optional.</param>
818+
/// <param name="status">Status buffer, if specified a status code will be left here, if not specified, a <see cref="T:TensorFlow.TFException"/> exception is raised if there is an error.</param>
798819
public void ImportGraphDef (TFBuffer graphDef, TFImportGraphDefOptions options, TFOutput [] returnOutputs, TFStatus status = null)
799820
{
800821
if (handle == IntPtr.Zero)
@@ -878,6 +899,7 @@ static unsafe TFOutput [] CopyFrom (TFOutput* ptr, int n)
878899
/// </summary>
879900
/// <param name="inputs">Inputs.</param>
880901
/// <param name="constructor">Callback method that fills out the various while loop parameters.</param>
902+
/// <param name="status">Status buffer, if specified a status code will be left here, if not specified, a <see cref="T:TensorFlow.TFException"/> exception is raised if there is an error.</param>
881903
/// <returns>
882904
/// An array of TFOutputs from creating the While loop, or null if there is an error creating the
883905
/// while loop, or if the constructor raised an exception when it was invoked.
@@ -951,7 +973,7 @@ public TFOutput [] While (TFOutput [] inputs, WhileConstructor constructor, TFSt
951973
/// <param name="x">The x elements.</param>
952974
/// <param name="dx">Initial gradients, which represent the symbolic partial derivatives of some loss function `L` w.r.t. <paramref name="y"/> ).
953975
/// If the parameter is null, the implementation will use dx for 'OnesLike' for all shapes in <paramref name="y"/></param>
954-
/// <param name="status">Status.</param>
976+
/// <param name="status">Status buffer, if specified a status code will be left here, if not specified, a <see cref="T:TensorFlow.TFException"/> exception is raised if there is an error.</param>
955977
/// <remarks>
956978
/// d(y[0] + y[1]+ ...)/dx[0], d(y[0] + y[1] + ...)/dx[1]z...
957979
/// </remarks>
@@ -1727,7 +1749,7 @@ public TFAttributeMetadata GetAttributeMetadata (string attrName, TFStatus statu
17271749
/// Encodes the TFOperation as a protocol buffer payload
17281750
/// </summary>
17291751
/// <returns>The buffer with the encoded operation in the protocol buffer format.</returns>
1730-
/// <param name="status">Status.</param>
1752+
/// <param name="status">Status buffer, if specified a status code will be left here, if not specified, a <see cref="T:TensorFlow.TFException"/> exception is raised if there is an error.</param>
17311753
/// <remarks>
17321754
/// </remarks>
17331755
public TFBuffer ToNodeDef (TFStatus status = null)
@@ -1918,6 +1940,10 @@ public class TFSession : TFDisposable
19181940
[DllImport (NativeBinding.TensorFlowLibrary)]
19191941
static extern unsafe TF_Session TF_NewSession (TF_Graph graph, TF_SessionOptions opts, TF_Status status);
19201942

1943+
/// <summary>
1944+
/// Gets the graph associated with this TensorFlow session.
1945+
/// </summary>
1946+
/// <value>The graph.</value>
19211947
public TFGraph Graph { get; private set; }
19221948

19231949
TFSession (IntPtr handle, TFGraph graph) : base (handle)
@@ -2115,7 +2141,7 @@ TFOutput ParseOutput (string operation)
21152141
/// Adds the specified operation names as the ones to be retrieved.
21162142
/// </summary>
21172143
/// <returns>An instance to the runner, so you can easily chain the operations together.</returns>
2118-
/// <param name="targets">One or more target names.</param>
2144+
/// <param name="targetNames">One or more target names.</param>
21192145
public Runner AddTarget (params string [] targetNames)
21202146
{
21212147
foreach (var tn in targetNames)
@@ -2142,7 +2168,6 @@ public Runner Fetch (string operation, int index)
21422168
/// <returns>The instance of runner, to allow chaining operations.</returns>
21432169
/// <param name="operation">The name of the operation in the graph, which might be a simple name, or it might be name:index,
21442170
/// where the index is the .</param>
2145-
/// <param name="index">The index of the output in the operation.</param>
21462171
public Runner Fetch (string operation)
21472172
{
21482173
var op = ParseOutput (operation);
@@ -2199,7 +2224,7 @@ public Runner Fetch (params string [] outputs)
21992224
/// Execute the graph fragments necessary to compute all requested fetches.
22002225
/// </summary>
22012226
/// <returns>One TFTensor for each call to Fetch that you made, in the order that you made them.</returns>
2202-
/// <param name="status">Status.</param>
2227+
/// <param name="status">Status buffer, if specified a status code will be left here, if not specified, a <see cref="T:TensorFlow.TFException"/> exception is raised if there is an error.</param>
22032228
public TFTensor [] Run (TFStatus status = null)
22042229
{
22052230
return session.Run (inputs.ToArray (), inputValues.ToArray (), outputs.ToArray (), targets.ToArray (), RunMetadata, RunOptions, status);
@@ -2209,7 +2234,7 @@ public TFTensor [] Run (TFStatus status = null)
22092234
/// Run the specified operation, by adding it implicity to the output, single return value
22102235
/// </summary>
22112236
/// <param name="operation">The output of the operation.</param>
2212-
/// <param name="status">Optional, status.</param>
2237+
/// <param name="status">Status buffer, if specified a status code will be left here, if not specified, a <see cref="T:TensorFlow.TFException"/> exception is raised if there is an error.</param>
22132238
/// <remarks>
22142239
/// This method is a convenience method, and when you call it, it will clear any
22152240
/// calls that you might have done to Fetch() and use the specified operation to Fetch
@@ -2251,7 +2276,7 @@ public Runner GetRunner ()
22512276
/// <param name="targetOpers">Target operations to execute.</param>
22522277
/// <param name="runMetadata">Run metadata.</param>
22532278
/// <param name="runOptions">Run options.</param>
2254-
/// <param name="status">Status code.</param>
2279+
/// <param name="status">Status buffer, if specified a status code will be left here, if not specified, a <see cref="T:TensorFlow.TFException"/> exception is raised if there is an error.</param>
22552280
public TFTensor [] Run (TFOutput [] inputs, TFTensor [] inputValues, TFOutput [] outputs, TFOperation [] targetOpers = null, TFBuffer runMetadata = null, TFBuffer runOptions = null, TFStatus status = null)
22562281
{
22572282
if (handle == IntPtr.Zero)
@@ -2333,7 +2358,7 @@ void IDisposable.Dispose ()
23332358
/// <param name="inputs">Inputs.</param>
23342359
/// <param name="outputs">Outputs.</param>
23352360
/// <param name="targetOpers">Target operations to run.</param>
2336-
/// <param name="status">Status.</param>
2361+
/// <param name="status">Status buffer, if specified a status code will be left here, if not specified, a <see cref="T:TensorFlow.TFException"/> exception is raised if there is an error.</param>
23372362
public PartialRunToken PartialRunSetup (TFOutput [] inputs, TFOutput [] outputs, TFOperation [] targetOpers, TFStatus status = null)
23382363
{
23392364
if (handle == IntPtr.Zero)

0 commit comments

Comments
 (0)