군침이 싹 도는 코딩

Fine Tuning 본문

Python/Deep Learning

Fine Tuning

mugoori 2023. 1. 2. 17:11

Fine Tuning 이란 트랜스퍼 러닝을 한 결과를 더 향상 시키기위해 frozen 시켰던 학습 데이터를

다시 학습 가능 상태로 만든 뒤 원하는 만큼을 다시 frozen 해서 모델에 그대로 재 학습 시키는 것이다

 

 

 

 

base_model.trainable = True

# 학습을 다시 시키기위해 trainable 을 True로 바꿔준다

 

 

 

 

base_model.summary()
>>> 
                                                                                                  
 Conv_1 (Conv2D)                (None, 7, 7, 1280)   409600      ['block_16_project_BN[0][0]']    
                                                                                                  
 Conv_1_bn (BatchNormalization)  (None, 7, 7, 1280)  5120        ['Conv_1[0][0]']                 
                                                                                                  
 out_relu (ReLU)                (None, 7, 7, 1280)   0           ['Conv_1_bn[0][0]']              
                                                                                                  
==================================================================================================
Total params: 2,257,984
Trainable params: 2,223,872
Non-trainable params: 34,112
__________________________________________________________________________________________________

# 서머리를 찍어서 학습할 수 있는 데이터가 다시 돌아왔는지 확인해본다

 

 

 

 

len(base_model.layers)
>>> 154

# 레이어의 수를 확인한다

 

 

 

 

end_layer = 130

# 몇번째 레이어까지 frozen 시킬지를 결정한다

 

 

 

 

for layer in base_model.layers[0:end_layer+1] :
  layer.trainable = False

# for 문을 이용해 결정한 레이어 수까지 frozen을 해준다

 

 

 

 

base_model.summary()
>>>
 Conv_1 (Conv2D)                (None, 7, 7, 1280)   409600      ['block_16_project_BN[0][0]']    
                                                                                                  
 Conv_1_bn (BatchNormalization)  (None, 7, 7, 1280)  5120        ['Conv_1[0][0]']                 
                                                                                                  
 out_relu (ReLU)                (None, 7, 7, 1280)   0           ['Conv_1_bn[0][0]']              
                                                                                                  
==================================================================================================
Total params: 2,257,984
Trainable params: 1,360,000
Non-trainable params: 897,984
__________________________________________________________________________________________________

# 서머리를 찍어보면 학습가능한 데이터가 줄어든 것을 확인할 수 있다

 

 

 

 

model.compile(Adam(0.0001), 'categorical_crossentropy', ['accuracy'])

# 컴파일을 다시해준다

 

 

 

 

from keras import callbacks
epoch_history = model.fit(train_generator, epochs= 20, validation_data=(X_val,y_val), callbacks=[mcp, csv_logger], batch_size=64)
>>>
''''
Epoch 19/20
10/10 [==============================] - ETA: 0s - loss: 0.0123 - accuracy: 0.9950
Epoch 19: val_accuracy did not improve from 1.00000
10/10 [==============================] - 6s 565ms/step - loss: 0.0123 - accuracy: 0.9950 - val_loss: 0.0408 - val_accuracy: 0.9867
Epoch 20/20
10/10 [==============================] - ETA: 0s - loss: 0.0158 - accuracy: 0.9934
Epoch 20: val_accuracy did not improve from 1.00000
10/10 [==============================] - 6s 575ms/step - loss: 0.0158 - accuracy: 0.9934 - val_loss: 0.0227 - val_accuracy: 0.9867

# 학습을 시킨다 위 코드에서 사용된 콜백에

mcp와 csv_logger는 이글을 참고한다 : https://mugoori.tistory.com/151

 

Modelcheckpoint / CSVLogger 사용법

Modelcheckpoint는 에포크 시 마다 가장 좋은 모델을 저장한다 CSVLogger는 에포크 시 마다 기록을 남길 수 있다 if not os.path.exists(PROJECT_PATH + '/checkpoints/'+ model_type +'/') : os.makedirs(PROJECT_PATH + '/checkpoints/'+

mugoori.tistory.com