Skip to content

Commit 1810de4

Browse files
committed
Fixes TFTensor.GetValue for multidimensional arrays.
TFTensor.GetValue was updating the target in column major order, instead of the (documented) row major order. This commit aligns the multidimensional array behaviour with that when jagged: true is specified.
1 parent d74e95d commit 1810de4

File tree

2 files changed

+28
-5
lines changed

2 files changed

+28
-5
lines changed

SampleTest/SampleTest.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,29 @@ void TestVariable ()
421421
}
422422
}
423423

424+
void BasicMultidimensionalArray ()
425+
{
426+
Console.WriteLine ("Basic multidimensional array");
427+
using (var g = new TFGraph ()) {
428+
var s = new TFSession (g);
429+
430+
var var_a = g.Placeholder (TFDataType.Int32);
431+
var mul = g.Mul (var_a, g.Const (2));
432+
433+
var a = new int[,,] { { { 0, 1 } , { 2, 3 } } , { { 4, 5 }, { 6, 7 } } };
434+
var result = s.GetRunner ().AddInput (var_a, a).Fetch (mul).Run () [0];
435+
436+
var actual = (int[,,])result.GetValue ();
437+
var expected = new int[,,] { { { 0, 2 } , { 4, 6 } } , { { 8, 10 }, { 12, 14 } } };
438+
439+
Console.WriteLine ("Actual: " + RowOrderJoin (actual));
440+
Console.WriteLine ("Expected: " + RowOrderJoin (expected));
441+
Assert(expected.Cast<int> ().SequenceEqual (actual.Cast<int> ()));
442+
};
443+
}
444+
445+
private static string RowOrderJoin(int[,,] array) => string.Join (", ", array.Cast<int> ());
446+
424447
void BasicMatrix ()
425448
{
426449
Console.WriteLine ("Basic matrix");
@@ -573,6 +596,7 @@ public static void Main (string [] args)
573596

574597
t.BasicConstantOps ();
575598
t.BasicVariables ();
599+
t.BasicMultidimensionalArray ();
576600
t.BasicMatrix ();
577601

578602
t.NearestNeighbor ();

TensorFlowSharp/Tensor.cs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -824,18 +824,17 @@ static void FetchMultiDimensionalArray (Array target, TFDataType dt, IntPtr data
824824
for (int i = 0; i < shape.Length; i++) {
825825
if (shape [i] > Int32.MaxValue)
826826
throw new ArgumentOutOfRangeException ("Shape can not be longer than 32 bits");
827-
idx [i] = (int)shape [i];
828827
}
829-
Copy (target, dt, shape, idx, shape.Length - 1, ref data);
828+
Copy (target, dt, shape, idx, 0, ref data);
830829
}
831830

832831
static unsafe void Copy (Array target, TFDataType dt, long [] shape, int [] idx, int level, ref IntPtr data)
833832
{
834-
if (level > 0) {
833+
if (level < shape.Length - 1) {
835834
for (idx [level] = 0; idx [level] < shape [level]; idx [level]++)
836-
Copy (target, dt, shape, idx, level - 1, ref data);
835+
Copy (target, dt, shape, idx, level + 1, ref data);
837836
} else {
838-
for (idx [0] = 0; idx [0] < shape [0]; idx [0]++) {
837+
for (idx [level] = 0; idx [level] < shape [level]; idx [level]++) {
839838
switch (dt) {
840839
case TFDataType.Float:
841840
target.SetValue ((*(float*)data), idx);

0 commit comments

Comments
 (0)