Skip to content

Commit fcca0a7

Browse files
authored
refact : fix convert script + zero out KV cache to avoid nans (#3523)
* refact : fix convert script + zero out KV cache to avoid nans * ggml : silu(-inf) should never happen * metal : assert various kernel requirements
1 parent dcc09d2 commit fcca0a7

File tree

6 files changed

+51
-91
lines changed

6 files changed

+51
-91
lines changed

convert-refact-hf-to-gguf.py

Lines changed: 8 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -17,33 +17,6 @@
1717
sys.path.insert(1, str(Path(__file__).parent / "gguf-py" / "gguf"))
1818
import gguf
1919

20-
21-
def bytes_to_unicode():
22-
# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
23-
"""
24-
Returns list of utf-8 byte and a corresponding list of unicode strings.
25-
The reversible bpe codes work on unicode strings.
26-
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
27-
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
28-
This is a significant percentage of your normal, say, 32K bpe vocab.
29-
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
30-
And avoids mapping to whitespace/control characters the bpe code barfs on.
31-
"""
32-
bs = (
33-
list(range(ord("!"), ord("~") + 1))
34-
+ list(range(ord("¡"), ord("¬") + 1))
35-
+ list(range(ord("®"), ord("ÿ") + 1))
36-
)
37-
cs = bs[:]
38-
n = 0
39-
for b in range(2**8):
40-
if b not in bs:
41-
bs.append(b)
42-
cs.append(2**8 + n)
43-
n += 1
44-
return dict(zip(bs, (chr(n) for n in cs)))
45-
46-
4720
def count_model_parts(dir_model: Path) -> int:
4821
num_parts = 0
4922
for filename in os.listdir(dir_model):
@@ -153,53 +126,25 @@ def parse_args() -> argparse.Namespace:
153126
scores: list[float] = []
154127
toktypes: list[int] = []
155128

156-
tokenizer_json_file = dir_model / "tokenizer.json"
157-
if not tokenizer_json_file.is_file():
158-
print(f"Error: Missing {tokenizer_json_file}", file=sys.stderr)
159-
sys.exit(1)
160-
161129
# gpt2 tokenizer
162130
gguf_writer.add_tokenizer_model("gpt2")
163131

164-
with open(tokenizer_json_file, "r", encoding="utf-8") as f:
165-
tokenizer_json = json.load(f)
166-
167132
print("gguf: get gpt2 tokenizer vocab")
168133

134+
# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
135+
tokenizer = AutoTokenizer.from_pretrained(dir_model)
136+
169137
# The number of tokens in tokenizer.json can differ from the expected vocab size.
170138
# This causes downstream issues with mismatched tensor sizes when running the inference
171-
vocab_size = (
172-
hparams["vocab_size"]
173-
if "vocab_size" in hparams
174-
else len(tokenizer_json["model"]["vocab"])
175-
)
176-
177-
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
139+
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
140+
assert max(tokenizer.vocab.values()) < vocab_size
178141

179142
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
180-
byte_encoder = bytes_to_unicode()
181-
byte_decoder = {v: k for k, v in byte_encoder.items()}
182143

183144
for i in range(vocab_size):
184-
if i in reverse_vocab:
185-
text = reverse_vocab[i]
186-
try:
187-
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
188-
except KeyError:
189-
text = bytearray()
190-
for c in reverse_vocab[i]:
191-
if ord(c) < 256: # single byte character
192-
text.append(byte_decoder[ord(c)])
193-
else: # multibyte special token character
194-
text.extend(c.encode("utf-8"))
195-
else:
196-
print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.")
197-
pad_token = f"[PAD{i}]".encode("utf8")
198-
text = bytearray(pad_token)
199-
200-
tokens.append(text)
201-
scores.append(0.0) # dymmy
202-
toktypes.append(gguf.TokenType.NORMAL) # dummy
145+
tokens.append(reverse_vocab[i] if i in reverse_vocab else f"[PAD{i}]")
146+
scores.append(0.0) # dummy
147+
toktypes.append(gguf.TokenType.NORMAL)
203148

204149
gguf_writer.add_token_list(tokens)
205150
gguf_writer.add_token_scores(scores)

examples/parallel/parallel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ int main(int argc, char ** argv) {
167167

168168
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
169169
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
170-
llama_batch batch = llama_batch_init(params.n_ctx, 0);
170+
llama_batch batch = llama_batch_init(n_ctx, 0);
171171

172172
int32_t n_total_prompt = 0;
173173
int32_t n_total_gen = 0;

ggml-metal.m

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -779,8 +779,8 @@ void ggml_metal_graph_compute(
779779
} break;
780780
case GGML_OP_CONCAT:
781781
{
782+
const int64_t nb = ne00;
782783

783-
int64_t nb = ne00;
784784
[encoder setComputePipelineState:ctx->pipeline_concat];
785785
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
786786
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -812,6 +812,7 @@ void ggml_metal_graph_compute(
812812
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
813813

814814
const int nth = MIN(1024, ne0);
815+
815816
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
816817
} break;
817818
case GGML_OP_ADD:
@@ -909,9 +910,10 @@ void ggml_metal_graph_compute(
909910
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
910911
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
911912

912-
const int64_t n = ggml_nelements(dst)/4;
913+
const int64_t n = ggml_nelements(dst);
914+
GGML_ASSERT(n % 4 == 0);
913915

914-
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
916+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
915917
} break;
916918
case GGML_OP_UNARY:
917919
switch (ggml_get_unary_op(gf->nodes[i])) {
@@ -921,9 +923,10 @@ void ggml_metal_graph_compute(
921923
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
922924
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
923925

924-
const int64_t n = ggml_nelements(dst)/4;
926+
const int64_t n = ggml_nelements(dst);
927+
GGML_ASSERT(n % 4 == 0);
925928

926-
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
929+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
927930
} break;
928931
case GGML_UNARY_OP_RELU:
929932
{
@@ -941,9 +944,10 @@ void ggml_metal_graph_compute(
941944
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
942945
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
943946

944-
const int64_t n = ggml_nelements(dst)/4;
947+
const int64_t n = ggml_nelements(dst);
948+
GGML_ASSERT(n % 4 == 0);
945949

946-
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
950+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
947951
} break;
948952
default:
949953
{
@@ -1251,6 +1255,8 @@ void ggml_metal_graph_compute(
12511255
} break;
12521256
case GGML_OP_RMS_NORM:
12531257
{
1258+
GGML_ASSERT(ne00 % 4 == 0);
1259+
12541260
float eps;
12551261
memcpy(&eps, dst->op_params, sizeof(float));
12561262

ggml-metal.metal

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -345,10 +345,11 @@ kernel void kernel_rms_norm(
345345
uint sgitg[[simdgroup_index_in_threadgroup]],
346346
uint tiisg[[thread_index_in_simdgroup]],
347347
uint ntg[[threads_per_threadgroup]]) {
348-
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
349-
device const float * x_scalar = (device const float *) x;
350-
float4 sumf=0;
351-
float all_sum=0;
348+
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
349+
device const float * x_scalar = (device const float *) x;
350+
351+
float4 sumf = 0;
352+
float all_sum = 0;
352353

353354
// parallel sum
354355
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
@@ -361,14 +362,17 @@ kernel void kernel_rms_norm(
361362
}
362363

363364
threadgroup_barrier(mem_flags::mem_threadgroup);
365+
364366
// broadcast, simd group number is ntg / 32
365367
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
366368
if (tpitg < i) {
367369
sum[tpitg] += sum[tpitg + i];
368370
}
369371
}
370372
if (tpitg == 0) {
371-
for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
373+
for (int i = 4 * (ne00 / 4); i < ne00; i++) {
374+
sum[0] += x_scalar[i];
375+
}
372376
sum[0] /= ne00;
373377
}
374378

@@ -383,7 +387,9 @@ kernel void kernel_rms_norm(
383387
y[i00] = x[i00] * scale;
384388
}
385389
if (tpitg == 0) {
386-
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
390+
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
391+
y_scalar[i00] = x_scalar[i00] * scale;
392+
}
387393
}
388394
}
389395

ggml.c

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11233,7 +11233,7 @@ static void ggml_compute_forward_silu_f32(
1123311233

1123411234
#ifndef NDEBUG
1123511235
for (int k = 0; k < nc; k++) {
11236-
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
11236+
const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
1123711237
UNUSED(x);
1123811238
assert(!isnan(x));
1123911239
assert(!isinf(x));
@@ -13066,17 +13066,17 @@ static void ggml_compute_forward_alibi_f32(
1306613066

1306713067
assert(n_past >= 0);
1306813068

13069-
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
13070-
const int ne1 = src0->ne[1]; // seq_len_without_past
13071-
const int ne2 = src0->ne[2]; // n_head -> this is k
13072-
//const int ne3 = src0->ne[3]; // 1 -> bsz
13069+
const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
13070+
const int64_t ne1 = src0->ne[1]; // seq_len_without_past
13071+
const int64_t ne2 = src0->ne[2]; // n_head -> this is k
13072+
//const int64_t ne3 = src0->ne[3]; // 1 -> bsz
1307313073

13074-
const int n = ggml_nrows(src0);
13075-
const int ne2_ne3 = n/ne1; // ne2*ne3
13074+
const int64_t n = ggml_nrows(src0);
13075+
const int64_t ne2_ne3 = n/ne1; // ne2*ne3
1307613076

13077-
const int nb0 = src0->nb[0];
13078-
const int nb1 = src0->nb[1];
13079-
const int nb2 = src0->nb[2];
13077+
const size_t nb0 = src0->nb[0];
13078+
const size_t nb1 = src0->nb[1];
13079+
const size_t nb2 = src0->nb[2];
1308013080
//const int nb3 = src0->nb[3];
1308113081

1308213082
GGML_ASSERT(nb0 == sizeof(float));
@@ -13088,9 +13088,9 @@ static void ggml_compute_forward_alibi_f32(
1308813088
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
1308913089
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
1309013090

13091-
for (int i = 0; i < ne0; i++) {
13092-
for (int j = 0; j < ne1; j++) {
13093-
for (int k = 0; k < ne2_ne3; k++) {
13091+
for (int64_t i = 0; i < ne0; i++) {
13092+
for (int64_t j = 0; j < ne1; j++) {
13093+
for (int64_t k = 0; k < ne2_ne3; k++) {
1309413094
float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
1309513095
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
1309613096

@@ -13105,7 +13105,6 @@ static void ggml_compute_forward_alibi_f32(
1310513105
}
1310613106

1310713107
pdst[0] = i * m_k + src[0];
13108-
1310913108
}
1311013109
}
1311113110
}

llama.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1325,7 +1325,11 @@ static bool llama_kv_cache_init(
13251325
cache.cells.clear();
13261326
cache.cells.resize(n_ctx);
13271327

1328+
// TODO: this should be:
1329+
// cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*ggml_tensor_overhead());
1330+
// change it and test that it works
13281331
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
1332+
memset(cache.buf.data, 0, cache.buf.size);
13291333

13301334
struct ggml_init_params params;
13311335
params.mem_size = cache.buf.size;

0 commit comments

Comments
 (0)