mirror of
https://gitlab.com/niansa/libjustlm.git
synced 2025-03-06 20:49:17 +01:00
Implemented proper scrolling
This commit is contained in:
parent
4f03b75560
commit
d236e36d26
3 changed files with 136 additions and 63 deletions
|
@ -11,6 +11,8 @@
|
|||
namespace LM {
|
||||
class Inference {
|
||||
protected:
|
||||
std::function<bool (float)> on_scroll = nullptr;
|
||||
|
||||
void *generic_state = nullptr;
|
||||
|
||||
static inline
|
||||
|
@ -20,6 +22,7 @@ protected:
|
|||
}
|
||||
|
||||
public:
|
||||
|
||||
struct Exception : public std::runtime_error {
|
||||
using std::runtime_error::runtime_error;
|
||||
};
|
||||
|
@ -32,6 +35,8 @@ public:
|
|||
unsigned n_batch = 8; // Batch size
|
||||
unsigned n_repeat_last = 0; // llama.cpp specific
|
||||
|
||||
float scroll_keep = 0.0f; // 0.4f to keep 40% of context below top bar when scrolling; 0.0f to remove everything after top bar
|
||||
|
||||
unsigned top_k = 40;
|
||||
float top_p = 0.9f;
|
||||
float temp = 0.72f;
|
||||
|
@ -69,6 +74,10 @@ public:
|
|||
static
|
||||
Inference *construct(const std::string& weights_path, const Params& p);
|
||||
|
||||
void set_scroll_callback(const std::function<bool (float)>& scroll_cb) {
|
||||
on_scroll = scroll_cb;
|
||||
}
|
||||
|
||||
// This must be called with a non-empty prompt!
|
||||
virtual void append(const std::string& prompt, const std::function<bool (float progress)>& on_tick = nullptr) = 0;
|
||||
|
||||
|
|
|
@ -65,15 +65,67 @@ class GPTJInference final : public Inference {
|
|||
}
|
||||
}
|
||||
|
||||
void window_scroll() {
|
||||
// This function reduces the size of our tokens vector according to some parameters
|
||||
// All tokens will be evaluated if scrolling was needed and true will be returned
|
||||
bool window_scroll() {
|
||||
auto &state = get_state();
|
||||
if (state->tokens.size() > params.n_ctx) {
|
||||
// "Scroll" down the context window...
|
||||
unsigned overflow = state->tokens.size() - params.n_ctx;
|
||||
std::vector<int> tokens_in_view(state->tokens.begin()+params.n_ctx_window_top_bar+overflow, state->tokens.end());
|
||||
state->tokens.resize(params.n_ctx);
|
||||
std::memcpy(state->tokens.data()+params.n_ctx_window_top_bar, tokens_in_view.data(), tokens_in_view.size());
|
||||
// Check that we actually need to scroll
|
||||
if (state->tokens.size() <= params.n_ctx) {
|
||||
// Nope
|
||||
return false;
|
||||
}
|
||||
// Start scrolling
|
||||
if (params.scroll_keep > 0.0f) {
|
||||
// "Scroll" down the context window...
|
||||
unsigned keep_count = float(state->tokens.size() - params.n_ctx_window_top_bar) * 0.4f; // We keep about 40%
|
||||
// Get vector of tokens to keep
|
||||
std::vector<int> tokens_in_view(state->tokens.end()-keep_count, state->tokens.end());
|
||||
// Cut down tokens vector size
|
||||
state->tokens.resize(params.n_ctx_window_top_bar+keep_count);
|
||||
// Overwrite tokens after top bar with tokens in view
|
||||
std::memcpy(state->tokens.data()+params.n_ctx_window_top_bar, tokens_in_view.data(), tokens_in_view.size()*sizeof(int));
|
||||
} else {
|
||||
// Cut down tokens vector size to top bar
|
||||
state->tokens.resize(params.n_ctx_window_top_bar);
|
||||
}
|
||||
// Evaluate tokens
|
||||
evaluate_tokens(0, on_scroll);
|
||||
// We've scrolled!
|
||||
return true;
|
||||
}
|
||||
|
||||
void evaluate_tokens(size_t starting_offset, const std::function<bool (float)> &on_tick = nullptr) {
|
||||
auto& state = get_state();
|
||||
|
||||
// Evaluate tokens in batches
|
||||
unsigned it;
|
||||
for (it = starting_offset; ; it += params.n_batch) {
|
||||
if (it + params.n_batch >= ssize_t(state->tokens.size())) break;
|
||||
|
||||
// Evaluate
|
||||
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+params.n_batch);
|
||||
gptj_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token);
|
||||
|
||||
// Tick
|
||||
if (on_tick) {
|
||||
// Calculate progress
|
||||
auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f;
|
||||
// Run callback
|
||||
if (!on_tick(progress)) break;
|
||||
}
|
||||
}
|
||||
|
||||
// Evaluate remaining tokens
|
||||
if (it < state->tokens.size()) {
|
||||
for (; it != state->tokens.size(); it++) {
|
||||
//TODO: This is extremely inefficient! Don't do that...
|
||||
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+1);
|
||||
gptj_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token);
|
||||
}
|
||||
}
|
||||
|
||||
// Notify about completion
|
||||
if (on_tick) on_tick(100.f);
|
||||
}
|
||||
|
||||
public:
|
||||
|
@ -102,33 +154,13 @@ public:
|
|||
);
|
||||
|
||||
// Make sure token limit isn't being hit
|
||||
window_scroll();
|
||||
|
||||
// Evaluate new tokens in batches
|
||||
int it;
|
||||
for (it = old_token_count; ; it += params.n_batch) {
|
||||
if (it >= ssize_t(state->tokens.size()) - params.n_batch) break;
|
||||
|
||||
// Evaluate
|
||||
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+params.n_batch);
|
||||
gptj_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token);
|
||||
|
||||
// Tick
|
||||
if (on_tick) {
|
||||
// Calculate progress
|
||||
auto progress = float(it-old_token_count) / (state->tokens.size()-old_token_count) * 100.f;
|
||||
// Run callback
|
||||
if (!on_tick(progress)) break;
|
||||
}
|
||||
}
|
||||
// Evaluate remaining tokens
|
||||
if (it < state->tokens.size()) {
|
||||
for (; it != state->tokens.size(); it++) {
|
||||
//TODO: This is extremely inefficient! Don't do that...
|
||||
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+1);
|
||||
gptj_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token);
|
||||
}
|
||||
if (window_scroll()) {
|
||||
// That function already has evaluated our tokens since scrolling was needed
|
||||
return;
|
||||
}
|
||||
|
||||
// Evaluate new tokens
|
||||
evaluate_tokens(old_token_count, on_tick);
|
||||
}
|
||||
|
||||
std::string run(std::string_view end, const std::function<bool (const char *)> &on_tick = nullptr) override {
|
||||
|
|
|
@ -43,15 +43,64 @@ class LLaMaInference final : public Inference {
|
|||
state->n_ctx = llama_n_ctx(state->ctx);
|
||||
}
|
||||
|
||||
void window_scroll() {
|
||||
// This function reduces the size of our tokens vector according to some parameters
|
||||
// All tokens will be evaluated if scrolling was needed and true will be returned
|
||||
bool window_scroll() {
|
||||
auto &state = get_state();
|
||||
if (state->tokens.size() > state->n_ctx) {
|
||||
// "Scroll" down the context window...
|
||||
unsigned overflow = state->tokens.size() - state->n_ctx;
|
||||
std::vector<int> tokens_in_view(state->tokens.begin()+params.n_ctx_window_top_bar+overflow, state->tokens.end());
|
||||
state->tokens.resize(state->n_ctx);
|
||||
std::memcpy(state->tokens.data()+params.n_ctx_window_top_bar, tokens_in_view.data(), tokens_in_view.size());
|
||||
// Check that we actually need to scroll
|
||||
if (state->tokens.size() <= state->n_ctx) {
|
||||
// Nope
|
||||
return false;
|
||||
}
|
||||
// Start scrolling
|
||||
if (params.scroll_keep > 0.0f) {
|
||||
// "Scroll" down the context window...
|
||||
unsigned keep_count = float(state->tokens.size() - params.n_ctx_window_top_bar) * 0.4f; // We keep about 40%
|
||||
// Get vector of tokens to keep
|
||||
std::vector<int> tokens_in_view(state->tokens.end()-keep_count, state->tokens.end());
|
||||
// Cut down tokens vector size
|
||||
state->tokens.resize(params.n_ctx_window_top_bar+keep_count);
|
||||
// Overwrite tokens after top bar with tokens in view
|
||||
std::memcpy(state->tokens.data()+params.n_ctx_window_top_bar, tokens_in_view.data(), tokens_in_view.size()*sizeof(int));
|
||||
} else {
|
||||
// Cut down tokens vector size to top bar
|
||||
state->tokens.resize(params.n_ctx_window_top_bar);
|
||||
}
|
||||
// Evaluate tokens
|
||||
evaluate_tokens(0, on_scroll);
|
||||
// We've scrolled!
|
||||
return true;
|
||||
}
|
||||
|
||||
void evaluate_tokens(size_t starting_offset, const std::function<bool (float)> &on_tick = nullptr) {
|
||||
auto& state = get_state();
|
||||
|
||||
// Evaluate tokens in batches
|
||||
unsigned it;
|
||||
for (it = starting_offset; ; it += params.n_batch) {
|
||||
if (it + params.n_batch >= ssize_t(state->tokens.size())) break;
|
||||
|
||||
// Evaluate
|
||||
llama_eval(state->ctx, state->tokens.data()+it, params.n_batch, it, params.n_threads);
|
||||
|
||||
// Tick
|
||||
if (on_tick) {
|
||||
// Calculate progress
|
||||
auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f;
|
||||
// Run callback
|
||||
if (!on_tick(progress)) break;
|
||||
}
|
||||
}
|
||||
|
||||
// Evaluate remaining tokens
|
||||
if (it < state->tokens.size()) {
|
||||
for (; it != state->tokens.size(); it++) {
|
||||
llama_eval(state->ctx, state->tokens.data()+it, 1, it, params.n_threads);
|
||||
}
|
||||
}
|
||||
|
||||
// Notify about completion
|
||||
if (on_tick) on_tick(100.f);
|
||||
}
|
||||
|
||||
public:
|
||||
|
@ -85,30 +134,13 @@ public:
|
|||
state->tokens.resize(old_token_count+token_count);
|
||||
|
||||
// Make sure token limit isn't being hit
|
||||
window_scroll();
|
||||
|
||||
// Evaluate new tokens in batches
|
||||
int it;
|
||||
for (it = old_token_count; ; it += params.n_batch) {
|
||||
if (it >= ssize_t(state->tokens.size()) - params.n_batch) break;
|
||||
|
||||
// Evaluate
|
||||
llama_eval(state->ctx, state->tokens.data()+it, params.n_batch, it, params.n_threads);
|
||||
|
||||
// Tick
|
||||
if (on_tick) {
|
||||
// Calculate progress
|
||||
auto progress = float(it-old_token_count) / (state->tokens.size()-old_token_count) * 100.f;
|
||||
// Run callback
|
||||
if (!on_tick(progress)) break;
|
||||
}
|
||||
}
|
||||
// Evaluate remaining tokens
|
||||
if (it < state->tokens.size()) {
|
||||
for (; it != state->tokens.size(); it++) {
|
||||
llama_eval(state->ctx, state->tokens.data()+it, 1, it, params.n_threads);
|
||||
}
|
||||
if (window_scroll()) {
|
||||
// That function already has evaluated our tokens since scrolling was needed
|
||||
return;
|
||||
}
|
||||
|
||||
// Evaluate new tokens
|
||||
evaluate_tokens(old_token_count, on_tick);
|
||||
}
|
||||
|
||||
std::string run(std::string_view end, const std::function<bool (const char *)> &on_tick = nullptr) override {
|
||||
|
|
Loading…
Add table
Reference in a new issue