Skip to content

Commit 5ba4051

Browse files
committed
Add Gradient support, add import options, RemapControlDependency
1 parent 40f0457 commit 5ba4051

File tree

2 files changed

+71
-1
lines changed

2 files changed

+71
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,4 +158,4 @@ the terms of Apache 2 License, in particular all the generated documentation
158158
for the various operations that is generated by using the tensorflow reflection
159159
APIs.
160160

161-
Last API update: 998cb32c4f69c0a71faaa8d7ca5cc5bcd48a0585
161+
Last API update: a4b352bfddd518b540c30e456f3bc0027ba9351f

TensorFlowSharp/Tensorflow.cs

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,7 +813,55 @@ public TFOutput [] While (TFOutput [] inputs, WhileConstructor constructor, TFSt
813813
}
814814
}
815815

816+
[DllImport (NativeBinding.TensorFlowLibrary)]
817+
static extern unsafe void TF_AddGradients (TF_Graph graph, TFOutput* ys, int ny, TFOutput* xs, int nx, TFOutput* dx, TF_Status status, TFOutput* dy);
818+
819+
/// <summary>
820+
/// Adds a gradient: the operations needed to compute the partial derivatives of sum of <paramref name="y"/>` wrt to <paramref name="x"/>.
821+
/// </summary>
822+
/// <returns>The partial derivatives, the size of the array is the same as the length of the <paramref name="y"/> array.</returns>
823+
/// <param name="y">The y elements.</param>
824+
/// <param name="x">The x elements.</param>
825+
/// <param name="dx">Initial gradients, which represent the symbolic partial derivatives of some loss function `L` w.r.t. <paramref name="y"/> ).
826+
/// If the parameter is null, the implementation will use dx for 'OnesLike' for all shapes in <paramref name="y"/></param>
827+
/// <param name="status">Status.</param>
828+
/// <remarks>
829+
/// d(y[0] + y[1]+ ...)/dx[0], d(y[0] + y[1] + ...)/dx[1]z...
830+
/// </remarks>
831+
public TFOutput [] AddGradients (TFOutput [] y, TFOutput [] x, TFOutput [] dx = null, TFStatus status = null)
832+
{
833+
if (y == null)
834+
throw new ArgumentNullException (nameof (y));
835+
if (x == null)
836+
throw new ArgumentNullException (nameof (x));
837+
if (dx != null) {
838+
if (dx.Length != y.Length)
839+
throw new ArgumentException ("If dx is not null, the size of the gradients must match the size of y", nameof (dx));
840+
}
841+
842+
var cstatus = TFStatus.Setup (status);
816843

844+
var ret = new TFOutput [x.Length];
845+
unsafe
846+
{
847+
fixed (TFOutput* pret = &ret [0]) {
848+
fixed (TFOutput* py = &y [0]) {
849+
fixed (TFOutput* px = &x [0]) {
850+
if (dx == null) {
851+
TF_AddGradients (handle, py, y.Length, px, x.Length, (TFOutput*)null, status.Handle, pret);
852+
} else {
853+
fixed (TFOutput* pdx = &dx [0]) {
854+
TF_AddGradients (handle, py, y.Length, px, x.Length, pdx, status.Handle, pret);
855+
}
856+
}
857+
}
858+
}
859+
}
860+
}
861+
if (!cstatus.CheckMaybeRaise (status, last: false))
862+
return null;
863+
return ret;
864+
}
817865
}
818866

819867
//
@@ -1625,6 +1673,28 @@ public int NumReturnOutputs {
16251673
}
16261674
}
16271675

1676+
[DllImport (NativeBinding.TensorFlowLibrary)]
1677+
extern static void TF_ImportGraphDefOptionsRemapControlDependency (TF_ImportGraphDefOptions opts, string srcName, TF_Operation dst);
1678+
1679+
/// <summary>
1680+
/// Sets any imported nodes with a given control input to have it replaced with an operation
1681+
/// </summary>
1682+
/// <param name="srcName">Node in the graph to be imported.</param>
1683+
/// <param name="destination">References an operation that already exists in the graph being imported.</param>
1684+
/// <remarks>
1685+
/// Set any imported nodes with control input <paramref name="srcName"/> to have that input
1686+
/// replaced with <paramref name="dst"/>.
1687+
/// </remarks>
1688+
public void RemapControlDependency (string srcName, TFOperation destination)
1689+
{
1690+
if (srcName == null)
1691+
throw new ArgumentNullException (nameof (srcName));
1692+
if (destination == null)
1693+
throw new ArgumentNullException (nameof (destination));
1694+
if (destination.Handle == IntPtr.Zero)
1695+
throw new ObjectDisposedException (nameof (destination));
1696+
TF_ImportGraphDefOptionsRemapControlDependency (handle, srcName, destination.Handle);
1697+
}
16281698
}
16291699

16301700
/// <summary>

0 commit comments

Comments
 (0)