Skip to content
Draft
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions lit_llama/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def forward(
# instead of calculating `query`, `key` and `value` by separately multiplying input `x` with corresponding
# weight matrices do it (for all heads) in a single multiplication with a matrix of 3x size (concatenated
# weights for q, k, v) and then split the result along `embedding size` dimension
q, k, v = self.c_attn(x).split(self.n_embd, dim=2) # (B, T, 3 * C) --> 3 * (B, T, C)
q, k, v = self.c_attn(x).split(C, dim=2) # (B, T, 3 * C) --> 3 * (B, T, C)

# in order to move head_size (hs) dimension right after batch (B) dimension, we need to first split
# embedding size (C) dimension into num_heads (nh) and head_size (hs)
Expand Down Expand Up @@ -150,9 +150,9 @@ def forward(
if adapter_kv_cache is not None:
ak, av = adapter_kv_cache # 2 * (B, nh, aT, hs)
else:
prefix = self.adapter_wte.weight.reshape(1, self.adapter_prompt_length, self.n_embd)
aT = prefix.size(1)
_, ak, av = self.c_attn(prefix).split(self.n_embd, dim=2) # (1, aT, 3 * C) --> 3 * (1, aT, C)
aT = self.adapter_prompt_length
prefix = self.adapter_wte.weight.reshape(1, aT, C)
_, ak, av = self.c_attn(prefix).split(C, dim=2) # (1, aT, 3 * C) --> 3 * (1, aT, C)
ak = ak.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2) # (B, nh, aT, hs)
av = av.view(1, aT, self.n_head, head_size).repeat(B, 1, 1, 1).transpose(1, 2) # (B, nh, aT, hs)
adapter_kv_cache = (ak, av)
Expand Down Expand Up @@ -223,14 +223,14 @@ class LLaMA(llama.LLaMA):

def __init__(self, config: LLaMAConfig) -> None:
nn.Module.__init__(self)
assert config.vocab_size is not None
assert config.padded_vocab_size is not None
assert config.block_size is not None
self.config = config

self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.vocab_size, config.n_embd),
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),
ln_f=RMSNorm(config.n_embd),
)
Expand Down Expand Up @@ -297,7 +297,7 @@ def forward(

x = self.transformer.ln_f(x) # (B, T, n_embd)

logits = self.lm_head(x) # (B, T, vocab_size)
logits = self.lm_head(x) # (B, T, padded_vocab_size)

return logits

Expand Down