Skip to content

Commit 6ba299e

Browse files
cesarsouzamigueldeicaza
authored andcommitted
migueldeicazaGH-118: Adding ReduceMean and related unit tests. (migueldeicaza#126)
1 parent b81eba9 commit 6ba299e

File tree

4 files changed

+98
-1
lines changed

4 files changed

+98
-1
lines changed

TensorFlowSharp/OperationsExtras.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,32 @@ public TFOutput ReduceSum (TFOutput input, TFOutput? axis = null, bool? keep_dim
6262
return Sum (input, this.ReduceDims (input, axis), keep_dims, operName);
6363
}
6464

65+
/// <summary>
66+
/// Computes the mean of elements across dimensions of a tensor.
67+
/// </summary>
68+
/// <returns>The reduced tensor.</returns>
69+
/// <param name="input">The tensor to reduce. Should have numeric type.</param>
70+
/// <param name="axis">The dimensions to reduce. If not set (the default), reduces all dimensions.</param>
71+
/// <param name="keep_dims">If set to <c>true</c> retains reduced dimensions with length 1.</param>
72+
/// <param name="operName">A name for the operation, optional.</param>
73+
/// <remarks>
74+
/// <para>
75+
/// Reduces input_tensor along the dimensions given in axis.
76+
/// Unless keep_dims is true, the rank of the tensor is reduced by 1 for each
77+
/// entry in axis. If keep_dims is true, the reduced dimensions
78+
/// are retained with length 1.</para>
79+
///
80+
/// <para>
81+
/// If axis has no entries, all dimensions are reduced, and a
82+
/// tensor with a single element is returned.</para>
83+
/// </remarks>
84+
public TFOutput ReduceMean (TFOutput input, TFOutput? axis = null, bool? keep_dims = false, string operName = null)
85+
{
86+
if (input.OutputType == TFDataType.Bool)
87+
input = this.Cast (input, TFDataType.Int8);
88+
return this.Mean (input, this.ReduceDims (input, axis), keep_dims, operName);
89+
}
90+
6591
/// <summary>
6692
/// Variable node, with a starting initial value.
6793
/// </summary>
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
using System.Collections.Generic;
2+
using TensorFlow;
3+
using Xunit;
4+
5+
namespace TensorFlowSharp.Tests.CSharp
6+
{
7+
public class MathTests
8+
{
9+
private static IEnumerable<object []> reduceMeanData ()
10+
{
11+
// Example from https://www.tensorflow.org/api_docs/python/tf/reduce_mean
12+
// # 'x' is [[1., 1.]
13+
// # [2., 2.]]
14+
// tf.reduce_mean (x) ==> 1.5
15+
// tf.reduce_mean (x, 0) ==> [1.5, 1.5]
16+
// tf.reduce_mean (x, 1) ==> [1., 2.]
17+
18+
var x = new double [,] { { 1, 1 },
19+
{ 2, 2 } };
20+
21+
yield return new object [] { x, null, 1.5 };
22+
yield return new object [] { x, 0, new double [] { 1.5, 1.5 } };
23+
yield return new object [] { x, 1, new double [] { 1, 2 } };
24+
}
25+
26+
[Theory]
27+
[MemberData (nameof (reduceMeanData))]
28+
public void Should_ReduceMean (double [,] input, int? axis, object expected)
29+
{
30+
using (var graph = new TFGraph ())
31+
using (var session = new TFSession (graph)) {
32+
var tinput = graph.Placeholder (TFDataType.Double, new TFShape (2, 2));
33+
34+
TFTensor [] result;
35+
if (axis != null) {
36+
var taxis = graph.Const (axis.Value);
37+
TFOutput y = graph.ReduceMean (tinput, taxis);
38+
result = session.Run (new [] { tinput, taxis }, new TFTensor [] { input, axis }, new [] { y });
39+
40+
double [] actual = (double [])result [0].GetValue ();
41+
TestUtils.MatrixEqual (expected, actual, precision: 8);
42+
} else {
43+
TFOutput y = graph.ReduceMean (tinput, axis: null);
44+
result = session.Run (new [] { tinput }, new TFTensor [] { input }, new [] { y });
45+
46+
double actual = (double)result [0].GetValue ();
47+
TestUtils.MatrixEqual (expected, actual, precision: 8);
48+
}
49+
}
50+
}
51+
52+
}
53+
}

tests/TensorFlowSharp.Tests.CSharp/TensorFlowSharp.Tests.CSharp.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
<Compile Include="ArrayTests.cs" />
6464
<Compile Include="ClipTests.cs" />
6565
<Compile Include="BitwiseOperationTests.cs" />
66+
<Compile Include="MathTests.cs" />
6667
<Compile Include="PartialRunTests.cs" />
6768
<Compile Include="Properties\AssemblyInfo.cs" />
6869
<Compile Include="TestUtils.cs" />

tests/TensorFlowSharp.Tests.CSharp/TestUtils.cs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,22 @@ public static void MatrixEqual(Array expected, Array actual, int precision)
4343
Assert.True(Object.Equals(ei.Current, ai.Current));
4444
}
4545
}
46-
}
46+
47+
public static void MatrixEqual (object expected, object actual, int precision)
48+
{
49+
if (expected is Array) {
50+
MatrixEqual (expected as Array, actual as Array, precision);
51+
return;
52+
}
53+
var expectedType = expected.GetType ();
54+
55+
if (expectedType == typeof (double)) {
56+
Assert.Equal ((double)expected, (double)actual, precision: precision);
57+
} else if (expectedType == typeof (float)) {
58+
Assert.Equal ((float)expected, (float)actual, precision: precision);
59+
} else {
60+
Assert.True (Object.Equals (expected, actual));
61+
}
62+
}
63+
}
4764
}

0 commit comments

Comments
 (0)