[~] Refactor
This commit is contained in:
parent
e60b290a00
commit
c40954472c
@ -247,12 +247,16 @@ def kernel_3(
|
|||||||
return roc_auc
|
return roc_auc
|
||||||
|
|
||||||
# %% [code]
|
# %% [code]
|
||||||
|
if os.path.exists('model.h5'):
|
||||||
|
o_2['model'].load_weights('model.h5')
|
||||||
|
else:
|
||||||
o_2['model'].fit(
|
o_2['model'].fit(
|
||||||
o_2['xtrain_pad'],
|
o_2['xtrain_pad'],
|
||||||
o_2['ytrain'],
|
o_2['ytrain'],
|
||||||
nb_epoch=5,
|
nb_epoch=5,
|
||||||
batch_size=64*o_2['strategy'].num_replicas_in_sync
|
batch_size=64*o_2['strategy'].num_replicas_in_sync
|
||||||
) #Multiplying by Strategy to run on TPU's
|
) #Multiplying by Strategy to run on TPU's
|
||||||
|
o_2['model'].save_weights('model.h5')
|
||||||
|
|
||||||
# %% [code]
|
# %% [code]
|
||||||
scores = o_2['model'].predict(o_2['xvalid_pad'])
|
scores = o_2['model'].predict(o_2['xvalid_pad'])
|
||||||
|
Loading…
Reference in New Issue
Block a user