/**
 * Copyright 2014-2019 XebiaLabs Inc. and its affiliates. Use is subject to terms of the enclosed Legal Notice.
 */
package com.xebialabs.xltype.serialization.xml;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.annotation.Annotation;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import javax.ws.rs.Consumes;
import javax.ws.rs.Produces;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.MultivaluedMap;
import javax.ws.rs.ext.MessageBodyReader;
import javax.ws.rs.ext.MessageBodyWriter;
import javax.ws.rs.ext.Provider;
import javax.ws.rs.ext.Providers;
import javax.xml.bind.annotation.XmlElementWrapper;
import org.jboss.resteasy.core.NoMessageBodyWriterFoundFailure;

@Provider
@Consumes(MediaType.APPLICATION_XML)
@Produces(MediaType.APPLICATION_XML)
public class StreamXmlReaderWriter implements MessageBodyReader<Stream<Object>>, MessageBodyWriter<Stream<Object>> {

    private static byte[] DEFAULT_START_LIST = {'<', 'l', 'i', 's', 't', '>'};
    private static byte[] DEFAULT_END_LIST = {'<', '/', 'l', 'i', 's', 't', '>'};

    private Map<String, ElementHolder> elementsCache = new HashMap<>();

    @Context
    private Providers providers;

    @Override
    public boolean isReadable(final Class<?> type, final Type genericType, final Annotation[] annotations, final MediaType mediaType) {
        return Stream.class.isAssignableFrom(type) && genericType instanceof ParameterizedType;
    }

    @Override
    public Stream<Object> readFrom(final Class<Stream<Object>> type, final Type genericType, final Annotation[] annotations, final MediaType mediaType, final MultivaluedMap<String, String> httpHeaders, final InputStream entityStream) throws IOException, WebApplicationException {
        ParameterizedType listType = new ParameterizedType() {

            @Override
            public Type[] getActualTypeArguments() {
                return ((ParameterizedType)genericType).getActualTypeArguments();
            }

            @Override
            public Type getRawType() {
                return List.class;
            }

            @Override
            public Type getOwnerType() {
                return null;
            }
        };
        MessageBodyReader<List> listReader = providers.getMessageBodyReader(List.class, listType, annotations, mediaType);
        List<Object> list = listReader.readFrom(List.class, listType, annotations, mediaType, httpHeaders, entityStream);
        return list.stream();
    }

    @Override
    public boolean isWriteable(final Class<?> type, final Type genericType, final Annotation[] annotations, final MediaType mediaType) {
        return Stream.class.isAssignableFrom(type) && genericType instanceof ParameterizedType;
    }

    @Override
    public long getSize(final Stream<Object> stream, final Class<?> type, final Type genericType, final Annotation[] annotations, final MediaType mediaType) {
        return -1;
    }

    @Override
    public void writeTo(final Stream<Object> stream, final Class<?> type, final Type genericType, final Annotation[] annotations, final MediaType mediaType, final MultivaluedMap<String, Object> httpHeaders, final OutputStream entityStream) throws IOException, WebApplicationException {
        byte[] startList = DEFAULT_START_LIST;
        byte[] endList = DEFAULT_END_LIST;

        XmlElementWrapper wrapper = findAnnotation(annotations, XmlElementWrapper.class);
        if (wrapper != null) {
            ElementHolder holder = elementsCache.computeIfAbsent(wrapper.name(), ElementHolder::create);
            startList = holder.startList;
            endList = holder.endList;
        }

        final Type argumentType = ((ParameterizedType) genericType).getActualTypeArguments()[0];
        final Class argumentClass = (Class) argumentType;

        MessageBodyWriter writer = providers.getMessageBodyWriter(argumentClass, argumentType, annotations, mediaType);
        if (writer == null) {
            throw new NoMessageBodyWriterFoundFailure((Class) argumentType, mediaType);
        }

        doWrite(stream, argumentClass, argumentType, annotations, mediaType, httpHeaders, entityStream, startList, endList, writer);
    }

    private void doWrite(final Stream<Object> stream, final Class actualClass, final Type argumentType, final Annotation[] annotations, final MediaType mediaType, final MultivaluedMap<String, Object> httpHeaders, final OutputStream entityStream, final byte[] startList, final byte[] endList, final MessageBodyWriter actualWriter) throws IOException {
        try (Stream<Object> s = stream) {
            entityStream.write(startList);
            s.forEachOrdered(o -> {
                try {
                    actualWriter.writeTo(o, actualClass, argumentType, annotations, mediaType, httpHeaders, entityStream);
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }
            });
            entityStream.write(endList);
        }
    }

    private static <T extends Annotation> T findAnnotation(Annotation[] annotations, Class<T> type) {
        for (Annotation annotation: annotations) {
            if (type.isInstance(annotation)) {
                return type.cast(annotation);
            }
        }
        return null;
    }

    static byte[] createStartElement(final String elementName) {
        final byte[] startList = new byte[elementName.length() + 2];
        startList[0] = '<';
        System.arraycopy(elementName.getBytes(), 0, startList, 1, startList.length - 2);
        startList[startList.length -1] = '>';
        return startList;
    }

    static byte[] createEndElement(final String elementName) {
        final byte[] endList = new byte[elementName.length() + 3];
        endList[0] = '<';

        endList[1] = '/';
        System.arraycopy(elementName.getBytes(), 0, endList, 2, endList.length - 3);
        endList[endList.length -1] = '>';
        return endList;
    }

    private static class ElementHolder {
        private final byte[] startList;
        private final byte[] endList;

        private ElementHolder(final byte[] startList, final byte[] endList) {
            this.startList = startList;
            this.endList = endList;
        }

        private static ElementHolder create(String s) {
            return new ElementHolder(createStartElement(s), createEndElement(s));
        }
    }

}
