Skip to content

Commit ae900e3

Browse files
zeahmedmigueldeicaza
authored andcommitted
Added methods to support string input and output from TensorFlow. (migueldeicaza#394)
1 parent d3c2537 commit ae900e3

3 files changed

Lines changed: 192 additions & 33 deletions

File tree

TensorFlowSharp/Tensor.cs

Lines changed: 137 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -494,37 +494,144 @@ public TFTensor (long [] data) : base (SetupTensor (TFDataType.Int64, data, size
494494
/// <param name="data">Data.</param>
495495
public TFTensor (Complex [] data) : base (SetupTensor (TFDataType.Complex128, data, size: 16)) { }
496496

497-
/// <summary>
498-
/// Creates a single-dimension tensor from a byte buffer. This is different than creating a tensor from a byte array that produces a tensor with as many elements as the byte array.
497+
/// <summary>
498+
/// Creates a single-dimension tensor from a byte buffer. This is different than creating a tensor from a byte array that produces a tensor with as many elements as the byte array.
499+
/// </summary>
500+
public unsafe static TFTensor CreateString(byte[] buffer)
501+
{
502+
if (buffer == null)
503+
throw new ArgumentNullException(nameof(buffer));
504+
//
505+
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by
506+
// TF_StringEncode-encoded bytes.
507+
//
508+
var size = TFString.TF_StringEncodedSize((UIntPtr)buffer.Length);
509+
IntPtr handle = TF_AllocateTensor(TFDataType.String, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8));
510+
511+
// Clear offset table
512+
IntPtr dst = TF_TensorData(handle);
513+
Marshal.WriteInt64(dst, 0);
514+
var status = TFStatus.TF_NewStatus();
515+
fixed (byte* src = &buffer[0])
516+
{
517+
TFString.TF_StringEncode(src, (UIntPtr)buffer.Length, (byte*)(dst + 8), size, status);
518+
var ok = TFStatus.TF_GetCode(status) == TFCode.Ok;
519+
TFStatus.TF_DeleteStatus(status);
520+
if (!ok)
521+
return null;
522+
}
523+
return new TFTensor(handle);
524+
}
525+
526+
/// <summary>
527+
/// Converts a single-dimension tensor into a byte buffer. The byte array can be further decoded into strings using appropriate encoding scheme e.g. "UTF8"
499528
/// </summary>
500-
public unsafe static TFTensor CreateString (byte [] buffer)
501-
{
502-
if (buffer == null)
503-
throw new ArgumentNullException (nameof (buffer));
504-
//
505-
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by
506-
// TF_StringEncode-encoded bytes.
507-
//
508-
var size = TFString.TF_StringEncodedSize ((UIntPtr)buffer.Length);
509-
IntPtr handle = TF_AllocateTensor (TFDataType.String, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8));
510-
511-
// Clear offset table
512-
IntPtr dst = TF_TensorData (handle);
513-
Marshal.WriteInt64 (dst, 0);
514-
var status = TFStatus.TF_NewStatus ();
515-
fixed (byte* src = &buffer [0])
516-
{
517-
TFString.TF_StringEncode (src, (UIntPtr)buffer.Length, (sbyte*)(dst + 8), size, status);
518-
var ok = TFStatus.TF_GetCode (status) == TFCode.Ok;
519-
TFStatus.TF_DeleteStatus (status);
520-
if (!ok)
521-
return null;
522-
}
523-
return new TFTensor (handle);
524-
}
525-
526-
// Convenience function to factor out the setup of a new tensor from an array
527-
static IntPtr SetupTensor (TFDataType dt, long [] dims, Array data, int size)
529+
public static unsafe byte[] DecodeString(TFTensor tensor)
530+
{
531+
if (tensor == null)
532+
throw new ArgumentNullException(nameof(tensor));
533+
//
534+
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes.
535+
// [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes]
536+
//
537+
var src = TF_TensorData(tensor.handle);
538+
using (var status = new TFStatus())
539+
{
540+
IntPtr dst = IntPtr.Zero;
541+
UIntPtr dst_len = UIntPtr.Zero;
542+
TFString.TF_StringDecode((byte*)(src + 8), tensor.TensorByteSize - 8, (byte**)&dst, &dst_len, status.handle);
543+
var ok = status.StatusCode == TFCode.Ok;
544+
if (!ok)
545+
return null;
546+
var buffer = new byte[(int)dst_len];
547+
Marshal.Copy(dst, buffer, 0, buffer.Length);
548+
return buffer;
549+
}
550+
}
551+
552+
/// <summary>
553+
/// Creates a multi-dimension tensor from an array of byte buffer. The bytes for string[i] are represented as buffer[i][:].
554+
/// </summary>
555+
public static unsafe TFTensor CreateString(byte[][] buffer, TFShape shape)
556+
{
557+
if (buffer == null)
558+
throw new ArgumentNullException(nameof(buffer));
559+
//
560+
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes.
561+
// [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes]
562+
//
563+
int size = 0;
564+
foreach (var b in buffer)
565+
{
566+
size += (int)TFString.TF_StringEncodedSize((UIntPtr)b.Length);
567+
}
568+
int totalSize = size + buffer.Length * 8;
569+
ulong offset = 0;
570+
IntPtr handle = TF_AllocateTensor(TFDataType.String, shape.dims, shape.dims.Length, (UIntPtr)totalSize);
571+
572+
// Clear offset table
573+
IntPtr pOffset = TF_TensorData(handle);
574+
IntPtr dst = pOffset + buffer.Length * 8;
575+
IntPtr dstLimit = pOffset + totalSize;
576+
for (int i = 0; i < buffer.Length; i++)
577+
{
578+
Marshal.WriteInt64(pOffset, (long)offset);
579+
using (var status = new TFStatus())
580+
{
581+
fixed (byte* src = &buffer[i][0])
582+
{
583+
var written = TFString.TF_StringEncode(src, (UIntPtr)buffer[i].Length, (byte*)dst, (size_t)(dstLimit.ToInt64() - dst.ToInt64()), status.handle);
584+
var ok = status.StatusCode == TFCode.Ok;
585+
if (!ok)
586+
return null;
587+
pOffset += 8;
588+
dst += (int)written;
589+
offset += written.ToUInt64();
590+
}
591+
}
592+
}
593+
return new TFTensor(handle);
594+
}
595+
596+
/// <summary>
597+
/// Converts a multi-dimension tensor into a byte buffer array. The byte array can be further decoded into strings using appropriate encoding scheme e.g. "UTF8"
598+
/// </summary>
599+
public static unsafe byte[][] DecodeMultiDimensionString(TFTensor tensor)
600+
{
601+
if (tensor == null)
602+
throw new ArgumentNullException(nameof(tensor));
603+
//
604+
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes.
605+
// [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes]
606+
//
607+
long size = 1;
608+
foreach (var s in tensor.Shape)
609+
size *= s;
610+
611+
var buffer = new byte[size][];
612+
var src = TF_TensorData(tensor.handle);
613+
var srcLen = (IntPtr)(src.ToInt64() + (long)tensor.TensorByteSize);
614+
src += (int)(size * 8);
615+
for (int i = 0; i < buffer.Length; i++)
616+
{
617+
using (var status = new TFStatus())
618+
{
619+
IntPtr dst = IntPtr.Zero;
620+
UIntPtr dstLen = UIntPtr.Zero;
621+
var read = TFString.TF_StringDecode((byte*)src, (size_t)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, status.handle);
622+
var ok = status.StatusCode == TFCode.Ok;
623+
if (!ok)
624+
return null;
625+
buffer[i] = new byte[(int)dstLen];
626+
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length);
627+
src += (int)read;
628+
}
629+
}
630+
return buffer;
631+
}
632+
633+
// Convenience function to factor out the setup of a new tensor from an array
634+
static IntPtr SetupTensor (TFDataType dt, long [] dims, Array data, int size)
528635
{
529636
return SetupTensor (dt, dims, data, start: 0, count: data.Length, size: size);
530637
}

TensorFlowSharp/Tensorflow.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,11 +456,11 @@ internal class TFString
456456
{
457457
// extern size_t TF_StringEncode (const char *src, size_t src_len, char *dst, size_t dst_len, TF_Status *status);
458458
[DllImport (NativeBinding.TensorFlowLibrary)]
459-
internal static extern unsafe size_t TF_StringEncode (byte* src, size_t src_len, sbyte* dst, size_t dst_len, TF_Status status);
459+
internal static extern unsafe size_t TF_StringEncode (byte* src, size_t src_len, byte* dst, size_t dst_len, TF_Status status);
460460

461461
// extern size_t TF_StringDecode (const char *src, size_t src_len, const char **dst, size_t *dst_len, TF_Status *status);
462462
[DllImport (NativeBinding.TensorFlowLibrary)]
463-
internal static extern unsafe size_t TF_StringDecode (sbyte* src, size_t src_len, sbyte** dst, size_t* dst_len, TF_Status status);
463+
internal static extern unsafe size_t TF_StringDecode (byte* src, size_t src_len, byte** dst, size_t* dst_len, TF_Status status);
464464

465465
// extern size_t TF_StringEncodedSize (size_t len);
466466
[DllImport (NativeBinding.TensorFlowLibrary)]

tests/TensorFlowSharp.Tests.CSharp/TensorTests.cs

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Collections.Generic;
33
using TensorFlow;
4+
using System.Text;
45
using Xunit;
56

67
namespace TensorFlowSharp.Tests.CSharp
@@ -60,5 +61,56 @@ public void Should_MultidimensionalAndJaggedBeEqual (Array jagged, Array multidi
6061
}
6162
}
6263

63-
}
64+
[Fact]
65+
public void StringTestWithMultiDimStringTensorAsInputOutput()
66+
{
67+
using (var graph = new TFGraph())
68+
using (var session = new TFSession(graph))
69+
{
70+
var W = graph.Placeholder(TFDataType.String, new TFShape(-1, 2));
71+
var identityW = graph.Identity(W);
72+
73+
var dataW = new string[,] { { "This is fine.", "That's ok." }, { "This is fine.", "That's ok." } };
74+
var bytes = new byte[2 * 2][];
75+
bytes[0] = Encoding.UTF8.GetBytes(dataW[0, 0]);
76+
bytes[1] = Encoding.UTF8.GetBytes(dataW[0, 1]);
77+
bytes[2] = Encoding.UTF8.GetBytes(dataW[1, 0]);
78+
bytes[3] = Encoding.UTF8.GetBytes(dataW[1, 1]);
79+
var tensorW = TFTensor.CreateString(bytes, new TFShape(2,2));
80+
81+
var outputTensor = session.Run(new TFOutput[] { W }, new TFTensor[] { tensorW }, new[] { identityW });
82+
83+
var outputW = TFTensor.DecodeMultiDimensionString(outputTensor[0]);
84+
Assert.Equal(dataW[0, 0], Encoding.UTF8.GetString(outputW[0]));
85+
Assert.Equal(dataW[0, 1], Encoding.UTF8.GetString(outputW[1]));
86+
Assert.Equal(dataW[1, 0], Encoding.UTF8.GetString(outputW[2]));
87+
Assert.Equal(dataW[1, 1], Encoding.UTF8.GetString(outputW[3]));
88+
}
89+
}
90+
91+
[Fact]
92+
public void StringTestWithMultiDimStringTensorAsInputAndScalarStringAsOutput()
93+
{
94+
using (var graph = new TFGraph())
95+
using (var session = new TFSession(graph))
96+
{
97+
var X = graph.Placeholder(TFDataType.String, new TFShape(-1));
98+
var delimiter = graph.Const(TFTensor.CreateString(Encoding.UTF8.GetBytes("/")));
99+
var indices = graph.Const(0);
100+
var Y = graph.ReduceJoin(graph.StringSplit(X, delimiter).values, indices, separator: " ");
101+
102+
var dataX = new string[] { "Thank/you/very/much!.", "I/am/grateful/to/you.", "So/nice/of/you." };
103+
var bytes = new byte[dataX.Length][];
104+
bytes[0] = Encoding.UTF8.GetBytes(dataX[0]);
105+
bytes[1] = Encoding.UTF8.GetBytes(dataX[1]);
106+
bytes[2] = Encoding.UTF8.GetBytes(dataX[2]);
107+
var tensorX = TFTensor.CreateString(bytes, new TFShape(3));
108+
109+
var outputTensors = session.Run(new TFOutput[] { X }, new TFTensor[] { tensorX }, new[] { Y });
110+
111+
var outputY = Encoding.UTF8.GetString(TFTensor.DecodeString(outputTensors[0]));
112+
Assert.Equal(string.Join(" ", dataX).Replace("/", " "), outputY);
113+
}
114+
}
115+
}
64116
}

0 commit comments

Comments
 (0)