Use narrow instead of dynamic_slice
This commit is contained in:
parent
3bfe2bb549
commit
b119815333
1 changed files with 19 additions and 15 deletions
|
@ -5,6 +5,7 @@
|
||||||
# credit:
|
# credit:
|
||||||
# Amin Rezaei (original author)
|
# Amin Rezaei (original author)
|
||||||
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
|
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
|
||||||
|
# brkirch (modified to use torch.narrow instead of dynamic_slice implementation)
|
||||||
# implementation of:
|
# implementation of:
|
||||||
# Self-attention Does Not Need O(n2) Memory":
|
# Self-attention Does Not Need O(n2) Memory":
|
||||||
# https://arxiv.org/abs/2112.05682v2
|
# https://arxiv.org/abs/2112.05682v2
|
||||||
|
@ -16,13 +17,13 @@ from torch.utils.checkpoint import checkpoint
|
||||||
import math
|
import math
|
||||||
from typing import Optional, NamedTuple, Protocol, List
|
from typing import Optional, NamedTuple, Protocol, List
|
||||||
|
|
||||||
def dynamic_slice(
|
def narrow_trunc(
|
||||||
x: Tensor,
|
input: Tensor,
|
||||||
starts: List[int],
|
dim: int,
|
||||||
sizes: List[int],
|
start: int,
|
||||||
|
length: int
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
slicing = [slice(start, start + size) for start, size in zip(starts, sizes)]
|
return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
|
||||||
return x[slicing]
|
|
||||||
|
|
||||||
class AttnChunk(NamedTuple):
|
class AttnChunk(NamedTuple):
|
||||||
exp_values: Tensor
|
exp_values: Tensor
|
||||||
|
@ -76,15 +77,17 @@ def _query_chunk_attention(
|
||||||
_, _, v_channels_per_head = value.shape
|
_, _, v_channels_per_head = value.shape
|
||||||
|
|
||||||
def chunk_scanner(chunk_idx: int) -> AttnChunk:
|
def chunk_scanner(chunk_idx: int) -> AttnChunk:
|
||||||
key_chunk = dynamic_slice(
|
key_chunk = narrow_trunc(
|
||||||
key,
|
key,
|
||||||
(0, chunk_idx, 0),
|
1,
|
||||||
(batch_x_heads, kv_chunk_size, k_channels_per_head)
|
chunk_idx,
|
||||||
|
kv_chunk_size
|
||||||
)
|
)
|
||||||
value_chunk = dynamic_slice(
|
value_chunk = narrow_trunc(
|
||||||
value,
|
value,
|
||||||
(0, chunk_idx, 0),
|
1,
|
||||||
(batch_x_heads, kv_chunk_size, v_channels_per_head)
|
chunk_idx,
|
||||||
|
kv_chunk_size
|
||||||
)
|
)
|
||||||
return summarize_chunk(query, key_chunk, value_chunk)
|
return summarize_chunk(query, key_chunk, value_chunk)
|
||||||
|
|
||||||
|
@ -161,10 +164,11 @@ def efficient_dot_product_attention(
|
||||||
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
|
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
|
||||||
|
|
||||||
def get_query_chunk(chunk_idx: int) -> Tensor:
|
def get_query_chunk(chunk_idx: int) -> Tensor:
|
||||||
return dynamic_slice(
|
return narrow_trunc(
|
||||||
query,
|
query,
|
||||||
(0, chunk_idx, 0),
|
1,
|
||||||
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
|
chunk_idx,
|
||||||
|
min(query_chunk_size, q_tokens)
|
||||||
)
|
)
|
||||||
|
|
||||||
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
|
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
|
||||||
|
|
Loading…
Add table
Reference in a new issue