Skip to content

Commit 62e2e7b

Browse files
committed
Update with some work on Training
1 parent c199fba commit 62e2e7b

2 files changed

Lines changed: 93 additions & 0 deletions

File tree

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
//
2+
// Port of the checkpointable code from Python.
3+
//
4+
// Authors:
5+
// Miguel de Icaza
6+
//
7+
using System;
8+
using System.Collections.Generic;
9+
10+
namespace TensorFlow
11+
{
12+
/// <summary>
13+
/// Checkpointable reference.
14+
/// </summary>
15+
public class CheckpointableReference
16+
{
17+
/// <summary>
18+
/// Local name for the dependency
19+
/// </summary>
20+
/// <value>The name.</value>
21+
public string Name { get; private set; }
22+
23+
/// <summary>
24+
/// The Checkpointable object being referenced.
25+
/// </summary>
26+
/// <value>The reference.</value>
27+
public CheckpointableBase Reference { get; private set; }
28+
29+
public CheckpointableReference (string name, CheckpointableBase reference)
30+
{
31+
Name = name;
32+
Reference = reference;
33+
}
34+
}
35+
36+
/// <summary>
37+
/// Indicates a position within a Checkpoint
38+
/// </summary>
39+
public class CheckpointPosition
40+
{
41+
}
42+
43+
/// <summary>
44+
/// Base class for `Checkpointable` objects without automatic dependencies.
45+
/// </summary>
46+
/// <remarks>
47+
/// Dependencies must be added explicitly, unless attribute assignment
48+
/// is performance-critical use <see cref="T:TensorFlow.Checkpointable"/>
49+
/// </remarks>
50+
public class CheckpointableBase
51+
{
52+
List<CheckpointableReference> _unconditional_checkpoint_dependencies;
53+
Dictionary<string, CheckpointableReference> _unconditional_dependency_names;
54+
Dictionary<string, CheckpointPosition> _deferred_dependencies;
55+
int update_uid;
56+
57+
public void MaybeInitializeCheckpointTable ()
58+
{
59+
// If we have already been initialized
60+
if (_unconditional_checkpoint_dependencies != null)
61+
return;
62+
63+
_unconditional_checkpoint_dependencies = new List<CheckpointableReference> ();
64+
_unconditional_dependency_names = new Dictionary<string, CheckpointableReference> ();
65+
_deferred_dependencies = new Dictionary<string, CheckpointPosition> ();
66+
update_uid = -1;
67+
}
68+
69+
/// <summary>
70+
/// All dependencies for this object
71+
/// </summary>
72+
/// <value>A list of CheckpointableReference objects indicating named Checkpointable dependencies which should be saved along with this object.</value>
73+
/// <remarks>
74+
/// <para>
75+
/// May be overridden to include conditional dependencies.
76+
/// </para>
77+
/// <para>
78+
/// </para>
79+
/// </remarks>
80+
public virtual IList<CheckpointableReference> CheckPointDependencies => _unconditional_checkpoint_dependencies;
81+
82+
/// <summary>
83+
/// Look up a dependency by name, may be overridden to include conditional dependencies.
84+
/// </summary>
85+
/// <param name="name">Name.</param>
86+
public virtual CheckpointableReference LookupDependency (string name) => _unconditional_dependency_names.TryGetValue (name, out var res) ? res : null;
87+
88+
89+
}
90+
}

TensorFlowSharp/Training/Optimizer.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ protected override object UpdateOp (Optimizer optimizer, TFGraph graph)
9191
}
9292
}
9393

94+
/// <summary>
95+
/// Optimizer.
96+
/// </summary>
9497
public class Optimizer {
9598

9699
// What should the parameter be? The Python code calls

0 commit comments

Comments
 (0)