1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustlm.git synced 2025-03-06 20:49:17 +01:00
libjustlm/include/justlm.hpp
2024-03-25 01:18:37 +01:00

149 lines
5 KiB
C++

#ifndef JUSTLM_HPP
#define JUSTLM_HPP
#include <iostream>
#include <string>
#include <vector>
#include <functional>
#include <memory>
#include <thread>
#ifdef LM_NOEXCEPT
# define LM_NOEXCEPTDECL noexcept
# define LM_THROW(t, r) do {this->last_error = (t); return r;} while (0)
# define LM_LAST_ERROR_STORAGE mutable std::string last_error;
# define LM_LAST_ERROR_GETTER const std::string& get_last_error() const {return last_error;}
# define LM_ERRBOOL bool
# define LM_BOOL_ERROR false
# define LM_BOOL_SUCCESS true
# define LM_RETHROW(x) return x
# define LM_ERROR_CATCH(x, errval, ...) {auto v = x; if (v == (errval)) __VA_ARGS__}
# define LM_ERROR_FORWARD(x, errval) do {auto v = x; if (v == (errval)) return x;} while (0)
#else
# define LM_NOEXCEPTDECL
# define LM_THROW(t, r) throw Exception(t)
# define LM_LAST_ERROR_STORAGE
# define LM_LAST_ERROR_GETTER
# define LM_ERRBOOL void
# define LM_BOOL_ERROR
# define LM_BOOL_SUCCESS
# define LM_RETHROW(x) std::rethrow_exception(std::current_exception())
# define LM_ERROR_CATCH(x, errval, ...) try {x;} catch (...) __VA_ARGS__
# define LM_ERROR_FORWARD(x, errval) {x;}
#endif
#if _MSC_VER
#include <BaseTsd.h>
#endif
namespace LM {
#if _MSC_VER
using ssize_t = SSIZE_T;
#endif
using GenerateCallback = std::function<bool (const char *generated)>;
using AppendCallback = std::function<bool (float progress)>;
class Inference {
protected:
AppendCallback on_scroll = nullptr;
void *generic_state = nullptr;
LM_LAST_ERROR_STORAGE
public:
struct Exception : public std::runtime_error {
using std::runtime_error::runtime_error;
};
struct Params {
int seed = 0; // RNG seed
unsigned n_threads = 0; // Amount of threads to use, immutable after Inference was constructed
unsigned n_ctx = 2024; // Context size
unsigned n_ctx_window_top_bar = 0; // Top bar of context window. Must be smaller than context size
unsigned n_batch = 8; // Batch size
unsigned n_repeat_last = 0;
unsigned n_eos_ignores = 0;
float scroll_keep = 0.0f; // 0.4f to keep 40% of context below top bar when scrolling; 0.0f to remove everything after top bar
unsigned top_k = 40;
float top_p = 0.9f;
float temp = 0.72f;
float mirostat_learning_rate = 0.1f; // mirostat specific
float mirostat_target_entropy = 5.0f; // mirostat specific
float repeat_penalty = 1.0f;
unsigned n_gpu_layers = 38;
bool use_mlock = true; // llama specific
int prefer_mirostat = 0; // Use given mirostat version if available (see is_mirostat_available()); llama specific
} params;
struct Savestate {
std::vector<uint8_t> buf;
std::vector<int> tokens;
std::string prompt;
void *ctx = nullptr;
bool is_valid() const {
return ctx != nullptr;
}
};
Inference(const Params& p) : params(p) {
// Set random seed
params.seed = params.seed?params.seed:time(NULL);
params.n_threads = params.n_threads?params.n_threads:(static_cast<unsigned>(std::thread::hardware_concurrency()) / 2);
}
virtual ~Inference() {}
Inference(const Inference&) = delete;
Inference(Inference&) = delete;
Inference(Inference&& o)
: generic_state(o.generic_state)
, params(o.params) {
o.generic_state = nullptr;
}
static
Inference *construct(const std::string& weights_path, const Params& p);
void set_scroll_callback(const AppendCallback& scroll_cb) noexcept {
on_scroll = scroll_cb;
}
// This must be called with a non-empty prompt!
virtual LM_ERRBOOL append(const std::string& prompt, const AppendCallback& on_tick = nullptr) LM_NOEXCEPTDECL = 0;
// append() must have been called at least once before calling this!
virtual std::string run(std::string_view end = "", const GenerateCallback& on_tick = nullptr, const GenerateCallback& pre_tick = nullptr) LM_NOEXCEPTDECL = 0;
virtual unsigned get_context_size() const noexcept = 0;
virtual LM_ERRBOOL create_savestate(Savestate&) const LM_NOEXCEPTDECL = 0;
virtual LM_ERRBOOL restore_savestate(const Savestate&) LM_NOEXCEPTDECL = 0;
virtual LM_ERRBOOL serialize(std::ostream&) const LM_NOEXCEPTDECL = 0;
virtual LM_ERRBOOL deserialize(std::istream&) LM_NOEXCEPTDECL = 0;
virtual LM_ERRBOOL load_grammar(const std::string&, bool override_temperature [[maybe_unused]] = false) LM_NOEXCEPTDECL {
LM_THROW("Grammar is not available for this models backend", LM_BOOL_ERROR);
}
virtual LM_ERRBOOL unload_grammar() LM_NOEXCEPTDECL {
LM_THROW("Grammar is not available for this models backend", LM_BOOL_ERROR);
}
virtual const std::string& get_prompt() const LM_NOEXCEPTDECL = 0;
virtual bool is_mirostat_available() const noexcept {return false;}
virtual bool is_grammar_available() const noexcept {return false;}
LM_LAST_ERROR_GETTER
};
struct Implementation {
bool is_fallback = false;
};
}
#endif // JUSTLM_HPP