Skip to content

Commit 1bed2e1

Browse files
committed
TFTensor constructor can now take arbitrary arrays
1 parent 7acec0d commit 1bed2e1

1 file changed

Lines changed: 68 additions & 61 deletions

File tree

TensorFlowSharp/Tensor.cs

Lines changed: 68 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,73 @@ public static TFTensor FromBuffer (TFShape shape, Complex [] data, int start, in
221221
{
222222
return new TFTensor (SetupTensor (TFDataType.Complex128, shape, data, start, count, size: 16));
223223
}
224-
224+
225+
/// <summary>
226+
/// Creates a constant tensor from an array, the shape reflects the shape of the C# array and the underlying type reflects the C# type.
227+
/// </summary>
228+
public unsafe TFTensor (Array array)
229+
{
230+
if (array == null)
231+
throw new ArgumentNullException (nameof (array));
232+
// TODO: ensure that we do not have arrays of arrays.
233+
var t = array.GetType ().GetElementType ();
234+
var tc = Type.GetTypeCode (t);
235+
TFDataType dt;
236+
long size = 0;
237+
switch (tc) {
238+
case TypeCode.Boolean:
239+
dt = TFDataType.Bool;
240+
size = 1;
241+
break;
242+
case TypeCode.SByte:
243+
dt = TFDataType.Int8;
244+
size = 1;
245+
break;
246+
case TypeCode.Byte:
247+
dt = TFDataType.UInt8;
248+
size = 1;
249+
break;
250+
case TypeCode.Int16:
251+
dt = TFDataType.Int16;
252+
size = 2;
253+
break;
254+
case TypeCode.UInt16:
255+
dt = TFDataType.UInt16;
256+
size = 2;
257+
break;
258+
case TypeCode.Int32:
259+
dt = TFDataType.Int32;
260+
size = 4;
261+
break;
262+
case TypeCode.Int64:
263+
dt = TFDataType.Int64;
264+
size = 8;
265+
break;
266+
case TypeCode.Single:
267+
dt = TFDataType.Float;
268+
size = 4;
269+
break;
270+
case TypeCode.Double:
271+
dt = TFDataType.Double;
272+
size = 8;
273+
break;
274+
default:
275+
// Check types that are not handled by the typecode
276+
if (t.IsAssignableFrom (typeof (Complex))) {
277+
size = 16;
278+
dt = TFDataType.Complex128;
279+
} else
280+
throw new ArgumentException ($"The data type {t} is not supported");
281+
break;
282+
}
283+
284+
var dims = new long [array.Rank];
285+
for (int i = 0; i < array.Rank; i++) {
286+
dims [i] = array.GetLength (i);
287+
size *= (int)dims [i];
288+
}
289+
handle = SetupMulti (dt, dims, array, size);
290+
}
225291

226292
/// <summary>
227293
/// Creates a constant tensor with a single dimension from an integer value.
@@ -541,67 +607,8 @@ unsafe public static implicit operator TFTensor (byte value)
541607
/// </remarks>
542608
unsafe public static implicit operator TFTensor (Array array)
543609
{
544-
if (array == null)
545-
throw new ArgumentNullException (nameof (array));
546-
// TODO: ensure that we do not have arrays of arrays.
547-
var t = array.GetType ().GetElementType ();
548-
var tc = Type.GetTypeCode (t);
549-
TFDataType dt;
550-
long size = 0;
551-
switch (tc) {
552-
case TypeCode.Boolean:
553-
dt = TFDataType.Bool;
554-
size = 1;
555-
break;
556-
case TypeCode.SByte:
557-
dt = TFDataType.Int8;
558-
size = 1;
559-
break;
560-
case TypeCode.Byte:
561-
dt = TFDataType.UInt8;
562-
size = 1;
563-
break;
564-
case TypeCode.Int16:
565-
dt = TFDataType.Int16;
566-
size = 2;
567-
break;
568-
case TypeCode.UInt16:
569-
dt = TFDataType.UInt16;
570-
size = 2;
571-
break;
572-
case TypeCode.Int32:
573-
dt = TFDataType.Int32;
574-
size = 4;
575-
break;
576-
case TypeCode.Int64:
577-
dt = TFDataType.Int64;
578-
size = 8;
579-
break;
580-
case TypeCode.Single:
581-
dt = TFDataType.Float;
582-
size = 4;
583-
break;
584-
case TypeCode.Double:
585-
dt = TFDataType.Double;
586-
size = 8;
587-
break;
588-
default:
589-
// Check types that are not handled by the typecode
590-
if (t.IsAssignableFrom (typeof (Complex))){
591-
size = 16;
592-
dt = TFDataType.Complex128;
593-
} else
594-
throw new ArgumentException ($"The data type {t} is not supported");
595-
break;
596-
}
610+
return new TFTensor (array);
597611

598-
var dims = new long [array.Rank];
599-
for (int i = 0; i < array.Rank; i++) {
600-
dims [i] = array.GetLength (i);
601-
size *= (int)dims [i];
602-
}
603-
var newTensor = new TFTensor (SetupMulti (dt, dims, array, size));
604-
return newTensor;
605612
}
606613

607614
// General purpose constructor, specifies data type and gets pointer to buffer

0 commit comments

Comments
 (0)