mirror of
https://gitlab.com/niansa/libjustchat.git
synced 2025-03-06 20:48:31 +01:00
eos_ignores -> emits_eos
This commit is contained in:
parent
3a1bf94744
commit
581bc89fe4
6 changed files with 20 additions and 15 deletions
18
chat.cpp
18
chat.cpp
|
@ -1,6 +1,7 @@
|
|||
#include "chat.hpp"
|
||||
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <commoncpp/utils.hpp>
|
||||
#include <justlm.hpp>
|
||||
|
||||
|
@ -9,12 +10,11 @@ namespace LM {
|
|||
namespace Chat {
|
||||
LM::Inference::Params Inference::get_params() const {
|
||||
LM::Inference::Params fres;
|
||||
fres.n_eos_ignores = config.emits_eos?0:std::numeric_limits<unsigned>::max();
|
||||
if (config.max_context_size.has_value())
|
||||
fres.n_ctx = config.max_context_size.value();
|
||||
if (config.repeat_last.has_value())
|
||||
fres.n_repeat_last = config.repeat_last.value();
|
||||
if (config.eos_ignores.has_value())
|
||||
fres.n_eos_ignores = config.eos_ignores.value();
|
||||
if (config.top_k.has_value())
|
||||
fres.top_k = config.top_k.value();
|
||||
if (config.top_p.has_value())
|
||||
|
@ -52,12 +52,14 @@ LM_SCHEDULABLE(bool) Inference::reset() {
|
|||
LM_SCHEDULABLE(bool) Inference::append(const ModelConfig::Prompt &promptConfig, const std::string &message, const EvaluateCallback &on_evaluate) {
|
||||
auto inference = LM_COAWAIT get_underlaying();
|
||||
bool non_cancelled = true;
|
||||
// Append prompt prefix
|
||||
LM_ERROR_CATCH(LM_COAWAIT inference->append(promptConfig.prefix, [on_evaluate] (float progress) -> bool {
|
||||
if (on_evaluate)
|
||||
on_evaluate(progress/3.0f, true);
|
||||
return true; // Can't be cancelled here
|
||||
}), LM_BOOL_ERROR, {LM_CORETURN "";});
|
||||
// Append prompt prefix if needed
|
||||
if (!common::utils::ends_with(inference->get_prompt(), promptConfig.prefix)) {
|
||||
LM_ERROR_CATCH(LM_COAWAIT inference->append(promptConfig.prefix, [on_evaluate] (float progress) -> bool {
|
||||
if (on_evaluate)
|
||||
on_evaluate(progress/3.0f, true);
|
||||
return true; // Can't be cancelled here
|
||||
}), LM_BOOL_ERROR, {LM_CORETURN "";});
|
||||
}
|
||||
// Append prompt
|
||||
LM_ERROR_CATCH(LM_COAWAIT inference->append(message, [on_evaluate, &non_cancelled] (float progress) -> bool {
|
||||
return non_cancelled = on_evaluate?on_evaluate(progress/3.0f+100.0f/3.0f, false):true;
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit ec148c3c3447b26b3213b0277566445456693bdf
|
||||
Subproject commit 54f9a8e586d6c710999cb7f033f6915cff92e35f
|
|
@ -1,6 +1,6 @@
|
|||
cmake_minimum_required(VERSION 3.5)
|
||||
|
||||
project(anyproc LANGUAGES CXX)
|
||||
project(justchat LANGUAGES CXX)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
|
|
@ -14,7 +14,10 @@ int main(int argc, char **argv) {
|
|||
const auto model_config = argv[1];
|
||||
// Construct chat model
|
||||
LM::Chat::Inference model(model_config);
|
||||
model.start();
|
||||
if (!model.start()) {
|
||||
std::cerr << "Failed to load model." << std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
std::string instruction;
|
||||
for (;;) {
|
||||
std::cout << "> ";
|
||||
|
|
|
@ -30,11 +30,11 @@ public:
|
|||
language = "en";
|
||||
Prompt mutable prompt;
|
||||
bool allow_system_prompt = true,
|
||||
strict_prompt = false;
|
||||
strict_prompt = false,
|
||||
emits_eos = true; // Weather the model emits an eos at the end of each response
|
||||
|
||||
std::optional<unsigned> max_context_size,
|
||||
repeat_last, // How many tokens to repeat-penalize
|
||||
eos_ignores, // How many times to ignore EOS
|
||||
top_k,
|
||||
mirostat_version; // Version of mirostat to use; 0 for none
|
||||
std::optional<float> top_p,
|
||||
|
|
|
@ -25,8 +25,8 @@ void ModelConfig::fill(KeyValueMap &&map, bool ignore_extra) {
|
|||
max_context_size = std::stoul(value);
|
||||
else if (key == "repeat_last")
|
||||
repeat_last = std::stoul(value);
|
||||
else if (key == "eos_ignores")
|
||||
eos_ignores = std::stoi(value); // -1 for "infinite"
|
||||
else if (key == "emits_eos")
|
||||
emits_eos = parse_bool(value);
|
||||
else if (key == "top_k")
|
||||
top_k = std::stoi(value);
|
||||
else if (key == "mirostat_version")
|
||||
|
|
Loading…
Add table
Reference in a new issue