1- using System ;
1+ //
2+ // This is the driver for the operation generator, this takes data that
3+ // is provided by the Tensorflow runtime to produce strongly-typed and
4+ // high level methods on the TFGraph class.
5+ //
6+ // The result is generated into a partial class that is lined with the
7+ // main TensorFlowSharp library
8+ //
9+ // Authors:
10+ // Miguel de Icaza
11+ //
12+ // Copyright 2017, the year of downfall, Microsoft Inc
13+ //
14+ #pragma warning disable RECS0063 // Warns when a culture-aware 'StartsWith' call is used by default.
15+
16+ using System ;
217using System . Collections . Generic ;
318using System . IO ;
419using ProtoBuf ;
924
1025class OpGenerator
1126{
12- StreamWriter output ;
13-
1427 //
1528 // Maps a TensorFlow type to a C# type
1629 //
@@ -46,6 +59,15 @@ string CSharpType (string tfType)
4659 return cstype + ( list ? "[]" : "" ) ;
4760 }
4861
62+ bool IsReferenceType ( string tfType )
63+ {
64+ if ( tfType . StartsWith ( "list(" ) )
65+ return true ;
66+ if ( tfType == "tensor" || tfType == "string" || tfType == "shape" )
67+ return true ;
68+ return false ;
69+ }
70+
4971 // Maps a parameter name to a C# acceptable name, to avoid clashes with
5072 // language keywords
5173 string ParamMap ( string paramName )
@@ -118,10 +140,14 @@ string FillArguments (OpDef def)
118140 sb . AppendFormat ( $ ", ref { type } { ParamMap ( arg . name ) } ") ;
119141 }
120142
121- // FIXME: finish this part
122143 int n = 0 ;
123- foreach ( var attr in optional_attrs )
124- sb . AppendFormat ( $ ", object optional{ n ++ } ") ;
144+ foreach ( var attr in optional_attrs ) {
145+ bool reftype = IsReferenceType ( attr . type ) ;
146+ var cstype = CSharpType ( attr . type ) ;
147+ var cstypesuffix = reftype ? "" : "?" ;
148+
149+ sb . AppendFormat ( $ ", { cstype } { cstypesuffix } { attr . name } = null") ;
150+ }
125151 return sb . ToString ( ) ;
126152 }
127153
@@ -158,13 +184,62 @@ void GenDocs (OpDef oper)
158184 p ( "/// <param name=\" operName\" >" ) ;
159185 p ( $ "/// If specified, the created operation in the graph will be this one, otherwise it will be named '{ oper . name } '.") ;
160186 p ( "/// </param>" ) ;
187+ foreach ( var attr in optional_attrs ) {
188+ if ( String . IsNullOrEmpty ( attr . description ) )
189+ continue ;
190+ p ( $ "/// <param name=\" { ParamMap ( attr . name ) } \" >") ;
191+ Comment ( "Optional argument" ) ;
192+ Comment ( attr . description ) ;
193+ p ( $ "/// </param>") ;
194+ }
195+
161196 if ( ! String . IsNullOrEmpty ( oper . description ) ) {
162197 p ( "/// <remarks>" ) ;
163198 Comment ( oper . description ) ;
164199 p ( "/// </remarks>" ) ;
165200 }
166201 }
167202
203+ void SetAttribute ( string type , string attrName , string csAttrName )
204+ {
205+ if ( type == "shape" ) {
206+ p ( $ "desc.SetAttrShape (\" { attrName } \" , { csAttrName } );") ;
207+ return ;
208+ }
209+ if ( type . StartsWith ( "list(shape" ) ) {
210+ p ( $ "desc.SetAttrShape (\" { attrName } \" , { csAttrName } );") ;
211+ return ;
212+ }
213+
214+ var cstype = CSharpType ( type ) ;
215+ switch ( cstype ) {
216+ case "long" :
217+ case "long[]" :
218+ case "string" :
219+ case "string[]" :
220+ case "float" :
221+ case "float[]" :
222+ case "bool" :
223+ case "bool[]" :
224+ p ( $ "desc.SetAttr (\" { attrName } \" , { csAttrName } );") ;
225+ break ;
226+ case "TFDataType" :
227+ case "TFDataType[]" :
228+ p ( $ "desc.SetAttrType (\" { attrName } \" , { csAttrName } );") ;
229+ break ;
230+
231+ // This should pass the cstatus, but requires the
232+ // function to take a TFStatus as well, so need to weave that
233+ // in the parameters
234+ case "TFTensor" :
235+ case "TFTensor[]" :
236+ p ( $ "desc.SetAttr (\" { attrName } \" , { csAttrName } /* cstatus */);") ;
237+ break ;
238+ default :
239+ throw new Exception ( "Unexpected type: " + cstype ) ;
240+ }
241+ }
242+
168243 /// <summary>
169244 /// Generate the specified oper.
170245 /// </summary>
@@ -190,33 +265,22 @@ void Generate (OpDef oper)
190265 // If we have attributes
191266 if ( required_attrs . Count > 0 || optional_attrs . Count > 0 ) {
192267 foreach ( var attr in required_attrs ) {
193- var cstype = CSharpType ( attr . type ) ;
194- switch ( cstype ) {
195- case "int" :
196- case "int[]" :
197- case "string" :
198- case "string[]" :
199- case "float" :
200- case "float[]" :
201- case "bool" :
202- case "bool[]" :
203- p ( $ "desc.SetAttr (\" { attr . name } \" , { ParamMap ( attr . name ) } );") ;
204- break ;
205- case "TFDataType" :
206- case "TFDataType[]" :
207- p ( $ "desc.SetAttrType (\" { attr . name } \" , { ParamMap ( attr . name ) } );") ;
208- break ;
209-
210- // This should pass the cstatus, but requires the
211- // function to take a TFStatus as well, so need to weave that
212- // in the parameters
213- case "TFTensor" :
214- case "TFTensor[]" :
215- p ( $ "desc.SetAttr (\" { attr . name } \" , { ParamMap ( attr . name ) } /* cstatus */);") ;
216- break ;
217- }
268+ SetAttribute ( attr . type , attr . name , ParamMap ( attr . name ) ) ;
269+ }
270+
271+ foreach ( var attr in optional_attrs ) {
272+ var reftype = IsReferenceType ( attr . type ) ;
273+ var csattr = ParamMap ( attr . name ) ;
274+ if ( reftype )
275+ pi ( $ "if ({ csattr } != null)") ;
276+ else
277+ pi ( $ "if ({ csattr } .HasValue)") ;
278+ SetAttribute ( attr . type , attr . name , csattr + ( reftype ? "" : ".Value" ) ) ;
279+ pd ( "" ) ;
280+
218281 }
219282 }
283+
220284 p ( "var op = desc.FinishOperation ();" ) ;
221285 if ( oper . output_arg . Any ( x => IsListArg ( x ) ) ) {
222286 p ( "int _idx = 0, _n = 0;" ) ;
@@ -248,6 +312,8 @@ void Run ()
248312 output = File . CreateText ( "../../../TensorFlowSharp/Operations.cs" ) ;
249313
250314 var operations = Serializer . Deserialize < List < OpDef > > ( new MemoryStream ( TFCore . GetAllOpList ( ) . ToArray ( ) ) ) ;
315+ p ( "using System;\n " ) ;
316+
251317 pi ( "namespace TensorFlow {" ) ;
252318 pi ( "public partial class TFGraph {" ) ;
253319 foreach ( var oper in operations ) {
@@ -256,16 +322,24 @@ void Run ()
256322 continue ;
257323
258324 // Ignore functions where we lack a C# type mapping
259- if ( oper . attr . Any ( attr => CSharpType ( attr . type ) == null ) )
325+ if ( oper . attr . Any ( attr => CSharpType ( attr . type ) == null ) ) {
326+ var attr = oper . attr . First ( a => CSharpType ( a . type ) == null ) ;
327+
328+ //Console.WriteLine ($"Skip: {oper.name} due to attribute ({attr.type} {attr.name}) lacking a mapping to C#");
260329 continue ;
330+ }
261331
262332 // Ignore reference types as well (per go's binding)
263- if ( oper . input_arg . Any ( ia => ia . is_ref ) )
333+ if ( oper . input_arg . Any ( ia => ia . is_ref ) ) {
334+ //Console.WriteLine ($"Skip: {oper.name} due to presence of an input argument that is a reference");
264335 continue ;
265-
336+ }
337+
266338 // Ignore reference types as well (per go's binding)
267- if ( oper . output_arg . Any ( ia => ia . is_ref ) )
339+ if ( oper . output_arg . Any ( ia => ia . is_ref ) ) {
340+ //Console.WriteLine ($"Skip: {oper.name} due to presence of an output argument that is a reference");
268341 continue ;
342+ }
269343
270344 // Undocumented operation, perhaps we should not surface
271345 if ( oper . summary == "" )
@@ -278,8 +352,12 @@ void Run ()
278352 output . Close ( ) ;
279353 }
280354
355+ // The output file
356+ StreamWriter output ;
357+
281358 int indent = 0 ;
282359
360+ // Convenience methods to generate output
283361 void pi ( string fmt , params object [ ] args )
284362 {
285363 p ( fmt , args ) ;
0 commit comments