Skip to content

Commit 6bb28e9

Browse files
committed
TFTensor, add convenience boolean overloads
1 parent bd60f0e commit 6bb28e9

File tree

1 file changed

+26
-1
lines changed

1 file changed

+26
-1
lines changed

TensorFlowSharp/Tensor.cs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,16 @@ public unsafe TFTensor (int value)
233233
handle = TF_NewTensor (TFDataType.Int32, zeroDims: IntPtr.Zero, num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof (int), deallocator: FreeTensorData, deallocator_arg: IntPtr.Zero);
234234
}
235235

236+
/// <summary>
237+
/// Creates a constant tensor with a single dimension from a boolean value.
238+
/// </summary>
239+
public unsafe TFTensor (bool value)
240+
{
241+
var v = (bool*)Marshal.AllocHGlobal (sizeof (bool));
242+
*v = value;
243+
handle = TF_NewTensor (TFDataType.Bool, zeroDims: IntPtr.Zero, num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof (int), deallocator: FreeTensorData, deallocator_arg: IntPtr.Zero);
244+
}
245+
236246
/// <summary>
237247
/// Creates a constant tensor with a single dimension from an sbyte value.
238248
/// </summary>
@@ -315,11 +325,16 @@ public unsafe TFTensor (long value)
315325

316326
// Convenience, should I add T[,] and T[,,] as more convenience ones?
317327

328+
/// <summary>
329+
/// Creates a 1 dimensional tensor from an array of booleans.
330+
/// </summary>
331+
/// <param name="data">Data.</param>
332+
public TFTensor (bool [] data) : base (SetupTensor (TFDataType.Bool, data, size: 1)) { }
318333
/// <summary>
319334
/// Creates a 1 dimensional tensor from an array of sbytes.
320335
/// </summary>
321336
/// <param name="data">Data.</param>
322-
public TFTensor (sbyte [] data) : base (SetupTensor (TFDataType.Int8, data, size: 2)) { }
337+
public TFTensor (sbyte [] data) : base (SetupTensor (TFDataType.Int8, data, size: 1)) { }
323338
/// <summary>
324339
/// Creates a 1 dimensional tensor from an array of bytes.
325340
/// </summary>
@@ -453,6 +468,16 @@ public static implicit operator TFTensor (int value)
453468
return new TFTensor (value);
454469
}
455470

471+
/// <summary>
472+
/// Converts a boolean into a 1-dimensional, 1-valued tensor.
473+
/// </summary>
474+
/// <returns>The tensor representing the integer value.</returns>
475+
/// <param name="value">Value to initialize the tensor with.</param>
476+
public static implicit operator TFTensor (bool value)
477+
{
478+
return new TFTensor (value);
479+
}
480+
456481
/// <summary>
457482
/// Converts a long into a 1-dimensional, 1-valued tensor.
458483
/// </summary>

0 commit comments

Comments
 (0)