Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions PaddleNLP/language_model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,19 @@

SEED = 123

class TimeCostAverage(object):
def __init__(self):
self.reset()
def reset(self):
self.cnt = 0
self.total_time = 0
def record(self, usetime):
self.cnt += 1
self.total_time += usetime
def get_average(self):
if self.cnt == 0:
return 0
return self.total_time / self.cnt

@contextlib.contextmanager
def profile_context(profile=True, profiler_path='/tmp/paddingrnn.profile'):
Expand Down Expand Up @@ -293,8 +306,10 @@ def train_an_epoch(epoch_id, batch_times):

total_loss = 0
iters = 0
batch_cost_avg = TimeCostAverage()

init_hidden, init_cell = generate_init_data()
batch_start_time = time.time()
for batch_id, batch in enumerate(train_data_iter):
input_data_feed = prepare_input(
batch,
Expand All @@ -303,7 +318,6 @@ def train_an_epoch(epoch_id, batch_times):
epoch_id=epoch_id,
with_lr=True,
device_count=device_count)
batch_start_time = time.time()
fetch_outs = exe.run(train_program,
feed=input_data_feed,
fetch_list=[
Expand All @@ -313,6 +327,7 @@ def train_an_epoch(epoch_id, batch_times):
use_program_cache=True)
batch_time = time.time() - batch_start_time
batch_times.append(batch_time)
batch_cost_avg.record(batch_time)

cost_train = np.array(fetch_outs[0])
lr = np.array(fetch_outs[1])
Expand All @@ -324,13 +339,17 @@ def train_an_epoch(epoch_id, batch_times):
ppl = np.exp(total_loss / iters)
print(
"-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f, lr: %.5f"
% (epoch_id, batch_id, batch_time, ppl[0], lr[0]))
% (epoch_id, batch_id, batch_cost_avg.get_average(), ppl[0], lr[0]))
batch_cost_avg.reset()

# profiler tools for benchmark
if args.profile and batch_id == log_interval:
profiler.reset_profiler()
elif args.profile and batch_id == (log_interval + 5):
break

batch_start_time = time.time()

ppl = np.exp(total_loss / iters)
return ppl

Expand All @@ -342,6 +361,7 @@ def train_an_epoch_dataloader(epoch_id, batch_times):

total_loss = 0
iters = 0
batch_cost_avg = TimeCostAverage()

dataloader.start()
batch_id = 0
Expand All @@ -355,6 +375,7 @@ def train_an_epoch_dataloader(epoch_id, batch_times):
batch_time = time.time() - batch_start_time
batch_times.append(batch_time)
batch_start_time = time.time()
batch_cost_avg.record(batch_time)

new_lr = generate_new_lr(epoch_id, device_count)
data_feeds['learning_rate'] = new_lr
Expand All @@ -381,7 +402,8 @@ def train_an_epoch_dataloader(epoch_id, batch_times):
ppl = np.exp(total_loss / iters)
print(
"-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f, lr: %.5f"
% (epoch_id, batch_id, batch_time, ppl[0], lr[0]))
% (epoch_id, batch_id, batch_cost_avg.get_average(), ppl[0], lr[0]))
batch_cost_avg.reset()

batch_id += 1
# profiler tools for benchmark
Expand Down