Skip to content

Commit 006361b

Browse files
zeahmedmigueldeicaza
authored andcommitted
Implemented adaptive SGD optimizer (Adagrad). (migueldeicaza#412)
* Setting variable shared_name property to avoid variables sharing. * Added SGD and MomentumSGD optimizers together with relevant tests. * Tests added for momentum and Nesterov SGD with and without lr decay. * Added MNIST multilayer test. * Added MNIST GPU test in disabled mode. * Added support to place an operation on a specific device. * Disabled 'DevicePlacementTest' because it requires GPUs. * Added MNIST multilayer test. * Updated comments. * Disabled MnistGPU test. * Removed uncessary files. * Added Adagrad optimization algorithm.
1 parent 4a8356c commit 006361b

File tree

7 files changed

+467
-3
lines changed

7 files changed

+467
-3
lines changed

TensorFlowSharp/Optimizer.cs

Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,10 @@ public virtual (TFOutput gradient, Variable variable)[] ComputeGradient(TFOutput
7777
/// <param name="varList">list of variable to compute the gradients for.
7878
/// If null the gradient is computed for all the trainable variables in the graph./param>
7979
/// <returns>An Operation that updates the variables.</returns>
80-
public abstract TFOperation[] Minimize(TFOutput loss, Variable[] varList = null);
80+
public virtual TFOperation[] Minimize(TFOutput loss, Variable[] varList = null)
81+
{
82+
return ApplyGradient(ComputeGradient(loss, varList));
83+
}
8184
}
8285

8386
/// <summary>
@@ -188,11 +191,103 @@ public override TFOperation[] ApplyGradient((TFOutput gradient, Variable variabl
188191
}
189192
return _updateOps.ToArray();
190193
}
194+
}
195+
196+
/// <summary>
197+
/// Adaptive stochastic gradient descent optimizer.
198+
/// </summary>
199+
public sealed class Adagrad : Optimizer
200+
{
201+
/// <summary>
202+
/// Varaible to keep track of number of iterations (mini-batch processed)
203+
/// </summary>
204+
public Variable Iterations { get; }
205+
206+
/// <summary>
207+
/// Variable to keep track of the learning rate.
208+
/// </summary>
209+
public Variable LearningRate { get; }
210+
211+
private readonly string _lrName = "LearningRate";
212+
private readonly IList<TFOperation> _updateOps = new List<TFOperation>();
213+
private float _initialAccumulatorValue;
214+
private TFOutput _epsilon;
215+
216+
/// <summary>
217+
/// Construct Adagrad optimizer.
218+
/// </summary>
219+
/// <param name="graph">The graph object.</param>
220+
/// <param name="learningRate">The learning rate for the SGD update.</param>
221+
/// <param name="decay">Learning rate decay over each update.</param>
222+
/// <param name="initialAccumulatorValue">A floating point value. Starting value for the accumulators, must be positive.</param>
223+
/// <param name="operName">Name the optimizer. All the variable that are created in this class will be created under this scope.</param>
224+
public Adagrad(TFGraph graph, float learningRate, float decay = 0, float initialAccumulatorValue = 0.1f, string operName = "AdagradOptimizer") : base(graph, operName)
225+
{
226+
if (initialAccumulatorValue < 0)
227+
throw new ArgumentException($"Value must be positive. initialAccumulatorValue = {initialAccumulatorValue}");
228+
229+
using (var scope = _graph.WithScope(_optimizerName))
230+
{
231+
Iterations = _graph.Variable(_graph.Const(new TFTensor(0L)), trainable: false, operName: "iterations");
232+
_updateOps.Add(_graph.AssignAddVariableOp(Iterations, _graph.Const(1L)));
233+
var initialLearningRate = _graph.Const(learningRate);
234+
LearningRate = _graph.Variable(initialLearningRate, trainable: false, operName: _lrName);
235+
CreateDecayOps(decay, initialLearningRate);
236+
}
237+
_initialAccumulatorValue = initialAccumulatorValue;
238+
_epsilon = _graph.Const(1e-7f);
239+
}
240+
241+
private void CreateDecayOps(float decay, TFOutput initialLearningRate)
242+
{
243+
if (decay > 0)
244+
{
245+
var _decay = _graph.Const(decay, "Decay");
246+
var one = _graph.Const(1f);
247+
_updateOps.Add(_graph.AssignVariableOp(LearningRate,
248+
_graph.Mul(initialLearningRate,
249+
_graph.Div(one,
250+
_graph.Add(one,
251+
_graph.Mul(_decay,
252+
_graph.Cast(Iterations.Read, _decay.OutputType)
253+
)
254+
)
255+
)
256+
)));
257+
}
258+
}
259+
260+
private TFOutput[] InitMoments((TFOutput gradient, Variable variable)[] gradientsAndVariables)
261+
{
262+
var accumulators = new TFOutput[gradientsAndVariables.Length];
263+
for (int i = 0; i < gradientsAndVariables.Length; i++)
264+
{
265+
var gv = gradientsAndVariables[i];
266+
var varType = gv.variable.Read.OutputType;
267+
var varShape = _graph.GetTensorShape(gv.variable.Read);
268+
accumulators[i] = _graph.VariableV2(varShape, varType);
269+
_graph.AddInitVariable(_graph.Assign(accumulators[i], _graph.Constant(_initialAccumulatorValue, varShape, varType)).Operation);
270+
}
271+
return accumulators;
272+
}
191273

192274
/// <inheritdoc />
193-
public override TFOperation[] Minimize(TFOutput loss, Variable[] varList = null)
275+
public override TFOperation[] ApplyGradient((TFOutput gradient, Variable variable)[] gradientsAndVariables)
194276
{
195-
return ApplyGradient(ComputeGradient(loss, varList));
277+
var accumulators = InitMoments(gradientsAndVariables);
278+
for (int i = 0; i < gradientsAndVariables.Length; i++)
279+
{
280+
var gv = gradientsAndVariables[i];
281+
var lr = _graph.Cast(LearningRate.Read, gv.gradient.OutputType);
282+
// accum = g ** 2;
283+
var accum = _graph.Add(accumulators[i], _graph.Square(gv.gradient));
284+
// accumulators[i] = accum
285+
_updateOps.Add(_graph.Assign(accumulators[i], accum).Operation);
286+
// w = w - lr * g / sqrt(accum + 1e-7)
287+
var denom = _graph.Div(_graph.Mul(lr, gv.gradient), _graph.Sqrt(_graph.Add(accum, _epsilon)));
288+
_updateOps.Add(_graph.AssignSubVariableOp(gv.variable, denom));
289+
}
290+
return _updateOps.ToArray();
196291
}
197292
}
198293
}

tests/TensorFlowSharp.Tests.CSharp/OptimizerTests.cs

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,5 +556,118 @@ public void MNISTTwoHiddenLayerNetworkGPUTest()
556556
}
557557
}
558558
}
559+
560+
561+
[Fact]
562+
public void LinearRegresionTrainingWithAdagradTest()
563+
{
564+
Console.WriteLine("Linear regression");
565+
// Parameters
566+
var learning_rate = 0.01f;
567+
var training_epochs = 5;
568+
569+
// Training data
570+
var train_x = new float[] {
571+
3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f,
572+
7.042f, 10.791f, 5.313f, 7.997f, 5.654f, 9.27f, 3.1f
573+
};
574+
var train_y = new float[] {
575+
1.7f, 2.76f,2.09f,3.19f,1.694f,1.573f,3.366f,2.596f,2.53f,1.221f,
576+
2.827f,3.465f,1.65f,2.904f,2.42f,2.94f,1.3f
577+
};
578+
var n_samples = train_x.Length;
579+
using (var graph = new TFGraph())
580+
{
581+
var rng = new Random(0);
582+
// tf Graph Input
583+
584+
var X = graph.Placeholder(TFDataType.Float, TFShape.Scalar);
585+
var Y = graph.Placeholder(TFDataType.Float, TFShape.Scalar);
586+
587+
var W = graph.Variable(graph.Const(0.1f), operName: "weight");
588+
var b = graph.Variable(graph.Const(0.1f), operName: "bias");
589+
var pred = graph.Add(graph.Mul(X, W.Read, "x_w"), b.Read);
590+
591+
var cost = graph.Div(graph.ReduceSum(graph.Pow(graph.Sub(pred, Y), graph.Const(2f))), graph.Mul(graph.Const(2f), graph.Const((float)n_samples), "2_n_samples"));
592+
593+
var sgd = new Adagrad(graph, learning_rate);
594+
var updateOps = sgd.Minimize(cost);
595+
596+
using (var sesssion = new TFSession(graph))
597+
{
598+
sesssion.GetRunner().AddTarget(graph.GetGlobalVariablesInitializer()).Run();
599+
600+
var expectedLines = File.ReadAllLines(Path.Combine(_testDataPath, "Adagrad", "expected.txt"));
601+
for (int i = 0; i < training_epochs; i++)
602+
{
603+
for (int j = 0; j < n_samples; j++)
604+
{
605+
var tensors = sesssion.GetRunner()
606+
.AddInput(X, new TFTensor(train_x[j]))
607+
.AddInput(Y, new TFTensor(train_y[j]))
608+
.AddTarget(updateOps).Fetch(cost, W.Read, b.Read, pred).Run();
609+
var output = $"loss: {tensors[0].GetValue():F4}, W: {tensors[1].GetValue():F4}, b: {tensors[2].GetValue():F4}";
610+
Assert.Equal(expectedLines[i * n_samples + j], output);
611+
}
612+
}
613+
}
614+
}
615+
}
616+
617+
[Fact]
618+
public void LinearRegresionTrainingWithAdagradDecayTest()
619+
{
620+
Console.WriteLine("Linear regression");
621+
// Parameters
622+
var learning_rate = 0.01f;
623+
var training_epochs = 5;
624+
625+
// Training data
626+
var train_x = new float[] {
627+
3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f,
628+
7.042f, 10.791f, 5.313f, 7.997f, 5.654f, 9.27f, 3.1f
629+
};
630+
var train_y = new float[] {
631+
1.7f, 2.76f,2.09f,3.19f,1.694f,1.573f,3.366f,2.596f,2.53f,1.221f,
632+
2.827f,3.465f,1.65f,2.904f,2.42f,2.94f,1.3f
633+
};
634+
var n_samples = train_x.Length;
635+
using (var graph = new TFGraph())
636+
{
637+
var rng = new Random(0);
638+
// tf Graph Input
639+
640+
var X = graph.Placeholder(TFDataType.Float, TFShape.Scalar);
641+
var Y = graph.Placeholder(TFDataType.Float, TFShape.Scalar);
642+
643+
var W = graph.Variable(graph.Const(0.1f), operName: "weight");
644+
var b = graph.Variable(graph.Const(0.1f), operName: "bias");
645+
var pred = graph.Add(graph.Mul(X, W.Read, "x_w"), b.Read);
646+
647+
var cost = graph.Div(graph.ReduceSum(graph.Pow(graph.Sub(pred, Y), graph.Const(2f))), graph.Mul(graph.Const(2f), graph.Const((float)n_samples), "2_n_samples"));
648+
649+
var sgd = new Adagrad(graph, learning_rate, decay: 0.5f);
650+
var updateOps = sgd.Minimize(cost);
651+
652+
using (var sesssion = new TFSession(graph))
653+
{
654+
sesssion.GetRunner().AddTarget(graph.GetGlobalVariablesInitializer()).Run();
655+
656+
var expectedLines = File.ReadAllLines(Path.Combine(_testDataPath, "AdagradTimeDecay", "expected.txt"));
657+
for (int i = 0; i < training_epochs; i++)
658+
{
659+
for (int j = 0; j < n_samples; j++)
660+
{
661+
var tensors = sesssion.GetRunner()
662+
.AddInput(X, new TFTensor(train_x[j]))
663+
.AddInput(Y, new TFTensor(train_y[j]))
664+
.AddTarget(updateOps).Fetch(sgd.Iterations.Read, cost, W.Read, b.Read, sgd.LearningRate.Read).Run();
665+
var output = $"step: {tensors[0].GetValue():D}, loss: {tensors[1].GetValue():F4}, W: {tensors[2].GetValue():F4}, b: {tensors[3].GetValue():F4}, lr: {tensors[4].GetValue():F8}";
666+
Assert.Equal(expectedLines[i * n_samples + j], output);
667+
}
668+
}
669+
}
670+
}
671+
}
559672
}
560673
}

tests/TensorFlowSharp.Tests.CSharp/TensorFlowSharp.Tests.CSharp.csproj

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@
8686
</ItemGroup>
8787
<ItemGroup>
8888
<None Include="packages.config" />
89+
<None Include="TestData\AdagradTimeDecay\optimizer_lr_test.py">
90+
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
91+
</None>
92+
<None Include="TestData\Adagrad\optimizer_lr_test.py">
93+
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
94+
</None>
8995
<None Include="TestData\MomentumNesterovTimeDecay\optimizer_lr_test.py">
9096
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
9197
</None>
@@ -123,6 +129,12 @@
123129
</ItemGroup>
124130
<ItemGroup />
125131
<ItemGroup>
132+
<Content Include="TestData\AdagradTimeDecay\expected.txt">
133+
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
134+
</Content>
135+
<Content Include="TestData\Adagrad\expected.txt">
136+
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
137+
</Content>
126138
<Content Include="TestData\MomentumNesterovTimeDecay\expected.txt">
127139
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
128140
</Content>
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
loss: 0.0474, W: 0.1000, b: 0.1000
2+
loss: 0.1411, W: 0.1061, b: 0.1023
3+
loss: 0.0540, W: 0.1143, b: 0.1060
4+
loss: 0.1528, W: 0.1197, b: 0.1082
5+
loss: 0.0145, W: 0.1270, b: 0.1117
6+
loss: 0.0250, W: 0.1293, b: 0.1128
7+
loss: 0.1141, W: 0.1311, b: 0.1142
8+
loss: 0.0779, W: 0.1378, b: 0.1170
9+
loss: 0.0528, W: 0.1410, b: 0.1193
10+
loss: 0.0182, W: 0.1442, b: 0.1212
11+
loss: 0.0836, W: 0.1447, b: 0.1223
12+
loss: 0.0892, W: 0.1482, b: 0.1245
13+
loss: 0.0149, W: 0.1529, b: 0.1268
14+
loss: 0.0702, W: 0.1539, b: 0.1277
15+
loss: 0.0579, W: 0.1569, b: 0.1297
16+
loss: 0.0525, W: 0.1588, b: 0.1315
17+
loss: 0.0130, W: 0.1616, b: 0.1331
18+
loss: 0.0313, W: 0.1621, b: 0.1340
19+
loss: 0.1071, W: 0.1629, b: 0.1352
20+
loss: 0.0322, W: 0.1647, b: 0.1375
21+
loss: 0.1104, W: 0.1660, b: 0.1387
22+
loss: 0.0043, W: 0.1688, b: 0.1410
23+
loss: 0.0155, W: 0.1693, b: 0.1414
24+
loss: 0.0717, W: 0.1700, b: 0.1422
25+
loss: 0.0562, W: 0.1730, b: 0.1440
26+
loss: 0.0329, W: 0.1747, b: 0.1455
27+
loss: 0.0141, W: 0.1763, b: 0.1467
28+
loss: 0.0606, W: 0.1766, b: 0.1475
29+
loss: 0.0568, W: 0.1786, b: 0.1491
30+
loss: 0.0085, W: 0.1813, b: 0.1506
31+
loss: 0.0496, W: 0.1819, b: 0.1511
32+
loss: 0.0444, W: 0.1837, b: 0.1525
33+
loss: 0.0338, W: 0.1850, b: 0.1538
34+
loss: 0.0094, W: 0.1867, b: 0.1549
35+
loss: 0.0253, W: 0.1871, b: 0.1555
36+
loss: 0.0930, W: 0.1876, b: 0.1565
37+
loss: 0.0234, W: 0.1890, b: 0.1583
38+
loss: 0.0908, W: 0.1898, b: 0.1593
39+
loss: 0.0012, W: 0.1918, b: 0.1610
40+
loss: 0.0110, W: 0.1921, b: 0.1612
41+
loss: 0.0514, W: 0.1925, b: 0.1618
42+
loss: 0.0445, W: 0.1947, b: 0.1632
43+
loss: 0.0227, W: 0.1959, b: 0.1644
44+
loss: 0.0116, W: 0.1970, b: 0.1652
45+
loss: 0.0476, W: 0.1972, b: 0.1659
46+
loss: 0.0392, W: 0.1987, b: 0.1671
47+
loss: 0.0051, W: 0.2007, b: 0.1682
48+
loss: 0.0374, W: 0.2010, b: 0.1686
49+
loss: 0.0360, W: 0.2024, b: 0.1697
50+
loss: 0.0230, W: 0.2034, b: 0.1708
51+
loss: 0.0072, W: 0.2047, b: 0.1716
52+
loss: 0.0213, W: 0.2049, b: 0.1721
53+
loss: 0.0834, W: 0.2054, b: 0.1729
54+
loss: 0.0179, W: 0.2065, b: 0.1745
55+
loss: 0.0776, W: 0.2072, b: 0.1752
56+
loss: 0.0001, W: 0.2088, b: 0.1767
57+
loss: 0.0081, W: 0.2089, b: 0.1768
58+
loss: 0.0384, W: 0.2092, b: 0.1772
59+
loss: 0.0365, W: 0.2109, b: 0.1783
60+
loss: 0.0162, W: 0.2119, b: 0.1793
61+
loss: 0.0099, W: 0.2128, b: 0.1800
62+
loss: 0.0387, W: 0.2130, b: 0.1805
63+
loss: 0.0278, W: 0.2141, b: 0.1816
64+
loss: 0.0030, W: 0.2157, b: 0.1824
65+
loss: 0.0291, W: 0.2159, b: 0.1827
66+
loss: 0.0299, W: 0.2171, b: 0.1836
67+
loss: 0.0159, W: 0.2179, b: 0.1845
68+
loss: 0.0056, W: 0.2189, b: 0.1852
69+
loss: 0.0184, W: 0.2191, b: 0.1856
70+
loss: 0.0761, W: 0.2194, b: 0.1863
71+
loss: 0.0140, W: 0.2204, b: 0.1877
72+
loss: 0.0679, W: 0.2210, b: 0.1883
73+
loss: 0.0000, W: 0.2224, b: 0.1896
74+
loss: 0.0061, W: 0.2224, b: 0.1896
75+
loss: 0.0293, W: 0.2227, b: 0.1900
76+
loss: 0.0306, W: 0.2240, b: 0.1908
77+
loss: 0.0117, W: 0.2249, b: 0.1917
78+
loss: 0.0086, W: 0.2256, b: 0.1923
79+
loss: 0.0321, W: 0.2257, b: 0.1927
80+
loss: 0.0200, W: 0.2267, b: 0.1936
81+
loss: 0.0018, W: 0.2279, b: 0.1943
82+
loss: 0.0230, W: 0.2281, b: 0.1946
83+
loss: 0.0254, W: 0.2291, b: 0.1953
84+
loss: 0.0111, W: 0.2298, b: 0.1961
85+
loss: 0.0044, W: 0.2306, b: 0.1966
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# This script is used to create data file (expected.txt)
2+
# which is used to compare the output from TensorFlowSharp optimizer tests.
3+
4+
import tensorflow as tf
5+
6+
# Training data
7+
train_x =[
8+
3.3, 4.4, 5.5, 6.71, 6.93, 4.168, 9.779, 6.182, 7.59, 2.167,
9+
7.042, 10.791, 5.313, 7.997, 5.654, 9.27, 3.1
10+
]
11+
train_y = [
12+
1.7, 2.76,2.09,3.19,1.694,1.573,3.366,2.596,2.53,1.221,
13+
2.827,3.465,1.65,2.904,2.42,2.94,1.3
14+
]
15+
n_samples = len(train_x)
16+
learning_rate = 0.01
17+
X = tf.placeholder(tf.float32)
18+
Y = tf.placeholder(tf.float32)
19+
20+
W = tf.Variable(tf.constant(0.1), dtype=tf.float32)
21+
b = tf.Variable(tf.constant(0.1), dtype=tf.float32)
22+
23+
pred = tf.add(tf.multiply(X,W), b)
24+
25+
cost = tf.divide(tf.reduce_sum(tf.pow(tf.subtract(pred, Y), 2.0)), tf.multiply(2.0, n_samples))
26+
optimizer = tf.train.AdagradOptimizer(learning_rate).minimize(cost, name = "AdagradOptimizer")
27+
28+
init = tf.global_variables_initializer()
29+
with tf.Session() as session:
30+
session.run(init)
31+
for e in range(5):
32+
for i in range(n_samples):
33+
_, cost_v, W_v, b_v, pred_v = session.run([optimizer, cost, W, b, pred], feed_dict = {X: train_x[i], Y: train_y[i]})
34+
print(f"loss: {cost_v:.4f}, W: {W_v:.4f}, b: {b_v:.4f}")
35+
#print("Prediction: %f == Actual: %f" % (pred_v, train_y[i]))

0 commit comments

Comments
 (0)