Загрузка данных


#!/usr/bin/env python3
"""Baseline трека A («сколько товаров на фрагменте») — классификатор окон.

EfficientNet-B0 (torchvision) на окнах 768px, 5 классов-диапазонов.
Вырезает окна из изображений по координатам из windows_*_a.csv (на лету),
обучается, предсказывает на тесте и пишет сабмит pred_a.csv (window_id,class).

Это СТАРТОВЫЙ baseline для участников (Colab): простая модель, немного эпох.
Участники улучшают архитектуру/аугментации/обучение.

Зависимости: torch, torchvision, Pillow, numpy.
Запуск:
  python baseline_track_a.py --images-dir /data/SKU110K_fixed/images \
      --data-dir /data/polka_out --epochs 3 --out pred_a.csv
  # быстрый смоук: добавить --limit 2000 --epochs 1
"""
import argparse
import csv
import os

import torch
import torch.nn as nn
from PIL import Image, ImageFile
from torch.utils.data import DataLoader, Dataset

ImageFile.LOAD_TRUNCATED_IMAGES = True  # SKU-110K содержит обрезанные jpg
from torchvision import transforms
from torchvision.models import EfficientNet_B0_Weights, efficientnet_b0

NUM_CLASSES = 5
IMG_SIZE = 224


class WindowDataset(Dataset):
    """Окна по координатам из CSV; кроп из изображения на лету."""

    def __init__(self, csv_path, images_dir, train, tfm, limit=None):
        self.rows, self.images_dir, self.train, self.tfm = [], images_dir, train, tfm
        with open(csv_path, newline="") as f:
            for r in csv.DictReader(f):
                self.rows.append(r)
        if limit:
            self.rows = self.rows[:limit]

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, i):
        r = self.rows[i]
        x, y, w, h = int(r["x"]), int(r["y"]), int(r["w"]), int(r["h"])
        try:
            with Image.open(os.path.join(self.images_dir, r["image_id"])) as im:
                crop = im.convert("RGB").crop((x, y, x + w, y + h))
        except Exception:  # noqa: BLE001 — битый файл не должен валить эпоху
            crop = Image.new("RGB", (max(w, 1), max(h, 1)))
        t = self.tfm(crop)
        if self.train:
            return t, int(r["class"])
        return t, r["window_id"]


def build_model(device):
    m = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
    m.classifier[1] = nn.Linear(m.classifier[1].in_features, NUM_CLASSES)
    return m.to(device)


def main(argv=None):
    p = argparse.ArgumentParser()
    p.add_argument("--images-dir", required=True)
    p.add_argument("--data-dir", required=True, help="папка с windows_train_a.csv / windows_test_a.csv")
    p.add_argument("--epochs", type=int, default=3)
    p.add_argument("--batch", type=int, default=64)
    p.add_argument("--lr", type=float, default=1e-3)
    p.add_argument("--limit", type=int, default=None, help="ограничить train (для смоук-теста)")
    p.add_argument("--test-limit", type=int, default=None, help="ограничить test (для смоук-теста)")
    p.add_argument("--workers", type=int, default=4)
    p.add_argument("--out", default="pred_a.csv")
    a = p.parse_args(argv)

    device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
    print(f"device={device}")

    train_tfm = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    test_tfm = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    train_ds = WindowDataset(os.path.join(a.data_dir, "windows_train_a.csv"), a.images_dir, True, train_tfm, a.limit)
    test_ds = WindowDataset(os.path.join(a.data_dir, "windows_test_a.csv"), a.images_dir, False, test_tfm, a.test_limit)
    train_dl = DataLoader(train_ds, batch_size=a.batch, shuffle=True, num_workers=a.workers)
    test_dl = DataLoader(test_ds, batch_size=a.batch, shuffle=False, num_workers=a.workers)
    print(f"train windows={len(train_ds)}  test windows={len(test_ds)}")

    model = build_model(device)
    opt = torch.optim.Adam(model.parameters(), lr=a.lr)
    crit = nn.CrossEntropyLoss()

    for ep in range(a.epochs):
        model.train()
        total, correct, loss_sum = 0, 0, 0.0
        for xb, yb in train_dl:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad()
            out = model(xb)
            loss = crit(out, yb)
            loss.backward()
            opt.step()
            loss_sum += loss.item() * len(yb)
            correct += (out.argmax(1) == yb).sum().item()
            total += len(yb)
        print(f"epoch {ep+1}/{a.epochs}  loss={loss_sum/total:.3f}  train_acc={correct/total:.3f}")

    model.eval()
    with open(a.out, "w", newline="") as f, torch.no_grad():
        w = csv.writer(f)
        w.writerow(["window_id", "class"])
        for xb, wids in test_dl:
            preds = model(xb.to(device)).argmax(1).cpu().tolist()
            for wid, c in zip(wids, preds):
                w.writerow([wid, c])
    print(f"DONE -> {a.out}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())