mirror of
https://gitlab.com/niansa/llama_any_server.git
synced 2025-03-06 20:53:35 +01:00
Initial commit
This commit is contained in:
parent
1e520b3182
commit
ea861bcac9
4 changed files with 124 additions and 5 deletions
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
|
@ -0,0 +1,3 @@
|
|||
[submodule "libjustlm"]
|
||||
path = libjustlm
|
||||
url = https://gitlab.com/niansa/libjustlm.git
|
|
@ -5,7 +5,11 @@ project(llama_any_server LANGUAGES CXX)
|
|||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
add_subdirectory(libjustlm)
|
||||
|
||||
add_executable(llama_any_server main.cpp)
|
||||
|
||||
target_link_libraries(llama_any_server PUBLIC libjustlm)
|
||||
|
||||
install(TARGETS llama_any_server
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})
|
||||
|
|
1
libjustlm
Submodule
1
libjustlm
Submodule
|
@ -0,0 +1 @@
|
|||
Subproject commit 8b5a375f592099ef12d3988087fe699b3ecba3bb
|
121
main.cpp
121
main.cpp
|
@ -1,9 +1,120 @@
|
|||
#include <iostream>
|
||||
#include <justlm.hpp>
|
||||
#include <boost/asio.hpp>
|
||||
|
||||
using namespace std;
|
||||
|
||||
int main()
|
||||
{
|
||||
cout << "Hello World!" << endl;
|
||||
return 0;
|
||||
|
||||
namespace Application {
|
||||
using namespace boost::asio;
|
||||
|
||||
class Server {
|
||||
io_context service;
|
||||
ip::tcp::endpoint endpoint;
|
||||
ip::tcp::acceptor acceptor;
|
||||
|
||||
static inline
|
||||
const LM::Inference::Params& get_params() noexcept {
|
||||
static auto params = [] () {
|
||||
LM::Inference::Params params;
|
||||
params.n_batch = 8;
|
||||
params.n_ctx = 2048;
|
||||
params.n_repeat_last = 64;
|
||||
params.repeat_penalty = 1.3f;
|
||||
params.temp = 0.1f;
|
||||
params.top_k = 40;
|
||||
params.top_p = 0.95f;
|
||||
params.use_mlock = false;
|
||||
return params;
|
||||
}();
|
||||
return params;
|
||||
}
|
||||
|
||||
void client_run(ip::tcp::socket& socket) {
|
||||
uint8_t len;
|
||||
|
||||
// Create inference instance
|
||||
LM::Inference inference("gpt4all-lora-unfiltered-quantized.bin"/*TODO: do not hardcode path*/, get_params());
|
||||
|
||||
for (bool first_run = true; ; first_run = false) {
|
||||
// Receive prompt length
|
||||
std::cout << "Receiving prompt length..." << std::endl;
|
||||
socket.receive(mutable_buffer(&len, sizeof(len)));
|
||||
|
||||
// Receive prompt
|
||||
std::string prompt;
|
||||
prompt.resize(len);
|
||||
std::cout << "Receiving prompt of length " << unsigned(len) << "..." << std::endl;
|
||||
socket.receive(mutable_buffer(prompt.data(), len));
|
||||
|
||||
// Stop on zero length
|
||||
if (len == 0) break;
|
||||
|
||||
// Append prompt
|
||||
std::cout << "Evaluating prompt..." << std::endl;
|
||||
uint8_t old_progress = 0;
|
||||
inference.append(std::string(first_run?"Below is an instruction that describes a task. Write a response that appropriately completes the request.":"")+"\n\n### Instruction:\n\n"+prompt+"\n\n### Response:\n\n", [&old_progress, &socket] (float progress) {
|
||||
uint8_t progress_i = progress;
|
||||
// Report new progress
|
||||
if (old_progress != progress_i) {
|
||||
socket.send(const_buffer(&progress_i, sizeof(progress_i)));
|
||||
// Set as old progress
|
||||
old_progress = progress_i;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
// Report completion if needed
|
||||
if (old_progress != 100) {
|
||||
old_progress = 100;
|
||||
socket.send(const_buffer(&old_progress, sizeof(old_progress)));
|
||||
}
|
||||
|
||||
// Run inference
|
||||
std::cout << "Running interference...\n" << std::endl;
|
||||
auto result = inference.run("\n", [&socket] (const char *token) {
|
||||
uint8_t len;
|
||||
const auto token_len = strlen(token);
|
||||
std::cout << token << std::flush;
|
||||
|
||||
// Send result length
|
||||
len = token_len;
|
||||
socket.send(const_buffer(&len, sizeof(len)));
|
||||
|
||||
// Send result
|
||||
socket.send(const_buffer(token, token_len));
|
||||
return true;
|
||||
});
|
||||
std::cout << std::endl;
|
||||
|
||||
// Send zero-length token
|
||||
len = 0xFF;
|
||||
socket.send(const_buffer(&len, sizeof(len)));
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
Server() : endpoint(ip::tcp::v4(), (unsigned short)99181/*TODO: do not hardcode port*/), acceptor(service, endpoint) {}
|
||||
|
||||
void run() {
|
||||
std::cout << "Waiting for connection..." << std::endl;
|
||||
|
||||
// Wait for connections infinitely
|
||||
for (;;) {
|
||||
// Accept connection immediately
|
||||
ip::tcp::socket socket(service);
|
||||
acceptor.accept(socket);
|
||||
|
||||
// Start thread for new connection
|
||||
std::cout << "Accepted connection, starting connection thread..." << std::endl;
|
||||
std::thread([this, socket = std::move(socket)] () mutable {
|
||||
client_run(socket);
|
||||
}).detach();
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
Application::Server server;
|
||||
server.run();
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue