Use narrow instead of dynamic_slice

This commit is contained in:
brkirch 2023-01-05 04:37:17 -05:00
parent 3bfe2bb549
commit b119815333

View file

@ -5,6 +5,7 @@
# credit: # credit:
# Amin Rezaei (original author) # Amin Rezaei (original author)
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks) # Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
# brkirch (modified to use torch.narrow instead of dynamic_slice implementation)
# implementation of: # implementation of:
# Self-attention Does Not Need O(n2) Memory": # Self-attention Does Not Need O(n2) Memory":
# https://arxiv.org/abs/2112.05682v2 # https://arxiv.org/abs/2112.05682v2
@ -16,13 +17,13 @@ from torch.utils.checkpoint import checkpoint
import math import math
from typing import Optional, NamedTuple, Protocol, List from typing import Optional, NamedTuple, Protocol, List
def dynamic_slice( def narrow_trunc(
x: Tensor, input: Tensor,
starts: List[int], dim: int,
sizes: List[int], start: int,
length: int
) -> Tensor: ) -> Tensor:
slicing = [slice(start, start + size) for start, size in zip(starts, sizes)] return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
return x[slicing]
class AttnChunk(NamedTuple): class AttnChunk(NamedTuple):
exp_values: Tensor exp_values: Tensor
@ -76,15 +77,17 @@ def _query_chunk_attention(
_, _, v_channels_per_head = value.shape _, _, v_channels_per_head = value.shape
def chunk_scanner(chunk_idx: int) -> AttnChunk: def chunk_scanner(chunk_idx: int) -> AttnChunk:
key_chunk = dynamic_slice( key_chunk = narrow_trunc(
key, key,
(0, chunk_idx, 0), 1,
(batch_x_heads, kv_chunk_size, k_channels_per_head) chunk_idx,
kv_chunk_size
) )
value_chunk = dynamic_slice( value_chunk = narrow_trunc(
value, value,
(0, chunk_idx, 0), 1,
(batch_x_heads, kv_chunk_size, v_channels_per_head) chunk_idx,
kv_chunk_size
) )
return summarize_chunk(query, key_chunk, value_chunk) return summarize_chunk(query, key_chunk, value_chunk)
@ -161,10 +164,11 @@ def efficient_dot_product_attention(
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
def get_query_chunk(chunk_idx: int) -> Tensor: def get_query_chunk(chunk_idx: int) -> Tensor:
return dynamic_slice( return narrow_trunc(
query, query,
(0, chunk_idx, 0), 1,
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) chunk_idx,
min(query_chunk_size, q_tokens)
) )
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale) summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)