From c016ccd672070c6333c0214ae8fc7c1c37034ab5 Mon Sep 17 00:00:00 2001
From: niansa <tuxifan@posteo.de>
Date: Fri, 5 May 2023 12:55:48 +0200
Subject: [PATCH] Minor reworks and fixed lambda lifetime in ScheduledThread

---
 include/scheduled_thread.hpp |  6 ++--
 include/scheduler.hpp        | 29 ++++++++++++-------
 scheduled_thread.cpp         |  7 ++++-
 scheduler.cpp                | 18 ++++++++++--
 test.cpp                     | 56 ++++++++++++++++++++++++------------
 5 files changed, 80 insertions(+), 36 deletions(-)

diff --git a/include/scheduled_thread.hpp b/include/scheduled_thread.hpp
index f12b5c9..08a2059 100644
--- a/include/scheduled_thread.hpp
+++ b/include/scheduled_thread.hpp
@@ -14,7 +14,7 @@ namespace CoSched {
 class ScheduledThread {
     struct QueueEntry {
         std::string task_name;
-        std::function<AwaitableTask<void> ()> task_fcn;
+        std::function<AwaitableTask<void> ()> start_fcn;
     };
 
     std::thread thread;
@@ -38,11 +38,11 @@ public:
     }
 
     // DO NOT call from within a task
-    void create_task(const std::string& task_name, const std::function<AwaitableTask<void> ()>& task_fcn) {
+    void create_task(const std::string& task_name, std::function<AwaitableTask<void> ()>&& task_fcn) {
         // Enqueue function
         {
             std::scoped_lock L(queue_mutex);
-            queue.emplace(QueueEntry{task_name, task_fcn});
+            queue.emplace(QueueEntry{task_name, std::move(task_fcn)});
         }
 
         // Notify thread
diff --git a/include/scheduler.hpp b/include/scheduler.hpp
index e9c70d4..76c6377 100644
--- a/include/scheduler.hpp
+++ b/include/scheduler.hpp
@@ -2,9 +2,11 @@
 #define _SCHEDULER_HPP
 #include <string>
 #include <vector>
+#include <unordered_map>
 #include <mutex>
 #include <memory>
 #include <chrono>
+#include <any>
 #include <AwaitableTask.hpp>
 #include <SingleEvent.hpp>
 
@@ -24,10 +26,10 @@ enum {
 };
 
 enum class TaskState {
-    running,
-    sleeping,
-    terminating,
-    dead
+    running, // Task is currently in a normal running state
+    sleeping, // Task is currently waiting to be scheduled again
+    terminating, // Task will start terminating soon
+    dead // Taks is currently terminating
 };
 
 
@@ -37,7 +39,7 @@ class Task {
     static thread_local class Task *current;
 
     class Scheduler *scheduler;
-    std::unique_ptr<SingleEvent<void>> resume_event;
+    std::unique_ptr<SingleEvent<void>> resume_event = nullptr;
 
     std::chrono::system_clock::time_point stopped_at;
 
@@ -54,11 +56,16 @@ public:
     Task(const Task&) = delete;
     Task(Task&&) = delete;
 
+    // Misc property storage, unused
+    std::unordered_map<std::string, std::any> properties;
+
+    // Returns the task that is currently being executed on this thread
     static inline
     Task& get_current() {
         return *current;
     }
 
+    // Sets a task name
     const std::string& get_name() const {
         return name;
     }
@@ -66,6 +73,7 @@ public:
         name = value;
     }
 
+    // Sets the task priority
     Priority get_priority() const {
         return priority;
     }
@@ -73,22 +81,22 @@ public:
         priority = value;
     }
 
+    // Returns the state of this task
     TaskState get_state() const {
         return state;
     }
 
+    // Returns the scheduler that is scheduling this task
     Scheduler& get_scheduler() const {
         return *scheduler;
     }
 
+    // Terminates the task as soon as possible
     void terminate() {
-        if (state == TaskState::running) {
-            state = TaskState::terminating;
-        } else {
-            state = TaskState::dead;
-        }
+        state = TaskState::terminating;
     }
 
+    // Suspends (pauses) the task as soon as possible
     void set_suspended(bool value = true) {
         suspended = value;
     }
@@ -96,6 +104,7 @@ public:
         return suspended;
     }
 
+    // Allows other tasks to execute
     AwaitableTask<bool> yield();
 };
 
diff --git a/scheduled_thread.cpp b/scheduled_thread.cpp
index 8fc5475..ca17b1d 100644
--- a/scheduled_thread.cpp
+++ b/scheduled_thread.cpp
@@ -11,10 +11,15 @@ void CoSched::ScheduledThread::main_loop() {
         {
             std::scoped_lock L(queue_mutex);
             while (!queue.empty()) {
+                // Get queue entry
                 auto e = std::move(queue.front());
                 queue.pop();
+                // Create task for it
                 sched.create_task(e.task_name);
-                e.task_fcn();
+                // Move start function somewhere else
+                auto& start_fcn = std::any_cast<decltype(e.start_fcn)&>(Task::get_current().properties.emplace("start_function", std::move(e.start_fcn)).first->second);
+                // Call start function
+                start_fcn();
             }
         }
         // Run once
diff --git a/scheduler.cpp b/scheduler.cpp
index 02df0a0..5955977 100644
--- a/scheduler.cpp
+++ b/scheduler.cpp
@@ -11,11 +11,16 @@ void CoSched::Task::kill() {
 }
 
 AwaitableTask<bool> Task::yield() {
+    // If it was terminating, it can finally be declared dead now
     if (state == TaskState::terminating) {
-        // If it was terminating, it can finally be declared dead now
         state = TaskState::dead;
         co_return false;
     }
+    // Dead tasks may not yield
+    if (state == TaskState::dead) {
+        co_return false;
+    }
+    if (this != current) co_return true;
     // It's just sleeping
     state = TaskState::sleeping;
     // Create event for resume
@@ -23,6 +28,13 @@ AwaitableTask<bool> Task::yield() {
     // Let's wait until we're back up!
     stopped_at = std::chrono::system_clock::now();
     co_await *resume_event;
+    // Delete resume event
+    resume_event = nullptr;
+    // If task was terminating during sleep, it can finally be declared dead now
+    if (state == TaskState::terminating) {
+        state = TaskState::dead;
+        co_return false;
+    }
     // Here we go, let's keep going...
     state = TaskState::running;
     co_return true;
@@ -48,8 +60,8 @@ Task *Scheduler::get_next_task() {
     std::vector<Task*> max_prio_tasks;
     Priority max_prio = std::numeric_limits<Priority>::min();
     for (auto& task : tasks) {
-        // Filter tasks that aren't sleeping
-        if (task->state != TaskState::sleeping) continue;
+        // Filter tasks can't currently be resumed
+        if (task->resume_event == nullptr) continue;
         // Filter tasks that are suspended
         if (task->suspended) continue;
         // Update max priority
diff --git a/test.cpp b/test.cpp
index beba8fd..f771d33 100644
--- a/test.cpp
+++ b/test.cpp
@@ -4,31 +4,49 @@
 #include <string>
 
 
-CoSched::AwaitableTask<std::string> get_value() {
-    std::string fres = CoSched::Task::get_current().get_name();
-    for (unsigned it = 0; it != 100; it++) {
-        fres += "Hello";
-        co_await CoSched::Task::get_current().yield();
-    }
-    fres.resize(1);
-    co_return fres;
-}
 
-CoSched::AwaitableTask<void> test_task() {
-    auto& task = CoSched::Task::get_current();
-    if (task.get_name() == "B" || task.get_name() == "D") {
-        task.set_priority(CoSched::PRIO_HIGH);
+class LifetimeTest {
+    std::string read_test_str = "Test value";
+
+public:
+    LifetimeTest() {
+        std::cout << this << ": Lifetime start" << std::endl;
     }
-    for (unsigned x = 100; co_await task.yield(); x--) {
-        std::cout << co_await get_value() << ": " << x << '\n';
-        if (x == 10) task.terminate();
+    LifetimeTest(const LifetimeTest&) {
+        std::cout << this << ": Lifetime copy" << std::endl;
+    };
+    LifetimeTest(LifetimeTest&&) {
+        std::cout << this << ": Lifetime move" << std::endl;
     }
-}
+    ~LifetimeTest() {
+        std::cout << this << ": Lifetime end" << std::endl;
+    }
+
+    void read_test() const {
+        std::cout << read_test_str << std::flush;
+        std::cout << '\r';
+        for (unsigned i = 0; i != read_test_str.size(); i++) {
+            std::cout << ' ';
+        }
+        std::cout << '\r' << std::flush;
+        std::cout << this << ": Lifetime read test success" << std::endl;
+    }
+};
 
 int main () {
     CoSched::ScheduledThread scheduler;
-    for (const auto& name : {"A", "B", "C", "D", "E", "F"}) {
-        scheduler.create_task(name, test_task);
+    for (const auto& name : {"A", "B", "C"}) {
+        scheduler.create_task(name, [lt = LifetimeTest()] () -> CoSched::AwaitableTask<void> {
+            auto& task = CoSched::Task::get_current();
+            std::cout << task.get_name() << "Scope start" << std::endl;
+            lt.read_test();
+            if (!co_await task.yield()) co_return;
+            std::cout << task.get_name() << "Scope middle" << std::endl;
+            lt.read_test();
+            if (!co_await task.yield()) co_return;
+            std::cout << task.get_name() << "Scope end" << std::endl;
+            lt.read_test();
+        });
     }
     scheduler.start();
     scheduler.wait();