mirror of
https://gitlab.com/niansa/discord_llama.git
synced 2025-03-06 20:48:25 +01:00
Use libjustlm
This commit is contained in:
parent
766700602b
commit
553697d65e
5 changed files with 37 additions and 197 deletions
6
.gitmodules
vendored
6
.gitmodules
vendored
|
@ -1,6 +1,6 @@
|
|||
[submodule "DPP"]
|
||||
path = DPP
|
||||
url = https://github.com/brainboxdotcc/DPP.git
|
||||
[submodule "llama.cpp"]
|
||||
path = llama.cpp
|
||||
url = https://github.com/ggerganov/llama.cpp.git
|
||||
[submodule "libjustlm"]
|
||||
path = libjustlm
|
||||
url = https://gitlab.com/niansa/libjustlm.git
|
||||
|
|
|
@ -5,11 +5,11 @@ project(discord_llama LANGUAGES C CXX)
|
|||
set(CMAKE_CXX_STANDARD 20)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
add_subdirectory(llama.cpp)
|
||||
add_subdirectory(libjustlm)
|
||||
add_subdirectory(DPP)
|
||||
|
||||
add_executable(discord_llama main.cpp)
|
||||
target_link_libraries(discord_llama PUBLIC dpp pthread llama ggml)
|
||||
target_link_libraries(discord_llama PUBLIC dpp pthread libjustlm ggml)
|
||||
|
||||
install(TARGETS discord_llama
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})
|
||||
|
|
1
libjustlm
Submodule
1
libjustlm
Submodule
|
@ -0,0 +1 @@
|
|||
Subproject commit dc7fe7f9f01681544916da4d41ef704d254a4ca4
|
|
@ -1 +0,0 @@
|
|||
Subproject commit 19726169b379bebc96189673a19b89ab1d307659
|
222
main.cpp
222
main.cpp
|
@ -15,8 +15,7 @@
|
|||
#include <mutex>
|
||||
#include <memory>
|
||||
#include <dpp/dpp.h>
|
||||
#include <ggml.h>
|
||||
#include <llama.h>
|
||||
#include <justlm.hpp>
|
||||
|
||||
|
||||
|
||||
|
@ -43,194 +42,33 @@ void str_replace_in_place(std::string& subject, std::string_view search,
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
class LLM {
|
||||
struct {
|
||||
std::string model = "7B-ggml-model-quant.bin";
|
||||
|
||||
int32_t seed; // RNG seed
|
||||
int32_t n_threads = static_cast<int32_t>(std::thread::hardware_concurrency()) / 4;
|
||||
int32_t n_ctx = 2024; // Context size
|
||||
int32_t n_batch = 8; // Batch size, unused for now
|
||||
|
||||
int32_t top_k = 40;
|
||||
float top_p = 0.5f;
|
||||
float temp = 0.72f;
|
||||
|
||||
bool no_repeat = true;
|
||||
} params;
|
||||
|
||||
struct State {
|
||||
std::string prompt;
|
||||
std::vector<llama_token> embd;
|
||||
int n_ctx;
|
||||
std::string last_result;
|
||||
int repeats;
|
||||
} state;
|
||||
|
||||
llama_context *ctx = nullptr;
|
||||
std::mutex lock;
|
||||
|
||||
static inline
|
||||
std::string clean_string(const std::string& str) {
|
||||
std::string fres;
|
||||
for (const auto c : str) {
|
||||
if ((c >= 0x20 && c <= 0x7E)
|
||||
|| c == '\n'
|
||||
|| c == "ä"[0] || c == "ä"[1] || c == "ä"[2]
|
||||
|| c == "ö"[0] || c == "ö"[1] || c == "ö"[2]
|
||||
|| c == "ü"[0] || c == "ü"[1] || c == "ü"[2]
|
||||
|| c == "Ä"[0] || c == "Ä"[1] || c == "Ä"[2]
|
||||
|| c == "Ö"[0] || c == "Ö"[1] || c == "Ö"[2]
|
||||
|| c == "Ü"[0] || c == "Ü"[1] || c == "Ü"[2]
|
||||
|| c == "ß"[0] || c == "ß"[1] || c == "ß"[2]) {
|
||||
fres.push_back(c);
|
||||
}
|
||||
static inline
|
||||
std::string clean_string(std::string_view str) {
|
||||
std::string fres;
|
||||
for (const auto c : str) {
|
||||
if ((c >= 0x20 && c <= 0x7E)
|
||||
|| c == '\n'
|
||||
|| c == "ä"[0] || c == "ä"[1] || c == "ä"[2]
|
||||
|| c == "ö"[0] || c == "ö"[1] || c == "ö"[2]
|
||||
|| c == "ü"[0] || c == "ü"[1] || c == "ü"[2]
|
||||
|| c == "Ä"[0] || c == "Ä"[1] || c == "Ä"[2]
|
||||
|| c == "Ö"[0] || c == "Ö"[1] || c == "Ö"[2]
|
||||
|| c == "Ü"[0] || c == "Ü"[1] || c == "Ü"[2]
|
||||
|| c == "ß"[0] || c == "ß"[1] || c == "ß"[2]) {
|
||||
fres.push_back(c);
|
||||
}
|
||||
return fres;
|
||||
}
|
||||
|
||||
void init() {
|
||||
// Get llama parameters
|
||||
auto lparams = llama_context_default_params();
|
||||
lparams.seed = params.seed;
|
||||
lparams.n_ctx = 2024;
|
||||
|
||||
// Create context
|
||||
ctx = llama_init_from_file(params.model.c_str(), lparams);
|
||||
if (!ctx) {
|
||||
throw Exception("Failed to initialize llama from file");
|
||||
}
|
||||
|
||||
// Initialize some variables
|
||||
state.n_ctx = llama_n_ctx(ctx);
|
||||
state.repeats = 0;
|
||||
}
|
||||
|
||||
public:
|
||||
struct Exception : public std::runtime_error {
|
||||
using std::runtime_error::runtime_error;
|
||||
};
|
||||
struct ContextLengthException : public Exception {
|
||||
ContextLengthException() : Exception("Max. context length exceeded") {}
|
||||
};
|
||||
|
||||
|
||||
LLM(int32_t seed = 0) {
|
||||
// Set random seed
|
||||
params.seed = seed?seed:time(NULL);
|
||||
|
||||
// Initialize llama
|
||||
init();
|
||||
}
|
||||
~LLM() {
|
||||
std::scoped_lock L(lock);
|
||||
if (ctx) llama_free(ctx);
|
||||
}
|
||||
|
||||
void append(std::string prompt, const std::function<bool (float progress)>& on_tick = nullptr) {
|
||||
std::scoped_lock L(lock);
|
||||
|
||||
// Remove non-printables
|
||||
prompt = clean_string(prompt);
|
||||
|
||||
// Check if prompt was empty
|
||||
const bool was_empty = state.prompt.empty();
|
||||
|
||||
// Append to current prompt
|
||||
state.prompt.append(prompt);
|
||||
|
||||
// Debug
|
||||
std::ofstream("prompt.txt") << state.prompt;
|
||||
|
||||
// Resize buffer for tokens
|
||||
const auto old_token_count = state.embd.size();
|
||||
state.embd.resize(old_token_count+state.prompt.size()+1);
|
||||
|
||||
// Run tokenizer
|
||||
const auto token_count = llama_tokenize(ctx, prompt.data(), state.embd.data()+old_token_count, state.embd.size()-old_token_count, was_empty);
|
||||
state.embd.resize(old_token_count+token_count);
|
||||
|
||||
// Make sure limit is far from being hit
|
||||
if (state.embd.size() > state.n_ctx-6) {
|
||||
// Yup. *this MUST be decomposed now.
|
||||
throw ContextLengthException();
|
||||
}
|
||||
|
||||
// Evaluate new tokens
|
||||
// TODO: Larger batch size
|
||||
std::cout << "Context size: " << old_token_count << '+' << token_count << '=' << state.embd.size() << '/' << state.n_ctx << std::endl;
|
||||
for (int it = old_token_count; it != state.embd.size(); it++) {
|
||||
std::cout << llama_token_to_str(ctx, state.embd.data()[it]) << std::flush;
|
||||
llama_eval(ctx, state.embd.data()+it, 1, it, params.n_threads);
|
||||
|
||||
// Tick
|
||||
if (on_tick) {
|
||||
// Calculate progress
|
||||
auto progress = float(it) / (state.embd.size()) * 100.f;
|
||||
// Run callback
|
||||
if (!on_tick(progress)) break;
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
std::string run(std::string_view end, const std::function<bool ()>& on_tick = nullptr) {
|
||||
std::scoped_lock L(lock);
|
||||
std::string fres;
|
||||
|
||||
// Loop until done
|
||||
bool abort = false;
|
||||
while (!abort && !fres.ends_with(end)) {
|
||||
// Sample top p and top k
|
||||
bool has_repeated = state.repeats>=4;
|
||||
const auto id = llama_sample_top_p_top_k(ctx, nullptr, 0, params.top_k, has_repeated?(params.top_p+0.15f):params.top_p, has_repeated?(params.temp+0.4f):params.temp, 1.0f);
|
||||
|
||||
// Add token
|
||||
state.embd.push_back(id);
|
||||
|
||||
// Get token as string
|
||||
const auto str = llama_token_to_str(ctx, id);
|
||||
|
||||
// Debug
|
||||
std::cout << str << std::flush;
|
||||
|
||||
// Append string to function result
|
||||
fres.append(str);
|
||||
|
||||
// Evaluate token
|
||||
// TODO: Respect batch size
|
||||
llama_eval(ctx, state.embd.data()+state.embd.size()-1, 1, state.embd.size()-1, params.n_threads);
|
||||
|
||||
// Tick
|
||||
if (on_tick && !on_tick()) abort = true;
|
||||
}
|
||||
|
||||
// Create final string
|
||||
state.prompt.append(fres);
|
||||
fres = std::string(fres.data(), fres.size()-end.size());
|
||||
|
||||
// Check for repetition
|
||||
if (state.last_result == fres && params.no_repeat) {
|
||||
state.repeats++;
|
||||
} else {
|
||||
state.repeats = 0;
|
||||
state.last_result = fres;
|
||||
}
|
||||
|
||||
// Return final string
|
||||
return fres;
|
||||
}
|
||||
};
|
||||
return fres;
|
||||
}
|
||||
|
||||
|
||||
class Bot {
|
||||
RandomGenerator rng;
|
||||
Timer last_message_timer;
|
||||
std::shared_ptr<bool> stopping;
|
||||
std::unique_ptr<LLM> llm;
|
||||
std::unique_ptr<LM::Inference> llm;
|
||||
std::vector<dpp::snowflake> my_messages;
|
||||
std::mutex llm_init_lock;
|
||||
std::mutex llm_lock;
|
||||
|
||||
dpp::cluster bot;
|
||||
dpp::channel channel;
|
||||
|
@ -269,8 +107,8 @@ class Bot {
|
|||
if (!llm) {
|
||||
// Make sure llm is initialized
|
||||
{
|
||||
std::unique_lock L(llm_init_lock);
|
||||
llm = std::make_unique<LLM>();
|
||||
std::unique_lock L(llm_lock);
|
||||
llm = std::make_unique<LM::Inference>("7B-ggml-model-quant.bin");
|
||||
}
|
||||
// Create message for reporting progress
|
||||
dpp::message msg(channel_id, "Wird initialisiert...");
|
||||
|
@ -295,6 +133,7 @@ class Bot {
|
|||
return true;
|
||||
};
|
||||
// Add initial context
|
||||
std::unique_lock L(llm_lock);
|
||||
llm->append("Verlauf des #"+channel.name+" Kanals.\n"
|
||||
"Notiz 1: "+bot.me.username+" ist ein freundlicher Chatbot, der immer gerne auf deutsch mitredet. Er ist freundlich und hilfsbereit und antwortet immer sofort. Er hat guten Humor und mag jeden. Sein Alter ist 16 und er wurde 2007 geboren.\n"
|
||||
"Notiz 2: Ecki heisst in Wirklichkeit Eckhard Kohlhuber und kommt aus Bayern.\n"
|
||||
|
@ -311,9 +150,10 @@ class Bot {
|
|||
return;
|
||||
}
|
||||
// Format and append line
|
||||
std::unique_lock L(llm_lock);
|
||||
for (const auto line : str_split(msg.content, '\n')) {
|
||||
Timer timeout;
|
||||
llm->append(msg.author.username+": "+std::string(line)+'\n', [&] (float) {
|
||||
llm->append(msg.author.username+": "+clean_string(line)+'\n', [&] (float) {
|
||||
if (timeout.get<std::chrono::minutes>() > 1) {
|
||||
std::cerr << "\nWarning: Timeout reached processing message" << std::endl;
|
||||
return false;
|
||||
|
@ -321,15 +161,16 @@ class Bot {
|
|||
return true;
|
||||
});
|
||||
}
|
||||
} catch (const LLM::ContextLengthException&) {
|
||||
} catch (const LM::Inference::ContextLengthException&) {
|
||||
llm.reset();
|
||||
llm_init();
|
||||
}
|
||||
}
|
||||
void prompt_add_trigger() {
|
||||
try {
|
||||
std::unique_lock L(llm_lock);
|
||||
llm->append(bot.me.username+':');
|
||||
} catch (const LLM::ContextLengthException&) {
|
||||
} catch (const LM::Inference::ContextLengthException&) {
|
||||
llm.reset();
|
||||
llm_init();
|
||||
}
|
||||
|
@ -346,14 +187,16 @@ class Bot {
|
|||
// Run model
|
||||
Timer timeout;
|
||||
bool timed_out = false;
|
||||
auto output = llm->run("\n", [&] () {
|
||||
auto output = llm->run("\n", [&] (std::string_view str) {
|
||||
std::cout << str << std::flush;
|
||||
if (timeout.get<std::chrono::minutes>() > 2) {
|
||||
timed_out = true;
|
||||
std::cerr << "\nWarning: Timeout reached generating message" << std::endl;
|
||||
std::cerr << "\nWarning: Timeout reached generating message";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
std::cout << std::endl;
|
||||
if (timed_out) output = "Fehler: Zeitüberschreitung";
|
||||
// Send resulting message
|
||||
msg.content = output;
|
||||
|
@ -463,9 +306,6 @@ public:
|
|||
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
// Init GGML
|
||||
ggml_time_init();
|
||||
|
||||
// Check arguments
|
||||
if (argc < 3) {
|
||||
std::cout << "Usage: " << argv[0] << " <token> <channel>" << std::endl;
|
||||
|
|
Loading…
Add table
Reference in a new issue