diff --git a/src/ert/callbacks.py b/src/ert/callbacks.py index 6d1c297162a..c88312e1747 100644 --- a/src/ert/callbacks.py +++ b/src/ert/callbacks.py @@ -5,7 +5,7 @@ import time from pathlib import Path -from ert.config import InvalidResponseFile +from ert.config import InvalidResponseFile, ParameterSource from ert.storage import Ensemble from ert.storage.realization_storage_state import RealizationStorageState @@ -24,7 +24,7 @@ async def _read_parameters( error_msg = "" parameter_configuration = ensemble.experiment.parameter_configuration.values() for config in parameter_configuration: - if not config.forward_init: + if config.source != ParameterSource.forward_init: continue try: start_time = time.perf_counter() diff --git a/src/ert/config/__init__.py b/src/ert/config/__init__.py index 85a1b8e2fb5..15fce124103 100644 --- a/src/ert/config/__init__.py +++ b/src/ert/config/__init__.py @@ -22,7 +22,7 @@ from .lint_file import lint_file from .model_config import ModelConfig from .observations import EnkfObs -from .parameter_config import ParameterConfig +from .parameter_config import ParameterConfig, ParameterSource from .parsing import ( ConfigValidationError, ConfigWarning, @@ -72,6 +72,7 @@ "ObservationGroups", "ObservationType", "ParameterConfig", + "ParameterSource", "PriorDict", "QueueConfig", "QueueSystem", diff --git a/src/ert/config/design_matrix.py b/src/ert/config/design_matrix.py index f446b36ec72..6de4a413146 100644 --- a/src/ert/config/design_matrix.py +++ b/src/ert/config/design_matrix.py @@ -8,7 +8,11 @@ import pandas as pd from pandas.api.types import is_integer_dtype -from ert.config.gen_kw_config import GenKwConfig, TransformFunctionDefinition +from ert.config.gen_kw_config import ( + GenKwConfig, + ParameterSource, + TransformFunctionDefinition, +) from ._option_dict import option_dict from .parsing import ConfigValidationError, ErrorInfo @@ -211,7 +215,7 @@ def read_design_matrix( ) parameter_configuration = GenKwConfig( name=DESIGN_MATRIX_GROUP, - forward_init=False, + source=ParameterSource.design_matrix, template_file=None, output_file=None, transform_function_definitions=transform_function_definitions, diff --git a/src/ert/config/ext_param_config.py b/src/ert/config/ext_param_config.py index cd9378f02c5..cd2b48fffe0 100644 --- a/src/ert/config/ext_param_config.py +++ b/src/ert/config/ext_param_config.py @@ -11,7 +11,7 @@ from ert.substitutions import substitute_runpath_name -from .parameter_config import ParameterConfig +from .parameter_config import ParameterConfig, ParameterSource if TYPE_CHECKING: import numpy.typing as npt @@ -33,7 +33,7 @@ class ExtParamConfig(ParameterConfig): """ input_keys: list[str] | dict[str, list[str]] = field(default_factory=list) - forward_init: bool = False + source: ParameterSource = ParameterSource.sampled output_file: str = "" forward_init_file: str = "" update: bool = False diff --git a/src/ert/config/field.py b/src/ert/config/field.py index 0e2a1bbd525..026126232d3 100644 --- a/src/ert/config/field.py +++ b/src/ert/config/field.py @@ -16,7 +16,7 @@ from ._option_dict import option_dict from ._str_to_bool import str_to_bool -from .parameter_config import ParameterConfig +from .parameter_config import ParameterConfig, ParameterSource from .parsing import ConfigValidationError, ConfigWarning if TYPE_CHECKING: @@ -125,7 +125,11 @@ def from_config_list( input_transformation=init_transform, truncation_max=float(max_) if max_ is not None else None, truncation_min=float(min_) if min_ is not None else None, - forward_init=forward_init, + source=( + ParameterSource.forward_init + if forward_init + else ParameterSource.sampled + ), forward_init_file=init_files, output_file=out_file, grid_file=os.path.abspath(grid_file_path), diff --git a/src/ert/config/gen_kw_config.py b/src/ert/config/gen_kw_config.py index a1c68ea0883..8af38cd3bff 100644 --- a/src/ert/config/gen_kw_config.py +++ b/src/ert/config/gen_kw_config.py @@ -19,7 +19,7 @@ from ert.substitutions import substitute_runpath_name from ._str_to_bool import str_to_bool -from .parameter_config import ParameterConfig, parse_config +from .parameter_config import ParameterConfig, ParameterSource, parse_config from .parsing import ConfigValidationError, ConfigWarning, ErrorInfo if TYPE_CHECKING: @@ -194,12 +194,16 @@ def from_config_list(cls, gen_kw: list[str]) -> Self: ) return cls( name=gen_kw_key, - forward_init=forward_init, template_file=template_file, output_file=output_file, forward_init_file=init_file, transform_function_definitions=transform_function_definitions, update=update_parameter, + source=( + ParameterSource.forward_init + if forward_init + else ParameterSource.sampled + ), ) def _validate(self) -> None: diff --git a/src/ert/config/parameter_config.py b/src/ert/config/parameter_config.py index cb5369f2d00..999ebc73bb8 100644 --- a/src/ert/config/parameter_config.py +++ b/src/ert/config/parameter_config.py @@ -2,6 +2,7 @@ import dataclasses from abc import ABC, abstractmethod +from enum import StrEnum, auto from pathlib import Path from typing import TYPE_CHECKING, Any @@ -46,11 +47,17 @@ def parse_config( return args, kwargs +class ParameterSource(StrEnum): + forward_init = auto() + sampled = auto() + design_matrix = auto() + + @dataclasses.dataclass class ParameterConfig(ABC): name: str - forward_init: bool update: bool + source: ParameterSource def sample_or_load( self, diff --git a/src/ert/config/surface_config.py b/src/ert/config/surface_config.py index 13bd0c266c2..f0efc23a0d8 100644 --- a/src/ert/config/surface_config.py +++ b/src/ert/config/surface_config.py @@ -12,7 +12,7 @@ from ._option_dict import option_dict from ._str_to_bool import str_to_bool -from .parameter_config import ParameterConfig +from .parameter_config import ParameterConfig, ParameterSource from .parsing import ConfigValidationError, ErrorInfo if TYPE_CHECKING: @@ -88,7 +88,11 @@ def from_config_list(cls, surface: list[str]) -> Self: rotation=surf.rotation, yflip=surf.yflip, name=name, - forward_init=forward_init, + source=( + ParameterSource.forward_init + if forward_init + else ParameterSource.sampled + ), forward_init_file=init_file, output_file=Path(out_file), base_surface_path=base_surface, diff --git a/src/ert/enkf_main.py b/src/ert/enkf_main.py index 263d664d56c..dc01085c4ce 100644 --- a/src/ert/enkf_main.py +++ b/src/ert/enkf_main.py @@ -19,7 +19,14 @@ from ert.config.model_config import ModelConfig from ert.substitutions import Substitutions, substitute_runpath_name -from .config import ExtParamConfig, Field, GenKwConfig, ParameterConfig, SurfaceConfig +from .config import ( + ExtParamConfig, + Field, + GenKwConfig, + ParameterConfig, + ParameterSource, + SurfaceConfig, +) from .config.design_matrix import DESIGN_MATRIX_GROUP from .run_arg import RunArg from .runpaths import Runpaths @@ -106,7 +113,7 @@ def _generate_parameter_files( # For the first iteration we do not write the parameter # to run path, as we expect to read if after the forward # model has completed. - if node.forward_init and iteration == 0: + if node.source == ParameterSource.forward_init and iteration == 0: continue export_values = node.write_to_runpath(Path(run_path), iens, fs) if export_values: @@ -125,7 +132,10 @@ def _manifest_to_json(ensemble: Ensemble, iens: int, iter: int) -> dict[str, Any param_config, ExtParamConfig | GenKwConfig | Field | SurfaceConfig, ) - if param_config.forward_init and ensemble.iteration == 0: + if ( + param_config.source == ParameterSource.forward_init + and ensemble.iteration == 0 + ): assert param_config.forward_init_file is not None file_path = substitute_runpath_name( param_config.forward_init_file, iens, iter @@ -198,8 +208,12 @@ def sample_prior( parameters = list(parameter_configs.keys()) for parameter in parameters: config_node = parameter_configs[parameter] - if config_node.forward_init: + if config_node.source in { + ParameterSource.forward_init, + ParameterSource.design_matrix, + }: continue + logger.info( f"Sampling parameter {config_node.name} for realizations {active_realizations}" ) diff --git a/src/ert/storage/local_ensemble.py b/src/ert/storage/local_ensemble.py index 6d319b1c1d2..076809787ac 100644 --- a/src/ert/storage/local_ensemble.py +++ b/src/ert/storage/local_ensemble.py @@ -16,7 +16,7 @@ from pydantic import BaseModel from typing_extensions import deprecated -from ert.config.gen_kw_config import GenKwConfig +from ert.config.gen_kw_config import GenKwConfig, ParameterSource from ert.storage.mode import BaseMode, Mode, require_write from .realization_storage_state import RealizationStorageState @@ -277,7 +277,7 @@ def is_initalized(self) -> list[int]: / (_escape_filename(parameter.name) + ".nc") ).exists() for parameter in self.experiment.parameter_configuration.values() - if not parameter.forward_init + if parameter.source != ParameterSource.forward_init ) ] diff --git a/src/ert/storage/local_storage.py b/src/ert/storage/local_storage.py index 9ef4d3348bb..ba435d31a36 100644 --- a/src/ert/storage/local_storage.py +++ b/src/ert/storage/local_storage.py @@ -456,6 +456,7 @@ def _migrate(self, version: int) -> None: to7, to8, to9, + to10, ) try: @@ -500,7 +501,7 @@ def _migrate(self, version: int) -> None: elif version < _LOCAL_STORAGE_VERSION: migrations = list( - enumerate([to2, to3, to4, to5, to6, to7, to8, to9], start=1) + enumerate([to2, to3, to4, to5, to6, to7, to8, to9, to10], start=1) ) for from_version, migration in migrations[version - 1 :]: print(f"* Updating storage to version: {from_version + 1}") diff --git a/src/ert/storage/migration/to10.py b/src/ert/storage/migration/to10.py new file mode 100644 index 00000000000..cdfcee7d758 --- /dev/null +++ b/src/ert/storage/migration/to10.py @@ -0,0 +1,19 @@ +import json +from pathlib import Path + +info = "Add design field into GenKwConfig" + + +def migrate(path: Path) -> None: + for experiment in path.glob("experiments/*"): + with open(experiment / "parameter.json", encoding="utf-8") as fin: + parameters_json = json.load(fin) + + with open(experiment / "parameter.json", "w", encoding="utf-8") as fout: + for param in parameters_json.values(): + if param.get("forward_init") == True: + param["source"] = "forward_init" + else: + param["source"] = "sampled" + del param["forward_init"] + fout.write(json.dumps(parameters_json, indent=4)) diff --git a/tests/ert/unit_tests/storage/snapshots/test_storage_migration/test_that_storage_matches/parameters b/tests/ert/unit_tests/storage/snapshots/test_storage_migration/test_that_storage_matches/parameters index 74f07f09cf8..8fcdacb8780 100644 --- a/tests/ert/unit_tests/storage/snapshots/test_storage_migration/test_that_storage_matches/parameters +++ b/tests/ert/unit_tests/storage/snapshots/test_storage_migration/test_that_storage_matches/parameters @@ -1 +1 @@ -{'BPR': GenKwConfig(name='BPR', forward_init=False, update=True, template_file='/home/eivind/Projects/ert/test-data/all_data_types/template.txt', output_file='params.txt', transform_function_definitions=[{'name': 'BPR', 'param_name': 'NORMAL', 'values': ['0', '1']}], forward_init_file=None), 'PORO': Field(name='PORO', forward_init=False, update=True, nx=2, ny=3, nz=4, file_format=, output_transformation=None, input_transformation=None, truncation_min=None, truncation_max=None, forward_init_file='data/poro%d.grdecl', output_file=PosixPath('poro.grdecl'), grid_file='/home/eivind/Projects/ert/test-data/all_data_types/refcase/CASE.EGRID', mask_file=''), 'TOP': SurfaceConfig(name='TOP', forward_init=False, update=True, ncol=2, nrow=3, xori=0.0, yori=0.0, xinc=1.0, yinc=1.0, rotation=0.0, yflip=1, forward_init_file='data/surf%d.irap', output_file='surf.irap', base_surface_path='data/basesurf.irap')} +{'BPR': GenKwConfig(name='BPR', forward_init=False, update=True, template_file='/home/eivind/Projects/ert/test-data/all_data_types/template.txt', output_file='params.txt', transform_function_definitions=[{'name': 'BPR', 'param_name': 'NORMAL', 'values': ['0', '1']}], forward_init_file=None, design=False), 'PORO': Field(name='PORO', forward_init=False, update=True, nx=2, ny=3, nz=4, file_format=, output_transformation=None, input_transformation=None, truncation_min=None, truncation_max=None, forward_init_file='data/poro%d.grdecl', output_file=PosixPath('poro.grdecl'), grid_file='/home/eivind/Projects/ert/test-data/all_data_types/refcase/CASE.EGRID', mask_file=''), 'TOP': SurfaceConfig(name='TOP', forward_init=False, update=True, ncol=2, nrow=3, xori=0.0, yori=0.0, xinc=1.0, yinc=1.0, rotation=0.0, yflip=1, forward_init_file='data/surf%d.irap', output_file='surf.irap', base_surface_path='data/basesurf.irap')}