From a77222f8d0b09497c8ce6a085c81f3960da9d5f4 Mon Sep 17 00:00:00 2001 From: Giacomo Travaglini Date: Thu, 14 Jun 2018 11:37:20 +0100 Subject: base: Add an asymmetrical Coroutine class This patch is providing gem5 a Coroutine class to be used for instantiating asymmetrical coroutines. Coroutines are built on top of gem5 fibers, which makes them ucontext based. Change-Id: I7bb673a954d4a456997afd45b696933534f3e239 Signed-off-by: Giacomo Travaglini Reviewed-on: https://gem5-review.googlesource.com/11195 Reviewed-by: Gabe Black Maintainer: Gabe Black --- src/base/SConscript | 1 + src/base/coroutine.hh | 266 ++++++++++++++++++++++++++++++++++++++++++++++ src/base/coroutinetest.cc | 262 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 529 insertions(+) create mode 100644 src/base/coroutine.hh create mode 100644 src/base/coroutinetest.cc (limited to 'src') diff --git a/src/base/SConscript b/src/base/SConscript index b3205a6bb..ea91f7011 100644 --- a/src/base/SConscript +++ b/src/base/SConscript @@ -48,6 +48,7 @@ if env['USE_PNG']: Source('pngwriter.cc') Source('fiber.cc') GTest('fibertest', 'fibertest.cc', 'fiber.cc') +GTest('coroutinetest', 'coroutinetest.cc', 'fiber.cc') Source('framebuffer.cc') Source('hostinfo.cc') Source('inet.cc') diff --git a/src/base/coroutine.hh b/src/base/coroutine.hh new file mode 100644 index 000000000..d28889296 --- /dev/null +++ b/src/base/coroutine.hh @@ -0,0 +1,266 @@ +/* + * Copyright (c) 2018 ARM Limited + * All rights reserved + * + * The license below extends only to copyright in the software and shall + * not be construed as granting a license to any other intellectual + * property including but not limited to intellectual property relating + * to a hardware implementation of the functionality of the software + * licensed hereunder. You may use the software subject to the license + * terms below provided that you ensure that this notice is replicated + * unmodified and in its entirety in all distributions of the software, + * modified or unmodified, in source code or in binary form. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * Authors: Giacomo Travaglini + */ + +#ifndef __BASE_COROUTINE_HH__ +#define __BASE_COROUTINE_HH__ + +#include +#include + +#include "base/fiber.hh" + +namespace m5 +{ + +/** + * This template defines a Coroutine wrapper type with a Boost-like + * interface. It is built on top of the gem5 fiber class. + * The two template parameters (Arg and Ret) are the coroutine + * argument and coroutine return types which are passed between + * the coroutine and the caller via operator() and get() method. + * This implementation doesn't support passing multiple values, + * so a tuple must be used in that scenario. + * + * Most methods are templatized since it is relevant to distinguish + * the cases where one or both of the template parameters are void + */ +template +class Coroutine : public Fiber +{ + + // This empty struct type is meant to replace coroutine channels + // in case the channel should be void (Coroutine template parameters + // are void. (See following ArgChannel, RetChannel typedef) + struct Empty {}; + using ArgChannel = typename std::conditional< + std::is_same::value, Empty, std::stack>::type; + + using RetChannel = typename std::conditional< + std::is_same::value, Empty, std::stack>::type; + + public: + /** + * CallerType: + * A reference to an object of this class will be passed + * to the coroutine task. This is the way it is possible + * for the coroutine to interface (e.g. switch back) + * to the coroutine caller. + */ + class CallerType + { + friend class Coroutine; + protected: + CallerType(Coroutine& _coro) : coro(_coro), callerFiber(nullptr) {} + + public: + /** + * operator() is the way we can jump outside the coroutine + * and return a value to the caller. + * + * This method is generated only if the coroutine returns + * a value (Ret != void) + */ + template + CallerType& + operator()(typename std::enable_if< + !std::is_same::value, T>::type param) + { + retChannel.push(param); + callerFiber->run(); + return *this; + } + + /** + * operator() is the way we can jump outside the coroutine + * + * This method is generated only if the coroutine doesn't + * return a value (Ret = void) + */ + template + typename std::enable_if::value, + CallerType>::type& + operator()() + { + callerFiber->run(); + return *this; + } + + /** + * get() is the way we can extrapolate arguments from the + * coroutine caller. + * The coroutine blocks, waiting for the value, unless it is already + * available; otherwise caller execution is resumed, + * and coroutine won't execute until a value is pushed + * from the caller. + * + * @return arg coroutine argument + */ + template + typename std::enable_if::value, T>::type + get() + { + auto& args_channel = coro.argsChannel; + while (args_channel.empty()) { + callerFiber->run(); + } + + auto ret = args_channel.top(); + args_channel.pop(); + return ret; + } + + private: + Coroutine& coro; + Fiber* callerFiber; + RetChannel retChannel; + }; + + Coroutine() = delete; + Coroutine(const Coroutine& rhs) = delete; + Coroutine& operator=(const Coroutine& rhs) = delete; + + /** + * Coroutine constructor. + * The only way to construct a coroutine is to pass it the routine + * it needs to run. The first argument of the function should be a + * reference to the Coroutine::caller_type which the + * routine will use as a way for yielding to the caller. + * + * @param f task run by the coroutine + */ + Coroutine(std::function f) + : Fiber(), task(f), caller(*this) + { + // Create and Run the Coroutine + this->call(); + } + + virtual ~Coroutine() {} + + public: + /** Coroutine interface */ + + /** + * operator() is the way we can jump inside the coroutine + * and passing arguments. + * + * This method is generated only if the coroutine takes + * arguments (Arg != void) + */ + template + Coroutine& + operator()(typename std::enable_if< + !std::is_same::value, T>::type param) + { + argsChannel.push(param); + this->call(); + return *this; + } + + /** + * operator() is the way we can jump inside the coroutine. + * + * This method is generated only if the coroutine takes + * no arguments. (Arg = void) + */ + template + typename std::enable_if::value, Coroutine>::type& + operator()() + { + this->call(); + return *this; + } + + /** + * get() is the way we can extrapolate return values + * (yielded) from the coroutine. + * The caller blocks, waiting for the value, unless it is already + * available; otherwise coroutine execution is resumed, + * and caller won't execute until a value is yielded back + * from the coroutine. + * + * @return ret yielded value + */ + template + typename std::enable_if::value, T>::type + get() + { + auto& ret_channel = caller.retChannel; + while (ret_channel.empty()) { + this->call(); + } + + auto ret = ret_channel.top(); + ret_channel.pop(); + return ret; + } + + /** Check if coroutine is still running */ + operator bool() const { return !this->finished(); } + + private: + /** + * Overriding base (Fiber) main. + * This method will be automatically called by the Fiber + * running engine and it is a simple wrapper for the task + * that the coroutine is supposed to run. + */ + void main() override { this->task(caller); } + + void + call() + { + caller.callerFiber = currentFiber(); + run(); + } + + private: + /** Arguments for the coroutine */ + ArgChannel argsChannel; + + /** Coroutine task */ + std::function task; + + /** Coroutine caller */ + CallerType caller; +}; + +} //namespace m5 + +#endif // __BASE_COROUTINE_HH__ diff --git a/src/base/coroutinetest.cc b/src/base/coroutinetest.cc new file mode 100644 index 000000000..655bc254a --- /dev/null +++ b/src/base/coroutinetest.cc @@ -0,0 +1,262 @@ +/* + * Copyright (c) 2018 ARM Limited + * All rights reserved + * + * The license below extends only to copyright in the software and shall + * not be construed as granting a license to any other intellectual + * property including but not limited to intellectual property relating + * to a hardware implementation of the functionality of the software + * licensed hereunder. You may use the software subject to the license + * terms below provided that you ensure that this notice is replicated + * unmodified and in its entirety in all distributions of the software, + * modified or unmodified, in source code or in binary form. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * Authors: Giacomo Travaglini + */ + +#include + +#include "base/coroutine.hh" + +using namespace m5; + +/** + * This test is checking if the Coroutine, once it yields + * back to the caller, it is still marked as not finished. + */ +TEST(Coroutine, Unfinished) +{ + auto yielding_task = + [] (Coroutine::CallerType& yield) + { + yield(); + }; + + Coroutine coro(yielding_task); + ASSERT_TRUE(coro); +} + +/** + * This test is checking the parameter passing interface of a + * coroutine which takes an integer as an argument. + * Coroutine::operator() and CallerType::get() are the tested + * APIS. + */ +TEST(Coroutine, Passing) +{ + const std::vector input{ 1, 2, 3 }; + const std::vector expected_values = input; + + auto passing_task = + [&expected_values] (Coroutine::CallerType& yield) + { + int argument; + + for (const auto expected : expected_values) { + argument = yield.get(); + ASSERT_EQ(argument, expected); + } + }; + + Coroutine coro(passing_task); + ASSERT_TRUE(coro); + + for (const auto val : input) { + coro(val); + } +} + +/** + * This test is checking the yielding interface of a coroutine + * which takes no argument and returns integers. + * Coroutine::get() and CallerType::operator() are the tested + * APIS. + */ +TEST(Coroutine, Returning) +{ + const std::vector output{ 1, 2, 3 }; + const std::vector expected_values = output; + + auto returning_task = + [&output] (Coroutine::CallerType& yield) + { + for (const auto ret : output) { + yield(ret); + } + }; + + Coroutine coro(returning_task); + ASSERT_TRUE(coro); + + for (const auto expected : expected_values) { + int returned = coro.get(); + ASSERT_EQ(returned, expected); + } +} + +/** + * This test is still supposed to test the returning interface + * of the the Coroutine, proving how coroutine can be used + * for generators. + * The coroutine is computing the first #steps of the fibonacci + * sequence and it is yielding back results one number per time. + */ +TEST(Coroutine, Fibonacci) +{ + const std::vector expected_values{ + 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233 }; + + const int steps = expected_values.size(); + + auto fibonacci_task = + [steps] (Coroutine::CallerType& yield) + { + int prev = 0; + int current = 1; + + for (auto iter = 0; iter < steps; iter++) { + int sum = prev + current; + yield(sum); + + prev = current; + current = sum; + } + }; + + Coroutine coro(fibonacci_task); + ASSERT_TRUE(coro); + + for (const auto expected : expected_values) { + ASSERT_TRUE(coro); + int returned = coro.get(); + ASSERT_EQ(returned, expected); + } +} + +/** + * This test is using a bi-channel coroutine (accepting and + * yielding values) for testing a cooperative task. + * The caller and the coroutine have a string each; they are + * composing a new string by merging the strings together one + * character per time. + * The result string is hence passed back and forth between the + * coroutine and the caller. + */ +TEST(Coroutine, Cooperative) +{ + const std::string caller_str("HloWrd"); + const std::string coro_str("el ol!"); + const std::string expected("Hello World!"); + + auto cooperative_task = + [&coro_str] (Coroutine::CallerType& yield) + { + for (auto& appended_c : coro_str) { + auto old_str = yield.get(); + yield(old_str + appended_c); + } + }; + + Coroutine coro(cooperative_task); + + std::string result; + for (auto& c : caller_str) { + ASSERT_TRUE(coro); + result += c; + result = coro(result).get(); + } + + ASSERT_EQ(result, expected); +} + +/** + * This test is testing nested coroutines by using one inner and one + * outer coroutine. It basically ensures that yielding from the inner + * coroutine returns to the outer coroutine (mid-layer of execution) and + * not to the outer caller. + */ +TEST(Coroutine, Nested) +{ + const std::string wrong("Inner"); + const std::string expected("Inner + Outer"); + + auto inner_task = + [] (Coroutine::CallerType& yield) + { + std::string inner_string("Inner"); + yield(inner_string); + }; + + auto outer_task = + [&inner_task] (Coroutine::CallerType& yield) + { + Coroutine coro(inner_task); + std::string inner_string = coro.get(); + + std::string outer_string("Outer"); + yield(inner_string + " + " + outer_string); + }; + + + Coroutine coro(outer_task); + ASSERT_TRUE(coro); + + std::string result = coro.get(); + + ASSERT_NE(result, wrong); + ASSERT_EQ(result, expected); +} + +/** + * This test is stressing the scenario where two distinct fibers are + * calling the same coroutine. First the test instantiates (and runs) a + * coroutine, then spawns another one and it passes it a reference to + * the first coroutine. Once the new coroutine calls the first coroutine + * and the first coroutine yields, we are expecting execution flow to + * be yielded to the second caller (the second coroutine) and not the + * original caller (the test itself) + */ +TEST(Coroutine, TwoCallers) +{ + bool valid_return = false; + + Coroutine callee{[] + (Coroutine::CallerType& yield) + { + yield(); + yield(); + }}; + + Coroutine other_caller{[&callee, &valid_return] + (Coroutine::CallerType& yield) + { + callee(); + valid_return = true; + yield(); + }}; + + ASSERT_TRUE(valid_return); +} -- cgit v1.2.3