#ifndef JUSTLM_HPP
#define JUSTLM_HPP
#include <iostream>
#include <string>
#include <vector>
#include <functional>
#include <memory>
#include <thread>

#ifdef LM_NOEXCEPT
#   define LM_NOEXCEPTDECL noexcept
#   define LM_THROW(t, r) do {this->last_error = (t); return r;} while (0)
#   define LM_LAST_ERROR_STORAGE mutable std::string last_error;
#   define LM_LAST_ERROR_GETTER const std::string& get_last_error() const {return last_error;}
#   define LM_ERRBOOL bool
#   define LM_BOOL_ERROR false
#   define LM_BOOL_SUCCESS true
#   define LM_RETHROW(x) return x
#   define LM_ERROR_CATCH(x, errval, ...) {auto v = x; if (v == (errval)) __VA_ARGS__}
#   define LM_ERROR_FORWARD(x, errval) do {auto v = x; if (v == (errval)) return x;} while (0)
#else
#   define LM_NOEXCEPTDECL
#   define LM_THROW(t, r) throw Exception(t)
#   define LM_LAST_ERROR_STORAGE
#   define LM_LAST_ERROR_GETTER
#   define LM_ERRBOOL void
#   define LM_BOOL_ERROR
#   define LM_BOOL_SUCCESS
#   define LM_RETHROW(x) std::rethrow_exception(std::current_exception())
#   define LM_ERROR_CATCH(x, errval, ...) try {x;} catch (...) __VA_ARGS__
#   define LM_ERROR_FORWARD(x, errval) {x;}
#endif

#if _MSC_VER
#include <BaseTsd.h>
#endif


namespace LM {
#if _MSC_VER
using ssize_t = SSIZE_T;
#endif

using GenerateCallback = std::function<bool (const char *generated)>;
using AppendCallback = std::function<bool (float progress)>;

class Inference {
protected:
    AppendCallback on_scroll = nullptr;

    void *generic_state = nullptr;

    LM_LAST_ERROR_STORAGE

public:
    struct Exception : public std::runtime_error {
        using std::runtime_error::runtime_error;
    };

    struct Params {
        int seed = 0; // RNG seed
        unsigned n_threads = 0; // Amount of threads to use, immutable after Inference was constructed
        unsigned n_ctx = 2024; // Context size
        unsigned n_ctx_window_top_bar = 0; // Top bar of context window. Must be smaller than context size
        unsigned n_batch = 8; // Batch size
        unsigned n_repeat_last = 0;
        unsigned n_eos_ignores = 0;

        float scroll_keep = 0.0f; // 0.4f to keep 40% of context below top bar when scrolling; 0.0f to remove everything after top bar

        unsigned top_k = 40;
        float top_p = 0.9f;
        float temp = 0.72f;
        float mirostat_learning_rate = 0.1f; // mirostat specific
        float mirostat_target_entropy = 5.0f; // mirostat specific
        float repeat_penalty = 1.0f;

        unsigned n_gpu_layers = 38;
        bool use_mlock = true; // llama specific
        int prefer_mirostat = 0; // Use given mirostat version if available (see is_mirostat_available()); llama specific
    } params;

    struct Savestate {
        std::vector<uint8_t> buf;
        std::vector<int> tokens;
        std::string prompt;
        void *ctx = nullptr;

        bool is_valid() const {
            return ctx != nullptr;
        }
    };

    Inference(const Params& p) : params(p) {
        // Set random seed
        params.seed = params.seed?params.seed:time(NULL);
        params.n_threads = params.n_threads?params.n_threads:(static_cast<unsigned>(std::thread::hardware_concurrency()) / 2);
    }
    virtual ~Inference() {}
    Inference(const Inference&) = delete;
    Inference(Inference&) = delete;
    Inference(Inference&& o)
            : generic_state(o.generic_state)
            , params(o.params) {
        o.generic_state = nullptr;
    }

    static
    Inference *construct(const std::string& weights_path, const Params& p);

    void set_scroll_callback(const AppendCallback& scroll_cb) noexcept {
        on_scroll = scroll_cb;
    }

    // This must be called with a non-empty prompt!
    virtual LM_ERRBOOL append(const std::string& prompt, const AppendCallback& on_tick = nullptr) LM_NOEXCEPTDECL = 0;

    // append() must have been called at least once before calling this!
    virtual std::string run(std::string_view end = "", const GenerateCallback& on_tick = nullptr, const GenerateCallback& pre_tick = nullptr) LM_NOEXCEPTDECL = 0;

    virtual unsigned get_context_size() const noexcept = 0;

    virtual LM_ERRBOOL create_savestate(Savestate&) const LM_NOEXCEPTDECL = 0;
    virtual LM_ERRBOOL restore_savestate(const Savestate&) LM_NOEXCEPTDECL = 0;

    virtual LM_ERRBOOL serialize(std::ostream&) const LM_NOEXCEPTDECL = 0;
    virtual LM_ERRBOOL deserialize(std::istream&) LM_NOEXCEPTDECL = 0;

    virtual LM_ERRBOOL load_grammar(const std::string&, bool override_temperature [[maybe_unused]] = false) LM_NOEXCEPTDECL {
        LM_THROW("Grammar is not available for this models backend", LM_BOOL_ERROR);
    }
    virtual LM_ERRBOOL unload_grammar() LM_NOEXCEPTDECL {
        LM_THROW("Grammar is not available for this models backend", LM_BOOL_ERROR);
    }

    virtual const std::string& get_prompt() const LM_NOEXCEPTDECL = 0;

    virtual bool is_mirostat_available() const noexcept {return false;}
    virtual bool is_grammar_available() const noexcept {return false;}

    LM_LAST_ERROR_GETTER
};


struct Implementation {
    bool is_fallback = false;
};
}
#endif // JUSTLM_HPP