From 816419bec5d477b97404d250d4d7729dcefaf11f Mon Sep 17 00:00:00 2001 From: tqchen <tianqi.tchen@gmail.com> Date: Sat, 22 Oct 2016 11:29:27 -0700 Subject: [PATCH] Check in basic schedule container --- include/tvm/array.h | 4 +- include/tvm/base.h | 2 +- include/tvm/codegen.h | 39 ++++++++++++++ include/tvm/schedule.h | 108 +++++++++++++++++++++++++++++++++++++ include/tvm/split.h | 84 +++++++++++++++++++++++++++++ src/c_api/c_api_registry.h | 3 ++ src/schedule/schedule.cc | 5 ++ 7 files changed, 242 insertions(+), 3 deletions(-) create mode 100644 include/tvm/codegen.h create mode 100644 include/tvm/schedule.h create mode 100644 include/tvm/split.h create mode 100644 src/schedule/schedule.cc diff --git a/include/tvm/array.h b/include/tvm/array.h index 5c9be4cf0..a695cf24a 100644 --- a/include/tvm/array.h +++ b/include/tvm/array.h @@ -147,7 +147,7 @@ class Array : public NodeRef { /*! * \brief set i-th element of the array. * \param i The index - * \param other The value to be setted. + * \param value The value to be setted. */ inline void Set(size_t i, const T& value) { this->CopyOnWrite(); @@ -161,7 +161,7 @@ class Array : public NodeRef { size_t index; /*! * \brief assign operator - * \param value The value to be assigned + * \param other The value to be assigned * \return reference to self. */ inline ArrayItemRef& operator=(const T& other) { diff --git a/include/tvm/base.h b/include/tvm/base.h index 608abe11d..a024e7086 100644 --- a/include/tvm/base.h +++ b/include/tvm/base.h @@ -12,7 +12,7 @@ #include <memory> #include <functional> #include <typeinfo> - +#include <type_traits> namespace tvm { diff --git a/include/tvm/codegen.h b/include/tvm/codegen.h new file mode 100644 index 000000000..3baa284e9 --- /dev/null +++ b/include/tvm/codegen.h @@ -0,0 +1,39 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file codegen.h + * \brief Common data structure for codegen + */ +#ifndef TVM_CODEGEN_H_ +#define TVM_CODEGEN_H_ + +namespace tvm { + +// incomplete spec. +struct Assign : public Node { + Expr src; + Expr offset; + Var ptr; +}; + +struct Assign : public Node { + Expr src; + Expr offset; + Var ptr; +}; + +struct Loop : public Node { + Expr init; + Expr cond; + Stmt body; +}; + +struct IfThenElse : public Node { + Expr cond; + Expr then_; + Stmt else_; +}; + + +} // namespace tvm + +#endif // TVM_CODEGEN_H_ diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h new file mode 100644 index 000000000..b1d4227f6 --- /dev/null +++ b/include/tvm/schedule.h @@ -0,0 +1,108 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file schedule.h + * \brief Define a schedule. + */ +#ifndef TVM_SCHEDULE_H_ +#define TVM_SCHEDULE_H_ + +#include <string> +#include "./base.h" +#include "./split.h" +#include "./tensor.h" + +namespace tvm { + +// Node container for Schedule +class ScheduleNode; +// Node container for AttachSpec +class AttachSpecNode; + +/*! \brief the attachment type */ +enum AttachType : int { + kRoot = 0, + kInline = 1, + kSplit = 2 +}; + +/*! \brief schedule container */ +class Schedule : public NodeRef { + public: + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const ScheduleNode* operator->() const; +}; + +/*! \brief schedule container */ +class AttachSpec : public NodeRef { + public: + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const AttachSpecNode* operator->() const; +}; + +// defintion of node containers + +/*! \brief The attach specification of each subschedule */ +class AttachSpecNode : public Node { + public: + /*! \brief The attachment type */ + AttachType attach_type; + /*! + * \brief The split to be attached to, + * only valid when attach_type is kRoot + */ + Split attach_split; + /*! \brief the child schedule to be attached. */ + Schedule schedule; + const char* type_key() const override { + return "AttachSpecNode"; + } + void VisitAttrs(AttrVisitor* visitor) override { + visitor->Visit("attach_type", &attach_type); + } + void VisitNodeRefFields(FNodeRefVisit fvisit) override { + fvisit("attach_split", &attach_split); + fvisit("schedule", &schedule); + } +}; + +/*! \brief represents the schedule of the tensor */ +class ScheduleNode : public Node { + public: + /*! \brief Tensor to be scheduled */ + Tensor tensor; + /*! \brief The thread scope level of the schedule */ + std::string scope; + /*! \brief Splits over domains or rdomains */ + Array<Split> splits; + /*! \brief attach specifications */ + Array<AttachSpec> attachs; + const char* type_key() const override { + return "AttachSpecNode"; + } + void VisitAttrs(AttrVisitor* visitor) override { + visitor->Visit("scope", &scope); + } + void VisitNodeRefFields(FNodeRefVisit fvisit) override { + fvisit("tensor", &tensor); + fvisit("splits", &splits); + fvisit("attachs", &attachs); + } +}; + +// implementations +inline const ScheduleNode* Schedule::operator->() const { + return static_cast<const ScheduleNode*>(node_.get()); +} + +inline const AttachSpecNode* AttachSpec::operator->() const { + return static_cast<const AttachSpecNode*>(node_.get()); +} + +} // namespace tvm +#endif // TVM_SCHEDULE_H_ diff --git a/include/tvm/split.h b/include/tvm/split.h new file mode 100644 index 000000000..cd03ab770 --- /dev/null +++ b/include/tvm/split.h @@ -0,0 +1,84 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file split.h + * \brief Define a split over Domain or RDomain + */ +#ifndef TVM_SPLIT_H_ +#define TVM_SPLIT_H_ + +#include "./base.h" +#include "./array.h" +#include "./domain.h" + +namespace tvm { + +// internal node container for split. +class SplitNode; + +/*! \brief Split over input domain */ +class Split : public NodeRef { + public: + /*! \brief default constructor */ + Split() {} + /*! \return Whether the split is over RDomain or not */ + inline bool is_over_rdom() const; + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const SplitNode* operator->() const; +}; + +/*! + * \brief base class of split node, + * specifies a split over domain + * split also defines how to generate + */ +class SplitNode : public Node { + public: + /*! \brief whether the split is over reduction domain*/ + int split_over_rdom{0}; + /*! + * \brief given the output domain, infer input domain + * \param split_index The index to be splitted on + * \param out_domain The outer domain + * \return The inferred inner domain. + */ + virtual Domain InferInnerDomain(Expr split_index, Domain out_domain) const = 0; +}; + +/*! \brief simple split node that splits over one dimension */ +class DimSplitNode : public SplitNode { + public: + /*! \brief The dimension to split on */ + int64_t dim_index; + /*! \brief The factor of the split */ + Expr factor; + /*! \brief constructor */ + DimSplitNode() {} + const char* type_key() const override { + return "DimSplitNode"; + } + void VisitAttrs(AttrVisitor* visitor) override { + visitor->Visit("split_over_rdom", &split_over_rdom); + } + void VisitNodeRefFields(FNodeRefVisit fvisit) override { + fvisit("factor", &factor); + } + Domain InferInnerDomain(Expr split_index, Domain out_domain) const override { + LOG(FATAL) << "not implemented"; + return Domain(); + } +}; + +// Implementations of inline functions +inline const SplitNode* Split::operator->() const { + return static_cast<const SplitNode*>(node_.get()); +} + +inline bool Split::is_over_rdom() const { + return (*this)->split_over_rdom != 0; +} + +} // namespace tvm +#endif // TVM_SPLIT_H_ diff --git a/src/c_api/c_api_registry.h b/src/c_api/c_api_registry.h index 95c4a8494..ace2b27d0 100644 --- a/src/c_api/c_api_registry.h +++ b/src/c_api/c_api_registry.h @@ -10,6 +10,7 @@ #include <tvm/expr.h> #include <tvm/c_api.h> #include <memory> +#include <limits> #include <string> #include <vector> @@ -84,6 +85,8 @@ struct APIVariantValue { } inline operator int() const { CHECK_EQ(type_id, kLong); + CHECK_LE(v_union.v_long, + std::numeric_limits<int>::max()); return v_union.v_long; } inline operator std::string() const { diff --git a/src/schedule/schedule.cc b/src/schedule/schedule.cc new file mode 100644 index 000000000..b10d2542b --- /dev/null +++ b/src/schedule/schedule.cc @@ -0,0 +1,5 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file schedule.cc + */ +#include <tvm/schedule.h> -- GitLab