# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import abc
import logging
import warnings
import subprocess
from typing import Type, Tuple, List, Dict, Union
import torch
import numpy as np
from sparta.specializer import kernels
from sparta.common.tuning import TunableItemCfg, Tunable
from sparta.testing import test_latency
_logger = logging.Logger(__name__)
[docs]class OperatorBase(torch.nn.Module):
'''Base class of sparse operators.
Examples:
.. code-block:: python
# Create a dense softmax layer
dense_softmax = torch.nn.Softmax
# Create a mask
mask = torch.rand((2048, 1024)) > 0.99
# Create a sparse softmax layer using the dense layer and the mask
sparse_softmax = sparta.nn.SparseSoftmax(dense_softmax, mask=mask)
# Tune the sparse softmax layer
sparta.tune(sparse_softmax, sample_inputs=[torch.rand((2048, 1024))])
Args:
raw_module (torch.nn.Module): The corresponding dense operator.
'''
__base_class__: Type[torch.nn.Module] = None
def __init__(self, raw_module: torch.nn.Module):
if type(raw_module) is not self.__base_class__:
raise ValueError(f'expected a {self.__base_class__} module')
super().__init__()
self._raw_module = raw_module
self._forward_function = None
self._mask = None
self.ready = False
self._possible_implementations = {}
self._search_space = None
[docs] def build(self, params: Dict, sample_inputs: List, jit: bool = True):
'''Build the sparse kernel using the specified implementation and configs.
Args:
params (Dict): building parameters. It should be a valid sample of search space
params['_name'] should be a valid kernel name in `self._possible_implementations`
other key-value pairs in params are the parameters for `self._possible_implementations[params['_name']]`
sample_inputs (List): sample inputs for shape inference
jit (bool): Determine whether to build the kernel using JIT mode.
'''
if sample_inputs:
shape, inputs = self._read_sample_inputs(*sample_inputs)
forward_kernel = self._possible_implementations[params['_name']]
self._forward_function = forward_kernel.compile(params, self._mask, jit).forward
self._load_compile_kernel(forward_kernel)
self.ready = True
[docs] def forward(self, *args):
'''Forward function. Calls the corresponding dense operator if not built.'''
if self.ready:
return self._sparse_forward(*args)
else:
warnings.warn('the sparse module is not compiled, using the dense module to forward')
return self._raw_module.forward(*args)
@abc.abstractmethod
def _sparse_forward(self, *args):
'''Calls the sparse forward kernel.'''
@abc.abstractmethod
def _load_compile_kernel(self, forward_kernel: kernels.KernelBase):
'''Set PyTorch module parameters according to the dense operator.
Args:
forward_kernel (kernels.KernelBase): The forward kernel object
which provides the sparsify function.
'''
@abc.abstractmethod
def _read_sample_inputs(self, *args) -> Tuple[dict, dict]:
'''Read shape config and convert sample inputs to test inputs.'''
[docs] def set_search_space(self, search_space: TunableItemCfg = None):
'''Input a custom search space to override the default one before tuning.
Examples:
.. code-block:: python
# Create a dense linear layer
dense_linear = torch.nn.Linear(1024, 2048)
# Create a mask
weight_mask = torch.rand((2048, 1024)) > 0.99
# Create a sparse linear layer using the dense layer and the mask
sparse_linear = sparta.nn.SparseLinear(dense_linear, weight_mask=weight_mask)
# Set custom search space
search_space_cfg = TunableItemCfg('choice', {
'openai': {},
'sparta': {
'BLOCK_SIZE_M_VALUE': TunableItemCfg('choice', [32, 64]),
'BLOCK_SIZE_K_VALUE': TunableItemCfg('choice', [32, 64]),
'BLOCK_SIZE_N_VALUE': TunableItemCfg('choice', [32, 64]),
'THREAD_SIZE_M_VALUE': TunableItemCfg('choice', [4]),
'THREAD_SIZE_K_VALUE': TunableItemCfg('choice', [4]),
'THREAD_SIZE_N_VALUE': TunableItemCfg('choice', [4]),
},
})
sparse_linear.set_search_space(search_space_cfg)
# Tune the sparse linear layer
sparta.tune(sparse_linear, sample_inputs=[torch.rand((512, 1024))])
Args:
search_space (dict): Key is the tuning algorithm, value is a dictionary whose keys are
tunable parameters and values are lists of possible values.
'''
if search_space is None:
search_space = TunableItemCfg(
'choice',
_is_nested=True,
_value={k: v.get_search_space() for k, v in self._possible_implementations.items()}
)
self._search_space = search_space
[docs] def get_search_space(self) -> TunableItemCfg:
'''Get the search space of the sparse operator.
Returns:
TunableItemCfg: the search space of the sparse operator.
'''
if self._search_space is None:
self.set_search_space()
return self._search_space
[docs] def tester(self, params: Dict, sample_inputs: List, jit: bool = False, weight_bk: float=0.) -> float:
'''Tester function for tuning. It will build the sparse kernel and run the forward function (or backward also), and return the measured time.
Args:
params (Dict): building parameters. It should be a valid sample of search space
sample_inputs (List): sample inputs for shape inference
jit (bool): Determine whether to test the kernel using JIT mode.
weight_bk (float): The weight of the backward time in the total time. If set to 0, the backward time is not counted.
Returns:
float: The performance (running latency) of the kernel.
'''
if jit:
self.build(params, sample_inputs, jit)
# how to get the latency of the compiled kernel?
latency = test_latency(self.forward, sample_inputs, None)
if weight_bk > 0:
# TODO add backward time
raise NotImplementedError
else:
shape, inputs = self._read_sample_inputs(*sample_inputs)
implement, cfg = params['_name'], params
kernel = self._possible_implementations[implement]
latency = kernel.test(dict(shape, **cfg), mask=self._mask, inputs=inputs)
if weight_bk > 0:
# TODO add backward time
raise NotImplementedError
return latency