Загрузка данных


// Функция получения оценки позиции
vector<pair<vector<float>, float>> EvaluatePositions(vector<Node*> batch, std::shared_ptr<torch::jit::script::Module> model) {
    torch::NoGradGuard no_grad;

    if (!model) {
        throw std::runtime_error("Model not loaded. Call load_model() first.");
    }
    torch::Device device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);

    c10::IValue output;
    std::vector<std::array<std::array<std::array<std::array<char, 7>, 9>, 15>, 8>> state_array(batch.size());

    for (int i = 0; i < batch.size(); i++) {
        batch[i]->toTensor(reinterpret_cast<char(&)[8][15][9][7]>(state_array[i]));
    }

    if (model_type[GetModelId(model)] == 1)
    {
        auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU);
        torch::Tensor input_tensor = torch::zeros({ (long long)batch.size(), 8, 15, 9, 7 }, options);

        auto accessor = input_tensor.accessor<float, 5>();
        for (int n = 0; n < batch.size(); n++)
            for (int h = 0; h < 8; h++)
                for (int i = 0; i < 15; i++)
                    for (int j = 0; j < 9; j++)
                        for (int k = 0; k < 7; k++)
                            accessor[n][h][i][j][k] = static_cast<float>(state_array[n][h][i][j][k]);

        input_tensor = input_tensor.to(device);

        std::vector<torch::jit::IValue> inputs;
        inputs.push_back(input_tensor);
        output = model->forward(inputs);
    }
    else
    {
        auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU);
        torch::Tensor input_tensor = torch::zeros({ (long long)batch.size(), 13, 9, 7 }, options);

        auto accessor = input_tensor.accessor<float, 4>();
        for (int n = 0; n < batch.size(); n++)
            for (int i = 0; i < 13; i++)
                for (int j = 0; j < 9; j++)
                    for (int k = 0; k < 7; k++)
                        accessor[n][i][j][k] = static_cast<float>(state_array[n][0][i][j][k]);

        input_tensor = input_tensor.to(device);

        std::vector<torch::jit::IValue> inputs;
        inputs.push_back(input_tensor);
        output = model->forward(inputs);
    }
    vector<pair<vector<float>, float>> result;

    // Проверяем тип вывода модели
    if (output.isTuple()) {
        // Модель возвращает кортеж (policy, value)
        auto output_tuple = output.toTuple();

        torch::Tensor log_policy = output_tuple->elements()[0].toTensor().to(torch::kCPU);
        torch::Tensor values = output_tuple->elements()[1].toTensor().to(torch::kCPU);

        auto log_policy_accessor = log_policy.accessor<float, 4>();

        for (int i = 0; i < batch.size(); i++) {
            auto game_state = batch[i]->state;

            vector<float> normalized_probs;
            float sum = 0.0;
            vector<float> raw_probs;

            vector<Move> moves = GetMoves(game_state);

            for (Move& m : moves) {
                auto [a, b, c] = m.toTensor();
                const float log_prob = log_policy_accessor[i][a][b][c];
                const float prob = exp(log_prob);
                raw_probs.push_back(prob);
                sum += prob;
            }

            for (float p : raw_probs) {
                normalized_probs.push_back(p / sum);
            }

            float value = values[i].item<float>();
            if (game_state.order == BLACK) value *= -1;

            result.emplace_back(normalized_probs, value);
        }
    }
    else if (output.isTensor()) {
        // Модель возвращает только оценку (value)
        torch::Tensor values = output.toTensor().to(torch::kCPU);

        for (int i = 0; i < batch.size(); i++) {
            auto game_state = batch[i]->state;

            // Если политика не возвращается, используем равномерное распределение
            vector<Move> moves = GetMoves(game_state);
            vector<float> uniform_probs(moves.size(), 1.0f / moves.size());

            float value = values[i].item<float>();
            if (game_state.order == BLACK) value *= -1;

            result.emplace_back(uniform_probs, value);
        }
    }
    else {
        throw std::runtime_error("Unexpected model output format");
    }

    return result;
}