diff --git a/python/tasks/jigsaw_toxic.py b/python/tasks/jigsaw_toxic.py index 56298c8..426fd93 100644 --- a/python/tasks/jigsaw_toxic.py +++ b/python/tasks/jigsaw_toxic.py @@ -231,7 +231,11 @@ def kernel_2(): def kernel_3( o_2, + nb_epochs=None, ): + if nb_epochs is None: + nb_epochs = 5 + # %% [markdown] # Writing a function for getting auc score for validation @@ -253,7 +257,7 @@ def kernel_3( o_2['model'].fit( o_2['xtrain_pad'], o_2['ytrain'], - nb_epoch=5, + nb_epoch=nb_epochs, batch_size=64*o_2['strategy'].num_replicas_in_sync ) #Multiplying by Strategy to run on TPU's o_2['model'].save_weights('model.h5')