diff --git a/python/tasks/jigsaw_toxic.py b/python/tasks/jigsaw_toxic.py index bf1f6b2..56298c8 100644 --- a/python/tasks/jigsaw_toxic.py +++ b/python/tasks/jigsaw_toxic.py @@ -247,12 +247,16 @@ def kernel_3( return roc_auc # %% [code] - o_2['model'].fit( - o_2['xtrain_pad'], - o_2['ytrain'], - nb_epoch=5, - batch_size=64*o_2['strategy'].num_replicas_in_sync - ) #Multiplying by Strategy to run on TPU's + if os.path.exists('model.h5'): + o_2['model'].load_weights('model.h5') + else: + o_2['model'].fit( + o_2['xtrain_pad'], + o_2['ytrain'], + nb_epoch=5, + batch_size=64*o_2['strategy'].num_replicas_in_sync + ) #Multiplying by Strategy to run on TPU's + o_2['model'].save_weights('model.h5') # %% [code] scores = o_2['model'].predict(o_2['xvalid_pad'])