1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustlm.git synced 2025-03-06 20:49:17 +01:00
libjustlm/justlm.cpp
2023-05-15 14:46:19 +02:00

30 lines
761 B
C++

#include "justlm.hpp"
#include "justlm_llama.hpp"
#include "justlm_gptj.hpp"
#ifdef LM_MPT
# include "justlm_mpt.hpp"
#endif
#include <fstream>
LM::Inference *LM::Inference::construct(const std::string &weights_path, const Params &p) {
// Read magic
std::ifstream f(weights_path, std::ios::binary);
uint32_t magic;
f.read(reinterpret_cast<char*>(&magic), sizeof(magic));
// Create inference instance
if (magic == 0x67676d6c) {
f.seekg(0);
return new GPTJInference(weights_path, f, p);
# ifdef LM_MPT
} else if (magic == 0x67676d6d) {
f.seekg(0);
return new MPTInference(weights_path, f, p);
# endif
} else {
f.close();
return new LLaMaInference(weights_path, p);
}
}