diff --git a/tutorials/get_started.py b/tutorials/get_started.py
index 3dce21a649509015f4e6df546e9f95a2b2333451..2c6165940d057a688de5cbd1dc3ee57bfff43103 100644
--- a/tutorials/get_started.py
+++ b/tutorials/get_started.py
@@ -13,6 +13,12 @@ from __future__ import absolute_import, print_function
 import tvm
 import numpy as np
 
+# Global declarations of environment.
+
+tgt_host="llvm"
+# Change it to respective GPU if gpu is enabled Ex: cuda, opencl
+tgt="cuda"
+
 ######################################################################
 # Vector Add Example
 # ------------------
@@ -88,8 +94,9 @@ bx, tx = s[C].split(C.op.axis[0], factor=64)
 # compute grid. These are GPU specific constructs that allows us
 # to generate code that runs on GPU.
 #
-s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
-s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
+if tgt == "cuda":
+  s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
+  s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
 
 ######################################################################
 # Compilation
@@ -103,12 +110,12 @@ s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
 # function(including the inputs and outputs) as well as target language
 # we want to compile to.
 #
-# The result of compilation fadd is a CUDA device function that can
-# as well as a host wrapper that calls into the CUDA function.
+# The result of compilation fadd is a GPU device function(if GPU is involved) 
+# that can as well as a host wrapper that calls into the GPU function.
 # fadd is the generated host wrapper function, it contains reference
 # to the generated device function internally.
 #
-fadd_cuda = tvm.build(s, [A, B, C], "cuda", target_host="llvm", name="myadd")
+fadd = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd")
 
 ######################################################################
 # Run the Function
@@ -124,12 +131,13 @@ fadd_cuda = tvm.build(s, [A, B, C], "cuda", target_host="llvm", name="myadd")
 # - fadd runs the actual computation.
 # - asnumpy() copies the gpu array back to cpu and we can use this to verify correctness
 #
-ctx = tvm.gpu(0)
+ctx = tvm.context(tgt, 0)
+
 n = 1024
 a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
 b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
 c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
-fadd_cuda(a, b, c)
+fadd(a, b, c)
 np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
 ######################################################################
@@ -137,13 +145,16 @@ np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 # --------------------------
 # You can inspect the generated code in TVM. The result of tvm.build
 # is a tvm Module. fadd is the host module that contains the host wrapper,
-# it also contains a device module for the CUDA function.
+# it also contains a device module for the CUDA (GPU) function.
 #
 # The following code fetches the device module and prints the content code.
 #
-dev_module = fadd_cuda.imported_modules[0]
-print("-----CUDA code-----")
-print(dev_module.get_source())
+if tgt == "cuda":
+    dev_module = fadd.imported_modules[0]
+    print("-----GPU code-----")
+    print(dev_module.get_source())
+else:
+    print(fadd.get_source())
 
 ######################################################################
 # .. note:: Code Specialization
@@ -179,8 +190,9 @@ from tvm.contrib import cc
 from tvm.contrib import util
 
 temp = util.tempdir()
-fadd_cuda.save(temp.relpath("myadd.o"))
-fadd_cuda.imported_modules[0].save(temp.relpath("myadd.ptx"))
+fadd.save(temp.relpath("myadd.o"))
+if tgt == "cuda":
+    fadd.imported_modules[0].save(temp.relpath("myadd.ptx"))
 cc.create_shared(temp.relpath("myadd.so"), [temp.relpath("myadd.o")])
 print(temp.listdir())
 
@@ -201,8 +213,9 @@ print(temp.listdir())
 # re-link them together. We can verify that the newly loaded function works.
 #
 fadd1 = tvm.module.load(temp.relpath("myadd.so"))
-fadd1_dev = tvm.module.load(temp.relpath("myadd.ptx"))
-fadd1.import_module(fadd1_dev)
+if tgt == "cuda":
+    fadd1_dev = tvm.module.load(temp.relpath("myadd.ptx"))
+    fadd1.import_module(fadd1_dev)
 fadd1(a, b, c)
 np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
@@ -215,7 +228,7 @@ np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 # them together with the host code.
 # Currently we support packing of Metal, OpenCL and CUDA modules.
 #
-fadd_cuda.export_library(temp.relpath("myadd_pack.so"))
+fadd.export_library(temp.relpath("myadd_pack.so"))
 fadd2 = tvm.module.load(temp.relpath("myadd_pack.so"))
 fadd2(a, b, c)
 np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
@@ -241,16 +254,17 @@ np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 # The following codeblocks generate opencl code, creates array on opencl
 # device, and verifies the correctness of the code.
 #
-fadd_cl = tvm.build(s, [A, B, C], "opencl", name="myadd")
-print("------opencl code------")
-print(fadd_cl.imported_modules[0].get_source())
-ctx = tvm.cl(0)
-n = 1024
-a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
-b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
-c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
-fadd_cl(a, b, c)
-np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
+if tgt == "opencl":
+    fadd_cl = tvm.build(s, [A, B, C], "opencl", name="myadd")
+    print("------opencl code------")
+    print(fadd_cl.imported_modules[0].get_source())
+    ctx = tvm.cl(0)
+    n = 1024
+    a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
+    b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
+    c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
+    fadd_cl(a, b, c)
+    np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
 
 ######################################################################
 # Summary