diff --git a/gptj.cpp b/gptj.cpp index 6bdd7c3..f01d5b9 100644 --- a/gptj.cpp +++ b/gptj.cpp @@ -14,7 +14,9 @@ const LM::Implementation *get_justlm_implementation() { return &fres; } -bool magic_match(uint32_t magic) { +bool magic_match(std::istream& f) { + uint32_t magic; + f.read(reinterpret_cast(&magic), sizeof(magic)); return magic == 0x67676d6c; } diff --git a/justlm.cpp b/justlm.cpp index ad05ea9..97903e6 100644 --- a/justlm.cpp +++ b/justlm.cpp @@ -9,7 +9,7 @@ static -Dlhandle get_implementation(uint32_t magic) { +Dlhandle get_implementation(std::ifstream& input_f) { Dlhandle matching; Dlhandle fallback; // Iterate over all libraries @@ -32,8 +32,9 @@ Dlhandle get_implementation(uint32_t magic) { continue; } // Set if matching magic - auto magic_match = dl.get("magic_match"); - if (magic_match && magic_match(magic)) { + input_f.seekg(0); + auto magic_match = dl.get("magic_match"); + if (magic_match && magic_match(input_f)) { matching = std::move(dl); continue; } @@ -48,13 +49,11 @@ LM::Inference *LM::Inference::construct(const std::string &weights_path, const P static std::vector dls; // Read magic std::ifstream f(weights_path, std::ios::binary); - uint32_t magic; - if (!f.read(reinterpret_cast(&magic), sizeof(magic))) { + if (!f) { throw Exception("Failed to open weights file for reading at "+weights_path); } - f.seekg(0); // Get correct implementation - auto impl = get_implementation(magic); + auto impl = get_implementation(f); if (!impl) return nullptr; // Get inference constructor auto constructor = impl.get("construct"); @@ -62,5 +61,6 @@ LM::Inference *LM::Inference::construct(const std::string &weights_path, const P // Back up Dlhandle dls.push_back(std::move(impl)); // Construct inference + f.seekg(0); return constructor(weights_path, f, p); } diff --git a/llama.cpp b/llama.cpp index 650558c..510f8b6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -10,10 +10,21 @@ extern "C" { const LM::Implementation *get_justlm_implementation() { - static LM::Implementation fres{true}; + static LM::Implementation fres{false}; return &fres; } +bool magic_match(std::istream& f) { + // Check magic + uint32_t magic; + f.read(reinterpret_cast(&magic), sizeof(magic)); + if (magic != 0x67676a74) return false; + // Check version + uint32_t version = 0; + f.read(reinterpret_cast(&version), sizeof(version)); + return version >= 2; +} + LM::Inference *construct(const std::string &weights_path, std::ifstream& f, const LM::Inference::Params &p) { f.close(); return new LM::LLaMAInference(weights_path, p); diff --git a/llama_old.cpp b/llama_old.cpp index b4496d7..582e8ea 100644 --- a/llama_old.cpp +++ b/llama_old.cpp @@ -10,14 +10,10 @@ extern "C" { const LM::Implementation *get_justlm_implementation() { - static LM::Implementation fres{false}; + static LM::Implementation fres{true}; return &fres; } -bool magic_match(uint32_t magic) { - return magic == 0x67676a74; -} - LM::Inference *construct(const std::string &weights_path, std::ifstream& f, const LM::Inference::Params &p) { f.close(); return new LM::LLaMAInference(weights_path, p); diff --git a/mpt.cpp b/mpt.cpp index e952c99..04140aa 100644 --- a/mpt.cpp +++ b/mpt.cpp @@ -14,7 +14,9 @@ const LM::Implementation *get_justlm_implementation() { return &fres; } -bool magic_match(uint32_t magic) { +bool magic_match(std::istream& f) { + uint32_t magic; + f.read(reinterpret_cast(&magic), sizeof(magic)); return magic == 0x67676d6d; }