/*
 * Decompiled with CFR 0.152.
 */
package com.github.llmjava.cohere4j;

import com.github.llmjava.cohere4j.CohereApi;
import com.github.llmjava.cohere4j.CohereApiFactory;
import com.github.llmjava.cohere4j.CohereConfig;
import com.github.llmjava.cohere4j.callback.AsyncCallback;
import com.github.llmjava.cohere4j.callback.StreamingCallback;
import com.github.llmjava.cohere4j.exception.CohereException;
import com.github.llmjava.cohere4j.request.ClassifyRequest;
import com.github.llmjava.cohere4j.request.DetectLanguageRequest;
import com.github.llmjava.cohere4j.request.DetokenizeRequest;
import com.github.llmjava.cohere4j.request.EmbedRequest;
import com.github.llmjava.cohere4j.request.GenerateRequest;
import com.github.llmjava.cohere4j.request.RerankRequest;
import com.github.llmjava.cohere4j.request.SummarizeRequest;
import com.github.llmjava.cohere4j.request.TokenizeRequest;
import com.github.llmjava.cohere4j.response.ClassifyResponse;
import com.github.llmjava.cohere4j.response.DetectLanguageResponse;
import com.github.llmjava.cohere4j.response.DetokenizeResponse;
import com.github.llmjava.cohere4j.response.EmbedResponse;
import com.github.llmjava.cohere4j.response.GenerateResponse;
import com.github.llmjava.cohere4j.response.RerankResponse;
import com.github.llmjava.cohere4j.response.SummarizeResponse;
import com.github.llmjava.cohere4j.response.TokenizeResponse;
import com.github.llmjava.cohere4j.response.streaming.ResponseConverter;
import com.github.llmjava.cohere4j.response.streaming.StreamGenerateResponse;
import com.google.gson.Gson;
import java.io.IOException;
import retrofit2.Call;
import retrofit2.Callback;
import retrofit2.Response;

public class CohereClient {
    private final CohereApi api;
    private final CohereConfig config;
    private final Gson gson;

    CohereClient(Builder builder) {
        this.api = builder.api;
        this.config = builder.config;
        this.gson = builder.gson;
    }

    public GenerateResponse generate(GenerateRequest request) {
        return this.execute(this.api.generate(request));
    }

    public void generateAsync(GenerateRequest request, AsyncCallback<GenerateResponse> callback) {
        this.execute(this.api.generate(request), callback);
    }

    public void generateStream(GenerateRequest request, final StreamingCallback<StreamGenerateResponse> callback) {
        if (!request.isStreaming().booleanValue()) {
            throw new IllegalArgumentException("Expected a streaming request");
        }
        final ResponseConverter converter = new ResponseConverter(this.gson);
        this.api.generateStream(request).enqueue((Callback)new Callback<String>(){

            public void onResponse(Call<String> call, Response<String> response) {
                if (response.isSuccessful()) {
                    for (StreamGenerateResponse resp : converter.toStreamingGenerationResponse((String)response.body())) {
                        if (resp.isFinished().booleanValue()) {
                            callback.onComplete(resp);
                            continue;
                        }
                        callback.onPart(resp);
                    }
                } else {
                    callback.onFailure(CohereException.fromResponse(response));
                }
            }

            public void onFailure(Call<String> call, Throwable throwable) {
                callback.onFailure(throwable);
            }
        });
    }

    public EmbedResponse embed(EmbedRequest request) {
        return this.execute(this.api.embed(request));
    }

    public void embedAsync(EmbedRequest request, AsyncCallback<EmbedResponse> callback) {
        this.execute(this.api.embed(request), callback);
    }

    public ClassifyResponse classify(ClassifyRequest request) {
        return this.execute(this.api.classify(request));
    }

    public void classifyAsync(ClassifyRequest request, AsyncCallback<ClassifyResponse> callback) {
        this.execute(this.api.classify(request), callback);
    }

    public TokenizeResponse tokenize(TokenizeRequest request) {
        return this.execute(this.api.tokenize(request));
    }

    public void tokenizeAsync(TokenizeRequest request, AsyncCallback<TokenizeResponse> callback) {
        this.execute(this.api.tokenize(request), callback);
    }

    public DetokenizeResponse detokenize(DetokenizeRequest request) {
        return this.execute(this.api.detokenize(request));
    }

    public void detokenizeAsync(DetokenizeRequest request, AsyncCallback<DetokenizeResponse> callback) {
        this.execute(this.api.detokenize(request), callback);
    }

    public DetectLanguageResponse detectLanguage(DetectLanguageRequest request) {
        return this.execute(this.api.detectLanguage(request));
    }

    public void detectLanguageAsync(DetectLanguageRequest request, AsyncCallback<DetectLanguageResponse> callback) {
        this.execute(this.api.detectLanguage(request), callback);
    }

    public SummarizeResponse summarize(SummarizeRequest request) {
        return this.execute(this.api.summarize(request));
    }

    public void summarizeAsync(SummarizeRequest request, AsyncCallback<SummarizeResponse> callback) {
        this.execute(this.api.summarize(request), callback);
    }

    public RerankResponse rerank(RerankRequest request) {
        return this.execute(this.api.rerank(request));
    }

    public void rerankAsync(RerankRequest request, AsyncCallback<RerankResponse> callback) {
        this.execute(this.api.rerank(request), callback);
    }

    private <T> T execute(Call<T> action) {
        try {
            Response response = action.execute();
            if (response.isSuccessful()) {
                return (T)response.body();
            }
            throw CohereException.fromResponse(response);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private <T> void execute(Call<T> action, final AsyncCallback<T> callback) {
        action.enqueue(new Callback<T>(){

            public void onResponse(Call<T> call, Response<T> response) {
                if (response.isSuccessful()) {
                    callback.onSuccess(response.body());
                } else {
                    callback.onFailure(CohereException.fromResponse(response));
                }
            }

            public void onFailure(Call<T> call, Throwable throwable) {
                callback.onFailure(throwable);
            }
        });
    }

    public static class Builder {
        private CohereApi api;
        private CohereConfig config;
        private Gson gson;

        public Builder withConfig(CohereConfig config) {
            this.config = config;
            CohereApiFactory factory = new CohereApiFactory();
            this.api = factory.createGson().createHttpClient(config).build();
            this.gson = factory.gson;
            return this;
        }

        public CohereClient build() {
            return new CohereClient(this);
        }
    }
}

