카테고리 없음

[오늘의 코드] vit sam과 부분 전이학습 (gpt활용한 공부)

수닝이 2025. 2. 24. 19:18

오늘의 코드 : vit sam

[문제상황]

코드를 실행시키면 bbox인데도 예측값이 3개밖에 안나오는 상황 발생 -> 이를 해결하기 위해 열심히 gpt한테 물어물어 보았다.

그렇다 num_multimask_outputs 이 부분 3-> 4로 바꾸면 된단다.

 

그렇게 또 다른 에러를 만났다.

 

정확한 에러 메시지는 없지만 이유는 사전학습된 부분과 현재 num_multimask_outputs 부분이 맞지 않는다. 그래서 사용된 코드가 아래코드다.

# Mask 후보 수(num_multimask_outputs)를 10개로 조정한 모델 사용 시:
efficientvit_sam = create_efficientvit_sam_model(name="efficientvit-sam-xl1", pretrained=False).to(device)

# pretrained weight 로드 (mask_decoder 제외)
state_dict = torch.load('efficientvit_sam_xl1.pt')
filtered_state_dict = {k: v for k, v in state_dict.items() if "mask_decoder" not in k}
efficientvit_sam.load_state_dict(filtered_state_dict, strict=False)

efficientvit_sam.train()


# 6️⃣ 손실 함수
criterion_seg = nn.BCEWithLogitsLoss()  # Segmentation Loss
criterion_bbox = nn.SmoothL1Loss()      # Bounding Box Loss

# 7️⃣ 옵티마이저 설정 (새로 추가된 head는 더 큰 학습률 적용)
optimizer = optim.AdamW([
    {"params": efficientvit_sam.image_encoder.parameters(), "lr": 1e-5},
    {"params": efficientvit_sam.prompt_encoder.parameters(), "lr": 1e-5},
    {"params": efficientvit_sam.mask_decoder.parameters(), "lr": 1e-4},  # 랜덤 초기화된 레이어에 큰 학습률
], weight_decay=1e-4)

# 8️⃣ Learning Rate Scheduler (Warm-up Scheduler 추가)
from torch.optim.lr_scheduler import LambdaLR

def warmup_scheduler(epoch, warmup_epochs=5):
    if epoch < warmup_epochs:
        return float(epoch) / float(max(1, warmup_epochs))
    return 1.0

scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: warmup_scheduler(epoch))

이 부분부터 시작한다,

 

중요한 부분 1 :  Pretrained = False

# Mask 후보 수(num_multimask_outputs)를 10개로 조정한 모델 사용 시:
efficientvit_sam = create_efficientvit_sam_model(name="efficientvit-sam-xl1", pretrained=False).to(device)

# pretrained weight 로드 (mask_decoder 제외)
state_dict = torch.load('efficientvit_sam_xl1.pt')
filtered_state_dict = {k: v for k, v in state_dict.items() if "mask_decoder" not in k}
efficientvit_sam.load_state_dict(filtered_state_dict, strict=False)

efficientvit_sam.train()

 

이 부분이 부분 전이 학습이다.

사전학습된 틀과 맞지않아 마지막 부분은 제외하고 초반 중간 부분만 가져온 것이다. 단점은 epoch을 늘려 오랜 시간 학습시켜줘야한다는 것이다.

 

중요한 부분 2 : Learning Rate가 다른 이유

# 7️⃣ 옵티마이저 설정 (새로 추가된 head는 더 큰 학습률 적용)
optimizer = optim.AdamW([
    {"params": efficientvit_sam.image_encoder.parameters(), "lr": 1e-5},
    {"params": efficientvit_sam.prompt_encoder.parameters(), "lr": 1e-5},
    {"params": efficientvit_sam.mask_decoder.parameters(), "lr": 1e-4},  # 랜덤 초기화된 레이어에 큰 학습률
], weight_decay=1e-4)

학습률이 작은 부분

image_encoder

- 이미지의 특징을 추출하는 백본(backbone).

  • pretrained된 모델로부터 좋은 표현을 학습한 상태로 제공되기 때문에, 너무 큰 학습률로 학습하면 기존에 학습된 표현이 손상될 수 있다. 이러한 이유로 학습률은 작게한다.

prompt_encoder

- 입력으로 제공되는 포인트, 박스와 같은 프롬프트를 인코딩하여 모델이 이해할 수 있는 형태로 만드는 부분.

  • pretrained,즉 사전학습된 상태라 기존 학습된 것이 손상될 우려가 있다. 이러한 이유로 학습률이 작게 설정된것이다.

학습률이 큰 부분

mask_decoder

- 실제 마스크(mask)를 예측하는 모델의 최종 출력 부분이자 네가 후보 마스크 수(num_multimask_outputs)를 조정하면 새로 추가되거나 변경된 레이어들이 포함된 부분이다.

 

  • 처음 부터 새로 학습해야하므로 큰 학습률이 필요하다.

 

중요한 부분 3 : warmup_scheduler

# 8️⃣ Learning Rate Scheduler (Warm-up Scheduler 추가)
from torch.optim.lr_scheduler import LambdaLR

def warmup_scheduler(epoch, warmup_epochs=5):
    if epoch < warmup_epochs:
        return float(epoch) / float(max(1, warmup_epochs))
    return 1.0

scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: warmup_scheduler(epoch))

특정 epoch = n 일때 0~1사이를 n번 분할한다. 1/n, 2/n ... n/n -> 이렇게 된다. 1/n ~ n-1/n 은 0 ~ n-1까지의 epoch에 해당할때 학습률이다. 정확히 이야기하면, 미리 설정한 학습률에 곱해지는 비율이라고 보면 된다.

예를 들어 learning rate : 0.001이라면 0.001* 1/n , 0.001*2/n ...이렇게되는 것이다. n번째부턴 1이 곱해져 기존에 설정된 learning rate가 사용된다.

 

기록을 남기는 이유

요즘 chat gpt에 의존하는 경향이 높아졌다. "해결"을 gpt가 해주면 나는 해답을 "공부"하고 "숙지"해야한다는 생각이 들었다. 안 그러면 내 실력이 정체되고 발전이 없을것이라고 생각되었다.

특히 오늘 코드는 gpt를 활용하여 많은 공부를 하게되었다. 복습을 안 하면 소용이 없을 뿐더러 오래 기억하고 싶었기에 남긴다.

앞으로도 gpt를 많이 사용하고 공부가 많이되었다 생각되면 간단하게라도 기록을 남길 생각이다.

 

 

코드 출처

https://github.com/mit-han-lab/efficientvit

 

GitHub - mit-han-lab/efficientvit: Efficient vision foundation models for high-resolution generation and perception.

Efficient vision foundation models for high-resolution generation and perception. - mit-han-lab/efficientvit

github.com

 

관련논문

EFFICIENTVIT