[go: up one dir, main page]

Skip to content

Commit

Permalink
move instantiation of DataLoader
Browse files Browse the repository at this point in the history
  • Loading branch information
waynemystir committed Nov 20, 2020
1 parent 4050db6 commit 8fcaafb
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
.ipynb_checkpoints/
__pycache__/
*.swp
29 changes: 20 additions & 9 deletions mingpt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,8 @@ def train(self):
raw_model = model.module if hasattr(self.model, "module") else model
optimizer = raw_model.configure_optimizers(config)

def run_epoch(split):
is_train = split == 'train'
def run_epoch(loader, is_train=True):
model.train(is_train)
data = self.train_dataset if is_train else self.test_dataset
loader = DataLoader(data, shuffle=True, pin_memory=True,
batch_size=config.batch_size,
num_workers=config.num_workers)

losses = []
pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader)
Expand Down Expand Up @@ -117,11 +112,27 @@ def run_epoch(split):

best_loss = float('inf')
self.tokens = 0 # counter used for learning rate decay
for epoch in range(config.max_epochs):

run_epoch('train')
train_loader = DataLoader(
self.train_dataset,
shuffle=True,
pin_memory=True,
batch_size=config.batch_size,
num_workers=config.num_workers
)
if self.test_dataset is not None:
test_loader = DataLoader(
self.test_dataset,
shuffle=True,
pin_memory=True,
batch_size=config.batch_size,
num_workers=config.num_workers
)

for epoch in range(config.max_epochs):
run_epoch(train_loader)
if self.test_dataset is not None:
test_loss = run_epoch('test')
test_loss = run_epoch(test_loader, is_train=False)

# supports early stopping based on the test loss, or just save always if no test set is provided
good_model = self.test_dataset is None or test_loss < best_loss
Expand Down

0 comments on commit 8fcaafb

Please sign in to comment.