Загрузка данных


import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

# 1. ЗАГРУЗКА ДАННЫХ
print("Загрузка данных MNIST...")
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Нормализация данных (приводим значения пикселей к диапазону [0, 1])
x_train = x_train / 255.0
x_test = x_test / 255.0

# 2. ПРОСМОТР ДАННЫХ
print(f"Размер обучающей выборки: {x_train.shape}")
print(f"Размер тестовой выборки: {x_test.shape}")
print(f"Количество классов: {len(np.unique(y_train))}")

# Покажем несколько примеров изображений
plt.figure(figsize=(10, 4))
for i in range(10):
    plt.subplot(2, 5, i + 1)
    plt.imshow(x_train[i], cmap='gray')
    plt.title(f'Цифра: {y_train[i]}')
    plt.axis('off')
plt.tight_layout()
plt.show()

# 3. СОЗДАНИЕ МОДЕЛИ НЕЙРОННОЙ СЕТИ
model = keras.Sequential([
    # Входной слой: преобразуем 28x28 в плоский вектор 784
    keras.layers.Flatten(input_shape=(28, 28)),

    # Скрытый слой с 128 нейронами и функцией активации ReLU
    keras.layers.Dense(128, activation='relu'),

    # Выходной слой с 10 нейронами (по числу цифр) и softmax
    keras.layers.Dense(10, activation='softmax')
])

# 4. КОМПИЛЯЦИЯ МОДЕЛИ
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Вывод структуры модели
print("\nСтруктура модели:")
model.summary()

# 5. ОБУЧЕНИЕ МОДЕЛИ
print("\nНачинаем обучение...")
history = model.fit(
    x_train, y_train,
    epochs=5,  # 5 эпох обучения
    batch_size=32,
    validation_split=0.2,  # 20% данных используем для валидации
    verbose=1
)

# 6. ОЦЕНКА МОДЕЛИ
print("\nОценка модели на тестовых данных...")
test_loss, test_accuracy = model.evaluate(x_test, y_test, verbose=0)
print(f"Точность на тестовых данных: {test_accuracy:.4f} ({test_accuracy * 100:.2f}%)")

# 7. ВИЗУАЛИЗАЦИЯ РЕЗУЛЬТАТОВ ОБУЧЕНИЯ
plt.figure(figsize=(12, 4))

# График точности
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Точность на обучении')
plt.plot(history.history['val_accuracy'], label='Точность на валидации')
plt.title('Точность модели')
plt.xlabel('Эпоха')
plt.ylabel('Точность')
plt.legend()

# График потерь
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Потери на обучении')
plt.plot(history.history['val_loss'], label='Потери на валидации')
plt.title('Функция потерь модели')
plt.xlabel('Эпоха')
plt.ylabel('Потери')
plt.legend()

plt.tight_layout()
plt.show()

# 8. ПРЕДСКАЗАНИЕ НА ПРИМЕРАХ ИЗ ТЕСТОВОЙ ВЫБОРКИ
predictions = model.predict(x_test[:10])

# Покажем предсказания для первых 10 изображений
plt.figure(figsize=(12, 4))
for i in range(10):
    plt.subplot(2, 5, i + 1)
    plt.imshow(x_test[i], cmap='gray')
    pred_digit = np.argmax(predictions[i])
    true_digit = y_test[i]

    # Зеленый цвет, если угадано верно, красный - если ошибка
    color = 'green' if pred_digit == true_digit else 'red'
    plt.title(f'Пред: {pred_digit}\nРеал: {true_digit}', color=color)
    plt.axis('off')
plt.tight_layout()
plt.show()

# 9. ПРИМЕР ПРЕДСКАЗАНИЯ ДЛЯ ОДНОГО ИЗОБРАЖЕНИЯ
sample_image = x_test[0]
sample_label = y_test[0]

# Делаем предсказание (добавляем размерность батча через reshape)
prediction = model.predict(sample_image.reshape(1, 28, 28))
predicted_digit = np.argmax(prediction)

print(f"\nПример предсказания:")
print(f"Реальная цифра: {sample_label}")
print(f"Предсказанная цифра: {predicted_digit}")
print(f"Вероятности для всех цифр: \n{prediction[0]}")

# 10. СОХРАНЕНИЕ МОДЕЛИ (опционально)
model.save('mnist_digit_classifier.h5')
print("\nМодель сохранена в файл 'mnist_digit_classifier.h5'")