diff --git a/tests/src/test/scala/runtime/actionContainers/Python3AiActionContainerTests.scala b/tests/src/test/scala/runtime/actionContainers/Python3AiActionContainerTests.scala index 4f2d69762a4ff055a81586c659577bcaf2df79c6..e5c4075c04f464cb3a99349fda1ae374b0aab07a 100644 --- a/tests/src/test/scala/runtime/actionContainers/Python3AiActionContainerTests.scala +++ b/tests/src/test/scala/runtime/actionContainers/Python3AiActionContainerTests.scala @@ -58,4 +58,37 @@ class Python3AiActionContainerTests extends PythonActionContainerTests with WskA runRes should be(Some(JsObject("response" -> List(5, 12, 21, 32).toJson))) } } + + it should "run pytorch" in { + val (out, err) = withActionContainer() { c => + val code = + """ + |import torch + |import torchvision + |import torch.nn as nn + |import numpy as np + |import torchvision.transforms as transforms + |def main(args): + | # Create a numpy array. + | x = np.array([1,2,3,4]) + | + | # Convert the numpy array to a torch tensor. + | y = torch.from_numpy(x) + | + | # Convert the torch tensor to a numpy array. + | z = y.numpy() + | return { "response": z.tolist()} + """.stripMargin + + val (initCode, res) = c.init(initPayload(code)) + initCode should be(200) + + val (runCode, runRes) = c.run(runPayload(JsObject())) + runCode should be(200) + + runRes shouldBe defined + runRes should be(Some(JsObject("response" -> List(1, 2, 3, 4).toJson))) + } + } + }