Net

class dragon.vm.caffe.Net(*args)[source]

The abstraction caffe.Net.

This class accepts a proto-text file, and an optional serialized model weights. You can also specify a phase flag to indicate whether to compute the gradients:

train_net = Net('train.prototxt', 'TRAIN')
test_net = Net('test.prototxt', 'my.caffemodel', 'TEST')

__init__

Net.__init__(*args)[source]

Create a Net.

Parameters:
  • network_file (str) – The path of net.prototxt file.
  • weights (str, optional) – The path of the weights file.
  • phase ({'TRAIN', 'TEST'}, optional) – The optional phase.

Properties

blobs

Net.blobs

Return the blob dict.

Blobs stored in the dict will be:

for blob_name, blob in net.blobs():
    print(blob.data)  # DataTensor
    print(blob.diff)  # GradTensor
Returns:
Dict – The blob dict.

inputs

Net.inputs

Return the input blob names.

Returns:
Sequence[str] – The input names.

outputs

Net.outputs

Return the output blob names.

Returns:
Sequence[str] – The output names.

params

Net.params

Return the parameter dict.

Parameters stored in the dict will be:

for layer_name, blobs in net.params():
    print(layer_name)
    for blob in blobs:
        print('  *', blob.data)  # DataTensor
        print('  *', blob.diff)  # GradTensor
Returns:
Dict – The parameter dict.

Methods

backward

Net.backward(**diffs)[source]

Backward pass.

Parameters:
  • diffs (dict, optional) – The diffs to feed.

copy_from

classmethod Net.copy_from(weights)[source]

Copy the weights from the binary proto file.

Parameters:
  • weights (str) – The path of the weights file.

forward

Net.forward(**inputs)[source]

Forward pass.

Parameters:
  • inputs (dict, optional) – The blobs to feed.
Returns:

callable – The callable to return outputs.

forward_backward

Net.forward_backward()[source]

Forward pass following by backward pass.

This function will be compiled to a computation graph once executed, with implicit feeding of inputs.

save

Net.save(filename)[source]

Save the parameters into a binary file.

Parameters:
  • filename (str) – The path of model file.