Merge pull request #6922 from brkirch/cumsum-fix
Improve cumsum fix for MPS
This commit is contained in:
commit
aa60fc6660
1 changed files with 7 additions and 4 deletions
|
@ -169,8 +169,10 @@ orig_Tensor_cumsum = torch.Tensor.cumsum
|
||||||
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
||||||
if input.device.type == 'mps':
|
if input.device.type == 'mps':
|
||||||
output_dtype = kwargs.get('dtype', input.dtype)
|
output_dtype = kwargs.get('dtype', input.dtype)
|
||||||
if any(output_dtype == broken_dtype for broken_dtype in [torch.bool, torch.int8, torch.int16, torch.int64]):
|
if output_dtype == torch.int64:
|
||||||
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
||||||
|
elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
||||||
|
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
||||||
return cumsum_func(input, *args, **kwargs)
|
return cumsum_func(input, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -181,9 +183,10 @@ if has_mps():
|
||||||
torch.nn.functional.layer_norm = layer_norm_fix
|
torch.nn.functional.layer_norm = layer_norm_fix
|
||||||
torch.Tensor.numpy = numpy_fix
|
torch.Tensor.numpy = numpy_fix
|
||||||
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
||||||
if not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.Tensor([1,1]).to(torch.device("mps")).cumsum(0, dtype=torch.int16)):
|
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
||||||
torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
|
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
|
||||||
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
|
torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
|
||||||
|
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
|
||||||
orig_narrow = torch.narrow
|
orig_narrow = torch.narrow
|
||||||
torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )
|
torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue