Skip to content

Commit 37f4539

Browse files
zeahmedmigueldeicaza
authored andcommitted
Setting variable shared_name property to avoid variables sharing. (migueldeicaza#396)
1 parent ae900e3 commit 37f4539

3 files changed

Lines changed: 66 additions & 19 deletions

File tree

TensorFlowSharp/OperationsExtras.cs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ Variable MakeVariable (TFOutput initialValue, bool trainable, string operName)
119119

120120
using (var newScope = WithScope (scopeName)) {
121121
var type = initialValue.OutputType;
122-
var variableHandle = VarHandleOp (type, new TFShape (GetShape (initialValue)));
122+
var variableHandle = VarHandleOp (type, new TFShape (GetShape (initialValue)), shared_name: operName);
123123
using (var aScope = WithScope ("Assign")) {
124124
var assignOp = AssignVariableOp (variableHandle, initialValue);
125125
using (var rScope = WithScope ("Read")) {
@@ -208,22 +208,22 @@ public TFOperation [] GetGlobalVariablesInitializer ()
208208
return res;
209209
}
210210

211-
/// <summary>
212-
/// Variable node, with a starting initial value. Convenience that registers the init variable to a global queue.
213-
/// </summary>
214-
/// <param name="initialValue">Initial value.</param>
215-
/// <param name="value">Returns the value of the variable.</param>
216-
/// <param name="trainable">If true, this add the variable to the graph's TrainableVariables, this collection is intended to be used by the Optimizer classes.</param>
217-
/// <param name="operName">Operation name, optional.</param>
218-
/// <returns>The returning Variable contains the variable, with three nodes with the operations making up the variable assignment.</returns>
219-
/// <remarks>
220-
/// Variables need to be initialized before the main execution so you will typically want to
221-
/// run the session on the variable.
222-
///
223-
/// The init sequence for the variable is stored in the graph, you must manually initialize
224-
/// those by running the session on the global variables.
225-
/// </remarks>
226-
public Variable Variable (TFOutput initialValue, out TFOutput value, bool trainable = true, string operName = null)
211+
/// <summary>
212+
/// Variable node, with a starting initial value. Convenience that registers the init variable to a global queue.
213+
/// </summary>
214+
/// <param name="initialValue">Initial value.</param>
215+
/// <param name="value">Returns the value of the variable.</param>
216+
/// <param name="trainable">If true, this add the variable to the graph's TrainableVariables, this collection is intended to be used by the Optimizer classes.</param>
217+
/// <param name="operName">Operation name, optional.</param>
218+
/// <returns>The returning Variable contains the variable, with three nodes with the operations making up the variable assignment.</returns>
219+
/// <remarks>
220+
/// Variables need to be initialized before the main execution so you will typically want to
221+
/// run the session on the variable.
222+
///
223+
/// The init sequence for the variable is stored in the graph, you must manually initialize
224+
/// those by running the session on the global variables.
225+
/// </remarks>
226+
public Variable Variable (TFOutput initialValue, out TFOutput value, bool trainable = true, string operName = null)
227227
{
228228
var nv = MakeVariable (initialValue, trainable, operName);
229229
value = nv.Read;

tests/TensorFlowSharp.Tests.CSharp/TensorFlowSharp.Tests.CSharp.csproj

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
<DefineConstants>DEBUG;TRACE</DefineConstants>
2626
<ErrorReport>prompt</ErrorReport>
2727
<WarningLevel>4</WarningLevel>
28-
<DocumentationFile></DocumentationFile>
28+
<DocumentationFile>
29+
</DocumentationFile>
2930
</PropertyGroup>
3031
<PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|AnyCPU' ">
3132
<DebugType>pdbonly</DebugType>
@@ -34,7 +35,8 @@
3435
<DefineConstants>TRACE</DefineConstants>
3536
<ErrorReport>prompt</ErrorReport>
3637
<WarningLevel>4</WarningLevel>
37-
<AssemblyName></AssemblyName>
38+
<AssemblyName>
39+
</AssemblyName>
3840
</PropertyGroup>
3941
<ItemGroup>
4042
<Reference Include="Microsoft.DotNet.InternalAbstractions, Version=1.0.500.0, Culture=neutral, PublicKeyToken=adb9793829ddae60, processorArchitecture=MSIL">
@@ -79,6 +81,7 @@
7981
<Compile Include="PartialRunTests.cs" />
8082
<Compile Include="Properties\AssemblyInfo.cs" />
8183
<Compile Include="TestUtils.cs" />
84+
<Compile Include="VariableTests.cs" />
8285
</ItemGroup>
8386
<ItemGroup>
8487
<None Include="packages.config" />
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
using System;
2+
using System.Collections;
3+
using System.Collections.Generic;
4+
using TensorFlow;
5+
using Xunit;
6+
7+
8+
namespace TensorFlowSharp.Tests.CSharp
9+
{
10+
public class VariableTests
11+
{
12+
[Fact]
13+
public void ShouldNotShareVariablesSameType()
14+
{
15+
using (var graph = new TFGraph())
16+
{
17+
var v1 = graph.Variable(graph.Const(0.5f), operName: "v1");
18+
var v2 = graph.Variable(graph.Const(0.6f), operName: "v2");
19+
20+
using (var session = new TFSession(graph))
21+
{
22+
var result = session.GetRunner().AddTarget(graph.GetGlobalVariablesInitializer()).Fetch(v1.Read, v2.Read).Run();
23+
Assert.NotEqual(result[0].GetValue(), result[1].GetValue());
24+
}
25+
}
26+
}
27+
28+
[Fact]
29+
public void ShouldNotShareVariablesDifferentType()
30+
{
31+
using (var graph = new TFGraph())
32+
{
33+
var v1 = graph.Variable(graph.Const(0.5f), operName: "v1");
34+
var v2 = graph.Variable(graph.Const(0L), operName: "v2");
35+
36+
using (var session = new TFSession(graph))
37+
{
38+
var result = session.GetRunner().AddTarget(graph.GetGlobalVariablesInitializer()).Fetch(v1.Read, v2.Read).Run();
39+
Assert.NotEqual(result[0].TensorType, result[1].TensorType);
40+
}
41+
}
42+
}
43+
}
44+
}

0 commit comments

Comments
 (0)