Загрузка данных
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
# 1. ЗАГРУЗКА ДАННЫХ
print("Загрузка данных MNIST...")
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Нормализация данных (приводим значения пикселей к диапазону [0, 1])
x_train = x_train / 255.0
x_test = x_test / 255.0
# 2. ПРОСМОТР ДАННЫХ
print(f"Размер обучающей выборки: {x_train.shape}")
print(f"Размер тестовой выборки: {x_test.shape}")
print(f"Количество классов: {len(np.unique(y_train))}")
# Покажем несколько примеров изображений
plt.figure(figsize=(10, 4))
for i in range(10):
plt.subplot(2, 5, i + 1)
plt.imshow(x_train[i], cmap='gray')
plt.title(f'Цифра: {y_train[i]}')
plt.axis('off')
plt.tight_layout()
plt.show()
# 3. СОЗДАНИЕ МОДЕЛИ НЕЙРОННОЙ СЕТИ
model = keras.Sequential([
# Входной слой: преобразуем 28x28 в плоский вектор 784
keras.layers.Flatten(input_shape=(28, 28)),
# Скрытый слой с 128 нейронами и функцией активации ReLU
keras.layers.Dense(128, activation='relu'),
# Выходной слой с 10 нейронами (по числу цифр) и softmax
keras.layers.Dense(10, activation='softmax')
])
# 4. КОМПИЛЯЦИЯ МОДЕЛИ
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# Вывод структуры модели
print("\nСтруктура модели:")
model.summary()
# 5. ОБУЧЕНИЕ МОДЕЛИ
print("\nНачинаем обучение...")
history = model.fit(
x_train, y_train,
epochs=5, # 5 эпох обучения
batch_size=32,
validation_split=0.2, # 20% данных используем для валидации
verbose=1
)
# 6. ОЦЕНКА МОДЕЛИ
print("\nОценка модели на тестовых данных...")
test_loss, test_accuracy = model.evaluate(x_test, y_test, verbose=0)
print(f"Точность на тестовых данных: {test_accuracy:.4f} ({test_accuracy * 100:.2f}%)")
# 7. ВИЗУАЛИЗАЦИЯ РЕЗУЛЬТАТОВ ОБУЧЕНИЯ
plt.figure(figsize=(12, 4))
# График точности
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Точность на обучении')
plt.plot(history.history['val_accuracy'], label='Точность на валидации')
plt.title('Точность модели')
plt.xlabel('Эпоха')
plt.ylabel('Точность')
plt.legend()
# График потерь
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Потери на обучении')
plt.plot(history.history['val_loss'], label='Потери на валидации')
plt.title('Функция потерь модели')
plt.xlabel('Эпоха')
plt.ylabel('Потери')
plt.legend()
plt.tight_layout()
plt.show()
# 8. ПРЕДСКАЗАНИЕ НА ПРИМЕРАХ ИЗ ТЕСТОВОЙ ВЫБОРКИ
predictions = model.predict(x_test[:10])
# Покажем предсказания для первых 10 изображений
plt.figure(figsize=(12, 4))
for i in range(10):
plt.subplot(2, 5, i + 1)
plt.imshow(x_test[i], cmap='gray')
pred_digit = np.argmax(predictions[i])
true_digit = y_test[i]
# Зеленый цвет, если угадано верно, красный - если ошибка
color = 'green' if pred_digit == true_digit else 'red'
plt.title(f'Пред: {pred_digit}\nРеал: {true_digit}', color=color)
plt.axis('off')
plt.tight_layout()
plt.show()
# 9. ПРИМЕР ПРЕДСКАЗАНИЯ ДЛЯ ОДНОГО ИЗОБРАЖЕНИЯ
sample_image = x_test[0]
sample_label = y_test[0]
# Делаем предсказание (добавляем размерность батча через reshape)
prediction = model.predict(sample_image.reshape(1, 28, 28))
predicted_digit = np.argmax(prediction)
print(f"\nПример предсказания:")
print(f"Реальная цифра: {sample_label}")
print(f"Предсказанная цифра: {predicted_digit}")
print(f"Вероятности для всех цифр: \n{prediction[0]}")
# 10. СОХРАНЕНИЕ МОДЕЛИ (опционально)
model.save('mnist_digit_classifier.h5')
print("\nМодель сохранена в файл 'mnist_digit_classifier.h5'")