Skip to content

tasptz/pytorch-stochastic-depth

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

23 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Stochastic Depth with PyTorch Hooks Travis CI build status PyPI version

A simple hook based implementation of Deep Networks with Stochastic Depth for torchvision resnets.

Example

import torch
import torchvision.models as models
resnet = models.resnet152(pretrained=False)
resnet.train()

from stochdepth import uniform
hooks = uniform(resnet, p=0.2)

x = torch.zeros((8, 3, 224, 224), dtype=torch.float32)
y = resnet(x)

# remove hooks
for h in hooks:
    h.remove()

from stochdepth import resnet_linear
hooks = resnet_linear(resnet)

y = resnet(x)
# remove hooks
for h in hooks:
    h.remove()