Skip to content

Commit bb50d3a

Browse files
committed
Make Variable a class that holds the three convenience operations involved
1 parent 3f08892 commit bb50d3a

5 files changed

Lines changed: 135 additions & 58 deletions

File tree

SampleTest/SampleTest.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -291,10 +291,9 @@ void LinearRegression ()
291291
var X = g.Placeholder (TFDataType.Float);
292292
var Y = g.Placeholder (TFDataType.Float);
293293

294-
TFOutput readW, readB;
295-
var W = g.Variable (g.Const ((float)rng.Next ()), out readW, operName: "weight");
296-
var b = g.Variable (g.Const ((float) rng.Next ()), out readB, operName: "bias");
297-
var pred = g.Add (g.Mul (X, readW, "x*w"), readB);
294+
var W = g.Variable (g.Const ((float)rng.Next ()), operName: "weight");
295+
var b = g.Variable (g.Const ((float) rng.Next ()), operName: "bias");
296+
var pred = g.Add (g.Mul (X, W.Read, "x*w"), b.Read);
298297

299298
var first = g.Pow (g.Sub (pred, Y), g.Const ((float)2));
300299
var cost = g.Div (g.ReduceSum (g.Pow (g.Sub (pred, Y), g.Const (2f))), g.Mul (g.Const (2f), g.Const ((float)n_samples), "2*n_samples"));

TensorFlowSharp/OperationsExtras.cs

Lines changed: 41 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,30 @@ public TFOutput ReduceMean (TFOutput input, TFOutput? axis = null, bool? keep_di
8585
return this.Mean (input, this.ReduceDims (input, axis), keep_dims, operName);
8686
}
8787

88+
// Helper method to create a variable and track it.
89+
Variable MakeVariable (TFOutput initialValue, bool trainable, string operName)
90+
{
91+
var scopeName = MakeName ("Variable", operName);
92+
93+
using (var newScope = WithScope (scopeName)) {
94+
var type = initialValue.OutputType;
95+
var variableHandle = VarHandleOp (type, new TFShape (GetShape (initialValue)));
96+
using (var aScope = WithScope ("Assign")) {
97+
var assignOp = AssignVariableOp (variableHandle, initialValue);
98+
using (var rScope = WithScope ("Read")) {
99+
var readHandle = ReadVariableOp (variableHandle, type);
100+
101+
var nv = new Variable (variableHandle, readHandle, assignOp);
102+
if (trainable)
103+
AddTrainableVariable (nv);
104+
AddInitVariable (assignOp);
105+
return nv;
106+
}
107+
}
108+
}
109+
110+
}
111+
88112
/// <summary>
89113
/// Variable node, with a starting initial value.
90114
/// </summary>
@@ -93,32 +117,21 @@ public TFOutput ReduceMean (TFOutput input, TFOutput? axis = null, bool? keep_di
93117
/// <param name="value">Returns the value of the variable.</param>
94118
/// <param name="trainable">If true, this add the variable to the graph's TrainableVariables, this collection is intended to be used by the Optimizer classes.</param>
95119
/// <param name="operName">Operation name, optional.</param>
96-
/// <returns>The returning TFOutput returns the handle to the variable.</returns>
120+
/// <returns>The returning Variable contains the variable, with three nodes with the operations making up the variable assignment.</returns>
97121
/// <remarks>
98122
/// Variables need to be initialized before the main execution so you will typically want to
99123
/// run the session on the variable
100124
/// </remarks>
101-
public TFOutput Variable (TFOutput initialValue, out TFOperation init, out TFOutput value, bool trainable = true, string operName = null)
125+
public Variable Variable (TFOutput initialValue, out TFOperation init, out TFOutput value, bool trainable = true, string operName = null)
102126
{
103-
var scopeName = MakeName ("Variable", operName);
104-
105-
using (var newScope = WithScope (scopeName)) {
106-
var type = initialValue.OutputType;
107-
var handle = VarHandleOp (type, new TFShape (GetShape (initialValue)));
108-
using (var aScope = WithScope ("Assign")) {
109-
init = AssignVariableOp (handle, initialValue);
110-
if (trainable)
111-
AddTrainableVariable (handle.Operation);
112-
using (var rScope = WithScope ("Read")) {
113-
value = ReadVariableOp (handle, type);
114-
return handle;
115-
}
116-
}
117-
}
127+
var nv = MakeVariable (initialValue, trainable, operName);
128+
init = nv.Assign;
129+
value = nv.Read;
130+
return nv;
118131
}
119132

120133
List<TFOperation> pending_init_variables;
121-
List<TFOperation> trainable_variables;
134+
List<Variable> trainable_variables;
122135

123136
/// <summary>
124137
/// Registers a specified variable as an initialization variable.
@@ -144,10 +157,10 @@ public void AddInitVariable (TFOperation variable)
144157
}
145158

146159
// TODO: finalize semantics, when should we clear these?
147-
internal void AddTrainableVariable (TFOperation variable)
160+
internal void AddTrainableVariable (Variable variable)
148161
{
149162
if (trainable_variables == null)
150-
trainable_variables = new List<TFOperation> ();
163+
trainable_variables = new List<Variable> ();
151164
trainable_variables.Add (variable);
152165
}
153166

@@ -173,32 +186,19 @@ public TFOperation [] GetGlobalVariablesInitializer ()
173186
/// <param name="value">Returns the value of the variable.</param>
174187
/// <param name="trainable">If true, this add the variable to the graph's TrainableVariables, this collection is intended to be used by the Optimizer classes.</param>
175188
/// <param name="operName">Operation name, optional.</param>
176-
/// <returns>The returning TFOutput returns the handle to the variable.</returns>
189+
/// <returns>The returning Variable contains the variable, with three nodes with the operations making up the variable assignment.</returns>
177190
/// <remarks>
178191
/// Variables need to be initialized before the main execution so you will typically want to
179192
/// run the session on the variable.
180193
///
181194
/// The init sequence for the variable is stored in the graph, you must manually initialize
182195
/// those by running the session on the global variables.
183196
/// </remarks>
184-
public TFOutput Variable (TFOutput initialValue, out TFOutput value, bool trainable = true, string operName = null)
197+
public Variable Variable (TFOutput initialValue, out TFOutput value, bool trainable = true, string operName = null)
185198
{
186-
var scopeName = MakeName ("Variable", operName);
187-
188-
using (var newScope = WithScope (scopeName)) {
189-
var type = initialValue.OutputType;
190-
var handle = VarHandleOp (type, new TFShape (GetShape (initialValue)));
191-
using (var aScope = WithScope ("Assign")) {
192-
var init = AssignVariableOp (handle, initialValue);
193-
AddInitVariable (init);
194-
if (trainable)
195-
AddTrainableVariable (handle.Operation);
196-
using (var rScope = WithScope ("Read")) {
197-
value = ReadVariableOp (handle, type);
198-
return handle;
199-
}
200-
}
201-
}
199+
var nv = MakeVariable (initialValue, trainable, operName);
200+
value = nv.Read;
201+
return nv;
202202
}
203203

204204
/// <summary>
@@ -207,29 +207,17 @@ public TFOutput Variable (TFOutput initialValue, out TFOutput value, bool traina
207207
/// <param name="initialValue">Initial value.</param>
208208
/// <param name="trainable">If true, this add the variable to the graph's TrainableVariables, this collection is intended to be used by the Optimizer classes.</param>
209209
/// <param name="operName">Operation name, optional.</param>
210-
/// <returns>The returning TFOutput returns the handle to the variable, this is a VarHandleOp, if you want to read it, create a ReadVariableOp on result.</returns>
210+
/// <returns>The returning Variable contains the variable, with three nodes with the operations making up the variable assignment.</returns>
211211
/// <remarks>
212212
/// Variables need to be initialized before the main execution so you will typically want to
213213
/// run the session on the variable.
214214
///
215215
/// The init sequence for the variable is stored in the graph, you must manually initialize
216216
/// those by running the session on the global variables.
217217
/// </remarks>
218-
public TFOutput Variable (TFOutput initialValue, bool trainable = true, string operName = null)
218+
public Variable Variable (TFOutput initialValue, bool trainable = true, string operName = null)
219219
{
220-
var scopeName = MakeName ("Variable", operName);
221-
using (var newScope = WithScope (scopeName)) {
222-
var type = initialValue.OutputType;
223-
224-
var handle = VarHandleOp (type, new TFShape (GetShape (initialValue)));
225-
using (var aScope = WithScope ("Assign")) {
226-
var init = AssignVariableOp (handle, initialValue);
227-
AddInitVariable (init);
228-
if (trainable)
229-
AddTrainableVariable (handle.Operation);
230-
return handle;
231-
}
232-
}
220+
return MakeVariable (initialValue, trainable, operName);
233221
}
234222

235223
//

TensorFlowSharp/TensorFlowSharp.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ In addition to bringing the 1.3 API, this contains the following high-level APIs
7575
<Compile Include="OperationsExtras.cs" />
7676
<Compile Include="Buffer.cs" />
7777
<Compile Include="Tensor.cs" />
78+
<Compile Include="Variable.cs" />
7879
</ItemGroup>
7980
<ItemGroup>
8081
<PackageReference Include="NuGet.Build.Packaging">

TensorFlowSharp/Tensorflow.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -781,7 +781,7 @@ public TFScope WithScope (string nameScopeDesc)
781781

782782
Dictionary<string, int> values = new Dictionary<string, int> ();
783783

784-
string MakeName (string operName, string userName)
784+
internal string MakeName (string operName, string userName)
785785
{
786786
if (userName == null) {
787787
var k = CurrentNameScope == "" ? operName : CurrentNameScope + "/" + operName;

TensorFlowSharp/Variable.cs

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
//
2+
// TensorFlow.cs; Bindings to the TensorFlow C API for .NET
3+
//
4+
// Authors:
5+
// Miguel de Icaza ([email protected])
6+
//
7+
using System;
8+
using System.Numerics;
9+
using System.Runtime.InteropServices;
10+
using System.Text;
11+
using size_t = System.UIntPtr;
12+
using TF_Tensor = System.IntPtr;
13+
14+
namespace TensorFlow
15+
{
16+
/// <summary>
17+
/// The Variable class holds the TFOutput nodes that are used to initialize, read and assign a value to a variable.
18+
/// </summary>
19+
/// <remarks>
20+
/// A variable maintains state in the graph across calls to `run()`. You add a
21+
/// variable to the graph by constructing an instance of the class `Variable`.
22+
///
23+
/// The `Variable()` constructor requires an initial value for the variable,
24+
/// which can be a `Tensor` of any type and shape. The initial value defines the
25+
/// type and shape of the variable. After construction, the type and shape of
26+
/// the variable are fixed. The value can be changed using one of the assign
27+
/// methods.
28+
///
29+
/// When a variable is created a VarHandleOp is created which is returned as
30+
/// the VariableOp property, an assign operation is created that can be accessed
31+
/// using the assignHandle and you can read the value of the variable using the
32+
/// ReadHandle.
33+
///
34+
/// When you launch the graph, variables have to be explicitly initialized before
35+
/// you can run Ops that use their value. You can initialize a variable by
36+
/// running its *initializer op*, restoring the variable from a save file, or
37+
/// simply running an `assign` Op that assigns a value to the variable. In fact,
38+
/// the variable *initializer op* is just an `assign` Op that assigns the
39+
/// variable's initial value to the variable itself.
40+
///
41+
/// There is an implicit conversion from the Variable into the VarHandleOp if
42+
/// used.
43+
/// </remarks>
44+
public class Variable
45+
{
46+
TFOutput variableHandle;
47+
TFOutput readHandle;
48+
TFOperation assignOp;
49+
50+
/// <summary>
51+
/// Returns the ReadVariableOp that is used to fetch the value of the variable from the graph.
52+
/// </summary>
53+
/// <value>The read op.</value>
54+
public TFOutput Read => readHandle;
55+
56+
/// <summary>
57+
/// Returns the AssignVariableOp that is used to assign the initial value to the variable from the graph.
58+
/// </summary>
59+
/// <value>The assign op.</value>
60+
public TFOperation Assign => assignOp;
61+
62+
/// <summary>
63+
/// Returns the VarHandleOp that was created using the shape of the initial value.
64+
/// </summary>
65+
/// <value>The variable op.</value>
66+
public TFOutput VariableOp => variableHandle;
67+
68+
internal Variable (TFOutput variableHandle, TFOutput readHandle, TFOperation assignOp)
69+
{
70+
this.variableHandle = variableHandle;
71+
this.readHandle = readHandle;
72+
this.assignOp = assignOp;
73+
}
74+
75+
/// <summary>
76+
/// Returns the VarHandleOp (the VariableOp property).
77+
/// </summary>
78+
/// <returns>The variable handle created for the variable.</returns>
79+
/// <param name="variable">Variable reference.</param>
80+
/// <remarks>
81+
/// This implicit operator exists to preserve the compatibility with code that
82+
/// created Variables and expected the result to be the VariableOp.
83+
/// </remarks>
84+
public static implicit operator TFOutput (Variable variable)
85+
{
86+
return variable.VariableOp;
87+
}
88+
}
89+
}

0 commit comments

Comments
 (0)