1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustlm.git synced 2025-03-06 20:49:17 +01:00
libjustlm/justlm.hpp
2023-03-30 07:03:33 -05:00

54 lines
1.4 KiB
C++

#ifndef LLM_H
#define LLM_H
#include <iostream>
#include <string>
#include <vector>
#include <functional>
#include <memory>
#include <thread>
class LLM {
struct {
int32_t seed; // RNG seed
int32_t n_threads = static_cast<int32_t>(std::thread::hardware_concurrency()) / 2;
union {
int32_t n_ctx; // Context size, llama.cpp specific
int32_t n_prompt = -1; // Prompt size, gpt2 specific
};
int32_t n_batch = 8; // Batch size, unused
int32_t top_k = 40;
float top_p = 0.5f;
float temp = 0.72f;
} params;
struct State *state;
void init(const std::string& weights_path);
static
bool ends_with(std::string_view str, std::string_view suffix);
public:
struct Exception : public std::runtime_error {
using std::runtime_error::runtime_error;
};
struct ContextLengthException : public Exception {
ContextLengthException() : Exception("Max. context length exceeded") {}
};
LLM(const std::string& weights_path, int32_t seed = 0) {
// Set random seed
params.seed = seed?seed:time(NULL);
// Initialize llm
init(weights_path);
}
~LLM();
void append(std::string_view prompt, const std::function<bool (float progress)>& on_tick = nullptr);
std::string run(std::string_view end, const std::function<bool (const char *generated)>& on_tick = nullptr);
};
#endif // LLM_H