1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustchat.git synced 2025-03-06 20:48:31 +01:00

eos_ignores -> emits_eos

This commit is contained in:
niansa 2023-06-10 13:36:32 +02:00
parent 3a1bf94744
commit 581bc89fe4
6 changed files with 20 additions and 15 deletions

View file

@ -1,6 +1,7 @@
#include "chat.hpp" #include "chat.hpp"
#include <iostream> #include <iostream>
#include <limits>
#include <commoncpp/utils.hpp> #include <commoncpp/utils.hpp>
#include <justlm.hpp> #include <justlm.hpp>
@ -9,12 +10,11 @@ namespace LM {
namespace Chat { namespace Chat {
LM::Inference::Params Inference::get_params() const { LM::Inference::Params Inference::get_params() const {
LM::Inference::Params fres; LM::Inference::Params fres;
fres.n_eos_ignores = config.emits_eos?0:std::numeric_limits<unsigned>::max();
if (config.max_context_size.has_value()) if (config.max_context_size.has_value())
fres.n_ctx = config.max_context_size.value(); fres.n_ctx = config.max_context_size.value();
if (config.repeat_last.has_value()) if (config.repeat_last.has_value())
fres.n_repeat_last = config.repeat_last.value(); fres.n_repeat_last = config.repeat_last.value();
if (config.eos_ignores.has_value())
fres.n_eos_ignores = config.eos_ignores.value();
if (config.top_k.has_value()) if (config.top_k.has_value())
fres.top_k = config.top_k.value(); fres.top_k = config.top_k.value();
if (config.top_p.has_value()) if (config.top_p.has_value())
@ -52,12 +52,14 @@ LM_SCHEDULABLE(bool) Inference::reset() {
LM_SCHEDULABLE(bool) Inference::append(const ModelConfig::Prompt &promptConfig, const std::string &message, const EvaluateCallback &on_evaluate) { LM_SCHEDULABLE(bool) Inference::append(const ModelConfig::Prompt &promptConfig, const std::string &message, const EvaluateCallback &on_evaluate) {
auto inference = LM_COAWAIT get_underlaying(); auto inference = LM_COAWAIT get_underlaying();
bool non_cancelled = true; bool non_cancelled = true;
// Append prompt prefix // Append prompt prefix if needed
LM_ERROR_CATCH(LM_COAWAIT inference->append(promptConfig.prefix, [on_evaluate] (float progress) -> bool { if (!common::utils::ends_with(inference->get_prompt(), promptConfig.prefix)) {
if (on_evaluate) LM_ERROR_CATCH(LM_COAWAIT inference->append(promptConfig.prefix, [on_evaluate] (float progress) -> bool {
on_evaluate(progress/3.0f, true); if (on_evaluate)
return true; // Can't be cancelled here on_evaluate(progress/3.0f, true);
}), LM_BOOL_ERROR, {LM_CORETURN "";}); return true; // Can't be cancelled here
}), LM_BOOL_ERROR, {LM_CORETURN "";});
}
// Append prompt // Append prompt
LM_ERROR_CATCH(LM_COAWAIT inference->append(message, [on_evaluate, &non_cancelled] (float progress) -> bool { LM_ERROR_CATCH(LM_COAWAIT inference->append(message, [on_evaluate, &non_cancelled] (float progress) -> bool {
return non_cancelled = on_evaluate?on_evaluate(progress/3.0f+100.0f/3.0f, false):true; return non_cancelled = on_evaluate?on_evaluate(progress/3.0f+100.0f/3.0f, false):true;

@ -1 +1 @@
Subproject commit ec148c3c3447b26b3213b0277566445456693bdf Subproject commit 54f9a8e586d6c710999cb7f033f6915cff92e35f

View file

@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.5) cmake_minimum_required(VERSION 3.5)
project(anyproc LANGUAGES CXX) project(justchat LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)

View file

@ -14,7 +14,10 @@ int main(int argc, char **argv) {
const auto model_config = argv[1]; const auto model_config = argv[1];
// Construct chat model // Construct chat model
LM::Chat::Inference model(model_config); LM::Chat::Inference model(model_config);
model.start(); if (!model.start()) {
std::cerr << "Failed to load model." << std::endl;
return EXIT_FAILURE;
}
std::string instruction; std::string instruction;
for (;;) { for (;;) {
std::cout << "> "; std::cout << "> ";

View file

@ -30,11 +30,11 @@ public:
language = "en"; language = "en";
Prompt mutable prompt; Prompt mutable prompt;
bool allow_system_prompt = true, bool allow_system_prompt = true,
strict_prompt = false; strict_prompt = false,
emits_eos = true; // Weather the model emits an eos at the end of each response
std::optional<unsigned> max_context_size, std::optional<unsigned> max_context_size,
repeat_last, // How many tokens to repeat-penalize repeat_last, // How many tokens to repeat-penalize
eos_ignores, // How many times to ignore EOS
top_k, top_k,
mirostat_version; // Version of mirostat to use; 0 for none mirostat_version; // Version of mirostat to use; 0 for none
std::optional<float> top_p, std::optional<float> top_p,

View file

@ -25,8 +25,8 @@ void ModelConfig::fill(KeyValueMap &&map, bool ignore_extra) {
max_context_size = std::stoul(value); max_context_size = std::stoul(value);
else if (key == "repeat_last") else if (key == "repeat_last")
repeat_last = std::stoul(value); repeat_last = std::stoul(value);
else if (key == "eos_ignores") else if (key == "emits_eos")
eos_ignores = std::stoi(value); // -1 for "infinite" emits_eos = parse_bool(value);
else if (key == "top_k") else if (key == "top_k")
top_k = std::stoi(value); top_k = std::stoi(value);
else if (key == "mirostat_version") else if (key == "mirostat_version")