Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
T
tvm
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
cld
ml
tvm
Commits
e8afa1b4
Commit
e8afa1b4
authored
7 years ago
by
xqdan
Committed by
Tianqi Chen
7 years ago
Browse files
Options
Downloads
Patches
Plain Diff
[PASS] Support buffer reuse for different types (#891)
[PASS] Support buffer reuse for different types
parent
61cdf903
No related branches found
Branches containing commit
No related tags found
Tags containing commit
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
src/pass/storage_rewrite.cc
+18
-10
18 additions, 10 deletions
src/pass/storage_rewrite.cc
tests/python/unittest/test_pass_storage_rewrite.py
+76
-4
76 additions, 4 deletions
tests/python/unittest/test_pass_storage_rewrite.py
with
94 additions
and
14 deletions
src/pass/storage_rewrite.cc
+
18
−
10
View file @
e8afa1b4
...
...
@@ -502,7 +502,6 @@ class StoragePlanRewriter : public IRMutator {
}
// Remap the index
Expr
RemapIndex
(
Type
dtype
,
Expr
index
,
StorageEntry
*
e
)
{
CHECK_EQ
(
dtype
.
element_of
(),
e
->
elem_type
);
if
(
e
->
bits_offset
==
0
)
return
index
;
uint64_t
elem_bits
=
dtype
.
bits
()
*
dtype
.
lanes
();
CHECK_EQ
(
e
->
bits_offset
%
elem_bits
,
0U
);
...
...
@@ -564,17 +563,22 @@ class StoragePlanRewriter : public IRMutator {
Expr
combo_size
;
for
(
const
Allocate
*
op
:
e
->
allocs
)
{
Expr
sz
=
arith
::
ComputeReduce
<
Mul
>
(
op
->
extents
,
make_const
(
Int
(
32
),
1
));
if
(
alloc_type
.
lanes
()
!=
op
->
type
.
lanes
())
{
sz
=
(
sz
*
make_const
(
sz
.
type
(),
op
->
type
.
lanes
())
+
make_const
(
sz
.
type
(),
alloc_type
.
lanes
()
-
1
))
/
make_const
(
sz
.
type
(),
alloc_type
.
lanes
());
}
// transform to bits
auto
sz_nbits
=
sz
*
(
op
->
type
.
bits
()
*
op
->
type
.
lanes
());
if
(
combo_size
.
defined
())
{
combo_size
=
max
(
combo_size
,
sz
);
combo_size
=
max
(
combo_size
,
sz
_nbits
);
}
else
{
combo_size
=
sz
;
combo_size
=
sz
_nbits
;
}
}
// transform to alloc bytes
auto
type_bits
=
alloc_type
.
bits
()
*
alloc_type
.
lanes
();
bool
divided
=
can_prove
(
combo_size
%
type_bits
==
0
);
combo_size
=
combo_size
/
type_bits
;
// round up for can not divided
if
(
!
divided
)
{
combo_size
+=
make_const
(
Int
(
32
),
1
);
}
combo_size
=
ir
::
Simplify
(
combo_size
);
e
->
new_alloc
=
Allocate
::
make
(
e
->
alloc_var
,
alloc_type
,
{
combo_size
},
const_true
(),
...
...
@@ -784,8 +788,9 @@ class StoragePlanRewriter : public IRMutator {
// skip plan for local variable,
// compiler can do a better job with register allocation.
const
uint64_t
match_range
=
16
;
uint64_t
op_elem_bits
=
op
->
type
.
bits
()
*
op
->
type
.
lanes
();
uint64_t
const_nbits
=
static_cast
<
uint64_t
>
(
op
->
constant_allocation_size
()
*
op
->
type
.
bits
()
*
op
->
type
.
lanes
()
);
op
->
constant_allocation_size
()
*
op
_elem_bits
);
// disable reuse of small arrays, they will be lowered to registers in LLVM
// This rules only apply if we are using non special memory
if
(
scope
.
tag
.
length
()
==
0
)
{
...
...
@@ -801,15 +806,18 @@ class StoragePlanRewriter : public IRMutator {
auto
begin
=
const_free_map_
.
lower_bound
(
const_nbits
/
match_range
);
auto
mid
=
const_free_map_
.
lower_bound
(
const_nbits
);
auto
end
=
const_free_map_
.
upper_bound
(
const_nbits
*
match_range
);
// start looking at the buffer that is bigger than the required size first
for
(
auto
it
=
mid
;
it
!=
end
;
++
it
)
{
StorageEntry
*
e
=
it
->
second
;
if
(
e
->
attach_scope_
!=
attach_scope
)
continue
;
if
(
e
->
scope
!=
scope
)
continue
;
if
(
e
->
elem_type
!=
op
->
type
.
element_of
())
continue
;
// when not divided, no reuse, eg, float4 vs float3
if
(
e
->
bits_offset
%
op_elem_bits
!=
0
)
continue
;
e
->
const_nbits
=
std
::
max
(
const_nbits
,
e
->
const_nbits
);
const_free_map_
.
erase
(
it
);
return
e
;
}
// then start looking at smaller buffers.
for
(
auto
it
=
mid
;
it
!=
begin
;)
{
--
it
;
StorageEntry
*
e
=
it
->
second
;
...
...
This diff is collapsed.
Click to expand it.
tests/python/unittest/test_pass_storage_rewrite.py
+
76
−
4
View file @
e8afa1b4
...
...
@@ -54,10 +54,27 @@ def test_alloc_different_dtypes():
ib
=
tvm
.
ir_builder
.
create
()
base_dtype
=
dtype_list
[
0
]
global_a
=
tvm
.
placeholder
((
length
,),
name
=
"
global_a
"
,
dtype
=
base_dtype
)
for
index
,
dtype
in
enumerate
(
dtype_list
):
with
ib
.
for_range
(
0
,
length
,
name
=
"
j
"
)
as
j
:
A
=
ib
.
allocate
(
dtype
,
length
,
name
=
"
A_
"
+
str
(
index
),
scope
=
"
local.L0A
"
)
A
[
j
]
=
tvm
.
const
(
1
,
dtype
=
dtype
)
assert
len
(
dtype_list
)
==
4
with
ib
.
for_range
(
0
,
length
,
name
=
"
j
"
)
as
j
:
dtype
=
dtype_list
[
0
]
A
=
ib
.
allocate
(
dtype
,
length
,
name
=
"
A
"
,
scope
=
"
local.L0A
"
)
A
[
j
]
=
tvm
.
const
(
1
,
dtype
=
dtype
)
with
ib
.
for_range
(
0
,
length
,
name
=
"
j
"
)
as
j
:
dtype
=
dtype_list
[
1
]
B
=
ib
.
allocate
(
dtype
,
length
,
name
=
"
B
"
,
scope
=
"
local.L0A
"
)
B
[
j
]
=
tvm
.
const
(
1
,
dtype
=
dtype
)
with
ib
.
for_range
(
0
,
length
,
name
=
"
j
"
)
as
j
:
dtype
=
dtype_list
[
2
]
C
=
ib
.
allocate
(
dtype
,
length
,
name
=
"
C
"
,
scope
=
"
local.L0A
"
)
C
[
j
]
=
tvm
.
const
(
1
,
dtype
=
dtype
)
with
ib
.
for_range
(
0
,
length
,
name
=
"
j
"
)
as
j
:
dtype
=
dtype_list
[
3
]
D
=
ib
.
allocate
(
dtype
,
length
,
name
=
"
D
"
,
scope
=
"
local.L0A
"
)
D
[
j
]
=
tvm
.
const
(
1
,
dtype
=
dtype
)
with
ib
.
for_range
(
0
,
length
,
name
=
"
j
"
)
as
j
:
dtype
=
"
int8
"
E
=
ib
.
allocate
(
dtype
,
length
,
name
=
"
E
"
,
scope
=
"
local.L0A
"
)
E
[
j
]
=
A
[
j
].
astype
(
dtype
)
+
B
[
j
].
astype
(
dtype
)
+
C
[
j
].
astype
(
dtype
)
+
D
[
j
].
astype
(
dtype
)
return
ib
.
get
()
def
dtype_bit_len
(
dtype
):
...
...
@@ -342,6 +359,58 @@ def test_inplace_rule3():
assert
n
.
extents
[
0
].
value
==
70
tvm
.
ir_pass
.
PostOrderVisit
(
stmt
,
verify
)
def
test_alloc_seq_type
():
ib
=
tvm
.
ir_builder
.
create
()
n
=
tvm
.
var
(
"
n
"
)
with
ib
.
for_range
(
0
,
n
,
name
=
"
i
"
)
as
i
:
with
ib
.
for_range
(
0
,
10
,
name
=
"
j
"
)
as
j
:
A
=
ib
.
allocate
(
"
float32
"
,
200
,
name
=
"
A
"
,
scope
=
"
local.L0A
"
)
A1
=
ib
.
allocate
(
"
float32
"
,
200
,
name
=
"
A1
"
,
scope
=
"
local.L0A
"
)
A
[
j
]
=
1.2
A1
[
j
]
=
1.3
B
=
ib
.
allocate
(
"
int16
"
,
200
,
name
=
"
B
"
,
scope
=
"
local.L0A
"
)
B
[
j
]
=
tvm
.
const
(
1
,
"
int16
"
)
C
=
ib
.
allocate
(
"
int16
"
,
200
,
name
=
"
C
"
,
scope
=
"
local.L0A
"
)
C
[
j
]
=
tvm
.
const
(
1
,
"
int16
"
)
D
=
ib
.
allocate
(
"
int16
"
,
200
,
name
=
"
D
"
,
scope
=
"
local.L0A
"
)
D
[
j
]
=
B
[
j
]
+
C
[
j
]
A2
=
ib
.
allocate
(
"
float32
"
,
200
,
name
=
"
A2
"
,
scope
=
"
local.L0A
"
)
A2
[
j
]
=
A
[
j
]
body
=
ib
.
get
()
body
=
tvm
.
ir_pass
.
StorageRewrite
(
body
)
num_alloc
=
[
0
]
def
verify
(
n
):
if
isinstance
(
n
,
tvm
.
stmt
.
Allocate
):
num_alloc
[
0
]
+=
1
assert
n
.
extents
[
0
].
value
==
500
tvm
.
ir_pass
.
PostOrderVisit
(
body
,
verify
)
assert
num_alloc
[
0
]
==
1
def
test_alloc_seq_type2
():
ib
=
tvm
.
ir_builder
.
create
()
n
=
tvm
.
var
(
"
n
"
)
with
ib
.
for_range
(
0
,
n
,
name
=
"
i
"
)
as
i
:
with
ib
.
for_range
(
0
,
10
,
name
=
"
j
"
)
as
j
:
A
=
ib
.
allocate
(
"
float32
"
,
200
,
name
=
"
A
"
,
scope
=
"
local.L0A
"
)
A
[
j
]
=
1.2
with
ib
.
for_range
(
0
,
20
,
name
=
"
j
"
)
as
j
:
B
=
ib
.
allocate
(
"
int16
"
,
400
,
name
=
"
B
"
,
scope
=
"
local.L0A
"
)
B
[
j
]
=
tvm
.
const
(
1
,
"
int16
"
)
with
ib
.
for_range
(
0
,
10
,
name
=
"
j
"
)
as
j
:
C
=
ib
.
allocate
(
"
float32
"
,
200
,
name
=
"
C
"
,
scope
=
"
local.L0A
"
)
C
[
j
]
=
1.2
body
=
ib
.
get
()
body
=
tvm
.
ir_pass
.
StorageRewrite
(
body
)
num_alloc
=
[
0
]
def
verify
(
n
):
if
isinstance
(
n
,
tvm
.
stmt
.
Allocate
):
num_alloc
[
0
]
+=
1
assert
n
.
extents
[
0
].
value
==
200
tvm
.
ir_pass
.
PostOrderVisit
(
body
,
verify
)
assert
num_alloc
[
0
]
==
1
if
__name__
==
"
__main__
"
:
test_alloc_seq
()
test_alloc_different_dtypes
()
...
...
@@ -352,3 +421,6 @@ if __name__ == "__main__":
test_storage_share_gpu
()
test_inplace_rule2
()
test_inplace_rule3
()
test_alloc_seq_type
()
test_alloc_seq_type2
()
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment