/* IMPORT */
import Buffer from '../buffer.js';
import AbstractHidden from './abstract_hidden.js';
/* MAIN */
class Dropout extends AbstractHidden {
    /* CONSTRUCTOR */
    constructor(options, prev) {
        super(options, prev);
        this.probability = options.probability ?? 0.5;
        this.dropped = new Buffer(this.osx * this.osy * this.osz);
    }
    /* API */
    forward(input, isTraining) {
        this.it = input;
        const output = input.clone();
        if (isTraining) {
            // do dropout
            for (let i = 0; i < input.length; i++) {
                if (Math.random() < this.probability) { // drop!
                    output.w[i] = 0;
                    this.dropped[i] = 1;
                }
                else {
                    this.dropped[i] = 0;
                }
            }
        }
        else {
            // scale the activations during prediction
            for (let i = 0; i < input.length; i++) {
                output.w[i] *= this.probability;
            }
        }
        this.ot = output;
        return this.ot; // dummy identity function for now
    }
    backward() {
        const input = this.it;
        const output = this.ot;
        input.dw = new Buffer(input.length);
        for (let i = 0, l = input.length; i < l; i++) {
            if (!this.dropped[i]) {
                input.dw[i] = output.dw[i];
            }
        }
    }
}
/* EXPORT */
export default Dropout;
