Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training SegFormer model not working (goes through notebook, but model loss becomes nan) on dataset I created (stuck for a week or so) #459

Open
realharryhero opened this issue Dec 30, 2023 · 5 comments

Comments

@realharryhero
Copy link

realharryhero commented Dec 30, 2023

When trying to train a SegFormer model on this notebook, changing the variable ds to some contrails datasets that I have been sending to huggingface, such as this one, the model's loss turns to nan (and perhaps (?) it sometimes crashes after training the first epoch).

This does not occur when training segment.ai's sidewalks dataset. This may have something to do with some differences in my segmentation bitmaps or some issues with the duckdb files (the duckdb files seem to be formatted differently on the sidewalks dataset compared to my contails dataset).

Why does this occur?

(I obtained the contrails images from this competition's dataset.)

@realharryhero
Copy link
Author

@sayakpaul

@sayakpaul
Copy link
Member

Try lowering down the learning rate.

@realharryhero
Copy link
Author

The model's loss still becomes nan even with 10x (1000x?) lower learning rate than what was originally in the notebook. A few errors also occur; a screenshot and some text describing the error are below.

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[30], line 1
----> 1 model.fit(
      2     train_set,
      3     validation_data=val_set,
      4     callbacks=callbacks,
      5     epochs=epochs,
      6 )

File ~/jupyter/miniconda3/envs/tf3.10new/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py:70, in filter_traceback.<locals>.error_handler(*args, **kwargs)
     67     filtered_tb = _process_traceback_frames(e.__traceback__)
     68     # To get the full stack trace, call:
     69     # `tf.debugging.disable_traceback_filtering()`
---> 70     raise e.with_traceback(filtered_tb) from None
     71 finally:
     72     del filtered_tb

File ~/jupyter/miniconda3/envs/tf3.10new/lib/python3.10/site-packages/transformers/keras_callbacks.py:256, in KerasMetricCallback.on_epoch_end(self, epoch, logs)
    253 all_preds = self._postprocess_predictions_or_labels(prediction_list)
    254 all_labels = self._postprocess_predictions_or_labels(label_list)
--> 256 metric_output = self.metric_fn((all_preds, all_labels))
    257 if not isinstance(metric_output, dict):
    258     raise TypeError(
    259         f"metric_fn should return a dict mapping metric names to values but instead returned {metric_output}"
    260     )

Cell In[27], line 29
     25 per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
     26 per_category_iou = metrics.pop("per_category_iou").tolist()
     28 metrics.update(
---> 29     {f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)}
     30 )
     31 metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})
     32 return {"val_" + k: v for k, v in metrics.items()}

Cell In[27], line 29
     25 per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
     26 per_category_iou = metrics.pop("per_category_iou").tolist()
     28 metrics.update(
---> 29     {f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)}
     30 )
     31 metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})
     32 return {"val_" + k: v for k, v in metrics.items()}

KeyError: 2
SegFormer model not training screenshot

@realharryhero
Copy link
Author

I think I figured it out; the labels file I used had pixel value 255 as contrails, pixel value 1 as another ("filler") class, and pixel value 0 as unlabeled. But I think I needed to have a pixel value 2 as contrails, to have the pattern "0 1 2 3 ...".

Sort of "closed," but this is a very dumb issue. Any way to fix it in the future? Shouldn't take too long to change some bits of code; especially as I was stuck on this for a week and a half.

@realharryhero
Copy link
Author

@sayakpaul

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants