Skip to content

Commit 8dd21fc

Browse files
committed
TensorFlowSharp 1.11 API
1 parent 08002f5 commit 8dd21fc

File tree

3 files changed

+118
-55
lines changed

3 files changed

+118
-55
lines changed

TensorFlowSharp/Operations.g.cs

Lines changed: 18 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -4280,19 +4280,25 @@ public TFOutput CacheDataset (TFOutput input_dataset, TFOutput filename, TFDataT
42804280
/// <param name="operName">
42814281
/// If specified, the created operation in the graph will be this one, otherwise it will be named 'Cast'.
42824282
/// </param>
4283+
/// <param name="Truncate">
4284+
/// Optional argument
4285+
/// </param>
42834286
/// <param name="DstT">
42844287
/// </param>
42854288
/// <returns>
42864289
/// The TFOperation can be fetched from the resulting TFOutput, by fethching the Operation property from the result.
42874290
/// </returns>
4288-
public TFOutput Cast (TFOutput x, TFDataType DstT, string operName = null)
4291+
public TFOutput Cast (TFOutput x, TFDataType DstT, bool? Truncate = null, string operName = null)
42894292
{
42904293
var desc = new TFOperationDesc (this, "Cast", MakeName ("Cast", operName));
42914294
desc.AddInput (x);
42924295
foreach ( TFOperation control in CurrentDependencies )
42934296
desc.AddControlInput (control);
42944297

42954298
desc.SetAttrType ("DstT", DstT);
4299+
if (Truncate.HasValue)
4300+
desc.SetAttr ("Truncate", Truncate.Value);
4301+
42964302
var op = desc.FinishOperation ();
42974303
int _idx = 0;
42984304
var y = new TFOutput (op, _idx++);
@@ -6017,14 +6023,14 @@ public TFOutput Cross (TFOutput a, TFOutput b, string operName = null)
60176023
/// <param name="input">
60186024
/// The local input to the sum.
60196025
/// </param>
6026+
/// <param name="group_assignment">
6027+
/// An int32 tensor with shape
6028+
/// [num_groups, num_replicas_per_group]. <c>group_assignment[i]</c> represents the
6029+
/// replica ids in the ith subgroup.
6030+
/// </param>
60206031
/// <param name="operName">
60216032
/// If specified, the created operation in the graph will be this one, otherwise it will be named 'CrossReplicaSum'.
60226033
/// </param>
6023-
/// <param name="group_assignment">
6024-
/// Optional argument
6025-
/// The list of group ids. <c>group_assignment[i]</c> represents the
6026-
/// group id of replica i.
6027-
/// </param>
60286034
/// <returns>
60296035
/// The sum of all the distributed inputs.
60306036
/// The TFOperation can be fetched from the resulting TFOutput, by fethching the Operation property from the result.
@@ -6034,20 +6040,19 @@ public TFOutput Cross (TFOutput a, TFOutput b, string operName = null)
60346040
/// each is the sum of all the inputs, otherwise the output of each is the sum of
60356041
/// the inputs belonging to the same group.
60366042
///
6037-
/// For example, suppose there are 4 TPU instances: <c>[A, B, C, D]</c>. Passing
6038-
/// group_assignment=<c>[0,1,0,1]</c> sets <c>A, C</c> as group 0, and <c>B, D</c> as group 1.
6039-
/// Thus we get the outputs: <c>[A+C, B+D, A+C, B+D]</c>.
6043+
/// For example, suppose there are 8 TPU instances: <c>[A, B, C, D, E, F, G, H]</c>.
6044+
/// Passing group_assignment=<c>[[0,2,4,6],[1,3,5,7]]</c> sets <c>A, C, E, G</c> as group 0,
6045+
/// and <c>B, D, F, H</c> as group 1. Thus we get the outputs:
6046+
/// <c>[A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H]</c>.
60406047
/// </remarks>
6041-
public TFOutput CrossReplicaSum (TFOutput input, long[] group_assignment = null, string operName = null)
6048+
public TFOutput CrossReplicaSum (TFOutput input, TFOutput group_assignment, string operName = null)
60426049
{
60436050
var desc = new TFOperationDesc (this, "CrossReplicaSum", MakeName ("CrossReplicaSum", operName));
60446051
desc.AddInput (input);
6052+
desc.AddInput (group_assignment);
60456053
foreach ( TFOperation control in CurrentDependencies )
60466054
desc.AddControlInput (control);
60476055

6048-
if (group_assignment != null)
6049-
desc.SetAttr ("group_assignment", group_assignment);
6050-
60516056
var op = desc.FinishOperation ();
60526057
int _idx = 0;
60536058
var output = new TFOutput (op, _idx++);
@@ -30729,46 +30734,6 @@ public TFOutput Slice (TFOutput input, TFOutput begin, TFOutput size, string ope
3072930734
return output;
3073030735
}
3073130736

30732-
/// <summary>
30733-
/// Creates a dataset that passes a sliding window over <c>input_dataset</c>.
30734-
/// </summary>
30735-
/// <param name="input_dataset">
30736-
/// </param>
30737-
/// <param name="window_size">
30738-
/// A scalar representing the number of elements in the
30739-
/// sliding window.
30740-
/// </param>
30741-
/// <param name="stride">
30742-
/// A scalar representing the steps moving the sliding window
30743-
/// forward in one iteration. It must be in <c>[1, window_size)</c>.
30744-
/// </param>
30745-
/// <param name="operName">
30746-
/// If specified, the created operation in the graph will be this one, otherwise it will be named 'SlideDataset'.
30747-
/// </param>
30748-
/// <param name="output_types">
30749-
/// </param>
30750-
/// <param name="output_shapes">
30751-
/// </param>
30752-
/// <returns>
30753-
/// The TFOperation can be fetched from the resulting TFOutput, by fethching the Operation property from the result.
30754-
/// </returns>
30755-
public TFOutput SlideDataset (TFOutput input_dataset, TFOutput window_size, TFOutput stride, TFDataType[] output_types, TFShape[] output_shapes, string operName = null)
30756-
{
30757-
var desc = new TFOperationDesc (this, "SlideDataset", MakeName ("SlideDataset", operName));
30758-
desc.AddInput (input_dataset);
30759-
desc.AddInput (window_size);
30760-
desc.AddInput (stride);
30761-
foreach ( TFOperation control in CurrentDependencies )
30762-
desc.AddControlInput (control);
30763-
30764-
desc.SetAttrType ("output_types", output_types);
30765-
desc.SetAttrShape ("output_shapes", output_shapes);
30766-
var op = desc.FinishOperation ();
30767-
int _idx = 0;
30768-
var handle = new TFOutput (op, _idx++);
30769-
return handle;
30770-
}
30771-
3077230737
/// <summary>
3077330738
/// Returns a copy of the input tensor.
3077430739
/// </summary>

TensorFlowSharp/TensorFlowSharp.csproj

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
<DocumentationFile>bin\Debug\TensorFlowSharp.xml</DocumentationFile>
1010
<GenerateDocumentationFile Condition=" '$(Configuration)' == 'Release' ">true</GenerateDocumentationFile>
1111
<ReleaseVersion>0.2</ReleaseVersion>
12-
<TensorFlowRuntimeVersion>1.10.0</TensorFlowRuntimeVersion>
12+
<TensorFlowRuntimeVersion>1.11.0</TensorFlowRuntimeVersion>
1313
</PropertyGroup>
1414

1515
<PropertyGroup>
@@ -23,7 +23,7 @@
2323
<Description>.NET Bindings for TensorFlow</Description>
2424
<Owners>Miguel de Icaza</Owners>
2525
<Summary>.NET API for TensorFlow, Google's Machine Intelligence framework</Summary>
26-
<PackageReleaseNotes>1.10.0-pre1 adds support for the TensorFlow 1.10 release</PackageReleaseNotes>
26+
<PackageReleaseNotes>1.11.0 adds support for the TensorFlow 1.11 release</PackageReleaseNotes>
2727
</PropertyGroup>
2828

2929
<ItemGroup>

TensorFlowSharp/Tensorflow.cs

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,47 @@ public static TFBuffer GetAllOpList ()
9999
return new TFBuffer (TF_GetAllOpList ());
100100
}
101101

102+
103+
[DllImport (NativeBinding.TensorFlowLibrary)]
104+
static extern unsafe IntPtr TF_GetAllRegisteredKernels (TF_Status status);
105+
106+
/// <summary>
107+
/// Returns a serialized KernelList protocol buffer containing KernelDefs for all registered kernels
108+
/// </summary>
109+
/// <param name="status">Status buffer, if specified a status code will be left here, if not specified, a <see cref="T:TensorFlow.TFException"/> exception is raised if there is an error.</param>
110+
/// <returns>The all registered kernels.</returns>
111+
public static TFBuffer GetAllRegisteredKernels (TFStatus status = null)
112+
{
113+
var cstatus = TFStatus.Setup (status);
114+
115+
var r = TF_GetAllRegisteredKernels (cstatus.Handle);
116+
if (!cstatus.CheckMaybeRaise (status, last: false))
117+
return null;
118+
return new TFBuffer (r);
119+
}
120+
121+
[DllImport (NativeBinding.TensorFlowLibrary)]
122+
static extern unsafe IntPtr TF_GetRegisteredKernelsForOp (string name, TF_Status status);
123+
/// <summary>
124+
/// Returns a serialized KernelList protocol buffer containing KernelDefs for all
125+
/// kernels registered for the operation specified.
126+
/// </summary>
127+
/// <param name="name">The operation to look up.</param>
128+
/// <param name="status">Status buffer, if specified a status code will be left here, if not specified, a <see cref="T:TensorFlow.TFException"/> exception is raised if there is an error.</param>
129+
/// <returns>The registered kernels for the specified operation.</returns>
130+
public static TFBuffer GetAllRegisteredKernels (string name, TFStatus status = null)
131+
{
132+
if (name == null)
133+
throw new ArgumentNullException (nameof (name));
134+
var cstatus = TFStatus.Setup (status);
135+
136+
var r = TF_GetRegisteredKernelsForOp (name, cstatus.Handle);
137+
if (!cstatus.CheckMaybeRaise (status, last: false))
138+
return null;
139+
return new TFBuffer (r);
140+
}
141+
142+
102143
static void CheckSize ()
103144
{
104145
unsafe {
@@ -1090,6 +1131,58 @@ public TFOutput [] AddGradients (TFOutput [] y, TFOutput [] x, TFOutput [] dx =
10901131
return ret;
10911132
}
10921133

1134+
[DllImport (NativeBinding.TensorFlowLibrary)]
1135+
static extern unsafe void TF_AddGradientsWithPrefix (TF_Graph graph, string prefix, TFOutput* ys, int ny, TFOutput* xs, int nx, TFOutput* dx, TF_Status status, TFOutput* dy);
1136+
/// <summary>
1137+
/// Adds a gradient: the operations needed to compute the partial derivatives of sum of <paramref name="y"/>` wrt to <paramref name="x"/>.
1138+
/// </summary>
1139+
/// <returns>The partial derivatives, the size of the array is the same as the length of the <paramref name="y"/> array.</returns>
1140+
/// <param name="prefix">names the scope into which all gradients operations are being added. This must be unique within
1141+
/// the provided graph otherwise this operation will fail. If the value is null, the default prefixing behaviour takes
1142+
/// place, see AddGradients for more details.
1143+
/// </param>
1144+
/// <param name="y">The y elements.</param>
1145+
/// <param name="x">The x elements.</param>
1146+
/// <param name="dx">Initial gradients, which represent the symbolic partial derivatives of some loss function `L` w.r.t. <paramref name="y"/> ).
1147+
/// If the parameter is null, the implementation will use dx for 'OnesLike' for all shapes in <paramref name="y"/></param>
1148+
/// <param name="status">Status buffer, if specified a status code will be left here, if not specified, a <see cref="T:TensorFlow.TFException"/> exception is raised if there is an error.</param>
1149+
/// <remarks>
1150+
/// d(y[0] + y[1]+ ...)/dx[0], d(y[0] + y[1] + ...)/dx[1]z...
1151+
/// </remarks>
1152+
public TFOutput [] AddGradients (string prefix, TFOutput [] y, TFOutput [] x, TFOutput [] dx = null, TFStatus status = null)
1153+
{
1154+
if (y == null)
1155+
throw new ArgumentNullException (nameof (y));
1156+
if (x == null)
1157+
throw new ArgumentNullException (nameof (x));
1158+
if (dx != null) {
1159+
if (dx.Length != y.Length)
1160+
throw new ArgumentException ("If dx is not null, the size of the gradients must match the size of y", nameof (dx));
1161+
}
1162+
1163+
var cstatus = TFStatus.Setup (status);
1164+
1165+
var ret = new TFOutput [x.Length];
1166+
unsafe {
1167+
fixed (TFOutput* pret = &ret [0]) {
1168+
fixed (TFOutput* py = &y [0]) {
1169+
fixed (TFOutput* px = &x [0]) {
1170+
if (dx == null) {
1171+
TF_AddGradientsWithPrefix (handle, prefix, py, y.Length, px, x.Length, (TFOutput*)null, cstatus.Handle, pret);
1172+
} else {
1173+
fixed (TFOutput* pdx = &dx [0]) {
1174+
TF_AddGradientsWithPrefix (handle, prefix, py, y.Length, px, x.Length, pdx, cstatus.Handle, pret);
1175+
}
1176+
}
1177+
}
1178+
}
1179+
}
1180+
}
1181+
if (!cstatus.CheckMaybeRaise (status, last: false))
1182+
return null;
1183+
return ret;
1184+
}
1185+
10931186
[DllImport (NativeBinding.TensorFlowLibrary)]
10941187
static extern unsafe void TF_GraphCopyFunction (TF_Graph graph, TF_Function func, TF_Function grad, TF_Status status);
10951188

@@ -1388,6 +1481,11 @@ public TFFunction ImportFunctionDef (byte [] proto, TFStatus status = null)
13881481
return new TFFunction (handle);
13891482
}
13901483
}
1484+
1485+
[DllImport (NativeBinding.TensorFlowLibrary)]
1486+
static extern unsafe IntPtr TF_FunctionName (IntPtr handle);
1487+
1488+
public string Name => Marshal.PtrToStringAnsi (handle);
13911489
}
13921490

13931491
/// <summary>

0 commit comments

Comments
 (0)