1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustlm.git synced 2025-03-06 20:49:17 +01:00

Implemented proper scrolling

This commit is contained in:
niansa 2023-04-28 18:04:07 +02:00
parent 4f03b75560
commit d236e36d26
3 changed files with 136 additions and 63 deletions

View file

@ -11,6 +11,8 @@
namespace LM {
class Inference {
protected:
std::function<bool (float)> on_scroll = nullptr;
void *generic_state = nullptr;
static inline
@ -20,6 +22,7 @@ protected:
}
public:
struct Exception : public std::runtime_error {
using std::runtime_error::runtime_error;
};
@ -32,6 +35,8 @@ public:
unsigned n_batch = 8; // Batch size
unsigned n_repeat_last = 0; // llama.cpp specific
float scroll_keep = 0.0f; // 0.4f to keep 40% of context below top bar when scrolling; 0.0f to remove everything after top bar
unsigned top_k = 40;
float top_p = 0.9f;
float temp = 0.72f;
@ -69,6 +74,10 @@ public:
static
Inference *construct(const std::string& weights_path, const Params& p);
void set_scroll_callback(const std::function<bool (float)>& scroll_cb) {
on_scroll = scroll_cb;
}
// This must be called with a non-empty prompt!
virtual void append(const std::string& prompt, const std::function<bool (float progress)>& on_tick = nullptr) = 0;

View file

@ -65,15 +65,67 @@ class GPTJInference final : public Inference {
}
}
void window_scroll() {
// This function reduces the size of our tokens vector according to some parameters
// All tokens will be evaluated if scrolling was needed and true will be returned
bool window_scroll() {
auto &state = get_state();
if (state->tokens.size() > params.n_ctx) {
// "Scroll" down the context window...
unsigned overflow = state->tokens.size() - params.n_ctx;
std::vector<int> tokens_in_view(state->tokens.begin()+params.n_ctx_window_top_bar+overflow, state->tokens.end());
state->tokens.resize(params.n_ctx);
std::memcpy(state->tokens.data()+params.n_ctx_window_top_bar, tokens_in_view.data(), tokens_in_view.size());
// Check that we actually need to scroll
if (state->tokens.size() <= params.n_ctx) {
// Nope
return false;
}
// Start scrolling
if (params.scroll_keep > 0.0f) {
// "Scroll" down the context window...
unsigned keep_count = float(state->tokens.size() - params.n_ctx_window_top_bar) * 0.4f; // We keep about 40%
// Get vector of tokens to keep
std::vector<int> tokens_in_view(state->tokens.end()-keep_count, state->tokens.end());
// Cut down tokens vector size
state->tokens.resize(params.n_ctx_window_top_bar+keep_count);
// Overwrite tokens after top bar with tokens in view
std::memcpy(state->tokens.data()+params.n_ctx_window_top_bar, tokens_in_view.data(), tokens_in_view.size()*sizeof(int));
} else {
// Cut down tokens vector size to top bar
state->tokens.resize(params.n_ctx_window_top_bar);
}
// Evaluate tokens
evaluate_tokens(0, on_scroll);
// We've scrolled!
return true;
}
void evaluate_tokens(size_t starting_offset, const std::function<bool (float)> &on_tick = nullptr) {
auto& state = get_state();
// Evaluate tokens in batches
unsigned it;
for (it = starting_offset; ; it += params.n_batch) {
if (it + params.n_batch >= ssize_t(state->tokens.size())) break;
// Evaluate
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+params.n_batch);
gptj_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token);
// Tick
if (on_tick) {
// Calculate progress
auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f;
// Run callback
if (!on_tick(progress)) break;
}
}
// Evaluate remaining tokens
if (it < state->tokens.size()) {
for (; it != state->tokens.size(); it++) {
//TODO: This is extremely inefficient! Don't do that...
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+1);
gptj_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token);
}
}
// Notify about completion
if (on_tick) on_tick(100.f);
}
public:
@ -102,33 +154,13 @@ public:
);
// Make sure token limit isn't being hit
window_scroll();
// Evaluate new tokens in batches
int it;
for (it = old_token_count; ; it += params.n_batch) {
if (it >= ssize_t(state->tokens.size()) - params.n_batch) break;
// Evaluate
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+params.n_batch);
gptj_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token);
// Tick
if (on_tick) {
// Calculate progress
auto progress = float(it-old_token_count) / (state->tokens.size()-old_token_count) * 100.f;
// Run callback
if (!on_tick(progress)) break;
}
}
// Evaluate remaining tokens
if (it < state->tokens.size()) {
for (; it != state->tokens.size(); it++) {
//TODO: This is extremely inefficient! Don't do that...
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+1);
gptj_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token);
}
if (window_scroll()) {
// That function already has evaluated our tokens since scrolling was needed
return;
}
// Evaluate new tokens
evaluate_tokens(old_token_count, on_tick);
}
std::string run(std::string_view end, const std::function<bool (const char *)> &on_tick = nullptr) override {

View file

@ -43,15 +43,64 @@ class LLaMaInference final : public Inference {
state->n_ctx = llama_n_ctx(state->ctx);
}
void window_scroll() {
// This function reduces the size of our tokens vector according to some parameters
// All tokens will be evaluated if scrolling was needed and true will be returned
bool window_scroll() {
auto &state = get_state();
if (state->tokens.size() > state->n_ctx) {
// "Scroll" down the context window...
unsigned overflow = state->tokens.size() - state->n_ctx;
std::vector<int> tokens_in_view(state->tokens.begin()+params.n_ctx_window_top_bar+overflow, state->tokens.end());
state->tokens.resize(state->n_ctx);
std::memcpy(state->tokens.data()+params.n_ctx_window_top_bar, tokens_in_view.data(), tokens_in_view.size());
// Check that we actually need to scroll
if (state->tokens.size() <= state->n_ctx) {
// Nope
return false;
}
// Start scrolling
if (params.scroll_keep > 0.0f) {
// "Scroll" down the context window...
unsigned keep_count = float(state->tokens.size() - params.n_ctx_window_top_bar) * 0.4f; // We keep about 40%
// Get vector of tokens to keep
std::vector<int> tokens_in_view(state->tokens.end()-keep_count, state->tokens.end());
// Cut down tokens vector size
state->tokens.resize(params.n_ctx_window_top_bar+keep_count);
// Overwrite tokens after top bar with tokens in view
std::memcpy(state->tokens.data()+params.n_ctx_window_top_bar, tokens_in_view.data(), tokens_in_view.size()*sizeof(int));
} else {
// Cut down tokens vector size to top bar
state->tokens.resize(params.n_ctx_window_top_bar);
}
// Evaluate tokens
evaluate_tokens(0, on_scroll);
// We've scrolled!
return true;
}
void evaluate_tokens(size_t starting_offset, const std::function<bool (float)> &on_tick = nullptr) {
auto& state = get_state();
// Evaluate tokens in batches
unsigned it;
for (it = starting_offset; ; it += params.n_batch) {
if (it + params.n_batch >= ssize_t(state->tokens.size())) break;
// Evaluate
llama_eval(state->ctx, state->tokens.data()+it, params.n_batch, it, params.n_threads);
// Tick
if (on_tick) {
// Calculate progress
auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f;
// Run callback
if (!on_tick(progress)) break;
}
}
// Evaluate remaining tokens
if (it < state->tokens.size()) {
for (; it != state->tokens.size(); it++) {
llama_eval(state->ctx, state->tokens.data()+it, 1, it, params.n_threads);
}
}
// Notify about completion
if (on_tick) on_tick(100.f);
}
public:
@ -85,30 +134,13 @@ public:
state->tokens.resize(old_token_count+token_count);
// Make sure token limit isn't being hit
window_scroll();
// Evaluate new tokens in batches
int it;
for (it = old_token_count; ; it += params.n_batch) {
if (it >= ssize_t(state->tokens.size()) - params.n_batch) break;
// Evaluate
llama_eval(state->ctx, state->tokens.data()+it, params.n_batch, it, params.n_threads);
// Tick
if (on_tick) {
// Calculate progress
auto progress = float(it-old_token_count) / (state->tokens.size()-old_token_count) * 100.f;
// Run callback
if (!on_tick(progress)) break;
}
}
// Evaluate remaining tokens
if (it < state->tokens.size()) {
for (; it != state->tokens.size(); it++) {
llama_eval(state->ctx, state->tokens.data()+it, 1, it, params.n_threads);
}
if (window_scroll()) {
// That function already has evaluated our tokens since scrolling was needed
return;
}
// Evaluate new tokens
evaluate_tokens(old_token_count, on_tick);
}
std::string run(std::string_view end, const std::function<bool (const char *)> &on_tick = nullptr) override {