Remove fallback for Protocol import and remove Protocol import and remove instances of Protocol in code
add some whitespace between functions to be in line with other code in the repo
This commit is contained in:
parent
89c3663080
commit
cdfcbd9959
1 changed files with 11 additions and 8 deletions
|
@ -15,14 +15,9 @@ import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
import math
|
import math
|
||||||
|
|
||||||
try:
|
|
||||||
from typing import Protocol
|
|
||||||
except:
|
|
||||||
from typing_extensions import Protocol
|
|
||||||
|
|
||||||
from typing import Optional, NamedTuple, List
|
from typing import Optional, NamedTuple, List
|
||||||
|
|
||||||
|
|
||||||
def narrow_trunc(
|
def narrow_trunc(
|
||||||
input: Tensor,
|
input: Tensor,
|
||||||
dim: int,
|
dim: int,
|
||||||
|
@ -31,12 +26,14 @@ def narrow_trunc(
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
|
return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
|
||||||
|
|
||||||
|
|
||||||
class AttnChunk(NamedTuple):
|
class AttnChunk(NamedTuple):
|
||||||
exp_values: Tensor
|
exp_values: Tensor
|
||||||
exp_weights_sum: Tensor
|
exp_weights_sum: Tensor
|
||||||
max_score: Tensor
|
max_score: Tensor
|
||||||
|
|
||||||
class SummarizeChunk(Protocol):
|
|
||||||
|
class SummarizeChunk:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __call__(
|
def __call__(
|
||||||
query: Tensor,
|
query: Tensor,
|
||||||
|
@ -44,7 +41,8 @@ class SummarizeChunk(Protocol):
|
||||||
value: Tensor,
|
value: Tensor,
|
||||||
) -> AttnChunk: ...
|
) -> AttnChunk: ...
|
||||||
|
|
||||||
class ComputeQueryChunkAttn(Protocol):
|
|
||||||
|
class ComputeQueryChunkAttn:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __call__(
|
def __call__(
|
||||||
query: Tensor,
|
query: Tensor,
|
||||||
|
@ -52,6 +50,7 @@ class ComputeQueryChunkAttn(Protocol):
|
||||||
value: Tensor,
|
value: Tensor,
|
||||||
) -> Tensor: ...
|
) -> Tensor: ...
|
||||||
|
|
||||||
|
|
||||||
def _summarize_chunk(
|
def _summarize_chunk(
|
||||||
query: Tensor,
|
query: Tensor,
|
||||||
key: Tensor,
|
key: Tensor,
|
||||||
|
@ -72,6 +71,7 @@ def _summarize_chunk(
|
||||||
max_score = max_score.squeeze(-1)
|
max_score = max_score.squeeze(-1)
|
||||||
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
||||||
|
|
||||||
|
|
||||||
def _query_chunk_attention(
|
def _query_chunk_attention(
|
||||||
query: Tensor,
|
query: Tensor,
|
||||||
key: Tensor,
|
key: Tensor,
|
||||||
|
@ -112,6 +112,7 @@ def _query_chunk_attention(
|
||||||
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
|
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
|
||||||
return all_values / all_weights
|
return all_values / all_weights
|
||||||
|
|
||||||
|
|
||||||
# TODO: refactor CrossAttention#get_attention_scores to share code with this
|
# TODO: refactor CrossAttention#get_attention_scores to share code with this
|
||||||
def _get_attention_scores_no_kv_chunking(
|
def _get_attention_scores_no_kv_chunking(
|
||||||
query: Tensor,
|
query: Tensor,
|
||||||
|
@ -131,10 +132,12 @@ def _get_attention_scores_no_kv_chunking(
|
||||||
hidden_states_slice = torch.bmm(attn_probs, value)
|
hidden_states_slice = torch.bmm(attn_probs, value)
|
||||||
return hidden_states_slice
|
return hidden_states_slice
|
||||||
|
|
||||||
|
|
||||||
class ScannedChunk(NamedTuple):
|
class ScannedChunk(NamedTuple):
|
||||||
chunk_idx: int
|
chunk_idx: int
|
||||||
attn_chunk: AttnChunk
|
attn_chunk: AttnChunk
|
||||||
|
|
||||||
|
|
||||||
def efficient_dot_product_attention(
|
def efficient_dot_product_attention(
|
||||||
query: Tensor,
|
query: Tensor,
|
||||||
key: Tensor,
|
key: Tensor,
|
||||||
|
|
Loading…
Reference in a new issue