Skip to content

Commit a796177

Browse files
committed
Add support for variables, based on the guidance from Asim Shankar, and https://github.com/asimshankar/tensorflow/tree/govars/tensorflow/go/v
1 parent 7ceaae8 commit a796177

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

SampleTest/SampleTest.cs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,40 @@ void BasicVariables ()
386386
}
387387
}
388388

389+
//
390+
// Shows the use of Variable
391+
//
392+
void TestVariable ()
393+
{
394+
Console.WriteLine ("Variables");
395+
var status = new TFStatus ();
396+
using (var g = new TFGraph ()) {
397+
var initValue = g.Const (1.5);
398+
var increment = g.Const (0.5);
399+
TFOperation init;
400+
TFOutput value;
401+
var handle = g.Variable (initValue, out init, out value);
402+
403+
// Add 0.5 and assign to the variable.
404+
// Perhaps using op.AssignAddVariable would be better,
405+
// but demonstrating with Add and Assign for now.
406+
var update = g.AssignVariableOp (handle, g.Add (value, increment));
407+
408+
var s = new TFSession (g);
409+
// Must first initialize all the variables.
410+
s.GetRunner ().AddTarget (init).Run (status);
411+
Assert (status);
412+
// Now print the value, run the update op and repeat
413+
// Ignore errors.
414+
for (int i = 0; i < 5; i++) {
415+
// Read and update
416+
var result = s.GetRunner ().Fetch (value).AddTarget (update).Run ();
417+
418+
Console.WriteLine ("Result of variable read {0} -> {1}", i, result [0].GetValue ());
419+
}
420+
}
421+
}
422+
389423
void BasicMatrix ()
390424
{
391425
Console.WriteLine ("Basic matrix");
@@ -466,6 +500,7 @@ public static void Main (string [] args)
466500
t.TestImportGraphDef ();
467501
t.TestSession ();
468502
t.TestOperationOutputListSize ();
503+
t.TestVariable ();
469504

470505
// Current failing test
471506
t.TestOutputShape ();

TensorFlowSharp/OperationsExtras.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,5 +60,30 @@ public TFOutput ReduceSum (TFOutput input, TFOutput? axis = null, bool? keep_dim
6060
{
6161
return Sum (input, this.ReduceDims (input, axis), keep_dims, operName);
6262
}
63+
64+
/// <summary>
65+
/// Variable node, with a starting initial value.
66+
/// </summary>
67+
/// <param name="initialValue">Initial value.</param>
68+
/// <param name="init">Returns the operation that initializes the value of the variable.</param>
69+
/// <param name="value">Returns the value of the variable.</param>
70+
/// <param name="operName">Operation name, optional.</param>
71+
/// <returns>The returning TFOutput returns the handle to the variable.</returns>
72+
public TFOutput Variable (TFOutput initialValue, out TFOperation init, out TFOutput value, string operName = null)
73+
{
74+
var scopeName = MakeName ("Variable", operName);
75+
76+
using (var newScope = WithScope (scopeName)) {
77+
var type = initialValue.OutputType;
78+
var handle = VarHandleOp (type, new TFShape (GetShape (initialValue)));
79+
using (var aScope = WithScope ("Assign")) {
80+
init = AssignVariableOp (handle, initialValue);
81+
using (var rScope = WithScope ("Read")) {
82+
value = ReadVariableOp (handle, type);
83+
return handle;
84+
}
85+
}
86+
}
87+
}
6388
}
6489
}

0 commit comments

Comments
 (0)