Skip to content

Commit 32006fb

Browse files
feat/unknown enumeration branches (#134)
Why === We need to be able to add new members to existing enumerations What changed ============ Adding unknown enumeration fallbacks Test plan ========= _Describe what you did to test this change to a level of detail that allows your reviewer to test it_
1 parent ecd1b17 commit 32006fb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+614
-75
lines changed

Makefile

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
lint:
2-
uv run ruff format --check .
3-
uv run ruff check .
4-
uv run mypy .
5-
uv run pyright-python .
6-
uv run deptry .
2+
uv run ruff format --check src tests
3+
uv run ruff check src tests
4+
uv run mypy src tests
5+
uv run pyright-python src tests
6+
uv run deptry src tests
77

88
format:
9-
uv run ruff format .
10-
uv run ruff check . --fix
9+
uv run ruff format src tests
10+
uv run ruff check src tests --fix

mypy.ini

+3
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,8 @@ ignore_missing_imports = True
1212
[mypy-pyd.*]
1313
ignore_missing_imports = True
1414

15+
[mypy-pytest_snapshot.*]
16+
ignore_missing_imports = True
17+
1518
[mypy-tyd.*]
1619
ignore_missing_imports = True

pyproject.toml

+6-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ dependencies = [
2424
"aiochannel>=1.2.1",
2525
"grpcio-tools>=1.59.3",
2626
"grpcio>=1.59.3",
27-
"msgpack-types>=0.3.0",
2827
"msgpack>=1.0.7",
2928
"nanoid>=2.0.0",
3029
"protobuf>=5.28.3",
@@ -37,6 +36,7 @@ dependencies = [
3736
[tool.uv]
3837
dev-dependencies = [
3938
"deptry>=0.14.0",
39+
"msgpack-types>=0.3.0",
4040
"mypy>=1.4.0",
4141
"mypy-protobuf>=3.5.0",
4242
"pytest>=7.4.0",
@@ -48,11 +48,12 @@ dev-dependencies = [
4848
"types-protobuf>=4.24.0.20240311",
4949
"types-nanoid>=2.0.0.20240601",
5050
"pyright>=1.1.389",
51+
"pytest-snapshot>=0.9.0",
5152
]
5253

5354
[tool.ruff]
5455
lint.select = ["F", "E", "W", "I001"]
55-
exclude = ["*/generated/*"]
56+
exclude = ["*/generated/*", "*/snapshots/*"]
5657

5758
# Should be kept in sync with mypy.ini in the project root.
5859
# The VSCode mypy extension can only read /mypy.ini.
@@ -91,3 +92,6 @@ ignore_missing_imports = true
9192
[build-system]
9293
requires = ["hatchling"]
9394
build-backend = "hatchling.build"
95+
96+
[tool.hatch.build.targets.wheel]
97+
packages = ["src/replit_river"]
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

replit_river/codegen/client.py src/replit_river/codegen/client.py

+68-18
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
from textwrap import dedent
66
from typing import (
77
Any,
8+
Callable,
89
Dict,
910
List,
1011
Literal,
1112
Optional,
1213
OrderedDict,
1314
Sequence,
1415
Set,
16+
TextIO,
1517
Tuple,
1618
Union,
1719
cast,
@@ -32,6 +34,7 @@
3234
TypeExpression,
3335
TypeName,
3436
UnionTypeExpr,
37+
UnknownTypeExpr,
3538
ensure_literal_type,
3639
extract_inner_type,
3740
render_type_expr,
@@ -80,6 +83,7 @@
8083
Literal,
8184
Optional,
8285
Mapping,
86+
NewType,
8387
NotRequired,
8488
Union,
8589
Tuple,
@@ -160,6 +164,7 @@ def encode_type(
160164
prefix: TypeName,
161165
base_model: str,
162166
in_module: list[ModuleName],
167+
permit_unknown_members: bool,
163168
) -> Tuple[TypeExpression, list[ModuleName], list[FileContents], set[TypeName]]:
164169
encoder_name: Optional[str] = None # defining this up here to placate mypy
165170
chunks: List[FileContents] = []
@@ -256,6 +261,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
256261
TypeName(f"{pfx}{i}"),
257262
base_model,
258263
in_module,
264+
permit_unknown_members=permit_unknown_members,
259265
)
260266
one_of.append(type_name)
261267
chunks.extend(contents)
@@ -283,7 +289,11 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
283289
else:
284290
oneof_t = oneof_ts[0]
285291
type_name, _, contents, _ = encode_type(
286-
oneof_t, TypeName(pfx), base_model, in_module
292+
oneof_t,
293+
TypeName(pfx),
294+
base_model,
295+
in_module,
296+
permit_unknown_members=permit_unknown_members,
287297
)
288298
one_of.append(type_name)
289299
chunks.extend(contents)
@@ -301,6 +311,14 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
301311
else
302312
""",
303313
)
314+
if permit_unknown_members:
315+
unknown_name = TypeName(f"{prefix}AnyOf__Unknown")
316+
chunks.append(
317+
FileContents(
318+
f"{unknown_name} = NewType({repr(unknown_name)}, object)"
319+
)
320+
)
321+
one_of.append(UnknownTypeExpr(unknown_name))
304322
chunks.append(
305323
FileContents(
306324
f"{prefix} = {render_type_expr(UnionTypeExpr(one_of))}"
@@ -336,7 +354,11 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
336354
typeddict_encoder = []
337355
for i, t in enumerate(type.anyOf):
338356
type_name, _, contents, _ = encode_type(
339-
t, TypeName(f"{prefix}AnyOf_{i}"), base_model, in_module
357+
t,
358+
TypeName(f"{prefix}AnyOf_{i}"),
359+
base_model,
360+
in_module,
361+
permit_unknown_members=permit_unknown_members,
340362
)
341363
any_of.append(type_name)
342364
chunks.extend(contents)
@@ -363,6 +385,12 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
363385
typeddict_encoder.append(
364386
f"encode_{ensure_literal_type(other)}(x)"
365387
)
388+
if permit_unknown_members:
389+
unknown_name = TypeName(f"{prefix}AnyOf__Unknown")
390+
chunks.append(
391+
FileContents(f"{unknown_name} = NewType({repr(unknown_name)}, object)")
392+
)
393+
any_of.append(UnknownTypeExpr(unknown_name))
366394
if is_literal(type):
367395
typeddict_encoder = ["x"]
368396
chunks.append(
@@ -404,6 +432,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
404432
prefix,
405433
base_model,
406434
in_module,
435+
permit_unknown_members=permit_unknown_members,
407436
)
408437
elif isinstance(type, RiverConcreteType):
409438
typeddict_encoder = list[str]()
@@ -446,7 +475,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
446475
return (TypeName("datetime.datetime"), [], [], set())
447476
elif type.type == "array" and type.items:
448477
type_name, module_info, type_chunks, encoder_names = encode_type(
449-
type.items, prefix, base_model, in_module
478+
type.items,
479+
prefix,
480+
base_model,
481+
in_module,
482+
permit_unknown_members=permit_unknown_members,
450483
)
451484
typeddict_encoder.append("TODO: dstewart")
452485
return (ListTypeExpr(type_name), module_info, type_chunks, encoder_names)
@@ -460,6 +493,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
460493
prefix,
461494
base_model,
462495
in_module,
496+
permit_unknown_members=permit_unknown_members,
463497
)
464498
# TODO(dstewart): This structure changed since we were incorrectly leaking
465499
# ListTypeExprs into codegen. This generated code is
@@ -494,7 +528,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
494528
) in sorted(list(type.properties.items()), key=lambda xs: xs[0]):
495529
typeddict_encoder.append(f"{repr(name)}:")
496530
type_name, _, contents, _ = encode_type(
497-
prop, TypeName(prefix + name.title()), base_model, in_module
531+
prop,
532+
TypeName(prefix + name.title()),
533+
base_model,
534+
in_module,
535+
permit_unknown_members=permit_unknown_members,
498536
)
499537
encoder_name = None
500538
chunks.extend(contents)
@@ -685,7 +723,7 @@ def generate_common_client(
685723
chunks.extend(
686724
[
687725
f"from .{model_name} import {class_name}"
688-
for model_name, class_name in modules
726+
for model_name, class_name in sorted(modules, key=lambda kv: kv[1])
689727
]
690728
)
691729
chunks.extend(handshake_chunks)
@@ -732,6 +770,7 @@ def __init__(self, client: river.Client[Any]):
732770
TypeName(f"{name.title()}Init"),
733771
input_base_class,
734772
module_names,
773+
permit_unknown_members=False,
735774
)
736775
serdes.append(
737776
(
@@ -745,6 +784,7 @@ def __init__(self, client: river.Client[Any]):
745784
TypeName(f"{name.title()}Input"),
746785
input_base_class,
747786
module_names,
787+
permit_unknown_members=False,
748788
)
749789
serdes.append(
750790
(
@@ -758,6 +798,7 @@ def __init__(self, client: river.Client[Any]):
758798
TypeName(f"{name.title()}Output"),
759799
"BaseModel",
760800
module_names,
801+
permit_unknown_members=True,
761802
)
762803
serdes.append(
763804
(
@@ -772,6 +813,7 @@ def __init__(self, client: river.Client[Any]):
772813
TypeName(f"{name.title()}Errors"),
773814
"RiverError",
774815
module_names,
816+
permit_unknown_members=True,
775817
)
776818
if error_type == "None":
777819
error_type = TypeName("RiverError")
@@ -822,9 +864,9 @@ def __init__(self, client: river.Client[Any]):
822864
.validate_python
823865
"""
824866

825-
assert (
826-
init_type is None or render_init_method
827-
), f"Unable to derive the init encoder from: {input_type}"
867+
assert init_type is None or render_init_method, (
868+
f"Unable to derive the init encoder from: {input_type}"
869+
)
828870

829871
# Input renderer
830872
render_input_method: Optional[str] = None
@@ -862,9 +904,9 @@ def __init__(self, client: river.Client[Any]):
862904
):
863905
render_input_method = "lambda x: x"
864906

865-
assert (
866-
render_input_method
867-
), f"Unable to derive the input encoder from: {input_type}"
907+
assert render_input_method, (
908+
f"Unable to derive the input encoder from: {input_type}"
909+
)
868910

869911
if output_type == "None":
870912
parse_output_method = "lambda x: None"
@@ -1038,7 +1080,7 @@ async def {name}(
10381080
emitted_files[file_path] = FileContents("\n".join([existing] + contents))
10391081

10401082
rendered_imports = [
1041-
f"from .{dotted_modules} import {', '.join(names)}"
1083+
f"from .{dotted_modules} import {', '.join(sorted(names))}"
10421084
for dotted_modules, names in imports.items()
10431085
]
10441086

@@ -1063,7 +1105,11 @@ def generate_river_client_module(
10631105
handshake_chunks: list[str] = []
10641106
if schema_root.handshakeSchema is not None:
10651107
_handshake_type, _, contents, _ = encode_type(
1066-
schema_root.handshakeSchema, TypeName("HandshakeSchema"), "BaseModel", []
1108+
schema_root.handshakeSchema,
1109+
TypeName("HandshakeSchema"),
1110+
"BaseModel",
1111+
[],
1112+
permit_unknown_members=False,
10671113
)
10681114
handshake_chunks.extend(contents)
10691115
handshake_type = HandshakeType(render_type_expr(_handshake_type))
@@ -1090,25 +1136,29 @@ def generate_river_client_module(
10901136

10911137

10921138
def schema_to_river_client_codegen(
1093-
schema_path: str,
1139+
read_schema: Callable[[], TextIO],
10941140
target_path: str,
10951141
client_name: str,
10961142
typed_dict_inputs: bool,
1143+
file_opener: Callable[[Path], TextIO],
10971144
) -> None:
10981145
"""Generates the lines of a River module."""
1099-
with open(schema_path) as f:
1146+
with read_schema() as f:
11001147
schemas = RiverSchemaFile(json.load(f))
11011148
for subpath, contents in generate_river_client_module(
11021149
client_name, schemas.root, typed_dict_inputs
11031150
).items():
11041151
module_path = Path(target_path).joinpath(subpath)
11051152
module_path.parent.mkdir(mode=0o755, parents=True, exist_ok=True)
1106-
with open(module_path, "w") as f:
1153+
with file_opener(module_path) as f:
11071154
try:
11081155
popen = subprocess.Popen(
1109-
["ruff", "format", "-"], stdin=subprocess.PIPE, stdout=f
1156+
["ruff", "format", "-"],
1157+
stdin=subprocess.PIPE,
1158+
stdout=subprocess.PIPE,
11101159
)
1111-
popen.communicate(contents.encode())
1160+
stdout, _ = popen.communicate(contents.encode())
1161+
f.write(stdout.decode("utf-8"))
11121162
except:
11131163
f.write(contents)
11141164
raise
File renamed without changes.

replit_river/codegen/run.py src/replit_river/codegen/run.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import argparse
22
import os.path
3+
from pathlib import Path
4+
from typing import TextIO
35

46
from .client import schema_to_river_client_codegen
57
from .schema import proto_to_river_schema_codegen
@@ -50,8 +52,16 @@ def main() -> None:
5052
elif args.command == "client":
5153
schema_path = os.path.abspath(args.schema)
5254
target_path = os.path.abspath(args.output)
55+
56+
def file_opener(path: Path) -> TextIO:
57+
return open(path, "w")
58+
5359
schema_to_river_client_codegen(
54-
schema_path, target_path, args.client_name, args.typed_dict_inputs
60+
lambda: open(schema_path),
61+
target_path,
62+
args.client_name,
63+
args.typed_dict_inputs,
64+
file_opener,
5565
)
5666
else:
5767
raise NotImplementedError(f"Unknown command {args.command}")
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)