|
5 | 5 | // Miguel de Icaza ([email protected]) |
6 | 6 | // |
7 | 7 | using System; |
| 8 | +using System.Collections; |
| 9 | +using System.Collections.Generic; |
| 10 | +using System.Linq; |
8 | 11 | using System.Numerics; |
9 | 12 | using System.Runtime.InteropServices; |
10 | 13 | using System.Text; |
@@ -238,7 +241,22 @@ public unsafe TFTensor (Array array) |
238 | 241 | { |
239 | 242 | if (array == null) |
240 | 243 | throw new ArgumentNullException (nameof (array)); |
241 | | - // TODO: ensure that we do not have arrays of arrays. |
| 244 | + |
| 245 | + // Ensure that, if we have arrays of arrays, we can handle them accordingly: |
| 246 | + if (isJagged (array.GetType ())) { |
| 247 | + Type elementType = getInnerMostType (array); |
| 248 | + int [] length = getLength (array); |
| 249 | + Array multidimensional = Array.CreateInstance (elementType, length); |
| 250 | + Array flatten = deepFlatten (array); |
| 251 | + Buffer.BlockCopy (flatten, 0, multidimensional, 0, flatten.Length * Marshal.SizeOf (elementType)); |
| 252 | + createFromMultidimensionalArrays (multidimensional); |
| 253 | + } else { |
| 254 | + createFromMultidimensionalArrays (array); |
| 255 | + } |
| 256 | + } |
| 257 | + |
| 258 | + private unsafe void createFromMultidimensionalArrays (Array array) |
| 259 | + { |
242 | 260 | var t = array.GetType ().GetElementType (); |
243 | 261 | var tc = Type.GetTypeCode (t); |
244 | 262 | TFDataType dt; |
@@ -1115,6 +1133,152 @@ public override string ToString () |
1115 | 1133 | sb.Append ("]"); |
1116 | 1134 | return sb.ToString (); |
1117 | 1135 | } |
| 1136 | + |
| 1137 | + |
| 1138 | + |
| 1139 | + |
| 1140 | + |
| 1141 | + |
| 1142 | + |
| 1143 | + |
| 1144 | + |
| 1145 | + private static int [] getLength (Array array, bool deep = true, bool max = false) |
| 1146 | + { |
| 1147 | + // This function gets the length of all dimensions in a multidimensional, jagged, or mixed array. |
| 1148 | + // https://github.com/accord-net/framework/blob/b4990721a61f03602d04c12b148215c7eca1b7ac/Sources/Accord.Math/Matrix/Matrix.Construction.cs#L1118 |
| 1149 | + // Relicensed under the MIT license by the original author for inclusion in TensorFlowSharp and any derived projects, see the MIT license for details |
| 1150 | + |
| 1151 | + if (array.Rank == 0) |
| 1152 | + return new int [0]; |
| 1153 | + |
| 1154 | + if (deep && isJagged (array)) { |
| 1155 | + if (array.Length == 0) |
| 1156 | + return new int [0]; |
| 1157 | + |
| 1158 | + int [] rest; |
| 1159 | + if (!max) { |
| 1160 | + rest = getLength (array.GetValue (0) as Array, deep); |
| 1161 | + } else { |
| 1162 | + // find the max |
| 1163 | + rest = getLength (array.GetValue (0) as Array, deep); |
| 1164 | + for (int i = 1; i < array.Length; i++) { |
| 1165 | + int [] r = getLength (array.GetValue (i) as Array, deep); |
| 1166 | + |
| 1167 | + for (int j = 0; j < r.Length; j++) { |
| 1168 | + if (r [j] > rest [j]) |
| 1169 | + rest [j] = r [j]; |
| 1170 | + } |
| 1171 | + } |
| 1172 | + } |
| 1173 | + |
| 1174 | + return new [] { array.Length }.Concat (rest).ToArray (); |
| 1175 | + } |
| 1176 | + |
| 1177 | + int [] vector = new int [array.Rank]; |
| 1178 | + for (int i = 0; i < vector.Length; i++) |
| 1179 | + vector [i] = array.GetUpperBound (i) + 1; |
| 1180 | + return vector; |
| 1181 | + } |
| 1182 | + |
| 1183 | + private static Array deepFlatten (Array array) |
| 1184 | + { |
| 1185 | + // This function converts multidimensional, jagged, or mixed arrays into a single unidimensional array (i.e. flattens the mixed array). |
| 1186 | + // https://github.com/accord-net/framework/blob/f78181b82eb6ee6cc7fd10d2a7a55334982c40df/Sources/Accord.Math/Matrix/Matrix.Common.cs#L1625 |
| 1187 | + // Relicensed under the MIT license by the original author for inclusion in TensorFlowSharp and any derived projects, see the MIT license for details |
| 1188 | + int totalLength = getTotalLength (array, deep: true); |
| 1189 | + var elementType = getInnerMostType (array); |
| 1190 | + Array result = Array.CreateInstance (elementType, totalLength); |
| 1191 | + |
| 1192 | + int k = 0; |
| 1193 | + foreach (object v in enumerateJagged (array)) |
| 1194 | + result.SetValue (v, k++); |
| 1195 | + return result; |
| 1196 | + } |
| 1197 | + |
| 1198 | + private static IEnumerable enumerateJagged (Array array) |
| 1199 | + { |
| 1200 | + // This function can enumerate all elements in a multidimensional ,jagged, or mixed array. |
| 1201 | + // From https://github.com/accord-net/framework/blob/b4990721a61f03602d04c12b148215c7eca1b7ac/Sources/Accord.Math/Matrix/Jagged.Construction.cs#L1202 |
| 1202 | + // Relicensed under the MIT license by the original author for inclusion in TensorFlowSharp and any derived projects, see the MIT license for details |
| 1203 | + var arrays = new Stack<Array> (); |
| 1204 | + var counters = new Stack<int> (); |
| 1205 | + |
| 1206 | + arrays.Push (array); |
| 1207 | + counters.Push (0); |
| 1208 | + int depth = 1; |
| 1209 | + |
| 1210 | + Array a = array; |
| 1211 | + int i = 0; |
| 1212 | + |
| 1213 | + while (arrays.Count > 0) { |
| 1214 | + if (i >= a.Length) { |
| 1215 | + a = arrays.Pop (); |
| 1216 | + i = counters.Pop () + 1; |
| 1217 | + depth--; |
| 1218 | + } else { |
| 1219 | + Object e = a.GetValue (i); |
| 1220 | + Array next = e as Array; |
| 1221 | + if (next == null) { |
| 1222 | + yield return e; |
| 1223 | + i++; |
| 1224 | + } else { |
| 1225 | + arrays.Push (a); |
| 1226 | + counters.Push (i); |
| 1227 | + a = next; |
| 1228 | + i = 0; |
| 1229 | + depth++; |
| 1230 | + } |
| 1231 | + } |
| 1232 | + } |
| 1233 | + } |
| 1234 | + |
| 1235 | + private static int getTotalLength (Array array, bool deep = true, bool rectangular = true) |
| 1236 | + { |
| 1237 | + // From https://github.com/accord-net/framework/blob/b4990721a61f03602d04c12b148215c7eca1b7ac/Sources/Accord.Math/Matrix/Matrix.Construction.cs#L1087 |
| 1238 | + // Relicensed under the MIT license by the original author for inclusion in TensorFlowSharp and any derived projects, see the MIT license for details |
| 1239 | + if (deep && isJagged (array.GetType ())) { |
| 1240 | + if (rectangular) { |
| 1241 | + int rest = getTotalLength (array.GetValue (0) as Array, deep); |
| 1242 | + return array.Length * rest; |
| 1243 | + } else { |
| 1244 | + int sum = 0; |
| 1245 | + for (int i = 0; i < array.Length; i++) |
| 1246 | + sum += getTotalLength (array.GetValue (i) as Array, deep); |
| 1247 | + return sum; |
| 1248 | + } |
| 1249 | + } |
| 1250 | + |
| 1251 | + return array.Length; |
| 1252 | + } |
| 1253 | + |
| 1254 | + private static bool isJagged (Array array) |
| 1255 | + { |
| 1256 | + // From https://github.com/accord-net/framework/blob/f78181b82eb6ee6cc7fd10d2a7a55334982c40df/Sources/Accord.Math/Matrix/Matrix.Construction.cs#L1204 |
| 1257 | + // Relicensed under the MIT license by the original author for inclusion in TensorFlowSharp and any derived projects, see the MIT license for details |
| 1258 | + if (array.Length == 0) |
| 1259 | + return array.Rank == 1; |
| 1260 | + return array.Rank == 1 && array.GetValue (0) is Array; |
| 1261 | + } |
| 1262 | + |
| 1263 | + private static bool isJagged (Type type) |
| 1264 | + { |
| 1265 | + // From https://github.com/accord-net/framework/blob/eb371fbc540a41c1a711b6ab1ebd49889316e7f7/Sources/Accord.Math/Matrix/Matrix.Common.cs#L84 |
| 1266 | + // Relicensed under the MIT license by the original author for inclusion in TensorFlowSharp and any derived projects, see the MIT license for details |
| 1267 | + return type.IsArray && type.GetElementType ().IsArray; |
| 1268 | + } |
| 1269 | + |
| 1270 | + private static Type getInnerMostType (Array array) |
| 1271 | + { |
| 1272 | + // From https://github.com/accord-net/framework/blob/eb371fbc540a41c1a711b6ab1ebd49889316e7f7/Sources/Accord.Math/Matrix/Matrix.Common.cs#L95 |
| 1273 | + // Relicensed under the MIT license by the original author for inclusion in TensorFlowSharp and any derived projects, see the MIT license for details |
| 1274 | + Type type = array.GetType (); |
| 1275 | + |
| 1276 | + while (type.IsArray) |
| 1277 | + type = type.GetElementType (); |
| 1278 | + |
| 1279 | + return type; |
| 1280 | + } |
| 1281 | + |
1118 | 1282 | } |
1119 | 1283 |
|
1120 | 1284 | } |
0 commit comments