From d3f53e088955e6463f90916a1746603599a5375c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 5 Dec 2022 06:15:33 +0100 Subject: [PATCH 1/7] example --- examples/gpt/.lightning | 2 + examples/gpt/.lightningignore | 3 + examples/gpt/LICENSE | 7 + examples/gpt/README.md | 3 + examples/gpt/bpe.py | 339 +++++++++++++++++++++++++++++ examples/gpt/config.py | 50 +++++ examples/gpt/data/download-data.sh | 1 + examples/gpt/model.py | 287 ++++++++++++++++++++++++ examples/gpt/requirements.txt | 0 examples/gpt/train.py | 186 ++++++++++++++++ examples/gpt/train_cloud.py | 20 ++ 11 files changed, 898 insertions(+) create mode 100644 examples/gpt/.lightning create mode 100644 examples/gpt/.lightningignore create mode 100644 examples/gpt/LICENSE create mode 100644 examples/gpt/README.md create mode 100644 examples/gpt/bpe.py create mode 100644 examples/gpt/config.py create mode 100644 examples/gpt/data/download-data.sh create mode 100644 examples/gpt/model.py create mode 100644 examples/gpt/requirements.txt create mode 100644 examples/gpt/train.py create mode 100644 examples/gpt/train_cloud.py diff --git a/examples/gpt/.lightning b/examples/gpt/.lightning new file mode 100644 index 0000000..f3e862b --- /dev/null +++ b/examples/gpt/.lightning @@ -0,0 +1,2 @@ +cluster_id: litng-ai-03 +name: modest-bardeen-5468 diff --git a/examples/gpt/.lightningignore b/examples/gpt/.lightningignore new file mode 100644 index 0000000..0671247 --- /dev/null +++ b/examples/gpt/.lightningignore @@ -0,0 +1,3 @@ +__pycache__/ +.git/ +data/*.txt \ No newline at end of file diff --git a/examples/gpt/LICENSE b/examples/gpt/LICENSE new file mode 100644 index 0000000..3d89960 --- /dev/null +++ b/examples/gpt/LICENSE @@ -0,0 +1,7 @@ +The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/examples/gpt/README.md b/examples/gpt/README.md new file mode 100644 index 0000000..7a9dde3 --- /dev/null +++ b/examples/gpt/README.md @@ -0,0 +1,3 @@ +# Character-level Language Model + +Code modified from Andrej Karpathy's minGPT repository. \ No newline at end of file diff --git a/examples/gpt/bpe.py b/examples/gpt/bpe.py new file mode 100644 index 0000000..d8b8d88 --- /dev/null +++ b/examples/gpt/bpe.py @@ -0,0 +1,339 @@ +""" +bpe is short for Byte Pair Encoder. It translates arbitrary utf-8 strings into +sequences of integers, where each integer represents small chunks of commonly +occuring characters. This implementation is based on openai's gpt2 encoder.py: +https://github.com/openai/gpt-2/blob/master/src/encoder.py +but was mildly modified because the original implementation is a bit confusing. +I also tried to add as many comments as possible, my own understanding of what's +going on. +""" + +import json +import os + +import regex as re +import requests +import torch + +# ----------------------------------------------------------------------------- + + +def bytes_to_unicode(): + """ + Every possible byte (really an integer 0..255) gets mapped by OpenAI to a unicode + character that represents it visually. Some bytes have their appearance preserved + because they don't cause any trouble. These are defined in list bs. For example: + chr(33) returns "!", so in the returned dictionary we simply have d[33] -> "!". + However, chr(0), for example, is '\x00', which looks ugly. So OpenAI maps these + bytes, into new characters in a range where chr() returns a single nice character. + So in the final dictionary we have d[0] -> 'Ā' instead, which is just chr(0 + 2**8). + In particular, the space character is 32, which we can see by ord(' '). Instead, + this function will shift space (32) by 256 to 288, so d[32] -> 'Ġ'. + So this is just a simple one-to-one mapping of bytes 0..255 into unicode characters + that "look nice", either in their original form, or a funny shifted character + like 'Ā', or 'Ġ', etc. + """ + # the 188 integers that render fine in their original form and need no shifting + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] # all integers b in bs will simply map to chr(b) in the output dict + # now get the representations of the other 68 integers that do need shifting + # each will get mapped chr(256 + n), where n will grow from 0...67 in the loop + n = 0 + for b in range(2**8): + if b not in bs: + # if this byte is "ugly" then map it to the next available "nice" character + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + d = dict(zip(bs, cs)) + return d + + +def get_pairs(word): + """ + Return all bigrams as a set of tuples, of consecutive elements in the iterable word. + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class Encoder: + def __init__(self, encoder, bpe_merges): + # byte encoder/decoder + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + # bpe token encoder/decoder + self.encoder = encoder + self.decoder = {v: k for k, v in self.encoder.items()} + # bpe merge list that defines the bpe "tree", of tuples (a,b) that are to merge to token ab + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + # the splitting pattern used for pre-tokenization + # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions <-- original openai comment + """ + ok so what is this regex looking for, exactly? + python re reference: https://docs.python.org/3/library/re.html + - the vertical bars | is OR, so re.findall will chunkate text as the pieces match, from left to right + - '\'s' would split up things like Andrej's -> (Andrej, 's) + - ' ?\p{L}': optional space followed by 1+ unicode code points in the category "letter" + - ' ?\p{N}': optional space followed by 1+ unicode code points in the category "number" + - ' ?[^\s\p{L}\p{N}]+': optional space, then 1+ things that are NOT a whitespace, letter or number + - '\s+(?!\S)': 1+ whitespace characters (e.g. space or tab or etc) UNLESS they are followed by non-whitespace + so this will consume whitespace characters in a sequence but exclude the last whitespace in + that sequence. that last whitespace has the opportunity to then match the optional ' ?' in + earlier patterns. + - '\s+': 1+ whitespace characters, intended probably to catch a full trailing sequence of whitespaces at end of string + So TLDR: + - we are special casing a few common apostrophe constructs ('s, 't, 're, ...) and making those into separate tokens + - we then separate out strings into consecutive chunks of 1) letters, 2) numbers, 3) non-letter-numbers, 4) whitespaces + """ + self.pat = re.compile( + r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" + ) + self.cache = {} + + def bpe(self, token): + """ + this function uses self.bpe_ranks to iteratively merge all the possible bpe tokens + up the tree. token is a string of one individual 'word' (after regex tokenization) + and after byte encoding, e.g. 'Ġthere'. + """ + # token is a string of one individual 'word', after byte encoding, e.g. 'Ġthere' + + # memoization, for efficiency + if token in self.cache: + return self.cache[token] + + word = tuple(token) # individual characters that make up the token, in a tuple + pairs = get_pairs(word) # get all bigrams + + if not pairs: + return token + + while True: + + # find the next lowest rank bigram that can be merged + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break # no more bigrams are eligible to be merged + first, second = bigram + + # we will now replace all occurences of (first, second) in the list of current + # words into one merged token first_second, in the output list new_words + new_word = [] + i = 0 + while i < len(word): + + # find the next occurence of first in the sequence of current words + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + # if this occurence is also followed by second, then merge them into one + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + + # all occurences of (first, second) have been merged to first_second + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + + # concat all words into a string, and use ' ' as the separator. Note that + # by now all characters have been byte encoded, guaranteeing that ' ' is + # not used in the actual data and is a 'special' delimiter character + word = " ".join(word) + + # cache the result and return + self.cache[token] = word + return word + + def encode(self, text): + """string goes in, list of integers comes out""" + bpe_idx = [] + # pre-tokenize the input text into string tokens (words, roughly speaking) + tokens = re.findall(self.pat, text) + # process each token into BPE integers + for token in tokens: + # encode the token as a bytes (b'') object + token_bytes = token.encode("utf-8") + # translate all bytes to their unicode string representation and flatten + token_translated = "".join(self.byte_encoder[b] for b in token_bytes) + # perform all the applicable bpe merges according to self.bpe_ranks + token_merged = self.bpe(token_translated).split(" ") + # translate all bpe tokens to integers + token_ix = [self.encoder[bpe_token] for bpe_token in token_merged] + # extend our running list of all output integers + bpe_idx.extend(token_ix) + return bpe_idx + + def encode_and_show_work(self, text): + """debugging function, same as encode but returns all intermediate work""" + bpe_idx = [] + parts = [] + tokens = re.findall(self.pat, text) + for token in tokens: + token_bytes = token.encode("utf-8") + token_translated = "".join(self.byte_encoder[b] for b in token_bytes) + token_merged = self.bpe(token_translated).split(" ") + token_ix = [self.encoder[bpe_token] for bpe_token in token_merged] + bpe_idx.extend(token_ix) + parts.append( + { + "token": token, + "token_bytes": token_bytes, + "token_translated": token_translated, + "token_merged": token_merged, + "token_ix": token_ix, + } + ) + out = { + "bpe_idx": bpe_idx, # the actual output sequence + "tokens": tokens, # result of pre-tokenization + "parts": parts, # intermediates for each token part + } + return out + + def decode(self, bpe_idx): + """list of integers comes in, string comes out""" + # inverse map the integers to get the tokens + tokens_merged = [self.decoder[token] for token in bpe_idx] + # inverse the byte encoder, e.g. recovering 'Ġ' -> ' ', and get the bytes + tokens_flat = "".join(tokens_merged) + tokens_bytes = bytearray([self.byte_decoder[c] for c in tokens_flat]) + # recover the full utf-8 string + text = tokens_bytes.decode("utf-8", errors="replace") + return text + + +def get_file(local_file, remote_file): + """downloads remote_file to local_file if necessary""" + if not os.path.isfile(local_file): + print(f"downloading {remote_file} to {local_file}") + response = requests.get(remote_file) + open(local_file, "wb").write(response.content) + + +def get_encoder(): + """ + Returns an instance of the GPT BPE Encoder/Decoder + and handles caching of "database" files. + """ + home_dir = os.path.expanduser("~") + cache_dir = os.path.join(home_dir, ".cache", "mingpt") + os.makedirs(cache_dir, exist_ok=True) + + # load encoder.json that has the raw mappings from token -> bpe index + encoder_local_file = os.path.join(cache_dir, "encoder.json") + encoder_remote_file = ( + "https://openaipublic.blob.core.windows.net/gpt-2/models/124M/encoder.json" + ) + get_file(encoder_local_file, encoder_remote_file) + with open(encoder_local_file, "r") as f: + encoder = json.load(f) + assert ( + len(encoder) == 50257 + ) # 256 individual byte tokens, 50,000 merged tokens, and 1 special <|endoftext|> token + + # load vocab.bpe that contains the bpe merges, i.e. the bpe tree structure + # in the form tuples (a, b), that indicate that (a, b) is to be merged to one token ab + vocab_local_file = os.path.join(cache_dir, "vocab.bpe") + vocab_remote_file = ( + "https://openaipublic.blob.core.windows.net/gpt-2/models/124M/vocab.bpe" + ) + get_file(vocab_local_file, vocab_remote_file) + with open(vocab_local_file, "r", encoding="utf-8") as f: + bpe_data = f.read() + # light postprocessing: strip the version on first line and the last line is a blank + bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]] + assert len(bpe_merges) == 50000 # 50,000 merged tokens + + # construct the Encoder object and return + enc = Encoder(encoder, bpe_merges) + return enc + + +# ----------------------------------------------------------------------------- + + +class BPETokenizer: + """PyTorch-aware class that wraps the Encoder above""" + + def __init__(self): + self.encoder = get_encoder() + + def __call__(self, text, return_tensors="pt"): + # PyTorch only; here because we want to match huggingface/transformers interface + assert return_tensors == "pt" + # single string input for now, in the future potentially a list of strings + assert isinstance(text, str) + # encode and create a "batch dimension" of 1 + idx = [self.encoder.encode(text)] + # wrap into PyTorch tensor + out = torch.tensor(idx, dtype=torch.long) + return out + + def decode(self, idx): + # ensure a simple 1D tensor for now + assert idx.ndim == 1 + # decode indices to text + text = self.encoder.decode(idx.tolist()) + return text + + +if __name__ == "__main__": + + # here is an encoding example + text = "Hello!! I'm Andrej Karpathy. It's 2022. w00t :D 🤗" + e = get_encoder() + r = e.encode_and_show_work(text) + + print("Original text is:") + print(text) + print("First the text gets pre-tokenized, broken up into chunks, the outcome is:") + print(r["tokens"]) + # ['Hello', '!!', ' I', "'m", ' Andrej', ' Karpathy', '.', ' It', "'s", ' 2022', '.', ' w', '00', 't', ' :', 'D', ' 🤗'] + print("Then we iterate over each chunk and process them in turn...") + for part in r["parts"]: + print(part) + # {'token': 'Hello', 'token_bytes': b'Hello', 'token_translated': 'Hello', 'token_merged': ['Hello'], 'token_ix': [15496]} + # {'token': '!!', 'token_bytes': b'!!', 'token_translated': '!!', 'token_merged': ['!!'], 'token_ix': [3228]} + # {'token': ' I', 'token_bytes': b' I', 'token_translated': 'ĠI', 'token_merged': ['ĠI'], 'token_ix': [314]} + # {'token': "'m", 'token_bytes': b"'m", 'token_translated': "'m", 'token_merged': ["'m"], 'token_ix': [1101]} + # {'token': ' Andrej', 'token_bytes': b' Andrej', 'token_translated': 'ĠAndrej', 'token_merged': ['ĠAndre', 'j'], 'token_ix': [10948, 73]} + # {'token': ' Karpathy', 'token_bytes': b' Karpathy', 'token_translated': 'ĠKarpathy', 'token_merged': ['ĠK', 'arp', 'athy'], 'token_ix': [509, 5117, 10036]} + # {'token': '.', 'token_bytes': b'.', 'token_translated': '.', 'token_merged': ['.'], 'token_ix': [13]} + # {'token': ' It', 'token_bytes': b' It', 'token_translated': 'ĠIt', 'token_merged': ['ĠIt'], 'token_ix': [632]} + # {'token': "'s", 'token_bytes': b"'s", 'token_translated': "'s", 'token_merged': ["'s"], 'token_ix': [338]} + # {'token': ' 2022', 'token_bytes': b' 2022', 'token_translated': 'Ġ2022', 'token_merged': ['Ġ2022'], 'token_ix': [33160]} + # {'token': '.', 'token_bytes': b'.', 'token_translated': '.', 'token_merged': ['.'], 'token_ix': [13]} + # {'token': ' w', 'token_bytes': b' w', 'token_translated': 'Ġw', 'token_merged': ['Ġw'], 'token_ix': [266]} + # {'token': '00', 'token_bytes': b'00', 'token_translated': '00', 'token_merged': ['00'], 'token_ix': [405]} + # {'token': 't', 'token_bytes': b't', 'token_translated': 't', 'token_merged': ['t'], 'token_ix': [83]} + # {'token': ' :', 'token_bytes': b' :', 'token_translated': 'Ġ:', 'token_merged': ['Ġ:'], 'token_ix': [1058]} + # {'token': 'D', 'token_bytes': b'D', 'token_translated': 'D', 'token_merged': ['D'], 'token_ix': [35]} + # {'token': ' 🤗', 'token_bytes': b' \xf0\x9f\xa4\x97', 'token_translated': 'Ġð٤Ĺ', 'token_merged': ['ĠðŁ', '¤', 'Ĺ'], 'token_ix': [12520, 97, 245]} + # (refer to the code inside Encoder.encode for what these intermediates are) + print("and the final outcome is concatenating and flattening all the token_ix:") + print(r["bpe_idx"]) + # [15496, 3228, 314, 1101, 10948, 73, 509, 5117, 10036, 13, 632, 338, 33160, 13, 266, 405, 83, 1058, 35, 12520, 97, 245] + # this would then become the integer input sequence to the transformer + print("ready to feed into a Transformer!") diff --git a/examples/gpt/config.py b/examples/gpt/config.py new file mode 100644 index 0000000..2779cd8 --- /dev/null +++ b/examples/gpt/config.py @@ -0,0 +1,50 @@ +from dataclasses import dataclass +from typing import Optional, Tuple + + +@dataclass +class GPTConfig: + model_type: str + vocab_size: int + block_size: int + embd_pdrop: float = 0.1 + resid_pdrop: float = 0.1 + attn_pdrop: float = 0.1 + n_layer: Optional[int] = None + n_head: Optional[int] = None + n_embd: Optional[int] = None + + def __post_init__(self): + type_given = self.model_type is not None + params_given = all( + (self.n_layer is not None, self.n_head is not None, self.n_embd is not None) + ) + assert type_given ^ params_given + if type_given: + # translate from model_type to detailed configuration + values = { + # names follow the huggingface naming conventions + # GPT-1 + "openai-gpt": dict(n_layer=12, n_head=12, n_embd=768), # 117M params + # GPT-2 configs + "gpt2": dict(n_layer=12, n_head=12, n_embd=768), # 124M params + "gpt2-medium": dict(n_layer=24, n_head=16, n_embd=1024), # 350M params + "gpt2-large": dict(n_layer=36, n_head=20, n_embd=1280), # 774M params + "gpt2-xl": dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params + }[self.model_type] + self.n_layer = values["n_layer"] + self.n_head = values["n_head"] + self.n_embd = values["n_embd"] + + +@dataclass +class TrainerConfig: + block_size: int + num_workers: int + batch_size: int + learning_rate: float + betas: Tuple[int] + weight_decay: float + grad_norm_clip: float + seed: int = 1 + max_iters: int = -1 diff --git a/examples/gpt/data/download-data.sh b/examples/gpt/data/download-data.sh new file mode 100644 index 0000000..ee942db --- /dev/null +++ b/examples/gpt/data/download-data.sh @@ -0,0 +1 @@ +wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -O tinyshakespeare.txt diff --git a/examples/gpt/model.py b/examples/gpt/model.py new file mode 100644 index 0000000..659ecb8 --- /dev/null +++ b/examples/gpt/model.py @@ -0,0 +1,287 @@ +""" +Full definition of a GPT Language Model, all of it in this single file. + +References: +1) the official GPT-2 TensorFlow implementation released by OpenAI: +https://github.com/openai/gpt-2/blob/master/src/model.py +2) huggingface/transformers PyTorch implementation: +https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py +""" + +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +from torch.nn import functional as F + + +class NewGELU(nn.Module): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). + Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 + """ + + def forward(self, x): + return ( + 0.5 + * x + * ( + 1.0 + + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)) + ) + ) + ) + + +class CausalSelfAttention(nn.Module): + """ + A vanilla multi-head masked self-attention layer with a projection at the end. + It is possible to use torch.nn.MultiheadAttention here but I am including an + explicit implementation here to show that there is nothing too scary here. + """ + + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd) + # regularization + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "bias", + torch.tril(torch.ones(config.block_size, config.block_size)).view( + 1, 1, config.block_size, config.block_size + ), + ) + self.n_head = config.n_head + self.n_embd = config.n_embd + + def forward(self, x): + ( + B, + T, + C, + ) = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = ( + y.transpose(1, 2).contiguous().view(B, T, C) + ) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + +class Block(nn.Module): + """an unassuming Transformer block""" + + def __init__(self, config): + super().__init__() + self.ln_1 = nn.LayerNorm(config.n_embd) + self.attn = CausalSelfAttention(config) + self.ln_2 = nn.LayerNorm(config.n_embd) + self.mlp = nn.ModuleDict( + dict( + c_fc=nn.Linear(config.n_embd, 4 * config.n_embd), + c_proj=nn.Linear(4 * config.n_embd, config.n_embd), + act=NewGELU(), + dropout=nn.Dropout(config.resid_pdrop), + ) + ) + m = self.mlp + self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward + + def forward(self, x): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlpf(self.ln_2(x)) + return x + + +class GPT(nn.Module): + """GPT Language Model""" + + def __init__(self, config): + super().__init__() + assert config.vocab_size is not None + assert config.block_size is not None + self.block_size = config.block_size + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.vocab_size, config.n_embd), + wpe=nn.Embedding(config.block_size, config.n_embd), + drop=nn.Dropout(config.embd_pdrop), + h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), + ln_f=nn.LayerNorm(config.n_embd), + ) + ) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper + self.apply(self._init_weights) + for pn, p in self.named_parameters(): + if pn.endswith("c_proj.weight"): + torch.nn.init.normal_( + p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) + ) + + @property + def num_parameters(self): + # note: we don't count the decoder parameters in lm_head + return sum(p.numel() for p in self.transformer.parameters()) + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + torch.nn.init.zeros_(module.bias) + torch.nn.init.ones_(module.weight) + + def configure_optimizers(self, train_config): + """ + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the model into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + We are then returning the PyTorch optimizer object. + """ + + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + whitelist_weight_modules = (torch.nn.Linear,) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + for mn, m in self.named_modules(): + for pn, p in m.named_parameters(): + fpn = "%s.%s" % (mn, pn) if mn else pn # full param name + # random note: because named_modules and named_parameters are recursive + # we will see the same tensors p many many times. but doing it this way + # allows us to know which parent module any tensor p belongs to... + if pn.endswith("bias"): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + # validate that we considered every parameter + param_dict = {pn: p for pn, p in self.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + assert ( + len(inter_params) == 0 + ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) + assert ( + len(param_dict.keys() - union_params) == 0 + ), "parameters %s were not separated into either decay/no_decay set!" % ( + str(param_dict.keys() - union_params), + ) + + # create the pytorch optimizer object + optim_groups = [ + { + "params": [param_dict[pn] for pn in sorted(list(decay))], + "weight_decay": train_config.weight_decay, + }, + { + "params": [param_dict[pn] for pn in sorted(list(no_decay))], + "weight_decay": 0.0, + }, + ] + optimizer = torch.optim.AdamW( + optim_groups, lr=train_config.learning_rate, betas=train_config.betas + ) + return optimizer + + def forward(self, idx, targets=None): + device = idx.device + b, t = idx.size() + assert ( + t <= self.block_size + ), f"Cannot forward sequence of length {t}, block size is only {self.block_size}" + pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze( + 0 + ) # shape (1, t) + + # forward the GPT model itself + tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + pos_emb = self.transformer.wpe( + pos + ) # position embeddings of shape (1, t, n_embd) + x = self.transformer.drop(tok_emb + pos_emb) + for block in self.transformer.h: + x = block(x) + x = self.transformer.ln_f(x) + logits = self.lm_head(x) + + # if we are given some desired targets also calculate the loss + loss = None + if targets is not None: + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1 + ) + + return logits, loss + + @torch.no_grad() + def generate( + self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None + ): + """ + Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete + the sequence max_new_tokens times, feeding the predictions back into the model each time. + Most likely you'll want to make sure to be in model.eval() mode of operation for this. + """ + for _ in range(max_new_tokens): + # if the sequence context is growing too long we must crop it at block_size + idx_cond = ( + idx if idx.size(1) <= self.block_size else idx[:, -self.block_size :] + ) + # forward the model to get the logits for the index in the sequence + logits, _ = self(idx_cond) + # pluck the logits at the final step and scale by desired temperature + logits = logits[:, -1, :] / temperature + # optionally crop the logits to only the top k options + if top_k is not None: + v, _ = torch.topk(logits, top_k) + logits[logits < v[:, [-1]]] = -float("Inf") + # apply softmax to convert logits to (normalized) probabilities + probs = F.softmax(logits, dim=-1) + # either sample from the distribution or take the most likely element + if do_sample: + idx_next = torch.multinomial(probs, num_samples=1) + else: + _, idx_next = torch.topk(probs, k=1, dim=-1) + # append sampled index to the running sequence and continue + idx = torch.cat((idx, idx_next), dim=1) + + return idx diff --git a/examples/gpt/requirements.txt b/examples/gpt/requirements.txt new file mode 100644 index 0000000..e69de29 diff --git a/examples/gpt/train.py b/examples/gpt/train.py new file mode 100644 index 0000000..8beec70 --- /dev/null +++ b/examples/gpt/train.py @@ -0,0 +1,186 @@ +""" +Trains a character-level language model. +""" +import functools +import time + + +import torch +from lightning_lite import seed_everything +from lightning_lite.lite import LightningLite +from lightning_lite.strategies.fsdp import FSDPStrategy +from lightning_lite.strategies import STRATEGY_REGISTRY +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, apply_activation_checkpointing, checkpoint_wrapper) +from torch.distributed.fsdp import BackwardPrefetch, CPUOffload +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.utils.data import Dataset +from torch.utils.data.dataloader import DataLoader + +from model import GPT, Block +from config import GPTConfig, TrainerConfig + + +model_config = GPTConfig( + model_type="gpt2-xl", + vocab_size=None, + block_size=128, + embd_pdrop=0.1, + resid_pdrop=0.1, + attn_pdrop=0.1, +) + + +trainer_config = TrainerConfig( + num_workers=4, + max_iters=100, + block_size=128, + batch_size=64, + learning_rate=3e-4, + betas=(0.9, 0.95), + weight_decay=0.1, # only applied on matmul weights + grad_norm_clip=1.0, +) + +auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, transformer_layer_cls={Block} + ) +check_fn = lambda submodule: isinstance(submodule, Block) +wrapper = functools.partial( + checkpoint_wrapper, + offload_to_cpu=False, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, +) +STRATEGY_REGISTRY.register( + name="fsdp-gpt", + strategy=FSDPStrategy, + description="FSDP strategy with memory optimizations enabled for GPT large scale pretraining.", + auto_wrap_policy=auto_wrap_policy, + backward_prefetch=BackwardPrefetch.BACKWARD_PRE +) + + +class CharDataset(Dataset): + def __init__(self, textfile, block_size): + self.data = open(textfile, "r").read() + self.block_size = block_size + chars = sorted(list(set(self.data))) + self.stoi = {ch: i for i, ch in enumerate(chars)} + self.itos = {i: ch for i, ch in enumerate(chars)} + self.vocab_size = len(chars) + + + def get_vocab_size(self): + return self.vocab_size + + def get_block_size(self): + return self.block_size + + def __len__(self): + return len(self.data) - self.block_size + + def __getitem__(self, idx): + # grab a chunk of (block_size + 1) characters from the data + chunk = self.data[idx : idx + self.block_size + 1] + # encode every character to an integer + dix = [self.stoi[s] for s in chunk] + # return as tensors + x = torch.tensor(dix[:-1], dtype=torch.long) + y = torch.tensor(dix[1:], dtype=torch.long) + return x, y + + +def main(): + seed_everything(trainer_config.seed) + + # TODO: precision 16 and cpu offload hangs + lite = LightningLite( + accelerator="cuda", + devices=-1, + precision=16, + strategy="fsdp-gpt", + # num_nodes=2, + ) + # lite.launch() + + # construct the training dataset + train_dataset = CharDataset(textfile="data/tinyshakespeare.txt", block_size=model_config.block_size) + + # construct the model + model_config.vocab_size = train_dataset.get_vocab_size() + + lite.print(model_config) + lite.print(trainer_config) + + # setup the model and optimizer + with lite.sharded_model(): + model = GPT(model_config) + model = lite.setup_module(model) + + lite.print(f"Number of parameters per device: {model.num_parameters / 1e6:.1f} M") + lite.print(f"Total number of parameters: ~ {lite.world_size * model.num_parameters / 1e6:.1f} M") + + apply_activation_checkpointing( + model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn + ) + + # TODO: support multiple param groups for FSDP + # optimizer = model.configure_optimizers(config.trainer) + optimizer = torch.optim.AdamW( + model.parameters(), lr=trainer_config.learning_rate, betas=trainer_config.betas + ) + optimizer = lite.setup_optimizers(optimizer) + + train_loader = DataLoader( + train_dataset, + # TODO: fix this in Lite + # sampler=torch.utils.data.RandomSampler(train_dataset, replacement=True, num_samples=int(1e10)), + shuffle=True, + pin_memory=True, + batch_size=trainer_config.batch_size, + num_workers=trainer_config.num_workers, + ) + train_loader = lite.setup_dataloaders(train_loader) + + model.train() + iteration = 0 + iter_dt = 0 + iter_time = time.time() + data_iter = iter(train_loader) + + while True: + try: + batch = next(data_iter) + except StopIteration: + data_iter = iter(train_loader) + batch = next(data_iter) + + x, y = batch + + _, loss = model(x, y) + model.zero_grad(set_to_none=True) + lite.backward(loss) + torch.nn.utils.clip_grad_norm_( + model.parameters(), trainer_config.grad_norm_clip + ) + optimizer.step() + + if iteration % 10 == 0: + lite.print( + f"iteration time {iter_dt * 1000:.2f}ms; iteration {iteration}; train loss {loss.item():.5f}" + ) + + iteration += 1 + tnow = time.time() + iter_dt = tnow - iter_time + iter_time = tnow + + if trainer_config.max_iters != -1 and iteration >= trainer_config.max_iters: + break + + # For optimal memory throughput, make sure the summary shows 0 cudaMalloc retries and otherwise try lowering the batch size. + lite.print(torch.cuda.memory_summary()) + + +if __name__ == "__main__": + main() diff --git a/examples/gpt/train_cloud.py b/examples/gpt/train_cloud.py new file mode 100644 index 0000000..5769fc1 --- /dev/null +++ b/examples/gpt/train_cloud.py @@ -0,0 +1,20 @@ +# ! PACKAGE_NAME=lite pip install git+https://github.com/Lightning-AI/lightning +# ! cd data && bash download-data.sh + +from lightning import LightningApp, LightningWork, CloudCompute +from lightning.app.components import LiteMultiNode +from train import main + + +class Work(LightningWork): + def run(self): + main() + + +app = LightningApp( + LiteMultiNode( + Work, + num_nodes=2, + cloud_compute=CloudCompute(name="gpu-fast-multi"), + ) +) \ No newline at end of file From 25967efc46fdd2722fa6a281b5dd0df8f17ae8ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 5 Dec 2022 06:46:23 +0100 Subject: [PATCH 2/7] examples --- examples/gpt/LICENSE | 7 -- examples/gpt/gpt/__init__.py | 0 examples/gpt/{ => gpt}/bpe.py | 89 +++++++--------------- examples/gpt/{ => gpt}/config.py | 4 +- examples/gpt/gpt/dataset.py | 54 +++++++++++++ examples/gpt/{ => gpt}/model.py | 90 +++++++++------------- examples/gpt/train.py | 125 +++++++++++-------------------- examples/gpt/train_cloud.py | 37 ++++++++- pyproject.toml | 10 +++ 9 files changed, 204 insertions(+), 212 deletions(-) delete mode 100644 examples/gpt/LICENSE create mode 100644 examples/gpt/gpt/__init__.py rename examples/gpt/{ => gpt}/bpe.py (75%) rename examples/gpt/{ => gpt}/config.py (92%) create mode 100644 examples/gpt/gpt/dataset.py rename examples/gpt/{ => gpt}/model.py (81%) create mode 100644 pyproject.toml diff --git a/examples/gpt/LICENSE b/examples/gpt/LICENSE deleted file mode 100644 index 3d89960..0000000 --- a/examples/gpt/LICENSE +++ /dev/null @@ -1,7 +0,0 @@ -The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy - -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/examples/gpt/gpt/__init__.py b/examples/gpt/gpt/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/gpt/bpe.py b/examples/gpt/gpt/bpe.py similarity index 75% rename from examples/gpt/bpe.py rename to examples/gpt/gpt/bpe.py index d8b8d88..1570e58 100644 --- a/examples/gpt/bpe.py +++ b/examples/gpt/gpt/bpe.py @@ -1,3 +1,25 @@ +# MIT License +# +# Copyright (c) 2020 Andrej Karpathy +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + """ bpe is short for Byte Pair Encoder. It translates arbitrary utf-8 strings into sequences of integers, where each integer represents small chunks of commonly @@ -15,8 +37,6 @@ import requests import torch -# ----------------------------------------------------------------------------- - def bytes_to_unicode(): """ @@ -34,11 +54,7 @@ def bytes_to_unicode(): like 'Ā', or 'Ġ', etc. """ # the 188 integers that render fine in their original form and need no shifting - bs = ( - list(range(ord("!"), ord("~") + 1)) - + list(range(ord("¡"), ord("¬") + 1)) - + list(range(ord("®"), ord("ÿ") + 1)) - ) + bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) cs = bs[:] # all integers b in bs will simply map to chr(b) in the output dict # now get the representations of the other 68 integers that do need shifting # each will get mapped chr(256 + n), where n will grow from 0...67 in the loop @@ -95,9 +111,7 @@ def __init__(self, encoder, bpe_merges): - we are special casing a few common apostrophe constructs ('s, 't, 're, ...) and making those into separate tokens - we then separate out strings into consecutive chunks of 1) letters, 2) numbers, 3) non-letter-numbers, 4) whitespaces """ - self.pat = re.compile( - r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" - ) + self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") self.cache = {} def bpe(self, token): @@ -243,22 +257,16 @@ def get_encoder(): # load encoder.json that has the raw mappings from token -> bpe index encoder_local_file = os.path.join(cache_dir, "encoder.json") - encoder_remote_file = ( - "https://openaipublic.blob.core.windows.net/gpt-2/models/124M/encoder.json" - ) + encoder_remote_file = "https://openaipublic.blob.core.windows.net/gpt-2/models/124M/encoder.json" get_file(encoder_local_file, encoder_remote_file) with open(encoder_local_file, "r") as f: encoder = json.load(f) - assert ( - len(encoder) == 50257 - ) # 256 individual byte tokens, 50,000 merged tokens, and 1 special <|endoftext|> token + assert len(encoder) == 50257 # 256 individual byte tokens, 50,000 merged tokens, and 1 special <|endoftext|> token # load vocab.bpe that contains the bpe merges, i.e. the bpe tree structure # in the form tuples (a, b), that indicate that (a, b) is to be merged to one token ab vocab_local_file = os.path.join(cache_dir, "vocab.bpe") - vocab_remote_file = ( - "https://openaipublic.blob.core.windows.net/gpt-2/models/124M/vocab.bpe" - ) + vocab_remote_file = "https://openaipublic.blob.core.windows.net/gpt-2/models/124M/vocab.bpe" get_file(vocab_local_file, vocab_remote_file) with open(vocab_local_file, "r", encoding="utf-8") as f: bpe_data = f.read() @@ -271,9 +279,6 @@ def get_encoder(): return enc -# ----------------------------------------------------------------------------- - - class BPETokenizer: """PyTorch-aware class that wraps the Encoder above""" @@ -297,43 +302,3 @@ def decode(self, idx): # decode indices to text text = self.encoder.decode(idx.tolist()) return text - - -if __name__ == "__main__": - - # here is an encoding example - text = "Hello!! I'm Andrej Karpathy. It's 2022. w00t :D 🤗" - e = get_encoder() - r = e.encode_and_show_work(text) - - print("Original text is:") - print(text) - print("First the text gets pre-tokenized, broken up into chunks, the outcome is:") - print(r["tokens"]) - # ['Hello', '!!', ' I', "'m", ' Andrej', ' Karpathy', '.', ' It', "'s", ' 2022', '.', ' w', '00', 't', ' :', 'D', ' 🤗'] - print("Then we iterate over each chunk and process them in turn...") - for part in r["parts"]: - print(part) - # {'token': 'Hello', 'token_bytes': b'Hello', 'token_translated': 'Hello', 'token_merged': ['Hello'], 'token_ix': [15496]} - # {'token': '!!', 'token_bytes': b'!!', 'token_translated': '!!', 'token_merged': ['!!'], 'token_ix': [3228]} - # {'token': ' I', 'token_bytes': b' I', 'token_translated': 'ĠI', 'token_merged': ['ĠI'], 'token_ix': [314]} - # {'token': "'m", 'token_bytes': b"'m", 'token_translated': "'m", 'token_merged': ["'m"], 'token_ix': [1101]} - # {'token': ' Andrej', 'token_bytes': b' Andrej', 'token_translated': 'ĠAndrej', 'token_merged': ['ĠAndre', 'j'], 'token_ix': [10948, 73]} - # {'token': ' Karpathy', 'token_bytes': b' Karpathy', 'token_translated': 'ĠKarpathy', 'token_merged': ['ĠK', 'arp', 'athy'], 'token_ix': [509, 5117, 10036]} - # {'token': '.', 'token_bytes': b'.', 'token_translated': '.', 'token_merged': ['.'], 'token_ix': [13]} - # {'token': ' It', 'token_bytes': b' It', 'token_translated': 'ĠIt', 'token_merged': ['ĠIt'], 'token_ix': [632]} - # {'token': "'s", 'token_bytes': b"'s", 'token_translated': "'s", 'token_merged': ["'s"], 'token_ix': [338]} - # {'token': ' 2022', 'token_bytes': b' 2022', 'token_translated': 'Ġ2022', 'token_merged': ['Ġ2022'], 'token_ix': [33160]} - # {'token': '.', 'token_bytes': b'.', 'token_translated': '.', 'token_merged': ['.'], 'token_ix': [13]} - # {'token': ' w', 'token_bytes': b' w', 'token_translated': 'Ġw', 'token_merged': ['Ġw'], 'token_ix': [266]} - # {'token': '00', 'token_bytes': b'00', 'token_translated': '00', 'token_merged': ['00'], 'token_ix': [405]} - # {'token': 't', 'token_bytes': b't', 'token_translated': 't', 'token_merged': ['t'], 'token_ix': [83]} - # {'token': ' :', 'token_bytes': b' :', 'token_translated': 'Ġ:', 'token_merged': ['Ġ:'], 'token_ix': [1058]} - # {'token': 'D', 'token_bytes': b'D', 'token_translated': 'D', 'token_merged': ['D'], 'token_ix': [35]} - # {'token': ' 🤗', 'token_bytes': b' \xf0\x9f\xa4\x97', 'token_translated': 'Ġð٤Ĺ', 'token_merged': ['ĠðŁ', '¤', 'Ĺ'], 'token_ix': [12520, 97, 245]} - # (refer to the code inside Encoder.encode for what these intermediates are) - print("and the final outcome is concatenating and flattening all the token_ix:") - print(r["bpe_idx"]) - # [15496, 3228, 314, 1101, 10948, 73, 509, 5117, 10036, 13, 632, 338, 33160, 13, 266, 405, 83, 1058, 35, 12520, 97, 245] - # this would then become the integer input sequence to the transformer - print("ready to feed into a Transformer!") diff --git a/examples/gpt/config.py b/examples/gpt/gpt/config.py similarity index 92% rename from examples/gpt/config.py rename to examples/gpt/gpt/config.py index 2779cd8..c706fb1 100644 --- a/examples/gpt/config.py +++ b/examples/gpt/gpt/config.py @@ -16,9 +16,7 @@ class GPTConfig: def __post_init__(self): type_given = self.model_type is not None - params_given = all( - (self.n_layer is not None, self.n_head is not None, self.n_embd is not None) - ) + params_given = all((self.n_layer is not None, self.n_head is not None, self.n_embd is not None)) assert type_given ^ params_given if type_given: # translate from model_type to detailed configuration diff --git a/examples/gpt/gpt/dataset.py b/examples/gpt/gpt/dataset.py new file mode 100644 index 0000000..05da89e --- /dev/null +++ b/examples/gpt/gpt/dataset.py @@ -0,0 +1,54 @@ +# MIT License +# +# Copyright (c) 2020 Andrej Karpathy +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +import torch +from torch.utils.data import Dataset + + +class CharDataset(Dataset): + def __init__(self, textfile, block_size): + self.data = open(textfile, "r").read() + self.block_size = block_size + chars = sorted(list(set(self.data))) + self.stoi = {ch: i for i, ch in enumerate(chars)} + self.itos = {i: ch for i, ch in enumerate(chars)} + self.vocab_size = len(chars) + + def get_vocab_size(self): + return self.vocab_size + + def get_block_size(self): + return self.block_size + + def __len__(self): + return len(self.data) - self.block_size + + def __getitem__(self, idx): + # grab a chunk of (block_size + 1) characters from the data + chunk = self.data[idx : idx + self.block_size + 1] + # encode every character to an integer + dix = [self.stoi[s] for s in chunk] + # return as tensors + x = torch.tensor(dix[:-1], dtype=torch.long) + y = torch.tensor(dix[1:], dtype=torch.long) + return x, y diff --git a/examples/gpt/model.py b/examples/gpt/gpt/model.py similarity index 81% rename from examples/gpt/model.py rename to examples/gpt/gpt/model.py index 659ecb8..186db0f 100644 --- a/examples/gpt/model.py +++ b/examples/gpt/gpt/model.py @@ -1,3 +1,25 @@ +# MIT License +# +# Copyright (c) 2020 Andrej Karpathy +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + """ Full definition of a GPT Language Model, all of it in this single file. @@ -9,7 +31,6 @@ """ import math -from dataclasses import dataclass import torch import torch.nn as nn @@ -23,16 +44,7 @@ class NewGELU(nn.Module): """ def forward(self, x): - return ( - 0.5 - * x - * ( - 1.0 - + torch.tanh( - math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)) - ) - ) - ) + return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) class CausalSelfAttention(nn.Module): @@ -71,15 +83,9 @@ def forward(self, x): # calculate query, key, values for all heads in batch and move head forward to be the batch dim q, k, v = self.c_attn(x).split(self.n_embd, dim=2) - k = k.view(B, T, self.n_head, C // self.n_head).transpose( - 1, 2 - ) # (B, nh, T, hs) - q = q.view(B, T, self.n_head, C // self.n_head).transpose( - 1, 2 - ) # (B, nh, T, hs) - v = v.view(B, T, self.n_head, C // self.n_head).transpose( - 1, 2 - ) # (B, nh, T, hs) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) @@ -87,9 +93,7 @@ def forward(self, x): att = F.softmax(att, dim=-1) att = self.attn_dropout(att) y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = ( - y.transpose(1, 2).contiguous().view(B, T, C) - ) # re-assemble all head outputs side by side + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side # output projection y = self.resid_dropout(self.c_proj(y)) @@ -144,9 +148,7 @@ def __init__(self, config): self.apply(self._init_weights) for pn, p in self.named_parameters(): if pn.endswith("c_proj.weight"): - torch.nn.init.normal_( - p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) - ) + torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) @property def num_parameters(self): @@ -197,14 +199,10 @@ def configure_optimizers(self, train_config): param_dict = {pn: p for pn, p in self.named_parameters()} inter_params = decay & no_decay union_params = decay | no_decay - assert ( - len(inter_params) == 0 - ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) + assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) assert ( len(param_dict.keys() - union_params) == 0 - ), "parameters %s were not separated into either decay/no_decay set!" % ( - str(param_dict.keys() - union_params), - ) + ), "parameters %s were not separated into either decay/no_decay set!" % (str(param_dict.keys() - union_params),) # create the pytorch optimizer object optim_groups = [ @@ -217,26 +215,18 @@ def configure_optimizers(self, train_config): "weight_decay": 0.0, }, ] - optimizer = torch.optim.AdamW( - optim_groups, lr=train_config.learning_rate, betas=train_config.betas - ) + optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) return optimizer def forward(self, idx, targets=None): device = idx.device b, t = idx.size() - assert ( - t <= self.block_size - ), f"Cannot forward sequence of length {t}, block size is only {self.block_size}" - pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze( - 0 - ) # shape (1, t) + assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}" + pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) # forward the GPT model itself tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) - pos_emb = self.transformer.wpe( - pos - ) # position embeddings of shape (1, t, n_embd) + pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) x = self.transformer.drop(tok_emb + pos_emb) for block in self.transformer.h: x = block(x) @@ -246,16 +236,12 @@ def forward(self, idx, targets=None): # if we are given some desired targets also calculate the loss loss = None if targets is not None: - loss = F.cross_entropy( - logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1 - ) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) return logits, loss @torch.no_grad() - def generate( - self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None - ): + def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None): """ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete the sequence max_new_tokens times, feeding the predictions back into the model each time. @@ -263,9 +249,7 @@ def generate( """ for _ in range(max_new_tokens): # if the sequence context is growing too long we must crop it at block_size - idx_cond = ( - idx if idx.size(1) <= self.block_size else idx[:, -self.block_size :] - ) + idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size :] # forward the model to get the logits for the index in the sequence logits, _ = self(idx_cond) # pluck the logits at the final step and scale by desired temperature diff --git a/examples/gpt/train.py b/examples/gpt/train.py index 8beec70..0752706 100644 --- a/examples/gpt/train.py +++ b/examples/gpt/train.py @@ -4,47 +4,24 @@ import functools import time - import torch +from gpt.config import GPTConfig, TrainerConfig +from gpt.dataset import CharDataset +from gpt.model import Block, GPT from lightning_lite import seed_everything from lightning_lite.lite import LightningLite -from lightning_lite.strategies.fsdp import FSDPStrategy from lightning_lite.strategies import STRATEGY_REGISTRY +from lightning_lite.strategies.fsdp import FSDPStrategy from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - CheckpointImpl, apply_activation_checkpointing, checkpoint_wrapper) + apply_activation_checkpointing, + checkpoint_wrapper, + CheckpointImpl, +) from torch.distributed.fsdp import BackwardPrefetch, CPUOffload from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy -from torch.utils.data import Dataset from torch.utils.data.dataloader import DataLoader -from model import GPT, Block -from config import GPTConfig, TrainerConfig - - -model_config = GPTConfig( - model_type="gpt2-xl", - vocab_size=None, - block_size=128, - embd_pdrop=0.1, - resid_pdrop=0.1, - attn_pdrop=0.1, -) - - -trainer_config = TrainerConfig( - num_workers=4, - max_iters=100, - block_size=128, - batch_size=64, - learning_rate=3e-4, - betas=(0.9, 0.95), - weight_decay=0.1, # only applied on matmul weights - grad_norm_clip=1.0, -) - -auto_wrap_policy = functools.partial( - transformer_auto_wrap_policy, transformer_layer_cls={Block} - ) +auto_wrap_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) check_fn = lambda submodule: isinstance(submodule, Block) wrapper = functools.partial( checkpoint_wrapper, @@ -52,46 +29,33 @@ checkpoint_impl=CheckpointImpl.NO_REENTRANT, ) STRATEGY_REGISTRY.register( - name="fsdp-gpt", - strategy=FSDPStrategy, - description="FSDP strategy with memory optimizations enabled for GPT large scale pretraining.", - auto_wrap_policy=auto_wrap_policy, - backward_prefetch=BackwardPrefetch.BACKWARD_PRE + name="fsdp-gpt", + strategy=FSDPStrategy, + description="FSDP strategy with memory optimizations enabled for GPT large scale pretraining.", + auto_wrap_policy=auto_wrap_policy, + backward_prefetch=BackwardPrefetch.BACKWARD_PRE, ) -class CharDataset(Dataset): - def __init__(self, textfile, block_size): - self.data = open(textfile, "r").read() - self.block_size = block_size - chars = sorted(list(set(self.data))) - self.stoi = {ch: i for i, ch in enumerate(chars)} - self.itos = {i: ch for i, ch in enumerate(chars)} - self.vocab_size = len(chars) - - - def get_vocab_size(self): - return self.vocab_size - - def get_block_size(self): - return self.block_size - - def __len__(self): - return len(self.data) - self.block_size - - def __getitem__(self, idx): - # grab a chunk of (block_size + 1) characters from the data - chunk = self.data[idx : idx + self.block_size + 1] - # encode every character to an integer - dix = [self.stoi[s] for s in chunk] - # return as tensors - x = torch.tensor(dix[:-1], dtype=torch.long) - y = torch.tensor(dix[1:], dtype=torch.long) - return x, y - - def main(): - seed_everything(trainer_config.seed) + model_config = GPTConfig( + model_type="gpt2-xl", + vocab_size=None, + block_size=128, + embd_pdrop=0.1, + resid_pdrop=0.1, + attn_pdrop=0.1, + ) + trainer_config = TrainerConfig( + num_workers=4, + max_iters=100, + block_size=128, + batch_size=64, + learning_rate=3e-4, + betas=(0.9, 0.95), + weight_decay=0.1, # only applied on matmul weights + grad_norm_clip=1.0, + ) # TODO: precision 16 and cpu offload hangs lite = LightningLite( @@ -101,12 +65,15 @@ def main(): strategy="fsdp-gpt", # num_nodes=2, ) - # lite.launch() + lite.launch() + train(lite, model_config, trainer_config) + + +def train(lite, model_config, trainer_config): + seed_everything(trainer_config.seed) # construct the training dataset train_dataset = CharDataset(textfile="data/tinyshakespeare.txt", block_size=model_config.block_size) - - # construct the model model_config.vocab_size = train_dataset.get_vocab_size() lite.print(model_config) @@ -120,15 +87,11 @@ def main(): lite.print(f"Number of parameters per device: {model.num_parameters / 1e6:.1f} M") lite.print(f"Total number of parameters: ~ {lite.world_size * model.num_parameters / 1e6:.1f} M") - apply_activation_checkpointing( - model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn - ) + apply_activation_checkpointing(model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn) # TODO: support multiple param groups for FSDP # optimizer = model.configure_optimizers(config.trainer) - optimizer = torch.optim.AdamW( - model.parameters(), lr=trainer_config.learning_rate, betas=trainer_config.betas - ) + optimizer = torch.optim.AdamW(model.parameters(), lr=trainer_config.learning_rate, betas=trainer_config.betas) optimizer = lite.setup_optimizers(optimizer) train_loader = DataLoader( @@ -160,15 +123,11 @@ def main(): _, loss = model(x, y) model.zero_grad(set_to_none=True) lite.backward(loss) - torch.nn.utils.clip_grad_norm_( - model.parameters(), trainer_config.grad_norm_clip - ) + torch.nn.utils.clip_grad_norm_(model.parameters(), trainer_config.grad_norm_clip) optimizer.step() if iteration % 10 == 0: - lite.print( - f"iteration time {iter_dt * 1000:.2f}ms; iteration {iteration}; train loss {loss.item():.5f}" - ) + lite.print(f"iteration time {iter_dt * 1000:.2f}ms; iteration {iteration}; train loss {loss.item():.5f}") iteration += 1 tnow = time.time() diff --git a/examples/gpt/train_cloud.py b/examples/gpt/train_cloud.py index 5769fc1..8df0fa1 100644 --- a/examples/gpt/train_cloud.py +++ b/examples/gpt/train_cloud.py @@ -1,14 +1,43 @@ # ! PACKAGE_NAME=lite pip install git+https://github.com/Lightning-AI/lightning # ! cd data && bash download-data.sh -from lightning import LightningApp, LightningWork, CloudCompute +from gpt.config import GPTConfig, TrainerConfig +from lightning import CloudCompute, LightningApp, LightningWork from lightning.app.components import LiteMultiNode -from train import main +from lightning_lite.lite import LightningLite +from train import train class Work(LightningWork): def run(self): - main() + model_config = GPTConfig( + model_type="gpt2-xl", + vocab_size=None, + block_size=128, + embd_pdrop=0.1, + resid_pdrop=0.1, + attn_pdrop=0.1, + ) + trainer_config = TrainerConfig( + num_workers=4, + max_iters=100, + block_size=128, + batch_size=64, + learning_rate=3e-4, + betas=(0.9, 0.95), + weight_decay=0.1, + grad_norm_clip=1.0, + ) + + lite = LightningLite( + accelerator="cuda", + devices=-1, + precision=16, + strategy="fsdp-gpt", + num_nodes=2, # TODO: Let MultiNode component set this value automatically + ) + lite.launch() + train(lite, model_config, trainer_config) app = LightningApp( @@ -17,4 +46,4 @@ def run(self): num_nodes=2, cloud_compute=CloudCompute(name="gpu-fast-multi"), ) -) \ No newline at end of file +) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..d002ae9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,10 @@ +[tool.isort] +profile = "black" +line_length = 120 +force_sort_within_sections = "False" +order_by_type = "False" + + +[tool.black] +line-length = 120 +exclude = '(_notebooks/.*)' From e9d5157babd755eb87d7fae18436132699adcdbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 5 Dec 2022 06:48:34 +0100 Subject: [PATCH 3/7] examples --- examples/gpt/train.py | 1 - examples/gpt/train_cloud.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/examples/gpt/train.py b/examples/gpt/train.py index 0752706..09bed40 100644 --- a/examples/gpt/train.py +++ b/examples/gpt/train.py @@ -63,7 +63,6 @@ def main(): devices=-1, precision=16, strategy="fsdp-gpt", - # num_nodes=2, ) lite.launch() train(lite, model_config, trainer_config) diff --git a/examples/gpt/train_cloud.py b/examples/gpt/train_cloud.py index 8df0fa1..a4af303 100644 --- a/examples/gpt/train_cloud.py +++ b/examples/gpt/train_cloud.py @@ -28,7 +28,6 @@ def run(self): weight_decay=0.1, grad_norm_clip=1.0, ) - lite = LightningLite( accelerator="cuda", devices=-1, @@ -36,7 +35,6 @@ def run(self): strategy="fsdp-gpt", num_nodes=2, # TODO: Let MultiNode component set this value automatically ) - lite.launch() train(lite, model_config, trainer_config) From 4277b5037572ecb58a69560de92de956a3faac01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 5 Dec 2022 06:53:01 +0100 Subject: [PATCH 4/7] update --- .gitignore | 3 +++ examples/gpt/.lightning | 2 -- 2 files changed, 3 insertions(+), 2 deletions(-) delete mode 100644 examples/gpt/.lightning diff --git a/.gitignore b/.gitignore index 7b24361..e2bfe36 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,6 @@ __pycache__/ .env .pylintrc *.egg-info + + +.lightning \ No newline at end of file diff --git a/examples/gpt/.lightning b/examples/gpt/.lightning deleted file mode 100644 index f3e862b..0000000 --- a/examples/gpt/.lightning +++ /dev/null @@ -1,2 +0,0 @@ -cluster_id: litng-ai-03 -name: modest-bardeen-5468 From e3c07b4b44d4c045fa3c96184ee45d1b8f4986bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 5 Dec 2022 07:48:29 +0100 Subject: [PATCH 5/7] update --- examples/gpt/requirements.txt | 2 ++ pyproject.toml | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/gpt/requirements.txt b/examples/gpt/requirements.txt index e69de29..fae9a4f 100644 --- a/examples/gpt/requirements.txt +++ b/examples/gpt/requirements.txt @@ -0,0 +1,2 @@ +torch>=1.13.0 +lightning diff --git a/pyproject.toml b/pyproject.toml index d002ae9..1b15ae7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,4 +7,3 @@ order_by_type = "False" [tool.black] line-length = 120 -exclude = '(_notebooks/.*)' From ff06be8dd35396d0c97a45ff508425251baa866e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 7 Dec 2022 03:28:19 +0100 Subject: [PATCH 6/7] simplify --- examples/gpt/requirements.txt | 2 +- examples/gpt/train.py | 16 ++-------------- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/examples/gpt/requirements.txt b/examples/gpt/requirements.txt index fae9a4f..0c56f0f 100644 --- a/examples/gpt/requirements.txt +++ b/examples/gpt/requirements.txt @@ -1,2 +1,2 @@ torch>=1.13.0 -lightning +https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip diff --git a/examples/gpt/train.py b/examples/gpt/train.py index 09bed40..57de7b0 100644 --- a/examples/gpt/train.py +++ b/examples/gpt/train.py @@ -12,27 +12,17 @@ from lightning_lite.lite import LightningLite from lightning_lite.strategies import STRATEGY_REGISTRY from lightning_lite.strategies.fsdp import FSDPStrategy -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - apply_activation_checkpointing, - checkpoint_wrapper, - CheckpointImpl, -) -from torch.distributed.fsdp import BackwardPrefetch, CPUOffload +from torch.distributed.fsdp import BackwardPrefetch from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.utils.data.dataloader import DataLoader auto_wrap_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) -check_fn = lambda submodule: isinstance(submodule, Block) -wrapper = functools.partial( - checkpoint_wrapper, - offload_to_cpu=False, - checkpoint_impl=CheckpointImpl.NO_REENTRANT, -) STRATEGY_REGISTRY.register( name="fsdp-gpt", strategy=FSDPStrategy, description="FSDP strategy with memory optimizations enabled for GPT large scale pretraining.", auto_wrap_policy=auto_wrap_policy, + activation_checkpointing=[Block], backward_prefetch=BackwardPrefetch.BACKWARD_PRE, ) @@ -86,8 +76,6 @@ def train(lite, model_config, trainer_config): lite.print(f"Number of parameters per device: {model.num_parameters / 1e6:.1f} M") lite.print(f"Total number of parameters: ~ {lite.world_size * model.num_parameters / 1e6:.1f} M") - apply_activation_checkpointing(model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn) - # TODO: support multiple param groups for FSDP # optimizer = model.configure_optimizers(config.trainer) optimizer = torch.optim.AdamW(model.parameters(), lr=trainer_config.learning_rate, betas=trainer_config.betas) From e13eab6c6612a6c995aa3dc89e9d5fb141b2c2a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 7 Dec 2022 03:31:07 +0100 Subject: [PATCH 7/7] simplify --- .gitignore | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index e2bfe36..511fb31 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,4 @@ __pycache__/ .env .pylintrc *.egg-info - - -.lightning \ No newline at end of file +.lightning