1
0
Fork 0
mirror of https://gitlab.com/niansa/llama_any_server.git synced 2025-03-06 20:53:35 +01:00
llama_any_server/main.cpp
2023-04-08 14:21:54 +02:00

150 lines
5 KiB
C++

#include <iostream>
#include <justlm.hpp>
#include <boost/asio.hpp>
namespace Application {
using namespace boost::asio;
class Server {
io_context service;
ip::tcp::endpoint endpoint;
ip::tcp::acceptor acceptor;
static inline
const LM::Inference::Params& get_params() noexcept {
static auto params = [] () {
LM::Inference::Params params;
params.n_batch = 8;
params.n_ctx = 2048;
params.n_repeat_last = 64;
params.repeat_penalty = 1.3f;
params.temp = 0.1f;
params.top_k = 40;
params.top_p = 0.95f;
params.use_mlock = false;
return params;
}();
return params;
}
void client_run(ip::tcp::socket& socket) {
uint8_t len;
// Create inference instance
LM::Inference inference("gpt4all-lora-unfiltered-quantized.bin"/*TODO: do not hardcode path*/, get_params());
for (bool first_run = true; ; first_run = false) {
// Receive prompt length
std::cout << "Receiving prompt length..." << std::endl;
socket.receive(mutable_buffer(&len, sizeof(len)));
// Receive prompt
std::string prompt;
prompt.resize(len);
std::cout << "Receiving prompt of length " << unsigned(len) << "..." << std::endl;
socket.receive(mutable_buffer(prompt.data(), len));
// Stop on zero length
if (len == 0) break;
// Append prompt
std::cout << "Evaluating prompt..." << std::endl;
uint8_t old_progress = 0;
inference.append(std::string(first_run?"Below is an instruction that describes a task. Write a response that appropriately completes the request.":"")+"\n\n### Instruction:\n\n"+prompt+"\n\n### Response:\n\n", [&old_progress, &socket] (float progress) {
uint8_t progress_i = progress;
// Report new progress
if (old_progress != progress_i) {
socket.send(const_buffer(&progress_i, sizeof(progress_i)));
// Set as old progress
old_progress = progress_i;
}
return true;
});
// Report completion if needed
if (old_progress != 100) {
old_progress = 100;
socket.send(const_buffer(&old_progress, sizeof(old_progress)));
}
// Run inference
std::cout << "Running interference...\n" << std::endl;
bool aborted = false;
auto result = inference.run("RUN UNTIL EOS...", [&socket, &aborted, stop_soon = false] (const char *token) mutable {
uint8_t len;
const auto token_len = strlen(token);
std::cout << token << std::flush;
// Skip empty tokens
if (token_len == 0) return true;
// Send result length
len = token_len;
socket.send(const_buffer(&len, sizeof(len)));
// Send result
socket.send(const_buffer(token, token_len));
// Check for cancellation request
socket.receive(mutable_buffer(&len, sizeof(len)));
std::cout << unsigned(len) << std::flush;
switch (len) {
case 0xAB: { // Abort
// Stop immediately
std::cout << "... Aborted." << std::endl;
aborted = true;
return false;
} break;
case 0xCA: { // Cancel
// Stop at end of sentence
stop_soon = true;
} break;
}
// Check for end of sentence
if (stop_soon && (token[0] == '.' || token[token_len-1] == '.')) {
std::cout << "... Cancelled." << std::endl;
return false;
}
// Continue
return true;
});
std::cout << std::endl;
// Send zero-length token if not aborted
if (!aborted) {
len = 0xFF;
socket.send(const_buffer(&len, sizeof(len)));
}
}
}
public:
Server() : endpoint(ip::tcp::v4(), (unsigned short)99181/*TODO: do not hardcode port*/), acceptor(service, endpoint) {}
void run() {
std::cout << "Waiting for connection..." << std::endl;
// Wait for connections infinitely
for (;;) {
// Accept connection immediately
ip::tcp::socket socket(service);
acceptor.accept(socket);
// Start thread for new connection
std::cout << "Accepted connection, starting connection thread..." << std::endl;
std::thread([this, socket = std::move(socket)] () mutable {
client_run(socket);
}).detach();
}
}
};
}
int main() {
Application::Server server;
server.run();
}