Skip to content

Commit 60800c9

Browse files
committed
TFTensor.GetValue(), tracked the source of errros - must support TF_STRING encoding
1 parent 34432e7 commit 60800c9

File tree

4 files changed

+213
-20
lines changed

4 files changed

+213
-20
lines changed

ExampleInceptionInference/Program.cs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ public static void Main (string [] args)
106106
Environment.Exit (1);
107107
}
108108
int nlabels = (int) rshape [1];
109+
var val = result.GetValue ();
109110

110111
}
111112
}
@@ -116,7 +117,7 @@ static TFTensor CreateTensorFromImageFile (string file)
116117
var contents = File.ReadAllBytes (file);
117118

118119
// DecodeJpeg uses a scalar String-valued tensor as input.
119-
var tensor = (TFTensor) contents;
120+
var tensor = TFTensor.CreateString (contents);
120121

121122
TFGraph graph;
122123
TFOutput input, output;
@@ -158,16 +159,17 @@ static void ConstructGraphToNormalizeImage (out TFGraph graph, out TFOutput inpu
158159

159160
graph = new TFGraph ();
160161
input = graph.Placeholder (TFDataType.String);
162+
161163
output = graph.Div (
162164
x: graph.Sub (
163165
x: graph.ResizeBilinear (
164166
images: graph.ExpandDims (
165167
input: graph.Cast (
166168
graph.DecodeJpeg (contents: input, channels: 3), DstT: TFDataType.Float),
167169
dim: graph.Const (0, "make_batch")),
168-
size: graph.Const (new int [] { W, H })),
169-
y: graph.Const (Mean)),
170-
y: graph.Const (Scale));
170+
size: graph.Const (new int [] { W, H }, "size")),
171+
y: graph.Const (Mean, "mean")),
172+
y: graph.Const (Scale, "scale"));
171173
}
172174

173175
//

SampleTest/SampleTest.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ public void TestSession ()
131131
var res = results [0];
132132
Assert (res.TensorType == TFDataType.Int32);
133133
Assert (res.NumDims == 0); // Scalar
134-
Assert (res.ByteSize == (UIntPtr) 4);
134+
Assert (res.TensorByteSize == (UIntPtr) 4);
135135
Assert (Marshal.ReadInt32 (res.Data) == 3 + 2);
136136

137137

TensorFlowSharp/TensorFlowSharp.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
<OutputType>Library</OutputType>
88
<RootNamespace>TensorFlowSharp</RootNamespace>
99
<AssemblyName>TensorFlowSharp</AssemblyName>
10-
<TargetFrameworkVersion>v4.5</TargetFrameworkVersion>
10+
<TargetFrameworkVersion>v4.6.1</TargetFrameworkVersion>
1111
</PropertyGroup>
1212
<PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|AnyCPU' ">
1313
<DebugSymbols>true</DebugSymbols>

TensorFlowSharp/Tensorflow.cs

Lines changed: 205 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ public class TFStatus : TFDisposable
111111
{
112112
// extern TF_Status * TF_NewStatus ();
113113
[DllImport (NativeBinding.TensorFlowLibrary)]
114-
static extern unsafe TF_Status TF_NewStatus ();
114+
internal static extern unsafe TF_Status TF_NewStatus ();
115115

116116
[ThreadStatic] public static TFStatus Default = new TFStatus ();
117117

@@ -121,7 +121,7 @@ public TFStatus () : base (TF_NewStatus ())
121121

122122
// extern void TF_DeleteStatus (TF_Status *);
123123
[DllImport (NativeBinding.TensorFlowLibrary)]
124-
static extern unsafe void TF_DeleteStatus (TF_Status status);
124+
internal static extern unsafe void TF_DeleteStatus (TF_Status status);
125125

126126
internal override void NativeDispose (IntPtr handle)
127127
{
@@ -140,7 +140,7 @@ public void SetStatusCode (TFCode code, string msg)
140140

141141
// extern TF_Code TF_GetCode (const TF_Status *s);
142142
[DllImport (NativeBinding.TensorFlowLibrary)]
143-
static extern unsafe TFCode TF_GetCode (TF_Status s);
143+
internal static extern unsafe TFCode TF_GetCode (TF_Status s);
144144

145145
public TFCode StatusCode {
146146
get {
@@ -292,6 +292,7 @@ public byte [] ToArray ()
292292
}
293293
}
294294

295+
295296
public delegate void TFTensorDeallocator (IntPtr data, IntPtr size, IntPtr deallocatorData);
296297

297298
public class TFTensor : TFDisposable
@@ -317,6 +318,7 @@ internal static void FreeTensorHandle (IntPtr data, IntPtr len, IntPtr closure)
317318

318319
// TODO: Other overloads we could add: String, Complex (float), Bool, QInt8, QUInt8, QInt32, Bfloat16,
319320
// QInt16, QUint16, Half, Resource
321+
// TODO: not clear that this is very useful (the dims versions).
320322
public TFTensor (long [] dims, sbyte [] data, int start, int count) : base (SetupTensor (TFDataType.Int8, dims, data, start, count, size: 2)) { }
321323
public TFTensor (long [] dims, byte [] data, int start, int count) : base (SetupTensor (TFDataType.UInt8, dims, data, start, count, size: 1)) { }
322324
public TFTensor (long [] dims, short [] data, int start, int count) : base (SetupTensor (TFDataType.Int16, dims, data, start, count, size: 2)) { }
@@ -336,6 +338,32 @@ public TFTensor (long [] dims, double [] data) : base (SetupTensor (TFDataType.D
336338
public TFTensor (long [] dims, long [] data) : base (SetupTensor (TFDataType.Int64, dims, data, size: 8)) { }
337339
public TFTensor (long [] dims, Complex [] data) : base (SetupTensor (TFDataType.Complex128, dims, data, size: 16)) { }
338340

341+
public unsafe static TFTensor CreateString (byte [] buffer)
342+
{
343+
if (buffer == null)
344+
throw new ArgumentNullException (nameof (buffer));
345+
//
346+
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by
347+
// TF_StringEncode-encoded bytes.
348+
//
349+
var size = TFString.TF_StringEncodedSize ((UIntPtr) buffer.Length);
350+
IntPtr handle = TF_AllocateTensor (TFDataType.String, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8));
351+
352+
// Clear offset table
353+
IntPtr dst = TF_TensorData (handle);
354+
Marshal.WriteInt64 (dst, 0);
355+
var status = TFStatus.TF_NewStatus ();
356+
fixed (byte *src = &buffer [0])
357+
{
358+
TFString.TF_StringEncode (src, (UIntPtr) buffer.Length, (sbyte *)(dst + 8), size, status);
359+
var ok = TFStatus.TF_GetCode (status) == TFCode.Ok;
360+
TFStatus.TF_DeleteStatus (status);
361+
if (!ok)
362+
return null;
363+
}
364+
return new TFTensor (handle);
365+
}
366+
339367
// Convenience function to factor out the setup of a new tensor from an array
340368
static IntPtr SetupTensor (TFDataType dt, long [] dims, Array data, int size)
341369
{
@@ -418,6 +446,7 @@ unsafe public static implicit operator TFTensor (Array array)
418446
{
419447
if (array == null)
420448
throw new ArgumentNullException (nameof (array));
449+
// TODO: ensure that we do not have arrays of arrays.
421450
var t = array.GetType ().GetElementType ();
422451
var tc = Type.GetTypeCode (t);
423452
TFDataType dt;
@@ -474,7 +503,6 @@ unsafe public static implicit operator TFTensor (Array array)
474503
size *= (int) dims [i];
475504
}
476505
var newTensor = new TFTensor (SetupMulti (dt, dims, array, size));
477-
var s = newTensor.Shape;
478506
return newTensor;
479507
}
480508

@@ -496,13 +524,15 @@ internal override void NativeDispose (IntPtr handle)
496524

497525
// extern TF_Tensor * TF_AllocateTensor (TF_DataType, const int64_t *dims, int num_dims, size_t len);
498526
[DllImport (NativeBinding.TensorFlowLibrary)]
499-
static extern unsafe TF_Tensor TF_AllocateTensor (TFDataType dataType, [In] ref long [] dims, int num_dims, IntPtr len);
527+
static extern unsafe TF_Tensor TF_AllocateTensor (TFDataType dataType, long [] dims, int num_dims, size_t len);
528+
[DllImport (NativeBinding.TensorFlowLibrary)]
529+
static extern unsafe TF_Tensor TF_AllocateTensor (TFDataType dataType, IntPtr zeroDim, int num_dims, size_t len);
500530

501531
public TFTensor (TFDataType dataType, long [] dims, int size) : base (IntPtr.Zero)
502532
{
503533
if (dims == null)
504534
throw new ArgumentNullException ("dims");
505-
handle = TF_AllocateTensor (dataType, ref dims, dims.Length, (IntPtr)size);
535+
handle = TF_AllocateTensor (dataType, dims, dims.Length, (size_t)size);
506536
}
507537

508538
// extern void TF_DeleteTensor (TF_Tensor *);
@@ -534,14 +564,30 @@ public long GetTensorDimension (int dimIndex)
534564
[DllImport (NativeBinding.TensorFlowLibrary)]
535565
static extern unsafe size_t TF_TensorByteSize (TF_Tensor tensor);
536566

537-
public size_t ByteSize => TF_TensorByteSize (handle);
567+
public size_t TensorByteSize => TF_TensorByteSize (handle);
538568

539569
// extern void * TF_TensorData (const TF_Tensor *);
540570
[DllImport (NativeBinding.TensorFlowLibrary)]
541571
static extern unsafe IntPtr TF_TensorData (TF_Tensor tensor);
542572

573+
/// <summary>
574+
/// Returns a pointer to the raw data in the tensor.
575+
/// </summary>
576+
/// <remarks>
577+
/// The contents of the Data must be interpreted according to the type of the
578+
/// data as described by the DataType property. The amount of data
579+
/// is given by the the TensorByteSize property.
580+
/// </remarks>
543581
public IntPtr Data => TF_TensorData (handle);
544582

583+
/// <summary>
584+
/// Returns the tensor shape, this is an array whose size determines the number of dimensions on the tensor, and each element is the size of the dimension
585+
/// </summary>
586+
/// <remarks>
587+
/// An array of size 0 is used for constants, an array of size 1 is used
588+
/// for single-dimension arrays, where the dimension is the value of the
589+
/// first element. And so on.
590+
/// </remarks>
545591
public long [] Shape {
546592
get {
547593
var dims = new long [TF_NumDims (handle)];
@@ -551,22 +597,167 @@ public long [] Shape {
551597
return dims;
552598
}
553599
}
600+
601+
static Type TypeFromTensorType (TFDataType type)
602+
{
603+
switch (type) {
604+
case TFDataType.Float:
605+
return typeof (float);
606+
case TFDataType.Double:
607+
return typeof (double);
608+
case TFDataType.Int32:
609+
return typeof (int);
610+
case TFDataType.UInt8:
611+
return typeof (byte);
612+
case TFDataType.Int16:
613+
return typeof (short);
614+
case TFDataType.Int8:
615+
return typeof (sbyte);
616+
case TFDataType.String:
617+
return typeof (TFString);
618+
case TFDataType.Int64:
619+
return typeof (long);
620+
case TFDataType.Bool:
621+
return typeof (bool);
622+
case TFDataType.UInt16:
623+
return typeof (ushort);
624+
case TFDataType.Complex128:
625+
return typeof (Complex);
626+
default:
627+
return null;
628+
}
629+
}
630+
631+
static unsafe object FetchSimple (TFDataType dt, IntPtr data)
632+
{
633+
switch (dt) {
634+
case TFDataType.Float:
635+
return *(float*)data;
636+
case TFDataType.Double:
637+
return *(double*)data;
638+
case TFDataType.Int32:
639+
return *(int*)data;
640+
case TFDataType.UInt8:
641+
return *(byte*)data;
642+
case TFDataType.Int16:
643+
return *(short*)data;
644+
case TFDataType.Int8:
645+
return *(sbyte*)data;
646+
case TFDataType.String:
647+
throw new NotImplementedException ();
648+
case TFDataType.Int64:
649+
return *(long*)data;
650+
case TFDataType.Bool:
651+
return *(bool*)data;
652+
case TFDataType.UInt16:
653+
return *(ushort*)data;
654+
case TFDataType.Complex128:
655+
return *(Complex*)data;
656+
default:
657+
return null;
658+
}
659+
}
660+
661+
unsafe static void Copy (IntPtr src, void* target, int size)
662+
{
663+
Buffer.MemoryCopy ((void*)src, target, size, size);
664+
}
665+
666+
static unsafe void FetchArray (Array target, TFDataType dt, IntPtr data)
667+
{
668+
int len = target.Length;
669+
switch (dt) {
670+
case TFDataType.Int8:
671+
var asbyte = (sbyte [])target;
672+
fixed (sbyte* p = &asbyte [0])
673+
Copy (data, p, len);
674+
return;
675+
case TFDataType.Bool:
676+
var abool = (bool [])target;
677+
fixed (bool* p = &abool [0])
678+
Copy (data, p, len);
679+
return;
680+
case TFDataType.UInt16:
681+
var aushort = (ushort [])target;
682+
fixed (ushort* p = &aushort [0])
683+
Copy (data, p, len * 2);
684+
return;
685+
case TFDataType.Complex128:
686+
var acomplex = (Complex [])target;
687+
fixed (Complex* p = &acomplex [0])
688+
Copy (data, p, len * sizeof (Complex));
689+
return;
690+
case TFDataType.Float:
691+
var afloat = (float [])target;
692+
fixed (float* p = &afloat [0])
693+
Copy (data, p, len * sizeof(float));
694+
return;
695+
case TFDataType.Double:
696+
var adouble = (double [])target;
697+
fixed (double* p = &adouble [0])
698+
Copy (data, p, len * sizeof (double));
699+
return;
700+
case TFDataType.Int32:
701+
var aint = (int [])target;
702+
fixed (int* p = &aint [0])
703+
Copy (data, p, len * sizeof (double));
704+
return;
705+
case TFDataType.UInt8:
706+
var abyte = (byte [])target;
707+
fixed (byte* p = &abyte [0])
708+
Copy (data, p, len * sizeof (byte));
709+
return;
710+
case TFDataType.Int16:
711+
var ashort = (short [])target;
712+
fixed (short* p = &ashort [0])
713+
Copy (data, p, len * sizeof (short));
714+
return;
715+
case TFDataType.Int64:
716+
var along = (long [])target;
717+
fixed (long* p = &along [0])
718+
Copy (data, p, len * sizeof (long));
719+
return;
720+
case TFDataType.String:
721+
// need to return an array of TFStrings []
722+
throw new NotImplementedException ();
723+
default:
724+
throw new NotImplementedException ();
725+
}
726+
}
727+
728+
/// <summary>
729+
/// Returns the value of the Tensor as a C# type if possible, or null if the data type can not be represented in C#
730+
/// </summary>
731+
/// <returns>The value encodes the contents of the tensor, and could include simple values, arrays and multi-dimensional values</returns>
732+
public object GetValue ()
733+
{
734+
var dims = NumDims;
735+
if (dims == 0)
736+
return FetchSimple (TensorType, Data);
737+
738+
var t = TypeFromTensorType (TensorType);
739+
if (t == null)
740+
return null;
741+
742+
var result = Array.CreateInstance (t, Shape);
743+
FetchArray (result, TensorType, Data);
744+
return result;
745+
}
554746
}
555747

556-
// TODO: All these
557-
static partial class NativeBinding
748+
public class TFString
558749
{
559750
// extern size_t TF_StringEncode (const char *src, size_t src_len, char *dst, size_t dst_len, TF_Status *status);
560751
[DllImport (NativeBinding.TensorFlowLibrary)]
561-
static extern unsafe size_t TF_StringEncode (sbyte* src, size_t src_len, sbyte* dst, size_t dst_len, TF_Status status);
562-
752+
internal static extern unsafe size_t TF_StringEncode (byte* src, size_t src_len, sbyte* dst, size_t dst_len, TF_Status status);
753+
563754
// extern size_t TF_StringDecode (const char *src, size_t src_len, const char **dst, size_t *dst_len, TF_Status *status);
564755
[DllImport (NativeBinding.TensorFlowLibrary)]
565-
static extern unsafe size_t TF_StringDecode (sbyte* src, size_t src_len, sbyte** dst, size_t* dst_len, TF_Status status);
756+
internal static extern unsafe size_t TF_StringDecode (sbyte* src, size_t src_len, sbyte** dst, size_t* dst_len, TF_Status status);
566757

567758
// extern size_t TF_StringEncodedSize (size_t len);
568759
[DllImport (NativeBinding.TensorFlowLibrary)]
569-
static extern size_t TF_StringEncodedSize (size_t len);
760+
internal static extern size_t TF_StringEncodedSize (size_t len);
570761
}
571762

572763
public class TFSessionOptions : TFDisposable
@@ -1800,7 +1991,7 @@ public TFInput [] OutputConsumers {
18001991
public TFOperation Operation => new TFOperation (null, LLOperation);
18011992
public override string ToString ()
18021993
{
1803-
return string.Format ("[TFOutput: LLOperation={0:X} Index={0} NumConsumers={0}, OutputType={1}, Operation={2}]", (long) LLOperation, Index, NumConsumers, OutputType, Operation);
1994+
return string.Format ("[TFOutput: LLOperation=0x{0:X} Index={1} Operation={2}]", (long) LLOperation, Index, Operation);
18041995
}
18051996
}
18061997

0 commit comments

Comments
 (0)