Загрузка данных
#!/usr/bin/env python3
"""Baseline трека A («сколько товаров на фрагменте») — классификатор окон.
EfficientNet-B0 (torchvision) на окнах 768px, 5 классов-диапазонов.
Вырезает окна из изображений по координатам из windows_*_a.csv (на лету),
обучается, предсказывает на тесте и пишет сабмит pred_a.csv (window_id,class).
Это СТАРТОВЫЙ baseline для участников (Colab): простая модель, немного эпох.
Участники улучшают архитектуру/аугментации/обучение.
Зависимости: torch, torchvision, Pillow, numpy.
Запуск:
python baseline_track_a.py --images-dir /data/SKU110K_fixed/images \
--data-dir /data/polka_out --epochs 3 --out pred_a.csv
# быстрый смоук: добавить --limit 2000 --epochs 1
"""
import argparse
import csv
import os
import torch
import torch.nn as nn
from PIL import Image, ImageFile
from torch.utils.data import DataLoader, Dataset
ImageFile.LOAD_TRUNCATED_IMAGES = True # SKU-110K содержит обрезанные jpg
from torchvision import transforms
from torchvision.models import EfficientNet_B0_Weights, efficientnet_b0
NUM_CLASSES = 5
IMG_SIZE = 224
class WindowDataset(Dataset):
"""Окна по координатам из CSV; кроп из изображения на лету."""
def __init__(self, csv_path, images_dir, train, tfm, limit=None):
self.rows, self.images_dir, self.train, self.tfm = [], images_dir, train, tfm
with open(csv_path, newline="") as f:
for r in csv.DictReader(f):
self.rows.append(r)
if limit:
self.rows = self.rows[:limit]
def __len__(self):
return len(self.rows)
def __getitem__(self, i):
r = self.rows[i]
x, y, w, h = int(r["x"]), int(r["y"]), int(r["w"]), int(r["h"])
try:
with Image.open(os.path.join(self.images_dir, r["image_id"])) as im:
crop = im.convert("RGB").crop((x, y, x + w, y + h))
except Exception: # noqa: BLE001 — битый файл не должен валить эпоху
crop = Image.new("RGB", (max(w, 1), max(h, 1)))
t = self.tfm(crop)
if self.train:
return t, int(r["class"])
return t, r["window_id"]
def build_model(device):
m = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
m.classifier[1] = nn.Linear(m.classifier[1].in_features, NUM_CLASSES)
return m.to(device)
def main(argv=None):
p = argparse.ArgumentParser()
p.add_argument("--images-dir", required=True)
p.add_argument("--data-dir", required=True, help="папка с windows_train_a.csv / windows_test_a.csv")
p.add_argument("--epochs", type=int, default=3)
p.add_argument("--batch", type=int, default=64)
p.add_argument("--lr", type=float, default=1e-3)
p.add_argument("--limit", type=int, default=None, help="ограничить train (для смоук-теста)")
p.add_argument("--test-limit", type=int, default=None, help="ограничить test (для смоук-теста)")
p.add_argument("--workers", type=int, default=4)
p.add_argument("--out", default="pred_a.csv")
a = p.parse_args(argv)
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
print(f"device={device}")
train_tfm = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
test_tfm = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
train_ds = WindowDataset(os.path.join(a.data_dir, "windows_train_a.csv"), a.images_dir, True, train_tfm, a.limit)
test_ds = WindowDataset(os.path.join(a.data_dir, "windows_test_a.csv"), a.images_dir, False, test_tfm, a.test_limit)
train_dl = DataLoader(train_ds, batch_size=a.batch, shuffle=True, num_workers=a.workers)
test_dl = DataLoader(test_ds, batch_size=a.batch, shuffle=False, num_workers=a.workers)
print(f"train windows={len(train_ds)} test windows={len(test_ds)}")
model = build_model(device)
opt = torch.optim.Adam(model.parameters(), lr=a.lr)
crit = nn.CrossEntropyLoss()
for ep in range(a.epochs):
model.train()
total, correct, loss_sum = 0, 0, 0.0
for xb, yb in train_dl:
xb, yb = xb.to(device), yb.to(device)
opt.zero_grad()
out = model(xb)
loss = crit(out, yb)
loss.backward()
opt.step()
loss_sum += loss.item() * len(yb)
correct += (out.argmax(1) == yb).sum().item()
total += len(yb)
print(f"epoch {ep+1}/{a.epochs} loss={loss_sum/total:.3f} train_acc={correct/total:.3f}")
model.eval()
with open(a.out, "w", newline="") as f, torch.no_grad():
w = csv.writer(f)
w.writerow(["window_id", "class"])
for xb, wids in test_dl:
preds = model(xb.to(device)).argmax(1).cpu().tolist()
for wid, c in zip(wids, preds):
w.writerow([wid, c])
print(f"DONE -> {a.out}")
return 0
if __name__ == "__main__":
raise SystemExit(main())