Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save batch data to file after batch evaluation #10243

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 36 additions & 18 deletions src/everest/everest_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class BatchStorageData:

@property
def existing_dataframes(self) -> dict[str, pl.DataFrame]:
return {
existing_dfs = {
k: cast(pl.DataFrame, getattr(self, k))
for k in [
"batch_objectives",
Expand All @@ -62,6 +62,8 @@ def existing_dataframes(self) -> dict[str, pl.DataFrame]:
if getattr(self, k) is not None
}

return {k: v for k, v in existing_dfs.items() if not v.is_empty()}


@dataclass
class OptimizationStorageData:
Expand Down Expand Up @@ -95,7 +97,7 @@ def simulation_to_geo_realization_map(self, batch_id: int) -> dict[int, int]:

@property
def existing_dataframes(self) -> dict[str, pl.DataFrame]:
return {
existing_dfs = {
k: cast(pl.DataFrame, getattr(self, k))
for k in [
"controls",
Expand All @@ -106,25 +108,39 @@ def existing_dataframes(self) -> dict[str, pl.DataFrame]:
if getattr(self, k) is not None
}

def write_to_experiment(self, experiment: _OptimizerOnlyExperiment) -> None:
return {k: v for k, v in existing_dfs.items() if not v.is_empty()}

def write_to_experiment(
self, experiment: _OptimizerOnlyExperiment, overwrite: bool | None = False
) -> None:
for df_name, df in self.existing_dataframes.items():
df.write_parquet(f"{experiment.optimizer_mount_point / df_name}.parquet")
df_path = experiment.optimizer_mount_point / f"{df_name}.parquet"

if not overwrite and df_path.exists():
continue

df.write_parquet(df_path)

for batch_data in self.batches:
ensemble = experiment.get_ensemble_by_name(f"batch_{batch_data.batch_id}")
with open(
ensemble.optimizer_mount_point / "batch.json", "w+", encoding="utf-8"
) as f:
json.dump(
{
"batch_id": batch_data.batch_id,
"is_improvement": batch_data.is_improvement,
},
f,
)
df_path = ensemble.optimizer_mount_point / "batch.json"

if overwrite or not df_path.exists():
with open(df_path, "w+", encoding="utf-8") as f:
json.dump(
{
"batch_id": batch_data.batch_id,
"is_improvement": batch_data.is_improvement,
},
f,
)

for df_key, df in batch_data.existing_dataframes.items():
df.write_parquet(ensemble.optimizer_mount_point / f"{df_key}.parquet")
df_path = ensemble.optimizer_mount_point / f"{df_key}.parquet"
if not overwrite and df_path.exists():
continue

df.write_parquet(df_path)

def read_from_experiment(self, experiment: _OptimizerOnlyExperiment) -> None:
self.controls = pl.read_parquet(
Expand Down Expand Up @@ -364,11 +380,11 @@ def _ropt_to_df(

return df

def write_to_output_dir(self) -> None:
def write_to_output_dir(self, overwrite: bool | None = False) -> None:
exp = _OptimizerOnlyExperiment(self._output_dir)

# csv writing mostly for dev/debugging/quick inspection
self.data.write_to_experiment(exp)
self.data.write_to_experiment(exp, overwrite)

@staticmethod
def check_for_deprecated_seba_storage(config_file: str) -> None:
Expand Down Expand Up @@ -709,6 +725,8 @@ def on_batch_evaluation_finished(
)
)

self.write_to_output_dir(overwrite=False)

def on_optimization_finished(self) -> None:
logger.debug("Storing final results Everest storage")

Expand Down Expand Up @@ -742,7 +760,7 @@ def on_optimization_finished(self) -> None:
b.is_improvement = True
max_total_objective = total_objective

self.write_to_output_dir()
self.write_to_output_dir(overwrite=True)

def get_optimal_result(self) -> OptimalResult | None:
# Only used in tests, but re-created to ensure
Expand Down
43 changes: 43 additions & 0 deletions tests/everest/test_everest_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.run_models.everest_run_model import EverestRunModel
from everest.config import EverestConfig
from everest.everest_storage import EverestStorage


def test_save_after_one_batch(copy_math_func_test_data_to_tmp):
num_batches = 1
config = EverestConfig.load_file("config_minimal.yml")
config_dict = config.to_dict()
config_dict["optimization"]["max_batch_num"] = num_batches

config = EverestConfig(**config_dict)

n_invocations = 0

original_write_to_output_dir = None

def write_to_output_dir_intercept(*args, **kwargs):
nonlocal n_invocations
assert original_write_to_output_dir is not None
result = original_write_to_output_dir(*args, **kwargs)
n_invocations += 1
return result

# We "catch" the everest storage through __setattr__
# then assert that its .write_to_output_dir is invoked once
# per batch + one final write for adding merit values
class MockEverestRunModel(EverestRunModel):
def __setattr__(self, key, value):
nonlocal original_write_to_output_dir
if isinstance(value, EverestStorage):
ever_storage = value
original_write_to_output_dir = ever_storage.write_to_output_dir
ever_storage.write_to_output_dir = write_to_output_dir_intercept

object.__setattr__(self, key, value)

run_model = MockEverestRunModel.create(config)
run_model.run_experiment(EvaluatorServerConfig())

# Expect one per batch, + one final write after the entire experiment is done
assert n_invocations == num_batches + 1