@@ -813,7 +813,55 @@ public TFOutput [] While (TFOutput [] inputs, WhileConstructor constructor, TFSt
813813 }
814814 }
815815
816+ [ DllImport ( NativeBinding . TensorFlowLibrary ) ]
817+ static extern unsafe void TF_AddGradients ( TF_Graph graph , TFOutput * ys , int ny , TFOutput * xs , int nx , TFOutput * dx , TF_Status status , TFOutput * dy ) ;
818+
819+ /// <summary>
820+ /// Adds a gradient: the operations needed to compute the partial derivatives of sum of <paramref name="y"/>` wrt to <paramref name="x"/>.
821+ /// </summary>
822+ /// <returns>The partial derivatives, the size of the array is the same as the length of the <paramref name="y"/> array.</returns>
823+ /// <param name="y">The y elements.</param>
824+ /// <param name="x">The x elements.</param>
825+ /// <param name="dx">Initial gradients, which represent the symbolic partial derivatives of some loss function `L` w.r.t. <paramref name="y"/> ).
826+ /// If the parameter is null, the implementation will use dx for 'OnesLike' for all shapes in <paramref name="y"/></param>
827+ /// <param name="status">Status.</param>
828+ /// <remarks>
829+ /// d(y[0] + y[1]+ ...)/dx[0], d(y[0] + y[1] + ...)/dx[1]z...
830+ /// </remarks>
831+ public TFOutput [ ] AddGradients ( TFOutput [ ] y , TFOutput [ ] x , TFOutput [ ] dx = null , TFStatus status = null )
832+ {
833+ if ( y == null )
834+ throw new ArgumentNullException ( nameof ( y ) ) ;
835+ if ( x == null )
836+ throw new ArgumentNullException ( nameof ( x ) ) ;
837+ if ( dx != null ) {
838+ if ( dx . Length != y . Length )
839+ throw new ArgumentException ( "If dx is not null, the size of the gradients must match the size of y" , nameof ( dx ) ) ;
840+ }
841+
842+ var cstatus = TFStatus . Setup ( status ) ;
816843
844+ var ret = new TFOutput [ x . Length ] ;
845+ unsafe
846+ {
847+ fixed ( TFOutput * pret = & ret [ 0 ] ) {
848+ fixed ( TFOutput * py = & y [ 0 ] ) {
849+ fixed ( TFOutput * px = & x [ 0 ] ) {
850+ if ( dx == null ) {
851+ TF_AddGradients ( handle , py , y . Length , px , x . Length , ( TFOutput * ) null , status . Handle , pret ) ;
852+ } else {
853+ fixed ( TFOutput * pdx = & dx [ 0 ] ) {
854+ TF_AddGradients ( handle , py , y . Length , px , x . Length , pdx , status . Handle , pret ) ;
855+ }
856+ }
857+ }
858+ }
859+ }
860+ }
861+ if ( ! cstatus . CheckMaybeRaise ( status , last : false ) )
862+ return null ;
863+ return ret ;
864+ }
817865 }
818866
819867 //
@@ -1625,6 +1673,28 @@ public int NumReturnOutputs {
16251673 }
16261674 }
16271675
1676+ [ DllImport ( NativeBinding . TensorFlowLibrary ) ]
1677+ extern static void TF_ImportGraphDefOptionsRemapControlDependency ( TF_ImportGraphDefOptions opts , string srcName , TF_Operation dst ) ;
1678+
1679+ /// <summary>
1680+ /// Sets any imported nodes with a given control input to have it replaced with an operation
1681+ /// </summary>
1682+ /// <param name="srcName">Node in the graph to be imported.</param>
1683+ /// <param name="destination">References an operation that already exists in the graph being imported.</param>
1684+ /// <remarks>
1685+ /// Set any imported nodes with control input <paramref name="srcName"/> to have that input
1686+ /// replaced with <paramref name="dst"/>.
1687+ /// </remarks>
1688+ public void RemapControlDependency ( string srcName , TFOperation destination )
1689+ {
1690+ if ( srcName == null )
1691+ throw new ArgumentNullException ( nameof ( srcName ) ) ;
1692+ if ( destination == null )
1693+ throw new ArgumentNullException ( nameof ( destination ) ) ;
1694+ if ( destination . Handle == IntPtr . Zero )
1695+ throw new ObjectDisposedException ( nameof ( destination ) ) ;
1696+ TF_ImportGraphDefOptionsRemapControlDependency ( handle , srcName , destination . Handle ) ;
1697+ }
16281698 }
16291699
16301700 /// <summary>
0 commit comments