TorchScript

library(torch)

TorchScript is a statically typed subset of Python that can be interpreted by LibTorch without any Python dependency. The torch R package provides interfaces to create, serialize, load and execute TorchScript programs.

Advantages of using TorchScript are:

Creating TorchScript programs

Tracing

TorchScript programs can be created from R using tracing. When using tracing, code is automatically converted into this subset of Python by recording only the actual operators on tensors and simply executing and discarding the other surrounding R code.

Currently tracing is the only supported way to create TorchScript programs from R code.

For example, let’s use the jit_trace function to create a TorchScript program. We pass a regular R function and example inputs.

fn <- function(x) {
  torch_relu(x)
}

traced_fn <- jit_trace(fn, torch_tensor(c(-1, 0, 1)))

The jit_trace function has executed the R function with the example input and recorded all torch operations that occurred during execution to create a graph. graph is how we call the intermediate representation of TorchScript programs, and it can be inspected with:

traced_fn$graph
#> graph(%0 : Float(3, strides=[1], requires_grad=0, device=cpu)):
#>   %1 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::relu(%0)
#>   return (%1)

The traced function can now be invoked as a regular R function:

traced_fn(torch_randn(3))
#> torch_tensor
#>  0.0000
#>  0.2149
#>  2.2419
#> [ CPUFloatType{3} ]

It’s also possible to trace nn_modules() defined in R, for example:

module <- nn_module(
  initialize = function() {
    self$linear1 <- nn_linear(10, 10)
    self$linear2 <- nn_linear(10, 1)
  },
  forward = function(x) {
    x %>% 
      self$linear1() %>% 
      nnf_relu() %>% 
      self$linear2()
  }
)
traced_module <- jit_trace(module(), torch_randn(10, 10))

When using jit_trace with a nn_module only the forward method is traced. You can use the jit_trace_module function to pass example inputs to other methods. Traced modules look like normal nn_modules(), and can be called the same way:

traced_module(torch_randn(3, 10))
#> torch_tensor
#>  0.1851
#>  0.3527
#>  0.2905
#> [ CPUFloatType{3,1} ][ grad_fn = <AddBackward0> ]

Limitations of tracing

  1. Tracing will not record any control flow like if-statements or loops. When this control flow is constant across your module, this is fine and it often inlines the control flow decisions. But sometimes the control flow is actually part of the model itself. For instance, a recurrent network is a loop over the (possibly dynamic) length of an input sequence. For example:
# fn does does an operation for each dimension of a tensor
fn <- function(x) {
  x %>% 
    torch_unbind(dim = 1) %>% 
    lapply(function(x) x$sum()) %>% 
    torch_stack(dim = 1)
}
# we trace using as an example a tensor with size (10, 5, 5)
traced_fn <- jit_trace(fn, torch_randn(10, 5, 5))
# applying it with a tensor with different size returns an error.
traced_fn(torch_randn(11, 5, 5))
#> Error in cpp_call_traced_fn(ptr, inputs): The following operation failed in the TorchScript interpreter.
#> Traceback of TorchScript (most recent call last):
#> RuntimeError: Expected 10 elements in a list but found 11
  1. In the returned ScriptModule, operations that have different behaviors in training and eval modes will always behave as if it were in the mode it was in during tracing, no matter which mode the ScriptModule is in. For example:
traced_dropout <- jit_trace(nn_dropout(), torch_ones(5,5))
traced_dropout(torch_ones(3,3))
#> torch_tensor
#>  0  0  0
#>  2  2  0
#>  2  2  2
#> [ CPUFloatType{3,3} ]
traced_dropout$eval()
# even after setting to eval mode, dropout is applied
traced_dropout(torch_ones(3,3))
#> torch_tensor
#>  2  0  2
#>  0  0  2
#>  0  2  2
#> [ CPUFloatType{3,3} ]
  1. Tracing proegrams can only take tensors and lists of tensors as input and return tensors and lists of tensors. For example:
fn <- function(x, y) {
  x + y
}
jit_trace(fn, torch_tensor(1), 1)
#> Error in cpp_trace_function(tr_fn, list(...), .compilation_unit, strict, : Only tensors or (possibly nested) dict or tuples of tensors can be inputs to traced functions. Got float
#> Exception raised from addInput at ../torch/csrc/jit/frontend/tracer.cpp:408 (most recent call first):
#> frame #0: c10::Error::Error(c10::SourceLocation, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >) + 98 (0x109bf05f2 in libc10.dylib)
#> frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&) + 106 (0x109beed6a in libc10.dylib)
#> frame #2: torch::jit::tracer::addInput(std::__1::shared_ptr<torch::jit::tracer::TracingState> const&, c10::IValue const&, std::__1::shared_ptr<c10::Type> const&, torch::jit::Value*) + 6951 (0x114fcd997 in libtorch_cpu.dylib)
#> frame #3: torch::jit::tracer::addInput(std::__1::shared_ptr<torch::jit::tracer::TracingState> const&, c10::IValue const&, std::__1::shared_ptr<c10::Type> const&, torch::jit::Value*) + 4216 (0x114fccee8 in libtorch_cpu.dylib)
#> frame #4: torch::jit::tracer::trace(std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue> >, std::__1::function<std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue> > (std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue> >)> const&, std::__1::function<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > (at::Tensor const&)>, bool, bool, torch::jit::Module*, std::__1::vector<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, std::__1::allocator<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > > > const&) + 1034 (0x114fca4fa in libtorch_cpu.dylib)
#> frame #5: _lantern_trace_fn + 708 (0x111d9a764 in liblantern.dylib)
#> frame #6: cpp_trace_function(Rcpp::Function_Impl<Rcpp::PreserveStorage>, XPtrTorchStack, XPtrTorchCompilationUnit, XPtrTorchstring, bool, XPtrTorchScriptModule, bool, bool) + 505 (0x1110d7df9 in torchpkg.so)
#> frame #7: _torch_cpp_trace_function + 720 (0x110f61a00 in torchpkg.so)
#> frame #8: R_doDotCall + 2679 (0x108da3297 in libR.dylib)
#> frame #9: do_dotcall + 334 (0x108da475e in libR.dylib)
#> frame #10: bcEval + 28581 (0x108ddb795 in libR.dylib)
#> frame #11: Rf_eval + 385 (0x108dd4141 in libR.dylib)
#> frame #12: R_execClosure + 2169 (0x108df3fb9 in libR.dylib)
#> frame #13: Rf_applyClosure + 471 (0x108df2da7 in libR.dylib)
#> frame #14: bcEval + 26782 (0x108ddb08e in libR.dylib)
#> frame #15: Rf_eval + 385 (0x108dd4141 in libR.dylib)
#> frame #16: R_execClosure + 2169 (0x108df3fb9 in libR.dylib)
#> frame #17: Rf_applyClosure + 471 (0x108df2da7 in libR.dylib)
#> frame #18: Rf_eval + 1595 (0x108dd45fb in libR.dylib)
#> frame #19: do_eval + 625 (0x108df7ff1 in libR.dylib)
#> frame #20: bcEval + 28581 (0x108ddb795 in libR.dylib)
#> frame #21: Rf_eval + 385 (0x108dd4141 in libR.dylib)
#> frame #22: R_execClosure + 2169 (0x108df3fb9 in libR.dylib)
#> frame #23: Rf_applyClosure + 471 (0x108df2da7 in libR.dylib)
#> frame #24: bcEval + 26782 (0x108ddb08e in libR.dylib)
#> frame #25: Rf_eval + 385 (0x108dd4141 in libR.dylib)
#> frame #26: forcePromise + 172 (0x108df24bc in libR.dylib)
#> frame #27: Rf_eval + 1124 (0x108dd4424 in libR.dylib)
#> frame #28: do_withVisible + 57 (0x108df8669 in libR.dylib)
#> frame #29: do_internal + 362 (0x108e3bd6a in libR.dylib)
#> frame #30: bcEval + 29053 (0x108ddb96d in libR.dylib)
#> frame #31: Rf_eval + 385 (0x108dd4141 in libR.dylib)
#> frame #32: R_execClosure + 2169 (0x108df3fb9 in libR.dylib)
#> frame #33: Rf_applyClosure + 471 (0x108df2da7 in libR.dylib)
#> frame #34: bcEval + 26782 (0x108ddb08e in libR.dylib)
#> frame #35: Rf_eval + 385 (0x108dd4141 in libR.dylib)
#> frame #36: forcePromise + 172 (0x108df24bc in libR.dylib)
#> frame #37: getvar + 778 (0x108dfdb0a in libR.dylib)
#> frame #38: bcEval + 15063 (0x108dd82c7 in libR.dylib)
#> frame #39: Rf_eval + 385 (0x108dd4141 in libR.dylib)
#> frame #40: R_execClosure + 2169 (0x108df3fb9 in libR.dylib)
#> frame #41: Rf_applyClosure + 471 (0x108df2da7 in libR.dylib)
#> frame #42: bcEval + 26782 (0x108ddb08e in libR.dylib)
#> frame #43: Rf_eval + 385 (0x108dd4141 in libR.dylib)
#> frame #44: forcePromise + 172 (0x108df24bc in libR.dylib)
#> frame #45: getvar + 778 (0x108dfdb0a in libR.dylib)
#> frame #46: bcEval + 15063 (0x108dd82c7 in libR.dylib)
#> frame #47: Rf_eval + 385 (0x108dd4141 in libR.dylib)
#> frame #48: forcePromise + 172 (0x108df24bc in libR.dylib)
#> frame #49: getvar + 778 (0x108dfdb0a in libR.dylib)
#> frame #50: bcEval + 15063 (0x108dd82c7 in libR.dylib)
#> frame #51: Rf_eval + 385 (0x108dd4141 in libR.dylib)
#> frame #52: forcePromise + 172 (0x108df24bc in libR.dylib)
#> frame #53: getvar + 778 (0x108dfdb0a in libR.dylib)
#> frame #54: bcEval + 15063 (0x108dd82c7 in libR.dylib)
#> frame #55: Rf_eval + 385 (0x108dd4141 in libR.dylib)
#> frame #56: forcePromise + 172 (0x108df24bc in libR.dylib)
#> frame #57: getvar + 778 (0x108dfdb0a in libR.dylib)
#> frame #58: bcEval + 15063 (0x108dd82c7 in libR.dylib)
#> frame #59: Rf_eval + 385 (0x108dd4141 in libR.dylib)
#> frame #60: forcePromise + 172 (0x108df24bc in libR.dylib)
#> frame #61: getvar + 778 (0x108dfdb0a in libR.dylib)
#> frame #62: bcEval + 15063 (0x108dd82c7 in libR.dylib)
#> frame #63: Rf_eval + 385 (0x108dd4141 in libR.dylib)
#> :

Compiling TorchScript

It’s also possible to create TorchScript programs by compiling TorchScript code. TorchScript code looks a lot like standard python code. For example:

tr <- jit_compile("
def fn (x: Tensor):
  return torch.relu(x)

")
tr$fn(torch_tensor(c(-1, 0, 1)))
#> torch_tensor
#>  0
#>  0
#>  1
#> [ CPUFloatType{3} ]

Serializing and loading

TorchScript programs can be serialized using the jit_save function and loaded back from disk with jit_load.

For example:

fn <- function(x) {
  torch_relu(x)
}
tr_fn <- jit_trace(fn, torch_tensor(1))
jit_save(tr_fn, "path.pt")
loaded <- jit_load("path.pt")

Loaded programs can be executed as usual:

loaded(torch_tensor(c(-1, 0, 1)))
#> torch_tensor
#>  0
#>  0
#>  1
#> [ CPUFloatType{3} ]

Note You can load TorchScript programs that were created in libraries different than torch for R. Eg, a TorchScript program can be created in PyTorch with torch.jit.trace or torch.jit.script, and run from R.

R objects are automatically converted to their TorchScript counterpart following the Types table in this document. However, sometimes it’s necessary to make type annotations with jit_tuple() and jit_scalar() to disambiguate the conversion.

Types

The following table lists all TorchScript types and how to convert the to and back to R.

TorchScript Type R Description
Tensor A torch_tensor with any shape, dtype or backend.
Tuple[T0, T1, ..., TN] A list() containing subtypes T0, T1, etc. wrapped with jit_tuple() .
bool A scalar logical value create using jit_scalar.
int A scalar integer value created using jit_scalar.
float A scalar floating value created using jit_scalar.
str A string (ie. character vector of length 1) wrapped in jit_scalar.
List[T] An R list of which all types are type T . Or numeric vectors, logical vectors, etc.
Optional[T] Not yet supported.
Dict[str, V] A named list with values of type V . Only str key values are currently supported.
T Not yet supported.
E Not yet supported.
NamedTuple[T0, T1, ...] A named list containing subtypes T0, T1, etc. wrapped in jit_tuple().