@@ -1498,6 +1498,116 @@ string MakeUnique (string name)
14981498 return name + val ;
14991499 }
15001500
1501+ [ DllImport ( NativeBinding . TensorFlowLibrary ) ]
1502+ unsafe extern static void TF_GraphImportGraphDefWithReturnOutputs (
1503+ TF_Graph graph , LLBuffer * graph_def ,
1504+ TF_ImportGraphDefOptions options , TFOutput * return_outputs ,
1505+ int num_return_outputs , TF_Status status ) ;
1506+
1507+ /// <summary>
1508+ /// Imports a graph serialized into the graph
1509+ /// </summary>
1510+ /// <param name="graphDef">Serialized graph definition (in protocol buffer format).</param>
1511+ /// <param name="options">Import options.</param>
1512+ /// <param name="returnOutputs">Array large enough to contain all the return options.</param>
1513+ /// <param name="status">Status, optional.</param>
1514+ public void ImportGraphDef ( TFBuffer graphDef , TFImportGraphDefOptions options , TFOutput [ ] returnOutputs , TFStatus status = null )
1515+ {
1516+ if ( handle == IntPtr . Zero )
1517+ ObjectDisposedException ( ) ;
1518+ if ( graphDef == null )
1519+ throw new ArgumentNullException ( nameof ( graphDef ) ) ;
1520+ if ( options == null )
1521+ throw new ArgumentNullException ( nameof ( options ) ) ;
1522+ var cstatus = TFStatus . Setup ( status ) ;
1523+
1524+ unsafe
1525+ {
1526+ if ( returnOutputs == null ) {
1527+ TF_GraphImportGraphDefWithReturnOutputs ( handle , graphDef . LLBuffer , options . handle , null , 0 , cstatus . handle ) ;
1528+ } else {
1529+ fixed ( TFOutput * first = & returnOutputs [ 0 ] )
1530+ {
1531+ TF_GraphImportGraphDefWithReturnOutputs ( handle , graphDef . LLBuffer , options . handle , first , returnOutputs . Length , cstatus . handle ) ;
1532+ }
1533+ }
1534+ }
1535+ }
1536+
1537+ [ StructLayout ( LayoutKind . Sequential ) ]
1538+ unsafe struct TFWhileParams
1539+ {
1540+ int ninputs ;
1541+ TF_Graph cond_graph ;
1542+ TFOutput * cond_inputs ;
1543+ TFOutput cond_output ;
1544+ TF_Graph * body_graph ;
1545+ TFOutput * body_inputs ;
1546+ TFOutput * body_outputs ;
1547+ IntPtr charPtrName ;
1548+ }
1549+
1550+ [ DllImport ( NativeBinding . TensorFlowLibrary ) ]
1551+ static extern unsafe TFWhileParams TF_NewWhile ( TF_Graph g , TFOutput [ ] inputs , int ninputs , TF_Status status ) ;
1552+
1553+ [ DllImport ( NativeBinding . TensorFlowLibrary ) ]
1554+ static extern void TF_AbortWhile ( ref TFWhileParams pars ) ;
1555+
1556+ [ DllImport ( NativeBinding . TensorFlowLibrary ) ]
1557+ static extern unsafe void TF_FinishWhile ( ref TFWhileParams pars , TF_Status status , TFOutput [ ] outputs ) ;
1558+
1559+ /// <summary>
1560+ /// Signature of the method that will be invoked by the TFGraph.While method to construct a while loop
1561+ /// </summary>
1562+ /// <remarks>
1563+ /// The method should build up the condition on the conditionGraph and the body of the while
1564+ /// loop in the provided bodyGraph. It should set the condOutput to the value used as the
1565+ /// condition output and the array of values in bodyOutputs to the final outputs as well as the
1566+ /// name to be used, if not set, one will be assigned.
1567+ /// </remarks>
1568+ public delegate void WhileConstructor ( TFGraph conditionGraph , TFGraph bodyGraph , TFOutput [ ] condInputs , TFOutput [ ] bodyInputs , ref TFOutput condOutput , ref TFOutput [ ] bodyOutputs , ref string name ) ;
1569+
1570+ /// <summary>
1571+ /// Constructs a while loop with the specified inputs and a callback that composes the while loop
1572+ /// </summary>
1573+ /// <param name="inputs">Inputs.</param>
1574+ /// <param name="constructor">Callback method that fills out the various while loop parameters.</param>
1575+ /// <returns>
1576+ /// true on success, or false if it was not possible to create the while loop.
1577+ /// </returns>
1578+ public TFOutput [ ] While ( TFOutput [ ] inputs , WhileConstructor constructor , TFStatus status = null )
1579+ {
1580+ if ( handle == IntPtr . Zero )
1581+ ObjectDisposedException ( ) ;
1582+ if ( inputs == null )
1583+ throw new ArgumentNullException ( nameof ( inputs ) ) ;
1584+ if ( constructor == null )
1585+ throw new ArgumentNullException ( nameof ( constructor ) ) ;
1586+ var s = TFStatus . Setup ( status ) ;
1587+ var result = TF_NewWhile ( handle , inputs , inputs . Length , s . handle ) ;
1588+ if ( s . Error )
1589+ return null ;
1590+ try {
1591+ //
1592+ // Call constructor here
1593+ // Wrap the various TF_graphs (with owns=false)
1594+ // Marshal the condInputs, bodyInputs
1595+ //
1596+ // TODO:
1597+ throw new NotImplementedException ( ) ;
1598+
1599+ // On return, copy the condOutput and bodyOututs
1600+ // Set the name
1601+ var ret = new TFOutput [ inputs . Length ] ;
1602+ TF_FinishWhile ( ref result , s . handle , ret ) ;
1603+ return ret ;
1604+ } catch {
1605+ TF_AbortWhile ( ref result ) ;
1606+ return null ;
1607+ }
1608+ }
1609+
1610+
15011611 }
15021612
15031613 /// <summary>
@@ -2213,6 +2323,81 @@ public void SetPrefix (string prefix)
22132323 TF_ImportGraphDefOptionsSetPrefix ( handle , prefix ) ;
22142324 }
22152325
2326+ // extern void TF_ImportGraphDefOptionsAddInputMapping (TF_ImportGraphDefOptions *opts, const char* src_name, int src_index, TF_Output dst);
2327+ [ DllImport ( NativeBinding . TensorFlowLibrary ) ]
2328+ static extern unsafe void TF_ImportGraphDefOptionsAddInputMapping ( TF_ImportGraphDefOptions opts , string src_name , int src_index , TFOutput dst ) ;
2329+
2330+
2331+ /// <summary>
2332+ /// Adds an input mapping from a source name and index to a destination output
2333+ /// </summary>
2334+ /// <param name="srcName">Source name.</param>
2335+ /// <param name="srcIndex">Source index (in the source).</param>
2336+ /// <param name="dst">Replacement value for the srcName:srcIndex.</param>
2337+ /// <remarks>
2338+ /// Set any imported nodes with input `src_name:src_index` to have that input
2339+ /// replaced with `dst`. `src_name` refers to a node in the graph to be imported,
2340+ /// `dst` references a node already existing in the graph being imported into.
2341+ /// </remarks>
2342+ public void AddInputMapping ( string srcName , int srcIndex , TFOutput dst )
2343+ {
2344+ if ( handle == IntPtr . Zero )
2345+ ObjectDisposedException ( ) ;
2346+ TF_ImportGraphDefOptionsAddInputMapping ( handle , srcName , srcIndex , dst ) ;
2347+ }
2348+
2349+ [ DllImport ( NativeBinding . TensorFlowLibrary ) ]
2350+ extern static void TF_ImportGraphDefOptionsAddControlDependency ( TF_ImportGraphDefOptions opts , TF_Operation oper ) ;
2351+
2352+ /// <summary>
2353+ /// Cause the imported graph to have a control dependency on the provided operation.
2354+ /// </summary>
2355+ /// <param name="operation">This operation should exist in the graph being imported to.</param>
2356+ public void AddControlDependency ( TFOperation operation )
2357+ {
2358+ if ( operation == null )
2359+ throw new ArgumentNullException ( nameof ( operation ) ) ;
2360+ if ( handle == IntPtr . Zero )
2361+ ObjectDisposedException ( ) ;
2362+
2363+ TF_ImportGraphDefOptionsAddControlDependency ( handle , operation . handle ) ;
2364+ }
2365+
2366+ [ DllImport ( NativeBinding . TensorFlowLibrary ) ]
2367+ extern static void TF_ImportGraphDefOptionsAddReturnOutput ( TF_ImportGraphDefOptions opts , string oper_name , int index ) ;
2368+
2369+ /// <summary>
2370+ /// Add an output in the graph definition to be returned via the return outputs parameter.
2371+ /// </summary>
2372+ /// <param name="operName">Operation name.</param>
2373+ /// <param name="index">Operation index.</param>
2374+ /// <remarks>
2375+ /// If the output is remapped via an input
2376+ /// mapping, the corresponding existing tensor in graph will be returned.
2377+ /// </remarks>
2378+ public void AddReturnOutput ( string operName , int index )
2379+ {
2380+ if ( operName == null )
2381+ throw new ArgumentNullException ( nameof ( operName ) ) ;
2382+ if ( handle == IntPtr . Zero )
2383+ ObjectDisposedException ( ) ;
2384+ TF_ImportGraphDefOptionsAddReturnOutput ( handle , operName , index ) ;
2385+ }
2386+
2387+ [ DllImport ( NativeBinding . TensorFlowLibrary ) ]
2388+ extern static int TF_ImportGraphDefOptionsNumReturnOutputs ( TF_ImportGraphDefOptions opts ) ;
2389+
2390+ /// <summary>
2391+ /// Gets the number return outputs added via AddReturnOutput.
2392+ /// </summary>
2393+ /// <value>The number return outputs.</value>
2394+ public int NumReturnOutputs {
2395+ get {
2396+ if ( handle == IntPtr . Zero )
2397+ ObjectDisposedException ( ) ;
2398+ return TF_ImportGraphDefOptionsNumReturnOutputs ( handle ) ;
2399+ }
2400+ }
22162401
22172402 }
22182403
@@ -2854,4 +3039,6 @@ public override string ToString ()
28543039 }
28553040 }
28563041
3042+
3043+
28573044}
0 commit comments