Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
T
tvm
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
cld
ml
tvm
Commits
52e55baa
Commit
52e55baa
authored
6 years ago
by
Josh Pollock
Committed by
Tianqi Chen
6 years ago
Browse files
Options
Downloads
Patches
Plain Diff
[Relay] Parser Tests (#2209)
parent
d3bc59d2
No related branches found
Branches containing commit
No related tags found
Tags containing commit
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
src/relay/ir/alpha_equal.cc
+6
-6
6 additions, 6 deletions
src/relay/ir/alpha_equal.cc
src/relay/ir/text_printer.cc
+2
-2
2 additions, 2 deletions
src/relay/ir/text_printer.cc
tests/python/relay/test_ir_parser.py
+562
-0
562 additions, 0 deletions
tests/python/relay/test_ir_parser.py
with
570 additions
and
8 deletions
src/relay/ir/alpha_equal.cc
+
6
−
6
View file @
52e55baa
...
...
@@ -26,7 +26,7 @@ class AlphaEqualHandler:
* Check equality of two nodes.
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return
t
he compare result.
* \return
T
he compare result.
*/
bool
Equal
(
const
NodeRef
&
lhs
,
const
NodeRef
&
rhs
)
{
if
(
lhs
.
same_as
(
rhs
))
return
true
;
...
...
@@ -46,7 +46,7 @@ class AlphaEqualHandler:
* Check equality of two attributes.
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return
t
he compare result.
* \return
T
he compare result.
*/
bool
AttrEqual
(
const
NodeRef
&
lhs
,
const
NodeRef
&
rhs
)
{
return
AttrsEqualHandler
::
Equal
(
lhs
,
rhs
);
...
...
@@ -55,7 +55,7 @@ class AlphaEqualHandler:
* Check equality of two types.
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return
t
he compare result.
* \return
T
he compare result.
*/
bool
TypeEqual
(
const
Type
&
lhs
,
const
Type
&
rhs
)
{
if
(
lhs
.
same_as
(
rhs
))
return
true
;
...
...
@@ -72,7 +72,7 @@ class AlphaEqualHandler:
*
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return
t
he compare result.
* \return
T
he compare result.
*/
bool
ExprEqual
(
const
Expr
&
lhs
,
const
Expr
&
rhs
)
{
if
(
lhs
.
same_as
(
rhs
))
return
true
;
...
...
@@ -94,7 +94,7 @@ class AlphaEqualHandler:
* \brief Check if data type equals each other.
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return
t
he compare result.
* \return
T
he compare result.
*/
bool
DataTypeEqual
(
const
DataType
&
lhs
,
const
DataType
&
rhs
)
{
return
lhs
==
rhs
;
...
...
@@ -104,7 +104,7 @@ class AlphaEqualHandler:
* if map_free_var_ is set to true, try to map via equal node.
* \param lhs The left hand operand.
* \param rhs The right hand operand.
* \return
t
he compare result.
* \return
T
he compare result.
*/
bool
LeafNodeEqual
(
const
NodeRef
&
lhs
,
const
NodeRef
&
rhs
)
{
if
(
lhs
.
same_as
(
rhs
))
return
true
;
...
...
This diff is collapsed.
Click to expand it.
src/relay/ir/text_printer.cc
+
2
−
2
View file @
52e55baa
...
...
@@ -38,7 +38,7 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO
* It can be hard to design a text format for all the possible nodes
* as the set of nodes can grow when we do more extensions.
*
* Instead of trying to design readable text format for every node
s
,
* Instead of trying to design readable text format for every node,
* we support a meta-data section in the text format.
* We allow the text format to refer to a node in the meta-data section.
*
...
...
@@ -73,7 +73,7 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO
* \endcode
*
* Note that we store tvm.var("n") in the meta data section.
* Since it is stored in the index-0 in the meta-data sec
i
ton,
* Since it is stored in the index-0 in the meta-data sect
i
on,
* we print it as meta.Variable(0).
*
* The text parser can recover this object by loading from the corresponding
...
...
This diff is collapsed.
Click to expand it.
tests/python/relay/test_ir_parser.py
0 → 100644
+
562
−
0
View file @
52e55baa
import
tvm
from
tvm
import
relay
from
tvm.relay.parser
import
enabled
from
tvm.relay.ir_pass
import
alpha_equal
from
nose.tools
import
nottest
,
raises
from
numpy
import
isclose
from
typing
import
Union
from
functools
import
wraps
if
enabled
():
from
tvm.relay._parser
import
ParseError
raises_parse_error
=
raises
(
ParseError
)
else
:
raises_parse_error
=
lambda
x
:
x
BINARY_OPS
=
{
"
*
"
:
relay
.
multiply
,
"
/
"
:
relay
.
divide
,
"
+
"
:
relay
.
add
,
"
-
"
:
relay
.
subtract
,
"
<
"
:
relay
.
less
,
"
>
"
:
relay
.
greater
,
"
<=
"
:
relay
.
less_equal
,
"
>=
"
:
relay
.
greater_equal
,
"
==
"
:
relay
.
equal
,
"
!=
"
:
relay
.
not_equal
,
}
TYPES
=
{
"
int8
"
,
"
int16
"
,
"
int32
"
,
"
int64
"
,
"
uint8
"
,
"
uint16
"
,
"
uint32
"
,
"
uint64
"
,
"
float16
"
,
"
float32
"
,
"
float64
"
,
"
bool
"
,
"
int8x4
"
,
"
uint1x4
"
,
"
float16x4
"
,
}
def
get_scalar
(
x
):
# type: (relay.Constant) -> (Union[float, int, bool])
return
x
.
data
.
asnumpy
().
item
()
int32
=
relay
.
scalar_type
(
"
int32
"
)
_
=
relay
.
Var
(
"
_
"
)
X
=
relay
.
Var
(
"
x
"
)
Y
=
relay
.
Var
(
"
y
"
)
X_ANNO
=
relay
.
Var
(
"
x
"
,
int32
)
Y_ANNO
=
relay
.
Var
(
"
y
"
,
int32
)
UNIT
=
relay
.
Tuple
([])
# decorator to determine if parser is enabled
def
if_parser_enabled
(
func
):
# https://stackoverflow.com/q/7727678
@wraps
(
func
)
def
wrapper
():
if
not
enabled
():
return
func
()
return
wrapper
@if_parser_enabled
def
test_comments
():
assert
alpha_equal
(
relay
.
fromtext
(
"""
// This is a line comment!
()
"""
),
UNIT
)
assert
alpha_equal
(
relay
.
fromtext
(
"""
/* This is a block comment!
This is still a block comment!
*/
()
"""
),
UNIT
)
@if_parser_enabled
def
test_int_literal
():
assert
isinstance
(
relay
.
fromtext
(
"
1
"
),
relay
.
Constant
)
assert
isinstance
(
relay
.
fromtext
(
"
1
"
).
data
,
tvm
.
ndarray
.
NDArray
)
assert
get_scalar
(
relay
.
fromtext
(
"
1
"
))
==
1
assert
get_scalar
(
relay
.
fromtext
(
"
10
"
))
==
10
assert
get_scalar
(
relay
.
fromtext
(
"
0
"
))
==
0
assert
get_scalar
(
relay
.
fromtext
(
"
-100
"
))
==
-
100
assert
get_scalar
(
relay
.
fromtext
(
"
-05
"
))
==
-
5
@if_parser_enabled
def
test_float_literal
():
assert
get_scalar
(
relay
.
fromtext
(
"
1.0
"
))
==
1.0
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
"
1.56667
"
)),
1.56667
)
assert
get_scalar
(
relay
.
fromtext
(
"
0.0
"
))
==
0.0
assert
get_scalar
(
relay
.
fromtext
(
"
-10.0
"
))
==
-
10.0
# scientific notation
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
"
1e-1
"
)),
1e-1
)
assert
get_scalar
(
relay
.
fromtext
(
"
1e+1
"
))
==
1e+1
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
"
1E-1
"
)),
1E-1
)
assert
get_scalar
(
relay
.
fromtext
(
"
1E+1
"
))
==
1E+1
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
"
1.0e-1
"
)),
1.0e-1
)
assert
get_scalar
(
relay
.
fromtext
(
"
1.0e+1
"
))
==
1.0e+1
assert
isclose
(
get_scalar
(
relay
.
fromtext
(
"
1.0E-1
"
)),
1.0E-1
)
assert
get_scalar
(
relay
.
fromtext
(
"
1.0E+1
"
))
==
1.0E+1
@if_parser_enabled
def
test_bool_literal
():
assert
get_scalar
(
relay
.
fromtext
(
"
True
"
))
==
True
assert
get_scalar
(
relay
.
fromtext
(
"
False
"
))
==
False
@if_parser_enabled
def
test_negative
():
assert
isinstance
(
relay
.
fromtext
(
"
let %x = 1; -%x
"
).
body
,
relay
.
Call
)
assert
get_scalar
(
relay
.
fromtext
(
"
--10
"
))
==
10
assert
get_scalar
(
relay
.
fromtext
(
"
---10
"
))
==
-
10
@if_parser_enabled
def
test_bin_op
():
for
bin_op
in
BINARY_OPS
.
keys
():
assert
alpha_equal
(
relay
.
fromtext
(
"
1 {} 1
"
.
format
(
bin_op
)),
BINARY_OPS
.
get
(
bin_op
)(
relay
.
const
(
1
),
relay
.
const
(
1
))
)
@if_parser_enabled
def
test_parens
():
assert
alpha_equal
(
relay
.
fromtext
(
"
1 * 1 + 1
"
),
relay
.
fromtext
(
"
(1 * 1) + 1
"
))
assert
not
alpha_equal
(
relay
.
fromtext
(
"
1 * 1 + 1
"
),
relay
.
fromtext
(
"
1 * (1 + 1)
"
))
@if_parser_enabled
def
test_op_assoc
():
assert
alpha_equal
(
relay
.
fromtext
(
"
1 * 1 + 1 < 1 == 1
"
),
relay
.
fromtext
(
"
(((1 * 1) + 1) < 1) == 1
"
))
assert
alpha_equal
(
relay
.
fromtext
(
"
1 == 1 < 1 + 1 * 1
"
),
relay
.
fromtext
(
"
1 == (1 < (1 + (1 * 1)))
"
))
@nottest
@if_parser_enabled
def
test_vars
():
# temp vars won't work b/c they start with a digit
# # temp var
# temp_var = relay.fromtext("%1")
# assert isinstance(temp_var, relay.Var)
# assert temp_var.name == "1"
# var
var
=
relay
.
fromtext
(
"
let %foo = (); %foo
"
)
assert
isinstance
(
var
.
body
,
relay
.
Var
)
assert
var
.
body
.
name_hint
==
"
foo
"
# global var
global_var
=
relay
.
fromtext
(
"
@foo
"
)
assert
isinstance
(
global_var
,
relay
.
GlobalVar
)
assert
global_var
.
name_hint
==
"
foo
"
# operator id
op
=
relay
.
fromtext
(
"
foo
"
)
assert
isinstance
(
op
,
relay
.
Op
)
assert
op
.
name
==
"
foo
"
@if_parser_enabled
def
test_let
():
assert
alpha_equal
(
relay
.
fromtext
(
"
let %x = 1; ()
"
),
relay
.
Let
(
X
,
relay
.
const
(
1
),
UNIT
)
)
@if_parser_enabled
def
test_seq
():
assert
alpha_equal
(
relay
.
fromtext
(
"
(); ()
"
),
relay
.
Let
(
_
,
UNIT
,
UNIT
)
)
assert
alpha_equal
(
relay
.
fromtext
(
"
let %_ = { 1 }; ()
"
),
relay
.
Let
(
X
,
relay
.
const
(
1
),
UNIT
)
)
@raises_parse_error
@if_parser_enabled
def
test_let_global_var
():
relay
.
fromtext
(
"
let @x = 1; ()
"
)
@raises_parse_error
@if_parser_enabled
def
test_let_op
():
relay
.
fromtext
(
"
let x = 1; ()
"
)
@if_parser_enabled
def
test_tuple
():
assert
alpha_equal
(
relay
.
fromtext
(
"
()
"
),
relay
.
Tuple
([]))
assert
alpha_equal
(
relay
.
fromtext
(
"
(0,)
"
),
relay
.
Tuple
([
relay
.
const
(
0
)]))
assert
alpha_equal
(
relay
.
fromtext
(
"
(0, 1)
"
),
relay
.
Tuple
([
relay
.
const
(
0
),
relay
.
const
(
1
)]))
assert
alpha_equal
(
relay
.
fromtext
(
"
(0, 1, 2)
"
),
relay
.
Tuple
([
relay
.
const
(
0
),
relay
.
const
(
1
),
relay
.
const
(
2
)]))
@if_parser_enabled
def
test_func
():
# 0 args
assert
alpha_equal
(
relay
.
fromtext
(
"
fn () { 0 }
"
),
relay
.
Function
(
[],
relay
.
const
(
0
),
None
,
[]
)
)
# 1 arg
assert
alpha_equal
(
relay
.
fromtext
(
"
fn (%x) { %x }
"
),
relay
.
Function
(
[
X
],
X
,
None
,
[]
)
)
# 2 args
assert
alpha_equal
(
relay
.
fromtext
(
"
fn (%x, %y) { %x + %y }
"
),
relay
.
Function
(
[
X
,
Y
],
relay
.
add
(
X
,
Y
),
None
,
[]
)
)
# annotations
assert
alpha_equal
(
relay
.
fromtext
(
"
fn (%x: int32) -> int32 { %x }
"
),
relay
.
Function
(
[
X_ANNO
],
X_ANNO
,
int32
,
[]
)
)
# TODO(@jmp): Crashes if %x isn't annnotated.
# @nottest
@if_parser_enabled
def
test_defn
():
id_defn
=
relay
.
fromtext
(
"""
def @id(%x: int32) -> int32 {
%x
}
"""
)
assert
isinstance
(
id_defn
,
relay
.
Module
)
@if_parser_enabled
def
test_ifelse
():
assert
alpha_equal
(
relay
.
fromtext
(
"""
if (True) {
0
} else {
1
}
"""
),
relay
.
If
(
relay
.
const
(
True
),
relay
.
const
(
0
),
relay
.
const
(
1
)
)
)
@raises_parse_error
@if_parser_enabled
def
test_ifelse_scope
():
relay
.
fromtext
(
"""
if (True) {
let %x = ();
()
} else {
%x
}
"""
)
@if_parser_enabled
def
test_call
():
# 0 args
constant
=
relay
.
Var
(
"
constant
"
)
assert
alpha_equal
(
relay
.
fromtext
(
"""
let %constant = fn () { 0 };
%constant()
"""
),
relay
.
Let
(
constant
,
relay
.
Function
([],
relay
.
const
(
0
),
None
,
[]),
relay
.
Call
(
constant
,
[],
None
,
None
)
)
)
# 1 arg
id_var
=
relay
.
Var
(
"
id
"
)
assert
alpha_equal
(
relay
.
fromtext
(
"""
let %id = fn (%x) { %x };
%id(1)
"""
),
relay
.
Let
(
id_var
,
relay
.
Function
([
X
],
X
,
None
,
[]),
relay
.
Call
(
id_var
,
[
relay
.
const
(
1
)],
None
,
None
)
)
)
# 2 args
multiply
=
relay
.
Var
(
"
multiply
"
)
assert
alpha_equal
(
relay
.
fromtext
(
"""
let %multiply = fn (%x, %y) { %x * %y };
%multiply(0, 0)
"""
),
relay
.
Let
(
multiply
,
relay
.
Function
(
[
X
,
Y
],
relay
.
multiply
(
X
,
Y
),
None
,
[]
),
relay
.
Call
(
multiply
,
[
relay
.
const
(
0
),
relay
.
const
(
0
)],
None
,
None
)
)
)
# anonymous function
assert
alpha_equal
(
relay
.
fromtext
(
"""
(fn (%x) { %x })(0)
"""
),
relay
.
Call
(
relay
.
Function
(
[
X
],
X
,
None
,
[]
),
[
relay
.
const
(
0
)],
None
,
None
)
)
# curried function
curried_mult
=
relay
.
Var
(
"
curried_mult
"
)
alpha_equal
(
relay
.
fromtext
(
"""
let %curried_mult =
fn (%x) {
fn (%y) {
%x * %y
}
};
%curried_mult(0);
%curried_mult(0)(0)
"""
),
relay
.
Let
(
curried_mult
,
relay
.
Function
(
[
X
],
relay
.
Function
(
[
Y
],
relay
.
multiply
(
X
,
Y
),
None
,
[]
),
None
,
[]
),
relay
.
Let
(
_
,
relay
.
Call
(
curried_mult
,
[
relay
.
const
(
0
)],
None
,
None
),
relay
.
Call
(
relay
.
Call
(
curried_mult
,
[
relay
.
const
(
0
)],
None
,
None
),
[
relay
.
const
(
0
)],
None
,
None
)
)
)
)
# op
alpha_equal
(
relay
.
fromtext
(
"
abs(1)
"
),
relay
.
Call
(
relay
.
op
.
get
(
"
abs
"
),
[
relay
.
const
(
1
)],
None
,
None
)
)
# Types
@if_parser_enabled
def
test_incomplete_type
():
assert
alpha_equal
(
relay
.
fromtext
(
"
let %_ : _ = (); ()
"
),
relay
.
Let
(
_
,
UNIT
,
UNIT
)
)
@if_parser_enabled
def
test_builtin_types
():
for
builtin_type
in
TYPES
:
relay
.
fromtext
(
"
let %_ : {} = (); ()
"
.
format
(
builtin_type
))
@nottest
@if_parser_enabled
def
test_call_type
():
assert
False
@if_parser_enabled
def
test_tensor_type
():
assert
alpha_equal
(
relay
.
fromtext
(
"
let %_ : Tensor[(), float32] = (); ()
"
),
relay
.
Let
(
relay
.
Var
(
"
_
"
,
relay
.
TensorType
((),
"
float32
"
)),
UNIT
,
UNIT
)
)
assert
alpha_equal
(
relay
.
fromtext
(
"
let %_ : Tensor[(1,), float32] = (); ()
"
),
relay
.
Let
(
relay
.
Var
(
"
_
"
,
relay
.
TensorType
((
1
,),
"
float32
"
)),
UNIT
,
UNIT
)
)
assert
alpha_equal
(
relay
.
fromtext
(
"
let %_ : Tensor[(1, 1), float32] = (); ()
"
),
relay
.
Let
(
relay
.
Var
(
"
_
"
,
relay
.
TensorType
((
1
,
1
),
"
float32
"
)),
UNIT
,
UNIT
)
)
@if_parser_enabled
def
test_function_type
():
assert
alpha_equal
(
relay
.
fromtext
(
"""
let %_: fn () -> int32 = fn () -> int32 { 0 }; ()
"""
),
relay
.
Let
(
relay
.
Var
(
"
_
"
,
relay
.
FuncType
([],
int32
,
[],
[])),
relay
.
Function
([],
relay
.
const
(
0
),
int32
,
[]),
UNIT
)
)
assert
alpha_equal
(
relay
.
fromtext
(
"""
let %_: fn (int32) -> int32 = fn (%x: int32) -> int32 { 0 }; ()
"""
),
relay
.
Let
(
relay
.
Var
(
"
_
"
,
relay
.
FuncType
([
int32
],
int32
,
[],
[])),
relay
.
Function
([
relay
.
Var
(
"
x
"
,
int32
)],
relay
.
const
(
0
),
int32
,
[]),
UNIT
)
)
assert
alpha_equal
(
relay
.
fromtext
(
"""
let %_: fn (int32, int32) -> int32 = fn (%x: int32, %y: int32) -> int32 { 0 }; ()
"""
),
relay
.
Let
(
relay
.
Var
(
"
_
"
,
relay
.
FuncType
([
int32
,
int32
],
int32
,
[],
[])),
relay
.
Function
([
relay
.
Var
(
"
x
"
,
int32
),
relay
.
Var
(
"
y
"
,
int32
)],
relay
.
const
(
0
),
int32
,
[]),
UNIT
)
)
@if_parser_enabled
def
test_tuple_type
():
assert
alpha_equal
(
relay
.
fromtext
(
"""
let %_: () = (); ()
"""
),
relay
.
Let
(
relay
.
Var
(
"
_
"
,
relay
.
TupleType
([])),
UNIT
,
UNIT
)
)
assert
alpha_equal
(
relay
.
fromtext
(
"""
let %_: (int32,) = (0,); ()
"""
),
relay
.
Let
(
relay
.
Var
(
"
_
"
,
relay
.
TupleType
([
int32
])),
relay
.
Tuple
([
relay
.
const
(
0
)]),
UNIT
)
)
assert
alpha_equal
(
relay
.
fromtext
(
"""
let %_: (int32, int32) = (0, 1); ()
"""
),
relay
.
Let
(
relay
.
Var
(
"
_
"
,
relay
.
TupleType
([
int32
,
int32
])),
relay
.
Tuple
([
relay
.
const
(
0
),
relay
.
const
(
1
)]),
UNIT
)
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment