diff --git a/python/tasks/jigsaw_toxic.py b/python/tasks/jigsaw_toxic.py
index 56298c8..426fd93 100644
--- a/python/tasks/jigsaw_toxic.py
+++ b/python/tasks/jigsaw_toxic.py
@@ -231,7 +231,11 @@ def kernel_2():
 
 def kernel_3(
     o_2,
+    nb_epochs=None,
 ):
+    if nb_epochs is None:
+        nb_epochs = 5
+
     # %% [markdown]
     # Writing a function for getting auc score for validation
 
@@ -253,7 +257,7 @@ def kernel_3(
         o_2['model'].fit(
             o_2['xtrain_pad'],
             o_2['ytrain'],
-            nb_epoch=5,
+            nb_epoch=nb_epochs,
             batch_size=64*o_2['strategy'].num_replicas_in_sync
         ) #Multiplying by Strategy to run on TPU's
         o_2['model'].save_weights('model.h5')