Skip to content

Commit

Permalink
Fix bugs in optimizer and trainer
Browse files Browse the repository at this point in the history
Fix several bugs in optimizer and trainer. But the weights still brow
up now.
  • Loading branch information
suquark committed Jan 27, 2017
1 parent d355661 commit 993a45c
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 35 deletions.
2 changes: 2 additions & 0 deletions layers/conv.js
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ class ConvLayer extends Layer {

this.filters = json.filters.map(getVolFromJSON);
this.biases = getVolFromJSON(json.biases);
// record updated values for updating
this.updated = this.filters.concat([this.biases]);
}

}
Expand Down
4 changes: 3 additions & 1 deletion layers/fullconn.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ class FullyConnLayer extends OutputLayer {
// initializations
this.num_inputs = opt.in_sx * opt.in_sy * opt.in_depth;
let bias = getopt(opt, 'bias_pref', 0.0);

this.filters = createMatrix(this.out_depth, this.num_inputs);
this.biases = createVector(this.out_depth, bias);

// record updated values for updating
this.updated = this.filters.concat([this.biases]);
}
Expand Down Expand Up @@ -77,6 +77,8 @@ class FullyConnLayer extends OutputLayer {

this.filters = json.filters.map(getVolFromJSON);
this.biases = getVolFromJSON(json.biases);
// record updated values for updating
this.updated = this.filters.concat([this.biases]);
}
}

Expand Down
22 changes: 11 additions & 11 deletions optimizer.js
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Adam {
this.xsum[j] = this.xsum[j] * this.beta2 + (1 - this.beta2) * g[j] * g[j]; // update biased second moment estimate
let biasCorr1 = this.gsum[j] * (1 - Math.pow(this.beta1, this.k)); // correct bias first moment estimate
let biasCorr2 = this.xsum[j] * (1 - Math.pow(this.beta2, this.k)); // correct bias second moment estimate
dx[j] = -this.learning_rate * biasCorr1 / (Math.sqrt(biasCorr2) + this.eps);
dx[j] = -this.learning_rate * biasCorr1 / (Math.sqrt(biasCorr2) + eps);
}
return dx;
}
Expand All @@ -41,7 +41,7 @@ class Adagrad {
let dx = new Array(g.length);
for (let j = 0; j < g.length; j++) {
this.gsum[j] += g[j] * g[j];
dx[j] = - this.learning_rate / Math.sqrt(gsum[j] + this.eps) * g;
dx[j] = - this.learning_rate / Math.sqrt(gsum[j] + eps) * g;
}
return dx;
}
Expand All @@ -61,9 +61,9 @@ class Windowgrad {
// adagrad update
let dx = new Array(g.length);
for (let j = 0; j < g.length; j++) {
gsum[j] = this.ro * gsum[j] + (1 - this.ro) * g[j] * g[j];
this.gsum[j] = this.ro * this.gsum[j] + (1 - this.ro) * g[j] * g[j];
// eps added for better conditioning
dx[j] = - this.learning_rate / Math.sqrt(gsum[j] + this.eps) * g[j];
dx[j] = - this.learning_rate / Math.sqrt(this.gsum[j] + eps) * g[j];
}
return dx;
}
Expand All @@ -82,9 +82,9 @@ class Adadelta {
let dx = new Array(g.length);
for (let j = 0; j < g.length; j++) {
let gij = g[j];
gsum[j] = this.ro * gsum[j] + (1-this.ro) * gij * gij;
dx[j] = - Math.sqrt((xsum[j] + this.eps)/(gsum[j] + this.eps)) * gij;
xsum[j] = this.ro * xsum[j] + (1-this.ro) * dx[j] * dx[j]; // yes, xsum lags behind gsum by 1.
this.gsum[j] = this.ro * this.gsum[j] + (1 - this.ro) * gij * gij;
dx[j] = - Math.sqrt((this.xsum[j] + eps) / (this.gsum[j] + eps)) * gij;
this.xsum[j] = this.ro * this.xsum[j] + (1 - this.ro) * dx[j] * dx[j]; // yes, xsum lags behind gsum by 1.
}
return dx;
}
Expand All @@ -104,8 +104,8 @@ class Nesterov {
for (let j = 0; j < g.length; j++) {
let gij = g[j];
let d = gsumi[j];
gsumi[j] = gsumi[j] * this.momentum + this.learning_rate * gij;
dx[j] = this.momentum * d - (1.0 + this.momentum) * gsumi[j];
this.gsumi[j] = this.gsumi[j] * this.momentum + this.learning_rate * gij;
dx[j] = this.momentum * d - (1.0 + this.momentum) * this.gsumi[j];
}
return dx;
}
Expand All @@ -125,8 +125,8 @@ class SGD {
if (this.momentum > 0.0) {
for (let j = 0; j < g.length; j++) {
// momentum update
dx[j] = this.momentum * gsumi[j] - this.learning_rate * g[j]; // step
gsumi[j] = dx[j]; // back this up for next iteration of momentum
dx[j] = this.momentum * this.gsumi[j] - this.learning_rate * g[j]; // step
this.gsumi[j] = dx[j]; // back this up for next iteration of momentum
}
} else {
// vanilla sgd
Expand Down
2 changes: 1 addition & 1 deletion regularization.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Regularization {
var l1_decay = this.l1_decay * l1_decay_mul;

// using map will too slow for closure, etc ...
var lgrad = new Array(plist);
var lgrad = new Array(plist.length);
for (let i in plist) {
let p = plist[i];
this.l2_decay_loss += l2_decay * p * p / 2; // accumulate weight decay loss
Expand Down
45 changes: 40 additions & 5 deletions rl_demo.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,33 @@
import { Vol, Net, Trainer, DQN, World, Agent } from 'convnet.js';

// Load JSON text from server hosted file and return JSON parsed object
function loadJSON(filePath) {
// Load json file;
var json = loadTextFileAjaxSync(filePath, "application/json");
// Parse json
return JSON.parse(json);
}

// Load text with Ajax synchronously: takes path to file and optional MIME type
function loadTextFileAjaxSync(filePath, mimeType)
{
var xmlhttp=new XMLHttpRequest();
xmlhttp.open("GET",filePath,false);
if (mimeType != null) {
if (xmlhttp.overrideMimeType) {
xmlhttp.overrideMimeType(mimeType);
}
}
xmlhttp.send();
if (xmlhttp.status==200)
{
return xmlhttp.responseText;
}
else {
// TODO Throw exception
return null;
}
}

function create_world(H, W) {
var num_inputs = 27; // 9 eyes, each sees 3 numbers (wall, green, red thing proximity)
Expand All @@ -19,17 +47,24 @@ function create_world(H, W) {
var net = new Net();
net.makeLayers(layer_defs);

net.fromJSON(loadJSON("RL/helper/saving.json"));

// $.getJSON("RL/helper/saving.json", function(json) {
// net.fromJSON(json);
// stoplearn();
// goveryfast();
// });

// options for the Temporal Difference learner that trains the above net
// by backpropping the temporal difference learning rule.
// var tdtrainer_options = {learning_rate:0.02, momentum:0.0,
// batch_size:64, l2_decay:0.01,
// method: 'adam'};

var tdtrainer_options = {batch_size:64, l2_decay:0.01, method: 'adam'};

var tdtrainer = new Trainer(net, tdtrainer_options);
var tdtrainer_options = {learning_rate:0.001, momentum:0.0, batch_size:64, l2_decay:0.05};

// var tdtrainer_options = { batch_size:64, l2_decay:0.01, method: 'adam' };


var tdtrainer = new Trainer(net, tdtrainer_options);

var dqn_options = {
temporal_window: temporal_window,
Expand Down
27 changes: 12 additions & 15 deletions test_rl.html
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@



<script src="transpiler/system.js"></script>
<script src="/transpiler/system.js"></script>
<script>
SystemJS.config({
map: {
Expand Down Expand Up @@ -125,29 +125,26 @@
}
canvas = document.getElementById("canvas");
w = create_world(canvas.width, canvas.height);
//stoplearn();
//goslow();
startlearn();
goveryfast();

$.getJSON("RL/helper/saving.json", function(json) {
w.agent.brain.value_net.fromJSON(json);
stoplearn();
gonormal();
});
stoplearn();
// goslow();
// startlearn();
gofast();

// $.getJSON("RL/helper/saving.json", function(json) {
// w.agent.brain.value_net.fromJSON(json);
// stoplearn();
// goveryfast();
// });

}

function loadfile() {
w.agent.brain.value_net.fromJSON(j);
stoplearn(); // also stop learning
gonormal();
}




start();


</script>

Expand Down
4 changes: 2 additions & 2 deletions trainer.js
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ class Trainer {
let p = pg.params, g = pg.grads;
let batch_grad = this.regular.get_punish(p, pg.l2_decay_mul, pg.l1_decay_loss);
// make raw batch gradient
for (let i = 0; i < p.length; i++) {
for (let i in p) {
batch_grad[i] = (batch_grad[i] + g[i]) / this.batch_size;
}
let update = pg.optimizer.grad(batch_grad);
// perform an update for all sets of weights
for (let i = 0; i < p.length; i++) p[i] += update[i];
for (let i in p) p[i] += update[i];
}
}

Expand Down

0 comments on commit 993a45c

Please sign in to comment.