# -*- coding: utf-8 -*-
|
|
import torch
|
|
from torch import nn
|
|
from torch.autograd import Variable
|
|
from torch.nn.parameter import Parameter
|
|
|
|
FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
|
|
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
|
|
|
|
|
|
def conversion_helper(val, conversion):
|
|
"""Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
|
|
if not isinstance(val, (tuple, list)):
|
|
return conversion(val)
|
|
rtn = [conversion_helper(v, conversion) for v in val]
|
|
if isinstance(val, tuple):
|
|
rtn = tuple(rtn)
|
|
return rtn
|
|
|
|
|
|
def fp32_to_fp16(val):
|
|
"""Convert fp32 `val` to fp16"""
|
|
def half_conversion(val):
|
|
val_typecheck = val
|
|
if isinstance(val_typecheck, (Parameter, Variable)):
|
|
val_typecheck = val.data
|
|
if isinstance(val_typecheck, FLOAT_TYPES):
|
|
val = val.half()
|
|
return val
|
|
return conversion_helper(val, half_conversion)
|
|
|
|
|
|
def fp16_to_fp32(val):
|
|
"""Convert fp16 `val` to fp32"""
|
|
def float_conversion(val):
|
|
val_typecheck = val
|
|
if isinstance(val_typecheck, (Parameter, Variable)):
|
|
val_typecheck = val.data
|
|
if isinstance(val_typecheck, HALF_TYPES):
|
|
val = val.float()
|
|
return val
|
|
return conversion_helper(val, float_conversion)
|
|
|
|
|
|
class FP16Module(nn.Module):
|
|
def __init__(self, module):
|
|
super(FP16Module, self).__init__()
|
|
self.add_module('module', module.half())
|
|
|
|
def forward(self, *inputs, **kwargs):
|
|
return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs))
|
|
|
|
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
|
return self.module.state_dict(destination, prefix, keep_vars)
|
|
|
|
def load_state_dict(self, state_dict, strict=True):
|
|
self.module.load_state_dict(state_dict, strict=strict)
|
|
|
|
def get_param(self, item):
|
|
return self.module.get_param(item)
|