[TensorFlow 2.0] Optimizer 및 Training (Expert)IT/AI2022. 9. 8. 10:18
Table of Contents
반응형
TensorFlow 공식 홈페이지에서 설명하는 Expert 버전을 사용해봅니다.
Load Packages
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import datasets
학습 과정 돌아보기
Build Model
input_shape = (28, 28, 1)
num_classes = 10
inputs = layers.Input(input_shape, dtype=tf.float64)
net = layers.Conv2D(32, (3, 3), padding='SAME')(inputs)
net = layers.Activation('relu')(net)
net = layers.Conv2D(32, (3, 3), padding='SAME')(net)
net = layers.Activation('relu')(net)
net = layers.MaxPooling2D(pool_size=(2, 2))(net)
net = layers.Dropout(0.5)(net)
net = layers.Conv2D(64, (3, 3), padding='SAME')(net)
net = layers.Activation('relu')(net)
net = layers.Conv2D(64, (3, 3), padding='SAME')(net)
net = layers.Activation('relu')(net)
net = layers.MaxPooling2D(pool_size=(2, 2))(net)
net = layers.Dropout(0.5)(net)
net = layers.Flatten()(net)
net = layers.Dense(512)(net)
net = layers.Activation('relu')(net)
net = layers.Dropout(0.5)(net)
net = layers.Dense(num_classes)(net)
net = layers.Activation('softmax')(net)
model = tf.keras.Model(inputs=inputs, outputs=net, name='Basic_CNN')
Preprocess
TensorFlow 공식 홈페이지에서 말한 expert 방법을 사용합니다.
mnist = tf.keras.datasets.mnist
# Load Data from MNIST
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# Channel 차원 추가
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
# Data Normailzation
x_train, x_test = x_train / 255.0, x_test / 255.0
tf.data 사용
- from_tensor_slices()
- shuffle()
- batch()
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.shuffle(1000)
train_ds = train_ds.batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = test_ds.batch(32)
Visualize Data
matplotlib 불러와서 데이터를 시각화합니다.
import matplotlib.pyplot as plt
%matplotlib inline
for image, label in train_ds.take(2):
plt.title(label[0].shape)
plt.imshow(image[0, :, :, 0], 'gray')
plt.show()
Training (Keras)
Keras로 학습 할 때는 기존과 같지만, train_ds는 generator라서 그래도 넣을 수 있습니다.
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
model.fit(train_ds, epochs=1000)
Optimization
- Loss Function
- Optimizer
- Metrics
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
Training
@tf.function : 기존 session 열었던 것 처럼 바로 작동 안 하고, 그래프만 만들고 학습이 시작되면 돌아가도록 합니다.
@tf.function
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images)
loss = loss_object(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_accuracy(labels, predictions)
@tf.function
def test_step(images, lebels):
predictions = model(images)
t_loss = loss_object(labels, predictions)
test_loss(t_loss)
test_accuracy(labels, predictions)
for epoch in range(2):
for images, labels in train_ds:
train_step(images, labels)
for test_images, test_labels in test_ds:
test_step(test_images, test_labels)
template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss {}, Test Accuracy: {}'
print(template.format(epoch+1,
train_loss.result(),
train_accuracy.result() * 100,
test_loss.result(),
test_accuracy.result() * 100))
결과
Start Training
Epoch 1, Loss: 0.04196552559733391, Accuracy: 98.74666595458984, Test Loss 0.043360475450754166, Test Accuracy: 98.72000122070312
Start Training
Epoch 2, Loss: 0.033374134451150894, Accuracy: 99.0050048828125, Test Loss 0.03336939960718155, Test Accuracy: 98.95500183105469
반응형
'IT > AI' 카테고리의 다른 글
[PyTorch] 기초 사용법 (0) | 2022.09.10 |
---|---|
[TensorFlow 2.0] Evaluating & Prediction (0) | 2022.09.10 |
[TensorFlow 2.0] Optimizer 및 Training (Keras) (0) | 2022.09.08 |
[TensorFlow 2.0] 각 Layer별 역할 및 파라미터 (0) | 2022.09.07 |
[TensorFlow 2.0] 예제 데이터셋 (MNIST) 사용 (0) | 2022.09.07 |
@고지니어스 :: 규니의 개발 블로그
IT 기술과 개발 내용을 포스팅하는 블로그
포스팅이 좋았다면 "좋아요❤️" 또는 "구독👍🏻" 해주세요!