mirror of
https://gitlab.com/niansa/libjustchat.git
synced 2025-03-06 20:48:31 +01:00
104 lines
4.5 KiB
C++
104 lines
4.5 KiB
C++
#include "chat.hpp"
|
|
|
|
#include <iostream>
|
|
#include <limits>
|
|
#include <commoncpp/utils.hpp>
|
|
#include <justlm.hpp>
|
|
|
|
|
|
namespace LM {
|
|
namespace Chat {
|
|
LM::Inference::Params Inference::get_params() const {
|
|
LM::Inference::Params fres;
|
|
fres.n_eos_ignores = config.emits_eos?0:std::numeric_limits<unsigned>::max();
|
|
if (config.max_context_size.has_value())
|
|
fres.n_ctx = config.max_context_size.value();
|
|
if (config.repeat_last.has_value())
|
|
fres.n_repeat_last = config.repeat_last.value();
|
|
if (config.top_k.has_value())
|
|
fres.top_k = config.top_k.value();
|
|
if (config.top_p.has_value())
|
|
fres.top_p = config.top_p.value();
|
|
if (config.temp.has_value())
|
|
fres.temp = config.temp.value();
|
|
if (config.mirostat_learning_rate.has_value())
|
|
fres.mirostat_learning_rate = config.mirostat_learning_rate.value();
|
|
if (config.mirostat_target_entropy.has_value())
|
|
fres.mirostat_target_entropy = config.mirostat_target_entropy.value();
|
|
if (config.repeat_penalty.has_value())
|
|
fres.repeat_penalty = config.repeat_penalty.value();
|
|
if (config.mirostat_version.has_value())
|
|
fres.prefer_mirostat = config.mirostat_version.value();
|
|
return fres;
|
|
}
|
|
|
|
Inference::Inference(LM::InferencePool& pool, const std::string& config_path) : pool(&pool) {
|
|
config.parse(config_path);
|
|
}
|
|
|
|
Inference::Inference(const std::string &config_path) {
|
|
config.parse(config_path);
|
|
}
|
|
|
|
LM_SCHEDULABLE(bool) Inference::reset() {
|
|
if (pool) {
|
|
const auto id = get_id();
|
|
pool->delete_inference(id);
|
|
LM_CORETURN (LM_COAWAIT pool->create_inference(id, config.model_file, get_params())) != nullptr;
|
|
} else
|
|
inference.reset(LM::Inference::construct(config.model_file, get_params()));
|
|
LM_CORETURN inference != nullptr;
|
|
}
|
|
|
|
LM_SCHEDULABLE(bool) Inference::append(const ModelConfig::Prompt &promptConfig, const std::string &message, const EvaluateCallback &on_evaluate) {
|
|
auto inference = LM_COAWAIT get_underlaying();
|
|
bool non_cancelled = true;
|
|
// Append prompt prefix if needed
|
|
if (!common::utils::ends_with(inference->get_prompt(), promptConfig.prefix)) {
|
|
LM_ERROR_CATCH(LM_COAWAIT inference->append(promptConfig.prefix, [on_evaluate] (float progress) -> bool {
|
|
if (on_evaluate)
|
|
on_evaluate(progress/3.0f, true);
|
|
return true; // Can't be cancelled here
|
|
}), LM_BOOL_ERROR, {LM_CORETURN "";});
|
|
}
|
|
// Append prompt
|
|
LM_ERROR_CATCH(LM_COAWAIT inference->append(message, [on_evaluate, &non_cancelled] (float progress) -> bool {
|
|
return non_cancelled = on_evaluate?on_evaluate(progress/3.0f+100.0f/3.0f, false):true;
|
|
}), LM_BOOL_ERROR, {LM_CORETURN "";});
|
|
// Append prompt suffix
|
|
LM_ERROR_CATCH(LM_COAWAIT inference->append(promptConfig.suffix, [on_evaluate] (float progress) -> bool {
|
|
if (on_evaluate)
|
|
on_evaluate(progress/3.0f+(100.0f/3.0f*2.0f), true);
|
|
return true; // Can't be cancelled here
|
|
}), LM_BOOL_ERROR, {LM_CORETURN "";});
|
|
// Final result
|
|
LM_CORETURN non_cancelled;
|
|
}
|
|
LM_SCHEDULABLE(bool) Inference::append(const std::string &message, const EvaluateCallback &on_evaluate) {
|
|
LM_CORETURN LM_COAWAIT append(config.prompt, message, on_evaluate);
|
|
}
|
|
|
|
LM_SCHEDULABLE(std::string) LM::Chat::Inference::generate(const ModelConfig::Prompt &promptConfig, const GenerateCallback &on_generate) {
|
|
LM_CORETURN LM_COAWAIT inference->run(promptConfig.prefix, [on_generate] (const char *token) {
|
|
return on_generate(token);
|
|
});
|
|
}
|
|
LM_SCHEDULABLE(std::string) LM::Chat::Inference::generate(const GenerateCallback &on_generate) {
|
|
LM_CORETURN LM_COAWAIT generate(config.prompt, on_generate);
|
|
}
|
|
|
|
LM_SCHEDULABLE(std::string) Inference::prompt(const ModelConfig::Prompt &promptConfig, const std::string &message, const GenerateCallback &on_generate, const EvaluateCallback &on_evaluate) {
|
|
if (!LM_COAWAIT append(promptConfig, message, on_evaluate)) LM_CORETURN "";
|
|
LM_CORETURN LM_COAWAIT generate(on_generate);
|
|
}
|
|
LM_SCHEDULABLE(std::string) Inference::prompt(const std::string& message, const GenerateCallback& on_generate, const EvaluateCallback& on_evaluate) {
|
|
LM_CORETURN LM_COAWAIT prompt(config.prompt, message, on_generate, on_evaluate);
|
|
}
|
|
|
|
LM_SCHEDULABLE(Inference::OptionallySharedInference) LM::Chat::Inference::get_underlaying() const {
|
|
if (inference) LM_CORETURN inference.get();
|
|
if (pool) LM_CORETURN LM_COAWAIT pool->get_inference(get_id());
|
|
common::utils::unreachable();
|
|
}
|
|
}
|
|
}
|