Skip to content

Commit 801d9db

Browse files
Handle recursively serializing a dataclasses as a dictionary. (mobilityhouse#547)
see issue mobilityhouse#255
1 parent 2ab67ee commit 801d9db

File tree

3 files changed

+133
-9
lines changed

3 files changed

+133
-9
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Change log
22

3+
- [#547](https://github.com/mobilityhouse/ocpp/pull/547) Feat: Handle recursively serializing a dataclasses as a dictionary Thanks [@MacDue](https://github.com/MacDue)
34
- [#601](https://github.com/mobilityhouse/ocpp/issues/601) Fix case conversion for soc in non "State of Charge" context
45
- [#523](https://github.com/mobilityhouse/ocpp/issues/523) The serialisation of soc to SoC should not occur in camel case if it is existing at the beginning of a field
56
- [#515](https://github.com/mobilityhouse/ocpp/issues/515) Update Readthedocs configuration

ocpp/charge_point.py

+69-4
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import re
55
import time
66
import uuid
7-
from dataclasses import asdict
8-
from typing import Dict, List, Union
7+
from dataclasses import Field, asdict, is_dataclass
8+
from typing import Any, Dict, List, Union, get_args, get_origin
99

1010
from ocpp.exceptions import NotImplementedError, NotSupportedError, OCPPError
1111
from ocpp.messages import Call, MessageType, unpack, validate_payload
@@ -73,6 +73,71 @@ def snake_to_camel_case(data):
7373
return data
7474

7575

76+
def _is_dataclass_instance(input: Any) -> bool:
77+
"""Verify if given `input` is a dataclass."""
78+
return is_dataclass(input) and not isinstance(input, type)
79+
80+
81+
def _is_optional_field(field: Field) -> bool:
82+
"""Verify if given `field` allows `None` as value.
83+
84+
The fields `schema` and `host` on the following class would return `False`.
85+
While the fields `post` and `query` return `True`.
86+
87+
@dataclass
88+
class URL:
89+
schema: str,
90+
host: str,
91+
post: Optional[str],
92+
query: Union[None, str]
93+
94+
"""
95+
return get_origin(field.type) is Union and type(None) in get_args(field.type)
96+
97+
98+
def serialize_as_dict(dataclass):
99+
"""Serialize the given `dataclass` as a `dict` recursively.
100+
101+
@dataclass
102+
class StatusInfoType:
103+
reason_code: str
104+
additional_info: Optional[str] = None
105+
106+
with_additional_info = StatusInfoType(
107+
reason="Unknown",
108+
additional_info="More details"
109+
)
110+
111+
assert serialize_as_dict(with_additional_info) == {
112+
'reason': 'Unknown',
113+
'additional_info': 'More details',
114+
}
115+
116+
without_additional_info = StatusInfoType(reason="Unknown")
117+
118+
assert serialize_as_dict(with_additional_info) == {
119+
'reason': 'Unknown',
120+
'additional_info': None,
121+
}
122+
123+
"""
124+
serialized = asdict(dataclass)
125+
126+
for field in dataclass.__dataclass_fields__.values():
127+
128+
value = getattr(dataclass, field.name)
129+
if _is_dataclass_instance(value):
130+
serialized[field.name] = serialize_as_dict(value)
131+
continue
132+
133+
if isinstance(value, list):
134+
for item in value:
135+
if _is_dataclass_instance(item):
136+
serialized[field.name] = [serialize_as_dict(item)]
137+
138+
return serialized
139+
140+
76141
def remove_nones(data: Union[List, Dict]) -> Union[List, Dict]:
77142
if isinstance(data, dict):
78143
return {k: remove_nones(v) for k, v in data.items() if v is not None}
@@ -246,7 +311,7 @@ async def _handle_call(self, msg):
246311

247312
return
248313

249-
temp_response_payload = asdict(response)
314+
temp_response_payload = serialize_as_dict(response)
250315

251316
# Remove nones ensures that we strip out optional arguments
252317
# which were not set and have a default value of None
@@ -308,7 +373,7 @@ async def call(self, payload, suppress=True, unique_id=None):
308373
CallError.
309374
310375
"""
311-
camel_case_payload = snake_to_camel_case(asdict(payload))
376+
camel_case_payload = snake_to_camel_case(serialize_as_dict(payload))
312377

313378
unique_id = (
314379
unique_id if unique_id is not None else str(self._unique_id_generator())

tests/test_charge_point.py

+63-5
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22

33
import pytest
44

5-
from ocpp.charge_point import camel_to_snake_case, remove_nones, snake_to_camel_case
5+
from ocpp.charge_point import (
6+
camel_to_snake_case,
7+
remove_nones,
8+
serialize_as_dict,
9+
snake_to_camel_case,
10+
)
611
from ocpp.messages import Call
712
from ocpp.routing import after, create_route_map, on
813
from ocpp.v16 import ChargePoint as cp_16
@@ -11,8 +16,15 @@
1116
from ocpp.v16.datatypes import MeterValue, SampledValue
1217
from ocpp.v16.enums import Action, RegistrationStatus
1318
from ocpp.v201 import ChargePoint as cp_201
14-
from ocpp.v201.call import SetNetworkProfile
15-
from ocpp.v201.datatypes import NetworkConnectionProfileType
19+
from ocpp.v201.call import GetVariables as v201GetVariables
20+
from ocpp.v201.call import SetNetworkProfile as v201SetNetworkProfile
21+
from ocpp.v201.datatypes import (
22+
ComponentType,
23+
EVSEType,
24+
GetVariableDataType,
25+
NetworkConnectionProfileType,
26+
VariableType,
27+
)
1628
from ocpp.v201.enums import OCPPInterfaceType, OCPPTransportType, OCPPVersionType
1729

1830

@@ -72,7 +84,7 @@ def test_camel_to_snake_case(test_input, expected):
7284
[
7385
({"transaction_id": "74563478"}, {"transactionId": "74563478"}),
7486
({"full_soc": 100}, {"fullSoC": 100}),
75-
({"soc_limit_reached": 200}, {"SOCLimitReached": 200}),
87+
({"soc_limit_reached": 200}, {"SoCLimitReached": 200}),
7688
({"ev_min_v2x_energy_request": 200}, {"evMinV2XEnergyRequest": 200}),
7789
({"v2x_charging_ctrlr": 200}, {"v2xChargingCtrlr": 200}),
7890
({"web_socket_ping_interval": 200}, {"webSocketPingInterval": 200}),
@@ -125,7 +137,9 @@ def test_nested_remove_nones():
125137
apn=None,
126138
)
127139

128-
payload = SetNetworkProfile(configuration_slot=1, connection_data=connection_data)
140+
payload = v201SetNetworkProfile(
141+
configuration_slot=1, connection_data=connection_data
142+
)
129143
payload = asdict(payload)
130144

131145
assert expected_payload == remove_nones(payload)
@@ -246,6 +260,50 @@ def test_remove_nones_with_list_of_strings():
246260
}
247261

248262

263+
def test_serialize_as_dict():
264+
"""
265+
Test recursively serializing a dataclasses as a dictionary.
266+
"""
267+
# Setup
268+
expected = camel_to_snake_case(
269+
{
270+
"getVariableData": [
271+
{
272+
"component": {
273+
"name": "Component",
274+
"instance": None,
275+
"evse": {
276+
"id": 1,
277+
"connectorId": None,
278+
},
279+
},
280+
"variable": {
281+
"name": "Variable",
282+
"instance": None,
283+
},
284+
"attributeType": None,
285+
}
286+
],
287+
"customData": None,
288+
}
289+
)
290+
291+
payload = v201GetVariables(
292+
get_variable_data=[
293+
GetVariableDataType(
294+
component=ComponentType(
295+
name="Component",
296+
evse=EVSEType(id=1),
297+
),
298+
variable=VariableType(name="Variable"),
299+
)
300+
]
301+
)
302+
303+
# Execute / Assert
304+
assert serialize_as_dict(payload) == expected
305+
306+
249307
@pytest.mark.asyncio
250308
async def test_call_unique_id_added_to_handler_args_correctly(connection):
251309
"""

0 commit comments

Comments
 (0)