mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# Add a channels dimension
x_train = x_train[..., tf.newaxis].astype("float32")
x_test = x_test[..., tf.newaxis].astype("float32")
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(32, 3, activation='relu')
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(10)
def call(self, x):
x = self.conv1(x)
x = self.flatten(x)
x = self.d1(x)
return self.d2(x)
# Create an instance of the model
model = MyModel()
def train_step(images, labels):
with tf.GradientTape() as tape:
# training=True is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
predictions = model(images, training=True)
loss = loss_object(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_accuracy(labels, predictions)
for epoch in range(EPOCHS):
# Reset the metrics at the start of the next epoch
for images, labels in train_ds:
train_step(images, labels)
for test_images, test_labels in test_ds:
test_step(test_images, test_labels)
f'Epoch {epoch + 1}, '
f'Loss: {train_loss.result()}, '
f'Accuracy: {train_accuracy.result() * 100}, '
f'Test Loss: {test_loss.result()}, '
f'Test Accuracy: {test_accuracy.result() * 100}'
2022-01-24 14:49:39.070639: W tensorflow/core/platform/profile_utils/] Failed to get CPU frequency: 0 Hz
2022-01-24 14:49:39.070697: I tensorflow/core/grappler/optimizers/] Plugin optimizer for device_type GPU is enabled.
2022-01-24 14:49:53.550693: I tensorflow/core/grappler/optimizers/] Plugin optimizer for device_type GPU is enabled.
2022-01-24 14:49:55.053240: I tensorflow/core/grappler/optimizers/] Plugin optimizer for device_type GPU is enabled.
Epoch 1, Loss: 0.13899236917495728, Accuracy: 95.84333038330078, Test Loss: 0.06439483910799026, Test Accuracy: 97.94000244140625
Epoch 2, Loss: 0.04431848227977753, Accuracy: 98.625, Test Loss: 0.05371560901403427, Test Accuracy: 98.17000579833984
Epoch 3, Loss: 0.02307759039103985, Accuracy: 99.23666381835938, Test Loss: 0.056773096323013306, Test Accuracy: 98.32000732421875
Epoch 4, Loss: 0.013629536144435406, Accuracy: 99.54000091552734, Test Loss: 0.06244818866252899, Test Accuracy: 98.29000091552734
Epoch 5, Loss: 0.010191570036113262, Accuracy: 99.63833618164062, Test Loss: 0.07394283264875412, Test Accuracy: 98.30000305175781