Skip to content
Snippets Groups Projects
Commit cdee6a79 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by Tianqi Chen
Browse files

register depthwise conv2d as generic function (#1108)

parent 8eebf5f6
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,7 @@ from .util import get_pad_tuple
from ..util import simplify
@tvm.target.generic_func
def depthwise_conv2d_nchw(Input, Filter, stride, padding, out_dtype='float32'):
"""Depthwise convolution nchw forward operator.
......@@ -63,6 +64,7 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, out_dtype='float32'):
return Output
@tvm.target.generic_func
def depthwise_conv2d_nhwc(Input, Filter, stride, padding):
"""Depthwise convolution nhwc forward operator.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment