Skip to content

Commit dd1eab5

Browse files
committed
Forgot to commit these convenience methods
1 parent a06b77a commit dd1eab5

2 files changed

Lines changed: 79 additions & 3 deletions

File tree

SampleTest/SampleTest.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -529,8 +529,8 @@ void LinearRegression ()
529529

530530
var X = g.Placeholder (TFDataType.Float);
531531
var Y = g.Placeholder (TFDataType.Float);
532-
var W = g.Variable (new TFShape (rng.Next ()), TFDataType.Float, operName: "weight");
533-
var b = g.Variable (new TFShape (rng.Next ()), TFDataType.Float, operName: "bias");
532+
var W = g.Variable (g.Const (rng.Next ()), operName: "weight");
533+
var b = g.Variable (g.Const (rng.Next ()), operName: "bias");
534534

535535
var pred = g.Add (g.Mul (X, W), b);
536536

@@ -545,7 +545,7 @@ void LinearRegression ()
545545
var cost = g.Div (g.ReduceSum (g.Pow (g.Sub (pred, Y), g.Const (2))), g.Mul (g.Const (2), g.Const (n_samples)));
546546

547547

548-
548+
// STuck here: need gradient support
549549
}
550550
}
551551
#endif

TensorFlowSharp/OperationsExtras.cs

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Collections.Generic;
23
using System.Linq;
34

45
namespace TensorFlow
@@ -90,5 +91,80 @@ public TFOutput Variable (TFOutput initialValue, out TFOperation init, out TFOut
9091
}
9192
}
9293

94+
List<TFOperation> pending_init_variables;
95+
public void AddInitVariable (TFOperation variable)
96+
{
97+
if (pending_init_variables == null)
98+
pending_init_variables = new List<TFOperation> ();
99+
pending_init_variables.Add (variable);
100+
}
101+
102+
public TFOperation [] GetGlobalVariablesInitializer ()
103+
{
104+
var res = pending_init_variables.ToArray ();
105+
pending_init_variables.Clear ();
106+
return res;
107+
}
108+
109+
/// <summary>
110+
/// Variable node, with a starting initial value. Convenience that registers the init variable to a global queue.
111+
/// </summary>
112+
/// <param name="initialValue">Initial value.</param>
113+
/// <param name="value">Returns the value of the variable.</param>
114+
/// <param name="operName">Operation name, optional.</param>
115+
/// <returns>The returning TFOutput returns the handle to the variable.</returns>
116+
/// <remarks>
117+
/// Variables need to be initialized before the main execution so you will typically want to
118+
/// run the session on the variable.
119+
///
120+
/// The init sequence for the variable is stored in the graph, you must manually initialize
121+
/// those by running the session on the global variables.
122+
/// </remarks>
123+
public TFOutput Variable (TFOutput initialValue, out TFOutput value, string operName = null)
124+
{
125+
var scopeName = MakeName ("Variable", operName);
126+
127+
using (var newScope = WithScope (scopeName)) {
128+
var type = initialValue.OutputType;
129+
var handle = VarHandleOp (type, new TFShape (GetShape (initialValue)));
130+
using (var aScope = WithScope ("Assign")) {
131+
var init = AssignVariableOp (handle, initialValue);
132+
AddInitVariable (init);
133+
using (var rScope = WithScope ("Read")) {
134+
value = ReadVariableOp (handle, type);
135+
return handle;
136+
}
137+
}
138+
}
139+
}
140+
141+
/// <summary>
142+
/// Variable node, with a starting initial value. Convenience that registers the init variable to a global queue.
143+
/// </summary>
144+
/// <param name="initialValue">Initial value.</param>
145+
/// <param name="operName">Operation name, optional.</param>
146+
/// <returns>The returning TFOutput returns the handle to the variable.</returns>
147+
/// <remarks>
148+
/// Variables need to be initialized before the main execution so you will typically want to
149+
/// run the session on the variable.
150+
///
151+
/// The init sequence for the variable is stored in the graph, you must manually initialize
152+
/// those by running the session on the global variables.
153+
/// </remarks>
154+
public TFOutput Variable (TFOutput initialValue, string operName = null)
155+
{
156+
var scopeName = MakeName ("Variable", operName);
157+
158+
using (var newScope = WithScope (scopeName)) {
159+
var type = initialValue.OutputType;
160+
var handle = VarHandleOp (type, new TFShape (GetShape (initialValue)));
161+
using (var aScope = WithScope ("Assign")) {
162+
var init = AssignVariableOp (handle, initialValue);
163+
AddInitVariable (init);
164+
return handle;
165+
}
166+
}
167+
}
168+
93169
}
94170
}

0 commit comments

Comments
 (0)