본문 바로가기
딥러닝

TSM(Temporal Shift Module) train

by jennyiscoding 2024. 12. 10.

TSM은?

TSM (Temporal Shift Module)

 

프레임 수: 8-16 프레임

속도: 30-60 FPS (실시간 가능)

정확도: Kinetics-400에서 74-76%

PyTorch 구현: 공식 PyTorch 구현이 제공됨.

GitHub 링크: https://github.com/open-mmlab/mmaction2/blob/main/configs/recognition/tsn/README.md

ONNX 호환성: 모델 구조가 간단한 2D CNN 기반으로, ONNX로 손쉽게 변환 가능.

특징: 2D CNN 모델에 Temporal Shift Module을 추가하여 시간적 정보를 효율적으로 반영. 계산량이 적어 경량 디바이스에 적합.

Kinetics-400 모델

 

트레이닝 하기
python tools/train.py ${CONFIG_FILE} [optional arguments]
예:

python tools/train.py configs/recognition/tsn/tsn_imagenet-pretrained-r50_8xb32-1x1x3-100e_kinetics400-rgb.py  \
    --seed=0 --deterministic

 

내가 궁금한거:

트레인 목적은 뭔가?

config파일은 위에 표 위에서 골라서 넣는건가?

뭘 트레인 하는건가?

트레인하면 뭐가 나오는건가? 

 

1. 트레이닝의 주요 목적은:

  1. 행동 인식 모델 학습:
    비디오 데이터에서 사람이나 객체의 **행동(예: 걷기, 뛰기, 넘어짐)**을 인식할 수 있는 모델을 학습합니다.
  2. 모델 성능 최적화:
    모델이 학습 데이터를 기반으로 정확하게 행동을 분류하거나 예측할 수 있도록 가중치와 파라미터를 조정합니다.
  3. 일반화 성능 향상:
    새로운 비디오에서도 정확하게 행동을 인식할 수 있도록 모델을 학습합니다.

2. Config 파일은 어떻게 선택하는가?
Config 파일은 모델의 훈련 및 평가에 필요한 모든 설정을 담고 있습니다.

경로 예시:
configs/recognition/tsn/tsn_imagenet-pretrained-r50_8xb32-1x1x3-100e_kinetics400-rgb.py
Config 파일 선택 방법:
모델 유형에 따라 선택합니다:
TSN (Temporal Segment Network)
SlowFast
I3D (Inflated 3D ConvNet)
ST-GCN (Spatial Temporal Graph Convolutional Network)


데이터셋에 맞는 Config를 선택합니다:
Kinetics-400
UCF-101
HMDB-51


사전 훈련 여부:
ImageNet에서 사전 훈련된 모델을 사용하는지 여부에 따라 선택합니다.

 

3. 트레이닝 시 무엇을 학습하는가?
트레이닝을 통해 모델이 학습하는 것은:

입력 데이터의 특징:

비디오 프레임의 공간적 특징(예: 사람의 모양, 배경)
시간적 특징(예: 행동의 진행 상태)
행동 분류:

각 비디오 클립에서 어떤 행동이 일어나는지를 예측할 수 있게 됩니다.
모델 가중치:

신경망의 각 층에서 **최적의 가중치(Weights)와 편향(Bias)**을 학습합니다.

 

4. 트레이닝 결과로 무엇이 나오는가?

트레이닝이 완료되면 다음과 같은 결과물이 생성됩니다:

학습된 모델 체크포인트 (Checkpoint):

경로 예시:
work_dirs/tsn_imagenet-pretrained-r50/latest.pth
모델의 가중치와 파라미터가 저장된 파일입니다.
이 체크포인트를 사용해 추론(Inference) 또는 **추가 학습(Fine-Tuning)**을 할 수 있습니다.
로그 파일:

훈련 과정에서의 손실(loss), 정확도(accuracy), 학습률(learning rate) 등이 기록된 파일입니다.
모델 성능을 분석하고 디버깅하는 데 유용합니다.
시각화 결과 (TensorBoard 지원):

훈련 과정에서의 성능 지표를 그래프로 시각화할 수 있습니다.

 

 

내가 쓸 모델: ResNet101 Backbone + 1x1x8 Sampling Strategy

이유: chatgpt가 쓰라함

1x1x8 이것임. 맨 마지막꺼. 

 

동영상 추론 예제: 

python demo/demo_skeleton.py demo/fall.mp4 demo/output_fall.mp4 \ --config configs/recognition/tsn/tsn_imagenet-pretrained-r101_8xb32-1x1x8-100e_kinetics400-rgb.py \ --checkpoint checkpoints/tsn_imagenet-pretrained-r101_8xb32-1x1x8-100e_kinetics400-rgb.pth \ --device cuda:0

 

트레인(python tools/train.py configs/recognition/tsn/tsn_imagenet-pretrained-r50_8xb32-1x1x3-100e_kinetics400-rgb.py  \
    --seed=0 --deterministic) 돌렸는데, 에러가 떴음 --> 

FileNotFoundError: [Errno 2] No such file or directory: 'data/kinetics400/kinetics400_train_list_videos.txt'

 

이유: config파일에 kinetics400_train_list_videos.txt 넣는 부분이 있었기 때문이다. 

_base_ = [
    '../../_base_/models/tsn_r50.py', '../../_base_/schedules/sgd_100e.py',
    '../../_base_/default_runtime.py'
]

# dataset settings
dataset_type = 'VideoDataset'
data_root = 'data/kinetics400/videos_train'
data_root_val = 'data/kinetics400/videos_val'
ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt'
ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt'

file_client_args = dict(io_backend='disk')

train_pipeline = [
    dict(type='DecordInit', **file_client_args),
    dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=3),
    dict(type='DecordDecode'),
    dict(type='Resize', scale=(-1, 256)),
    dict(
        type='MultiScaleCrop',
        input_size=224,
        scales=(1, 0.875, 0.75, 0.66),
        random_crop=False,
        max_wh_scale_gap=1),
    dict(type='Resize', scale=(224, 224), keep_ratio=False),
    dict(type='Flip', flip_ratio=0.5),
    dict(type='FormatShape', input_format='NCHW'),
    dict(type='PackActionInputs')
]
val_pipeline = [
    dict(type='DecordInit', **file_client_args),
    dict(
        type='SampleFrames',
        clip_len=1,
        frame_interval=1,
        num_clips=3,
        test_mode=True),
    dict(type='DecordDecode'),
    dict(type='Resize', scale=(-1, 256)),
    dict(type='CenterCrop', crop_size=224),
    dict(type='FormatShape', input_format='NCHW'),
    dict(type='PackActionInputs')
]
test_pipeline = [
    dict(type='DecordInit', **file_client_args),
    dict(
        type='SampleFrames',
        clip_len=1,
        frame_interval=1,
        num_clips=25,
        test_mode=True),
    dict(type='DecordDecode'),
    dict(type='Resize', scale=(-1, 256)),
    dict(type='TenCrop', crop_size=224),
    dict(type='FormatShape', input_format='NCHW'),
    dict(type='PackActionInputs')
]

train_dataloader = dict(
    batch_size=32,
    num_workers=8,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    dataset=dict(
        type=dataset_type,
        ann_file=ann_file_train,
        data_prefix=dict(video=data_root),
        pipeline=train_pipeline))
val_dataloader = dict(
    batch_size=32,
    num_workers=8,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        ann_file=ann_file_val,
        data_prefix=dict(video=data_root_val),
        pipeline=val_pipeline,
        test_mode=True))
test_dataloader = dict(
    batch_size=1,
    num_workers=8,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        ann_file=ann_file_val,
        data_prefix=dict(video=data_root_val),
        pipeline=test_pipeline,
        test_mode=True))

val_evaluator = dict(type='AccMetric')
test_evaluator = val_evaluator

default_hooks = dict(checkpoint=dict(interval=3, max_keep_ckpts=3))

# Default setting for scaling LR automatically
#   - `enable` means enable scaling LR automatically
#       or not by default.
#   - `base_batch_size` = (8 GPUs) x (32 samples per GPU).
auto_scale_lr = dict(enable=False, base_batch_size=256)

 

# dataset settings
dataset_type = 'VideoDataset'
data_root = 'data/kinetics400/videos_train'
data_root_val = 'data/kinetics400/videos_val'
ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt'
ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt'

 

이거는 뭐 넣는걸까..?