From e3f0ebc5180eb025c1019dfe25f8a44283398dac Mon Sep 17 00:00:00 2001 From: niansa Date: Sun, 26 Nov 2023 22:06:56 +0100 Subject: [PATCH] Minor fixes and update to latest justlm --- examples/CMakeLists.txt | 5 +++- examples/repl.cpp | 65 +++++++++++++++++++++++++++++++++++++++++ include/anyproc.hpp | 23 +++++++++------ 3 files changed, 83 insertions(+), 10 deletions(-) create mode 100644 examples/repl.cpp diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 5514ccd..b7fcd63 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -6,7 +6,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) include(CPM.cmake) -CPMAddPackage("gl:niansa/libjustlm@1.0") +CPMAddPackage("gl:niansa/libjustlm@1.3") add_subdirectory(.. anyproc) add_executable(dictionary dictionary.cpp) @@ -14,3 +14,6 @@ target_link_libraries(dictionary PUBLIC anyproc) add_executable(translator translator.cpp) target_link_libraries(translator PUBLIC anyproc) + +add_executable(repl repl.cpp) +target_link_libraries(repl PUBLIC anyproc) diff --git a/examples/repl.cpp b/examples/repl.cpp new file mode 100644 index 0000000..db6fb6f --- /dev/null +++ b/examples/repl.cpp @@ -0,0 +1,65 @@ +#include "anyproc.hpp" + +#include +#include + + + +int main(int argc, char **argv) { + if (argc != 2) { + std::cerr << "Usage: " << argv[0] << " " << std::endl; + return -1; + } + + PyEval eval(argv[1]); + + const auto progress_indicator = [] (float progress) { + if (progress < 2.5f || progress > 90.0f) { + std::cout << "\r \r"; + return true; + } + std::cout << unsigned(progress) << "% \r" << std::flush; + return true; + }; + + bool had_nl = false, + had_begin = false, + first_token; + const auto result_printer = [&] (const char *token) { + std::cout << token << std::flush; + if (token[0] == '\n') + had_nl = true; + if ((had_nl || first_token) && token[0] == '>' && token[1] == '>') { + had_begin = true; + return false; + } + first_token = false; + return true; + }; + + for (;;) { + std::string expr; + std::cout << ">>> " << std::flush; + std::getline(std::cin, expr); + eval.begin(); + if (expr.empty()) + continue; + bool no_value = expr[0] == '$'; + if (expr.find("import ") == 0) { + eval.load_module(expr.substr(7, expr.size()-7)).finish(); + } else if (no_value) { + eval.expression(expr.substr(1, expr.size()-1)).finish(); + } else { + first_token = true; + had_begin = false; + eval.expression(expr).run(progress_indicator, result_printer); + if (had_begin) { + eval.get_inference().append("> \n"); + std::cout << "\r \r"; + } else if (!had_nl) { + std::cout << '\n'; + } + had_nl = false; + } + } +} diff --git a/include/anyproc.hpp b/include/anyproc.hpp index b7d2c92..1d5d0c4 100644 --- a/include/anyproc.hpp +++ b/include/anyproc.hpp @@ -8,6 +8,7 @@ #include #include #include +#include // for debugging purposes, to be removed #include @@ -68,11 +69,7 @@ public: constexpr static LM::Inference::Params get_recommended_params() { LM::Inference::Params p; p.use_mlock = false; - p.repeat_penalty = 1.2f; - p.n_repeat_last = 64; - p.temp = 0.4f; - p.top_k = 1; - p.top_p = 0.f; + p.temp = 0.0f; return p; } @@ -107,6 +104,7 @@ public: buffer += "\n"; LM_COAWAIT i->append(buffer, on_append_tick); buffer.clear(); + std::ofstream("prompt.txt") << i->get_prompt(); LM_CORETURN LM_COAWAIT i->run("\n", on_generation_tick); } void example(std::string_view response = "") { @@ -114,6 +112,9 @@ public: buffer += response; if (!response.empty()) buffer += "\n"; } + void finish() { + buffer += "\n"; + } LM_SCHEDULABLE(void) create_savestate(LM::Inference::Savestate &sv, const std::function &on_append_tick = nullptr) { if (!buffer.empty()) { @@ -142,13 +143,15 @@ public: p.top_k = 5; p.top_p = 0.2f; p.temp = 0.2f; + p.repeat_penalty = 1.2f; + p.n_repeat_last = 64; return p; } Dictionary(const std::string& weights_path, const LM::Inference::Params& params = get_params()) : PyEval(weights_path, params) { begin() .load_module("huge_dictionary", "dict") - .example(""); + .finish(); begin() .expression(word_lookup_exprgen("\"Treehouse\"")+".description") .example("'A small house, especially one for children to play in, built or placed up in the branches of a tree.'"); @@ -183,6 +186,8 @@ public: p.top_k = 5; p.top_p = 0.2f; p.temp = 0.1f; + p.repeat_penalty = 1.2f; + p.n_repeat_last = 64; p.n_eos_ignores = 1; return p; } @@ -190,10 +195,10 @@ public: Translator(const std::string& weights_path, const LM::Inference::Params& params = get_params()) : PyEval(weights_path, params) { begin() .load_module("deepl_scrape", "translator") - .example(""); + .finish(); begin() - .expression("# Using translator.translate, we can reliably translate any language into any other language. The meaning of the text is going to stay the same and if the text already is in the target language, it remains unchanged.") - .example(""); + .expression("# Using translator.translate, we can reliably translate any language into any other language. The meaning of the text is going to stay the same and if the text is already in the target language, it remains unchanged.") + .finish(); begin() .expression(translation_exprgen("\"Treehouse\"", "\"DE\"")) .example("'Baumhaus'");