From efbae0bd4c5b01fddb488904719a515004dd0fbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= <lolisa@marisa.moe> Date: Fri, 5 Oct 2018 13:53:17 -0700 Subject: [PATCH] [Relay] Add Let list, a helper datastructure to relay (#1827) --- src/relay/pass/let_list.h | 126 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 src/relay/pass/let_list.h diff --git a/src/relay/pass/let_list.h b/src/relay/pass/let_list.h new file mode 100644 index 000000000..d13358fe0 --- /dev/null +++ b/src/relay/pass/let_list.h @@ -0,0 +1,126 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file let_list.h + * \brief LetList record let binding and insert let expression implicitly. + * using it, one can treat AST as value instead of expression, + * and pass them around freely without fear of AST explosion (or effect duplication). + * for example, if one write 'b = a + a; c = b + b; d = c + c', the AST will contain 8 'a'. + * if one instead write 'b = ll.Push(a + a); c = ll.Push(b + b); d = ll.Get(c + c);', + * the AST will contain 2 'a', as b and c are now variables. + */ +#ifndef TVM_RELAY_PASS_LET_LIST_H_ +#define TVM_RELAY_PASS_LET_LIST_H_ + +#include <tvm/relay/expr.h> +#include <utility> +#include <vector> +#include <tuple> +#include "tvm/relay/type.h" + +namespace tvm { +namespace relay { + +/*! \brief LetList allow you to transform expression into variables, so you can copy them around. + * one can insert into the LetList by calling Push, and wrap an expression with bindings with Get. + * additionally, there is the 'With' function, which automatically call Get. + */ +class LetList { + public: + /*! \brief insert a binding. + * + * \param pv the var of the binding. + * + * \param ty the type of the binding. + * + * \param expr the value of the binding. + * + * \return a Var that hold the inserted expr. + */ + Var Push(const Var& pv, const Type& ty, const Expr& expr) { + std::tuple<Var, Type, Expr> tuple(pv, ty, expr); + lets_.push_back(tuple); + return pv; + } + + /*! \brief insert a binding. + * + * \param ty the type of the binding. + * + * \param expr the value of the binding. + * + * \return a Var that hold the inserted expr. + */ + Var Push(const Type& ty, const Expr& expr) { + return Push(VarNode::make("x"), ty, expr); + } + + /*! \brief insert a binding. + * + * \param pv the var of the binding. + * + * \param expr the value of the binding. + * + * \return a Var that hold the inserted expr. + */ + Var Push(const Var& pv, const Expr& expr) { + return Push(pv, IncompleteTypeNode::make(TypeParamNode::kType), expr); + } + + /*! \brief insert a binding. + * + * \param expr the value of the binding. + * + * \return a Var that hold the inserted expr. + */ + Var Push(const Expr& expr) { + return Push(IncompleteTypeNode::make(TypeParamNode::kType), expr); + } + + /*! \brief wrap an expr around the LetList. + * + * \param body the Expression to be wrapped around. + * + * \return the wrapped expr. + */ + Expr Get(const Expr& body) const { + Expr ret = body; + for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) { + ret = LetNode::make(std::get<0>(*rit), std::get<2>(*rit), ret, std::get<1>(*rit)); + } + return ret; + } + + /*! \brief generate an LetList and wrap the result automatically. + * + * \param f a function that generate the unwrapped Expr. + * + * \code + * // Example code that generate `16 * a` using 4 plus instead of 15 plus. + * Expr mult_sixteen(const Var& a) { + * Op plus = Op::Get("plus"); + * // Automatically call Get with LetList::With + * return LetList::With([&](LetList* ll) { + * // Turn a call to plus into a variable to avoid duplication of code + * Var b = ll->Push(CallNode::make(plus, {a, a})); + * Var c = ll->Push(CallNode::make(plus, {b, b})); + * Var d = ll->Push(CallNode::make(plus, {c, c})); + * return CallNode::make(plus, {d, d}); + * }); + * } + * \endcode + * + * \return the wrapped Expr. + */ + template<typename F> + static Expr With(F&& f) { + LetList ll; + return ll.Get(f(&ll)); + } + + private: + std::vector<std::tuple<Var, Type, Expr> > lets_; +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PASS_LET_LIST_H_ -- GitLab