Skip to content
Snippets Groups Projects
test_lang_basic.py 3.10 KiB
import tvm

def test_const():
    x = tvm.const(1)
    print(x.dtype)
    assert x.dtype == tvm.int32
    assert isinstance(x, tvm.expr.IntImm)

def test_const_saveload_json():
    # save load json
    x = tvm.const(1)
    y = tvm.const(10)
    z = x + y
    z = z + z
    json_str = tvm.save_json(z)
    zz = tvm.load_json(json_str)
    assert tvm.save_json(zz) == tvm.save_json(z)

def test_make():
    x = tvm.const(1)
    y = tvm.make.IntImm('int32', 1)
    z = x + y
    print(z)

def test_ir():
    x = tvm.const(1)
    y = tvm.make.IntImm('int32', 1)
    z = x + y
    stmt = tvm.make.Evaluate(z)
    assert isinstance(stmt, tvm.stmt.Evaluate)

def test_ir2():
    x = tvm.var("n")
    a = tvm.var("array", tvm.handle)
    st = tvm.make.Store(a, x + 1, 1)
    assert isinstance(st, tvm.stmt.Store)
    assert(st.buffer_var == a)

def test_let():
    x = tvm.var('x')
    y = tvm.var('y')
    stmt = tvm.make.LetStmt(
        x, 10, tvm.make.Evaluate(x + 1));

def test_attr():
    x = tvm.var('x')
    y = tvm.var('y')
    stmt = tvm.make.AttrStmt(
        y, "stride", 10, tvm.make.Evaluate(x + 1));
    assert stmt.node == y

    a = tvm.convert(1)
    assert a.value == 1
    try:
        a.no_field
        assert False
    except AttributeError:
        pass


def test_basic():
    a = tvm.var('a')
    b = tvm.var('b')
    c =  a + b
    assert str(c) == '(%s + %s)' % (a.name, b.name)


def test_stmt():
    x = tvm.make.Evaluate(0)
    tvm.make.For(tvm.var('i'), 0, 1,
                 tvm.stmt.For.Serial, 0,
                 x)

def test_dir():
    x = tvm.var('x')
    dir(x)

def test_dtype():
    x = tvm.var('x')
    assert x.dtype == 'int32'
    y = tvm.var('y')
    assert (x > y).dtype == 'uint1'


def test_any():
    x = tvm.var('x')
    y = tvm.var('y')
    z = tvm.var('z')
    try:
        t = x or x
        assert False
    except ValueError:
        pass
    try:
        tvm.any()
        assert False
    except ValueError:
        pass
    assert str(tvm.any(x < y)) == '(%s < %s)' % (x.name, y.name)
    assert str(tvm.any(x < y, x > z)) == '((%s < %s) || (%s > %s))' % (
        x.name, y.name, x.name, z.name)
    assert str(tvm.any(x < y, y > z + 1, x < z * 2)) == \
        '(((%s < %s) || (%s > (%s + 1))) || (%s < (%s*2)))' % (
            x.name, y.name, y.name, z.name, x.name, z.name)


def test_all():
    x = tvm.var('x')
    y = tvm.var('y')
    z = tvm.var('z')
    try:
        t = x and x
        assert False
    except ValueError:
        pass
    try:
        tvm.all()
        assert False
    except ValueError:
        pass
    assert str(tvm.all(x < y)) == '(%s < %s)' % (x.name, y.name)
    assert str(tvm.all(x < y, x > z)) == '((%s < %s) && (%s > %s))' % (
        x.name, y.name, x.name, z.name)
    assert str(tvm.all(x < y, y > z + 1, x < z * 2)) == \
        '(((%s < %s) && (%s > (%s + 1))) && (%s < (%s*2)))' % (
            x.name, y.name, y.name, z.name, x.name, z.name)

    
if __name__ == "__main__":
    test_attr()
    test_const()
    test_const_saveload_json()
    test_make()
    test_ir()
    test_basic()
    test_stmt()
    test_let()
    test_dir()
    test_dtype()
    test_any()
    test_all()