llama_flash_attn_monkey_patch

Module Contents

Functions

forward(→ Tuple[torch.Tensor, Optional[torch.Tensor], ...)

Input shape: Batch x Time x Channel

llama_flash_attn_monkey_patch.forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, past_key_value: Tuple[torch.Tensor] | None = None, output_attentions: bool = False, use_cache: bool = False) Tuple[torch.Tensor, torch.Tensor | None, Tuple[torch.Tensor] | None][source]

Input shape: Batch x Time x Channel

attention_mask: [bsz, q_len]