trace¶
dragon.vm.torch.jit.
trace
(
func=None,
example_inputs=None
)[source]¶Trace a function and return an executable.
Only the tensor operations could be traced:
def foo(x): return x + x bar = torch.jit.trace(foo, example_inputs=[torch.rand(1)]) print(bar(torch.tensor([1, 2])))
Above usages which can simplified as follows:
@torch.jit.trace(example_inputs=[torch.rand(1)]) def foo(x): return x + x print(foo(torch.tensor([1, 2])))
If providing
nn.Module
, theforward
method will be traced:class MyModule(torch.nn.Module): def forward(self, x): return x + x m = torch.jit.trace(MyModule(), example_inputs=[torch.rand(1)]) print(m(torch.tensor([1, 2])))
- Parameters:
- func (Union[callable, dragon.vm.torch.nn.Module], required) – The function to be traced.
- example_inputs (Sequence[dragon.vm.torch.Tensor], required) – The examples to hint the input info.
- Returns:
callable – A callable to execute the traced function.