llama_flash_attn_monkey_patch
Module Contents
Functions
|
Input shape: Batch x Time x Channel |
- llama_flash_attn_monkey_patch.forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, past_key_value: Tuple[torch.Tensor] | None = None, output_attentions: bool = False, use_cache: bool = False) Tuple[torch.Tensor, torch.Tensor | None, Tuple[torch.Tensor] | None][source]
Input shape: Batch x Time x Channel
attention_mask: [bsz, q_len]