/*
 * Decompiled with CFR 0.152.
 */
package org.eclipse.jetty.ee10.websocket.jakarta.common;

import jakarta.websocket.ClientEndpointConfig;
import jakarta.websocket.CloseReason;
import jakarta.websocket.Decoder;
import jakarta.websocket.EndpointConfig;
import jakarta.websocket.MessageHandler;
import jakarta.websocket.PongMessage;
import jakarta.websocket.server.ServerEndpointConfig;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodType;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import org.eclipse.jetty.ee10.websocket.jakarta.common.ClientEndpointConfigWrapper;
import org.eclipse.jetty.ee10.websocket.jakarta.common.ConfiguredEndpoint;
import org.eclipse.jetty.ee10.websocket.jakarta.common.EndpointConfigWrapper;
import org.eclipse.jetty.ee10.websocket.jakarta.common.JakartaWebSocketContainer;
import org.eclipse.jetty.ee10.websocket.jakarta.common.JakartaWebSocketFrameHandlerFactory;
import org.eclipse.jetty.ee10.websocket.jakarta.common.JakartaWebSocketMessageMetadata;
import org.eclipse.jetty.ee10.websocket.jakarta.common.JakartaWebSocketPongMessage;
import org.eclipse.jetty.ee10.websocket.jakarta.common.JakartaWebSocketSession;
import org.eclipse.jetty.ee10.websocket.jakarta.common.PutListenerMap;
import org.eclipse.jetty.ee10.websocket.jakarta.common.RegisteredMessageHandler;
import org.eclipse.jetty.ee10.websocket.jakarta.common.ServerEndpointConfigWrapper;
import org.eclipse.jetty.ee10.websocket.jakarta.common.UpgradeRequest;
import org.eclipse.jetty.ee10.websocket.jakarta.common.decoders.AvailableDecoders;
import org.eclipse.jetty.ee10.websocket.jakarta.common.decoders.RegisteredDecoder;
import org.eclipse.jetty.ee10.websocket.jakarta.common.messages.DecodedBinaryMessageSink;
import org.eclipse.jetty.ee10.websocket.jakarta.common.messages.DecodedBinaryStreamMessageSink;
import org.eclipse.jetty.ee10.websocket.jakarta.common.messages.DecodedTextMessageSink;
import org.eclipse.jetty.ee10.websocket.jakarta.common.messages.DecodedTextStreamMessageSink;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.thread.AutoLock;
import org.eclipse.jetty.websocket.core.CloseStatus;
import org.eclipse.jetty.websocket.core.CoreSession;
import org.eclipse.jetty.websocket.core.Frame;
import org.eclipse.jetty.websocket.core.FrameHandler;
import org.eclipse.jetty.websocket.core.OpCode;
import org.eclipse.jetty.websocket.core.exception.ProtocolException;
import org.eclipse.jetty.websocket.core.exception.WebSocketException;
import org.eclipse.jetty.websocket.core.messages.MessageSink;
import org.eclipse.jetty.websocket.core.messages.PartialByteArrayMessageSink;
import org.eclipse.jetty.websocket.core.messages.PartialByteBufferMessageSink;
import org.eclipse.jetty.websocket.core.messages.PartialStringMessageSink;
import org.eclipse.jetty.websocket.core.util.InvokerUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class JakartaWebSocketFrameHandler
implements FrameHandler {
    private final AutoLock lock = new AutoLock();
    private final Logger logger;
    private final JakartaWebSocketContainer container;
    private final Object endpointInstance;
    private final AtomicBoolean closeNotified = new AtomicBoolean();
    private MethodHandle openHandle;
    private MethodHandle closeHandle;
    private MethodHandle errorHandle;
    private MethodHandle pongHandle;
    private JakartaWebSocketMessageMetadata textMetadata;
    private JakartaWebSocketMessageMetadata binaryMetadata;
    private final UpgradeRequest upgradeRequest;
    private EndpointConfig endpointConfig;
    private final Map<Byte, RegisteredMessageHandler> messageHandlerMap = new HashMap<Byte, RegisteredMessageHandler>();
    private MessageSink textSink;
    private MessageSink binarySink;
    private MessageSink activeMessageSink;
    private JakartaWebSocketSession session;
    private CoreSession coreSession;
    protected byte dataType = (byte)-1;

    public JakartaWebSocketFrameHandler(JakartaWebSocketContainer container, UpgradeRequest upgradeRequest, Object endpointInstance, MethodHandle openHandle, MethodHandle closeHandle, MethodHandle errorHandle, JakartaWebSocketMessageMetadata textMetadata, JakartaWebSocketMessageMetadata binaryMetadata, MethodHandle pongHandle, EndpointConfig endpointConfig) {
        this.logger = LoggerFactory.getLogger(endpointInstance.getClass());
        this.container = container;
        this.upgradeRequest = upgradeRequest;
        if (endpointInstance instanceof ConfiguredEndpoint) {
            RuntimeException oops = new RuntimeException("ConfiguredEndpoint needs to be unwrapped");
            this.logger.warn("Unexpected ConfiguredEndpoint", (Throwable)oops);
            throw oops;
        }
        this.endpointInstance = endpointInstance;
        this.openHandle = openHandle;
        this.closeHandle = closeHandle;
        this.errorHandle = errorHandle;
        this.textMetadata = textMetadata;
        this.binaryMetadata = binaryMetadata;
        this.pongHandle = pongHandle;
        this.endpointConfig = endpointConfig;
    }

    public Object getEndpoint() {
        return this.endpointInstance;
    }

    public EndpointConfig getEndpointConfig() {
        return this.endpointConfig;
    }

    public JakartaWebSocketSession getSession() {
        return this.session;
    }

    public void onOpen(CoreSession coreSession, Callback callback) {
        this.coreSession = coreSession;
        try {
            JakartaWebSocketMessageMetadata actualBinaryMetadata;
            this.endpointConfig = this.getWrappedEndpointConfig();
            this.session = new JakartaWebSocketSession(this.container, coreSession, this, this.endpointConfig);
            if (!this.session.isOpen()) {
                throw new IllegalStateException("Session is not open");
            }
            this.openHandle = InvokerUtils.bindTo((MethodHandle)this.openHandle, (Object[])new Object[]{this.session, this.endpointConfig});
            this.closeHandle = InvokerUtils.bindTo((MethodHandle)this.closeHandle, (Object[])new Object[]{this.session});
            this.errorHandle = InvokerUtils.bindTo((MethodHandle)this.errorHandle, (Object[])new Object[]{this.session});
            this.pongHandle = InvokerUtils.bindTo((MethodHandle)this.pongHandle, (Object[])new Object[]{this.session});
            JakartaWebSocketMessageMetadata actualTextMetadata = JakartaWebSocketMessageMetadata.copyOf(this.textMetadata);
            if (actualTextMetadata != null) {
                if (actualTextMetadata.isMaxMessageSizeSet()) {
                    this.session.setMaxTextMessageBufferSize(actualTextMetadata.getMaxMessageSize());
                }
                MethodHandle methodHandle = actualTextMetadata.getMethodHandle();
                methodHandle = InvokerUtils.bindTo((MethodHandle)methodHandle, (Object[])new Object[]{this.endpointInstance, this.endpointConfig, this.session});
                methodHandle = JakartaWebSocketFrameHandlerFactory.wrapNonVoidReturnType(methodHandle, this.session);
                actualTextMetadata.setMethodHandle(methodHandle);
                this.textSink = JakartaWebSocketFrameHandlerFactory.createMessageSink(this.session, actualTextMetadata);
                this.textMetadata = actualTextMetadata;
            }
            if ((actualBinaryMetadata = JakartaWebSocketMessageMetadata.copyOf(this.binaryMetadata)) != null) {
                if (actualBinaryMetadata.isMaxMessageSizeSet()) {
                    this.session.setMaxBinaryMessageBufferSize(actualBinaryMetadata.getMaxMessageSize());
                }
                MethodHandle methodHandle = actualBinaryMetadata.getMethodHandle();
                methodHandle = InvokerUtils.bindTo((MethodHandle)methodHandle, (Object[])new Object[]{this.endpointInstance, this.endpointConfig, this.session});
                methodHandle = JakartaWebSocketFrameHandlerFactory.wrapNonVoidReturnType(methodHandle, this.session);
                actualBinaryMetadata.setMethodHandle(methodHandle);
                this.binarySink = JakartaWebSocketFrameHandlerFactory.createMessageSink(this.session, actualBinaryMetadata);
                this.binaryMetadata = actualBinaryMetadata;
            }
            if (this.openHandle != null) {
                this.openHandle.invoke();
            }
            if (this.session.isOpen()) {
                this.container.notifySessionListeners(listener -> listener.onJakartaWebSocketSessionOpened(this.session));
            }
            callback.succeeded();
            coreSession.demand();
        }
        catch (Throwable cause) {
            WebSocketException wse = new WebSocketException(this.endpointInstance.getClass().getSimpleName() + " OPEN method error: " + cause.getMessage(), cause);
            callback.failed((Throwable)wse);
        }
    }

    private EndpointConfig getWrappedEndpointConfig() {
        final PutListenerMap listenerMap = new PutListenerMap(this.endpointConfig.getUserProperties(), this::configListener);
        EndpointConfigWrapper wrappedConfig = this.endpointConfig instanceof ServerEndpointConfig ? new ServerEndpointConfigWrapper(this, (ServerEndpointConfig)this.endpointConfig){
            final /* synthetic */ JakartaWebSocketFrameHandler this$0;
            {
                this.this$0 = this$0;
                super(endpointConfig);
            }

            @Override
            public Map<String, Object> getUserProperties() {
                return listenerMap;
            }
        } : (this.endpointConfig instanceof ClientEndpointConfig ? new ClientEndpointConfigWrapper(this, (ClientEndpointConfig)this.endpointConfig){
            final /* synthetic */ JakartaWebSocketFrameHandler this$0;
            {
                this.this$0 = this$0;
                super(endpointConfig);
            }

            @Override
            public Map<String, Object> getUserProperties() {
                return listenerMap;
            }
        } : new EndpointConfigWrapper(this, this.endpointConfig){
            final /* synthetic */ JakartaWebSocketFrameHandler this$0;
            {
                this.this$0 = this$0;
                super(endpointConfig);
            }

            @Override
            public Map<String, Object> getUserProperties() {
                return listenerMap;
            }
        });
        return wrappedConfig;
    }

    public void onFrame(Frame frame, Callback callback) {
        switch (frame.getOpCode()) {
            case 1: {
                this.dataType = 1;
                this.onText(frame, callback);
                break;
            }
            case 2: {
                this.dataType = (byte)2;
                this.onBinary(frame, callback);
                break;
            }
            case 0: {
                this.onContinuation(frame, callback);
                break;
            }
            case 9: {
                this.onPing(frame, callback);
                break;
            }
            case 10: {
                this.onPong(frame, callback);
                break;
            }
            case 8: {
                this.onClose(frame, callback);
                break;
            }
            default: {
                callback.failed((Throwable)new IllegalStateException());
            }
        }
        if (frame.isFin() && !frame.isControlFrame()) {
            this.dataType = (byte)-1;
        }
    }

    public void onClose(Frame frame, Callback callback) {
        this.notifyOnClose(CloseStatus.getCloseStatus((Frame)frame), callback);
    }

    public void onClosed(CloseStatus closeStatus, Callback callback) {
        this.notifyOnClose(closeStatus, callback);
        this.container.notifySessionListeners(listener -> listener.onJakartaWebSocketSessionClosed(this.session));
        this.session.getDecoders().close();
        this.session.getEncoders().close();
    }

    private void notifyOnClose(CloseStatus closeStatus, Callback callback) {
        if (!this.closeNotified.compareAndSet(false, true)) {
            callback.succeeded();
            return;
        }
        try {
            if (this.closeHandle != null) {
                CloseReason closeReason = new CloseReason(CloseReason.CloseCodes.getCloseCode((int)closeStatus.getCode()), closeStatus.getReason());
                this.closeHandle.invoke(closeReason);
            }
            callback.succeeded();
        }
        catch (Throwable cause) {
            callback.failed((Throwable)new WebSocketException(this.endpointInstance.getClass().getSimpleName() + " CLOSE method error: " + cause.getMessage(), cause));
        }
    }

    public void onError(Throwable cause, Callback callback) {
        try {
            if (this.errorHandle != null) {
                this.errorHandle.invoke(cause);
            } else {
                this.logger.warn("Unhandled Error: " + String.valueOf(this.endpointInstance), cause);
            }
            callback.succeeded();
        }
        catch (Throwable t) {
            WebSocketException wsError = new WebSocketException(this.endpointInstance.getClass().getSimpleName() + " ERROR method error: " + cause.getMessage(), t);
            wsError.addSuppressed(cause);
            callback.failed((Throwable)wsError);
        }
    }

    public Set<MessageHandler> getMessageHandlers() {
        return this.messageHandlerMap.values().stream().map(RegisteredMessageHandler::getMessageHandler).collect(Collectors.toUnmodifiableSet());
    }

    public Map<Byte, RegisteredMessageHandler> getMessageHandlerMap() {
        return this.messageHandlerMap;
    }

    public JakartaWebSocketMessageMetadata getBinaryMetadata() {
        return this.binaryMetadata;
    }

    public JakartaWebSocketMessageMetadata getTextMetadata() {
        return this.textMetadata;
    }

    public <T> void addMessageHandler(Class<T> clazz, MessageHandler.Partial<T> handler) {
        try {
            byte basicType;
            MethodHandle methodHandle = JakartaWebSocketFrameHandlerFactory.getServerMethodHandleLookup().findVirtual(MessageHandler.Partial.class, "onMessage", MethodType.methodType(Void.TYPE, Object.class, Boolean.TYPE)).bindTo(handler);
            JakartaWebSocketMessageMetadata metadata = new JakartaWebSocketMessageMetadata();
            metadata.setMethodHandle(methodHandle);
            if (byte[].class.isAssignableFrom(clazz)) {
                basicType = 2;
                metadata.setSinkClass(PartialByteArrayMessageSink.class);
            } else if (ByteBuffer.class.isAssignableFrom(clazz)) {
                basicType = 2;
                metadata.setSinkClass(PartialByteBufferMessageSink.class);
            } else if (String.class.isAssignableFrom(clazz)) {
                basicType = 1;
                metadata.setSinkClass(PartialStringMessageSink.class);
            } else {
                throw new RuntimeException("Unable to add " + handler.getClass().getName() + " with type " + String.valueOf(clazz) + ": only supported types byte[], " + ByteBuffer.class.getName() + ", " + String.class.getName());
            }
            this.registerMessageHandler(clazz, (MessageHandler)handler, basicType, metadata);
        }
        catch (NoSuchMethodException e) {
            throw new IllegalStateException("Unable to find method", e);
        }
        catch (IllegalAccessException e) {
            throw new IllegalStateException("Unable to access " + handler.getClass().getName(), e);
        }
    }

    public <T> void addMessageHandler(Class<T> clazz, MessageHandler.Whole<T> handler) {
        try {
            byte basicType;
            MethodHandle methodHandle = JakartaWebSocketFrameHandlerFactory.getServerMethodHandleLookup().findVirtual(MessageHandler.Whole.class, "onMessage", MethodType.methodType(Void.TYPE, Object.class)).bindTo(handler);
            if (PongMessage.class.isAssignableFrom(clazz)) {
                this.assertBasicTypeNotRegistered((byte)10, (MessageHandler)handler);
                this.pongHandle = methodHandle;
                this.registerMessageHandler((byte)10, clazz, (MessageHandler)handler, null);
                return;
            }
            AvailableDecoders availableDecoders = this.session.getDecoders();
            RegisteredDecoder registeredDecoder = availableDecoders.getFirstRegisteredDecoder(clazz);
            if (registeredDecoder == null) {
                throw new IllegalStateException("Unable to find Decoder for type: " + String.valueOf(clazz));
            }
            JakartaWebSocketMessageMetadata metadata = new JakartaWebSocketMessageMetadata();
            metadata.setMethodHandle(methodHandle);
            if (registeredDecoder.implementsInterface(Decoder.Binary.class)) {
                basicType = 2;
                metadata.setRegisteredDecoders(availableDecoders.getBinaryDecoders(clazz));
                metadata.setSinkClass(DecodedBinaryMessageSink.class);
            } else if (registeredDecoder.implementsInterface(Decoder.BinaryStream.class)) {
                basicType = 2;
                metadata.setRegisteredDecoders(availableDecoders.getBinaryStreamDecoders(clazz));
                metadata.setSinkClass(DecodedBinaryStreamMessageSink.class);
            } else if (registeredDecoder.implementsInterface(Decoder.Text.class)) {
                basicType = 1;
                metadata.setRegisteredDecoders(availableDecoders.getTextDecoders(clazz));
                metadata.setSinkClass(DecodedTextMessageSink.class);
            } else if (registeredDecoder.implementsInterface(Decoder.TextStream.class)) {
                basicType = 1;
                metadata.setRegisteredDecoders(availableDecoders.getTextStreamDecoders(clazz));
                metadata.setSinkClass(DecodedTextStreamMessageSink.class);
            } else {
                throw new RuntimeException("Unable to add " + handler.getClass().getName() + ": type " + String.valueOf(clazz) + " is unrecognized by declared decoders");
            }
            this.registerMessageHandler(clazz, (MessageHandler)handler, basicType, metadata);
        }
        catch (NoSuchMethodException e) {
            throw new IllegalStateException("Unable to find method", e);
        }
        catch (IllegalAccessException e) {
            throw new IllegalStateException("Unable to access " + handler.getClass().getName(), e);
        }
    }

    private void assertBasicTypeNotRegistered(byte basicWebSocketType, MessageHandler replacement) {
        if ((switch (basicWebSocketType) {
            case 1 -> this.textSink;
            case 2 -> this.binarySink;
            case 10 -> this.pongHandle;
            default -> throw new IllegalStateException();
        }) != null) {
            throw new IllegalStateException("Cannot register " + replacement.getClass().getName() + ": Basic WebSocket type " + OpCode.name((byte)basicWebSocketType) + " is already registered");
        }
    }

    private void registerMessageHandler(Class<?> clazz, MessageHandler handler, byte basicMessageType, JakartaWebSocketMessageMetadata metadata) {
        this.assertBasicTypeNotRegistered(basicMessageType, handler);
        MessageSink messageSink = JakartaWebSocketFrameHandlerFactory.createMessageSink(this.session, metadata);
        switch (basicMessageType) {
            case 1: {
                this.textSink = this.registerMessageHandler((byte)1, clazz, handler, messageSink);
                this.textMetadata = metadata;
                break;
            }
            case 2: {
                this.binarySink = this.registerMessageHandler((byte)2, clazz, handler, messageSink);
                this.binaryMetadata = metadata;
                break;
            }
            default: {
                throw new IllegalStateException();
            }
        }
    }

    private <T> MessageSink registerMessageHandler(byte basicWebSocketMessageType, Class<T> handlerType, MessageHandler handler, MessageSink messageSink) {
        try (AutoLock l = this.lock.lock();){
            RegisteredMessageHandler registeredHandler = this.messageHandlerMap.get(basicWebSocketMessageType);
            if (registeredHandler != null) {
                throw new IllegalStateException(String.format("Cannot register %s: Basic WebSocket type %s is already registered to %s", handler.getClass().getName(), OpCode.name((byte)basicWebSocketMessageType), registeredHandler.getMessageHandler().getClass().getName()));
            }
            registeredHandler = new RegisteredMessageHandler(basicWebSocketMessageType, handlerType, handler);
            this.getMessageHandlerMap().put(registeredHandler.getWebsocketMessageType(), registeredHandler);
            MessageSink messageSink2 = messageSink;
            return messageSink2;
        }
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public void removeMessageHandler(MessageHandler handler) {
        try (AutoLock l = this.lock.lock();){
            Optional<Map.Entry> optionalEntry = this.messageHandlerMap.entrySet().stream().filter(entry -> ((RegisteredMessageHandler)entry.getValue()).getMessageHandler().equals((Object)handler)).findFirst();
            if (!optionalEntry.isPresent()) return;
            byte key = (Byte)optionalEntry.get().getKey();
            this.messageHandlerMap.remove(key);
            switch (key) {
                case 10: {
                    this.pongHandle = null;
                    return;
                }
                case 1: {
                    this.textMetadata = null;
                    this.textSink = null;
                    return;
                }
                case 2: {
                    this.binaryMetadata = null;
                    this.binarySink = null;
                    return;
                }
                default: {
                    throw new IllegalStateException("Invalid MessageHandler type " + OpCode.name((byte)key));
                }
            }
        }
    }

    public String toString() {
        StringBuilder ret = new StringBuilder();
        ret.append(this.getClass().getSimpleName());
        ret.append('@').append(Integer.toHexString(this.hashCode()));
        ret.append("[endpoint=");
        if (this.endpointInstance == null) {
            ret.append("<null>");
        } else {
            ret.append(this.endpointInstance.getClass().getName());
        }
        ret.append(']');
        return ret.toString();
    }

    private void acceptMessage(Frame frame, Callback callback) {
        if (this.activeMessageSink == null) {
            callback.succeeded();
            this.coreSession.demand();
            return;
        }
        MessageSink messageSink = this.activeMessageSink;
        if (frame.isFin()) {
            this.activeMessageSink = null;
        }
        messageSink.accept(frame, callback);
    }

    public void onPing(Frame frame, Callback callback) {
        this.coreSession.sendFrame(new Frame(10).setPayload(frame.getPayload()), Callback.from(() -> {
            callback.succeeded();
            this.coreSession.demand();
        }, x -> {
            callback.succeeded();
            this.coreSession.demand();
        }), false);
    }

    public void onPong(Frame frame, Callback callback) {
        if (this.pongHandle != null) {
            try {
                ByteBuffer payload = frame.getPayload();
                if (payload == null) {
                    payload = BufferUtil.EMPTY_BUFFER;
                }
                JakartaWebSocketPongMessage pongMessage = new JakartaWebSocketPongMessage(payload);
                this.pongHandle.invoke(pongMessage);
                callback.succeeded();
                this.coreSession.demand();
            }
            catch (Throwable cause) {
                callback.failed((Throwable)new WebSocketException(this.endpointInstance.getClass().getSimpleName() + " PONG method error: " + cause.getMessage(), cause));
            }
        } else {
            callback.succeeded();
            this.coreSession.demand();
        }
    }

    public void onText(Frame frame, Callback callback) {
        if (this.activeMessageSink == null) {
            this.activeMessageSink = this.textSink;
        }
        this.acceptMessage(frame, callback);
    }

    public void onBinary(Frame frame, Callback callback) {
        if (this.activeMessageSink == null) {
            this.activeMessageSink = this.binarySink;
        }
        this.acceptMessage(frame, callback);
    }

    public void onContinuation(Frame frame, Callback callback) {
        switch (this.dataType) {
            case 1: {
                this.onText(frame, callback);
                break;
            }
            case 2: {
                this.onBinary(frame, callback);
                break;
            }
            default: {
                callback.failed((Throwable)new ProtocolException("Unable to process continuation during dataType " + this.dataType));
            }
        }
    }

    public UpgradeRequest getUpgradeRequest() {
        return this.upgradeRequest;
    }

    private void configListener(String key, Object value) {
        if (!key.startsWith("org.eclipse.jetty.websocket.")) {
            return;
        }
        switch (key) {
            case "org.eclipse.jetty.websocket.autoFragment": {
                this.coreSession.setAutoFragment(((Boolean)value).booleanValue());
                break;
            }
            case "org.eclipse.jetty.websocket.maxFrameSize": {
                this.coreSession.setMaxFrameSize(((Long)value).longValue());
                break;
            }
            case "org.eclipse.jetty.websocket.outputBufferSize": {
                this.coreSession.setOutputBufferSize(((Integer)value).intValue());
                break;
            }
            case "org.eclipse.jetty.websocket.inputBufferSize": {
                this.coreSession.setInputBufferSize(((Integer)value).intValue());
            }
        }
    }
}

