import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# Предполагаем, что df уже содержит столбец 'prediction' с прогнозом модели
# и столбцы: 'end_month_plan', 'renewed_fact_rate', 'portfolio_share', 'segment'
date_col = 'end_month_plan'
target = 'renewed_fact_rate'
weight_col = 'portfolio_share'
# Функция агрегации (взвешенное среднее)
def aggregate_weighted(df, group_col, value_col, weight_col):
return (df.groupby(group_col)
.apply(lambda g: pd.Series({
value_col: np.average(g[value_col], weights=g[weight_col]),
weight_col: g[weight_col].sum()
}))
.reset_index()
.sort_values(group_col))
# Сегменты для визуализации (те, для которых есть коэффициенты в модели)
segments = ['prem', 'mid']
# Создаём фигуру с тремя подграфиками
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
fig.suptitle('Сравнение фактической и прогнозной доли пролонгаций', fontsize=16)
# 1. Общий график (уже был, но построим заново для единообразия)
fact_all = aggregate_weighted(df, date_col, target, weight_col)
pred_all = aggregate_weighted(df, date_col, 'prediction', weight_col)
comp_all = fact_all.merge(pred_all, on=date_col, suffixes=('_fact', '_pred'))
axes[0].plot(comp_all[date_col], comp_all[target], marker='o', label='Факт')
axes[0].plot(comp_all[date_col], comp_all['prediction'], marker='x', label='Прогноз')
axes[0].set_title('Все сегменты')
axes[0].legend()
axes[0].grid(True)
axes[0].tick_params(axis='x', rotation=45)
# 2. График для премиум-сегмента
for idx, seg in enumerate(segments, start=1):
df_seg = df[df['segment'] == seg].copy()
fact_seg = aggregate_weighted(df_seg, date_col, target, weight_col)
pred_seg = aggregate_weighted(df_seg, date_col, 'prediction', weight_col)
comp_seg = fact_seg.merge(pred_seg, on=date_col, suffixes=('_fact', '_pred'))
axes[idx].plot(comp_seg[date_col], comp_seg[target], marker='o', label='Факт')
axes[idx].plot(comp_seg[date_col], comp_seg['prediction'], marker='x', label='Прогноз')
axes[idx].set_title(f'Сегмент {seg.upper()}')
axes[idx].legend()
axes[idx].grid(True)
axes[idx].tick_params(axis='x', rotation=45)
plt.tight_layout()
plt.show()