You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

54 lines
1.6 KiB

# -*- coding: utf-8 -*-
import torch
def exists(val):
return val is not None
def is_empty(t):
return t.nelement() == 0
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, '{} is not divisible by {}'.format(
numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
"""
Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
def init_method_normal(std=0.02):
"""Init method based on normal distribution.
This is only used for embeddings. The transformer has its
own initializer.
"""
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_