Skip to content
Snippets Groups Projects
Commit eb761f36 authored by Tianqi Chen's avatar Tianqi Chen Committed by GitHub
Browse files

[Refactor] Introduce target generic dispatch system (#556)

* [TVM] Introduce target generic dispatch system

* fix target warning
parent c3cac464
No related branches found
No related tags found
No related merge requests found
Showing
with 600 additions and 104 deletions
......@@ -8,6 +8,7 @@ Python API
intrin
tensor
schedule
target
build
module
ndarray
......
tvm.target
----------
.. automodule:: tvm.target
.. autofunction:: tvm.target.generic_func
.. autoclass:: tvm.target.Target
:members:
.. autofunction:: tvm.target.cuda
.. autofunction:: tvm.target.rocm
.. autofunction:: tvm.target.rasp
.. autofunction:: tvm.target.create
......@@ -37,13 +37,11 @@ Index
.. autosummary::
topi.cuda.schedule_conv2d_nchw
topi.cuda.schedule_conv2d_hwcn
topi.cuda.schedule_depthwise_conv2d_nchw
topi.cuda.schedule_depthwise_conv2d_nhwc
topi.cuda.schedule_reduce
topi.cuda.schedule_broadcast
topi.cuda.schedule_injective
topi.generic.schedule_conv2d_nchw
topi.generic.schedule_depthwise_conv2d_nchw
topi.generic.schedule_reduce
topi.generic.schedule_broadcast
topi.generic.schedule_injective
topi
~~~~
......@@ -75,14 +73,12 @@ topi.nn
.. autofunction:: topi.nn.depthwise_conv2d_nhwc
topi.cuda
~~~~~~~~~
.. automodule:: topi.cuda
topi.generic
~~~~~~~~~~~~
.. automodule:: topi.generic
.. autofunction:: topi.cuda.schedule_conv2d_nchw
.. autofunction:: topi.cuda.schedule_conv2d_hwcn
.. autofunction:: topi.cuda.schedule_depthwise_conv2d_nchw
.. autofunction:: topi.cuda.schedule_depthwise_conv2d_nhwc
.. autofunction:: topi.cuda.schedule_reduce
.. autofunction:: topi.cuda.schedule_broadcast
.. autofunction:: topi.cuda.schedule_injective
.. autofunction:: topi.generic.schedule_conv2d_nchw
.. autofunction:: topi.generic.schedule_depthwise_conv2d_nchw
.. autofunction:: topi.generic.schedule_reduce
.. autofunction:: topi.generic.schedule_broadcast
.. autofunction:: topi.generic.schedule_injective
......@@ -56,11 +56,7 @@ def context(dev_type, dev_id=0):
assert tvm.context("cuda", 0) == tvm.gpu(0)
"""
if isinstance(dev_type, string_types):
if dev_type not in TVMContext.STR2MASK:
if dev_type.find("nvptx") != -1:
dev_type = "cuda"
if dev_type.find("rocm") != -1:
dev_type = "rocm"
dev_type = dev_type.split()[0]
if dev_type not in TVMContext.STR2MASK:
raise ValueError("Unknown device type %s" % dev_type)
dev_type = TVMContext.STR2MASK[dev_type]
......
......@@ -100,9 +100,12 @@ class TVMContext(ctypes.Structure):
12: 'ext_dev',
}
STR2MASK = {
'llvm': 1,
'stackvm': 1,
'cpu': 1,
'gpu': 2,
'cuda': 2,
'nvptx': 2,
'cl': 4,
'opencl': 4,
'metal': 8,
......
......@@ -15,6 +15,7 @@ from . import container
from . import module
from . import codegen
from . import ndarray
from . import target as _target
class BuildConfig(object):
"""Configuration scope to set a build config option.
......@@ -238,7 +239,7 @@ def lower(sch,
def build(sch,
args=None,
target="llvm",
target=None,
target_host=None,
name="default_function",
binds=None):
......@@ -252,36 +253,10 @@ def build(sch,
args : list of Buffer or Tensor or Var, optional
The argument lists to the function.
target : str, optional
target : str or :any:`tvm.target.Target`, optional
The target and option of the compilation.
When the target is llvm, you can set options like:
- **-mtriple=<target triple>** or **-target**
Specify the target triple, which is useful for cross
compilation.
- **-mcpu=<cpuname>**
Specify a specific chip in the current architecture to
generate code for. By default this is infered from the
target triple and autodetected to the current architecture.
- **-mattr=a1,+a2,-a3,...**
Override or control specific attributes of the target,
such as whether SIMD operations are enabled or not. The
default set of attributes is set by the current CPU.
- **-system-lib**
Build TVM system library module. System lib is a global module that contains
self registered functions in program startup. User can get the module using
:any:`tvm.module.system_lib`.
It is useful in environments where dynamic loading api like dlopen is banned.
The system lib will be available as long as the result code is linked by the program.
target_host : str, optional
target_host : str or :any:`tvm.target.Target` optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
......@@ -301,6 +276,10 @@ def build(sch,
-------
f : Function, or pair of functions
The result function.
Note
----
See the note on :any:`tvm.target` on target string format.
"""
if isinstance(sch, schedule.Schedule):
if args is None:
......@@ -325,6 +304,9 @@ def build(sch,
if x.name in fname_set:
raise ValueError("Duplicate function name %s" % x.name)
target = _target.current_target() if target is None else target
target = _target.create(target) if target else _target.create("llvm")
fhost = []
fdevice = []
for func in flist:
......@@ -332,7 +314,7 @@ def build(sch,
if BuildConfig.current.detect_global_barrier:
func = ir_pass.ThreadSync(func, "global")
func = ir_pass.ThreadSync(func, "shared")
warp_size = 32 if target == "cuda" else 1
warp_size = target.thread_warp_size
func = ir_pass.LowerThreadAllreduce(func, warp_size)
fsplits = [s for s in ir_pass.SplitHostDevice(func)]
fhost.append(fsplits[0])
......@@ -345,29 +327,28 @@ def build(sch,
else:
raise ValueError("unknown function type %d" % func.func_type)
if not target.startswith("llvm") and target not in ("stackvm", "ext_dev") and not fdevice:
if "gpu" in target.keys and not fdevice:
warnings.warn(
"Specified target %s, but cannot find device code, did you do bind?" % target)
device = "cpu" if target.startswith("llvm") or target == "stackvm" else target
device_type = ndarray.context(device, 0).device_type
device_type = ndarray.context(target.target_name, 0).device_type
fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost]
if not target_host:
if device == "cpu":
if device_type == ndarray.cpu(0).device_type:
target_host = target
assert not fdevice
else:
target_host = "llvm" if module.enabled("llvm") else "stackvm"
target_host = _target.create(target_host)
target_device = target
fdevice = [ir_pass.LowerIntrin(x, target_device) for x in fdevice]
fhost = [ir_pass.LowerIntrin(x, target_host) for x in fhost]
fdevice = [ir_pass.LowerIntrin(x, target_device.target_name) for x in fdevice]
fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
fhost = [ir_pass.CombineContextCall(x) for x in fhost]
mhost = codegen.build_module(fhost, target_host)
mhost = codegen.build_module(fhost, str(target_host))
if fdevice:
mdev = codegen.build_module(fdevice, target_device)
mdev = codegen.build_module(fdevice, str(target_device))
mhost.import_module(mdev)
return mhost
"""Target management API of tvm"""
"""Target management API of TVM.
TVM's target string is in fomat ``<target_name> [-option=value]...``.
Note
----
The list of options include:
- **-device=<device name>**
The device name.
- **-mtriple=<target triple>** or **-target**
Specify the target triple, which is useful for cross
compilation.
- **-mcpu=<cpuname>**
Specify a specific chip in the current architecture to
generate code for. By default this is infered from the
target triple and autodetected to the current architecture.
- **-mattr=a1,+a2,-a3,...**
Override or control specific attributes of the target,
such as whether SIMD operations are enabled or not. The
default set of attributes is set by the current CPU.
- **-system-lib**
Build TVM system library module. System lib is a global module that contains
self registered functions in program startup. User can get the module using
:any:`tvm.module.system_lib`.
It is useful in environments where dynamic loading api like dlopen is banned.
The system lib will be available as long as the result code is linked by the program.
We can use :any:`tvm.target.create` to create a tvm.target.Target from the target string.
We can also use other specific function in this module to create specific targets.
"""
from __future__ import absolute_import
import warnings
from ._ffi.base import _LIB_NAME
try:
from decorator import decorate
except ImportError as err_msg:
# Allow decorator to be missing in runtime
if _LIB_NAME != "libtvm_runtime.so":
raise err_msg
def _merge_opts(opts, new_opts):
"""Helper function to merge options"""
if isinstance(new_opts, str):
new_opts = new_opts.split()
if new_opts:
return opts + new_opts
return opts
class Target(object):
"""A Target describes the target type on which computation should be carried on"""
default_target = None
str2type = {'x86': 1, 'cuda': 2, 'rasp': 3}
type2str = {1: 'x86', 2: 'cuda', 3: 'rasp'}
def __init__(self, target_type):
"""Constructs a context."""
if isinstance(target_type, Target):
self.target_typeid = target_type.target_typeid
else:
self.target_typeid = Target.str2type[target_type]
"""Target device information, use through TVM API.
@property
def target_type(self):
"""Returns the target type of current target."""
return Target.type2str[self.target_typeid]
Parameters
----------
target_name : {"llvm", "cuda", "opencl", "metal", "rocm", "stackvm", "ext_dev"}
The major target name.
def __hash__(self):
"""Compute hash value of target for dictionary lookup"""
return hash(self.target_typeid)
options : list of str, optional
Additional arguments appended to the target.
def __eq__(self, other):
"""Compares two targets. Two targets are equal if they
have the same target type.
"""
return isinstance(other, Target) and \
self.target_typeid == other.target_typeid
Note
----
Do not use class constructor, you can create target using the following functions
- :any:`tvm.target.create` create target from string
- :any:`tvm.target.rasp` create raspberry pi target
- :any:`tvm.target.cuda` create CUDA target
- :any:`tvm.target.rocm` create ROCM target
"""
current = None
def __init__(self,
target_name,
options=None):
self.target_name = target_name
self.options = _merge_opts([], options)
self.device_name = ""
# Parse device option
for item in self.options:
if item.startswith("-device="):
self.device_name = item.split("=")[1]
# Target query searchs device name first
if self.device_name:
self.keys = (self.device_name,)
else:
self.keys = ()
# Target configuration handling
self.thread_warp_size = 1
if target_name in ("llvm", ):
self.keys += ("cpu",)
elif target_name in ("cuda", "nvptx"):
self.keys += ("cuda", "gpu")
self.max_num_threads = 512
self.thread_warp_size = 32
elif target_name in ("rocm", "opencl"):
# For now assume rocm schedule for opencl
self.keys += ("rocm", "gpu")
self.max_num_threads = 256
elif target_name in ("metal",):
self.keys += ("gpu",)
self.max_num_threads = 256
elif target_name in ("stackvm", "ext_dev"):
# Do not now class for stacvm or ext_dev
pass
else:
raise ValueError("Unknown target name %s" % target_name)
def __str__(self):
return '%s' % (self.target_type)
return " ".join([self.target_name] + self.options)
def __repr__(self):
return self.__str__()
def __enter__(self):
self._old_target = Target.default_target
Target.default_target = self
self._old_target = Target.current
if self._old_target is not None and str(self) != str(self._old_target):
warnings.warn(
"Override target '%s' with new target scope '%s'" % (
self._old_target, self))
Target.current = self
return self
def __exit__(self, ptype, value, trace):
Target.default_target = self._old_target
Target.current = self._old_target
def generic_func(fdefault):
"""Wrap a target generic function.
Generic function allows registeration of further functions
that can be dispatched on current target context.
If no registered dispatch is matched, the fdefault will be called.
Parameters
----------
fdefault : function
The default function.
Returns
-------
fgeneric : function
A wrapped generic function.
Example
-------
.. code-block:: python
import tvm
# wrap function as target generic
@tvm.target.generic_func
def my_func(a):
return a + 1
# register specialization of my_func under target cuda
@my_func.register("cuda")
def my_func_cuda(a):
return a + 2
# displays 3, because my_func is called
print(my_func(2))
# displays 4, because my_func_cuda is called
with tvm.target.cuda():
print(my_func(2))
"""
dispatch_dict = {}
func_name = fdefault.__name__
def register(key, func=None, override=False):
"""Register function to be the dispatch function.
Parameters
----------
key : str or list of str
The key to be registered.
func : function
The function to be registered.
override : bool
Whether override existing registeration.
Returns
-------
The register function is necessary.
"""
def _do_reg(myf):
key_list = [key] if isinstance(key, str) else key
for k in key_list:
if k in dispatch_dict and not override:
raise ValueError(
"Key is already registered for %s" % func_name)
dispatch_dict[k] = myf
return myf
if func:
return _do_reg(myf)
return _do_reg
def dispatch_func(func, *args, **kwargs):
"""The wrapped dispath function"""
target = current_target()
if target is None:
return func(*args, **kwargs)
for k in target.keys:
if k in dispatch_dict:
return dispatch_dict[k](*args, **kwargs)
return func(*args, **kwargs)
fdecorate = decorate(fdefault, dispatch_func)
fdecorate.register = register
return fdecorate
def cuda(options=None):
"""Returns a cuda target.
Parameters
----------
options : list of str
Additional options
"""
return Target("cuda", options)
def rocm(options=None):
"""Returns a ROCM target.
Parameters
----------
options : list of str
Additional options
"""
return Target("rocm", options)
def rasp(options=None):
"""Returns a rasp target.
Parameters
----------
options : list of str
Additional options
"""
opts = ["-device=rasp",
"-mtriple=armv7l-none-linux-gnueabihf",
"-mcpu=cortex-a53",
"-mattr=+neon"]
opts = _merge_opts(opts, options)
return Target("llvm", opts)
def create(target_str):
"""Get a target given target string.
Parameters
----------
target_str : str
The target string.
Returns
-------
target : Target
The target object
Target.default_target = Target('x86')
Note
----
See the note on :any:`tvm.target` on target string format.
"""
if isinstance(target_str, Target):
return target_str
if not isinstance(target_str, str):
raise ValueError("target_str has to be string type")
arr = target_str.split()
# Parse device option
device_name = ""
for item in arr[1:]:
if item.startswith("-device="):
device_name = item.split("=")[1]
if device_name == "rasp":
return rasp(arr[1:])
return Target(arr[0], arr[1:])
def x86():
"""Returns a x86 target."""
return Target('x86')
def cuda():
"""Returns a cuda target."""
return Target('cuda')
def current_target(allow_none=True):
"""Returns the current target.
def rasp():
"""Returns a rasp target."""
return Target('rasp')
Parameters
----------
allow_none : bool
Whether allow the current target to be none
def current_target():
"""Returns the current target."""
return Target.default_target
Raises
------
ValueError if current target is not set.
"""
if Target.current:
return Target.current
if not allow_none:
raise RuntimeError(
"Requires a current target in generic function, but it is not set. "
"Please set it using `with TargetObject:`")
return Target.current
......@@ -82,6 +82,8 @@ GetLLVMTargetMachine(const std::string& target_str,
} else {
LOG(FATAL) << "invalid -mfloat-abi option " << value;
}
} else if (key == "-device") {
// pass
} else {
LOG(FATAL) << "unknown option " << key;
}
......
......@@ -68,7 +68,8 @@ def test_gemm():
print("skip because %s is not enabled.." % device)
return
f = tvm.build(s, [A, B, C], device)
with tvm.target.create(device):
f = tvm.build(s, [A, B, C])
ctx = tvm.context(device, 0)
# launch the kernel.
n = nn
......
import tvm
@tvm.target.generic_func
def mygeneric(data):
# default generic function
return data + 1
@mygeneric.register(["cuda", "gpu"])
def cuda_func(data):
return data + 2
@mygeneric.register("rocm")
def rocm_func(data):
return data + 3
@mygeneric.register("cpu")
def rocm_func(data):
return data + 10
def test_target_dispatch():
with tvm.target.cuda():
assert mygeneric(1) == 3
with tvm.target.rocm():
assert mygeneric(1) == 4
with tvm.target.create("cuda"):
assert mygeneric(1) == 3
with tvm.target.rasp():
assert mygeneric(1) == 11
with tvm.target.create("metal"):
assert mygeneric(1) == 3
try:
mygeneric(0)
raise RuntimeError("not reached")
except RuntimeError:
pass
if __name__ == "__main__":
test_target_dispatch()
......@@ -3,6 +3,7 @@
import tvm
from .. import util
from .. import tag
from .. import generic
def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L):
"""Schedule conv2d for specific feature_in_out_filter pattern"""
......@@ -483,6 +484,8 @@ def schedule_conv2d_small_batch(outs):
traverse(outs[0].op)
return s
@generic.schedule_conv2d_nchw.register(["cuda", "gpu"])
def schedule_conv2d_nchw(outs):
"""Schedule for conv2d_nchw.
......
......@@ -3,7 +3,9 @@
from __future__ import absolute_import as _abs
import tvm
from .. import tag
from .. import generic
@generic.schedule_dense.register(["cuda", "gpu"])
def schedule_dense(outs):
"""Schedule for dense operator.
......
......@@ -3,7 +3,9 @@
import tvm
from ..util import get_const_tuple
from .. import tag
from .. import generic
@generic.schedule_depthwise_conv2d_nchw.register(["cuda", "gpu"])
def schedule_depthwise_conv2d_nchw(outs):
"""Schedule for depthwise_conv2d nchw forward.
......
# pylint: disable=invalid-name, unused-variable,
"""Schedule for composition of injective operator"""
import tvm
from .. import generic
def _schedule_injective(op, sch):
x = op.output(0)
fused = sch[x].fuse(*sch[x].op.axis)
num_thread = 512
target = tvm.target.current_target()
target = target if target else tvm.target.cuda()
num_thread = target.max_num_threads
bx, tx = sch[x].split(fused, factor=num_thread)
sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
return sch
@generic.schedule_injective.register(["cuda", "gpu"])
def schedule_injective(outs):
"""Schedule for injective op.
......
......@@ -2,7 +2,9 @@
"""Schedule for pooling operators"""
import tvm
from .. import tag
from .. import generic
@generic.schedule_global_pool.register(["cuda", "gpu"])
def schedule_global_pool(outs):
"""Schedule for global_pool.
......@@ -63,6 +65,7 @@ def schedule_global_pool(outs):
return s
@generic.schedule_pool.register(["cuda", "gpu"])
def schedule_pool(outs):
"""Schedule for pool.
......
......@@ -3,6 +3,7 @@
from __future__ import absolute_import as _abs
import tvm
from .. import tag
from .. import generic
def _schedule_reduce(op, sch, is_idx_reduce=False):
if is_idx_reduce:
......@@ -62,6 +63,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False):
return sch
@generic.schedule_reduce.register(["cuda", "gpu"])
def schedule_reduce(outs):
"""Schedule for inject->reduce->bcast ops.
......
# pylint: disable=invalid-name, unused-variable, trailing-whitespace
"""Schedule for softmax operator"""
import tvm
from .. import generic
@generic.schedule_softmax.register(["cuda", "gpu"])
def schedule_softmax(outs):
"""Schedule for softmax op.
......
# pylint: disable=wildcard-import
"""Generic declaration and schedules.
This is a recommended way of using TOPI API.
To use the generic schedule function, user must set
the current target scope using with block. See also :any:`tvm.target`
Example
-------
.. code-block:: python
# create schedule that dispatches to topi.cuda.schedule_injective
with tvm.target.create("cuda"):
s = tvm.generic.schedule_injective(outs)
"""
from __future__ import absolute_import as _abs
from .nn import *
from .injective import *
# pylint: disable=invalid-name
"""generic declaration and schedules."""
from __future__ import absolute_import as _abs
import tvm
@tvm.target.generic_func
def schedule_injective(outs):
"""Schedule for injective op.
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
target = tvm.target.current_target(allow_none=False)
if target.target_name != "llvm":
raise RuntimeError("schedule_injective not registered for '%s'" % target)
x = outs[0]
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
s[x].fuse(s[x].op.axis)
return s
schedule_elemwise = schedule_injective
schedule_broadcast = schedule_injective
"""Generic nn operators"""
from __future__ import absolute_import as _abs
import tvm
def _default_schedule(outs, auto_inline):
"""Default schedule for llvm."""
target = tvm.target.current_target(allow_none=False)
if target.target_name != "llvm":
raise RuntimeError("schedule_pool not registered for '%s'" % target)
s = tvm.create_schedule([x.op for x in outs])
if auto_inline:
x = outs[0]
tvm.schedule.AutoInlineInjective(s)
s[x].fuse(s[x].op.axis)
return s
@tvm.target.generic_func
def schedule_conv2d_nchw(outs):
"""Schedule for conv2d nchow
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_depthwise_conv2d_nchw(outs):
"""Schedule for conv2d nchow
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_reduce(outs):
"""Schedule for reduction
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, True)
@tvm.target.generic_func
def schedule_softmax(outs):
"""Schedule for softmax
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_dense(outs):
"""Schedule for dense
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_pool(outs):
"""Schedule for pool
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_global_pool(outs):
"""Schedule for global pool
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment