@@ -221,7 +221,73 @@ public static TFTensor FromBuffer (TFShape shape, Complex [] data, int start, in
221221 {
222222 return new TFTensor ( SetupTensor ( TFDataType . Complex128 , shape , data , start , count , size : 16 ) ) ;
223223 }
224-
224+
225+ /// <summary>
226+ /// Creates a constant tensor from an array, the shape reflects the shape of the C# array and the underlying type reflects the C# type.
227+ /// </summary>
228+ public unsafe TFTensor ( Array array )
229+ {
230+ if ( array == null )
231+ throw new ArgumentNullException ( nameof ( array ) ) ;
232+ // TODO: ensure that we do not have arrays of arrays.
233+ var t = array . GetType ( ) . GetElementType ( ) ;
234+ var tc = Type . GetTypeCode ( t ) ;
235+ TFDataType dt ;
236+ long size = 0 ;
237+ switch ( tc ) {
238+ case TypeCode . Boolean :
239+ dt = TFDataType . Bool ;
240+ size = 1 ;
241+ break ;
242+ case TypeCode . SByte :
243+ dt = TFDataType . Int8 ;
244+ size = 1 ;
245+ break ;
246+ case TypeCode . Byte :
247+ dt = TFDataType . UInt8 ;
248+ size = 1 ;
249+ break ;
250+ case TypeCode . Int16 :
251+ dt = TFDataType . Int16 ;
252+ size = 2 ;
253+ break ;
254+ case TypeCode . UInt16 :
255+ dt = TFDataType . UInt16 ;
256+ size = 2 ;
257+ break ;
258+ case TypeCode . Int32 :
259+ dt = TFDataType . Int32 ;
260+ size = 4 ;
261+ break ;
262+ case TypeCode . Int64 :
263+ dt = TFDataType . Int64 ;
264+ size = 8 ;
265+ break ;
266+ case TypeCode . Single :
267+ dt = TFDataType . Float ;
268+ size = 4 ;
269+ break ;
270+ case TypeCode . Double :
271+ dt = TFDataType . Double ;
272+ size = 8 ;
273+ break ;
274+ default :
275+ // Check types that are not handled by the typecode
276+ if ( t . IsAssignableFrom ( typeof ( Complex ) ) ) {
277+ size = 16 ;
278+ dt = TFDataType . Complex128 ;
279+ } else
280+ throw new ArgumentException ( $ "The data type { t } is not supported") ;
281+ break ;
282+ }
283+
284+ var dims = new long [ array . Rank ] ;
285+ for ( int i = 0 ; i < array . Rank ; i ++ ) {
286+ dims [ i ] = array . GetLength ( i ) ;
287+ size *= ( int ) dims [ i ] ;
288+ }
289+ handle = SetupMulti ( dt , dims , array , size ) ;
290+ }
225291
226292 /// <summary>
227293 /// Creates a constant tensor with a single dimension from an integer value.
@@ -541,67 +607,8 @@ unsafe public static implicit operator TFTensor (byte value)
541607 /// </remarks>
542608 unsafe public static implicit operator TFTensor ( Array array )
543609 {
544- if ( array == null )
545- throw new ArgumentNullException ( nameof ( array ) ) ;
546- // TODO: ensure that we do not have arrays of arrays.
547- var t = array . GetType ( ) . GetElementType ( ) ;
548- var tc = Type . GetTypeCode ( t ) ;
549- TFDataType dt ;
550- long size = 0 ;
551- switch ( tc ) {
552- case TypeCode . Boolean :
553- dt = TFDataType . Bool ;
554- size = 1 ;
555- break ;
556- case TypeCode . SByte :
557- dt = TFDataType . Int8 ;
558- size = 1 ;
559- break ;
560- case TypeCode . Byte :
561- dt = TFDataType . UInt8 ;
562- size = 1 ;
563- break ;
564- case TypeCode . Int16 :
565- dt = TFDataType . Int16 ;
566- size = 2 ;
567- break ;
568- case TypeCode . UInt16 :
569- dt = TFDataType . UInt16 ;
570- size = 2 ;
571- break ;
572- case TypeCode . Int32 :
573- dt = TFDataType . Int32 ;
574- size = 4 ;
575- break ;
576- case TypeCode . Int64 :
577- dt = TFDataType . Int64 ;
578- size = 8 ;
579- break ;
580- case TypeCode . Single :
581- dt = TFDataType . Float ;
582- size = 4 ;
583- break ;
584- case TypeCode . Double :
585- dt = TFDataType . Double ;
586- size = 8 ;
587- break ;
588- default :
589- // Check types that are not handled by the typecode
590- if ( t . IsAssignableFrom ( typeof ( Complex ) ) ) {
591- size = 16 ;
592- dt = TFDataType . Complex128 ;
593- } else
594- throw new ArgumentException ( $ "The data type { t } is not supported") ;
595- break ;
596- }
610+ return new TFTensor ( array ) ;
597611
598- var dims = new long [ array . Rank ] ;
599- for ( int i = 0 ; i < array . Rank ; i ++ ) {
600- dims [ i ] = array . GetLength ( i ) ;
601- size *= ( int ) dims [ i ] ;
602- }
603- var newTensor = new TFTensor ( SetupMulti ( dt , dims , array , size ) ) ;
604- return newTensor ;
605612 }
606613
607614 // General purpose constructor, specifies data type and gets pointer to buffer
0 commit comments