Skip to content
Snippets Groups Projects
Commit 51fe00fb authored by Jared Roesch's avatar Jared Roesch Committed by Tianqi Chen
Browse files

[High level OPT][RFC] NNVMv2 IR - Relay (#1672)

parent 543c4240
No related branches found
No related tags found
No related merge requests found
Showing
with 2067 additions and 1 deletion
......@@ -104,6 +104,12 @@ file(GLOB COMPILER_SRCS
src/schedule/*.cc
)
file(GLOB_RECURSE RELAY_SRCS
src/relay/*.cc
)
list(APPEND COMPILER_SRCS ${RELAY_SRCS})
if(NOT MSVC)
file(GLOB COMPILER_VERILOG_SRCS src/codegen/verilog/*.cc)
list(APPEND COMPILER_SRCS ${COMPILER_VERILOG_SRCS})
......
......@@ -33,7 +33,7 @@ sys.path.insert(0, os.path.join(curr_path, '../vta/python'))
# General information about the project.
project = u'tvm'
author = u'%s developers' % project
copyright = u'2017, %s' % author
copyright = u'2018, %s' % author
github_doc_root = 'https://github.com/tqchen/tvm/tree/master/docs/'
# add markdown parser
......
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/base.h
* \brief Base classes for the Relay IR.
*/
#ifndef TVM_RELAY_BASE_H_
#define TVM_RELAY_BASE_H_
#include <tvm/api_registry.h>
#include <tvm/ir.h>
#include <tvm/node.h>
#include <string>
#include <vector>
namespace tvm {
/*!
* \brief Relay: a high level functional IR for TVM.
*
* This namespace contains the abstract syntax tree, and other
* essential data structures for the Relay IR.
*
* You can find more about Relay by reading the language reference.
*/
namespace relay {
/*!
* \brief we always used NodeRef for referencing nodes.
*
* By default, NodeRef is a std::shared_ptr of node
*/
using NodeRef = tvm::NodeRef;
/*!
* \brief Content data type.
*/
using DataType = ::tvm::Type;
/*!
* \brief Symbolic expression for tensor shape.
*/
using ShapeExpr = ::tvm::Expr;
/*!
* \brief Hash function for nodes.
* e.g. std::unordered_map<Expr, Value, NodeHash, NodeEqual>
*/
using NodeHash = ::tvm::NodeHash;
/*!
* \brief Equality check function for nodes.
*/
using NodeEqual = ::tvm::NodeEqual;
/*!
* \brief Macro to make it easy to define node ref type given node
* \param TypeName The name of the reference type.
* \param NodeName The internal container name.
* \param NodeRefBase The base type.
*/
#define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \
class TypeName : public NodeRefBase { \
public: \
TypeName() {} \
explicit TypeName(std::shared_ptr<::tvm::Node> n) : NodeRefBase(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(node_.get()); \
} \
operator bool() { return this->defined(); } \
using ContainerType = NodeName; \
};
/*!
* \brief The source name in the Span
* \sa SourceNameNode, Span
*/
class SourceName;
/*!
* \brief The name of a source fragment.
*/
class SourceNameNode : public Node {
public:
/*! \brief The source name. */
std::string name;
// override attr visitor
void VisitAttrs(AttrVisitor* v) final { v->Visit("name", &name); }
TVM_DLL static SourceName make(std::string name);
static constexpr const char* _type_key = "relay.SourceName";
TVM_DECLARE_NODE_TYPE_INFO(SourceNameNode, Node);
};
/*!
* \brief The source name of a file span.
* \sa SourceNameNode, Span
*/
class SourceName : public NodeRef {
public:
/*! \brief default constructor */
SourceName() {}
/*! \brief constructor from node pointer */
explicit SourceName(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const SourceNameNode* operator->() const;
/*!
* \brief Get an SourceName for a given operator name.
* Will raise an error if the source name has not been registered.
* \param name Name of the operator.
* \return Reference to a SourceName valid throughout program lifetime.
*/
TVM_DLL static const SourceName& Get(const std::string& name);
/*! \brief specify container node */
using ContainerType = SourceNameNode;
};
/*!
* \brief Span information for debugging purposes
*/
class Span;
/*!
* \brief Stores locations in frontend source that generated a node.
*/
class SpanNode : public Node {
public:
/*! \brief The source name */
SourceName source;
/*! \brief Line number */
int lineno;
/*! \brief column offset */
int col_offset;
// override attr visitor
void VisitAttrs(AttrVisitor* v) final {
v->Visit("source", &source);
v->Visit("lineno", &lineno);
v->Visit("col_offset", &col_offset);
}
TVM_DLL static Span make(SourceName source, int lineno, int col_offset);
static constexpr const char* _type_key = "relay.Span";
TVM_DECLARE_NODE_TYPE_INFO(SpanNode, Node);
};
RELAY_DEFINE_NODE_REF(Span, SpanNode, NodeRef);
/*!
* \brief This is the base node container of all relay structures.
*/
class RelayNode : public Node {
public:
/*! \brief The location of the program in a SourceFragment can be null,
* check with span.defined() */
mutable Span span;
static constexpr const char* _type_key = "relay.Node";
TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node);
};
/*!
* \brief Get a reference type from a Node ptr type
*
* It is always important to get a reference type
* if we want to return a value as reference or keep
* the node alive beyond the scope of the function.
*
* \param ptr The node pointer
* \tparam RefType The reference type
* \tparam NodeType The node type
* \return The corresponding RefType
*/
template <typename RefType, typename NodeType>
RefType GetRef(const NodeType* ptr) {
static_assert(std::is_same<typename RefType::ContainerType, NodeType>::value,
"Can only cast to the ref of same container type");
return RefType(const_cast<NodeType*>(ptr)->shared_from_this());
}
// TODO(@tqchen, @jroesch): can we move these semantics to HalideIR
template <typename T>
inline const T* As(const NodeRef& node) {
const Node* ptr = static_cast<const Node*>(node.get());
if (ptr && (ptr->is_type<T>() || ptr->derived_from<T>())) {
return static_cast<const T*>(ptr);
}
return nullptr;
}
template <typename SubRef, typename BaseRef>
SubRef Downcast(BaseRef ref) {
CHECK(ref->template is_type<typename SubRef::ContainerType>())
<< "Downcast from " << ref->type_key() << " to "
<< SubRef::ContainerType::_type_key << " failed.";
return SubRef(ref.node_);
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_BASE_H_
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/environment.h
* \brief The global environment: contains information needed to
* compile & optimize Relay programs.
*/
#ifndef TVM_RELAY_ENVIRONMENT_H_
#define TVM_RELAY_ENVIRONMENT_H_
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/type.h>
#include <string>
#include <vector>
namespace tvm {
namespace relay {
struct Environment;
/*! \brief The global environment of Relay programs.
*
* The global environment contains the global
* information needed to compile a Relay program.
*
* It contains all global functions, and configuration
* options.
*
* Many operations require access to the global
* Environment. We pass the Environment by value
* in a functional style as an explicit argument,
* but we mutate the Environment while optimizing
* Relay programs.
*
* The functional style allows users to construct custom
* environments easily, for example each thread can store
* an Environment while auto-tuning.
* */
class EnvironmentNode : public RelayNode {
public:
/*! \brief A map from ids to all global functions. */
tvm::Map<GlobalVar, Function> functions;
EnvironmentNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("functions", &functions);
v->Visit("global_map_", &global_map_);
}
TVM_DLL static Environment make(tvm::Map<GlobalVar, Function> global_funcs);
/*! \brief Add a function to the global environment.
* \param var The name of the global function.
* \param func The function.
* \param update Controls whether you can replace a definition in the
* environment.
*/
void Add(const GlobalVar& var, const Function& func, bool update = false);
/*! \brief Update a function in the global environment.
* \param var The name of the global function to update.
* \param func The new function.
*/
void Update(const GlobalVar& var, const Function& func);
/*! \brief Remove a function from the global environment.
* \param var The name of the global function to update.
*/
void Remove(const GlobalVar& var);
/*! \brief Lookup a global function by its variable.
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
GlobalVar GetGlobalVar(const std::string& str);
/*! \brief Lookup a global function by its variable.
* \param var The global var to lookup.
* \returns The function named by the variable argument.
*/
Function Lookup(const GlobalVar& var);
/*! \brief Lookup a global function by its string name
* \param name The name of the function.
* \returns The function named by the argument.
*/
Function Lookup(const std::string& name);
/*! \brief Combine with another Environment.
* \param other The other environment.
*/
void Merge(const Environment& other);
static constexpr const char* _type_key = "relay.Environment";
TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node);
private:
/*! \brief A map from string names to global variables that
* ensures global uniqueness.
*/
tvm::Map<std::string, GlobalVar> global_map_;
};
struct Environment : public NodeRef {
Environment() {}
explicit Environment(std::shared_ptr<tvm::Node> p) : NodeRef(p) {}
inline EnvironmentNode* operator->() const {
return static_cast<EnvironmentNode*>(node_.get());
}
using ContainerType = EnvironmentNode;
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ENVIRONMENT_H_
/*!
* Copyright (c) 2018 by Contributors
* \file error.h
* \brief The set of errors raised by Relay.
*/
#ifndef TVM_RELAY_ERROR_H_
#define TVM_RELAY_ERROR_H_
#include <string>
#include "./base.h"
namespace tvm {
namespace relay {
struct Error : dmlc::Error {
explicit Error(const std::string &msg) : dmlc::Error(msg) {}
};
struct InternalError : Error {
explicit InternalError(const std::string &msg) : Error(msg) {}
};
// TODO(@jroesch): we should change spanned errors to report
// errors against the Environment, inverting control to error definition.
struct FatalTypeError : dmlc::Error {
explicit FatalTypeError(const std::string &s) : dmlc::Error(s) {}
};
struct TypecheckerError : public dmlc::Error {
explicit TypecheckerError(const std::string &msg) : Error(msg) {}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ERROR_H_
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/expr.h
* \brief Relay expression language.
*/
#ifndef TVM_RELAY_EXPR_H_
#define TVM_RELAY_EXPR_H_
#include <tvm/attrs.h>
#include <string>
#include "./base.h"
#include "./type.h"
namespace tvm {
namespace relay {
/*!
* \brief A Relay expression.
*/
class Expr;
/*!
* \brief Base type of the Relay expression hiearchy.
*/
class ExprNode : public RelayNode {
public:
/*!
* \brief Stores the result of type inference(type checking).
*
* \note This can be undefined before type inference.
* This value is discarded during serialization.
*/
mutable Type checked_type_ = Type(nullptr);
/*!
* \return The checked_type
*/
const Type& checked_type() const {
CHECK(checked_type_.defined()) << "internal error: the type checker has "
"not populated the checked_type "
"field for this node";
return this->checked_type_;
}
static constexpr const char* _type_key = "relay.Expr";
TVM_DECLARE_BASE_NODE_INFO(ExprNode, RelayNode);
};
RELAY_DEFINE_NODE_REF(Expr, ExprNode, NodeRef);
/*!
* \brief Constant tensor, backed by an NDArray on the cpu(0) device.
*
* \note Scalar constants are represented by rank-0 const tensor.
* Constant folding are handled uniformly via Tensor types.
*/
class Constant;
/*!
* \brief Constant tensor type.
*/
class ConstantNode : public ExprNode {
public:
/*! \brief The data of the tensor */
runtime::NDArray data;
/*! \return The corresponding tensor type of the data */
TensorType tensor_type() const;
/*! \return Whether it is scalar(rank-0 tensor) */
bool is_scalar() const { return data->ndim == 0; }
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("data", &data);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static Constant make(runtime::NDArray data);
static constexpr const char* _type_key = "relay.Constant";
TVM_DECLARE_NODE_TYPE_INFO(ConstantNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(Constant, ConstantNode, Expr);
/*! \brief Tuple of multiple Exprs */
class Tuple;
/*! \brief Tuple container */
class TupleNode : public ExprNode {
public:
/*! \brief the fields of the tuple */
tvm::Array<relay::Expr> fields;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("fields", &fields);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static Tuple make(tvm::Array<relay::Expr> fields);
static constexpr const char* _type_key = "relay.Tuple";
TVM_DECLARE_NODE_TYPE_INFO(TupleNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr);
/*!
* \brief Local variables used in the let expression.
*
* Its semantics are similar to tvm.Var node used in TVM's low level
* tensor expression language.
*
* \note Each Var is bind only once and is immutable/
*/
class Var;
/*! \brief Container for Var */
class VarNode : public ExprNode {
public:
/*! \brief The name of the variable, this only acts as a hint to the user,
* and is not used for equality.
*/
std::string name_hint;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("name_hint", &name_hint);
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static Var make(std::string name_hint);
static constexpr const char* _type_key = "relay.Var";
TVM_DECLARE_NODE_TYPE_INFO(VarNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(Var, VarNode, Expr);
/*!
* \brief Global variable that leaves in the top-level environment.
* This is used to enable recursive calls between function.
*
* \note A GlobalVar may only point to functions.
*/
class GlobalVar;
/*! \brief A GlobalId from the node's current type to target type. */
class GlobalVarNode : public ExprNode {
public:
/*! \brief The name of the variable, this only acts as a hint. */
std::string name_hint;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("name_hint", &name_hint);
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static GlobalVar make(std::string name_hint);
static constexpr const char* _type_key = "relay.GlobalVar";
TVM_DECLARE_NODE_TYPE_INFO(GlobalVarNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(GlobalVar, GlobalVarNode, Expr);
/*!
* \brief Function parameter declaration.
*/
class Param;
/*! \brief A parameter. */
class ParamNode : public ExprNode {
public:
/*! \brief The variable */
Var var;
/*! \brief The type of the parameter */
Type type;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("type", &type);
v->Visit("span", &span);
}
TVM_DLL static Param make(Var var, Type type);
static constexpr const char* _type_key = "relay.Param";
TVM_DECLARE_NODE_TYPE_INFO(ParamNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(Param, ParamNode, Expr);
/*!
* \brief Function (subgraph in computational graph)
*/
class Function;
/*! \brief Function container */
class FunctionNode : public ExprNode {
public:
/*! \brief Function parameters */
tvm::Array<Param> params;
/*! \brief User annotated return type of the function. */
Type ret_type;
/*!
* \brief
* The expression which represents the computation of the function,
* the expression may reference the parameters, and the type of it
* or sub-expressions may reference the type variables.
*/
Expr body;
/*!
* \brief Type parameters of the function.
* Enables the function to vary its type based on these.
* This corresponds to template paramaters in c++'s terminology.
*
* \note This can be usually empty for non-polymorphic functions.
*/
tvm::Array<TypeParam> type_params;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("params", &params);
v->Visit("ret_type", &ret_type);
v->Visit("body", &body);
v->Visit("type_params", &type_params);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Type fn_type() const;
TVM_DLL static Function make(tvm::Array<Param> params, Type ret_type,
Expr body, tvm::Array<TypeParam> ty_params);
static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr);
/*!
* \brief Call corresponds to operator invocation.
* Corresponds to the operator in computational graph terminology.
*/
class Call;
/*! \brief Call container. */
class CallNode : public ExprNode {
public:
/*!
* \brief The operator(function) being invoked
*
* - It can be relay::Op which corresponds to the primitive operators.
* - It can also be user defined functions (Function, GlobalVar, Var).
*/
Expr op;
/*! \brief The arguments(inputs) of the call */
tvm::Array<relay::Expr> args;
/*! \brief The additional attributes */
Attrs attrs;
/*!
* \brief The type arguments passed to polymorphic(template) function.
*
* This is the advance feature that is only used when the function is
* polymorphic. It is safe to be ignored in most cases. For example, in the
* following code, the type_args of addone call is [int].
*
* \code
*
* template<typename T>
* T addone(T a) { return a + 1; }
*
* void main() {
* int x = addone<int>(10);
* }
*
* \endcode
*/
tvm::Array<Type> type_args;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("op", &op);
v->Visit("args", &args);
v->Visit("attrs", &attrs);
v->Visit("type_args", &type_args);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static Call make(Expr op, Array<Expr> args, Attrs attrs = Attrs(),
Array<Type> ty_args = Array<Type>());
static constexpr const char* _type_key = "relay.Call";
TVM_DECLARE_NODE_TYPE_INFO(CallNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(Call, CallNode, Expr);
/*!
* \brief Let binding that binds a local var and optionally a type annotation.
*
* \note Let is useful to transform the program to be A-normal form.
* where each of the expression corresponds to a let binding.
*
* For developers who are familar with the computational graph.
* Each of the let can be viewed as a operator node in the computational graph.
* Traversing the list of let bindings is similar to running
* PostDFS-order(topo-order) traversal on the computational graph.
*/
class Let;
/*! \brief A binding of a sub-network. */
class LetNode : public ExprNode {
public:
/*! \brief The variable we bind to */
Var var;
/*! \brief The value we bind var to */
Expr value;
/*! \brief The body of the let binding */
Expr body;
/*! \brief Type annotation of value, this can be null */
Type value_type;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("value", &value);
v->Visit("body", &body);
v->Visit("value_type", &value_type);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static Let make(Var var, Expr value, Expr body, Type value_type);
static constexpr const char* _type_key = "relay.Let";
TVM_DECLARE_NODE_TYPE_INFO(LetNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(Let, LetNode, Expr);
/*!
* \brief Condition expression
*
* Unlike traditional statement `if`s, the if evalutes
* to the result of the branch taken.
*
* let x = if (true) { 1 } else { 0 }; // x is 1
* let y = if (false) { 1 } else { 0 }; // y is 0
*
* \note This is similar to C's ternary operator.
*/
class If;
/*! \brief container of If */
class IfNode : public ExprNode {
public:
/*! \brief The condition */
Expr cond;
/*! \brief The expression evaluated when condition is true. */
Expr true_branch;
/*! \brief The expression evaluated when condition is false */
Expr false_branch;
IfNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("cond", &cond);
v->Visit("true_branch", &true_branch);
v->Visit("false_branch", &false_branch);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static If make(Expr cond, Expr true_branch, Expr false_branch);
static constexpr const char* _type_key = "relay.If";
TVM_DECLARE_NODE_TYPE_INFO(IfNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(If, IfNode, Expr);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_H_
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/expr_functor.h
* \brief A more powerful visitor which enables defining arbitrary function
* signatures with type based dispatch on first argument.
*/
#ifndef TVM_RELAY_EXPR_FUNCTOR_H_
#define TVM_RELAY_EXPR_FUNCTOR_H_
#include <tvm/ir_functor.h>
#include <string>
#include "./expr.h"
#include "./op.h"
namespace tvm {
namespace relay {
/*!
* \brief A dynamical functor that dispatches on in the first Expr argument.
* You can use this as a more powerful Visitor, since it allows you to
* define function signatures of Visit Function.
*
* \sa tvm/ir_functor.h
*
* \tparam FType function signiture
* This type is only defined for FType with function signature R(const Expr&,
* Args...)
*/
template <typename FType>
class ExprFunctor;
// functions to be overriden.
#define EXPR_FUNCTOR_DEFAULT \
{ return VisitExprDefault_(op, std::forward<Args>(args)...); }
#define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
return self->VisitExpr_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \
});
template <typename R, typename... Args>
class ExprFunctor<R(const Expr& n, Args...)> {
private:
using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
using FType = tvm::IRFunctor<R(const NodeRef& n, TSelf* self, Args...)>;
public:
/*! \brief the result type of this functor */
using result_type = R;
/*! \brief virtual destructor */
virtual ~ExprFunctor() {}
/*!
* \brief Same as call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
R operator()(const Expr& n, Args... args) {
return VisitExpr(n, std::forward<Args>(args)...);
}
/*!
* \brief The functor call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
virtual R VisitExpr(const Expr& n, Args... args) {
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
// Functions that can be overriden by subclass
virtual R VisitExpr_(const ConstantNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const TupleNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const VarNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const GlobalVarNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ParamNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FunctionNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const IfNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const OpNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) {
throw dmlc::Error(std::string("Do not have a default for ") + op->type_key());
}
private:
// initialize the vtable.
static FType InitVTable() {
FType vtable;
// Set dispatch
RELAY_EXPR_FUNCTOR_DISPATCH(ConstantNode);
RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode);
RELAY_EXPR_FUNCTOR_DISPATCH(VarNode);
RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode);
RELAY_EXPR_FUNCTOR_DISPATCH(ParamNode);
RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode);
RELAY_EXPR_FUNCTOR_DISPATCH(CallNode);
RELAY_EXPR_FUNCTOR_DISPATCH(LetNode);
RELAY_EXPR_FUNCTOR_DISPATCH(IfNode);
RELAY_EXPR_FUNCTOR_DISPATCH(OpNode);
return vtable;
}
};
/*! \brief A simple visitor wrapper around ExprFunctor.
*
* Exposes two visitors with default traversal strategies, one
* which doesn't compute a result but can mutate internal state,
* and another which functionally builds a new Expr.
*/
class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
public:
void VisitExpr_(const VarNode* op) override;
void VisitExpr_(const GlobalVarNode* op) override;
void VisitExpr_(const ConstantNode* op) override;
void VisitExpr_(const TupleNode* op) override;
void VisitExpr_(const ParamNode* op) override;
void VisitExpr_(const FunctionNode* op) override;
void VisitExpr_(const CallNode* op) override;
void VisitExpr_(const LetNode* op) override;
void VisitExpr_(const IfNode* op) override;
void VisitExpr_(const OpNode* op) override;
virtual void VisitType(const Type& t);
};
/*! \brief A wrapper around ExprFunctor which functionally updates the AST.
*
* ExprMutator uses memoization and self return in order to amortize
* the cost of using functional updates.
*/
class ExprMutator
: public ::tvm::relay::ExprFunctor<Expr(const Expr&, const Expr&)> {
public:
Expr Mutate(const Expr& expr);
Expr VisitExpr_(const VarNode* op, const Expr& e) override;
Expr VisitExpr_(const ConstantNode* op, const Expr& e) override;
Expr VisitExpr_(const GlobalVarNode* op, const Expr& e) override;
Expr VisitExpr_(const OpNode* op, const Expr& expr) override;
Expr VisitExpr_(const TupleNode* op, const Expr& e) override;
Expr VisitExpr_(const ParamNode* op, const Expr& e) override;
Expr VisitExpr_(const FunctionNode* op, const Expr& e) override;
Expr VisitExpr_(const CallNode* call_node, const Expr& e) override;
Expr VisitExpr_(const LetNode* op, const Expr& e) override;
Expr VisitExpr_(const IfNode* op, const Expr& e) override;
/*! \brief Used to visit the types inside of expressions.
*
* Can be overloaded to transform the types in arbitrary
* ways, one way would be to define a sub-class of type
* visitor for types which transform them appropriately.
*/
virtual Type VisitType(const Type& t);
private:
/*! \brief Internal map used for memoization. */
tvm::Map<Expr, Expr> memo_;
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_FUNCTOR_H_
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/logging.h
* \brief A wrapper around dmlc-core/logging.h which adds the ability
* to toggle logging via an environment variable.
*/
#ifndef TVM_RELAY_LOGGING_H_
#define TVM_RELAY_LOGGING_H_
#include <dmlc/logging.h>
#include <string>
#include <cstdlib>
#include <iostream>
namespace tvm {
namespace relay {
static bool logging_enabled() {
if (auto var = std::getenv("RELAY_LOG")) {
std::string is_on(var);
return is_on == "1";
} else {
return false;
}
}
#define RELAY_LOG(severity) LOG_IF(severity, logging_enabled())
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_LOGGING_H_
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/op.h
* \brief Primitive operator definition.
*/
#ifndef TVM_RELAY_OP_H_
#define TVM_RELAY_OP_H_
#include <functional>
#include <limits>
#include <string>
#include <typeinfo>
#include <utility>
#include <vector>
#include "../attrs.h"
#include "./base.h"
#include "./expr.h"
#include "./type.h"
namespace tvm {
namespace relay {
// forward declare name.
template <typename ValueType>
class OpMap;
class GenericOpMap;
class OpRegistry;
/*!
* \brief Node container of operator structure.
*/
class OpNode : public relay::ExprNode {
public:
/*! \brief name of the operator */
std::string name;
/*! \brief the type of the operator */
mutable FuncType op_type;
/*!
* \brief detailed description of the operator
* This can be used to generate docstring automatically for the operator.
*/
std::string description;
/* \brief Information of input arguments to the operator */
Array<AttrFieldInfo> arguments;
/*!
* \brief The type key of the attribute field
* This can be empty, in which case it defaults to
*/
std::string attrs_type_key;
/*!
* \brief number of input arguments to the operator,
* -1 means it is variable length
*/
int32_t num_inputs = -1;
/*!
* \brief support level of the operator,
* The lower the more priority it contains.
* This is in analogies to BLAS levels.
*/
int32_t support_level = 10;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("op_type", &op_type);
v->Visit("description", &description);
v->Visit("arguments", &arguments);
v->Visit("attrs_type_key", &attrs_type_key);
v->Visit("num_inputs", &num_inputs);
v->Visit("support_level", &support_level);
}
static constexpr const char* _type_key = "relay.Op";
TVM_DECLARE_NODE_TYPE_INFO(OpNode, ExprNode);
private:
// friend class
friend class GenericOpMap;
friend class OpRegistry;
// Program internal unique index of operator.
// Used to help index the program.
uint32_t index_{0};
};
/*!
* \brief Operator reference class.
*/
class Op : public relay::Expr {
public:
/*! \brief default constructor */
Op() {}
/*! \brief constructor from node pointer */
explicit Op(std::shared_ptr<Node> n) : Expr(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const OpNode* operator->() const;
/*!
* \brief Get additional registered attribute about operators.
* If nothing has been registered, an empty OpMap will be returned.
* \param attr_name The name of the attribute.
* \return An OpMap of specified attr_name.
* \tparam ValueType The type of the attribute.
*/
template <typename ValueType>
inline static OpMap<ValueType> GetAttr(const std::string& attr_name);
/*!
* \brief Get an Op for a given operator name.
* Will raise an error if the op has not been registered.
* \param op_name Name of the operator.
* \return Pointer to a Op, valid throughout program lifetime.
*/
TVM_DLL static const Op& Get(const std::string& op_name);
/*! \brief specify container node */
using ContainerType = OpNode;
private:
/*!
* \brief Get generic attrmap given attr name
* \param key The attribute key
* \return reference to GenericOpMap
*/
TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key);
};
/*! \brief Helper structure to register operators */
class OpRegistry {
public:
/*! \return the operator */
const Op& op() const { return op_; }
/*!
* \brief setter function during registration
* Set the description of operator
* \param descr the description string.
* \return reference to self.
*/
inline OpRegistry& describe(const std::string& descr); // NOLINT(*)
/*!
* \brief Add argument information to the function.
* \param name Name of the argument.
* \param type Type of the argument.
* \param description Description of the argument.
* \return reference to self.
*/
inline OpRegistry& add_argument(const std::string& name,
const std::string& type,
const std::string& description);
/*!
* \brief Attach the type function corresponding to the return type.
* \param rel_name The type relation name to register.
* \param type_rel_func The backing relation function which can solve an arbitrary
* relation on variables.
* \return reference to self.
*/
inline OpRegistry& add_type_rel(
const std::string& rel_name,
std::function<Array<Type>(const Array<Type>&, int)> type_rel_func);
/*!
* \brief Set the type key of attributes.
* \param type_key The type of of the attrs field.x
* \return reference to self.
*/
inline OpRegistry& set_attrs_type_key(const std::string& type_key);
/*!
* \brief Set the num_inputs
* \param n The number of inputs to be set.
* \return reference to self.
*/
inline OpRegistry& set_num_inputs(int32_t n); // NOLINT(*)
/*!
* \brief Set the support level of op.
* \param level The support level.
* \return reference to self.
*/
inline OpRegistry& set_support_level(int32_t level); // NOLINT(*)
/*!
* \brief Register additional attributes to operator.
* \param attr_name The name of the attribute.
* \param value The value to be set.
* \param plevel The priority level of this set,
* an higher priority level attribute
* will replace lower priority level attribute.
* Must be bigger than 0.
*
* Cannot set with same plevel twice in the code.
*
* \tparam ValueType The type of the value to be set.
*/
template <typename ValueType>
inline OpRegistry& set_attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value, int plevel = 10);
// set the name of the op to be the same as registry
inline OpRegistry& set_name() { // NOLINT(*)
if (get()->name.length() == 0) {
get()->name = name;
}
return *this;
}
/*! \return The global single registry */
TVM_DLL static ::dmlc::Registry<OpRegistry>* Registry();
private:
friend class ::dmlc::Registry<OpRegistry>;
// the name
std::string name;
/*! \brief The operator */
Op op_;
// private constructor
OpRegistry();
// return internal pointer to op.
inline OpNode* get();
// update the attribute OpMap
TVM_DLL void UpdateAttr(const std::string& key, TVMRetValue value,
int plevel);
};
/*!
* \brief Generic map to store additional information of Op.
*/
class GenericOpMap {
public:
/*!
* \brief Check if the map has op as key.
* \param op The key to the map
* \return 1 if op is contained in map, 0 otherwise.
*/
inline int count(const Op& op) const;
/*!
* \brief get the corresponding value element at op
* \param op The key to the map
* \return the const reference to the content value.
*/
inline const TVMRetValue& operator[](const Op& op) const;
/*!
* \brief get the corresponding value element at op with default value.
* \param op The key to the map
* \param def_value The default value when the key does not exist.
* \return the const reference to the content value.
* \tparam ValueType The content value type.
*/
template <typename ValueType>
inline ValueType get(const Op& op, ValueType def_value) const;
private:
friend class OpRegistry;
// the attribute field.
std::string attr_name_;
// internal data
std::vector<std::pair<TVMRetValue, int> > data_;
// The value
GenericOpMap() = default;
};
/*!
* \brief Map<Op,ValueType> used to store meta-information about Op.
* \tparam ValueType The type of the value stored in map.
*/
template <typename ValueType>
class OpMap {
public:
/*!
* \brief Check if the map has op as key.
* \param op The key to the map
* \return 1 if op is contained in map, 0 otherwise.
*/
inline int count(const Op& op) const;
/*!
* \brief get the corresponding value element at op
* \param op The key to the map
* \return the const reference to the content value.
*/
inline ValueType operator[](const Op& op) const;
/*!
* \brief get the corresponding value element at op with default value.
* \param op The key to the map
* \param def_value The default value when the key does not exist.
* \return the const reference to the content value.
*/
inline ValueType get(const Op& op, ValueType def_value) const;
private:
friend class Op;
// constructor
explicit OpMap(const GenericOpMap& map) : map_(map) {}
/*! \brief The internal map field */
const GenericOpMap& map_;
};
// internal macros to make
#define RELAY_REGISTER_VAR_DEF \
static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::OpRegistry& __make_##RelayOp
/*!
* \def RELAY_REGISTER_OP
* \brief Register a new operator, or set attribute of the corresponding op.
*
* \param OpName The name of registry
*
* \code
*
* RELAY_REGISTER_OP("add")
* .describe("add two inputs together")
* .set_num_inputs(2)
* .set_attr<OpKernel>("gpu_kernel", AddKernel);
*
* \endcode
*/
#define RELAY_REGISTER_OP(OpName) \
DMLC_STR_CONCAT(RELAY_REGISTER_VAR_DEF, __COUNTER__) = \
::tvm::relay::OpRegistry::Registry() \
->__REGISTER_OR_GET__(OpName) \
.set_name()
// implementations
inline const OpNode* Op::operator->() const {
return static_cast<const OpNode*>(node_.get());
}
template <typename ValueType>
inline OpMap<ValueType> Op::GetAttr(const std::string& key) {
return OpMap<ValueType>(Op::GetGenericAttr(key));
}
inline OpNode* OpRegistry::get() {
return const_cast<OpNode*>(op_.operator->());
}
inline OpRegistry& OpRegistry::describe(
const std::string& descr) { // NOLINT(*)
get()->description = descr;
return *this;
}
inline OpRegistry& OpRegistry::add_argument(const std::string& name,
const std::string& type,
const std::string& description) {
std::shared_ptr<AttrFieldInfoNode> n = std::make_shared<AttrFieldInfoNode>();
n->name = name;
n->type_info = type;
n->description = description;
get()->arguments.push_back(AttrFieldInfo(n));
return *this;
}
inline OpRegistry& OpRegistry::add_type_rel(
const std::string& rel_name,
std::function<Array<Type>(const Array<Type>&, int)> type_rel_func) {
auto func_name = std::string("tvm.relay.type_relation.") + rel_name;
TypedEnvFunc<Array<Type>(const Array<Type>&, int)> env_type_rel_func;
if (runtime::Registry::Get(func_name)) {
auto env_func = EnvFunc::Get(func_name);
env_type_rel_func = env_func;
} else {
runtime::Registry::Register(func_name)
.set_body_typed<Array<Type>(const Array<Type>&, int)>(type_rel_func);
auto env_func = EnvFunc::Get(func_name);
env_type_rel_func = env_func;
}
std::vector<TypeParam> type_params;
std::vector<Type> arg_types;
// Add inputs.
std::string input_name_prefix = "in";
for (int i = 0; i < get()->num_inputs; i++) {
auto name = input_name_prefix + std::to_string(i);
auto param = TypeParamNode::make(name, TypeParamNode::Kind::kType);
type_params.push_back(param);
arg_types.push_back(param);
}
auto ty_call_args = Array<Type>(arg_types);
// Add output type.
auto out_param = TypeParamNode::make("out", TypeParamNode::Kind::kType);
type_params.push_back(out_param);
ty_call_args.push_back(out_param);
TypeConstraint type_rel =
TypeRelationNode::make(rel_name, env_type_rel_func, ty_call_args);
auto func_type =
FuncTypeNode::make(arg_types, out_param, type_params, {type_rel});
get()->op_type = func_type;
return *this;
}
inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*)
get()->num_inputs = n;
return *this;
}
inline OpRegistry& OpRegistry::set_attrs_type_key( // NOLINT(*)
const std::string& type_key) {
get()->attrs_type_key = type_key;
return *this;
}
inline OpRegistry& OpRegistry::set_support_level(int32_t n) { // NOLINT(*)
get()->support_level = n;
return *this;
}
template <typename ValueType>
inline OpRegistry& OpRegistry::set_attr( // NOLINT(*)
const std::string& attr_name, const ValueType& value, int plevel) {
CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0";
TVMRetValue rv;
rv = value;
UpdateAttr(attr_name, rv, plevel);
return *this;
}
// member functions of OpMap
inline int GenericOpMap::count(const Op& op) const {
if (op.defined()) {
const uint32_t idx = op->index_;
return idx < data_.size() ? (data_[idx].second != 0) : 0;
} else {
return 0;
}
}
inline const TVMRetValue& GenericOpMap::operator[](const Op& op) const {
CHECK(op.defined());
const uint32_t idx = op->index_;
CHECK(idx < data_.size() && data_[idx].second != 0)
<< "Attribute " << attr_name_ << " has not been registered for Operator "
<< op->name;
return data_[idx].first;
}
template <typename ValueType>
inline ValueType GenericOpMap::get(const Op& op, ValueType value) const {
CHECK(op.defined());
const uint32_t idx = op->index_;
if (idx < data_.size() && data_[idx].second != 0) {
return data_[idx].first;
} else {
return value;
}
}
template <typename ValueType>
inline int OpMap<ValueType>::count(const Op& op) const {
return map_.count(op);
}
template <typename ValueType>
inline ValueType OpMap<ValueType>::operator[](const Op& op) const {
return map_[op];
}
template <typename ValueType>
inline ValueType OpMap<ValueType>::get(const Op& op,
ValueType def_value) const {
return map_.get<ValueType>(op, def_value);
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_H_
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/pass.h
* \brief The set of Relay passes written in C++.
*/
#ifndef TVM_RELAY_PASS_H_
#define TVM_RELAY_PASS_H_
#include <tvm/relay/environment.h>
#include <tvm/relay/expr.h>
namespace tvm {
namespace relay {
/*! \brief Infer the type of an expression with the provided environment.
*
* The result of type checking is a new expression with unambigous
* type information filled in, as well as it's checked type field
* populated with the result type.
*
* \param env The environment used for global settings and referencing
* global functions.
*
* \param e The expression to type check.
*
* \return A type checked expression with its checked_type field populated.
*/
Expr InferType(const Environment& env, const Expr& e);
Expr InferType(const Environment& env, const GlobalVar& v, const Function& e);
/*!
* \brief Check that types are well formed by applying "kinding rules".
*
* This pass ensures we do not do things that violate the design of the
* type system when writing down types.
*
* For example tensors are not allowed to contain functions in Relay.
*
* We check this by ensuring the `dtype` field of a Tensor always contains
* a data type such as `int`, `float`, `uint`.
*
* \param env The global environment.
* \param t The type to check.
* \return true if the rules are satisified otherwise false
*/
bool KindCheck(const Environment& env, const Type& t);
/*! \brief Compare two expressions for structural equivalence.
*
* This comparison operator respects scoping and compares
* expressions without regard to variable choice.
*
* For example: `let x = 1 in x` is equal to `let y = 1 in y`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
* for more details.
*
* \param e1 The left hand expression.
* \param e2 The right hand expression.
*
* \return true if equal, otherwise false
*/
bool AlphaEqual(const Expr& e1, const Expr& e2);
/*! \brief Compare two types for structural equivalence.
*
* This comparison operator respects scoping and compares
* expressions without regard to variable choice.
*
* For example: `forall s, Tensor[f32, s]` is equal to
* `forall w, Tensor[f32, w]`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
* for more details.
*
* \param t1 The left hand type.
* \param t2 The right hand type.
*
* \return true if equal, otherwise false
*/
bool AlphaEqual(const Type& t1, const Type& t2);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_H_
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/type.h
* \brief Relay typed AST nodes.
*/
#ifndef TVM_RELAY_TYPE_H_
#define TVM_RELAY_TYPE_H_
#include <tvm/api_registry.h>
#include <tvm/ir.h>
#include <tvm/node.h>
#include <string>
#include "./base.h"
namespace tvm {
namespace relay {
/*! \brief Base type of the Relay type hiearchy. */
class TypeNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Type";
TVM_DECLARE_BASE_NODE_INFO(TypeNode, Node);
};
/*!
* \brief Type is the base type of relay type hiearchy.
*
* Relay's type system contains following two key concepts:
*
* - TensorType: type of certain Tensor values in the expression.
* - FunctionType: the type of the function.
*
* There are also advanced types to support generic(polymorphic types),
* which can be ignored when first reading the code base.
*/
class Type : public NodeRef {
public:
Type() {}
explicit Type(std::shared_ptr<tvm::Node> p) : NodeRef(p) {}
using ContainerType = TypeNode;
};
/*!
* \brief Base of all Tensor types
* This container can hold TensorType or GenericTensorType.
*/
class BaseTensorTypeNode : public TypeNode {
public:
static constexpr const char* _type_key = "relay.BaseTensorType";
TVM_DECLARE_BASE_NODE_INFO(BaseTensorTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(BaseTensorType, BaseTensorTypeNode, Type);
/*!
* \brief This is the most commonly used type in relay.
* TensorType have a fixed dimension, data type.
*
* The elements of shape can be either IntImm(constant integer),
* or any symbolic integer expression.
* The symbolic integer allows generic shape inference in certain cases.
* \sa TensorTypeNode The container class of TensorType.
*/
class TensorType;
/*! \brief TensorType container node */
class TensorTypeNode : public BaseTensorTypeNode {
public:
/*!
* \brief The shape of the tensor,
* represented by ShapeExpr(tvm::Expr).
*/
Array<ShapeExpr> shape;
/*! \brief The content data type */
DataType dtype;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
v->Visit("span", &span);
}
TVM_DLL static TensorType make(Array<ShapeExpr> shape, DataType dtype);
/*! \brief Construct an scalar containing elements of dtype. */
TVM_DLL static TensorType Scalar(DataType dtype);
static constexpr const char* _type_key = "relay.TensorType";
TVM_DECLARE_NODE_TYPE_INFO(TensorTypeNode, BaseTensorTypeNode);
};
RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type);
/*!
* \brief Type parameter in the function.
* This can be viewed as template parameter in c++ template function.
*
* For example, in the following pesudo code,
* the TypeParam of f is TypeParam(kind=kShapeVar, var=n).
* This function can take in a Tensor with shape=(3, 3) and
* returns a Tensor with shape=(9,)
*
* \code
*
* template<i32 n>
* f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)]
*
* \endcode
* \sa TypeParamNode The actual container class of TypeParam
*/
class TypeParam;
/*! \brief TypeParam container node */
class TypeParamNode : public TypeNode {
public:
/*! \brief possible kinds of TypeParam */
enum Kind : int {
/*! \brief template variable in shape expression */
kShapeVar = 0,
kShape = 1,
kBaseType = 2,
kType = 3
};
/*!
* \brief The variable itself is only meaningful when
* kind is ShapeVar, otherwise, we only use the name.
*/
tvm::Var var;
/*! \brief The kind of type parameter */
Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("kind", &kind);
v->Visit("span", &span);
}
TVM_DLL static TypeParam make(std::string name, Kind kind);
static constexpr const char* _type_key = "relay.TypeParam";
TVM_DECLARE_NODE_TYPE_INFO(TypeParamNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(TypeParam, TypeParamNode, Type);
/*!
* \brief Potential Constraints in the type.
* \note This is reserved for future use.
*/
class TypeConstraint;
/*! \brief TypeConstraint container node. */
class TypeConstraintNode : public TypeNode {
public:
static constexpr const char* _type_key = "relay.TypeConstraint";
TVM_DECLARE_BASE_NODE_INFO(TypeConstraintNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(TypeConstraint, TypeConstraintNode, Type);
class FuncType;
/*!
* \brief Function type in Relay.
*
* Relay support polymorphic function type.
* This can be roughly viewed as template function in C++.
*
* \sa TypeParam, TypeConstraint
*/
class FuncTypeNode : public TypeNode {
public:
/*! \brief type type of arguments */
tvm::Array<Type> arg_types;
/*! \brief The type of return value. */
Type ret_type;
// The following fields are used in polymorphic(template) functions
// For normal functions, the following two fields will be empty.
/*! \brief The type parameters of the function */
tvm::Array<TypeParam> type_params;
/*!
* \brief potential constraint the type need to obey
* \note this field is reserved for futher purposes.
*/
tvm::Array<TypeConstraint> type_constraints;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("arg_types", &arg_types);
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params);
v->Visit("type_constraints", &type_constraints);
v->Visit("span", &span);
}
TVM_DLL static FuncType make(tvm::Array<Type> arg_types, Type ret_type,
tvm::Array<TypeParam> type_params,
tvm::Array<TypeConstraint> type_constraints);
static constexpr const char* _type_key = "relay.FuncType";
TVM_DECLARE_NODE_TYPE_INFO(FuncTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(FuncType, FuncTypeNode, Type);
using TypeRelationFn =
TypedEnvFunc<Array<Type>(const Array<Type>&, int)>;
/*!
* \brief Opaque type relation, is an input-output relation on types.
*/
class TypeRelation;
/*!
* \brief TypeRelation container.
* \note This node is not directly serializable.
* The type function need to be lookedup in the environment.
*/
class TypeRelationNode : public TypeConstraintNode {
public:
/*! \brief The name of the function */
std::string name;
/*!
* \brief The function on input and output variables which
* this is not directly serializable,
* need to be looked-up in the environment.
*/
TypeRelationFn func_;
/*! \brief The type arguments to the type function. */
tvm::Array<Type> args;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("name", &name);
}
TVM_DLL static TypeRelation make(std::string name, TypeRelationFn func_, Array<Type> args);
static constexpr const char* _type_key = "relay.TypeRelation";
TVM_DECLARE_NODE_TYPE_INFO(TypeRelationNode, TypeConstraintNode);
};
RELAY_DEFINE_NODE_REF(TypeRelation, TypeRelationNode, TypeConstraint);
/*!
* \brief The type of tuple values.
*/
class TupleType;
/*!
* \brief TupleType container.
*/
class TupleTypeNode : public TypeNode {
public:
/*! \brief The type of each field in the tuple. */
tvm::Array<Type> fields;
TupleTypeNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); }
TVM_DLL static TupleType make(tvm::Array<Type> fields);
static constexpr const char* _type_key = "relay.TypeTuple";
TVM_DECLARE_NODE_TYPE_INFO(TupleTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type);
// The following fields contains advanced typing
// Only keep the class name and reserved for future usage.
class GenericTensorType;
// stores a DataType.
class GenericDataType;
// stores a DataType.
class GenericShape;
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_TYPE_H_
# pylint: disable=wildcard-import
"""The Relay IR namespace containing the IR definition and compiler."""
from . import base
from . import ty
from . import expr
from . import env
from . import ir_pass
from . import ir_builder
# Operators
from .op import Op
from .op.tensor import *
# Span
Span = base.Span
# Type
Type = ty.Type
TensorType = ty.TensorType
Kind = ty.Kind
TypeParam = ty.TypeParam
TypeConstraint = ty.TypeConstraint
FuncType = ty.FuncType
# Expr
Constant = expr.Constant
Tuple = expr.Tuple
Var = expr.Var
GlobalVar = expr.GlobalVar
Param = expr.Param
Function = expr.Function
Call = expr.Call
Let = expr.Let
If = expr.If
Var = Var
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface to the Environment exposed from C++."""
from tvm._ffi.function import _init_api
_init_api("relay._env", __name__)
from typing import Union, Tuple, Dict, List
from relay.ir import GlobalId, OperatorId, Item, NodeBase, Span, FileId
from relay.ir import ShapeExtension, Operator, Defn
class Environment(NodeBase): ...
\ No newline at end of file
"""FFI exposing the Relay type inference and checking."""
from tvm._ffi.function import _init_api
_init_api("relay._ir_pass", __name__)
from .env import Environment
from . import ir
def check_expr(env: Environment, expr: ir.Expr) -> ir.Type: ...
def generalize(env: Environment, expr: ir.Expr) -> ir.Expr: ...
def _get_checked_type(expr: ir.Expr) -> ir.Type: ...
"""
The constructors for all Relay AST nodes exposed from C++.
This module includes MyPy type signatures for all of the
exposed modules.
"""
from .._ffi.function import _init_api
_init_api("relay._make", __name__)
# pylint: disable=no-else-return, unidiomatic-typecheck
"""The base node types for the Relay language."""
from __future__ import absolute_import as _abs
from .._ffi.node import NodeBase, register_node as _register_tvm_node
from . import _make
NodeBase = NodeBase
def register_relay_node(type_key=None):
"""register relay node type
Parameters
----------
type_key : str or cls
The type key of the node
"""
if not isinstance(type_key, str):
return _register_tvm_node(
"relay." + type_key.__name__)(type_key)
return _register_tvm_node(type_key)
@register_relay_node
class Span(NodeBase):
def __init__(self, source, lineno, col_offset):
self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset)
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
"""A global environment storing everything needed to interpret or compile a Relay program."""
from .base import register_relay_node, NodeBase
from . import _make
from . import _env
@register_relay_node
class Environment(NodeBase):
"""The global Relay environment containing functions,
options and more.
"""
def __init__(self, funcs):
"""Construct an environment.
Parameters
------
funcs: list of relay.Function
Returns
------
env: A new environment containing :py:class:`~relay.env.Environment`.
"""
self.__init_handle_by_constructor__(_make.Environment, funcs)
def add(self, var, func):
"""Add a function to the environment.
Parameters
---------
var: GlobalVar
The global variable which names the function.
func: Function
The function.
"""
if isinstance(var, str):
var = _env.Environment_GetGlobalVar(self, var)
_env.Environment_Add(self, var, func)
def merge(self, other):
"""Merge two environments.
Parameters
----------
other: Environment
The environment to merge into the current Environment.
"""
return _env.Environment_Merge(self, other)
def global_var(self, name):
"""Get a global variable by name.
Parameters
----------
name: str
The name of the global variable.
Returns
-------
global_var: GlobalVar
The global variable mapped to :code:`name`.
"""
return _env.Environment_GetGlobalVar(self, name)
def __getitem__(self, var):
"""Lookup a global function by name or by variable.
Parameters
----------
var: str or GlobalVar
The name or global variable.
Returns
-------
func: Function
The function referenced by :code:`var`.
"""
if isinstance(var, str):
return _env.Environment_Lookup_str(self, var)
else:
return _env.Environment_Lookup(self, var)
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The expression nodes of Relay."""
from __future__ import absolute_import
from .base import NodeBase, register_relay_node
from ._ir_pass import _get_checked_type
from . import _make
from .. import convert
class Expr(NodeBase):
"""The base type for all Relay expressions."""
def checked_type(self):
return _get_checked_type(self)
def __call__(self, *args):
converted_args = []
for arg in args:
if isinstance(arg, Param):
converted_args.append(arg.var)
else:
converted_args.append(arg)
return Call(self, args, None, None)
@register_relay_node
class Constant(Expr):
"""A constant tensor in Relay, see tvm/relay/type.h for more details.
"""
def __init__(self, data):
self.__init_handle_by_constructor__(_make.Constant, data)
@register_relay_node
class Tuple(Expr):
"""A hetereogenous sequence of values.
see tvm/relay/type.h for more details.
"""
def __init__(self, fields):
self.__init_handle_by_constructor__(_make.Tuple, fields)
@register_relay_node
class Var(Expr):
"""A local variable in Relay."""
def __init__(self, name_hint):
self.__init_handle_by_constructor__(_make.Var, name_hint)
@register_relay_node
class GlobalVar(Expr):
"""A global variable in Relay."""
def __init__(self, name_hint):
self.__init_handle_by_constructor__(_make.GlobalVar, name_hint)
@register_relay_node
class Param(Expr):
"""A function type in Relay, see tvm/relay/type.h for more details.
"""
def __init__(self, var, ty):
self.__init_handle_by_constructor__(_make.Param, var, ty)
@register_relay_node
class Function(Expr):
"""A function in Relay, see tvm/relay/expr.h for more details."""
def __init__(self,
params,
ret_type,
body,
type_params=None
):
if type_params is None:
type_params = convert([])
self.__init_handle_by_constructor__(
_make.Function, params, ret_type, body, type_params)
@register_relay_node
class Call(Expr):
"""A function call in Relay, see tvm/relay/expr.h for more details."""
def __init__(self, op, args, attrs, ty_args=None):
if not ty_args:
ty_args = []
self.__init_handle_by_constructor__(
_make.Call, op, args, attrs, ty_args)
@register_relay_node
class Let(Expr):
"""A variable bindings in Relay, see tvm/relay/expr.h for more details."""
def __init__(self, var, value, body, value_type):
self.__init_handle_by_constructor__(
_make.Let, var, value, body, value_type)
@register_relay_node
class If(Expr):
"""A conditional expression in Relay, see tvm/relay/expr.h for more details."""
def __init__(self, cond, true_value, false_value):
self.__init_handle_by_constructor__(
_make.If, cond, true_value, false_value)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment