@@ -111,7 +111,7 @@ public class TFStatus : TFDisposable
111111 {
112112 // extern TF_Status * TF_NewStatus ();
113113 [ DllImport ( NativeBinding . TensorFlowLibrary ) ]
114- static extern unsafe TF_Status TF_NewStatus ( ) ;
114+ internal static extern unsafe TF_Status TF_NewStatus ( ) ;
115115
116116 [ ThreadStatic ] public static TFStatus Default = new TFStatus ( ) ;
117117
@@ -121,7 +121,7 @@ public TFStatus () : base (TF_NewStatus ())
121121
122122 // extern void TF_DeleteStatus (TF_Status *);
123123 [ DllImport ( NativeBinding . TensorFlowLibrary ) ]
124- static extern unsafe void TF_DeleteStatus ( TF_Status status ) ;
124+ internal static extern unsafe void TF_DeleteStatus ( TF_Status status ) ;
125125
126126 internal override void NativeDispose ( IntPtr handle )
127127 {
@@ -140,7 +140,7 @@ public void SetStatusCode (TFCode code, string msg)
140140
141141 // extern TF_Code TF_GetCode (const TF_Status *s);
142142 [ DllImport ( NativeBinding . TensorFlowLibrary ) ]
143- static extern unsafe TFCode TF_GetCode ( TF_Status s ) ;
143+ internal static extern unsafe TFCode TF_GetCode ( TF_Status s ) ;
144144
145145 public TFCode StatusCode {
146146 get {
@@ -292,6 +292,7 @@ public byte [] ToArray ()
292292 }
293293 }
294294
295+
295296 public delegate void TFTensorDeallocator ( IntPtr data , IntPtr size , IntPtr deallocatorData ) ;
296297
297298 public class TFTensor : TFDisposable
@@ -317,6 +318,7 @@ internal static void FreeTensorHandle (IntPtr data, IntPtr len, IntPtr closure)
317318
318319 // TODO: Other overloads we could add: String, Complex (float), Bool, QInt8, QUInt8, QInt32, Bfloat16,
319320 // QInt16, QUint16, Half, Resource
321+ // TODO: not clear that this is very useful (the dims versions).
320322 public TFTensor ( long [ ] dims , sbyte [ ] data , int start , int count ) : base ( SetupTensor ( TFDataType . Int8 , dims , data , start , count , size : 2 ) ) { }
321323 public TFTensor ( long [ ] dims , byte [ ] data , int start , int count ) : base ( SetupTensor ( TFDataType . UInt8 , dims , data , start , count , size : 1 ) ) { }
322324 public TFTensor ( long [ ] dims , short [ ] data , int start , int count ) : base ( SetupTensor ( TFDataType . Int16 , dims , data , start , count , size : 2 ) ) { }
@@ -336,6 +338,32 @@ public TFTensor (long [] dims, double [] data) : base (SetupTensor (TFDataType.D
336338 public TFTensor ( long [ ] dims , long [ ] data ) : base ( SetupTensor ( TFDataType . Int64 , dims , data , size : 8 ) ) { }
337339 public TFTensor ( long [ ] dims , Complex [ ] data ) : base ( SetupTensor ( TFDataType . Complex128 , dims , data , size : 16 ) ) { }
338340
341+ public unsafe static TFTensor CreateString ( byte [ ] buffer )
342+ {
343+ if ( buffer == null )
344+ throw new ArgumentNullException ( nameof ( buffer ) ) ;
345+ //
346+ // TF_STRING tensors are encoded with a table of 8-byte offsets followed by
347+ // TF_StringEncode-encoded bytes.
348+ //
349+ var size = TFString . TF_StringEncodedSize ( ( UIntPtr ) buffer . Length ) ;
350+ IntPtr handle = TF_AllocateTensor ( TFDataType . String , IntPtr . Zero , 0 , ( UIntPtr ) ( ( ulong ) size + 8 ) ) ;
351+
352+ // Clear offset table
353+ IntPtr dst = TF_TensorData ( handle ) ;
354+ Marshal . WriteInt64 ( dst , 0 ) ;
355+ var status = TFStatus . TF_NewStatus ( ) ;
356+ fixed ( byte * src = & buffer [ 0 ] )
357+ {
358+ TFString . TF_StringEncode ( src , ( UIntPtr ) buffer . Length , ( sbyte * ) ( dst + 8 ) , size , status ) ;
359+ var ok = TFStatus . TF_GetCode ( status ) == TFCode . Ok ;
360+ TFStatus . TF_DeleteStatus ( status ) ;
361+ if ( ! ok )
362+ return null ;
363+ }
364+ return new TFTensor ( handle ) ;
365+ }
366+
339367 // Convenience function to factor out the setup of a new tensor from an array
340368 static IntPtr SetupTensor ( TFDataType dt , long [ ] dims , Array data , int size )
341369 {
@@ -418,6 +446,7 @@ unsafe public static implicit operator TFTensor (Array array)
418446 {
419447 if ( array == null )
420448 throw new ArgumentNullException ( nameof ( array ) ) ;
449+ // TODO: ensure that we do not have arrays of arrays.
421450 var t = array . GetType ( ) . GetElementType ( ) ;
422451 var tc = Type . GetTypeCode ( t ) ;
423452 TFDataType dt ;
@@ -474,7 +503,6 @@ unsafe public static implicit operator TFTensor (Array array)
474503 size *= ( int ) dims [ i ] ;
475504 }
476505 var newTensor = new TFTensor ( SetupMulti ( dt , dims , array , size ) ) ;
477- var s = newTensor . Shape ;
478506 return newTensor ;
479507 }
480508
@@ -496,13 +524,15 @@ internal override void NativeDispose (IntPtr handle)
496524
497525 // extern TF_Tensor * TF_AllocateTensor (TF_DataType, const int64_t *dims, int num_dims, size_t len);
498526 [ DllImport ( NativeBinding . TensorFlowLibrary ) ]
499- static extern unsafe TF_Tensor TF_AllocateTensor ( TFDataType dataType , [ In ] ref long [ ] dims , int num_dims , IntPtr len ) ;
527+ static extern unsafe TF_Tensor TF_AllocateTensor ( TFDataType dataType , long [ ] dims , int num_dims , size_t len ) ;
528+ [ DllImport ( NativeBinding . TensorFlowLibrary ) ]
529+ static extern unsafe TF_Tensor TF_AllocateTensor ( TFDataType dataType , IntPtr zeroDim , int num_dims , size_t len ) ;
500530
501531 public TFTensor ( TFDataType dataType , long [ ] dims , int size ) : base ( IntPtr . Zero )
502532 {
503533 if ( dims == null )
504534 throw new ArgumentNullException ( "dims" ) ;
505- handle = TF_AllocateTensor ( dataType , ref dims , dims . Length , ( IntPtr ) size ) ;
535+ handle = TF_AllocateTensor ( dataType , dims , dims . Length , ( size_t ) size ) ;
506536 }
507537
508538 // extern void TF_DeleteTensor (TF_Tensor *);
@@ -534,14 +564,30 @@ public long GetTensorDimension (int dimIndex)
534564 [ DllImport ( NativeBinding . TensorFlowLibrary ) ]
535565 static extern unsafe size_t TF_TensorByteSize ( TF_Tensor tensor ) ;
536566
537- public size_t ByteSize => TF_TensorByteSize ( handle ) ;
567+ public size_t TensorByteSize => TF_TensorByteSize ( handle ) ;
538568
539569 // extern void * TF_TensorData (const TF_Tensor *);
540570 [ DllImport ( NativeBinding . TensorFlowLibrary ) ]
541571 static extern unsafe IntPtr TF_TensorData ( TF_Tensor tensor ) ;
542572
573+ /// <summary>
574+ /// Returns a pointer to the raw data in the tensor.
575+ /// </summary>
576+ /// <remarks>
577+ /// The contents of the Data must be interpreted according to the type of the
578+ /// data as described by the DataType property. The amount of data
579+ /// is given by the the TensorByteSize property.
580+ /// </remarks>
543581 public IntPtr Data => TF_TensorData ( handle ) ;
544582
583+ /// <summary>
584+ /// Returns the tensor shape, this is an array whose size determines the number of dimensions on the tensor, and each element is the size of the dimension
585+ /// </summary>
586+ /// <remarks>
587+ /// An array of size 0 is used for constants, an array of size 1 is used
588+ /// for single-dimension arrays, where the dimension is the value of the
589+ /// first element. And so on.
590+ /// </remarks>
545591 public long [ ] Shape {
546592 get {
547593 var dims = new long [ TF_NumDims ( handle ) ] ;
@@ -551,22 +597,167 @@ public long [] Shape {
551597 return dims ;
552598 }
553599 }
600+
601+ static Type TypeFromTensorType ( TFDataType type )
602+ {
603+ switch ( type ) {
604+ case TFDataType . Float :
605+ return typeof ( float ) ;
606+ case TFDataType . Double :
607+ return typeof ( double ) ;
608+ case TFDataType . Int32 :
609+ return typeof ( int ) ;
610+ case TFDataType . UInt8 :
611+ return typeof ( byte ) ;
612+ case TFDataType . Int16 :
613+ return typeof ( short ) ;
614+ case TFDataType . Int8 :
615+ return typeof ( sbyte ) ;
616+ case TFDataType . String :
617+ return typeof ( TFString ) ;
618+ case TFDataType . Int64 :
619+ return typeof ( long ) ;
620+ case TFDataType . Bool :
621+ return typeof ( bool ) ;
622+ case TFDataType . UInt16 :
623+ return typeof ( ushort ) ;
624+ case TFDataType . Complex128 :
625+ return typeof ( Complex ) ;
626+ default :
627+ return null ;
628+ }
629+ }
630+
631+ static unsafe object FetchSimple ( TFDataType dt , IntPtr data )
632+ {
633+ switch ( dt ) {
634+ case TFDataType . Float :
635+ return * ( float * ) data ;
636+ case TFDataType . Double :
637+ return * ( double * ) data ;
638+ case TFDataType . Int32 :
639+ return * ( int * ) data ;
640+ case TFDataType . UInt8 :
641+ return * ( byte * ) data ;
642+ case TFDataType . Int16 :
643+ return * ( short * ) data ;
644+ case TFDataType . Int8 :
645+ return * ( sbyte * ) data ;
646+ case TFDataType . String :
647+ throw new NotImplementedException ( ) ;
648+ case TFDataType . Int64 :
649+ return * ( long * ) data ;
650+ case TFDataType . Bool :
651+ return * ( bool * ) data ;
652+ case TFDataType . UInt16 :
653+ return * ( ushort * ) data ;
654+ case TFDataType . Complex128 :
655+ return * ( Complex * ) data ;
656+ default :
657+ return null ;
658+ }
659+ }
660+
661+ unsafe static void Copy ( IntPtr src , void * target , int size )
662+ {
663+ Buffer . MemoryCopy ( ( void * ) src , target , size , size ) ;
664+ }
665+
666+ static unsafe void FetchArray ( Array target , TFDataType dt , IntPtr data )
667+ {
668+ int len = target . Length ;
669+ switch ( dt ) {
670+ case TFDataType . Int8 :
671+ var asbyte = ( sbyte [ ] ) target ;
672+ fixed ( sbyte * p = & asbyte [ 0 ] )
673+ Copy ( data , p , len ) ;
674+ return ;
675+ case TFDataType . Bool :
676+ var abool = ( bool [ ] ) target ;
677+ fixed ( bool * p = & abool [ 0 ] )
678+ Copy ( data , p , len ) ;
679+ return ;
680+ case TFDataType . UInt16 :
681+ var aushort = ( ushort [ ] ) target ;
682+ fixed ( ushort * p = & aushort [ 0 ] )
683+ Copy ( data , p , len * 2 ) ;
684+ return ;
685+ case TFDataType . Complex128 :
686+ var acomplex = ( Complex [ ] ) target ;
687+ fixed ( Complex * p = & acomplex [ 0 ] )
688+ Copy ( data , p , len * sizeof ( Complex ) ) ;
689+ return ;
690+ case TFDataType . Float :
691+ var afloat = ( float [ ] ) target ;
692+ fixed ( float * p = & afloat [ 0 ] )
693+ Copy ( data , p , len * sizeof ( float ) ) ;
694+ return ;
695+ case TFDataType . Double :
696+ var adouble = ( double [ ] ) target ;
697+ fixed ( double * p = & adouble [ 0 ] )
698+ Copy ( data , p , len * sizeof ( double ) ) ;
699+ return ;
700+ case TFDataType . Int32 :
701+ var aint = ( int [ ] ) target ;
702+ fixed ( int * p = & aint [ 0 ] )
703+ Copy ( data , p , len * sizeof ( double ) ) ;
704+ return ;
705+ case TFDataType . UInt8 :
706+ var abyte = ( byte [ ] ) target ;
707+ fixed ( byte * p = & abyte [ 0 ] )
708+ Copy ( data , p , len * sizeof ( byte ) ) ;
709+ return ;
710+ case TFDataType . Int16 :
711+ var ashort = ( short [ ] ) target ;
712+ fixed ( short * p = & ashort [ 0 ] )
713+ Copy ( data , p , len * sizeof ( short ) ) ;
714+ return ;
715+ case TFDataType . Int64 :
716+ var along = ( long [ ] ) target ;
717+ fixed ( long * p = & along [ 0 ] )
718+ Copy ( data , p , len * sizeof ( long ) ) ;
719+ return ;
720+ case TFDataType . String :
721+ // need to return an array of TFStrings []
722+ throw new NotImplementedException ( ) ;
723+ default :
724+ throw new NotImplementedException ( ) ;
725+ }
726+ }
727+
728+ /// <summary>
729+ /// Returns the value of the Tensor as a C# type if possible, or null if the data type can not be represented in C#
730+ /// </summary>
731+ /// <returns>The value encodes the contents of the tensor, and could include simple values, arrays and multi-dimensional values</returns>
732+ public object GetValue ( )
733+ {
734+ var dims = NumDims ;
735+ if ( dims == 0 )
736+ return FetchSimple ( TensorType , Data ) ;
737+
738+ var t = TypeFromTensorType ( TensorType ) ;
739+ if ( t == null )
740+ return null ;
741+
742+ var result = Array . CreateInstance ( t , Shape ) ;
743+ FetchArray ( result , TensorType , Data ) ;
744+ return result ;
745+ }
554746 }
555747
556- // TODO: All these
557- static partial class NativeBinding
748+ public class TFString
558749 {
559750 // extern size_t TF_StringEncode (const char *src, size_t src_len, char *dst, size_t dst_len, TF_Status *status);
560751 [ DllImport ( NativeBinding . TensorFlowLibrary ) ]
561- static extern unsafe size_t TF_StringEncode ( sbyte * src , size_t src_len , sbyte * dst , size_t dst_len , TF_Status status ) ;
562-
752+ internal static extern unsafe size_t TF_StringEncode ( byte * src , size_t src_len , sbyte * dst , size_t dst_len , TF_Status status ) ;
753+
563754 // extern size_t TF_StringDecode (const char *src, size_t src_len, const char **dst, size_t *dst_len, TF_Status *status);
564755 [ DllImport ( NativeBinding . TensorFlowLibrary ) ]
565- static extern unsafe size_t TF_StringDecode ( sbyte * src , size_t src_len , sbyte * * dst , size_t * dst_len , TF_Status status ) ;
756+ internal static extern unsafe size_t TF_StringDecode ( sbyte * src , size_t src_len , sbyte * * dst , size_t * dst_len , TF_Status status ) ;
566757
567758 // extern size_t TF_StringEncodedSize (size_t len);
568759 [ DllImport ( NativeBinding . TensorFlowLibrary ) ]
569- static extern size_t TF_StringEncodedSize ( size_t len ) ;
760+ internal static extern size_t TF_StringEncodedSize ( size_t len ) ;
570761 }
571762
572763 public class TFSessionOptions : TFDisposable
@@ -1800,7 +1991,7 @@ public TFInput [] OutputConsumers {
18001991 public TFOperation Operation => new TFOperation ( null , LLOperation ) ;
18011992 public override string ToString ( )
18021993 {
1803- return string . Format ( "[TFOutput: LLOperation={0:X} Index={0} NumConsumers={0}, OutputType={1}, Operation={2}]" , ( long ) LLOperation , Index , NumConsumers , OutputType , Operation ) ;
1994+ return string . Format ( "[TFOutput: LLOperation=0x {0:X} Index={1} Operation={2}]" , ( long ) LLOperation , Index , Operation ) ;
18041995 }
18051996 }
18061997
0 commit comments