-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathepiautogp_forecast_utils.py
More file actions
345 lines (296 loc) · 10.2 KB
/
epiautogp_forecast_utils.py
File metadata and controls
345 lines (296 loc) · 10.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
"""
Shared utilities for forecast pipeline scripts.
This module contains common functionality used across different forecast
pipelines (pyrenew, timeseries, epiautogp, etc.).
"""
import logging
import os
from dataclasses import dataclass
from datetime import date, timedelta
from pathlib import Path
from typing import Any
import polars as pl
from pipelines.common_utils import (
calculate_training_dates,
create_hubverse_table,
get_available_reports,
load_credentials,
load_nssp_data,
parse_and_validate_report_date,
plot_and_save_loc_forecast,
)
from pipelines.forecast_pyrenew import generate_epiweekly_data
from pipelines.prep_data import process_and_save_loc_data
from pipelines.prep_eval_data import save_eval_data
@dataclass
class ModelPaths:
"""
Container for model output directory structure and file paths.
This class holds all the computed output paths for a specific model run,
making it easier to track where data and results are stored.
"""
model_output_dir: Path
data_dir: Path
daily_training_data: Path
epiweekly_training_data: Path
@dataclass
class ForecastPipelineContext:
"""
Container for common forecast pipeline data, input configurations and
the logger.
This class holds all the shared state that gets passed around during
a forecast pipeline run, reducing the number of parameters that need
to be passed between functions.
"""
disease: str
loc: str
target: str
frequency: str
use_percentage: bool
model_name: str
eval_data_path: Path | None
nhsn_data_path: Path | None
report_date: date
first_training_date: date
last_training_date: date
n_forecast_days: int
exclude_last_n_days: int
model_batch_dir: Path
model_run_dir: Path
credentials_dict: dict[str, Any]
facility_level_nssp_data: pl.LazyFrame
loc_level_nssp_data: pl.LazyFrame
logger: logging.Logger
def setup_forecast_pipeline(
disease: str,
report_date: str,
loc: str,
target: str,
frequency: str,
use_percentage: bool,
model_name: str,
eval_data_path: Path | None,
nhsn_data_path: Path | None,
facility_level_nssp_data_dir: Path | str,
state_level_nssp_data_dir: Path | str,
output_dir: Path | str,
n_training_days: int,
n_forecast_days: int,
exclude_last_n_days: int = 0,
credentials_path: Path = None,
logger: logging.Logger = None,
) -> ForecastPipelineContext:
"""
Set up common forecast pipeline infrastructure.
This function performs the initial setup steps that are common across
all forecast pipelines:
1. Load credentials
2. Get available report dates
3. Parse and validate the report date
4. Calculate training dates
5. Load NSSP data
6. Create batch directory structure
Parameters
----------
disease : str
Disease to model (e.g., "COVID-19", "Influenza", "RSV")
report_date : str
Report date in YYYY-MM-DD format or "latest"
loc : str
Two-letter USPS location abbreviation (e.g., "CA", "NY")
facility_level_nssp_data_dir : Path | str
Directory containing facility-level NSSP ED visit data
state_level_nssp_data_dir : Path | str
Directory containing state-level NSSP ED visit data
output_dir : Path | str
Root directory for output
n_training_days : int
Number of days of training data
n_forecast_days : int
Number of days ahead to forecast
exclude_last_n_days : int, default=0
Number of recent days to exclude from training
credentials_path : Path, optional
Path to credentials file
logger : logging.Logger, optional
Logger instance. If None, creates a new logger
Returns
-------
ForecastPipelineContext
Context object containing all setup information
"""
if logger is None:
logger = logging.getLogger(__name__)
logger.info(
f"Setting up forecast pipeline for {disease}, "
f"location {loc}, report date {report_date}"
)
# Load credentials
credentials_dict = load_credentials(credentials_path, logger)
# Get available reports
available_facility_level_reports = get_available_reports(
facility_level_nssp_data_dir
)
available_loc_level_reports = get_available_reports(state_level_nssp_data_dir)
# Parse and validate report date
report_date_parsed, loc_report_date = parse_and_validate_report_date(
report_date,
available_facility_level_reports,
available_loc_level_reports,
logger,
)
# Calculate training dates
first_training_date, last_training_date = calculate_training_dates(
report_date_parsed,
n_training_days,
exclude_last_n_days,
logger,
)
# Load NSSP data
facility_level_nssp_data, loc_level_nssp_data = load_nssp_data(
report_date_parsed,
loc_report_date,
available_facility_level_reports,
available_loc_level_reports,
facility_level_nssp_data_dir,
state_level_nssp_data_dir,
logger,
)
# Create model batch directory structure
model_batch_dir_name = (
f"{disease.lower()}_r_{report_date_parsed}_f_"
f"{first_training_date}_t_{last_training_date}"
)
model_batch_dir = Path(output_dir, model_batch_dir_name)
model_run_dir = Path(model_batch_dir, "model_runs", loc)
logger.info(f"Model batch directory: {model_batch_dir}")
logger.info(f"Model run directory: {model_run_dir}")
return ForecastPipelineContext(
disease=disease,
loc=loc,
target=target,
frequency=frequency,
use_percentage=use_percentage,
model_name=model_name,
eval_data_path=eval_data_path,
nhsn_data_path=nhsn_data_path,
report_date=report_date_parsed,
first_training_date=first_training_date,
last_training_date=last_training_date,
n_forecast_days=n_forecast_days,
exclude_last_n_days=exclude_last_n_days,
model_batch_dir=model_batch_dir,
model_run_dir=model_run_dir,
credentials_dict=credentials_dict,
facility_level_nssp_data=facility_level_nssp_data,
loc_level_nssp_data=loc_level_nssp_data,
logger=logger,
)
def prepare_model_data(
context: ForecastPipelineContext,
) -> ModelPaths:
"""
Prepare training and evaluation data for a model.
This function performs the data preparation steps that are common across
all forecast pipelines:
1. Create model output directory
2. Process and save location data
3. Save evaluation data
4. Generate epiweekly datasets
Parameters
----------
context : ForecastPipelineContext
Pipeline context with shared configuration
model_name : str
Name of the model (used for directory naming)
eval_data_path : Path, optional
Path to evaluation dataset
nhsn_data_path : Path, optional
Path to NHSN data (for local testing)
loc_level_nwss_data : pl.DataFrame, optional
Wastewater surveillance data (for pyrenew models)
Returns
-------
ModelPaths
Object containing all model output directory and file paths
Raises
------
ValueError
If eval_data_path is None
"""
logger = context.logger
# Create model output directory
model_output_dir = Path(context.model_run_dir, context.model_name)
data_dir = Path(model_output_dir, "data")
os.makedirs(data_dir, exist_ok=True)
logger.info(f"Processing data for {context.loc}")
# Process and save location data
process_and_save_loc_data(
loc_abb=context.loc,
disease=context.disease,
facility_level_nssp_data=context.facility_level_nssp_data,
loc_level_nssp_data=context.loc_level_nssp_data,
report_date=context.report_date,
first_training_date=context.first_training_date,
last_training_date=context.last_training_date,
save_dir=data_dir,
logger=logger,
credentials_dict=context.credentials_dict,
nhsn_data_path=context.nhsn_data_path,
)
# Save evaluation data
logger.info("Getting eval data...")
if context.eval_data_path is None:
raise ValueError("No path to an evaluation dataset provided.")
save_eval_data(
loc=context.loc,
disease=context.disease,
first_training_date=context.first_training_date,
last_training_date=context.last_training_date,
latest_comprehensive_path=context.eval_data_path,
output_data_dir=data_dir,
last_eval_date=context.report_date + timedelta(days=context.n_forecast_days),
credentials_dict=context.credentials_dict,
nhsn_data_path=context.nhsn_data_path,
)
logger.info("Done getting eval data.")
# Generate epiweekly datasets
logger.info("Generating epiweekly datasets from daily datasets...")
generate_epiweekly_data(data_dir)
logger.info("Data preparation complete.")
# Return structured paths object
return ModelPaths(
model_output_dir=model_output_dir,
data_dir=data_dir,
daily_training_data=Path(data_dir, "combined_training_data.tsv"),
epiweekly_training_data=Path(data_dir, "epiweekly_combined_training_data.tsv"),
)
def postprocess_forecast(
context: ForecastPipelineContext,
model_name: str,
) -> None:
"""
Perform standard postprocessing on forecast outputs.
This function performs postprocessing steps that are common across
all forecast pipelines:
1. Plot and save location forecast
2. Create hubverse-compatible table
Parameters
----------
context : ForecastPipelineContext
Pipeline context with shared configuration
model_name : str
Name of the model (used for directory naming)
"""
logger = context.logger
logger.info("Performing postprocessing...")
n_days_past_last_training = context.n_forecast_days + context.exclude_last_n_days
# Plot and save forecast
plot_and_save_loc_forecast(
context.model_run_dir,
n_days_past_last_training,
timeseries_model_name=model_name,
)
# Create hubverse table
create_hubverse_table(Path(context.model_run_dir, model_name))
logger.info("Postprocessing complete.")