diff options
Diffstat (limited to 'src/base/coroutine.hh')
-rw-r--r-- | src/base/coroutine.hh | 266 |
1 files changed, 266 insertions, 0 deletions
diff --git a/src/base/coroutine.hh b/src/base/coroutine.hh new file mode 100644 index 000000000..d28889296 --- /dev/null +++ b/src/base/coroutine.hh @@ -0,0 +1,266 @@ +/* + * Copyright (c) 2018 ARM Limited + * All rights reserved + * + * The license below extends only to copyright in the software and shall + * not be construed as granting a license to any other intellectual + * property including but not limited to intellectual property relating + * to a hardware implementation of the functionality of the software + * licensed hereunder. You may use the software subject to the license + * terms below provided that you ensure that this notice is replicated + * unmodified and in its entirety in all distributions of the software, + * modified or unmodified, in source code or in binary form. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer; + * redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution; + * neither the name of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * Authors: Giacomo Travaglini + */ + +#ifndef __BASE_COROUTINE_HH__ +#define __BASE_COROUTINE_HH__ + +#include <functional> +#include <stack> + +#include "base/fiber.hh" + +namespace m5 +{ + +/** + * This template defines a Coroutine wrapper type with a Boost-like + * interface. It is built on top of the gem5 fiber class. + * The two template parameters (Arg and Ret) are the coroutine + * argument and coroutine return types which are passed between + * the coroutine and the caller via operator() and get() method. + * This implementation doesn't support passing multiple values, + * so a tuple must be used in that scenario. + * + * Most methods are templatized since it is relevant to distinguish + * the cases where one or both of the template parameters are void + */ +template <typename Arg, typename Ret> +class Coroutine : public Fiber +{ + + // This empty struct type is meant to replace coroutine channels + // in case the channel should be void (Coroutine template parameters + // are void. (See following ArgChannel, RetChannel typedef) + struct Empty {}; + using ArgChannel = typename std::conditional< + std::is_same<Arg, void>::value, Empty, std::stack<Arg>>::type; + + using RetChannel = typename std::conditional< + std::is_same<Ret, void>::value, Empty, std::stack<Ret>>::type; + + public: + /** + * CallerType: + * A reference to an object of this class will be passed + * to the coroutine task. This is the way it is possible + * for the coroutine to interface (e.g. switch back) + * to the coroutine caller. + */ + class CallerType + { + friend class Coroutine; + protected: + CallerType(Coroutine& _coro) : coro(_coro), callerFiber(nullptr) {} + + public: + /** + * operator() is the way we can jump outside the coroutine + * and return a value to the caller. + * + * This method is generated only if the coroutine returns + * a value (Ret != void) + */ + template <typename T = Ret> + CallerType& + operator()(typename std::enable_if< + !std::is_same<T, void>::value, T>::type param) + { + retChannel.push(param); + callerFiber->run(); + return *this; + } + + /** + * operator() is the way we can jump outside the coroutine + * + * This method is generated only if the coroutine doesn't + * return a value (Ret = void) + */ + template <typename T = Ret> + typename std::enable_if<std::is_same<T, void>::value, + CallerType>::type& + operator()() + { + callerFiber->run(); + return *this; + } + + /** + * get() is the way we can extrapolate arguments from the + * coroutine caller. + * The coroutine blocks, waiting for the value, unless it is already + * available; otherwise caller execution is resumed, + * and coroutine won't execute until a value is pushed + * from the caller. + * + * @return arg coroutine argument + */ + template <typename T = Arg> + typename std::enable_if<!std::is_same<T, void>::value, T>::type + get() + { + auto& args_channel = coro.argsChannel; + while (args_channel.empty()) { + callerFiber->run(); + } + + auto ret = args_channel.top(); + args_channel.pop(); + return ret; + } + + private: + Coroutine& coro; + Fiber* callerFiber; + RetChannel retChannel; + }; + + Coroutine() = delete; + Coroutine(const Coroutine& rhs) = delete; + Coroutine& operator=(const Coroutine& rhs) = delete; + + /** + * Coroutine constructor. + * The only way to construct a coroutine is to pass it the routine + * it needs to run. The first argument of the function should be a + * reference to the Coroutine<Arg,Ret>::caller_type which the + * routine will use as a way for yielding to the caller. + * + * @param f task run by the coroutine + */ + Coroutine(std::function<void(CallerType&)> f) + : Fiber(), task(f), caller(*this) + { + // Create and Run the Coroutine + this->call(); + } + + virtual ~Coroutine() {} + + public: + /** Coroutine interface */ + + /** + * operator() is the way we can jump inside the coroutine + * and passing arguments. + * + * This method is generated only if the coroutine takes + * arguments (Arg != void) + */ + template <typename T = Arg> + Coroutine& + operator()(typename std::enable_if< + !std::is_same<T, void>::value, T>::type param) + { + argsChannel.push(param); + this->call(); + return *this; + } + + /** + * operator() is the way we can jump inside the coroutine. + * + * This method is generated only if the coroutine takes + * no arguments. (Arg = void) + */ + template <typename T = Arg> + typename std::enable_if<std::is_same<T, void>::value, Coroutine>::type& + operator()() + { + this->call(); + return *this; + } + + /** + * get() is the way we can extrapolate return values + * (yielded) from the coroutine. + * The caller blocks, waiting for the value, unless it is already + * available; otherwise coroutine execution is resumed, + * and caller won't execute until a value is yielded back + * from the coroutine. + * + * @return ret yielded value + */ + template <typename T = Ret> + typename std::enable_if<!std::is_same<T, void>::value, T>::type + get() + { + auto& ret_channel = caller.retChannel; + while (ret_channel.empty()) { + this->call(); + } + + auto ret = ret_channel.top(); + ret_channel.pop(); + return ret; + } + + /** Check if coroutine is still running */ + operator bool() const { return !this->finished(); } + + private: + /** + * Overriding base (Fiber) main. + * This method will be automatically called by the Fiber + * running engine and it is a simple wrapper for the task + * that the coroutine is supposed to run. + */ + void main() override { this->task(caller); } + + void + call() + { + caller.callerFiber = currentFiber(); + run(); + } + + private: + /** Arguments for the coroutine */ + ArgChannel argsChannel; + + /** Coroutine task */ + std::function<void(CallerType&)> task; + + /** Coroutine caller */ + CallerType caller; +}; + +} //namespace m5 + +#endif // __BASE_COROUTINE_HH__ |