/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.cv.translator;

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Mask;
import ai.djl.modality.cv.output.Point;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.Pipeline;
import ai.djl.translate.TranslatorContext;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Collections;
import java.util.List;

public class Sam2Translator
implements NoBatchifyTranslator<Sam2Input, DetectedObjects> {
    private static final float[] MEAN = new float[]{0.485f, 0.456f, 0.406f};
    private static final float[] STD = new float[]{0.229f, 0.224f, 0.225f};
    private Pipeline pipeline = new Pipeline();

    public Sam2Translator() {
        this.pipeline.add(new Resize(1024, 1024));
        this.pipeline.add(new ToTensor());
        this.pipeline.add(new Normalize(MEAN, STD));
    }

    @Override
    public NDList processInput(TranslatorContext ctx, Sam2Input input) throws Exception {
        Image image = input.getImage();
        int width = image.getWidth();
        int height = image.getHeight();
        ctx.setAttachment("width", width);
        ctx.setAttachment("height", height);
        List<Point> points = input.getPoints();
        int numPoints = points.size();
        float[] buf = input.toLocationArray(width, height);
        NDManager manager = ctx.getNDManager();
        NDArray array = image.toNDArray(manager, Image.Flag.COLOR);
        array = ((NDArray)this.pipeline.transform(new NDList(array)).get(0)).expandDims(0);
        NDArray locations = manager.create(buf, new Shape(1L, numPoints, 2L));
        NDArray labels = manager.create(input.getLabels());
        return new NDList(array, locations, labels);
    }

    @Override
    public DetectedObjects processOutput(TranslatorContext ctx, NDList list) throws Exception {
        NDArray logits = (NDArray)list.get(0);
        NDArray scores = ((NDArray)list.get(1)).squeeze(0);
        long best = scores.argMax().getLong(new long[0]);
        int width = (Integer)ctx.getAttachment("width");
        int height = (Integer)ctx.getAttachment("height");
        long[] size = new long[]{height, width};
        int mode = Image.Interpolation.BILINEAR.ordinal();
        logits = logits.getNDArrayInternal().interpolation(size, mode, false);
        NDArray masks = logits.gt(Float.valueOf(0.0f)).squeeze(0);
        float[][] dist = Mask.toMask(masks.get(best).toType(DataType.FLOAT32, true));
        Mask mask = new Mask(0.0, 0.0, width, height, dist, true);
        double probability = scores.getFloat(best);
        List<String> classes = Collections.singletonList("");
        List<Double> probabilities = Collections.singletonList(probability);
        List<BoundingBox> boxes = Collections.singletonList(mask);
        return new DetectedObjects(classes, probabilities, boxes);
    }

    public static final class Sam2Input {
        private Image image;
        private List<Point> points;
        private List<Integer> labels;

        public Sam2Input(Image image, List<Point> points, List<Integer> labels) {
            this.image = image;
            this.points = points;
            this.labels = labels;
        }

        public Image getImage() {
            return this.image;
        }

        public List<Point> getPoints() {
            return this.points;
        }

        float[] toLocationArray(int width, int height) {
            float[] ret = new float[this.points.size() * 2];
            int i = 0;
            for (Point point : this.points) {
                ret[i++] = (float)point.getX() / (float)width * 1024.0f;
                ret[i++] = (float)point.getY() / (float)height * 1024.0f;
            }
            return ret;
        }

        int[][] getLabels() {
            return new int[][]{this.labels.stream().mapToInt(Integer::intValue).toArray()};
        }

        public static Sam2Input newInstance(String url, int x, int y) throws IOException {
            Image image = ImageFactory.getInstance().fromUrl(url);
            List<Point> points = Collections.singletonList(new Point(x, y));
            List<Integer> labels = Collections.singletonList(1);
            return new Sam2Input(image, points, labels);
        }

        public static Sam2Input newInstance(Path path, int x, int y) throws IOException {
            Image image = ImageFactory.getInstance().fromFile(path);
            List<Point> points = Collections.singletonList(new Point(x, y));
            List<Integer> labels = Collections.singletonList(1);
            return new Sam2Input(image, points, labels);
        }
    }
}

