Remove fallback for Protocol import and remove Protocol import and remove instances of Protocol in code

add some whitespace between functions to be in line with other code in the repo
This commit is contained in:
AUTOMATIC 2023-01-09 20:08:48 +03:00
parent 89c3663080
commit cdfcbd9959

View file

@ -15,14 +15,9 @@ import torch
from torch import Tensor from torch import Tensor
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
import math import math
try:
from typing import Protocol
except:
from typing_extensions import Protocol
from typing import Optional, NamedTuple, List from typing import Optional, NamedTuple, List
def narrow_trunc( def narrow_trunc(
input: Tensor, input: Tensor,
dim: int, dim: int,
@ -31,12 +26,14 @@ def narrow_trunc(
) -> Tensor: ) -> Tensor:
return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start) return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
class AttnChunk(NamedTuple): class AttnChunk(NamedTuple):
exp_values: Tensor exp_values: Tensor
exp_weights_sum: Tensor exp_weights_sum: Tensor
max_score: Tensor max_score: Tensor
class SummarizeChunk(Protocol):
class SummarizeChunk:
@staticmethod @staticmethod
def __call__( def __call__(
query: Tensor, query: Tensor,
@ -44,7 +41,8 @@ class SummarizeChunk(Protocol):
value: Tensor, value: Tensor,
) -> AttnChunk: ... ) -> AttnChunk: ...
class ComputeQueryChunkAttn(Protocol):
class ComputeQueryChunkAttn:
@staticmethod @staticmethod
def __call__( def __call__(
query: Tensor, query: Tensor,
@ -52,6 +50,7 @@ class ComputeQueryChunkAttn(Protocol):
value: Tensor, value: Tensor,
) -> Tensor: ... ) -> Tensor: ...
def _summarize_chunk( def _summarize_chunk(
query: Tensor, query: Tensor,
key: Tensor, key: Tensor,
@ -72,6 +71,7 @@ def _summarize_chunk(
max_score = max_score.squeeze(-1) max_score = max_score.squeeze(-1)
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
def _query_chunk_attention( def _query_chunk_attention(
query: Tensor, query: Tensor,
key: Tensor, key: Tensor,
@ -112,6 +112,7 @@ def _query_chunk_attention(
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
return all_values / all_weights return all_values / all_weights
# TODO: refactor CrossAttention#get_attention_scores to share code with this # TODO: refactor CrossAttention#get_attention_scores to share code with this
def _get_attention_scores_no_kv_chunking( def _get_attention_scores_no_kv_chunking(
query: Tensor, query: Tensor,
@ -131,10 +132,12 @@ def _get_attention_scores_no_kv_chunking(
hidden_states_slice = torch.bmm(attn_probs, value) hidden_states_slice = torch.bmm(attn_probs, value)
return hidden_states_slice return hidden_states_slice
class ScannedChunk(NamedTuple): class ScannedChunk(NamedTuple):
chunk_idx: int chunk_idx: int
attn_chunk: AttnChunk attn_chunk: AttnChunk
def efficient_dot_product_attention( def efficient_dot_product_attention(
query: Tensor, query: Tensor,
key: Tensor, key: Tensor,