Training and Running SWIN Transformer Object Detection Method on custom dataset
In computer vision, SWIN Transformer has become the state-of-the-art object detection method. Many engineers and researchers who wish to learn more about it or to run the code on their local computers or online GPU find it often difficult to set up this project because it requires the correct versions of CUDA, MMCV, MMDET, and other packages. This is the step-by-step process of setting up this repository and running it on our custom dataset successfully.
Image Source (https://www.microsoft.com/en-us/research/blog/swin-transformer-supports-3-billion-parameter-vision-models-that-can-train-with-higher-resolution-images-for-greater-task-applicability/)
The SWIN Transformer Object Detection (paper link) (code link) method was published by the Microsoft Research team it holds a position among the top state-of-the-art object detection methods as of today. I am writing the steps that I followed to run their code on our custom dataset it includes steps for installation, and editing configuration files for object detection. Since many of us want to run the source code on custom datasets without reading the details, I have compiled all the issues I encountered during setting it up on my local computer in hopes that it might be useful to others.
1. Create a new Conda environment
conda create --name torch110cu102 python=3.8 -y conda activate torch110cu102
2. Installing PyTorch, torchvision, and mmcv
pip install torch==1.10.0+cu102 torchvision==0.11.0+cu102 -f https://download.pytorch.org/whl/torch_stable.html pip install mmcv-full==1.4.0 -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.10.0/index.html
3. Clone code repository and install dependencies
git clone https://github.com/SwinTransformer/Swin-Transformer-Object-Detection.git STOD cd STOD pip install -v -e .
4. Installing apex for mixed precision
git clone --recursive https://github.com/NVIDIA/apex cd apex/
I ran into some issues after installing the apex library so here is a fix go to
apex/apex/amp/utils.py replace from line number 93 to 97 with following
if x in cache: cached_x = cache[x] next_functions_available = False if x.requires_grad and cached_x.requires_grad: if len(cached_x.grad_fn.next_functions) > 1: next_functions_available = True # Make sure x is actually cached_x's autograd parent. if next_functions_available and cached_x.grad_fn.next_functions.variable is not x:
and then make sure you are inside the apex root and run
python setup.py install
5. Changing configuration file for object detection
There are a few changes that you need to make in order to run this code on your object detection method major changes will be in the following files:
- DATASET CONFIGURATION
configs/_base_/datasets/coco_detection.py you may create a copy of this and edit classes and other configurations [I am adding my file for reference here]
- CONFIG FILE (I create a copy)
configs/swin/cascade_mask_rcnn_swin_base_patch4_window7_mstrain_256-256_giou_4conv1f_adamw_3x_ours.py [I am adding my file for reference here]
6. Finally run your training configuration file
python tools/train.py configs/swin/cascade_mask_rcnn_swin_base_patch4_window7_mstrain_256-256_giou_4conv1f_adamw_3x_ours.py
Some issues and resolution links
- Comment out the configuration related to the mask
- Handle len(cached_x.grad_fn.next_functions) == 1 in cached_cast ( File "/opt/conda/lib/python3.9/site-packages/apex-0.1-py3.9.egg/apex/amp/utils.py", line 97, in cached_cast
IndexError: tuple index out of range) https://github.com/NVIDIA/apex/pull/1282/files/01802f623c9b54199566871b49f94b2d07c3f047
Please note that if you comment on lines suggested by the above link do not forget to build the apex again by following step 4 again.
- RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.
If your dataset is not in COCO format which this repository expects to be in please convert your dataset first to MS COCO format a brief post for that is here
The model configuration file for single GPU training is also attached.
Information about the dataset that I have used. There are 5 classes in our dataset, and each image is 256x256 in height and width. Therefore, the configuration files will have all the attributes related to this, but you can easily customize them as per your dataset. Please do not hesitate to contact me if you have any further problems, questions, or suggestions. I will try make the article better with your help.
What's Your Reaction?