#include "chat.hpp" #include #include #include #include namespace LM { namespace Chat { LM::Inference::Params Inference::get_params() const { LM::Inference::Params fres; fres.n_eos_ignores = config.emits_eos?0:std::numeric_limits::max(); if (config.max_context_size.has_value()) fres.n_ctx = config.max_context_size.value(); if (config.repeat_last.has_value()) fres.n_repeat_last = config.repeat_last.value(); if (config.top_k.has_value()) fres.top_k = config.top_k.value(); if (config.top_p.has_value()) fres.top_p = config.top_p.value(); if (config.temp.has_value()) fres.temp = config.temp.value(); if (config.mirostat_learning_rate.has_value()) fres.mirostat_learning_rate = config.mirostat_learning_rate.value(); if (config.mirostat_target_entropy.has_value()) fres.mirostat_target_entropy = config.mirostat_target_entropy.value(); if (config.repeat_penalty.has_value()) fres.repeat_penalty = config.repeat_penalty.value(); if (config.mirostat_version.has_value()) fres.prefer_mirostat = config.mirostat_version.value(); return fres; } Inference::Inference(LM::InferencePool& pool, const std::string& config_path) : pool(&pool) { config.parse(config_path); } Inference::Inference(const std::string &config_path) { config.parse(config_path); } LM_SCHEDULABLE(bool) Inference::reset() { if (pool) { const auto id = get_id(); pool->delete_inference(id); LM_CORETURN (LM_COAWAIT pool->create_inference(id, config.model_file, get_params())) != nullptr; } else inference.reset(LM::Inference::construct(config.model_file, get_params())); LM_CORETURN inference != nullptr; } LM_SCHEDULABLE(bool) Inference::append(const ModelConfig::Prompt &promptConfig, const std::string &message, const EvaluateCallback &on_evaluate) { auto inference = LM_COAWAIT get_underlaying(); bool non_cancelled = true; // Append prompt prefix if needed if (!common::utils::ends_with(inference->get_prompt(), promptConfig.prefix)) { LM_ERROR_CATCH(LM_COAWAIT inference->append(promptConfig.prefix, [on_evaluate] (float progress) -> bool { if (on_evaluate) on_evaluate(progress/3.0f, true); return true; // Can't be cancelled here }), LM_BOOL_ERROR, {LM_CORETURN "";}); } // Append prompt LM_ERROR_CATCH(LM_COAWAIT inference->append(message, [on_evaluate, &non_cancelled] (float progress) -> bool { return non_cancelled = on_evaluate?on_evaluate(progress/3.0f+100.0f/3.0f, false):true; }), LM_BOOL_ERROR, {LM_CORETURN "";}); // Append prompt suffix LM_ERROR_CATCH(LM_COAWAIT inference->append(promptConfig.suffix, [on_evaluate] (float progress) -> bool { if (on_evaluate) on_evaluate(progress/3.0f+(100.0f/3.0f*2.0f), true); return true; // Can't be cancelled here }), LM_BOOL_ERROR, {LM_CORETURN "";}); // Final result LM_CORETURN non_cancelled; } LM_SCHEDULABLE(bool) Inference::append(const std::string &message, const EvaluateCallback &on_evaluate) { LM_CORETURN LM_COAWAIT append(config.prompt, message, on_evaluate); } LM_SCHEDULABLE(std::string) LM::Chat::Inference::generate(const ModelConfig::Prompt &promptConfig, const GenerateCallback &on_generate) { LM_CORETURN LM_COAWAIT inference->run(promptConfig.prefix, [on_generate] (const char *token) { return on_generate(token); }); } LM_SCHEDULABLE(std::string) LM::Chat::Inference::generate(const GenerateCallback &on_generate) { LM_CORETURN LM_COAWAIT generate(config.prompt, on_generate); } LM_SCHEDULABLE(std::string) Inference::prompt(const ModelConfig::Prompt &promptConfig, const std::string &message, const GenerateCallback &on_generate, const EvaluateCallback &on_evaluate) { if (!LM_COAWAIT append(promptConfig, message, on_evaluate)) LM_CORETURN ""; LM_CORETURN LM_COAWAIT generate(on_generate); } LM_SCHEDULABLE(std::string) Inference::prompt(const std::string& message, const GenerateCallback& on_generate, const EvaluateCallback& on_evaluate) { LM_CORETURN LM_COAWAIT prompt(config.prompt, message, on_generate, on_evaluate); } LM_SCHEDULABLE(Inference::OptionallySharedInference) LM::Chat::Inference::get_underlaying() const { if (inference) LM_CORETURN inference.get(); if (pool) LM_CORETURN LM_COAWAIT pool->get_inference(get_id()); common::utils::unreachable(); } } }