Skip to content

Commit 60299b7

Browse files
cesarsouzamigueldeicaza
authored andcommitted
Adding support for jagged arrays as inputs to TFTensor (migueldeicaza#129)
* Adding a possible implementation for migueldeicazaGH-33. * Reverting unintended formatting changes.
1 parent 763a264 commit 60299b7

3 files changed

Lines changed: 230 additions & 1 deletion

File tree

TensorFlowSharp/Tensor.cs

Lines changed: 165 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
// Miguel de Icaza ([email protected])
66
//
77
using System;
8+
using System.Collections;
9+
using System.Collections.Generic;
10+
using System.Linq;
811
using System.Numerics;
912
using System.Runtime.InteropServices;
1013
using System.Text;
@@ -238,7 +241,22 @@ public unsafe TFTensor (Array array)
238241
{
239242
if (array == null)
240243
throw new ArgumentNullException (nameof (array));
241-
// TODO: ensure that we do not have arrays of arrays.
244+
245+
// Ensure that, if we have arrays of arrays, we can handle them accordingly:
246+
if (isJagged (array.GetType ())) {
247+
Type elementType = getInnerMostType (array);
248+
int [] length = getLength (array);
249+
Array multidimensional = Array.CreateInstance (elementType, length);
250+
Array flatten = deepFlatten (array);
251+
Buffer.BlockCopy (flatten, 0, multidimensional, 0, flatten.Length * Marshal.SizeOf (elementType));
252+
createFromMultidimensionalArrays (multidimensional);
253+
} else {
254+
createFromMultidimensionalArrays (array);
255+
}
256+
}
257+
258+
private unsafe void createFromMultidimensionalArrays (Array array)
259+
{
242260
var t = array.GetType ().GetElementType ();
243261
var tc = Type.GetTypeCode (t);
244262
TFDataType dt;
@@ -1115,6 +1133,152 @@ public override string ToString ()
11151133
sb.Append ("]");
11161134
return sb.ToString ();
11171135
}
1136+
1137+
1138+
1139+
1140+
1141+
1142+
1143+
1144+
1145+
private static int [] getLength (Array array, bool deep = true, bool max = false)
1146+
{
1147+
// This function gets the length of all dimensions in a multidimensional, jagged, or mixed array.
1148+
// https://github.com/accord-net/framework/blob/b4990721a61f03602d04c12b148215c7eca1b7ac/Sources/Accord.Math/Matrix/Matrix.Construction.cs#L1118
1149+
// Relicensed under the MIT license by the original author for inclusion in TensorFlowSharp and any derived projects, see the MIT license for details
1150+
1151+
if (array.Rank == 0)
1152+
return new int [0];
1153+
1154+
if (deep && isJagged (array)) {
1155+
if (array.Length == 0)
1156+
return new int [0];
1157+
1158+
int [] rest;
1159+
if (!max) {
1160+
rest = getLength (array.GetValue (0) as Array, deep);
1161+
} else {
1162+
// find the max
1163+
rest = getLength (array.GetValue (0) as Array, deep);
1164+
for (int i = 1; i < array.Length; i++) {
1165+
int [] r = getLength (array.GetValue (i) as Array, deep);
1166+
1167+
for (int j = 0; j < r.Length; j++) {
1168+
if (r [j] > rest [j])
1169+
rest [j] = r [j];
1170+
}
1171+
}
1172+
}
1173+
1174+
return new [] { array.Length }.Concat (rest).ToArray ();
1175+
}
1176+
1177+
int [] vector = new int [array.Rank];
1178+
for (int i = 0; i < vector.Length; i++)
1179+
vector [i] = array.GetUpperBound (i) + 1;
1180+
return vector;
1181+
}
1182+
1183+
private static Array deepFlatten (Array array)
1184+
{
1185+
// This function converts multidimensional, jagged, or mixed arrays into a single unidimensional array (i.e. flattens the mixed array).
1186+
// https://github.com/accord-net/framework/blob/f78181b82eb6ee6cc7fd10d2a7a55334982c40df/Sources/Accord.Math/Matrix/Matrix.Common.cs#L1625
1187+
// Relicensed under the MIT license by the original author for inclusion in TensorFlowSharp and any derived projects, see the MIT license for details
1188+
int totalLength = getTotalLength (array, deep: true);
1189+
var elementType = getInnerMostType (array);
1190+
Array result = Array.CreateInstance (elementType, totalLength);
1191+
1192+
int k = 0;
1193+
foreach (object v in enumerateJagged (array))
1194+
result.SetValue (v, k++);
1195+
return result;
1196+
}
1197+
1198+
private static IEnumerable enumerateJagged (Array array)
1199+
{
1200+
// This function can enumerate all elements in a multidimensional ,jagged, or mixed array.
1201+
// From https://github.com/accord-net/framework/blob/b4990721a61f03602d04c12b148215c7eca1b7ac/Sources/Accord.Math/Matrix/Jagged.Construction.cs#L1202
1202+
// Relicensed under the MIT license by the original author for inclusion in TensorFlowSharp and any derived projects, see the MIT license for details
1203+
var arrays = new Stack<Array> ();
1204+
var counters = new Stack<int> ();
1205+
1206+
arrays.Push (array);
1207+
counters.Push (0);
1208+
int depth = 1;
1209+
1210+
Array a = array;
1211+
int i = 0;
1212+
1213+
while (arrays.Count > 0) {
1214+
if (i >= a.Length) {
1215+
a = arrays.Pop ();
1216+
i = counters.Pop () + 1;
1217+
depth--;
1218+
} else {
1219+
Object e = a.GetValue (i);
1220+
Array next = e as Array;
1221+
if (next == null) {
1222+
yield return e;
1223+
i++;
1224+
} else {
1225+
arrays.Push (a);
1226+
counters.Push (i);
1227+
a = next;
1228+
i = 0;
1229+
depth++;
1230+
}
1231+
}
1232+
}
1233+
}
1234+
1235+
private static int getTotalLength (Array array, bool deep = true, bool rectangular = true)
1236+
{
1237+
// From https://github.com/accord-net/framework/blob/b4990721a61f03602d04c12b148215c7eca1b7ac/Sources/Accord.Math/Matrix/Matrix.Construction.cs#L1087
1238+
// Relicensed under the MIT license by the original author for inclusion in TensorFlowSharp and any derived projects, see the MIT license for details
1239+
if (deep && isJagged (array.GetType ())) {
1240+
if (rectangular) {
1241+
int rest = getTotalLength (array.GetValue (0) as Array, deep);
1242+
return array.Length * rest;
1243+
} else {
1244+
int sum = 0;
1245+
for (int i = 0; i < array.Length; i++)
1246+
sum += getTotalLength (array.GetValue (i) as Array, deep);
1247+
return sum;
1248+
}
1249+
}
1250+
1251+
return array.Length;
1252+
}
1253+
1254+
private static bool isJagged (Array array)
1255+
{
1256+
// From https://github.com/accord-net/framework/blob/f78181b82eb6ee6cc7fd10d2a7a55334982c40df/Sources/Accord.Math/Matrix/Matrix.Construction.cs#L1204
1257+
// Relicensed under the MIT license by the original author for inclusion in TensorFlowSharp and any derived projects, see the MIT license for details
1258+
if (array.Length == 0)
1259+
return array.Rank == 1;
1260+
return array.Rank == 1 && array.GetValue (0) is Array;
1261+
}
1262+
1263+
private static bool isJagged (Type type)
1264+
{
1265+
// From https://github.com/accord-net/framework/blob/eb371fbc540a41c1a711b6ab1ebd49889316e7f7/Sources/Accord.Math/Matrix/Matrix.Common.cs#L84
1266+
// Relicensed under the MIT license by the original author for inclusion in TensorFlowSharp and any derived projects, see the MIT license for details
1267+
return type.IsArray && type.GetElementType ().IsArray;
1268+
}
1269+
1270+
private static Type getInnerMostType (Array array)
1271+
{
1272+
// From https://github.com/accord-net/framework/blob/eb371fbc540a41c1a711b6ab1ebd49889316e7f7/Sources/Accord.Math/Matrix/Matrix.Common.cs#L95
1273+
// Relicensed under the MIT license by the original author for inclusion in TensorFlowSharp and any derived projects, see the MIT license for details
1274+
Type type = array.GetType ();
1275+
1276+
while (type.IsArray)
1277+
type = type.GetElementType ();
1278+
1279+
return type;
1280+
}
1281+
11181282
}
11191283

11201284
}

tests/TensorFlowSharp.Tests.CSharp/TensorFlowSharp.Tests.CSharp.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
</ItemGroup>
6565
<ItemGroup>
6666
<Compile Include="ArrayTests.cs" />
67+
<Compile Include="TensorTests.cs" />
6768
<Compile Include="ClipTests.cs" />
6869
<Compile Include="BitwiseOperationTests.cs" />
6970
<Compile Include="MathTests.cs" />
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using TensorFlow;
4+
using Xunit;
5+
6+
namespace TensorFlowSharp.Tests.CSharp
7+
{
8+
public class TensorTests
9+
{
10+
private static IEnumerable<object []> jaggedData ()
11+
{
12+
yield return new object [] {
13+
new double [][] { new [] { 1.0, 2.0 }, new [] { 3.0, 4.0 } },
14+
new double [,] { { 1.0, 2.0}, { 3.0, 4.0 } },
15+
true
16+
};
17+
18+
yield return new object [] {
19+
new double [][] { new [] { 1.0, 2.0 }, new [] { 1.0, 4.0 } },
20+
new double [,] { { 1.0, 2.0}, { 3.0, 4.0 } },
21+
false
22+
};
23+
24+
yield return new object [] {
25+
new double [][][] { new [] { new [] { 1.0 }, new[] { 2.0 } }, new [] { new [] { 3.0 }, new [] { 4.0 } } },
26+
new double [,,] { { { 1.0 }, { 2.0 } }, { { 3.0 }, { 4.0 } } },
27+
true
28+
};
29+
30+
yield return new object [] {
31+
new double [][][] { new [] { new [] { 1.0 }, new[] { 2.0 } }, new [] { new [] { 1.0 }, new [] { 4.0 } } },
32+
new double [,,] { { { 1.0 }, { 2.0 } }, { { 3.0 }, { 4.0 } } },
33+
false
34+
};
35+
}
36+
37+
38+
[Theory]
39+
[MemberData (nameof (jaggedData))]
40+
public void Should_MultidimensionalAndJaggedBeEqual (Array jagged, Array multidimensional, bool expected)
41+
{
42+
using (var graph = new TFGraph ())
43+
using (var session = new TFSession (graph)) {
44+
var tjagged = graph.Const (new TFTensor (jagged));
45+
var tmultidimensional = graph.Const (new TFTensor (multidimensional));
46+
47+
TFOutput y = graph.Equal (tjagged, tmultidimensional);
48+
TFOutput r;
49+
if (multidimensional.Rank == 2)
50+
r = graph.All (y, graph.Const (new [] { 0, 1 }));
51+
else if (multidimensional.Rank == 3)
52+
r = graph.All (y, graph.Const (new [] { 0, 1, 2 }));
53+
else
54+
throw new System.Exception ("If you want to test Ranks > 3 please handle this extra case manually.");
55+
56+
TFTensor [] result = session.Run (new TFOutput [] { }, new TFTensor [] { }, new [] { r });
57+
58+
bool actual = (bool)result [0].GetValue ();
59+
Assert.Equal (expected, actual);
60+
}
61+
}
62+
63+
}
64+
}

0 commit comments

Comments
 (0)