mirror of
https://gitlab.com/niansa/libjustlm.git
synced 2025-03-06 20:49:17 +01:00
149 lines
5 KiB
C++
149 lines
5 KiB
C++
#ifndef JUSTLM_HPP
|
|
#define JUSTLM_HPP
|
|
#include <iostream>
|
|
#include <string>
|
|
#include <vector>
|
|
#include <functional>
|
|
#include <memory>
|
|
#include <thread>
|
|
|
|
#ifdef LM_NOEXCEPT
|
|
# define LM_NOEXCEPTDECL noexcept
|
|
# define LM_THROW(t, r) do {this->last_error = (t); return r;} while (0)
|
|
# define LM_LAST_ERROR_STORAGE mutable std::string last_error;
|
|
# define LM_LAST_ERROR_GETTER const std::string& get_last_error() const {return last_error;}
|
|
# define LM_ERRBOOL bool
|
|
# define LM_BOOL_ERROR false
|
|
# define LM_BOOL_SUCCESS true
|
|
# define LM_RETHROW(x) return x
|
|
# define LM_ERROR_CATCH(x, errval, ...) {auto v = x; if (v == (errval)) __VA_ARGS__}
|
|
# define LM_ERROR_FORWARD(x, errval) do {auto v = x; if (v == (errval)) return x;} while (0)
|
|
#else
|
|
# define LM_NOEXCEPTDECL
|
|
# define LM_THROW(t, r) throw Exception(t)
|
|
# define LM_LAST_ERROR_STORAGE
|
|
# define LM_LAST_ERROR_GETTER
|
|
# define LM_ERRBOOL void
|
|
# define LM_BOOL_ERROR
|
|
# define LM_BOOL_SUCCESS
|
|
# define LM_RETHROW(x) std::rethrow_exception(std::current_exception())
|
|
# define LM_ERROR_CATCH(x, errval, ...) try {x;} catch (...) __VA_ARGS__
|
|
# define LM_ERROR_FORWARD(x, errval) {x;}
|
|
#endif
|
|
|
|
#if _MSC_VER
|
|
#include <BaseTsd.h>
|
|
#endif
|
|
|
|
|
|
namespace LM {
|
|
#if _MSC_VER
|
|
using ssize_t = SSIZE_T;
|
|
#endif
|
|
|
|
using GenerateCallback = std::function<bool (const char *generated)>;
|
|
using AppendCallback = std::function<bool (float progress)>;
|
|
|
|
class Inference {
|
|
protected:
|
|
AppendCallback on_scroll = nullptr;
|
|
|
|
void *generic_state = nullptr;
|
|
|
|
LM_LAST_ERROR_STORAGE
|
|
|
|
public:
|
|
struct Exception : public std::runtime_error {
|
|
using std::runtime_error::runtime_error;
|
|
};
|
|
|
|
struct Params {
|
|
int seed = 0; // RNG seed
|
|
unsigned n_threads = 0; // Amount of threads to use, immutable after Inference was constructed
|
|
unsigned n_ctx = 2024; // Context size
|
|
unsigned n_ctx_window_top_bar = 0; // Top bar of context window. Must be smaller than context size
|
|
unsigned n_batch = 8; // Batch size
|
|
unsigned n_repeat_last = 0;
|
|
unsigned n_eos_ignores = 0;
|
|
|
|
float scroll_keep = 0.0f; // 0.4f to keep 40% of context below top bar when scrolling; 0.0f to remove everything after top bar
|
|
|
|
unsigned top_k = 40;
|
|
float top_p = 0.9f;
|
|
float temp = 0.72f;
|
|
float mirostat_learning_rate = 0.1f; // mirostat specific
|
|
float mirostat_target_entropy = 5.0f; // mirostat specific
|
|
float repeat_penalty = 1.0f;
|
|
|
|
unsigned n_gpu_layers = 38;
|
|
bool use_mlock = true; // llama specific
|
|
int prefer_mirostat = 0; // Use given mirostat version if available (see is_mirostat_available()); llama specific
|
|
} params;
|
|
|
|
struct Savestate {
|
|
std::vector<uint8_t> buf;
|
|
std::vector<int> tokens;
|
|
std::string prompt;
|
|
void *ctx = nullptr;
|
|
|
|
bool is_valid() const {
|
|
return ctx != nullptr;
|
|
}
|
|
};
|
|
|
|
Inference(const Params& p) : params(p) {
|
|
// Set random seed
|
|
params.seed = params.seed?params.seed:time(NULL);
|
|
params.n_threads = params.n_threads?params.n_threads:(static_cast<unsigned>(std::thread::hardware_concurrency()) / 2);
|
|
}
|
|
virtual ~Inference() {}
|
|
Inference(const Inference&) = delete;
|
|
Inference(Inference&) = delete;
|
|
Inference(Inference&& o)
|
|
: generic_state(o.generic_state)
|
|
, params(o.params) {
|
|
o.generic_state = nullptr;
|
|
}
|
|
|
|
static
|
|
Inference *construct(const std::string& weights_path, const Params& p);
|
|
|
|
void set_scroll_callback(const AppendCallback& scroll_cb) noexcept {
|
|
on_scroll = scroll_cb;
|
|
}
|
|
|
|
// This must be called with a non-empty prompt!
|
|
virtual LM_ERRBOOL append(const std::string& prompt, const AppendCallback& on_tick = nullptr) LM_NOEXCEPTDECL = 0;
|
|
|
|
// append() must have been called at least once before calling this!
|
|
virtual std::string run(std::string_view end = "", const GenerateCallback& on_tick = nullptr, const GenerateCallback& pre_tick = nullptr) LM_NOEXCEPTDECL = 0;
|
|
|
|
virtual unsigned get_context_size() const noexcept = 0;
|
|
|
|
virtual LM_ERRBOOL create_savestate(Savestate&) const LM_NOEXCEPTDECL = 0;
|
|
virtual LM_ERRBOOL restore_savestate(const Savestate&) LM_NOEXCEPTDECL = 0;
|
|
|
|
virtual LM_ERRBOOL serialize(std::ostream&) const LM_NOEXCEPTDECL = 0;
|
|
virtual LM_ERRBOOL deserialize(std::istream&) LM_NOEXCEPTDECL = 0;
|
|
|
|
virtual LM_ERRBOOL load_grammar(const std::string&, bool override_temperature [[maybe_unused]] = false) LM_NOEXCEPTDECL {
|
|
LM_THROW("Grammar is not available for this models backend", LM_BOOL_ERROR);
|
|
}
|
|
virtual LM_ERRBOOL unload_grammar() LM_NOEXCEPTDECL {
|
|
LM_THROW("Grammar is not available for this models backend", LM_BOOL_ERROR);
|
|
}
|
|
|
|
virtual const std::string& get_prompt() const LM_NOEXCEPTDECL = 0;
|
|
|
|
virtual bool is_mirostat_available() const noexcept {return false;}
|
|
virtual bool is_grammar_available() const noexcept {return false;}
|
|
|
|
LM_LAST_ERROR_GETTER
|
|
};
|
|
|
|
|
|
struct Implementation {
|
|
bool is_fallback = false;
|
|
};
|
|
}
|
|
#endif // JUSTLM_HPP
|