Skip to content

Commit

Permalink
Fix a critical bug in regularization
Browse files Browse the repository at this point in the history
Fix a critical bug in regularization, it will cause weights blew up.
Also fix other bugs
  • Loading branch information
suquark committed Feb 13, 2017
1 parent 78fd151 commit 4e7a2c0
Show file tree
Hide file tree
Showing 9 changed files with 23 additions and 33 deletions.
12 changes: 2 additions & 10 deletions RL/deepqlearn.js
Original file line number Diff line number Diff line change
Expand Up @@ -122,18 +122,10 @@ class DQN {
// compute the value of doing any action in this state
// and return the argmax action and its value
let action_values = this.value_net.forward(new Vol(s));
let maxk = action_values.max_index();
let maxk = action_values.max_index;
return { action: maxk, value: action_values.w[maxk] };
}

_toarray(arr) {
let a = [];
for (let i in arr) {
a.push(arr[i]);
}
return a;
}

getNetInput(xt) {
// return s = (x, a, x, a, x, a, xt) state vector.
// It's a concatenation of last window_size (x,a) pairs and current state x
Expand All @@ -150,7 +142,7 @@ class DQN {
// we dont want weight regularization to undervalue this information, as it only exists once
let action1ofk = one_hot(this.num_actions, action, 1.0 * this.num_states);

w = w.concat(this._toarray(action1ofk)); // do not concat array & floatarray
w = w.concat(Array.prototype.slice.call(action1ofk)); // do not concat array & floatarray
}
return w;
}
Expand Down
12 changes: 4 additions & 8 deletions backend.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,13 @@ function TensorVectorProduct(ov, m, v) {
let ncol = m.axis(-1) | 0;
let nrow = m.axis(-2) | 0;
let new_shape = m.shape.slice(); new_shape.pop();
let bs = ncol * nrow | 0;
let N = (m.size / bs) | 0;
let N = (m.size / ncol) | 0;

let mw = m.w, vw = v.w, ow = ov.w;
ow.fill(0.);
for (let z = 0; z < N; z++) {
for (let i = 0; i < nrow; i++) {
for (let j = 0; j < ncol; j++) {
ow[z * nrow + i] += mw[z * bs + i * ncol + j] * vw[j];
}
for (let i = 0; i < N; i++) {
for (let j = 0; j < ncol; j++) {
ow[i] += mw[i * ncol + j] * vw[j];
}
}
}
Expand Down Expand Up @@ -49,7 +46,6 @@ function TransposedTensorVectorProductAdd(ov, m, v) {
}



/**
* HadmardProduct apply to self
*/
Expand Down
2 changes: 1 addition & 1 deletion layers/layer.js
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Layer {

compile(options) {
// setup objects for training
this.updated.forEach(function(V) {
this.updated.forEach((V) => {
V.dw = V.zeros_like();
V.optimizer = get_optimizer(V.size, options);
if (V.allow_regl) V.regularizer = new Regularization(options.l2_decay, options.l1_decay, this.l2_decay_mul, this.l1_decay_mul);
Expand Down
4 changes: 2 additions & 2 deletions objective.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
function meanSquaredError(x, y) {
let N = x.size;
let loss = 0.;
let aw = a.w, yw = x.w, adw = a.dw;
let xw = x.w, yw = y.w, xdw = x.dw;
for (let i = 0; i < N; i++) {
let dx = aw[i] - yw[i];
let dx = xw[i] - yw[i];
xdw[i] += dx;
loss += 0.5 * dx * dx;
}
Expand Down
2 changes: 1 addition & 1 deletion regularization.js
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Regularization {
decay_loss += l1_decay * Math.abs(p);
let l1grad = l1_decay * (p > 0 ? 1 : -1);
let l2grad = l2_decay * (p);
dx[i] -= (l2grad + l1grad);
dx[i] += (l2grad + l1grad);
}
return decay_loss;
}
Expand Down
2 changes: 1 addition & 1 deletion topology/vallia.js
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class Net {

// this is a convenience function for returning the argmax
// return index of the class with highest class probability
get prediction(x) {
prediction(x) {
if (typeof x !== 'undefined') this.forward(x);
// assume output is a vector
return this.output.max_index;
Expand Down
4 changes: 3 additions & 1 deletion trainer.js
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class Trainer {
let updates = this.net.trainables;
for (let i in updates) {
let T = updates[i];


if (T.regularizer) regular_loss += T.regularizer.punish(T);
// make raw batch gradient
T.batchGrad(this.batch_size);
Expand All @@ -51,7 +53,7 @@ class Trainer {

return {
fwd_time: timer.getTime('forward'),
bwd_time: getTime('backward'),
bwd_time: timer.getTime('backward'),

regular_loss: regular_loss,
cost_loss: cost_loss,
Expand Down
16 changes: 8 additions & 8 deletions util/timing.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,25 @@ class Timer {
constructor() {
this.lasttime = {};
this.sum = {};
if (performance.now) {
this.get_time = performance.now;
} else {
this.get_time = new Date.now;
}
// if (performance.now) {
// this.get_time = performance.now;
// } else {
// this.get_time = new Date.now;
// }
}

start(name) {
if (!this.sum[name]) this.sum[name]
this.lastname = name;
lasttime[name] = this.get_time();
this.lasttime[name] = performance.now();
}

stop(name) {
this.sum[name] += this.get_time() - this.lasttime[name];
this.sum[name] += performance.now() - this.lasttime[name];
}

stoplast() {
this.sum[this.lastname] += this.get_time() - this.lasttime[this.lastname];
this.stop(this.lastname);
}

passto(name) {
Expand Down
2 changes: 1 addition & 1 deletion vol.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Vol {
}
}

// this.dw = this.zeros_like(); // -- save memory, allocmem at training?
this.dw = this.zeros_like(); // -- save memory, allocmem at training?
this.length = this.size;
}

Expand Down

0 comments on commit 4e7a2c0

Please sign in to comment.