/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.initializer;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.initializer.Initializer;

public class TruncatedNormalInitializer
implements Initializer {
    private final float sigma;

    public TruncatedNormalInitializer() {
        this(0.01f);
    }

    public TruncatedNormalInitializer(float sigma) {
        this.sigma = sigma;
    }

    @Override
    public NDArray initialize(NDManager baseManager, Shape shape, DataType dataType) {
        long size = shape.size();
        if (size < 0L) {
            throw new IllegalArgumentException("Shape is not determined.");
        }
        NDManager manager = baseManager.newSubManager();
        NDArray result = manager.create(new float[0], new Shape(0L));
        int steps = 0;
        NDArray lowerBound = manager.create(-2.0f * this.sigma);
        NDArray upperBound = manager.create(2.0f * this.sigma);
        while (result.size() < size) {
            NDArray newResult;
            long samplesToCreate = (long)((double)(size - result.size()) * 1.1);
            NDArray normalDistribution = manager.randomNormal(0.0f, this.sigma, new Shape(samplesToCreate), dataType, manager.getDevice());
            NDArray larger2Sigma = normalDistribution.gt(lowerBound);
            NDArray smaller2Sigma = normalDistribution.lt(upperBound);
            NDArray withinBounds = larger2Sigma.logicalAnd(smaller2Sigma);
            NDArray truncatedNormalDistribution = normalDistribution.get(withinBounds);
            result = newResult = result.concat(truncatedNormalDistribution);
            if (++steps <= 10) continue;
            throw new IllegalStateException("Initialization of truncated normal takes too long - This is incredibly unlikely, something must be seriously wrong.");
        }
        result = result.get(new NDIndex().addSliceDim(0L, size));
        result = result.reshape(shape);
        result.attach(baseManager);
        manager.close();
        return result;
    }
}

