#include "justlm.hpp"

#include <fstream>
#include <random>
#include <cstring>
#include "gptj/gptj.hpp"
#include "g4a-common.hpp"


namespace LM {
class GPTJInference final : public Inference {
    std::string weights_path;

    struct State {
        gpt_vocab vocab;
        gptj_model model;
        std::string prompt; // Mostly here for easy "debugging"
        std::vector<int> tokens;
        std::vector<float> logits;
        size_t mem_per_token = 0;
        std::mt19937 rng;

        State(int32_t seed) : rng(seed) {}
    };

    State*& get_state() LM_NOEXCEPTDECL {
        return *reinterpret_cast<State**>(&generic_state);
    }
    State* const& get_state() const LM_NOEXCEPTDECL {
        return *reinterpret_cast<State* const*>(&generic_state);
    }

    LM_ERRBOOL init(const std::string& _weights_path, std::ifstream& f) LM_NOEXCEPTDECL {
        auto& state = get_state();
        weights_path = _weights_path;

        // Allocate state
        state = new State(params.seed);

        // Load model
        if (!gptj_model_load(weights_path, f, state->model, state->vocab)) {
            LM_THROW("Failed to initialize gptj from file", LM_BOOL_ERROR);
        }

        // Calculate memory required per token
        static std::vector<gpt_vocab::id> p_instruct;
        static std::vector<gpt_vocab::id> r_instruct;
        gptj_eval(state->model, params.n_threads, 0, { 0, 1, 2, 3 }, state->logits, state->mem_per_token);

        return LM_BOOL_SUCCESS;
    }
    void deinit() LM_NOEXCEPTDECL {
        auto& state = get_state();

        if (state) {
            if (state->model.ctx) ggml_free(state->model.ctx); //TODO: Is that enough?
            delete state;
        }
    }

    // This function reduces the size of our tokens vector according to some parameters
    // All tokens will be evaluated if scrolling was needed and true will be returned
    LM_SCHEDULABLE(bool) window_scroll() LM_NOEXCEPTDECL {
        auto &state = get_state();
        // Check that we actually need to scroll
        if (state->tokens.size() <= params.n_ctx) {
            // Nope
            LM_CORETURN false;
        }
        // Start scrolling
        if (params.scroll_keep > 0.0f) {
            // "Scroll" down the context window...
            unsigned keep_count = float(state->tokens.size() - params.n_ctx_window_top_bar) * 0.4f; // We keep about 40%
            // Get vector of tokens to keep
            std::vector<int> tokens_in_view(state->tokens.end()-keep_count, state->tokens.end());
            // Cut down tokens vector size
            state->tokens.resize(params.n_ctx_window_top_bar+keep_count);
            // Overwrite tokens after top bar with tokens in view
            std::memcpy(state->tokens.data()+params.n_ctx_window_top_bar, tokens_in_view.data(), tokens_in_view.size()*sizeof(int));
        } else {
            // Cut down tokens vector size to top bar
            state->tokens.resize(params.n_ctx_window_top_bar);
        }
        // Evaluate tokens
        LM_ERROR_FORWARD(LM_COAWAIT evaluate_tokens(0, on_scroll));
        LM_CORETURN true;
    }

    LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL {
        auto& state = get_state();

        // Evaluate tokens in batches
        unsigned it;
        for (it = starting_offset; ; it += params.n_batch) {
            if (it + params.n_batch >= ssize_t(state->tokens.size())) break;

            // Evaluate
            std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+params.n_batch);
            if (!gptj_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) {
                LM_COTHROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
            }

            // Tick
            if (on_tick) {
                // Calculate progress
                auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f;
                // Tick and yield
                if (!on_tick(progress)) LM_CORETURN LM_BOOL_SUCCESS;
                else if (!LM_TASKYIELD) LM_CORETURN LM_BOOL_SUCCESS;
            }
        }

        // Evaluate remaining tokens
        if (it < state->tokens.size()) {
            for (; it != state->tokens.size(); it++) {
                //TODO: This is extremely inefficient! Don't do that...
                std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+1);
                if (!gptj_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) {
                    LM_COTHROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
                }
            }
        }

        // Notify about completion
        if (on_tick) on_tick(100.f);

        LM_CORETURN LM_BOOL_SUCCESS;
    }

public:
    GPTJInference(const std::string& weights_path, std::ifstream& f, const Params& p) : Inference(p) {
        init(weights_path, f);
    }
    ~GPTJInference() LM_NOEXCEPTDECL override {
        deinit();
    }

    LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
        auto& state = get_state();

        // Append to current prompt
        state->prompt.append(prompt);

        // Resize buffer for tokens
        const auto old_token_count = state->tokens.size();

        // Run tokenizer
        const auto tokens = gpt_tokenize(state->vocab, prompt);
        state->tokens.insert(
                    state->tokens.end(),
                    std::make_move_iterator(tokens.begin()),
                    std::make_move_iterator(tokens.end())
        );

        // Make sure token limit isn't being hit
        if (LM_COAWAIT window_scroll()) {
            // That function already has evaluated our tokens since scrolling was needed
            LM_CORETURN LM_BOOL_SUCCESS;
        }

        // Evaluate new tokens
        LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick);
    }

    LM_SCHEDULABLE(std::string) run(std::string_view end, const std::function<bool (const char *)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
        auto& state = get_state();
        std::string fres;

        // Loop until done
        bool abort = false;
        unsigned eos_count = 0;
        while (!abort && !ends_with(fres, end)) {
            // Sample top p and top k
            auto id = gpt_sample_top_k_top_p(state->vocab, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-params.n_repeat_last):nullptr, params.n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);

            if (id == 50256) {
                if (eos_count++ == params.eos_ignores) {
                    abort = true;
                    continue;
                }
                id = gpt_tokenize(state->vocab, "\n")[0];
                state->tokens.push_back(id);
            } else {
                // Add token
                state->tokens.push_back(id);
            }

            // Make sure token limit isn't being hit
            LM_COAWAIT window_scroll();

            // Get token as string
            const auto str = state->vocab.id_to_token[id];

            // Append string to function result
            fres.append(str);

            // Evaluate token
            //  TODO: Respect batch size
            std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
            if (!gptj_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
                LM_COTHROW("Failed to evaluate new tokens", "");
            }

            // Tick
            if (on_tick && !on_tick(str.c_str())) abort = true;
            else if (!LM_TASKYIELD) abort = true;
        }

        // Create final string  TODO: Could be optimized
        state->prompt.append(fres);
        if (!abort) {
            fres = std::string(fres.data(), fres.size()-end.size());
        }

        // Return final string
        LM_CORETURN fres;
    }

    unsigned get_context_size() const noexcept override {
        return get_state()->tokens.size();
    }

    LM_SCHEDULABLE(LM_ERRBOOL) create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
        auto& state = get_state();
        sv.buf.resize(gptj_get_state_size(state->model));
        gptj_copy_state_data(state->model, state->rng, sv.buf.data());
        sv.tokens = state->tokens;
        sv.prompt = state->prompt;
        sv.ctx = generic_state;
        LM_CORETURN LM_BOOL_SUCCESS;
    }
    LM_SCHEDULABLE(LM_ERRBOOL) restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
        auto& state = get_state();
        if (sv.ctx != generic_state)
            LM_COTHROW("Savestate does not match context", LM_BOOL_ERROR);
        gptj_set_state_data(&state->model, &state->rng, sv.buf.data());
        state->tokens = sv.tokens;
        state->prompt = sv.prompt;
        LM_CORETURN LM_BOOL_SUCCESS;
    }

    LM_SCHEDULABLE(LM_ERRBOOL) serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
        auto& state = get_state();
        // Get state size
        auto state_size = gptj_get_state_size(state->model);
        // Write sizes
        for (const uint32_t s : {state->tokens.size(), state->prompt.size(), state_size}) {
            if (!o.write(reinterpret_cast<const char*>(&s), sizeof(s))) {
                LM_COTHROW("Failed to serialize data sizes", LM_BOOL_ERROR);
            }
        }
        // Write tokens
        if (!o.write(reinterpret_cast<const char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
            LM_COTHROW("Failed to serialize tokens", LM_BOOL_ERROR);
        }
        // Write prompt
        if (!o.write(state->prompt.data(), state->prompt.size())) {
            LM_COTHROW("Failed to serialize prompt", LM_BOOL_ERROR);
        }
        // Write state
        std::vector<uint8_t> state_buf(state_size);
        gptj_copy_state_data(state->model, state->rng, state_buf.data());
        if (!o.write(reinterpret_cast<const char*>(state_buf.data()), state_size)) {
            LM_COTHROW("Failed to serialize state", LM_BOOL_ERROR);
        }
        LM_CORETURN LM_BOOL_SUCCESS;
    }
    LM_SCHEDULABLE(LM_ERRBOOL) deserialize(std::istream &i) LM_NOEXCEPTDECL override {
        auto& state = get_state();
        uint32_t embd_size, prompt_size, state_size;
        // Initialization to prevent compiler complaints
        embd_size = prompt_size = state_size = 0;
        // Read sizes
        for (uint32_t *s : {&embd_size, &prompt_size, &state_size}) {
            if (!i.read(reinterpret_cast<char*>(s), sizeof(*s))) {
                LM_COTHROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
            }
        }
        // Read tokens
        state->tokens.resize(embd_size);
        if (!i.read(reinterpret_cast<char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
            LM_COTHROW("Failed to deserialize tokens", LM_BOOL_ERROR);
        }
        // Read prompt
        state->prompt.resize(prompt_size);
        if (!i.read(state->prompt.data(), state->prompt.size())) {
            LM_COTHROW("Failed to deserialize prompt", LM_BOOL_ERROR);
        }
        // Read state
        std::vector<uint8_t> state_buf(state_size);
        if (!i.read(reinterpret_cast<char*>(state_buf.data()), state_buf.size())) {
            LM_COTHROW("Failed to deserialize state", LM_BOOL_ERROR);
        }
        gptj_set_state_data(&state->model, &state->rng, state_buf.data());
        LM_CORETURN LM_BOOL_SUCCESS;
    }
    const std::string &get_prompt() const LM_NOEXCEPTDECL override {
        return get_state()->prompt;
    }
};
}