5
5
from textwrap import dedent
6
6
from typing import (
7
7
Any ,
8
+ Callable ,
8
9
Dict ,
9
10
List ,
10
11
Literal ,
11
12
Optional ,
12
13
OrderedDict ,
13
14
Sequence ,
14
15
Set ,
16
+ TextIO ,
15
17
Tuple ,
16
18
Union ,
17
19
cast ,
32
34
TypeExpression ,
33
35
TypeName ,
34
36
UnionTypeExpr ,
37
+ UnknownTypeExpr ,
35
38
ensure_literal_type ,
36
39
extract_inner_type ,
37
40
render_type_expr ,
80
83
Literal,
81
84
Optional,
82
85
Mapping,
86
+ NewType,
83
87
NotRequired,
84
88
Union,
85
89
Tuple,
@@ -160,6 +164,7 @@ def encode_type(
160
164
prefix : TypeName ,
161
165
base_model : str ,
162
166
in_module : list [ModuleName ],
167
+ permit_unknown_members : bool ,
163
168
) -> Tuple [TypeExpression , list [ModuleName ], list [FileContents ], set [TypeName ]]:
164
169
encoder_name : Optional [str ] = None # defining this up here to placate mypy
165
170
chunks : List [FileContents ] = []
@@ -256,6 +261,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
256
261
TypeName (f"{ pfx } { i } " ),
257
262
base_model ,
258
263
in_module ,
264
+ permit_unknown_members = permit_unknown_members ,
259
265
)
260
266
one_of .append (type_name )
261
267
chunks .extend (contents )
@@ -283,7 +289,11 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
283
289
else :
284
290
oneof_t = oneof_ts [0 ]
285
291
type_name , _ , contents , _ = encode_type (
286
- oneof_t , TypeName (pfx ), base_model , in_module
292
+ oneof_t ,
293
+ TypeName (pfx ),
294
+ base_model ,
295
+ in_module ,
296
+ permit_unknown_members = permit_unknown_members ,
287
297
)
288
298
one_of .append (type_name )
289
299
chunks .extend (contents )
@@ -301,6 +311,14 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
301
311
else
302
312
""" ,
303
313
)
314
+ if permit_unknown_members :
315
+ unknown_name = TypeName (f"{ prefix } AnyOf__Unknown" )
316
+ chunks .append (
317
+ FileContents (
318
+ f"{ unknown_name } = NewType({ repr (unknown_name )} , object)"
319
+ )
320
+ )
321
+ one_of .append (UnknownTypeExpr (unknown_name ))
304
322
chunks .append (
305
323
FileContents (
306
324
f"{ prefix } = { render_type_expr (UnionTypeExpr (one_of ))} "
@@ -336,7 +354,11 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
336
354
typeddict_encoder = []
337
355
for i , t in enumerate (type .anyOf ):
338
356
type_name , _ , contents , _ = encode_type (
339
- t , TypeName (f"{ prefix } AnyOf_{ i } " ), base_model , in_module
357
+ t ,
358
+ TypeName (f"{ prefix } AnyOf_{ i } " ),
359
+ base_model ,
360
+ in_module ,
361
+ permit_unknown_members = permit_unknown_members ,
340
362
)
341
363
any_of .append (type_name )
342
364
chunks .extend (contents )
@@ -363,6 +385,12 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
363
385
typeddict_encoder .append (
364
386
f"encode_{ ensure_literal_type (other )} (x)"
365
387
)
388
+ if permit_unknown_members :
389
+ unknown_name = TypeName (f"{ prefix } AnyOf__Unknown" )
390
+ chunks .append (
391
+ FileContents (f"{ unknown_name } = NewType({ repr (unknown_name )} , object)" )
392
+ )
393
+ any_of .append (UnknownTypeExpr (unknown_name ))
366
394
if is_literal (type ):
367
395
typeddict_encoder = ["x" ]
368
396
chunks .append (
@@ -404,6 +432,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
404
432
prefix ,
405
433
base_model ,
406
434
in_module ,
435
+ permit_unknown_members = permit_unknown_members ,
407
436
)
408
437
elif isinstance (type , RiverConcreteType ):
409
438
typeddict_encoder = list [str ]()
@@ -446,7 +475,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
446
475
return (TypeName ("datetime.datetime" ), [], [], set ())
447
476
elif type .type == "array" and type .items :
448
477
type_name , module_info , type_chunks , encoder_names = encode_type (
449
- type .items , prefix , base_model , in_module
478
+ type .items ,
479
+ prefix ,
480
+ base_model ,
481
+ in_module ,
482
+ permit_unknown_members = permit_unknown_members ,
450
483
)
451
484
typeddict_encoder .append ("TODO: dstewart" )
452
485
return (ListTypeExpr (type_name ), module_info , type_chunks , encoder_names )
@@ -460,6 +493,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
460
493
prefix ,
461
494
base_model ,
462
495
in_module ,
496
+ permit_unknown_members = permit_unknown_members ,
463
497
)
464
498
# TODO(dstewart): This structure changed since we were incorrectly leaking
465
499
# ListTypeExprs into codegen. This generated code is
@@ -494,7 +528,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
494
528
) in sorted (list (type .properties .items ()), key = lambda xs : xs [0 ]):
495
529
typeddict_encoder .append (f"{ repr (name )} :" )
496
530
type_name , _ , contents , _ = encode_type (
497
- prop , TypeName (prefix + name .title ()), base_model , in_module
531
+ prop ,
532
+ TypeName (prefix + name .title ()),
533
+ base_model ,
534
+ in_module ,
535
+ permit_unknown_members = permit_unknown_members ,
498
536
)
499
537
encoder_name = None
500
538
chunks .extend (contents )
@@ -685,7 +723,7 @@ def generate_common_client(
685
723
chunks .extend (
686
724
[
687
725
f"from .{ model_name } import { class_name } "
688
- for model_name , class_name in modules
726
+ for model_name , class_name in sorted ( modules , key = lambda kv : kv [ 1 ])
689
727
]
690
728
)
691
729
chunks .extend (handshake_chunks )
@@ -732,6 +770,7 @@ def __init__(self, client: river.Client[Any]):
732
770
TypeName (f"{ name .title ()} Init" ),
733
771
input_base_class ,
734
772
module_names ,
773
+ permit_unknown_members = False ,
735
774
)
736
775
serdes .append (
737
776
(
@@ -745,6 +784,7 @@ def __init__(self, client: river.Client[Any]):
745
784
TypeName (f"{ name .title ()} Input" ),
746
785
input_base_class ,
747
786
module_names ,
787
+ permit_unknown_members = False ,
748
788
)
749
789
serdes .append (
750
790
(
@@ -758,6 +798,7 @@ def __init__(self, client: river.Client[Any]):
758
798
TypeName (f"{ name .title ()} Output" ),
759
799
"BaseModel" ,
760
800
module_names ,
801
+ permit_unknown_members = True ,
761
802
)
762
803
serdes .append (
763
804
(
@@ -772,6 +813,7 @@ def __init__(self, client: river.Client[Any]):
772
813
TypeName (f"{ name .title ()} Errors" ),
773
814
"RiverError" ,
774
815
module_names ,
816
+ permit_unknown_members = True ,
775
817
)
776
818
if error_type == "None" :
777
819
error_type = TypeName ("RiverError" )
@@ -822,9 +864,9 @@ def __init__(self, client: river.Client[Any]):
822
864
.validate_python
823
865
"""
824
866
825
- assert (
826
- init_type is None or render_init_method
827
- ), f"Unable to derive the init encoder from: { input_type } "
867
+ assert init_type is None or render_init_method , (
868
+ f"Unable to derive the init encoder from: { input_type } "
869
+ )
828
870
829
871
# Input renderer
830
872
render_input_method : Optional [str ] = None
@@ -862,9 +904,9 @@ def __init__(self, client: river.Client[Any]):
862
904
):
863
905
render_input_method = "lambda x: x"
864
906
865
- assert (
866
- render_input_method
867
- ), f"Unable to derive the input encoder from: { input_type } "
907
+ assert render_input_method , (
908
+ f"Unable to derive the input encoder from: { input_type } "
909
+ )
868
910
869
911
if output_type == "None" :
870
912
parse_output_method = "lambda x: None"
@@ -1038,7 +1080,7 @@ async def {name}(
1038
1080
emitted_files [file_path ] = FileContents ("\n " .join ([existing ] + contents ))
1039
1081
1040
1082
rendered_imports = [
1041
- f"from .{ dotted_modules } import { ', ' .join (names )} "
1083
+ f"from .{ dotted_modules } import { ', ' .join (sorted ( names ) )} "
1042
1084
for dotted_modules , names in imports .items ()
1043
1085
]
1044
1086
@@ -1063,7 +1105,11 @@ def generate_river_client_module(
1063
1105
handshake_chunks : list [str ] = []
1064
1106
if schema_root .handshakeSchema is not None :
1065
1107
_handshake_type , _ , contents , _ = encode_type (
1066
- schema_root .handshakeSchema , TypeName ("HandshakeSchema" ), "BaseModel" , []
1108
+ schema_root .handshakeSchema ,
1109
+ TypeName ("HandshakeSchema" ),
1110
+ "BaseModel" ,
1111
+ [],
1112
+ permit_unknown_members = False ,
1067
1113
)
1068
1114
handshake_chunks .extend (contents )
1069
1115
handshake_type = HandshakeType (render_type_expr (_handshake_type ))
@@ -1090,25 +1136,29 @@ def generate_river_client_module(
1090
1136
1091
1137
1092
1138
def schema_to_river_client_codegen (
1093
- schema_path : str ,
1139
+ read_schema : Callable [[], TextIO ] ,
1094
1140
target_path : str ,
1095
1141
client_name : str ,
1096
1142
typed_dict_inputs : bool ,
1143
+ file_opener : Callable [[Path ], TextIO ],
1097
1144
) -> None :
1098
1145
"""Generates the lines of a River module."""
1099
- with open ( schema_path ) as f :
1146
+ with read_schema ( ) as f :
1100
1147
schemas = RiverSchemaFile (json .load (f ))
1101
1148
for subpath , contents in generate_river_client_module (
1102
1149
client_name , schemas .root , typed_dict_inputs
1103
1150
).items ():
1104
1151
module_path = Path (target_path ).joinpath (subpath )
1105
1152
module_path .parent .mkdir (mode = 0o755 , parents = True , exist_ok = True )
1106
- with open (module_path , "w" ) as f :
1153
+ with file_opener (module_path ) as f :
1107
1154
try :
1108
1155
popen = subprocess .Popen (
1109
- ["ruff" , "format" , "-" ], stdin = subprocess .PIPE , stdout = f
1156
+ ["ruff" , "format" , "-" ],
1157
+ stdin = subprocess .PIPE ,
1158
+ stdout = subprocess .PIPE ,
1110
1159
)
1111
- popen .communicate (contents .encode ())
1160
+ stdout , _ = popen .communicate (contents .encode ())
1161
+ f .write (stdout .decode ("utf-8" ))
1112
1162
except :
1113
1163
f .write (contents )
1114
1164
raise
0 commit comments