@@ -92,6 +92,7 @@ bool IsListArg (OpDef.ArgDef arg)
9292 //
9393 Dictionary < string , bool > inferred_input_args ;
9494 List < OpDef . AttrDef > required_attrs , optional_attrs ;
95+ bool return_is_tfoutput ;
9596
9697 void SetupArguments ( OpDef def )
9798 {
@@ -117,6 +118,22 @@ void SetupArguments (OpDef def)
117118 else
118119 optional_attrs . Add ( attr ) ;
119120 }
121+ // API: currently, if we have a single ref TFOutput result, we make the signature of the
122+ // function return that TFOutput instead of the TFOperation (as you can get the TFOperation
123+ // from the TFOutput anyways.
124+ //
125+ // When we move to tuples, we could probably put everything in a Tuple result, but for now
126+ // mult-return functions will just return all outputs on ref variables, instead of the first
127+ // as a ref, and the rest as TFOutputs.
128+ //
129+ // This means that we generate methods like this:
130+ // TFOutput Constant (....)
131+ // when there is a single output
132+ //
133+ // TFOperation Foo (..)
134+ // When there is no result or more than one result.
135+ return_is_tfoutput = def . output_arg . Count == 1 ;
136+
120137 }
121138
122139 // Generates arguments:
@@ -134,10 +151,12 @@ string FillArguments (OpDef def)
134151 foreach ( var attr in required_attrs )
135152 sb . AppendFormat ( $ ", { CSharpType ( attr . type ) } { ParamMap ( attr . name ) } ") ;
136153
137- foreach ( var arg in def . output_arg ) {
138- string type = "TFOutput" + ( IsListArg ( arg ) ? "[]" : "" ) ;
154+ if ( ! return_is_tfoutput ) {
155+ foreach ( var arg in def . output_arg ) {
156+ string type = "TFOutput" + ( IsListArg ( arg ) ? "[]" : "" ) ;
139157
140- sb . AppendFormat ( $ ", ref { type } { ParamMap ( arg . name ) } ") ;
158+ sb . AppendFormat ( $ ", ref { type } { ParamMap ( arg . name ) } ") ;
159+ }
141160 }
142161
143162 int n = 0 ;
@@ -161,6 +180,7 @@ void Comment (string text)
161180 }
162181 }
163182
183+
164184 // Produces the C# inline documentation
165185 void GenDocs ( OpDef oper )
166186 {
@@ -174,12 +194,14 @@ void GenDocs (OpDef oper)
174194 Comment ( input . description ) ;
175195 p ( $ "/// </param>") ;
176196 }
177- foreach ( var attr in oper . output_arg ) {
178- if ( String . IsNullOrEmpty ( attr . description ) )
179- continue ;
180- p ( $ "/// <param name=\" { ParamMap ( attr . name ) } \" >") ;
181- Comment ( attr . description ) ;
182- p ( $ "/// </param>") ;
197+ if ( ! return_is_tfoutput ) {
198+ foreach ( var attr in oper . output_arg ) {
199+ if ( String . IsNullOrEmpty ( attr . description ) )
200+ continue ;
201+ p ( $ "/// <param name=\" { ParamMap ( attr . name ) } \" >") ;
202+ Comment ( attr . description ) ;
203+ p ( $ "/// </param>") ;
204+ }
183205 }
184206 p ( "/// <param name=\" operName\" >" ) ;
185207 p ( $ "/// If specified, the created operation in the graph will be this one, otherwise it will be named '{ oper . name } '.") ;
@@ -193,6 +215,11 @@ void GenDocs (OpDef oper)
193215 p ( $ "/// </param>") ;
194216 }
195217
218+ if ( return_is_tfoutput ) {
219+ p ( $ "/// <returns>") ;
220+ Comment ( oper . output_arg . First ( ) . description ) ;
221+ p ( $ "/// </returns>") ;
222+ }
196223 if ( ! String . IsNullOrEmpty ( oper . description ) ) {
197224 p ( "/// <remarks>" ) ;
198225 Comment ( oper . description ) ;
@@ -246,15 +273,26 @@ void SetAttribute (string type, string attrName, string csAttrName)
246273 /// <param name="oper">Oper.</param>
247274 void Generate ( OpDef oper )
248275 {
276+
249277 SetupArguments ( oper ) ;
250278 GenDocs ( oper ) ;
251279
252280 var name = oper . name ;
281+ string retType ;
282+
283+ if ( return_is_tfoutput ) {
284+ if ( oper . output_arg . Any ( x => IsListArg ( x ) ) )
285+ retType = "TFOutput []" ;
286+ else
287+ retType = "TFOutput" ;
288+ } else
289+ retType = "TFOperation" ;
253290
254- p ( $ "public TFOperation { name } (Scope scope{ FillArguments ( oper ) } , string operName = null)") ;
291+
292+ p ( $ "public { retType } { name } (Scope scope{ FillArguments ( oper ) } , string operName = null)") ;
255293 pi ( "{" ) ;
256294 bool needStatus = required_attrs . Concat ( optional_attrs ) . Any ( attr => attr . type . Contains ( "TFTensor" ) ) ;
257- p ( $ "var desc = new TFOperationDesc (this, operName , operName == null ? \" { oper . name } \" : operName);") ;
295+ p ( $ "var desc = new TFOperationDesc (this, \" { oper . name } \" , operName == null ? \" { oper . name } \" : operName);") ;
258296 foreach ( var arg in oper . input_arg ) {
259297 if ( IsListArg ( arg ) )
260298 p ( $ "desc.AddInputs ({ ParamMap ( arg . name ) } );") ;
@@ -285,24 +323,42 @@ void Generate (OpDef oper)
285323 if ( oper . output_arg . Any ( x => IsListArg ( x ) ) ) {
286324 p ( "int _idx = 0, _n = 0;" ) ;
287325 foreach ( var arg in oper . output_arg ) {
288-
326+ string retDecl = "" , retOutput ;
327+
328+ if ( return_is_tfoutput ) {
329+ retDecl = "var " ;
330+ retOutput = "_ret" ;
331+ } else
332+ retOutput = ParamMap ( arg . name ) ;
333+
289334 if ( IsListArg ( arg ) ) {
290335 var outputs = new StringBuilder ( ) ;
291- p ( "_n = op.InputListLength (\" arg.name\" );" ) ;
292- p ( $ "{ ParamMap ( arg . name ) } = new TFOutput [_n];") ;
336+ p ( $ "_n = op.OutputListLength (\" { arg . name } \" );") ;
337+ p ( $ "{ retDecl } { retOutput } = new TFOutput [_n];") ;
293338 pi ( "for (int i = 0; i < _n; i++)" ) ;
294- p ( $ "{ ParamMap ( arg . name ) } [i] = new TFOutput (op, _idx++);") ;
339+ p ( $ "{ retOutput } [i] = new TFOutput (op, _idx++);") ;
295340 pd ( "" ) ;
296- } else
297- p ( $ "{ ParamMap ( arg . name ) } = new TFOutput (op, _idx++);") ;
341+ if ( return_is_tfoutput )
342+ p ( $ "return { retOutput } ;") ;
343+ } else {
344+ if ( return_is_tfoutput ) {
345+ p ( $ "return new TFOutput (op, _idx++);") ;
346+ } else {
347+ p ( $ "{ retOutput } = new TFOutput (op, _idx++);") ;
348+ }
349+ }
298350 }
299351 } else {
300352 int idx = 0 ;
301353 foreach ( var arg in oper . output_arg ) {
302- p ( $ "{ ParamMap ( arg . name ) } = new TFOutput (op, { idx ++ } );") ;
354+ if ( return_is_tfoutput )
355+ p ( $ "return new TFOutput (op, 0);") ;
356+ else
357+ p ( $ "{ ParamMap ( arg . name ) } = new TFOutput (op, { idx ++ } );") ;
303358 }
304359 }
305- p ( "return op;" ) ;
360+ if ( ! return_is_tfoutput )
361+ p ( "return op;" ) ;
306362 pd ( "}\n " ) ;
307363 }
308364
0 commit comments