trace

dragon.vm.torch.jit.trace(
  func=None,
  example_inputs=None
)[source]

Trace a function and return an executable.

Only the tensor operations could be traced:

def foo(x):
    return x + x

bar = torch.jit.trace(foo, example_inputs=[torch.rand(1)])
print(bar(torch.tensor([1, 2])))

Above usages which can simplified as follows:

@torch.jit.trace(example_inputs=[torch.rand(1)])
def foo(x):
    return x + x

print(foo(torch.tensor([1, 2])))

If providing nn.Module, the forward method will be traced:

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + x

m = torch.jit.trace(MyModule(), example_inputs=[torch.rand(1)])
print(m(torch.tensor([1, 2]))
Parameters:
Returns:

callable – A callable to execute the traced function.