1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustchat.git synced 2025-03-06 20:48:31 +01:00

eos_ignores -> emits_eos

This commit is contained in:
niansa 2023-06-10 13:36:32 +02:00
parent 3a1bf94744
commit 581bc89fe4
6 changed files with 20 additions and 15 deletions

View file

@ -1,6 +1,7 @@
#include "chat.hpp"
#include <iostream>
#include <limits>
#include <commoncpp/utils.hpp>
#include <justlm.hpp>
@ -9,12 +10,11 @@ namespace LM {
namespace Chat {
LM::Inference::Params Inference::get_params() const {
LM::Inference::Params fres;
fres.n_eos_ignores = config.emits_eos?0:std::numeric_limits<unsigned>::max();
if (config.max_context_size.has_value())
fres.n_ctx = config.max_context_size.value();
if (config.repeat_last.has_value())
fres.n_repeat_last = config.repeat_last.value();
if (config.eos_ignores.has_value())
fres.n_eos_ignores = config.eos_ignores.value();
if (config.top_k.has_value())
fres.top_k = config.top_k.value();
if (config.top_p.has_value())
@ -52,12 +52,14 @@ LM_SCHEDULABLE(bool) Inference::reset() {
LM_SCHEDULABLE(bool) Inference::append(const ModelConfig::Prompt &promptConfig, const std::string &message, const EvaluateCallback &on_evaluate) {
auto inference = LM_COAWAIT get_underlaying();
bool non_cancelled = true;
// Append prompt prefix
LM_ERROR_CATCH(LM_COAWAIT inference->append(promptConfig.prefix, [on_evaluate] (float progress) -> bool {
if (on_evaluate)
on_evaluate(progress/3.0f, true);
return true; // Can't be cancelled here
}), LM_BOOL_ERROR, {LM_CORETURN "";});
// Append prompt prefix if needed
if (!common::utils::ends_with(inference->get_prompt(), promptConfig.prefix)) {
LM_ERROR_CATCH(LM_COAWAIT inference->append(promptConfig.prefix, [on_evaluate] (float progress) -> bool {
if (on_evaluate)
on_evaluate(progress/3.0f, true);
return true; // Can't be cancelled here
}), LM_BOOL_ERROR, {LM_CORETURN "";});
}
// Append prompt
LM_ERROR_CATCH(LM_COAWAIT inference->append(message, [on_evaluate, &non_cancelled] (float progress) -> bool {
return non_cancelled = on_evaluate?on_evaluate(progress/3.0f+100.0f/3.0f, false):true;

@ -1 +1 @@
Subproject commit ec148c3c3447b26b3213b0277566445456693bdf
Subproject commit 54f9a8e586d6c710999cb7f033f6915cff92e35f

View file

@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.5)
project(anyproc LANGUAGES CXX)
project(justchat LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

View file

@ -14,7 +14,10 @@ int main(int argc, char **argv) {
const auto model_config = argv[1];
// Construct chat model
LM::Chat::Inference model(model_config);
model.start();
if (!model.start()) {
std::cerr << "Failed to load model." << std::endl;
return EXIT_FAILURE;
}
std::string instruction;
for (;;) {
std::cout << "> ";

View file

@ -30,11 +30,11 @@ public:
language = "en";
Prompt mutable prompt;
bool allow_system_prompt = true,
strict_prompt = false;
strict_prompt = false,
emits_eos = true; // Weather the model emits an eos at the end of each response
std::optional<unsigned> max_context_size,
repeat_last, // How many tokens to repeat-penalize
eos_ignores, // How many times to ignore EOS
top_k,
mirostat_version; // Version of mirostat to use; 0 for none
std::optional<float> top_p,

View file

@ -25,8 +25,8 @@ void ModelConfig::fill(KeyValueMap &&map, bool ignore_extra) {
max_context_size = std::stoul(value);
else if (key == "repeat_last")
repeat_last = std::stoul(value);
else if (key == "eos_ignores")
eos_ignores = std::stoi(value); // -1 for "infinite"
else if (key == "emits_eos")
emits_eos = parse_bool(value);
else if (key == "top_k")
top_k = std::stoi(value);
else if (key == "mirostat_version")