Source code for sparta.specializer.operators.sparse_linear

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Optional, Type

import torch

from sparta.specializer import kernels
from sparta.specializer.operators.operator_base import OperatorBase


[docs]class SparseLinear(OperatorBase): '''Sparse linear operator. Examples: .. code-block:: python # Create a dense linear layer dense_linear = torch.nn.Linear(1024, 2048) # Create a mask weight_mask = torch.rand((2048, 1024)) > 0.99 # Create a sparse linear layer using the dense layer and the mask sparse_linear = sparta.nn.SparseLinear(dense_linear, weight_mask=weight_mask) # Tune the sparse linear layer sparta.tune(sparse_linear, sample_inputs=[torch.rand((512, 1024))]) Args: raw_module (torch.nn.Linear): The corresponding dense linear operator. input_mask (torch.Tensor): The input mask tensor with shape (\*, in_features). The kernel mode will be "sparse x dense => dense" if the input mask is set. weight_mask (torch.Tensor): The weight mask tensor with shape (out_features, in_features). The kernel mode will be "dense x sparse => dense" if the input mask is set. output_mask (torch.Tensor): The output mask tensor with shape (\*, out_features). The kernel mode will be "dense x dense => sparse" if the input mask is set. ''' __base_class__: Type[torch.nn.Module] = torch.nn.Linear def __init__( self, raw_module: torch.nn.Linear, input_mask: Optional[torch.Tensor] = None, weight_mask: Optional[torch.Tensor] = None, output_mask: Optional[torch.Tensor] = None ): super().__init__(raw_module) N, K = raw_module.weight.shape M = None if sum(map(lambda x: x is not None, [input_mask, weight_mask, output_mask])) > 1: raise ValueError(f'linear operators with multiple sparse masks are not supported') if input_mask is not None: self._stype = 'sdd' self._compressed = False input_mask = input_mask.cpu().detach().numpy() if input_mask.shape[1] == K: M = input_mask.shape[0] self._mask = {'A': input_mask} else: raise ValueError(f'expected input mask shape (?, {K}), got {input_mask.shape}') elif weight_mask is not None: self._stype = 'dsd' self._compressed = True weight_mask = weight_mask.cpu().detach().numpy() if weight_mask.shape == (N, K): self._mask = {'B': weight_mask} else: raise ValueError(f'expected weight mask shape ({N}, {K}), got {weight_mask.shape}') elif output_mask is not None: self._stype = 'dds' self._compressed = False output_mask = output_mask.cpu().detach().numpy() if output_mask.shape[1] == N: M = output_mask.shape[0] self._mask = {'C': output_mask} else: raise ValueError(f'expected output mask shape (?, {N}), got {output_mask.shape}') else: raise ValueError(f'expected a sparse mask on input / weight / output') self._shape = {'GLOBAL_M_VALUE': M, 'GLOBAL_N_VALUE': N, 'GLOBAL_K_VALUE': K} self._biased = raw_module.bias is not None self._transpose = True self._dtype = 'int' if 'int' in str(raw_module.weight.dtype) else 'float' self._possible_implementations = { 'sparta': kernels.SparTATemplateSparseMatMulKernel(self._stype, self._dtype, self._biased, self._transpose, self._compressed), 'openai': kernels.OpenAITemplateSparseMatMulKernel(self._stype, self._dtype, self._biased, self._transpose, self._compressed), } def _load_compile_kernel(self, forward_kernel: kernels.MatMulKernelBase): '''Set PyTorch module parameters: weight and bias (if exists). Args: forward_kernel (kernels.MatMulKernelBase): A matmul kernel object which provides the function to sparsify the weight tensor in "dense x sparse => dense" mode. ''' device = self._raw_module.weight.device if self._biased: self.bias = torch.nn.Parameter(self._raw_module.bias.detach(), requires_grad=False) else: self.bias = None weight = self._raw_module.weight.cpu().detach().numpy().astype(f'{self._dtype}32') if self._stype == 'dsd': B_tensor = forward_kernel.get_input('B') B_tensor.set_data(weight) weight = B_tensor.sparse()['val'] self.weight = torch.nn.Parameter(torch.from_numpy(weight), requires_grad=False).to(device) def _sparse_forward(self, A: torch.Tensor): '''Calls the sparse forward kernel. Args: A (torch.Tensor): The input tensor. ''' if self._biased: return self._forward_function(A, self.weight, self.bias) else: return self._forward_function(A, self.weight) def _read_sample_inputs(self, A: torch.Tensor): '''Read shape config and convert sample inputs to test inputs. The captured shape config will be passed to implements (kernels). Args: A (torch.Tensor): The sample input tensor. Returns: Tuple: The first value is the shape dict, the second value is the test input dict. ''' M, K = A.shape assert self._shape['GLOBAL_K_VALUE'] == K self._shape['GLOBAL_M_VALUE'] = M for kern in self._possible_implementations.values(): kern.set_parameters(self._shape) inputs = { 'A': A.cpu().detach().numpy().astype(f'{self._dtype}32'), 'B': self._raw_module.weight.cpu().detach().numpy().astype(f'{self._dtype}32'), } if self._biased: inputs['bias'] = self._raw_module.bias.cpu().detach().numpy().astype(f'{self._dtype}32') return self._shape, inputs