1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustchat.git synced 2025-03-06 20:48:31 +01:00
libjustchat/chat.cpp

104 lines
4.5 KiB
C++

#include "chat.hpp"
#include <iostream>
#include <limits>
#include <commoncpp/utils.hpp>
#include <justlm.hpp>
namespace LM {
namespace Chat {
LM::Inference::Params Inference::get_params() const {
LM::Inference::Params fres;
fres.n_eos_ignores = config.emits_eos?0:std::numeric_limits<unsigned>::max();
if (config.max_context_size.has_value())
fres.n_ctx = config.max_context_size.value();
if (config.repeat_last.has_value())
fres.n_repeat_last = config.repeat_last.value();
if (config.top_k.has_value())
fres.top_k = config.top_k.value();
if (config.top_p.has_value())
fres.top_p = config.top_p.value();
if (config.temp.has_value())
fres.temp = config.temp.value();
if (config.mirostat_learning_rate.has_value())
fres.mirostat_learning_rate = config.mirostat_learning_rate.value();
if (config.mirostat_target_entropy.has_value())
fres.mirostat_target_entropy = config.mirostat_target_entropy.value();
if (config.repeat_penalty.has_value())
fres.repeat_penalty = config.repeat_penalty.value();
if (config.mirostat_version.has_value())
fres.prefer_mirostat = config.mirostat_version.value();
return fres;
}
Inference::Inference(LM::InferencePool& pool, const std::string& config_path) : pool(&pool) {
config.parse(config_path);
}
Inference::Inference(const std::string &config_path) {
config.parse(config_path);
}
LM_SCHEDULABLE(bool) Inference::reset() {
if (pool) {
const auto id = get_id();
pool->delete_inference(id);
LM_CORETURN (LM_COAWAIT pool->create_inference(id, config.model_file, get_params())) != nullptr;
} else
inference.reset(LM::Inference::construct(config.model_file, get_params()));
LM_CORETURN inference != nullptr;
}
LM_SCHEDULABLE(bool) Inference::append(const ModelConfig::Prompt &promptConfig, const std::string &message, const EvaluateCallback &on_evaluate) {
auto inference = LM_COAWAIT get_underlaying();
bool non_cancelled = true;
// Append prompt prefix if needed
if (!common::utils::ends_with(inference->get_prompt(), promptConfig.prefix)) {
LM_ERROR_CATCH(LM_COAWAIT inference->append(promptConfig.prefix, [on_evaluate] (float progress) -> bool {
if (on_evaluate)
on_evaluate(progress/3.0f, true);
return true; // Can't be cancelled here
}), LM_BOOL_ERROR, {LM_CORETURN "";});
}
// Append prompt
LM_ERROR_CATCH(LM_COAWAIT inference->append(message, [on_evaluate, &non_cancelled] (float progress) -> bool {
return non_cancelled = on_evaluate?on_evaluate(progress/3.0f+100.0f/3.0f, false):true;
}), LM_BOOL_ERROR, {LM_CORETURN "";});
// Append prompt suffix
LM_ERROR_CATCH(LM_COAWAIT inference->append(promptConfig.suffix, [on_evaluate] (float progress) -> bool {
if (on_evaluate)
on_evaluate(progress/3.0f+(100.0f/3.0f*2.0f), true);
return true; // Can't be cancelled here
}), LM_BOOL_ERROR, {LM_CORETURN "";});
// Final result
LM_CORETURN non_cancelled;
}
LM_SCHEDULABLE(bool) Inference::append(const std::string &message, const EvaluateCallback &on_evaluate) {
LM_CORETURN LM_COAWAIT append(config.prompt, message, on_evaluate);
}
LM_SCHEDULABLE(std::string) LM::Chat::Inference::generate(const ModelConfig::Prompt &promptConfig, const GenerateCallback &on_generate) {
LM_CORETURN LM_COAWAIT inference->run(promptConfig.prefix, [on_generate] (const char *token) {
return on_generate(token);
});
}
LM_SCHEDULABLE(std::string) LM::Chat::Inference::generate(const GenerateCallback &on_generate) {
LM_CORETURN LM_COAWAIT generate(config.prompt, on_generate);
}
LM_SCHEDULABLE(std::string) Inference::prompt(const ModelConfig::Prompt &promptConfig, const std::string &message, const GenerateCallback &on_generate, const EvaluateCallback &on_evaluate) {
if (!LM_COAWAIT append(promptConfig, message, on_evaluate)) LM_CORETURN "";
LM_CORETURN LM_COAWAIT generate(on_generate);
}
LM_SCHEDULABLE(std::string) Inference::prompt(const std::string& message, const GenerateCallback& on_generate, const EvaluateCallback& on_evaluate) {
LM_CORETURN LM_COAWAIT prompt(config.prompt, message, on_generate, on_evaluate);
}
LM_SCHEDULABLE(Inference::OptionallySharedInference) LM::Chat::Inference::get_underlaying() const {
if (inference) LM_CORETURN inference.get();
if (pool) LM_CORETURN LM_COAWAIT pool->get_inference(get_id());
common::utils::unreachable();
}
}
}