[~] Refactor

This commit is contained in:
Siarhei Siniak 2021-07-15 09:54:18 +03:00
parent c40954472c
commit e7f8a022cd

@ -231,7 +231,11 @@ def kernel_2():
def kernel_3( def kernel_3(
o_2, o_2,
nb_epochs=None,
): ):
if nb_epochs is None:
nb_epochs = 5
# %% [markdown] # %% [markdown]
# Writing a function for getting auc score for validation # Writing a function for getting auc score for validation
@ -253,7 +257,7 @@ def kernel_3(
o_2['model'].fit( o_2['model'].fit(
o_2['xtrain_pad'], o_2['xtrain_pad'],
o_2['ytrain'], o_2['ytrain'],
nb_epoch=5, nb_epoch=nb_epochs,
batch_size=64*o_2['strategy'].num_replicas_in_sync batch_size=64*o_2['strategy'].num_replicas_in_sync
) #Multiplying by Strategy to run on TPU's ) #Multiplying by Strategy to run on TPU's
o_2['model'].save_weights('model.h5') o_2['model'].save_weights('model.h5')