1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustlm.git synced 2025-03-06 20:49:17 +01:00
libjustlm/libjustlm_gpt2.cpp
2023-03-30 07:03:33 -05:00

80 lines
2.3 KiB
C++

#include "justlm.hpp"
#include "gpt2/gpt2tc.h"
#include <filesystem>
#include <cstring>
struct State {
std::string prompt;
std::string model_path;
GPT2ModelEnum model;
} state;
void LLM::init(const std::string& weights_path) {
state->model_path = weights_path;
// Get weight file size
auto weights_size = std::filesystem::file_size(weights_path);
// Determine weight size
switch (weights_size) {
case 250700242: state->model = GPT2_MODEL_117M; break;
case 3120522738: state->model = GPT2_MODEL_1558M; break;
case 712396722: state->model = GPT2_MODEL_345M; break;
case 1551900050: state->model = GPT2_MODEL_774M; break;
default: throw Exception("Unknown model size");
}
}
LLM::~LLM() {
delete state;
}
void LLM::append(std::string_view prompt, const std::function<bool (float)> &on_tick) {
state->prompt.append(prompt);
std::cout << prompt << std::endl;
}
std::string LLM::run(std::string_view end, const std::function<bool (const char *)> &on_tick) {
std::string fres;
TextCompleteGlobalState *tcs;
TextGenContext *ts;
int count;
struct timeval tv;
struct list_head ts_list;
// Initialize completion
tcs = text_complete_global_init(state->model, state->model_path.c_str());
// Run completion
ts = text_complete_start(tcs, state->prompt.c_str(), params.top_k, params.top_p, params.temp,
params.seed, params.n_prompt>0?params.n_prompt:0xfffffff - state->prompt.size());
bool abort = false;
while (!abort && !ends_with(fres, end)) {
// Run completion
init_list_head(&ts_list);
list_add_tail(&ts->link, &ts_list);
text_complete_next(tcs, &ts_list);
if (ts->out_text_len == 0)
break;
auto str = std::string_view{ts->out_text, static_cast<std::string_view::size_type>(ts->out_text_len)};
// Append result to fres
fres.append(str);
// Tick
if (on_tick && !on_tick(std::string(str).c_str()) /*Huge overhead in favor of llama.cpp*/) abort = true;
}
// End completion
text_complete_end(ts);
text_complete_global_end(tcs);
// Create final string TODO: Could be optimized
state->prompt.append(fres);
fres = std::string(fres.data(), fres.size()-end.size());
// Return final string
return fres;
}