From b3270d1818fe68e4f4f414f17ce0b8fa81fd0ba6 Mon Sep 17 00:00:00 2001 From: ghostplant <12099308+ghostplant@users.noreply.github.com> Date: Sun, 16 Feb 2025 14:55:58 +0000 Subject: [PATCH] add deepseek-r1 gating & mla for AMD MI300 --- README.md | 119 +++++++++------ doc/DeepSeekR1-tutel-accel.png | Bin 0 -> 188854 bytes setup.py | 21 ++- tutel/custom/antares_ops.h | 260 +++++++++++++++++++++++++++++++++ tutel/custom/backend.hpp | 224 ++++++++++++++++++++++++++++ tutel/custom/custom_kernel.cpp | 110 ++++++++++++++ tutel/system.py | 1 + 7 files changed, 683 insertions(+), 52 deletions(-) create mode 100644 doc/DeepSeekR1-tutel-accel.png create mode 100644 tutel/custom/antares_ops.h create mode 100644 tutel/custom/backend.hpp diff --git a/README.md b/README.md index 9f75b95..70e80b8 100644 --- a/README.md +++ b/README.md @@ -6,19 +6,41 @@ Tutel MoE: An Optimized Mixture-of-Experts Implementation, also the first parall - Supported GPUs: CUDA(fp64/fp32/fp16/bfp16), ROCm(fp64/fp32/fp16) - Supported CPU: fp64/fp32 +### ***Support Full Precision Inference of MoE-based Deepseek R1 671B on AMD MI300:*** -### What's New: +We compare three solutions that support Full-Precision Inference (PPL = 0) of Deepseek R1 671B. PPL = 0 means any quantization or unofficial sparsity techniques that may lower the scores of model, are prohibited. + +![benchmarking](doc/DeepSeekR1-tutel-accel.png) + +----------- + +## What's New: + +- Tutel v0.4.0: Accelerating Deepseek R1 Full-precision-Chat for AMD MI300x8 (more platform support will be added in later versions): +```sh + >> Example: + + # Step-1: Download Deepseek R1 671B Model + huggingface-cli download deepseek-ai/DeepSeek-R1 --local-dir ./deepseek-ai/DeepSeek-R1 + + # Step-2: Using 8 MI300 GPUs to Run Deepseek R1 Chat with Full Precision (PPL = 0) + docker run -it --rm --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --privileged \ + -v /:/host -w /host$(pwd) tutelgroup/deepseek-671b:mi300x8-fp16xfp8 \ + --model_path ./deepseek-ai/DeepSeek-R1 \ + --prompt "Calculate the result of: 1 / (sqrt(5) - sqrt(3))" + +``` - Tutel v0.3.3: Add all-to-all benchmark: -```py +```sh >> Example: python3 -m torch.distributed.run --nproc_per_node=8 -m tutel.examples.bandwidth_test --size_mb=256 ``` - Tutel v0.3.2: Add tensorcore option for extra benchmarks / Extend the example for custom experts / Allow NCCL timeout settings: -```py - >> Example for using tensorcore: +```sh + >> Example of using tensorcore: python3 -m tutel.examples.helloworld --dtype=float32 python3 -m tutel.examples.helloworld --dtype=float32 --use_tensorcore @@ -26,16 +48,16 @@ Tutel MoE: An Optimized Mixture-of-Experts Implementation, also the first parall python3 -m tutel.examples.helloworld --dtype=float16 python3 -m tutel.examples.helloworld --dtype=float16 --use_tensorcore - >> Example for custom gates/experts: + >> Example of custom gates/experts: python3 -m tutel.examples.helloworld_custom_gate_expert --batch_size=16 - >> Example for NCCL timeout settings: + >> Example of NCCL timeout settings: TUTEL_GLOBAL_TIMEOUT_SEC=60 python3 -m torch.distributed.run --nproc_per_node=8 -m tutel.examples.helloworld --use_tensorcore ``` - Tutel v0.3.1: Add NCCL all_to_all_v and all_gather_v for arbitrary-length message transfers: -```py +```sh >> Example: # All_to_All_v: python3 -m torch.distributed.run --nproc_per_node=2 --master_port=7340 -m tutel.examples.nccl_all_to_all_v @@ -48,8 +70,8 @@ Tutel MoE: An Optimized Mixture-of-Experts Implementation, also the first parall ``` - Tutel v0.3: Add Megablocks solution to improve decoder inference on single-GPU with num_local_expert >= 2: -```py - >> Example (capacity_factor=0 for dropless-MoE): +```sh + >> Example (capacity_factor=0 required by dropless-MoE): # Using BatchMatmul: python3 -m tutel.examples.helloworld --megablocks_size=0 --batch_size=1 --num_tokens=32 --top=1 --eval --num_local_experts=128 --capacity_factor=0 # Using Megablocks with block_size = 1: @@ -62,7 +84,7 @@ Tutel MoE: An Optimized Mixture-of-Experts Implementation, also the first parall ``` - Tutel v0.2: Allow most configurations to be dynamic switchable with free cost: -```py +```sh >> Example: python3 -m torch.distributed.run --nproc_per_node=8 -m tutel.examples.helloworld_switch --batch_size=16 @@ -74,35 +96,41 @@ Tutel MoE: An Optimized Mixture-of-Experts Implementation, also the first parall ``` - Tutel v0.1: Optimize the Einsum Complexity of Data Dispatch Encoding and Decoding, add 2DH option to deal with All-to-All at scale: -```py +```sh >> Example (suggest enabling 2DH only at scale, note that the value of --nproc_per_node MUST equal to total physical GPU counts per node, e.g. 8 for A100x8): python3 -m torch.distributed.run --nproc_per_node=8 -m tutel.examples.helloworld --batch_size=16 --use_2dh ``` +----------- +## Getting Started -### How to setup Tutel MoE for Pytorch 2 and [run examples](tutel/examples), or [enable fairseq with MoE](tutel/examples/fairseq_moe): +### 1. Prepare Pytorch (if applicable): ``` -* Prepare Recommended Pytorch >= 2.0.0 (minimal version == 1.8.0): +* Prepare Recommended Pytorch >= 2.0.0: # Windows/Linux Pytorch for NVIDIA CUDA >= 11.7: python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 - # Linux Pytorch for AMD ROCm == 5.4.2: - python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.4.2 + # Linux Pytorch for AMD ROCm >= 6.2.2: + python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2.2 # Windows/Linux Pytorch for CPU: python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu +``` -* Install Tutel Online: +### 2. Tutel Installation: +``` +* Option-1: Install Tutel Online: $ python3 -m pip uninstall tutel -y - $ python3 -m pip install setuptools wheel $ python3 -m pip install -v -U --no-build-isolation git+https://github.com/microsoft/tutel@main -* Build Tutel from Source: +* Option-2: Build Tutel from Source: $ git clone https://github.com/microsoft/tutel --branch main - $ python3 -m pip uninstall tutel -y $ python3 ./tutel/setup.py install --user +``` +### 3. Quick Test for Single Device / CPU: +``` * Quick Test on Single-GPU: $ python3 -m tutel.examples.helloworld --batch_size=16 # Test Tutel-optimized MoE + manual distribution @@ -117,32 +145,41 @@ Tutel MoE: An Optimized Mixture-of-Experts Implementation, also the first parall (If building from source, the following method also works:) $ python3 ./tutel/examples/helloworld.py --batch_size=16 .. +``` + +### 4. Quick Test for 8 GPUs within 1 Machine: +``` + $ python3 -m torch.distributed.run --nproc_per_node=8 -m tutel.examples.helloworld --batch_size=16 +``` +### 5. Quick Test for Multiple GPUs across Machines: +``` * Run Tutel MoE in Distributed Mode: - (Method A - Torch launcher for `Multi-Node x Multi-GPU`:) + (Option A - Torch launcher for `Multi-Node x Multi-GPU`:) $ ssh python3 -m torch.distributed.run --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr= -m tutel.examples.helloworld --batch_size=16 $ ssh python3 -m torch.distributed.run --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr= -m tutel.examples.helloworld --batch_size=16 - (Method B - Tutel launcher for `Multi-Node x Multi-GPU`, requiring package `openmpi-bin`:) + (Option B - Tutel launcher for `Multi-Node x Multi-GPU`, requiring package `openmpi-bin`:) # << Single Node >> $ mpiexec -bind-to none -host localhost -x LOCAL_SIZE=8 python3 -m tutel.launcher.run -m tutel.examples.helloworld_ddp_tutel --batch_size=16 $ mpiexec -bind-to none -host localhost -x LOCAL_SIZE=8 python3 -m tutel.launcher.run -m tutel.examples.moe_mnist $ mpiexec -bind-to none -host localhost -x LOCAL_SIZE=8 python3 -m tutel.launcher.run -m tutel.examples.moe_cifar10 ... - # << Cross Nodes >> + # << MPI-based launch for GPU backend>> $ mpiexec -bind-to none -host ,,.. -x MASTER_ADDR= -x LOCAL_SIZE=8 python3 -m tutel.launcher.run -m tutel.examples.helloworld --batch_size=16 - # << For CPU-based Launch>> + # << MPI-based Launch for CPU backend>> $ mpiexec -bind-to none -host localhost -x LOCAL_SIZE=1 -x OMP_NUM_THREADS=1024 python3 -m tutel.launcher.run -m tutel.examples.helloworld --batch_size=16 --device cpu - ``` -### How to convert checkpoint files that adapt to different distributed world sizes: -Documentation has been moved [here](doc/CHECKPOINT.md). +----------- + +### Advance: Convert Checkpoint Files for Different World Sizes: +Documentation for checkpoint conversion has been moved [here](doc/CHECKPOINT.md). -### How to import Tutel-optimized MoE in Pytorch: +### Examples: How to import Tutel-optimized MoE in Pytorch: ``` # Input Example: import torch @@ -177,6 +214,20 @@ y = moe_layer(x) print(y) ``` +### Reference +You can consult this [paper](https://arxiv.org/pdf/2206.03382.pdf) below to get to know more technical details about Tutel: +``` +@article {tutel, +author = {Changho Hwang and Wei Cui and Yifan Xiong and Ziyue Yang and Ze Liu and Han Hu and Zilong Wang and Rafael Salas and Jithin Jose and Prabhat Ram and Joe Chau and Peng Cheng and Fan Yang and Mao Yang and Yongqiang Xiong}, +title = {Tutel: Adaptive Mixture-of-Experts at Scale}, +year = {2022}, +month = jun, +journal = {CoRR}, +volume= {abs/2206.03382}, +url = {https://arxiv.org/pdf/2206.03382.pdf}, +} +``` + ### Usage of MOELayer: ``` * Usage of MOELayer Args: @@ -205,20 +256,6 @@ print(y) has_fc2_bias : If set to False, the expert bias parameters `batched_fc2_bias` is disabled. Default: True ``` -### Reference -You can consult this [paper](https://arxiv.org/pdf/2206.03382.pdf) below to get to know more technical details about Tutel: -``` -@article {tutel, -author = {Changho Hwang and Wei Cui and Yifan Xiong and Ziyue Yang and Ze Liu and Han Hu and Zilong Wang and Rafael Salas and Jithin Jose and Prabhat Ram and Joe Chau and Peng Cheng and Fan Yang and Mao Yang and Yongqiang Xiong}, -title = {Tutel: Adaptive Mixture-of-Experts at Scale}, -year = {2022}, -month = jun, -journal = {CoRR}, -volume= {abs/2206.03382}, -url = {https://arxiv.org/pdf/2206.03382.pdf}, -} -``` - ### Contributing This project welcomes contributions and suggestions. Most contributions require you to agree to a diff --git a/doc/DeepSeekR1-tutel-accel.png b/doc/DeepSeekR1-tutel-accel.png new file mode 100644 index 0000000000000000000000000000000000000000..bd2c8e83bfb0c36e1669faf841b9d76fc9a234cb GIT binary patch literal 188854 zcmeFYcUV*T+BJ*?8>6TwNDGdPq7;!Ly$lvmDM~ZaRm9L+fIvdAHy}}w9zdlRiF61A zR7!+EB7}fI5|9890*MeHl2G0a&dfRQ^Pcnl`~CB9>E#vnM&0-Nt#z-v?59_(%y#dR z+9f0;wENO0~@mr2VLPBe~7tfj4g}YAERLJ9|)r(6@o}xi54@?ryTF9uK zPfJNTpwrTl^Tg!Ek@Hdq)YQCgUccGX@*zh)6Z+-D(vB^LK?B+D5i=@Y@04r{lUKGk z=|nm6tS%Nt+Io9Tt^LHT@iA-MrZqv+8fR@4__Z7OZ~Vuvv&+M4f*9Z~*4odFTbb>~ zoB!jTLPFQg))w5>7S_BOj~~bUmwSXhtge>NHm~vB)E&4YZ+yBe8JNMtuZuu|Q zKA3#{x#_=L8#(N|{9i5!1@6@X@cPeRIy3Fg{$p7FS6cp4SpNT!mJhirGXXJz8BWZaz`t>Mj_}xh zvknoG6}317c8gjVf(=LUhIqpX+~lYv^}a`*xmhi;v3ou@7# zcoau6UR(ZM#bB^rE{lwd!*9PMj;Ov7=a4z1K{ETYg!jS9w_X0T%MLQ%$)6W4xk(4L_AP!{5krKTJD(T;m8d% zXD!i8d_cBVD+UJU^!4-$I~CL4wx0b)f>*i5Yuw-Ocm}Mk@|t2|BKY5$9}l|~dH1Mi zO5BN_Kf)LaXj+M~HZMW#QF&?oqj{(%z`^;ECEN={xUwzkEOFmTZOe6cSEJ2=alXx! zfj4)J^lsj{pn<9V47-)dF=5JKCXt7XHw_)o%x6-wqWP0=PS!@RE-{CVqL+UPG(rqO zLJdWjrd9srkulvap3k=f+y_GMk+9Q3r52ofH}zlbBX2@}s@Pq6iMxs^KN0qDNcAaQ zqHW%EayGLx)C8`Cb8{IAk1+RH<$Z200>_5VyuiP*RTVvb^{JQa#j@*W{z(E9W zA)d{4jiDe+`OX%}xdvfr;}Tn1yw@gJ_Ti{_{hLc1OX&=p&6z@Rl}Fc#}b^@Fcc>5WJGUY8q2)(D+H_0iAmSdC|j$ z*{Er((Q5d#bVgCymAb+Q4ufMS{_KYHI*HG?MNL1JGCy}ZQ}vGD;Auq5nUo8vNdrd@ z?GusUh4w`|yBU0m{aB&;$~L`fu_w5FpQ@8tx5&rKSw@D=(8Re5sv?L*5r_)MKAS*F zB`l?~3$?T(9+q{c&QU+MiLwqEe!<{@kk_ZC+g3Z2TI{XDqsj(MfzEejDL z*Ml5yA?i+PW>a5jkzZa;>{)2Hm%^6F`(Rr^eJm+O&)R+C@faaTrjkDArSZXp` z*|}2}VO%N^37Nq!e6p_pf`<|h&6bC1R?!Syr%cH=`5@I@Bv*QcX;M~2GFO9cn>e7q zJtR|&q=eknA{{>~_Cq25_XD*u1cB#Z$Y{ag^FliJ1oj49Onjx#2<1oVwhyvpvs+4^ zL8Q@LfCN?jJk7o;wWSfkW7~-Y@Lv1G^}FWs^3`bj;6MP-5kLTHEHm`IO;(P>o`lXM z;$M;RAvM#I*>4)Ztq6FQ(!##F&9|ol=nZ8?apk5Fd5sDn){b9H$eeHoZMf=WhLt)p zPe`yK8q1rOP3}HG;AFy%7pP^*yNazNpBU||8bT7k-m^av$0YQ1A8_vFA5HxYH$o{J zomDEeq1VJI8$(o?kj8xY>J>S0($W+W&6G6T{w(g(&z(Q-cg%yNZL6+ap$@cF39-%6 z=>@w>o=sMM>Hl<7I_1q_>%ogcm^gs)>}T<<;A}3a1FHSsC_Q zGb4~a9?_UDjk7MBE5=8zSc^=V)d!dOkbr~VH+?W$P4T2>aLcyi+r_$ zUOXko+9*hBWneF7cMPU3+l%vs4njo!K<~Y>e8_#zQagHooT4%KI1|#8-*La*Z#8hyV0ccEbU!ZZj**++P;6TjJA9jvI#tcd zBLCq%19!dyO#1AWKgHKia;pBY40)3&V_+L(r25S}whpSWOPvvlMShiO&frFv$_+o6 zO=y4?GVlkaDd*Gr~gm-K1>x8!a+zYw`B+?1@XG5es>{P1j}fF@s|wy}c0K z3jIo^VE_fcuF4JC^N*mRa@s%d6TA1}Ao}$4z%zgTeyz3^Q@p@{c#6Tt95~7EvI{xi zXOA9876dopar*c7jpv)Wx{Mp}!fLwYN}n@*N)O45UtfLz&R26+9)@I}-_mKh>1(Y1 z?byl6pFd|AXZe!5JW*boJ}8a2?7mYEl~BQl-023W+D%=oNCtxQqf2|(z&YD|_OaRu zt`2HsUU^6d9S40_;uW)zuZjW7eVEe)g)N zTQxE)>2Qd`*X3v9E2%gFDY0Y8o?Y0O8JPUKws>r5zgF}${Lh^*0O7}v%cHLT>+u35 zY$~^%-}1qJ7T_SQP9)h~?}I0B5-<1<4Gd*y#eAu1VL`oqpZOjc#EeX;m2)nbT3B^( z9O|b>=yW8~73PWV9bT@OL-9?3rD)1ONH*?p9e3mOq8f-b0mDXU+snULhwH*?{iC)9 z_i@Lz-Yu@RS&^+mXcHf=Kah?7DISn`jGYN|sw&@r%<^D)?sC?(-OI0saX`eNE_(u^ z60%h1sN2+IXdmcmu%?8y$uw^MrLR-O{X1N`rIxn~k!szmQ>;V`jlH*2hR3PInF%!< zsi7Hd2F_(I6|$JJ=advlOXnH=ep_&8j=t~e>q8#_7c?x|-#1!}X@u(#GZ(|dz9Ql< z!CZoBlYvY>MK8b641^3$yw-0xmlB*tmNz{c5%dW4@f?oO)DzqmZdiDEuvo56umD($ zxKnaN|AI~#z5$)!VJqEqh91w|VBH`JV1>8 z2^|Cy=?_Jq$pCJv=wfvrvgD3CmV~z;grmH0HWQ{1YpW~Gb*@hv4M3Q}L@f#mKail$ zthmilTdf(676f*Rub%6RY1G)+y zc5s7s|HC}eEzZ*$BXNy@b9;;(`69AK_~dutLf1%^aOE+dg{B3YbE>)s-1Aapc5Br8 zz#um;8*i^Qx-g$}2PBM{_imnmH~7lV__`&g9)JIPChMhAGOsRkoH;eV?9H2(Tlc+!Yr1z3radqRN zV}u1IX&QQ(W?`p0J=SiVzBw$2^Y~EWcPF$|0@1nM7F(+98bBHxH5wvYQB*V!n zH+L;n&aHn{S?dC6%V0xxsL>h6j%XC`nYxP)DqL}CGCvDMqy%jkxkFw?nSU8MGYfSq z%@YL~E$)LSk0t+5oPVpGw5CjDNP-}x-7fDaK85{v@5 z^Y377$LSONUmEdg59$wJP1m(+dkiPBL9A!)AE&l)5*s9SF)ZTJExTcp0IztBYdz~{ zXgVm%?PwpRi}72>q+ShrN)}x0=KNOYNvA-|ucg%TnBt7@{44|iI79>90PZczTByN7 z-g!YrGy8sJRnv5o$3_>{nfU!;8>F(}yx73(p*=I2(&X*zJdei{pxBR^-E!ovn`L9H zb)`m516~(XwdOxrnMAC`H7Z=n}^rVGv8(7g3tRuhxy#;RJs9pi|#pJB}Ao?iW@=fUZ&9e2r|K4Pv}lm zRU8Ci(6pX?FR;|`1EM`4Q+hWFZHFG{TyfsqrE zBKBsm`tVOHZ=5lwg<*Vxf1M#eqolviP_}|(vZqbq=~`FKR93@pTS{|sOAUs)Yc6@U zgj8HpUV>9c9)a#+FeBfiop~qnT;BZoG39nr>F;^kmoUpi>C)5k@Ego|$s%sz^AhrqjRmh+scJjH$NFS$_TVXA?b{EJQ|QG(_ZytH+;ZK?P-xh^An^Y)#d8&cD!boi ztg;etTSH0c>nIYu5u^3!CMW-w%EC|4&o7~83m@6EgEDEW>vNM#*>xaN$nxP&Df zTxh41wV^AXwvDm-f&Tuf!AKdvSgiL6sO{RYmq7$PN=^=GpL>y(!akmMh}bW4gdP2vcR6Ro{@3Xax?rj9`j-I3V8qrepSrz)SAG8EO=XK4awRME86x%ry*A zi{>{UY{SYQAGyQ}gYTgkVg@#SLdY# zJZYzphOgp~AEgszC9b%JBIP1l3jd7`Cuc1NYZdj;HU`lGy?;M?@iKTCVo%J&N^r9H zO_P}(M5h?(p_g$fd)OBjq7{9LB+BCVCVnhzEI64FRT6p|?29@#ehbT+63ORg6kt+U zqrBv|LSA4x+$>?3SD;G^zBVP)ReNRYgZUgQcSN_VeF0flRx*Q9Lr1orKvif z{y74sm(_3ExqUsC)*J7|zL?;z2hpy*q}r>wl!xdpzt46p4T~b0Te`SeDD)75X5ex$ z_2=C5!B}(E9H*8$Wy+S7_)2GX&dL*s&tRnrdvLu`zSX4#VUh+LIngnsv8~wv`Y5_# zeB72mu&LQI|3zbM@vYtglD_%Jt7cD$?UE3W*|DJH1n zNm=JWoe5pVFIprXcA)Yt0}t)!(a76e<9PphxogDo{4RHzg>+tnU5xWiM>4nQ-WD^J zOaAtC)t=|BSQC=5rovt}ZmWR{{7+HK@?BM{3ae3xe6w6k9X2H_an1>PGflzwgt3*; zVtJ^sBgxuBEwAb^Sd!tv?xzfWJZcr~;+HvW(o`1HWc{%|EflPTIp}Rn$vnOmnmJ3Y z=nX42@ab+%9lXUKFm|cRq8zFXUb92mi}!2I1g=vV%p5zLSK`qUqOmBG#S8CC8-!Qs zBB~aw1c&|Ux_H8?7sb_PQw%#R^s0@Ac~iy&j&{CLe%MmS??KxEwG734P@5*PL@yIf zx3}YZv_P{2{*|916bzZZ?`^s_Oh;u^X>lvoKZsjD#$l(wcVZa`HQ)N<*4k^XlO8n z1$E8AsA$^D2DySRYnIX7H_bRNCk>PoyJp!cLm=~EK~qkO=%iF*0f zA#{z)cSAc03OI3II4OldYdes;BVy3;w=o<2Iv#gI37O^URZ`#g=!Bt4T+rsqiAk1g zTO}zcw`6iDl7!F;5+Ek)o^vFp2w9UZb&Y{?YuA2Dn=EE<%I#22c(PC4G7{evI9HaO zE7(`Z{C*PI0k2Y{53XS0`n9d?nBp!j_)Ea_aDWaqI*-X!<9N6A%vOCU$Tf6$>myV3 zl)9fk(3A&jjjlFZA#Z*Uk^Zw5dQUB*7&>?3U5WiVw1={M5*`^2now#)<%~i&S?EO* zY6BEmbJeu;E68j14Y3ZQYwgH3edjH~&}Fwynq^hh*;~cApyJVH1(iT@_~z&|G^JQp>HyH%+DcA$+-~U*x&2GZ$-MD`YI4HVwc_#@ zEHpyZNlT*pr%GPU_=d0tK3F;g)KWkZGaiS&Rr2LiZFuLiK9{@wo}(QrX9APvcm+FG ziruofN20to+o@(VhOwHuLrNDu?*2MZPX~=pd4v%t7*b(EPkH#2(5|TV^|D5QJD#{o$qn^sBHBkexBGwpMP7=7f!f{xg!o!hCMtiui)8CyTSl0!e=cD5&4@~ zNj9!QOrh2xsf~&&r1E6&s@4vHS|RP>xwB*?WN30DHAkp8Nr-(R^fZW}Ol`baDRelm zMxsWJ?xUduAX@6|<~nO+Bwg?I@!}$=ZP?%Xb+PmxlYwnp0nW}%M;kCQ-A`ZFX4&Jz zVZPfd(94eSiSG_!a3UA-C1CjmDLPx2#ze&NhA4#7N=0j zwX^%gc$qUFyBLbRv=jL;xx8kzjJQnFZDM$?dS8#T^K#VHaUCe!3hCuxuZ_ZOl-e>L zMvwk-Eb`PVFo_=P-1yroMX4z!JbJ|-tnUNbz zIFbs&8z<<$8aNHBT^D1*r}GN$O%Bx31bkjtQZ#ktQ^!7O_@)P8aW-nHnZhCz)A2fC#me-=@I5}g z639-*)6)DNQT?T_<_&V@RA|T+%l-!bp@7UbKzVm5G?i+RnjS_oL=lV9?;P7=&-LCsw!NTsN#7ygh`rB8 zhuJ6{s&%c{2wE^x=)2byr=8Hpd$@sSe`O^N(Tdfh0{zL=_O7$}Xm3}&yo7G9VUk$C zjiis4eEv^CNs`(OJRbI!zndp^hXy}YqOI_SmVko zwPGzja1b%}EY344Kd2=FHL~r%@h2zq9ySHH?ciHlD6lqY**OC-i}qSAp_X|9wp~0i zzvJfX%tOY8d3CMt?!O}jWAcW5=juC#p!PEfAR_Chvqz^x%YZo{ zuR^RQDv(Pwq|XbKs?8idW9(vj1dfine&>87lc(xG)Oyh!bQeyXyj9xh6UvHC(r1?F z=Kmb7w$rXyTO10I>$%48FCW$;aVBnHR>QjnZo0Vm1M`Hnv57afRvGSv=(Z3?9mbsw zye;+6aRp;0)$(w$v@&g@&5Tr+YyPX@1sT2Us1RR{PQICtX<|#c$eXelpzwuCj6}u3Y-q#|F3BVQ#{~fdB{n^qq*= z`*YqTCy_SC6q%e#NaBJ%?w&U=Km5_5&SX=fL|-?Ju=K(&Wys(1E)mliW$Ls0xdWq* zlnFn1tyb{!X&GCNgECbn?p+YYmDX&p+5MuMR9+;f=Vlf*%M+g>7*2`sW=VqYTaED0 za8v1hx*v0KlYDi?CT%&yXJjr7I=$rFeuyK=pP>wyJWh85r!n2mtec%~0427)t9;@| zdF`l_`RUk+?~t|E!>)B%veu@jiP7AN_!R>=U=)Ck{EYlq10+F5Jk!jdf9h$v3%G9P zemaajJ3j;fMd1ody*hU32kp?R+Ci4G`b_0JeNvQGBYHQNP+JnN{-SQuT$yvGgOGzZuDsMRXt3{^Hs*@Mg2iY#eBm;x!C>}Z`96N z_&lvYag~J+I1nR7?kt+;{OZJJcj8rlHCdR`32|v7v3TNCDAFIaPz~PujaEvEdQm!T z?7q84UNq)TX z_ZmPgaX^jYI(|Bs*}M1!dVlMitd)dl1SPsUl$*LM{`XM!Y!jq~+o%We*VNxOyr10V zDVc?mvGwz++XGMxH3_=hgmXmX@N80f*AV)IGhWtm(xd=$RM&?h-W&Xi+#9{*?YSC#k=beXxo*%hoqXYJ)f zLmjG&?ma8rg4e82tCgv3!8+6p$=r+zG zu_+TRG5OXY*6Z74%Int`Sp?#*l3jZwFH{~?A9yc4_w=6=lEM$Z$GXm5GxL7O3+|s4 zbEg^C#U6D=Y%jiD<$Gnp+wauE_+@aZDiv>BI2_15l6iJ)5h7NlDwXdUmt6)8wKmCN zQKuAWk<-8HCxwx^U1Uz<4{rY@wf3KPV0v1r!yywifDYCWDqj6U7JH?xRe@C5f{LoU<_#*ps`i9uI5BET zxA5eILva6kZ2>x3@4lcxG=B-vj3`dVI>nljMoK-x)-SSdjo0cul8(&z^JNM>5&cFE z#wn|8c<@YQl#=PydF6yY`%pFnFf1qG-jfNgcU=we217Ox4!E(6##G7CT9*Z-Jl-?S z#wQU8-BO3Safp)$n=jTrA4-O5t`0uxW+=2_@%m?_C;S`aA^hNBdVwmQW$jjriWx^< zJPP0#M>)@|JaFNifEn%OS}mT1UHW=9OnD~L9o4E8H;e4gmh}zIU7k|QKhrb|UMJja zX&ba{PEYryNt7w7@xpfqsy{nxhO!v9m$Wj4)_Wy;Kj7))?OFymsqtEmFFhdpCah8SBT*lQ?06OzjVZ56Vqc)ZuB7h%qhBZLQ{QHm zMnU6Q=b_FV$2>=E#x+}9Y5P2aDZkf=R6^k!pk|+}C3W)~zjQUbt^}Ey8xQVh4aWW( zc!srdSGpg8)!rRd;fPmxd01dNGS)Wce~zSIJw&9YT+^>?hgPys$1;z@@5+z~MOKGM zeo+}_Yk3lBvoOgHL2k1BK^U$P=bGd5(zf4F__fzud3nP8%c@9cM#CQt*@v|(i&TXh zvOc;HrpHSAnSJe-o=+$>JqQ$h&RnStn?+7SN`&2R9UnQvyJB+DKk%906K);V=zULU zUz6)pi)pVv)B(hlA3O!yl+|`g#jh_z=VMoo)L}8NxbFSOm9ZE-%!$1YCg&-yJ0aEL z2>&|9dPbZLAu*Ov6V-y?hnO!5XQFgnQ9$iKt&rJ_);CU?(Kr|~sk}@b+{%6uiN1yJ zVe3{KZC7j4e+-(rUW}V#`Er-V`0>%By3<`Vu}OfE*3Rh#F?%NR!c132D@PpqbD?oD zPThRxbH<20?^a60c`cg7ny!=;+a){Qj1Er}Lb8?H^e=G@Y-4(x)2h$R^g23~#MstB zdxjnNVq5&=&<#Gi^nH$l^7**WZ2^U7opi}c-&@@Y8drk}HQYU9!#+=n+P6Xdf++#l zaCj#cI>yyR^ULNl9r_(2jv=A0v%?yFZ2LzWN%A+I%+aB%>~_Po2&wfA8M7qvRd_Up znp~HQScwzNO2=!vJMQU{ZL8EMXO#SE1iYAFNc?I9Kp$Q-o#_Cfvc}n;Try5&dY*YV zm%B;f)Sr$L5c(1y-gu#ZVM;2=_4{NKEBBpP$||#fgp^oa8RzIFSzUq+TDo;PCBOHb4T?n2C zHY(pu)_5rBadd3+g)xu{j$3XWslkcGZ|NI6nJ7`ie$wWdcqj3@|2l0Wu(zx1vFr%R zp0&P;{EDC9nbpuWcrfS1Edq~X?KVer254+Fp5FEwxAnoS%tX;$L3!fsF{DGkcvVqc zfs50q|5b5v#BIOdG?{cCSWxc+HCKA-&_@t_qJMz7G^rg)iFQ{dQmwafyg}rF9vj!( zPJIrS?;7wMSKGc{McUM7=^A20^&`hOYqIDia^vXgnO+x^ICjNAe zx@@74-6L8ob9Ly9wj5A+wAUvwRCL`DIS@ihU((d=EmNE#k1=JDy@|kBCV5K7)Z;_l zn##(|aVt=1m5EgWXxKyi(?bB z2D>{7lOM|yynyx?74=(7yl!VWTks^>+C9Y5<5R{$n z-R257wAn$~I=>j85J5X>pAS8=39Qxy>D4N&ituqP zsiQ#Jh6++2_dwDcO1<`h+=lVt_dJEQD2xT zl@XujE+y$N7#VXR%)ZsAF|(2kLNb&ebXG$VP@+rk;%&Zj;^lnfYmRnNTEuOdJ`78U z{LU(CwG0rvYe|REcTKby>Piy#!8f&gN^+m3KelqCe3BKbFDo^yi;y9|K^x`OdxYq| z>sk2{UOSj)HD4c&s`2z!1_bm-tk@I(yZ5_4eI2 z9R!2T6qx~3&DLgP!~lG&!l{e!v0+zg>kMfSM$uRZ@OKfb3fR>XowRyTr=MQ_)m;>F zlIbs<7hGB$`LqL_DM!7Nrknd>XF1}iRbWNsw$Giqp|%Hf_K^qQj)>JX4M~xd>xR~4 zTQ3pQF&k=8V*$dcy?_SXI+qn9-kTf>E(GxPlgeHUH!Uz)(}6~hDu4BKIJ_}fQlTPd zP9pci(7`-4I@NvrTGfBjy1X=RO0upR!s~sReI($X5TMZqhM_=#3LH{YtW(X3LHnIHf&TT`JtKhIP`X2$hpy-!i)Wn2y|o3Sq9XodPYac&m}J(l9NX~F{q zb16u+$8x$#*~zO}Fan5&a_RyWE5M))3#JdS?#*Rh+O`pW9N=^x&EDzK5w47SIJVaniZj<$-z|kEMnvcD&L7rMpi>NBxpR&#S2Vk2(#Z5DjZ?DZ zq0u-sD)xeLyzh3}3Aq9g*@=PNviN0g-$O&4@u?kCh$88z+w>bcGrT^12{V$e zoI+CYGx8N9@B}NlZDE$9uFms`-^j|!%vB3hLxxHa0I$52(G_5OQTh!m^2!*o3yj|^ z#*;j5*6{eTv9qjeH-C{>>mL82!jt)f zIv;}o*QF8ExK%-gp+NBgc!6JnNg1p4x#Ts3YB$>m4v-1o4{jZM#!R3TujuGc5z3V) zK>b>tB={M(`89=SuhXwgthj1EG`>Du`^DRvaW$p6t8xC@5nk<>J_z>(cX?i#;MlQZ z9S-Jxzcfxs6?v6ZX9Zz6FhxGhSFjuh(_gHV%=NnbIqjMS3>EN+DRhc1G#fWD=)RtF z8?Yme=y*BPpkJI-0DGd)W#toj?s^Us%a@w|PrPK>+cAc<<4Ze;o9(K))G|J&4h>}c zv2>Z)aJQN4{4VRF^h_1M&k}1Vvpt9wRD)@X9Hfr(-Gc+ESMdkoLmmV>YGBbsT~dG~ z7Rc4lD^8?(+5*&Qo-{s|s`h1LWpCsHw@l>M;!s`3AKBc^Q8pRBVc2hr+S@_;D*7hc zR;RjY7|i&2evinVlB7H7eFlB8`xSO%CmY`1#c{pOiH`EmaCWRMfKN2edq%tnx$19R z8k|#}{B2u z&rvBr!zgIL#yC`%7+MpTPg`H2c?j(qMamvxsMI^^e~nFd3(k#*=Er-0XS4yJ@|6<7 z84wHkvplQ^l5FT7>Mp_`@BmLiXiICL*w@9e?1`I>VCE}U66AheivJ@GBN$&T-vTvS zEow!tlHn-Myr|e`S2dhQe%qEpPjg*(r@)hJlBiLZJI{H}IPKeRq;EG=W7v`41iGo0 z;4QmPrI+&P>1>m#b;AV%_8+T$zqFKBo8^O6Yi^Mg<&8}QSwPO{c1gQfF&Dr7G$ADd zgDFaLnIe_j;k)bAH{{G8WaIzDw(6~7c46Y7oB56ka0Mc}kdom9pfPG?T8a}Y2$TR< zTLs6urs#ikGF;fB*s#H{r%npVu6uM;x1M*~^3;oVcL*hCGUMy>+B3jE-4HJGYJ6Mt z6kX@8f_W8cN@Yu&kZDzRygA*PMnUKAbi<5__!q0|nrxrF(KePrx@CN={Tba`Xl zX&tG!{0yu-M!l`xM*nNVMn@0U5seW5T|HGgW%(PxE85npB4X~2_HfiY=F+ncgL)|E zL9|wYoiVG;TxUtLP{v-I-ql)Jx{(t5K^rNN{pg3`6d6SQWn>pVH|?9T4z+INzfLiI z$X!*>3)(i-l-NJ(%*%mgBr+2!3Kq2$-3NFT!0Qz&&72xf12M8!H-czJ0#V9bL$}0L zEuH|NtD-&qi)HTGg}_=KFdIoBJ_t_JUS}EO$x&~@X@CEL<-&!A(sH#=|A}9SOqxLBJ^WO>!I%|0 zBCH*&fqa54OQa)u->qZY@yzTx4n$b#I6;OWa0f3z<1E*IgFzGv2 zeCY+zeGVmxg9b0Vy$FqZyz$lW;jc#aD`K5$bE@9fLJq>)xq4PjHZ?1Ga9PEWE z8qcm;)@RItVl8-Y1=IeeFLPU~acy_9y~+>`nJS50MQGRuZyun{Xmu5}3|*LVdRbN7 z&1FHHcfD9$#MeQPeI4_JA^(9myAB^RaXeo)n=py4-(+GCvtVM#XfEPzN^qU2(XUSE z44?Zoh0Iy{`ta8j(%Iu>w@b;9#P&eQ%Q_~3B^R?i`}@jiiYgwUVM|6_$*Y9Z#N@~4 zKI*Bx6CVtvN0bdGh+9-!jSb^E=_~vka7ORj6TG9g7Az}jbmW1=)Jq=(o)Fi)oWT3x zsIa``yX<7by!v9v-CKb7gakC7V-#jL6$NWq1#({{tG!dKG2j(tS>350eB3b_P)AsI zo8_P-R_|lb!k6)h#Q7br-nxzE_&8?S1Ah7RrG=fy83{=IszPWWc^ z;oQ*`6=JfUqT3|U;(MwuIl$G+NUp@WiCcxZm~A^))Oq5)S|+%UMwq5yF2gAo6|S5= zAhVY=xsKFU;YS&U1uZJzY`lpB2dp80Aw7`lLI9K%WT`{Wi zyG`n&aT)M!wfZJ``|+op*y_>2Y;eOPaj+Ci4=otV24SfUFF}5Rgjtuw^Lu)eJq$~s zc2VDqPO>60%mOj`c?o~H3!kuih8b9Hhr0p3Anbj#Qi6IQvCXhmkaOT2CZhiQC|`n8pp#@*a8_(HP&< zUQn(7;+HY30$zJ4htKa+p&B43&%Gl1t<#BYRB8=phkKvQOQ$esZ)@H4+nr;UE&J?C z*SEC(Z(|X5H=g<2KJ_=4YxzRtM@GXF3SlmV2rMtJt6ySpdFzt;6;|?|Y8gqpf@GHe zpWpsMsQ2Ow5wvD_Hni;M^RZn3NUHp~w9UH3ZVbI=w%cz3^2hQWe{+O;COW`d6&>cV zsRlwH8|wr%l|8QB?SXN?>gbW%zyx?-KC&p_*<>yC@H_7>%`R=1i7W8dvQk1_<1bQu zcXWr_tKebw9U2QB1BvUl!T}`+|AU7G{1{tDH;?=YH!)RfhQ@M>+zpb+MmjUsFPlyD z-{JtyewrNMEXVOSKqP?;x8R3wdzwHkSaE2CnXwb&6$$8zFELY0A8CC@deu>NKc)~V zx=UYtN)7YUt>2mV2x6(1%oonh%Q-Hcj;G$QesUZ#dft9dCq&zZKoy65@>LTAKH-KhH@W4SM@C{Vy3THYQ3PWMQ{! zqullOjfV_<57z5PFdjeUPb4kfO>ynB*D`(5)Q;^~-;$T{G79^<8J>D8gk-Czx{nc{ zF@MidjBV(t=3o(H?y*^ARf}<{Ux9rqhp3u29^7c~hgyaqP=8GDBvOg}9gd!mXvii9 z3kAj{53OIfz6-TcT}6L5hf@x@e=W=UifLz)p`BY;)jSGu#ijq`3jC=oX(0~%VUMKY zo(?^yhyokgC;gf)?}d%(dK7fxD~ao1KJ!>qG+99eZ%jH&Z>C8aipO219u|0C83CYF zTZscpv#J9JdFxpA$xT|LXb8$%^=ed8*Xzcy_##!c_}|!!)3yDwKxrXAxfr5}smr{(z$ zRg%Ep7Ug0DJpSJsQ<%C%VE9|D61VvC2=83%7kj*!U%^bZs6V|9Uw;WPrhJR2+~M+e zZa>*6TkWg`*O5@g%@HJPoDYN3emb$MadH$>chmR`lMks9N`ol&k;WP?q0a^dD*OF5e^#Yia~k zankCfF{3f(a}ve%KnQr{%o!0`rOf8E`s9Inx8a;BE3Ai|oAgs=&l#o|s_gMpY--YK zPyjp#tfHE&yXlHL;1sLCx)c#;qNVWYhbvCjbj;=vq z9B*aK7OJvJZ2;rQf>ZJI%FnjUT!Sla?}Fb7jJly$8jwZdoquIJI?MZgz7wp-14cO5 z7OrAic^@a6>Xu;0T%wi7RC?3&o1+>w(<#&PPOcfkB{sB_ksWLtVZudxWjY&RXV{Uv z7jK4JfrY^TE9(^?NebS^XslB%_OWlqfE^I6~ zBVV5%A|mj(1vUdwVr8%i3`W1~Pb@JyTLP50`LeAo#u^!J7Sg{W+76}33!YT`eb&X-7$i-I}!A`{uG9k}fqU9wO^ zcD6a`kmHZMhsUrO`KD?~(ow<;X`|kDkww&n=EKn{5III`WzqjY6#QKq*;1d@SUygf zfmOuDA;;+Y{DHK6DVSu)k@TPU8W5`hJPT;Q*P zeo(#;X-~E6&49rMznsX^NNd|+D1KTgiz#ilm`1*$t2-4ny}yQD%b7w0Yc%P5>!RQm zkc+5lrzJ76%l?@3ZpE+GJf8rNQ`c|5q#z?a+R#qhD>zf( zsJ_Dc)|u=-G#cyw9zFh>+WMyc=xTi5bIxvg=HM-bvp^Xq22(F*p38B zu3f<3>bndYt3P>A)Amht`^pp%G-J7LX8fi5g8d!uWLsL^OtUu>|J^O|m#y)a?yq+V zYHL@7s%4yp9Wt)tu{^Weu!fl{=tW!=s5-bKzp_#QwD_t%Z-dyl^||B%JB?T%ZWvgtcWlv;2AFo;D`0S0Ig=sYbs8vV z(1%)F(;@l`W;gZB?ZO@-T5y* zH*V#WA3`m3{)4Lal{fHc2md>_>=#<&hUhXTJk9!_eh2=dD3z~eeaB*4cN+TIsvzryxOpba@MGFqBQts|HMZgyK0*x>%h^KZI>R=YIcuIn#`VW%( zE0z3d5xHtW)HUA9PS1Q^POA-7gYNMkFdn4wE2=hoEBuchn5H8o#v0koAbxcCD3591@P<-mtwd+=X-iz*a(ISp3+b?QZ3vH*F0*tbYP7F7lldiKu5$cg&?}c-1Wz3S zT&WE$&+W9oXIjp#p1tq-cJlnVl)@pX`GW8?of@pQ-fQWrnM{LIWSL3^P&J%k71{^0 zJUEm}s2h5;V4$rXb1rt5I~?=jAB{!C&1^kKvO^Pxk$J>M1kj@1CZMi_Cam2Hj(- z)4c)IulGLz&EUr?jk;J`RySa|GT`I+Y#b}f4;VBRtxQXH6e9L)Y!tt7%uEKBsfrlf zqr@Z2SK*E1JcU{3BB1ZGzPV^P#d@;n2~0pJhW5h;$Z2Gm*KNmVYa|EiYH_u(4Y?lV zqY$wz-S=pHy8WQaJ6}^t)2j*h6HvNGEh!byA1uZy^U{L+zFD56f4xS~uWIo~ zU#RMOwb%!|JOK2~oGShQ+<47MGpif-q^avOai_1>Q;E(v%U`dsnWdpz{(8&k;4#Un z!=#Mw?<@j!RVPV@5%GW{YJZ{Hqu`I5ir9nra`RM_i+pB{Ja_XF+hY-6tXE+Io3I=A z=ge%?!@491*@M|y(FOC24s(|#cmcs+9M8%AUpYiq3}tOQ{a9L|`+(xu0Ql(%JW#Wx6mnLw+O|EsC6W(x=DR;-yt-_00XcN zKzY8dvfdG{fXV`qYYnI{~NhH~Z@s43fq5vclsG z?f5D|L+MG;&L&wZ(tjSmp3nT$BW@-1s&Z_Yj=oc1Sekz{jkW~3 zWnqQ)ZkeSAHOK_IR~gu&IY7Q#fy+6Crb=7DQK$Ui`(Xfh66$wP6ug}#>NR!=f^ z));v8eaa0-DixA3ZESyjMr0ELb9Y#{a$3%J^1UIs*{dyR+7=-WLx`J&SX&ep%8b*S zP$kWEvS&@{2>@GA<2_@bj_#n9%y`9GO#bP;I_B>`Vrk|1N0(D+jNqNe=7ImMzWeM> zeRyL^lLaP*UAejHKqZZGH^lR`9D-a<1idQ<&qOxk^(A;i!^?{GORMg{5}yw3d3w#L|hoS%Za7J;*?W#IrhAAxTOP7h?r>nbqb5!`H>CAAbpWBjMP%B1;uQLQjg z=L44X5R_kvGWfwGoWX&?CHik;ddX=p%~VGm)oNYh$X>aEw;DFF%3TYK?7ak8Acl0V z^9*t))H#a^P_*-SfyD$3`@OCWjQFuW{eBM|Nd{S7oepPQh|m)v{7?-VVYb&0;^RR?Z_7{rIak<#l{zvnoUQ(e+6+zi77?a zA3M!#t9sVociKGq^%CFK6~6Ef9r>1eKAboxACE)f9OW(;{c4J)^+fivBaqSwky|{O zhXaOIQkCV0^Wo_0R!Cu>iOXNBv89AmsOuv)w^ZIdk$?fVl8u!kpU;eeJiu3RA*Y7x zQ4jvEK$vZntTfx3&F!1If~da`mA8QBuGb)KzBI>mD22KhNzS7&v{{>j--nXsagKgH z0=ohPSip(oZGB_$(SBF$C?T&pPHNeV*?!m*xngj@g^MoVF+|mSKb+V))3T^m0TUz&GmqCYCWu z;a&0gs(Vzq)Kg^()5~(AsT21qc)QxZ*4OifcE6@-v7CMw(^(WFze#HF4cOKv} zS>4W;H3Tldphq8H4nTw-ic!cr&Tv|jpVIU9>Vkp&J@xPVZtIy?6ZRv_6z+)saMiFs zs|7|){VF&&c)VWGd_NkpaM|~}>^U6jJn7`^%qdQs>=LYMcV*;o>bYeoXAY~y)*b4{ z1gVo4o)1jdn5b9u7_y+P4Bolk`b3~n7kJzFZa$3z&40gID?rdlh^68AqTgo5hyFX^ zjdbA}>y4X16dPLx5zllSTXD7UzKbu9l-gec(sQF6!##zFS=2(670W+B284@4bqL(A z@u&}do~~*VNp4=liEciq?ZmAM@v1Cs74)(My^&vcAnC$D_fYV>lU34dI=5A9!|~xN zFPft_NAzHTb&=CRdXwy92lXhZ;p5j!XG%0er!6dpU%=uIhSwH~8V+(Z+@FzB{ZC1; zO)P69HBIy2?|v$tJZ3oTA!FWfd+{tV;B#X0zlukTKZ~n0$* zNS6~;8eMriK~h+I!p7_m%v?FQrRno>>XHBhxF9fMAWSR{@p-#Ai>bD+`m^Z)!JlY2 zunUqHI2}B>@9i>Zax1lb6pAS;-5eVhKG1Wqw1y2eeWpvMf*w&+k`VT=wzLAotSkIU zkWA{-P?fdq@zA`K0b*GnDi!2zj$)*HT_orA;Ly36q1!i3D)4j-#?UNtRj_W&$0-*a zZh0Xvp|tybqVf78>BMi5lWhN)P948QmBh%_BZI8=qh&;^FOmv09>+?x4>;E+-%`X}N_v6fMhklfA2I*Y{aTheoPDr#JNHcJkzO zjDNw3Z$@^x@Z{Z?JCSDcD&B6nHcCU%$qxw&a@+@@^pq)|=#mO_kLSpl4`C#`A$;aL z0EH+t?(rCpjV09#t$M+E?k?;ZcgRQC{!LubT^PoqZ_h^ka&m-UO-w7m$wm0K)#|D7 zB~3hQvjS!^*V)VIJ5VOtGjdKu*ScF=T9V6H?+G`NsM^B#7uAhZ9>P>M9F(3`#m-nX ze-7cR*2*gpTPC2Qz)a$l$FVE?Ofy2ns#`a77d^#R()mfB+idPRgwb3^!!ftHrh!>J zV)yNP%bK&e{U7E)QlLtHO*|B&x<3WljM7uxBCSF0M`3-lrc$x4S;|n`i7R@q`PLs5 zkr!n}PvSYc3|V~=I+D18mk&`cq`vMTtSd<`$*GH)Jwc~hp7Z;+k`)w43Z%YX{CB0I zPj>sytt&}HO0eT|TIqfOGh;X9)bUOO1;PW`vI{}+eM<5)V=!v^C>_{RT%F%PNj85j zMy(cokHd9FCAq9^?mirp*?%TY3za-3+fpp8SYM0Tx09;eoF5Gor*qX2Y^Qsabft(W zHpoIg?kcfMTNbzD<1{V3oU>sX>WHRh zN-ln8ic|s6{<2f*gT*R9wW_m5X@F6R+R4u_`4siuW+$*6PssN_;y1vP#5>VvHRzPgjbM|4019Wh; zvCJ%|fh&iI&2*wCw?BqeDb20ag{aY3h$Vz2{vO%k?P?x z>LJsGj+ioO((yDB7*u0aYr9^rUGIm-`;5QyS`S8!n%r#*Yl6#CgTn8oC^<1oO99?` z2o{mxARbP%`aji*V^WN21-c6vHGp%EKV_r;s2u(inSgq2d_VHb{*+s?H(&e znGBd-9=E?8)qi`4TfSJ1#oClFHA_p%N%1Xl|wQFZAY;x_Yo zxgHMsa^k?W=m4}I`PAn*XFwO2kU^TI>-kL-mkRS!@egW_o|hg!Ou&L*AR*-_Ynn!Q z@B2_erv@MBQd5}*RH7KsZ0nNO3{(wNl_+`kAeT)|_lK6o+`NMm#@Qp-aZ;Nn|0SwR zq)CySK=vU9+At4cppCdXP4%INLHc4p>kyJ;lPHg^7uqY+5^XiuD`O_DQc@HXP62RS zU(4J;DI#gz0LFz!nhx5*9?|Ev{lqkWsHsVO0mCt)RSj=uWd8i=x55UJUa(z)O|Cf~ z@p^4YQTD-X%eNV1#ZOfSZmFZbHy((Rb+3n*!|}|qB5}uKMVg{Y51S6O-w#qDZD=#4LjYQfYt>;0IbnWg2|xUw8ZxI zo1PQN)5}Le!B#&LG;>;lGILcyHJRRZLZsXH@>IGpNkJw(6ku?$!v~oY^ctFHtog<51;sJ1a~qp`8|I^jxMI zd69Gx)&_kO*=A5nZgCup1O2$@2$vE_ zD7f_z%Aidwa*9t%B<^=S@q@_U1B)zo?L5~(G{aw{j$ z!u#Pr5)OR0HuDC%2aMTT!7klDn$kr%Ajh+LuTZzM87IywTR zqsQAStgEC_b;lTzyR6zd>!FT^_m(vj7MJ!lpB8oi-7CNzUa&`Ndnh?Kfd&?-!WQ&m z?qn>!>|EMN;IfN$QQ4_5#lki$>w3Rz+LqL2Hj>Ml`>@82T`XH`H_ztjYMJU#pRe;O zxgW>w^mD!`1KRmKBPv)1MY#Z@g3Oyg9M?XnCX%B^Z=;ZG3&+-{!STh*YIte`>6_4w zh+hT(;yO85I~GVggHe|hw@QQiNxuVR+U8;Z$PFi%*(^ec!Q5V zLc#x=o~7TZ7s9saoV_P&s(9hsh0ep&NYm#Pe>^Uvm)coB9Lj!aErn-)*Z=HX+uvT4 z)29wqy=-XwT6~U?GLb{fem8UsG?gnoT(`X2yO!NoSxQxYck3NvEhWm?RNaY? z!B~vkkmbhr$^NWRmjaL_QHIW?Sc;wKFYwXpcOj0 zEYtQGm50+3O!|_r96h*lB<6jsqr=M^AToU7YR!R25>Vp>1f8O~NlnW)mtpe}f`t z5;b^bF7Hmbn37AAf*=(e+Z0xR5L>qJD$e6{N#%~iQi6lz@>SZLw#$E@j;wduX;7CD z^Be@RK+*hj^xtk;CCTpz-*m`7c@v8$fG89IsM~T8@ka#&PA_ zK{sKbYg9#4Xg*(CpiX1eAI@MtecetC^NjO=J~V<+rv-h`CB8-{Yq@>ix}?RW;krX* zEDb_DU(D}SXyjSzi=nx21lJj|D`v>Mu(QjUA)6i9Bg)O-Qg3$1GdJ-%j5=?p-(z3~B?`m`mIaO3BCX$ndBo-P;q&cy7 zl@Z9AvEm;GN`B6t0fcWk=+`UibAcpQsY6@7t|FM4pqZncgO^S$``(+^c*4zgp;oGN z^Y)H2(;1BW+WI9P(0ajMxXicG>j8ALS*jxCLU<6n0()0Y7qgB%J9ZE!3VHY3K~xpw!j$8HHGiJvU*p3&Fcg97-Hr23cjW>Xob zJ1j6@4@(E*n47v^d-ORy$ z=q;VeHW_5JmQ@zmTRQVUI){=uK(Fk^lz7Wg-x&B&QuTS*-1#%gwS`=yK7PZ7nzlae zq5PTRF&LM;)Wowq`w##_GX1foPam-MG!O53ZO{D%?uT;;Nnm_h!#hJzR>(88_{(j9WeExhF} z$>r&O-mV6rE01;LvxsJHes6jgKNc)D|C#FILfY+6^g$#iJ6tiPXR&N@+d0^vXXTQX zNdU1jrPJ-jpqvglmqPFM%lCR2t57&%?E3%h=xB&QcW6geD^M>rQ{>sd6re2;xg zv%mJ!OnSy3$?e>*D^l9oa0hx3fB_$Mn=&{T-V+ZIu?u35MaeMFHhVdW+Vv}~t@?0XC$;^^Z#a)~)12n0&K;9PO37y))!rVPY&NDaubW67267&m4T0#`DD9Fx zn9=r>_(K%5l}#4VxTrzsWA-1*(X#dKQhb9orudEMSg&{=Fac^D>QpTO;xsU6QXW$^ zbzfi>d`2j)cDY!xZ?ac$y+0Q8t!qE7ZPm9n=|0_6AzcD-1`0U1HnUoZsx}%~lseo4 zy31&<+Rl`CrB9m7(gSVF5ol%$_4KUqvw|0m$tl6{l2mWeb-6N+9??|%c&|}L)u3y` z?H!k1#DG2(Ci%V{X_AZv^1`#s`9o#e!Vv61BR7YeFF&Gyd8 zayRn+TBI64IBk9t$NftiUtHuob!m3rh3x6Y^UmX*kR(ffKuqZ090vsXEcViarh|uE zX5Z*E+c*D-nKISebAA#}lR^xH{+trejM;s&0c?7;*Fat7(va8_E4s@AYzH%bBx~_YB40SfhMx{nr9mSwX`h%n zQ_8?`SoTGG4lpm&*eG9LJ{Y?)R925Ide`t`U~frfz>J6LgXm?$Qe`g>=Q_GRb`#`Y zR(x%muD&>f8-l%z)KgRw{KCO6j973{H5IEz=PHA7iDN6HlYQS0fvL|lp-zX`JqtEY zxmu=mCL8aGFCu%{hG-!9A>*$CVs4<57V?9g8dA7yQfdovMiS}oW>)e2 z--j+6u@Vik@~vN&&fmPgU>NEawvt|M;J;-`_Ib(Lb-nn?Rj6PZsU-uOV?Xz26E1vJ z1}6LMFCr)ITvkd>qhrHLPx0sZW^-xSahzDv0O^)JPx-_q;QwLFf$hjtAg!m$Lmg+= ze{|1zD6yn=|D7T`jJ=k4M3kV8mqR12x7B#Zye73ai0&vSVE*n-PB3k=)BN$`;woMt zDtm^6F~SS;U$RPF7gc}z-u_2m#o0wL!C7@Xq3c$<=|F8^gw=@wtWr-faIfKahNXI1 za}JJ3F(;KieTY=|!Xl%>WXyoSED)9~YR5w&n*6i&N0e@slJgw{NzYi{Ur_I#sR8H) zy5u076Nq=oiIUTN>S}KitK-KbhX15P!IApr+9nxu3T>ldW;PXpg@7b)(y&SW!U2q> zh`TCX$e>L(X%1@$oF#Pjp@uy)@Ew0SzsTkVRJhyHB1f>)qxEAm%}Ij2puV2OL)CP~ zip(Ezduq{~%5tJB%aNT+N8*Rk`0#AgeZfaTRAW$LFXx`CX~mQk;K&+N5yf82dN}(t z8Py$PY8%)jEg!{sZ8h`2u7rSLLEzT}kn$kk(o+d%YOlS#HDz)<5e=-Y!Sjy>&`uIxPcOE04%wzj6TgQvKKj=M7>4#1@za|o z=y!sU?le>A2EpaALcdtnXC%tqCZB_@V!&$1;u9>35$5g0c)>2Xl1P?x6Ym!u@AnZD`!-{W`sKqCy9GM%BH0|VZwtBeCJVIn_B6|rE^{K z={9X`fOH{q#jPD1a&VHIi9#Q$WXm4`#49(0;8oleC^DaIb1zY5GAhtouk`S-TexqA zl^;Ca5(@BW6R(Q_rwPy*&^}Y*2Q6Krv=|WA_?J0D0gO6sI2j)g(gO)0moSO5aD&m5 zkzBj+53&wxT8^_^0-#37Gv7Mp-8(dQe_1HI@};%9PE~hWQMcbE9>Ot0wyG?ith=Eu z1~_ROdBaZcRB1VV8b{aHbROkc@kWEQQ^hQ8*BNs!nxS7tAArJG3!>pe@uQ3^=s=* z{_re{l&d;Hj3Ra>qHP3|vE6rCZ>@9oK53F4TYuiB+64+hvLGRx9Bn(*3#I5#*bWj4 zk!f^ii;f(aReVRj?)Tzi+wyt)>4}i7iY@GgRYleuk)f9rrU|t3Uai+Ma_Vxr{w}_l zV8s6N5^uu%$C!>;Rjs5VPiD$*8i(tpbq6h?bl81>8wiyVJ}icRc3B`sBn@y5K^IIty$=y7Hz}BDd|tWM8yc8{IZ+Wh$5PK6AT% z9d-r7N>%MfcI`eqQLt>t)#WtuxA%%B@C45CgD74_qO9GL?2z)1AJ4?KmqefaPJx*u}J`+HTS?hrBzfb>o^+DClU+Yb3=Aw zq1)ueBmaKSaUdOXfD_b3G@BGsN!|kEmu~zHOf)%?$eXB#nr>$ zPdR|C-vb>uyB@sgu1}CL@`#fm7-z00_TGIA1cX5qN+>TEKxalTE#emxCa#+j4Wt11 z#xp%675dI@iurJ11Uf6BDIuxxbNcw#P`3md4QqFUZ~GN;_!W~ z{;R~Of?!2dgX~Xbiwd>=N7=41yB;p-RoCU^glU$ z-&^uxDZR*NxS5;Y{={uvtN*O6%_MRNkfF42_}#z8^MYvLvcqX3MHO&bFNe;6aSY?X z3vUH>t(=j#XSU@$0T9%Nb|_6!6N{9%65XYr;gVoft&Y7IkM7|KR4=d`#P;>Td|{^v zo4TRvNRayZ+=}TywJ`LjI|&~O8cNAlmoZ3KWm?W*^NWC~ttSa9m6)Il=4r08<9PQUvA*6`D9GC+X4WRGC$zOhWr&r@7rj^VAQg<%T4$ z4<%LVS|k5Wm4Zg9ErfQi~oKQv}iX&=xL z<3TGlAk_L$wJ|nBk4gM_s0jPziz1{5Ji$YMx<8dFW?XDrK`mcb`}nuWa0043fp73e z1Q_{?BPIaq!wKY6jcB>`8EXh5;1p%24DfMtq?#Ro{`CUOuYO;n4MWFGTDElXg&MWwCDRZ1R zWrvLTv#0G(GU7J5eu=afe_|zr`mTO9ShB`67Rs0apJU>2Re(3stbQn8`44wtC}~kwnaKlC6deEmJUntgk_CO*G&J>M~yc8Zl5U0awR`=F!)&y+XZR=D=Xlpt+) z{zPcl2hWZ^&OlK~Wk^+Tu@g-ll(U8V)JF2OE!ol9F^|YHnC!A%b z0Mk1nv3f#%v;GnUr-zGBJ~!g%l#bT85sh$dRP zKQleDeCScM%*`iy$z;Ty^P(iJ#}V`wNcW$>cP@n`y*3M)s+n&CC5kLR1hhrcDYx+m#kvf89oY%u`hT#fMURg?4C@pYe=Plm%)mC#;|13(DbYMNx zQ8s}MbJbh=rkQh39;w#Db*O@8CDl_p6Y$HLjEJQRQ`14EVTE=&_`Y$LXV-Z)Zrs6- zLBzQ!?0aFO;vqWz`}X-XMq^Xzrp$fd_j?%VZ+>FyQe_qgM&HB4={s}?-sKMlE)!Pj zr;#X1K{*LJo6xkSQ;o;GZ$MVDPHe}|`;q1=gQP8UQa1bU6B~GaJZ~faoOntjIsPiU zqR`13(wQ0tiLF}W9wEv3?PQOty4Nvs0ovkMZ(?Y>~&}u1xy_K+=Ky~&Pt(ZJU zC-Mm&W$Q7Q7+D^N-|C;Q0}{guc}~(xC6cD$v~eqpYekw^FzgoKfqw$UwhtAAFcQQK zE%$~NHeH?$lkT?-iQDcCJIMTgq3jhgv#os5&kFy+^&$^iV}?5JQHTno^Qko~3&cL? zKTi25LBZK?55H})38hUYR>`L7Bf{zNyBek^p3oTS&~}6WtnleuMuFRB>15MG!n{?I6t1LS_5I*(n z`|#Uutk}XK!E~4A<7RbLsHh<)3Y{x$GG+A)GotodfuY%G@b?u0)Qn9DHFh#IB4vra zt64I$H7j7Vg(KapQ?*x9W>av~3s?0qWo}xPDuOL2V?Z`+u;EFBE=kfuXFF4&={QC4 z=}RQ-(OW%G?CyTNF@oNxWJB%*ka@LbP=UV3aTve{|&;o*{Ff%Pv%uwUoJe#X=( zL&qgpsIHUu`vm?KRF^9Y_0bpbgoycy3!E-nY+r8S!6oriCb0GCx;ue1w537%cRB^{ zcTtV0*gAcgWbdSyxR1_J7vfv!w;o8k!>nAaCxK6Q&G zbinf6Pr_^?esp(CIOZW~u`UBzy1B|De3;*6KUh~$R3?HAI({P~C!Qh(`=Bt6t@@Un z|Nbbi(%$F_7gM67bmYK})56}J&>rPNr-WxsO*}3#(c+$xgws#gwcP4~V$G?#6e$nBhjpfiEt<#qAWn&QLeQgq%7^2I*xa)g^QHhcO85d%l~*%6n)-n*;l%I^6I8+v-o428fj?c;x< z+J|~RL{{`jgx~vk8~$WA*7P>KllqmZ0+y@yI16y)+sDlz)KQ3uy&qo8G;y$dDB8&W5FYqQ(Ob{M!2x4i;yr92wb^GbL)N z!B3QWD0m#7YJ8U-d5FfR{)(27&jl-tnpgLtn;h#p6~Z{%xax&dgGfX%{O?g|(!M#7yj zJt5h_{&kmGMS?4|<@D+~ZRET49q}Y4HiG7-vzufMU)ih-yU*j zs0c!0n=PmJXfkR}0Opgpn=i}=UQ3uO#^Ka9?>$zNQ3;hg2<@6NjnAvvSr~_k4E17w z{n5J1pOkk=qYhq!2n5#1MNh?U-4!-XV}|khni`sn(JgOO1UC9=&z}@$FY*g*?5iH1 zF3O7gFT>U^CmwTWsGyzs0@>exLgBzt;hW!-YyEsBE6ZDNLq z-^?8@{=1W#xdZIM`=k1pm`c0%PnQ{awR^UdceB*3xlqjX-5>XrBIU05ux(#6iHV0} zR)$9zD1S6M)Sq$sOp@Q?cYOd&V~n2%W-Vm;QgLD;v#;vm8wnkWU*ZFA(6`F@n_E~t zH3!Qh89!9<1m2z1tc5BXX-v{+YYpckm%Ov_ z{0TY+#{w2O(qDdc_h0l~tMs4JE0b&B@e{H;=r#ND4c$y%d-qB`%K@wl1Mk3S<^SJ( z2A(sphCFIGePQakI6cPRuD0+{Uits^2LP=R-~$ATTL(-RuG5z%g%nIwEL5Ry|6J$( z|8ane^aX!G*piJ|unKpcjFb1d-At)tHKo9zA z`cdOT#7!JQBzH zf4MV^6<776uw_E1yj-lz&L5dJc;^0JP6rnw;owaJYXo`g$6!l9Oxs(s$8p>{e~U4f zGXra4!7ra`4jf?Kd6H?gJd#Mi18vjvb$|1dEMTUu%OSyMYwe5F&gC90Y9R*w)XmfX zsHQPHd4v959&re=|3dZlRtZjyHaBLxbNPRuCdo?w^Qqw(3jigI0?d~>fbb(msMwp?Mr#RBa?V{J9NAiV#rE^? zm`?w^1?AVow@DQX-XqClX$jt_Z=cT1f}n51OY8;M6}=gs8qYeOxAQkotmGPYr10xK zcK*AuO(&Q8=wXyx{i1s3;F;(P4gd31zkeG&yK}g|pZM2<|F@X_Ux)F}!~eIK{-3k> z*Q5WpnEqdf@%!QbuZ!uQM38|a7@;GG+;w1czATWwgA~fzJPp{AlDu;R8lNS?9Z zeay{GB$0&^5oAg?xBf`!rVfv%@Fbx<-wu=6zVG_@d>wIDAN3f0if%e@4m+$djN7p% zETd z@>yJ^&-T_>iQD`4>2xnLh9aK4dh^Gm)bwyIt?1U_mxP6mh}}octAyA3s<$XV&=2-D zO(VJ^G9g=w4ZOTdsK&TK(s8eE0p`m~V2w9nEH=c+*)%qcE&A@^!&RFrgCpc=fH>9@ zwo`*$#vM$-8jK~+8$oPu)H$hBJ{sH5t$hldlaQ}_9vr0*26C5MRs-34{5u)ukp(#^mGXgXepoBF4ksaHMuk+;j``5Re z$AXpI=On*#5Cz3yX&5(aO^3=}*z)zTRWw)V=oIqd4QTEYC+T=LKS9as1syuh#A5U< zUh(nM*q{TxxB(eq=bXhXPt}kO!-SdK5DemnFaVdYG;o)`8dFRHQy!r(;~|y?UTNaH zyElUy?k86FIIL_42>v(}S#p-;$@T8g*@{lF!7RsmGC+c}3lx!5<_A&QeJeA~q>qbV z#0E~Qc$>{ZcRXxqVb_yFRHP7 zm}TqWCP~6)8m>7Cg|)gTg#W>Xm1pZF9J-WdZe8QxdahH<&Avi;`5|9ON3-UQAGo;WJn<81Ja`?pK1A1 zC9?-p_6KicuqUkF^l5z-X;q(mR|Lw`j0bm)t$4(R;4c-ohf`tWLnoC~gJp?om83{4 zy6q^Y+tq|UjGQz0)e&wOaLVtmcn^9Y@31fmbjPQ(_xj$xc#R7RfwU6v*?Bbr=UjM2 zV}(_Rc1<{QL3kWYB)xx2B=+0^?|gkztK(j>z2Vrli6{VSI1 zpIw8*tO#|a`l0p(H|+TsHf{l(kek{0oJ!ER+BLczsw|rOo*f2|F(6AQ(o=b z0(mAzlq52uIcC4{$}4SK1?jkR_AV_%`Cl#_Z#X%7)-V>?clLb?1t+!3K7+A25OqjKLP1CqPgs zaRF;4aJKo`!8PL!JBpBNFSlL8z>fg-7B8RbJAGs+sA}tk z51PH_fuL84RO(an<6K@>gwo|rEj#QNnEJEc$A-M#+}$58Q!>Q2{_%LjREqSKLG6Xg z_J|eHr~NoEw+{njeeeXcumXi~n_S~?<#?GGc_yhIlt%;OsFi*nb>)DVrzQ8-td)mI ztVFg?Q84eV`5}5d<(P{(J6UH1KJXieDh9RlVLy4b66Jm$b$He4@jW-RghIvz>xwrulfsO;vVz^)nFjgP@jGhGpcN8ZvTyld#x7bUKw2;0%615h%3l_krx z{<3TYEI+1(Qb@hdJ)dd07C?SX-v5B}mnroErnExFZr~4c<;*4xG)3=r&!un_5+5sJ zO2PSpY;iSU8`WnaqG1Eh=O({cK{(|tXqhrT-8<)7nYxDGzZ1JKJuzYc%|G@o1+m%H zj||@xA5{rl8FvtQ-ZZQAKpNbX&%Wz~yu~l#P19TXgrLa2L$`H}a-NtE@mnAK8(Ohr zK;Lk`4Q}TX)?q=}KItSFho|RS?|L9Rd0@ctd_2y?I8O*(qqZ@BY2ek>c$O#CO~rz4 z@pLm|*&RwYw_vTlhI~k$yy2{8 ziFiklwGur}-a+BNZj6l__4kbl&o|@oex_J+$W!wku`^MEIn;~XD#23y#e2EvF9=hN z@B_0$seF>_#`o{}$clS^Q<54Qi2%OuRRmnldt>#@)AXGyG@*9mM~=sbyrBA5A|}eS z!1jzEV1C*_2zDYeEoX&U4gtn{-}I3HBJ-$5i_xq7pU%uQtX(bfxI0jIZnDhW`Yj8S zPY3(XtbwdR6ZohK0Z+8S+9E7=xe)q%POR;q?@C{ahSyJ1iSX8%9eWh_?-5t6o zymEGiWQ3814gdI>C4us^Xv%XOfm)>lf;=U9Nb|*GsHUeWziJNf zjoh3M4EZ2#cKFtEweA-92Xi|X#yxhq*Z#DpT!Ce+aP7?(YxA9{gCbr`p_#Zg+|*iu z%W-V19l<}xW4B#m))s+YiCWE3C6Brdy+^8>r?}526@~ZDr8c(|F~0E=nJ65Zk^BU| zmd4Hb_R%Bev64?bg`;kQ3U-$xLx9RO77K&x=LI|t{rW9m?~ja7EnJn|+R_c$Q5%~- zJRjen9XQjF4jnqvT)I*!9M)kFK-IsxgP`)wyYE~-P*vhkz3GZX=(C{l>~=rc zLJJhB%7+IY86rWXubaX9MrK7(U!XZmPP4E^S#%j`V(wa@5bxua}&!`fv%Oh9hB#crE^E5YhQX`*E zwGw52+q9QeQ_-8_fmM6Cd{~xs);|Bxy2?#=AO^BufoLrDv`>Tt+KF9Cqc!h~ZxSqN z{b_%->}!wuE?kJmnW9G6gBJ-R#X7#VMG9xPyj!`epQBUqjScXLatJs%G`({lB$x_J zQ`^h3R&@Ktqav7uj(w;BA(<^IY*}&5or}pw*UQy3srXu|nWl%hsSU0V;6#`B(Lmv< zOIZF$ZKJd;NU8^WziXu^yX)U`8^hehA%;Xg!J}~=zi(OC+#`LvVuRaVjX%wW!q@kArX^SH&S(m`GbHkeF9v5jqlJ8DN!FCy;NELzg_KHzr4^>hlEAgng3bv-qg}cK%+^1ZHoqA z_O>odcvJ7~yrZ<$MZ$LM>@Bl~_%@FP;m5kWrIj5vt`nK&lvdjq+ ztH^SOwZR!It?HKQyta}cfA~fHK6HpRs`K+%6F*Wx8Gx1=e!Su~xOzipU*VK|YMe-ZH3tMOtW+(Nk?A@00miwPAqB^GYo7o3n?~D487h(4Yo^<@+3CBk62%_&)YR^; za8lSnqRLuA*Md&!(1UIjxsv;tZo1|2Dp`gSI{TJoAP1)Vb9I-%@|%}4o-5_fxE?vY zH3DfV;t*Km6nV6JNCkG6A}wgQwu6W=@EX&{_l}&8*U(FCjhyu`6PQn!GLCc2I^lUm zb@O47n$?&Im=iDOsPFt-eTPyRD%2c?bZPRnXmEym**k=lzbD48(mbRM@ti)a8FE&0 zQfrn}GM=LsVDB!CyvTQir%SSb!7Ib}_!L4UBVhh9l|b}Ne|+1xsJ zMlFA{&yb#f1V6@go>jl?TaJYvV&%Yj-S=;3E>W52 zJ%6al(CGbltsjZLU4l1SRr4(oOG2T4awcCuzk_Qs-53946HF?qYUz==s71*Tl`YW8eyG@W zwsL08WH%&h7d5xY>HfiJQ&on<=G%WuETcFMTSsne|0@`bI#miwd#m}}ZO#)59OFM@ zRzDKRc2KYWbK7|q2tH^!yRq;>58Bhhpso0LIm=>+L71G_maA}B-lqbLr;X6Pp{uu2 zkDf&ZQ3d<{XwG5Vo+>l98w)bJf?2n9j>yZZ9nr@Q;9J?;K1_&Z24dN08Uh}ygS({& z_^v3d;bwZWSJ+Y=eiEraUyW`rd)zeT$>%=@MT>{nYI?OzovV0I;bgDs@~}CcEC!k9 zHiw&~;4R9`e^l6SoIQhkJWyt3hmYR;`ef=`4yUv((Y|-aXT_k32>qtI;$>J2*-5cW z7W_-y2GjS}Q;t|^UNu;F)`?U-Tqi29%j<+sb?O_}BPX>LbVHfG$JeAtr`aeJ+_Y4@ z(IV|OF=pr(WwFOD#m%vhjqkBvt$_0aCG4AfrC|Cy+L|wlc}tfq;EHB6{D#++B6Bo; zRtj?QWrK-UBato7uBj9HO|beMNw4B#${Qa9VgIHpL6C9es|unmD7iE_r#&`|gE#kN1(g7h*9BLe@vW$5h=v)ir;Y+I}!uT0qi}Snk%Z7Z(f`HL8>) zu>Fx4L^)zHk6rVt3Mist<1PG@1lREA(Q?Q`bd zj?c?36u2Zc=#z-~=G6Fat7FH_XT1d`Qa_hfH5J5bKhNqjye!7iQ+CQ=P~pb3ZhTTK zT2#OP3Eb|>p~etRR8Wud*WteS80(#5=XxmD_d7li&%a@XW%IbK0NBF1{E)+8Pvl-y zn`z%cOH)-G4Z^D&yHC7-XOiBda^q3hgO}bv~fv?%g z^?34n_Ud4^t=z1~^&Ru%Y#KIVae3ah?4#fr%hg{2C_JIIzpW7$3@FCw(DkM0+x~o7}Aq`-Q5(E{!+=@ohm(A2Ua*eYrn$_qWh;Z|sCMcp_64@`V-@Bwf@4Z30uvh{M z8#?*ra>t8M7 zY)^W2RU`*0WhTMJ%>^9~{HQtG_ev@A9U9`)QB=1*tO~#S!Gitf%DaPVIp2|iDuW-M z*)CGiNcMK!C5;JZ9DG(&Jyh$zoVvmNpD0^^s@=Q0OgVrkmX_zcMD8yICJ9SY75~Fi z>vA)ng+$(4wZY7)V$qdcp>LUtGq*J7D^atM=vW8%?TZ6di{r4>uF_mD4D)@%t86&o zB&VY@hYJI;wI;8y${e+4r?U7fTkL-mrZ^94H6zeQaX49|z4E@?0 z(64=HF5ZV4&dri;+(3)=s&Z;3iS-bJcz)k|C6D|4u|O8O{U>WKe+h$!sYPY1gGwN8 zNP%KZK)v(DFj?*XkZPQkSKu_H1Bu?)G%n`uS0Sq9FMPZv!RoV9HFdl|E<3hUe`aDT zd3E&hYaTP|5K@F14Hz&}en|lSuqOV7*P*|3?&-@mT<;mJvwoVyR+mwE`JU!evz{l# zHVVU2qf;u54~5HF%6S)!+tRg2f)mX>ZyXGlS-S=km(rb5#qZf=Y?Tz;v&iDnSt?T$ z573<_O)+b|wW$ye+t3~?hS7wKIK5(ZEp0m<<#6e(M;1Bv>R<};KmH0_NojM(_nIBY zDWANR^M17mp-+)U4@ER3QpEO^>hIpPgde%LCY`mu1n|a=yT0NIhUaFY&ZZVv*xk6} z5SI63#!oE1xyg2&qBe1+o85w)UtM%!DHE6?9=cS{OxkZ{8fu_4OJFw{(wOra;=4BOkq68dbJRn+s%!ca z&vK!RzMKr4rQah+1>N40sfcwp7NCfGI}YD#&Ors{TC zTGQjB@V@sO=|Ok10$DSut6%f87cE`2|JWQFSp6O~C-$sj?dhzo`O}%*9W&Q078db( z3~M#ikPECpCeFYuwjxSL{B!?Jm$t`RtuffLK7>+tC<=V}aLk0zxRk3Mx&3&<)Z%5khaV0TN22h9oqB5JL%4 z1BCKE0na&}d++xzct81p0b`J-thLvgbIrZ?+ObrNhtX@kz!}}`I6thjPMKT*AxG7I z?nC{7@FJR>OYXjEr@p+XgbN2mNo8QkgEfTu186V7$dVbo3ZX9gJ1uA`I$z4Lv_P`N z-^cE@;Mj4Lj~58&Gkc9JnW2N( zaZ?lhE{n6Tdj+Kh{%X(RR`Clm9rv%&-uyv4Dx0Pug5EK5WXGijgoXa@lx9?1QPI*3 z=phf#dohFV=qxM4Cra~W<_TGbW9h|->*w=nsO8Eb+}*3`lYVg|uHq&^Q8u&~a=AQj z$M@_okt9rZ@qs<@A^0ToOF>JXM%-#IW{e+Wrd366=&jeda}}|8^pe`7x&%d&qA)X% z^{V!pI?;b<5pPc37n6x}UX3eITT?aPFFL&K+<9{m{}e-R?q5<0VKu7e$skS|zAn=> zKH+AaGmI#4l?s*lcE1`TRA?xK&YX6QQGj2c7ZkFL7G2UaRWMi*$;3@x2^O^3`WDw7 zCwbT?F1J6x4P5)fjZ1qjq^!H%a*6!oSUuoU?Hqq~VWKE)T?`?g!qz*Ezghf2n+~3r z&uLg*?+Q0=iX-Ti2)uUcS4MOKR9R20x;RNO$Il2e41f|-XO$%=Li2U3#!O!7wEi-n zlu?n%jpow3HV?)$$e_Wk$odFKO>{c_qrC_%d92lD+!*aH)Af1m1lz2i$nkyGN@R($ zVF1rMN(-ubPvVl@w^*BTMN!`*Iz5N38rmW<>^pxT2k2-|G609SDh+G`)oZCF`rn|n zm%K%PM`+!Eu~9818%wVnkZeRbVv8P?@_pmk@Ok1BkE2J6kIrH5ei$SDfwlwsiqixO z#J`3X(G$7Kfz4jg(Oc>v@6%cX82uXL%fP<=;JTw zyA=ddcJGeY=#|EuuGlscDVs`|p;{@Jw;exq?1h9l&2auF`4c`Gfh{JQNaer_w(DMk zKQ^Rn%N7&keRB(#HqoY`2v%2q`sP~MhqF^8k2ZLa^89s8p%&`}r8jBXonhJjo$K8J zB*onP-Z7Fue!}jo)mDidH|}BW4flG6qIc}8Y%uq)AY5MQJ1n1MUL3N+W3_|S8}mc- z9a@TAO?bv?{|n-yI~xWBF? z=Q;>r=LSs5x`9!ojB>Oj@{HXq2dr6GCZXIsPAawa?hzRw7_*M^i7XREKr z>Go!nsz%YUd~+j0i{FG%oDDI07=luSmb>Q{$#pxw={tM3sQZ>vGe6+hbfF|eaL!lf zgXp&g?W_84ztyWu59D-$o#EZ>BJAEbjb}DO`4KLji1=F)dhUGkR`|Pc$QA<{%!KSA z`Z%iVQ7s2kiqBP=#9==e{(%uD-)(T-;sAPKDq>wvHPewhe&|HLUz0>{`)bop>zg8v zUzwF34+s!;slaa6HV)TwJR*!PGTZ&B^Yn*BHmGY^?s&zty4;`{ys7(B#mh0hQt=2# zfWEl;5gI&0ARm=|OL}q*9paDxx&(d=xg=vEPLGq8KK_>*U#+3mXzhYSij&DeZPY5t z-=m7;?!@j8BR_XR@X_Z^;Lrb5^%W$7k7kW%(6n0+02&4sC0Z_m+hV=e(tYRsmj^`7 zku~xo?fE?>S_JkQVq!$0x;Zw+EWT^dR!eVpntiHM5~^#%dk}LY#R2C^v^3sC@CDPY zR3~^1e98K|GLLNs#&FioJ+LBErYYK{V#9}u>Pv*~nz3GOXzh9Pr0dE!fV%f?-s;07PceQ|w5Ba>DQnk$Lh4n;~2Z%5i4P zx8dMApLHc=RAY0MO|bC$N_uL10jvN>_$`qh;) zfu6rPf5OuSM-hAN#oiTa&RgW})es3DPlSD&wXDKlB$F8X+kvN(?x$%@ER)3 z>9X|(o$31Eki9({B0N}rc>CCPupRzO6xWvA?b$}#jL(tZDO;LDk36dA-|BolJDN=} z@u24f+el zkrzx8-#L>J_ddSx;{z5ch_Dcn(}@eWHwcT54%%j(h@RFm#!d$S8^Y9du1uoqYFZV% z^joG>z90y2}64uRkyuFa4}FgEM;nkIM60MGF_I+u*ii^L=$>4i@u-TNP8dRn&}FEM~f9LxLoDQ zOCVJCP#6M(`g*%>L~JN`$KI75yS5xMBau-ejC)^Rbh1^E?tIZ1{Ge99zjx4KM?pg| zM(=~g3fvfa=BdtWww?&BcHXRrr7>#FRy70P(HnA=5nOV{gl2 zYfQ){#2uNlXi19pb<;dUvCH1BTLpap%!ZtqaRg?XvwMj(Ul0^^fadqcINhzP2a%-` z)U6dfc=C0QW??Kxq{M%Im3Zj1+U__nICTnR5RH-4OvDpL(3ov~g=qw2Y2vFHP)1>z z5105bvtojG1qh9Ry0_tm(kypBe0B6acZYez?y!QW7{dK$THyCiqOp1L zfYO5>J-S7(!*nQ3E!*$=H%o8u)%4y=)@^Td{o+K}b(x1m_+>aT(oQVitu6BLy|mH; zZ{74nOtyb*L=AOR3)=i<(j)~kn8%@kUo$!L=I&~&!$ije8x&B4cDZOPTYPSQ)*E9) z)-ZMoXIX7gFsW(qeL*)&FLZ?Fq0@Ec7A!tQ7#-i6rCO|vH6F8uHk(xhMo%~QqLV_c zpc}T&fdfY(xB&2Z%>GoKq5R|QC_JRUxOfM4!C8GZE7L!NlbNhJ@S#pSu7>NWEn+Ni z_QIg>WrRTF#+qtbX^~00QkQFJ-*f1tIK9cf6P`F@7mUjFoT>b|Q!vl>Gu4Ap-w7oI z-1uaA<#8RmwTz6SS(6@ILKu^9B3Hfa69tjZAEOw9S}x$0z&1w?wXVkWnDY|4u{t z`!$`oz$OC`mLqR0Cq_z-0Z3jQ+N=n9arl)fAn5=I*Vs39U2akaIO5vxO-H}| zIHFX-zT*g;{7xT31O|PwzN$1))E>`-Ea_X*J9lR*Z-b6&7SmUnvld*!DosV@wfRr^ zaTKk9s%u1Z?Sugn@bL$P2Q)J%!^bx(N=5I|;f@y~32?h*_r_6EXrdL;OZw_~eJVm~ z5|k?My(z@~fT}k9EyT@;>^(8$DuCXSbmNNr9;Novsj*DKzvXG|_jf_wGDmJ3EkR5d zM3oTa4>`^!5wvBTV=DuuDfJd~6V(jG7g(Ae1XHrytyQ88pZK$)^0xtp_~>b=bqI@P zch=}U_;CYW)S-V<6xr!6*_!1(d6TN#xN}SJktlIriGRp46NKpy%psA7YYhn>Z}4># z@2oK4Mmn3CZ`aq`9PP>U3JkyMwK4OO7p<0t-Ji1xA35NHV@Kni)udXH>f%I8@jR8reWy+~#&FCT$v z3Pv6{Aqpa%Kr-9_V40kngjAPqLJLK67n>JFKuWNbSq$`f0wE~YO z#LnnW<}gDk+NE&YLdlYll`Q^9|im!w^ z7BouyJdWgwgqZ6y4kTrV$jLP1q|w8UFX&g0fV%rxMS@w}8b_X+5*4wIB@H!mf}+~d zaj%FCA(gU#$rAdN=~r?g<}nNdlHoG7+eWeYq9Udt0%C4qdhM+^Pk!;oYYz+Y4%6;B zsVOu2M#4&ftbe?#zeP_TMzfn#%r5#Kj!{p7x6VQM)eOOBkN8`h%$U|GRcq=@hWL|K z!|?$b?SJYR-YOe=Rv!w;=-RG+Qkf6U@ZpY1MN`jf*VouV1Dgtv#Yexezf}d={Jzr0 zobAQ-=X;icRXqt?!A5^6(^a4;9Lq%dV;g!f)9L-R;72p`I2MDoeVTsl%3UQP6A-Ug zuEcp>$&3NWN7%GNTb8@zo)^K61!S2!?l<@D^AOX!SVrixps6+#%ZNb1oY)Y1qQ@2p z2uRZdib>&&vNW1iq8^R%wiNk`y0&{Y8nczlYxDupo?4%9)9NiWU<+BXN^jzNG5tAd(%hegzKQ&iXDL}23&jk_gG-~|0)&# zmBeJs&Zf+MJ6##OkO)@JnQJ)`dI&gy2K`)@ICtV4Fl_@=l=*G`!FcR9-mj*;7c-TC zimRAUjyu+HOyi;3(ht~cf@etayz|*wwPY)#_-;i$CUOFimaCD2$WV@74~V1gFt-K`;@ zK5BQ_&E33n_N(#5uC&7=(KUbWdDyyBnW3%38Ga;dt}x?oE;_?vrGKm!&DWkjGAN(Z zdrO2AC%z6ldLrxOI>F5x=rv*3+R_q9Ajz1A;bQb?#qFwZ#ly{v3Th}|BJ}Aa5XVKk z&47aE$zNG7(|P{6z0pS1!PlwQcE`%5hH$?eSc((1oWf8*+bA0xhiRJe@^-)g)VUrF zweR?|PZHh4IWy#xO5b&OO7gVdO5f&otg;wAo5xz|H*LpPe!$9s6$7H7mY_9MQ*LXY zVUx%B>D94V{Z&Vu;kXT^g_fd`$8HCk!`G52-J!)@0W4#3;uR4hkU_)A z?i1h2GuvI}eYkU!x4m^k*%|uEC%dcxWuP?OD31Mt|GUXD*9OIvCXMG5o1}zdZTPf{ z5^P*)_ZE&<7HJXBkL3;{T^!#YiJb7C6!^7O>muRzEC=0Yd-UkB04gLnZT|V7cOGU= zE^30OGX-)v2^e76X?=&qV!1}&IN#2G5>m7^WlwCM+3R<6B3Hc7JB|TaZ|8NBr&`x3 zl`&)=>%wCp=TZr)Y3^^ut+5@hq77pp%ymtgav7m5!m00D%uXA=>wHExMJFSDC0ldc zC%a0K1&3wdI)27vN3(u+%f#{R_{`19+~yVcm2#VjF&?{g={56|x?RB04U6eSecgXH&7s6rr$Uxv(4zdt>!R()25Ye3yG0 z#dw@z6Avp|y$^CVqdXzRy64`sY7y-|aBdO@urY<8iWVE&DX1hRQ4I6>aHcYB+Z*5% zT|)gyBZ>~Nl2{iz?L|BOBV-n6$5*8_Z>Gt|2^V1*Zi8B5uyRJWV||C(xk~obPLDY7vcgJfn$b zYtfW)S=O4x*07-^Tf-UK+!Ny2-JvSmj0Aw5E9w63Q+1|B1%25Gk&tg58C<>@hn>aH zm;3$>jl_I-V-82I)VEwoy*u5FozAazQMmOqaij~iRIBb z(+l$C!aGo_%fy_HDJzcYFAX<#kKRC_D~n}$B5QbN*F0Zf()2lRO-r@6Rtr2$d z_gX_L%jV^gJj3L>Z^uk-ibOpG#RA2_4bId-#6#_Ox*&z}!^N3{2L24ij6%)FEqfgF z4;B#&=2>x)6B?UA!2IrDZk?PjCecJFK7>|=U0_8LJdLODwXNm$3Ft$_3#k$3(BOLo ze2Kkh9u#J7q^AnjP@V|MDn9{1eq`KOe2+L$0nE3lhAL=EqS+VhVdh1f6qskxgfJaJ z;PI@>Pdy_3HP?@fD4;(%mkmsV5{kIr8_#GD`*^v<6AYj_C33Bk_d4&~)jysDAxl+jmo(hG`0{v~8c4nl(@3~$48aVH&G8?)fo{pXhH#Jy&6yB}s!%R&xa859 zf0f_+#_mXq+|b-A)`fj*d!gKoAu2^Q+&1wGEQ+(Y|7r3hRpa9u^2vO`jFVNbfGBA4 zmW%aB%m08U3>D?+Tn0KLs%2Z))7DJ>P)TY-W;=@zJ$fr3Z(q^TdX14z2^WEI$e&RH ziWM`*q+VY)V>Arg*>awLmslKlBF+hk{aw8$Q=;=WGZHHQdcPAOKr#7tWHrc&Z16`} zT=Ds9bp8$xqbaxWm4Y0Jt9Q&^zcE@GbAp+E<>jDwy**h;ZJ94;_4^HT&nzFNIV04V zLz1T|*}W}4+>^&SY`V1SFLqp|LRChvmcOU@cLN}9m!HoMv`@&-ilXWDD!u}V@K5ss z`Af%`cKmD20?;+BnP^D30Whp=tEL-g#_iI34O_M0Zh-vncqEEY{c=MpM33iZjf7?$ z2z7LBv>I<0P9)RmW18A$j0lkm!$IdSVHP>NgHgX`rg!|SQdzCkm3vT(u00IYn*&j` z*U*JW#pZPWjomlo6wEvX$kLrxgXf0|WdPI`AMSd7oRF(m%DvV0>9C0O)=;Ig#VU|E z4W(bv>b&HI=Pf|!J?iiN)J9MQPV@#;kx{L*yFIbp@%`3P&T=`4#*}j`t3)FjB^N$x z0e*c)8r~z*b$8`=sK}CNkS%xzNbyYV^u9b?Xsl1BuWc(WMze2z%}G3+r1$Yhc3 zH%q5}K+@51O1QH{&5#hw+cugA1PrjZbox}Ey)5$uezT@5f3kGq?^)}nCJAZ1Ew+GO zX?8w;*yKrJw2t+Umkr(-e11nZlNnINEA;FRXFN!CgybJ7d~%?#H&n9d`_N z3-_6q_M6!L!FL_NkMx$?AMOD(xohzH4kwe^2$KqV;9N=cTp##~$X1$o1NrktGDoPVQ^V%PS2nEiq60zx?h`l=}cSFZZH(l+-cO%ZSv4M`hQJ0~q?&9d+ zn);}}qNTgKGN-C5+=*f@2~ep~TPh|uJ4F5Z*!+ptmJ_0UM;triV>NR*nai)x#99Qv z?TT1tJfXxw3y75QV6fdQM=`yc+m{5_*@rjyR1PU2xln^vaU&6R9kHK0{3TSgi2uQf zCDlu*6e~rj2w}nVFCo3a>30RC=%b*@AKE_Oo`}8;<~C?RG!rC@b2k?X00sTWC+NYu z@(N~0j~eej4On*`*tIPfW~5hvI(I|dJmA`uk!jZTC14xC6tcrW_=P_pW{t)8%6yEd z44U?W83yJ)E<;Nnq_9{XMs1nI*o7z@77oLg>9o|2m z8ozLiqpMH>0he_f9&mm6=l9Q_IifuEvw@y!7X*dhFfb4i@>70PzF$&dA#DRDXRkAxyXMR~+IN4O^Etn0_){DA5TP#|E%2TfSZ$YJT=eGY;un0!%J^(K& zB`M9PGHR{`A03_&K&v)7y3VIl2ty$u6Q`Sig)+*OVkXO17+u$mW&`M7?7wEl$8R<3DU#zLf1n>5zNr6yTG0}>nnCpzk z;pPWa&J%|`F&f%6hg1$2QDU5{hO$EDi3(~Q8OwLR@&ej@c;qog)J92V1E+^9rlKKn zM@l5gEnib6gx|V_yp6k_gNdR1L9;godK$N9ytq^gbj+1#%b1@2wSLUxU;K$Y*oH&V z)VGFoedu-mX(_iTydl$0EU8A;;Ptr#_4~2nuiH{C3s^^U_g&ZpTjt;4YxIW%01{K% zubZ?_72f_i1_)SfCT{^|A`QhJUiq=}EA$D^xs8OA3;&b|-usFnPL_G!E5(6@aL%L! z8sdAL(@iceH~iA_#>Y=jHlv~*$=&GElQ!+JfV^sB7W#O& zP$Y-t?O|FFmRx*SF5~+h-+~tlOrGnZJJR9@$S7ORkoPB$F~0=!R0}s?yl0rpYa5yK zby!5S@MMqZg8H%|Onl3}n*fc)^>i=!))FU70T<+@2b8W=1v&Nc=hG^3rEB#RT@<^c z2{PY+*>lYjF9g;~K3(6Xmp%$nwcnUHAF`T?APCzYp(~vuFXynNx&|CY8)DnnemqHpA#mpSS&(Yga=$#q1 z$@=(oPaDCOGhbw5id&Jtin)JD=|_)pZ@YwH9D83HGiLHiX0P2pI(om!lc%_?dW5Qy z=cZ_lUB0;%_#-#WE5i5K2IB+d&?`nmb;^6lRLj67<2w>JlA!fDQCcx2G z>ka+u)tb?7k*QPGZs%eyDLxPml0j18Ra7;qQF@NPT{fXebW*>hW9kjDZR=yXZ(HTk zDn7l)eQb3j z)i*7b>r>{)+XvFy=@nmD>1l0hrJ)+G3|=@|;+hCt*_AzBTKK`;g6GTX;kd!K;d091g|Kv8JZg^*i-|%!J0DD9q5a~ z@PJsz$e~#eikDdSqQ1w> z4;csxFs^zuJ-FNk)oMU3(POh@K(n@(@xHW(hD4fj_7$y3(`YX z578~Qx26sE+pSnoYc=7_aA9;{KyiLYh9kd^6xg-LuyNk)&_;)vcLyQ|+;(kj&B!5@ zR@jkIUDERhxt2gJo~QZ(;v)*_j@)97`(3Cf{0FX9g0XND=A~|}wny>IfFy@k;!7I! zi=YaRkcPC8>I1knAAQEBIB>9*HvblTcr5D2U{D0?li1#!^$@yd0&5Tq3ax+#2vm^XOqm^06*3S{o4RHPRDB0G)2;$eEwSD7bQrzkAr>vgS34>LJ%gXe) zGuaV!3imK}KO+#5UsCco=36lv9--fAM7ganhD(i>fxfa5nGFw;+MgfGZa?`N$)nPn zIbwB^=>;j6jD$>?YI7e%D+eK&_QL4?b#Vi@q0dA~e^sEzdO_D*pDqKbH)ATtIAs-C z7(uebz^*i|u6}b$(C~SINFF({i?P-b6;k6&C;78d;8xUvUa!1JyKFK9>)AB7q=Joj z(CHm|Kdc_9hK*p!Mq$F!=j-zZ@X=YT>^Sd%_SGqA;=D_{4JN?`M0c;g&RB!jRotF2 z8^}y^;TJ^~+{4f>anpOmk`E-GeK8>cr*|#-V!`VJw$X#Ji>Olq_VLKi$&ht5gTw^R zh!e!Y1vt%y-a8B9hYxZZ+iWqEj}fVI|1IY(~{DIL0QVpOTCc>yPt%d|Os1L-$ihu_B&*+r2+%kv8cO0)}wA1cF-^Nc*ZY#Oku!Ah#>q%SNukITk0-QQ>M?HzYs0>M>Drmr7l<=KFU_b}`@!5Tgwyvc z1j!RKZaF5^+M)#_;T1|f)(=lFr`xnTMZ?@05x`cFfmj*PK!A=&jnLZ+UH_w`fzpo> zIMns-Q3YWQZ>W-{dHU467~LS}tn>fy9Cm~>G729Pag>NluaFcNmVom(M#;`S_gHs? z)MGMDAGEc!!TUdJtWbz4Q7i7AVuy8DuBm60LEOBp(%~s_5(48AaN?9|R3gg;)Y4)D z!kRcBYARgsU3QAzS)m{g;Wo=)m+7_NQ5a9;_5G^3@NLvbHhZ#HRPj~}V_EqE2A06` z>9*@ZazZhzaPEYd>23btYY?nlvhBSqJW2OB8M0V0^IQvc`#?-PU6FXA0tgoYYT43> zw}zD`Gi^^ze*b>2bluX$oY&wmqbXV`ZE8JGS3tA?=Gg$f?t z1nBqc3_sC0*xzo`WB99wvII2Nd{pthoemIaU?Tuy<)fx0PSvb$d@Di@K=XZbHJFS#?#rr_|+^h$= zHiV(YhrqXO?F&rhbdmxUs0it408>bYuZ#41h5;bjBhIKUG zX~gb|0Utnz^gW!u<9Oue??G6~b+zpxFSR};$TZ;_ck<+PSjnm!1E^ao(?bK&@!Ktw zP%)5Jwidj#rLbLdC}Nc1G2jTn-jhNxuJ&*OlYUL>h=&NDmNuWA4Yd?;k?)}PRcT5s zF&Rj!8RN5Q5!Ttk8tsg=1NkErMBJm?v%Ml;>`fAcmlSTUvs(7+C%E&Vs=D_gnZK$xA zF;(Ek`ZEfnt)6~@KNbiEH_e74a%?|(F;rlVgBp#9SUO7B2mzntWm>9NuEQa<))(@9 zA@x>)Q!bEzL#q9BaMDYg^x$R)4+{MY1E*kEIHdJ^GBG@&Fov<>a;;0re4l)Y&IJj%E@VZ*+);Cy5 z8MYHhKA=oE(YFxe3Ar|_e}aCeI>t>5)B^Lm0PhaAiA%pyY5VxD!>n1513?sS^_E5o ztkQSy&m4-GXTlakg(ph)sy0Blo1;T4+qWN|LO6a#-e_>NiPi`N?Tyh-N1SP8vH65; z!lET!i!7}GB3sWNEcL1AO$0$0oiZxA=war4 zx%XBYbpRS|jm)E2NSLOnUce0av-THerOXM+MPO#M3Zx+%xIJ zItDW*Mk|tl?67K0WhY8FrMBB(C;R2n*aP)m^42CkePO;j`{BedAgsfxs@9lX>3)6R2EVkuc(0I7*842tlGdCBE1C zJ}t+KA#dXRGMZBtrtB*#`}; z9fo{aa<|U67O|R*fRSX=S(e$RoZbh_V&8S|7RbOrp~mz05`UNAJ@K3Y@Z<9c>%sZtfI|3|ul^p~4tGV6`SfGq$67+&oIBz(nTP3` zN_PB(7th827(TI&CAUAiR-zCt)*oMb@>IPTa=m_7qH@XQ4G{3uYr|wXK1~^hbG2YY zT65d&8o~ITynE{)3C<1eg_QHFb(Dh*@`^%~*5K)3bgkq)<+VjWLCM50HOV+wvd0^T zMr5T9iC);g!!~)6K9mxr%qnI5;S~o()memV9CMR+w#X{CE32`N&-f>+M}f7x+;rek zJDJL9b3IJ`WmsP1eax2=;fhD-n%hmAWxOa7PGY-Z4zm@vi^?-R+^&cN^d*2%SbQvL zb-KTSw@~4HYxW|`TTOc^Q&>Dj5S}_>H3%-uy8nii2X7@CZsdYdupCKK9nV2#uM3LI zQWi`O<#l9e_57+>@Z^-)iQ}Z8cG$gY7o6D0FQ9HFp6UdH=+m$1n@EjD-E0vRZ1PH2 zTJwsd%~lywtz(4}podugkkMN&e8)XPwBS7RVSNiK1Im!nR_;jl?~16N+# zJMLh93fM+>%t@*@4;TvUYunh1q+CeB)<}v(6tAd%_;5MXNg-r(2+6P1Aw?B|@5E~- zAP+5D%~j{&U#}h{(QhvYwc>E|}gqSh={txsvDNmu_=0(wA#@ z-h_n10l+hpF!Kv5wY;P6Zs~}?8y-G;i1JJ&w98%MdW&J@#idig_PEdf5atf%5mL{z z3T)X6j_$`rt8&OQQ{QJmSv3rr6}nvOqBg=QaPZm;mCoDvs^batYF_I>rUX0;nU_1>XDQg!tAp-r4bDrTuQL5!d^`YGv$mIV9lZJiQ8`)DW_f=-f?XKI zjn(!E!JXgzI-XX=*>>>=Dz&5%G|2h&?!lC}Gw`3~x+^o#O4sam4a3Nk?@rYYcHN`7 z5%uVZz}nYLHQmOZS=qg+*AO1EajnbCl2>Kj(6YS}DcQxk2# zVvylXuXZbr+Kd=z)2!|XVUYkk{#IvyhbO#nSGw-cVCUte2IajPJmM;)jzg4+7dz}` z9oT~PSED@!s-m9#S8_zVx^Uir3w^E3`F z53FO!Gsrip);NCn z?{Eon2;1NGQ>jT(QbM0vXWeF~Xdem*IBZLDj}lVo|Y%MZ1a( z6EA`AB4|N#^TNTB5jST0L#tck*9a?YLwqXTHxW8gRC0M8rG~8KSXtX%3xWs(80}B-yQ3f z6fxz&y;KgEe2ZnmU&Q6&E)d?M>*19S*=cV}`EJcfec8s02I=2OOZ*OPzD_&Gu;oeD z0P51GAT4;fVthh}zjiS^Z2xrs!`1lacAcJ}h*{OvzMy3}9MTHwum_dJlIOdiJ?iY^ z4B{?4SYZnozb;E-`_1A*Nseu4af0gnAUZ?ZD#@PfoK^$!28dWQr(C7W7@91tZy_Ei zd4D%AnD|~o(8h{0@?Dmbx&X`$jt~TPbBei#j^Cl$%|_JePm>P>?jSTfbn5(0Quyxf*TmNPk z|9BvjQ})f$0qB^j)3z$p=jZ1q59=Ob?Jeu#m)oJ%xx7Iq+53=y?QLsYeTh-=bc?q% zIiO6peBqlIp(u?-strT%;}y582WWNkHf==O$jDjWMC!Fg%}PnY0he`!@#C9i`94jm z9jz39?O~FG74k)$R9RfQl+y3T-OuENRP~s>Li*GFrH;>DW%M9Xa#_WV@_b%ZKDrQ8 zMWriG(;+Rb`{?0#YHEEv*K8o zk`cvbgj$CtP<(&ep!@!UY%lq;mn2&ZI|F>b^-HgoP?cdLOomwUZe$I0qol=g={&U3dD4$04i6 zSpd3U3$CjNaESZt`3lv*Acy574WVG#6c8gX%S&7Nd1YI~_5v95axo8vd&dNW00$}G zY;+5PD#oZ=tV0TrxK|CaC4uJ8B77fT?3d)u6}x|^`!v&38>7y{+RCA0l^i+-R~~^~ zEX^C}>jm|fY4n=+3y%`~lR)raRhcS(pfOc>bV*NoL3yf49b;WfO)MBCPEn&0ai^T3 zJVG=?Rn}dqZsmzDa=Zm;eTyvcf^)%L`qOPu%Gdf%(7%ceBn*IWO&QEs2wk<98baLf zr`QVO*4C3uRR<`qDi&T3Jb&&zsXb*)c{L8FBD|%qaX|Z>X5@h9rY^ZN6+TN*mo2FK zBhJCuYy_KmE%LTlMp%9H+%(0B9Uqd;zGVPB9u}#Ev(ZB6i8aG!QFZKtp(b80>qiyp zk$lI7rP_xRgj>rZNT6nlGB@42x?olrPyOO|!SSXgHxOR3MYovzM6Yg=TW@)+AhCMX zvX_T`kMho<{X>!i?Xz8;J!{GkwO-Ju{fWFg_-OTG5loAKm(Ra9bTleho=X-sK_ zX1GeHWNIH!a1ebly7BgU#)+ASB0Q=78jZ7YEK1)7)=LGq3omlKS$F%1Q2qV3?M#_`8mm&Dg3QpVOMFgd?Kc%^tf-xY^v=9R2k5i z$u}pH1|lJ~2!0_OrTUaQ`ig-7|EA$H7Ch7 zm86bTNI;s?cgek&%_4pXDm1lzYWlelsH%?LlVR8pMaNzQCAz}ZY=$iE1R^+4bXPEP z!0VSE({*t$?<~!D2jgIkYqKj)AhO=q4dzbA8~$EJ-D#ibDOFVJPxTm}UhB|jun#s| zY&uCsB+sbx z@GVKS{=rlG#24?06-Z6*eFs-t*B=H>l^pGs1quU!q?*|^cONROzNgqHyBo0tb_?Jq zy(sk@fe)Q{CB>|iA5(9vxqO3lKu@M*^#tmLwaF%|HRp^{vPOML06z_G02FHTLpLvu zD+*w3nqPV}Yn=Zjh#deVO)ZrSPK*}4tY z%m%QTcB#rw>FZz0oV~9%B$zhmX<0uQl%;5c`ocajK>IT8Y3>b4q68k*8S|;Fvj~na zm|Rb))a|8fs3=ac4B^FG5-$;F`gk}&3jWQxa75c_K8>j(Tz+kSM_c@+y(9O&Ij3>P zbH!hFhq!ra`#}vEy9CWLZ&y2iwfp!|u2maL*&zvttVdmiK!fsRZ&(Jf>&-!PDrI*b zc$%+&VOiJeNF}N-Wm;0uxA%b>cY^{-(*7+^xK?#pBx$`<8k`}4n?@?aAFC&(GEm`H zdSNJ!etS0H-c;=F0kvG&z-b%+omdJg01{;{RF4yG>ujPj#!plX>pI$lrr*gh6Y zeB4O7w}RxwsymlC(`Sov^MTqXI&un3mw!fdjSil&=lL5PT5RCS)Y z^g3(RCZSwtV)@eR(yhnboPX&{aVxdI;OKl?EC%e?D~x(`wlAh38J_yB*jmvgbuK0< z#Z5$D+as(FTW-N&&#kgQA2m3=`FLQ97={vvTy#0Ig}1W&fYy{8v)%I~XKBp`_1J$d zIt{ah^Xso3{x*CBuN;f{v&_`+;#W==dFP3$FWAvE_O=IS;^jNWM|*%S@nz1wTfCbA z9aBuQglqO^pYWESyep&f_(rCj%Pi3QSdO~kav(1wf5kKXR=~1B_o7>Z1D_TP+C#zc zw{<2mF9m=8{292x%J%~SUl;xy)U-Imxx)s0{`Vb%##n;m!eDP@OAdCU`}r2~{&0Xlgx0uEmvCS2x$!Ho2l^6PyDcH86pVl}SF6s)~w@3&5ps z{@nIqMvwY_S4V5IRxNe=7N=wP_A~tJ;cN8ul_c_}fEd5h|I5KQI2Xn3d`(?hCgEEb zpyiY`5v$lTQ$3RoBY1ER_mir9MJSz!_z<)Zt$f)GhXhbFGAxYMw1<~)_`L-;zQ+H zm`oS&{tq@s_olYPM&Kw`)c_Q5f;=$B=5E11t8>QKMO*k(@)-;9&3s}PIpF_maE;63 zmvp;#xhFWFbXm~Yl*Wn?(pdfdqygaQgH(hFw;tTtIix-~pg!O_07vH)?UlopKKnhy<6}^(fNuRAxw+Bz_IYA3@XWQ`^n5GTO1#o9U#7;bA_lHdQhdr7xYJ%Yo5ndSkl+*60D?fW^;3D@8!NhHuA#2mHCnNtRhg z&y>)pF#d@kN+DX|9=S?egB`I0XRiGc<`+nwOXmDq5@x=iNS|h zSXqUwv1v^)srZ53-sT1WT||ruEG#TcGkg^|y~?ZMNGp+eJwUMG2z4`m{yu=hwb36B zh>D7~^!HmK+}x&1?mlN<`}wK6+Zfz``q~x^zPSfUJ!{Ef960;^a|hM6L+#q&XBrL4 zopO6sjs;a0K78bVb5KAj7 z@c8&R*{0v#?4ZVW$O!=fX!2nk_GHzLiHQl@;Sjisiyd`^z@XYu7MA&JwPJk zLJtA{Et>pK<8EAzN7BkRdDr)`u^ffDdlLW%^GjCyElE=sz#Nsc)`m)@k{3F9=T{1_ z*o2aj5`gso`*ZhBGDg zlm0`~z@aXYp~v>`-~ZoFbkCGB7{qy^d%M1`tSD1UJDXx8ac-Xh*O$V_0KZcuoB1>g z@%W_BW38>NMezT8rHu=8baaUAJLT-^YEO-(^z`tCCcS$9(u1nSaohAW;HMoF3ZmAA zR@LtVya6z?e}6ta5QF|y-!6D9_A1{qj``r!^j(2B!Y=@Eq(m+Kvb2N*+x7JIwUlmd zP59^Afi@R5{e18lXD|WZGr?Lx8;&lpSU9*g$E-kKjHFEJe~ypOg;5_5T)%$Zbl24$ zGP$vIvK_d`L!Emtm?&^RAlATk3P2U{zwgH}=kt@#y1KfYsS=i!me7Xc7sJ47e+5=l zDy998dwG030E59yw_JVk?YRe4qR-QUL?V$()0upi{vZDij0DrgD+-q`H35${ODE!N z_VDoVkV|nl8U)O(t%?8rVD}hLguwFhxSvmdj&{kiYw`vV6KRK9$<+bba8+PbxCkJ#t? z=~#spf`Rhcp;}^=6}x5PceflVeul69d8F34O_bi=UQOYm&3-1iz3<-bM<&BGfRU0%U&P6-N;GY-8NGQjBYCK4M^cv<>$ zZg=)54e0N`lCdA&7a6eTDag_ZGM>@HL~7R$1qSAU(OghlTRT?-Uu-hXawT9seKap2 znP zm#{K<`57GO*EJVLoRgq&>5xCzM7_qhBO?Ksj4Mr)HZJr~>A~F(_a z@ij-__|dA^&@{VkTV+|<_DqKZ3@V;cvLNg@<)q+&NF;I9@K+cOKP7g@x>2N=^|x5IzEpgz1) zFtnxvj+reS=E7qa;BYQ=&s%$XGy~74<>ke=%~uOjZO}uWtM0klLSM~aw!U)1<;CyZ z1iyZF_kQZJgRt8{eZN-1UpE|nL>br*X78%(1=qMB@YV+olow(PyK7Q2+&Va*$UQoR zutk|}p_sEL>mHn}vp^>G3OJADb(foEe5jm)ANz%>9EC!ubqqdp<_w%!d3m|3K>o>$ z&dkNgXy_Dy`ODNQiV240m?u0ms(Fk~uwl=q-hHm{_=yw0fB_5-4|lDzSqMAODs>@s z7htQM>L~ElwQiO`Ab^D;=a@O&MVG5=XIA$U5{Lnhb8?O`KDu@HGFd;aP&-~M@9D1(8C&9I{~~*VR7*mrC@o4Zrn`o z-@otFdI7oGi2Eri&hRxg)uk@W)9SZkq$KI-zqm2jOH|=`uhs7})}+BV#=KDx(jJx6 zrM*W}iIu~zCgT^Vs5pc71rY-83l>|?U5Zh57PsCJgh-D|??OaG#)DC*$kBQDy&ODl z&Gv4u8Jh6T3&W2^|Coap?V;4=R3i8nDvy9rZ{0ahWP){y!j!esd-o`?FZL<7SX=Ge zm);9ES!)Ic2G$PM&w2MPT;&inFL5wuGMb3-A@1**E52pe`zUpt13(7(jnj`V+LWD} z`y<+r-nzBx*O!${Q+N#jf9KGzM-RvL{BkE9H!Zu|Y74h{KaieRFm=|uZmb=)Y~Gz&|53-I@}!jdyeYH3ms zN4m7n$|?rsK_w5sj=j|DB>CDLI56MPXZ`cHpmOL8dzD)ltq5a6Vu@8BJzqA_!wKaP z)CvlI4Nl`oUO_>@ZtKm#8X(N+0ElDfy^=pH7|>y5WMo_l*IhP#ftu@Kf`PVu{t*Ce z*8m=k`tDjsj<-yoXE_lcENA|m$t)$+@oKM1f7Cd=@D-+RR$;-tX+ZN|HI z?|wwdxzeNG426lRbaLum$e|PgH%#|&BLiJo(j$T=4?EPhlV;KD}yZ9Nf@2F7j9Jx??bsf$ukO-1~0<-4`4A#YTRykze5L=VIyq6K&-G zt#Z7|3+)Yo4OBtLKu6N#(nIk|Q@CmKiPEGIvkZycIvLs$!nH zeEf_~>r1e}Yp6~rXmAP)6f|N&TU%RQ>HED0jm8;MM>Dw7NH^AhCih~Xhw9iUi#l&* zrRJ&eP!dD=Nxe7L3rq;&+k%Fr7bI}j^F3an;-iCGZU7Z?!$YPNiAoGA-<%r9ySmu? zAeoGx@6rR6liNC_ebdM&ojCOCPAk^em;o(m(lw`PN|Nr_S=7gl>5m>og-klXbK1Ns zi_M!*erc^ER8?w+64ivw{r#)v2f79ZZm7GXWJiMy&5~_Hq4M4pyYVZOWr8ByP)Ek( zsi26?j_2XH*1%rW8f9GZsc>;|sU2GLp`SQD$myaq_xvwbt?7Z^V3(zmqCAiYVnBgI zCQn7u>3F`tkE2#cQh~V}juip*)uhFl*{Is%mUnN&j|y#_qv_^Xod0}+ZEC)Rzs8yP#Ogb>#G8nV@xbeK1@@hvGHIR0n8jp%*Vn=DVJWRec+UU*O zCr>tJidJ=YBN6yn77;jj7ggpD1|5e|uiJd3+3G8`_|4VHR<%Q6Ct$k`?fpa}-kH#l z{~sP1iYn;lLGsOV(lb8cFw>PQ?~4F?>bCnINZd5cV96%Si!<3%LK zhKvuEoAF&1?f`76YfK*y$D`co>fnrOTSMe_f0t!9nMjv?OP@ip(+ie0X?!p${@0o9 z-+SXRQnpi*54pK%)wwA2O(n^B?&yT8bI;1YT2-i;|5k`81S4N?YFgNI#VbmM@}*bY z@u_QRkfrZ=P}1UM)NseFykB+&P(&5k)0mRX%uFc3z%#Qo-Y&9raA3}NH&?sR3&rQa z>)#`ZAQ5eD%dW4g);2ark>*Sq7mUoZpiUF1meMO7=XJLQo6X@Job*UP+MLE+?v>fF z9F()C5w3<{f}uL%Jj!heo2iJbX~|W^i0HnmH^2+vBCbJ2o~9|gXqD`?@KKognbnws z_r7Hxe0gd-;=8Q<6_j{)s0qv)fI<|c`7QOJFC(?c#O4ib!^mY>)Zm6Q9E-H#g^`!C zZ|?Yq8lOV-zk7Rw*Fu%;LLs(q>Z!Ny-pwx>Q9H!*ziWI7svhBm3ts6cFWxj;&7c`oXPDuGbA{&RGD8OQ z1fUbOg9AIj$4S{hb$-7YpL#1n5X4tSVTnTg)Yj5aOz6p&2ei|;9|OVvVrBBbKc0k& zu+|7xAqCLUFK)~dlPOc5dICPtlVYw!wfCx^zqq*goGAMC>;3i>>|6;ek6v}2dlrmInB2q?h%4HQL-K7Paz`T zNlSVFC)P!O_{RL8msiF*0|QqerECL*(hbN5B$2PH#$$k7J{6;G2wMzG zYLm&Gy>8<3*B*F}_X!BdmG@}Oe^#ObD7`!rS@tIeKsow(A@{AcU~2;K-#ACMGr6Lq z{vrzs-#vbLN>3(tdWNHFdmUa8y`fA2g!zP1<9!c+O#}vAMyr%LHH3X8`&${`c-(*- zyyHC7$B{|?u5l?wJz6FkGAy#w}ATwT6BofNyU>TQM#BfS@Au@l}x zGfcLPxT|H}1ln4+!WJN+)Y4?U8^vCTswa-G9+KJhJPwCwKvVOeE0JheN4kdiJC5hK zSexXE?y1$QE|jRMxLnng9d!oH2j}&+uCBaf&ITZPW>Sk0abf5!I2vi;ylSgK4Pd-> zKp~Yt>Oi$_x2ii^z(u4@mr1gYJN|rXyh5&g^R&*jjtcKC_U(aEuX62`OKN^<4R^xl z@Sq?(pViZIkH5xGJ&YhTT5Kx0Ibjl_8+k4%ZcP-_FCD+QZ$58eV0B$;cx7d!>%H$4 z#25RjC|i&_kPX?_Tc5mP9K0G)ejInMU$^dPoMGVQ-oO~Bghv!AztUM^=irc~qU{Y3 zB|I9L_Mz@~H&NtyKo-=5aR=w}7|PUoZ`^k=9Z}>Y@YZX?WYYnC)*YD=TX5SdKD01JRXCD?ypm$P>XdF}1(A6Tz?^W-Bmz8i;8 zv6ty^fPptkXZn|r5(q&54_0un2i@H61FzB*A`B$WzKEGR!sQE2;@-)`S8*7TQnK5% z$;0Jmelk8gpyTo&H}@UN@SqUFvA+HN`&+dQ4GRtFCpB-&{79gsYOX%?vyy_8+Sd+P ze12XyK)mxqW~J}lKaG89aj=pK4fI&mVnfV?)K$%^)e*L^~ zT=Q%W19Bz18kUOt5Vn++Wn_vIrU40j(zhVh=T9pu9RfcRSRnu=@M^@$^$g$g8_;}z z%ybCFM;M2mq&}FVUt-2<)sS`Hj{Est3HfGo+NNB)Xen-}dJ9<+AbPc|yCt|#%d1LrT*%FXfK~>VyDdbQK&$&Ny7*u+Z=N$w2qdo?4wLOE5v|q z{c6gTDYcny(%emmG>$#tDC-NOT&sDo6d`*bLVW==+Vcx8agIQ{Dl3_a=*O8dGW$Ta z>P{;GGcKbyHw~!13r$hLm#L+oitW5|$jeJtR@UiE62zp>3IX=oS)pLe@heUTFfOIr zKNrOwZgU%Dp&CN_1yaoEh?rfMqJb2g6O1UFb+drMcww%Zyjp{dMA(@mKe5nX>FjZo z!WccrrBV-)`4p_To->o-eRYP6B*h}PQ*uj}F0J*kL$7(kNwXpIHHdaz&2Hm)K$2 zG`~ATvXDSuKfwYE*G9guxw$!y&IA%3+J85%7%rFV?BS6Le8+N?JK>P)a;@NGcV=dw z)tAitdBrd8g14X93Iiq9xN6gmH5il86eW5U2=QiY3SAj0=Ybk0#3|trz3#?r9bU@(?}u3JK8T8$N$mNIK|H&}MqS)!C2re>AHRIfy66o5IJ+sByWdOQ95-<$ zX)Q$HrxDHa8U*q^4AU(Kk@Xq6;+vzvSG}Oq;G={ji396slwCKA?14b!I);L(3gJZb#Cm_pO&BTnBqIF2!DjcZxnppLO+#DTyw|q0zHOx z6!WyM+l15rY1L6k;{d7!o2Myv2)x6OiUSw>F73CoTNt8yOJIf~mm$-yNY*N!vEOR~ zQiop#wr**>bPC%V(u`Y;C*bd`#{03(v`=B(UyTnXlRHdq%=_9)zrFtmDYd?WljSA0 zlY0u_$o89HJm$6PCQ z!f`j0Mzb5y>-nmM@)mzIPPom7l$1?cA>_I3Gex%$1snGSEBXvK`tJ8CBP(YT_#5X? z4D)CB@f&~pky{YO{^5Nd&fq9*zGPw-`$=Ve!Gi$ZJ4?r_Hij6VFT|CMlP|juY3l)ggva6wQdC zl7{n#pGDJM-_V6l+*WN#Ao>!~p21Fy`G!hbrz)Or?f$iwk3odvuG0Do>r}i9+X)q- z?8#3bElT;y(=GJaW?PMxYYH_NUat8zF9nl!ejDM~~f{oJNT`PH|Xs>>G&-_gRh5!FUbrTAofc zT7^o;#{3zlLUZs_gWmFsKU`ZmlN@@iIFF6`iK^v==J8;uG!LqxP@?WaFEGLDC{o|} zyanZ~4zP=QKUx7VMrnjic-SN;brADn9#hB-3$&)=v=k-9l4=7{@tfZ+`HDSmtk>uI z8c-eGnUJlMd%6~Lc!8L0%b#G}25+ohIPXOM!= zMG-KP!P(7i2As@?7k8+nkw+s06jo1$|HIGv&-Q^E zPwxASLFd5?>S~woFcwI48xTf>}gMt&n~y_6LvUU0X%&VPZ|`gWfR&XsPR4#&RIHH%Uv zdF~87W@7tI%YKEL^VvS1eSV#%OJ0P2%xiAbGVFbqka=9H? zJ`1LDl)m~LlCVbSHKh{UC`lU@PrQy=dz=`DAKDyI5V=Tx7PbCOBK~=^C|^*`Ad!J1(t$ zH0Z>aNwGEvDp1~?>vo;HNI6NBW9t}S)t7@m>M;;E}+5Mr&%wrQSPN?}6 z!7uMPq1WNmyYLP!SpJIP7&e7^;UC9-2G?>8rYjuG`Mt#ic2t@_QMBmv$yg6rxJ^pZ zGT|+hxe^4UN_5?2X>Y=G-lW|I`#$5R%3W9So6odUSlie|lu90+pFP}HcBy>>E@h2m z;H67tmKFIG;n;!^_WKU~cJ?*`Y&rYyzdDnDf`v(Lgf z6zg5TdlL94uuRi7%Ar0j06b8e)$@tLP>>U$%Ab-43=DY;=)ytK&F31f%6r_{MS-j} zcmhiVYyVy40u!JKt`ZM#Qi^v6CA-uf1M%uWN2Puo>7}r!>e&lQRtxh~Zb2-ARfvmo zKR110mFq$_RUecqCC@FzhD6PovizJl06+Dk~BkII04j$u=C z=Uf{*5)E3cxBr&No~sSQ+b96F6WkaU3vhZ&!NEtveH7|yiSXr}%B*cquvC2}VI%`I zj?_qKz4YxsGfa_kz*GyCBU>!*Wz+vOeAJ9)Ep2eV{JAb3Q%z{WI|?`NZL8$%+ZV&S zx0zi80)haO&YMB71#uatBs)5lG2BEyuT%K8gI z(d?6stN&q~K@-7SX{L^5d`#yG)3UOnC#jwO zG^f(PBJw?s(dTF>Nyu(uNzgbxlhjXUR`i;Qy;s*bK92x({AlyH*_zN)go}jZ4%+;s zqNtpNI&)0qr1S_k82Ix;Yj`j6)=Xzak??gz9o3|3dE-}XT5v4aQ2#Ta|FUJUk?_cN zcG0_UWT1i=L!MKZ;iO*PL6?)5|zQ`enMOGBD-`656Lkc~%#a0)c z`R|!z)A_=*%*;0k?#9bGDoj_XyCrAe-G@{ z-BBV4H*p3@4tZDg!dp;FbB25}8WXW;k($9F7neK!V7)S#Qj;4|R|4o%@6CJ9g$ow^ zFizg{Jc?B4X1uwm7^)JM`k6eN^}ei0c*5J=Y?E|r!;$>~UfgVv@^O+b`GdU2> zEg-asl6P9gy<}qh*IvtF9JIzS>R(q6qJ${9f?ga3vQ7#%JHB5 z1Pl3-43w@{y0Q`tU_&zz;2ONm)@V)x%-PtrYge$5;%A}^>wR{U@%2jlmG-x2$k>Os z&-YM{o{%weVs-RZuOAKQuP~t-3UU4GG(kDO8+G+oAgpTlnjxdDJH=RcDW*w#s%327AgU_eUU9hG<^(cw)({gn+?UQ60vw6gCDidE`4+!`gqu}MX+Jo9?Xb`_d&LN71Wm38ux3Z&J~42V#*XH!g4J??U%4 zM8NW-x+Y3YNUOtYDF%QrAzbOUq^h(VK3!pm-MqBYcN?HyyRXcG0u2kU>f%VNu9@ho z%2UJJ-vA7Mmhtj>SdMRvuz)y|sYixAe!X0Jl>fr(OS`vzDrrvb5gKP8ELG`3s`%mX z&^Ek^mbv6P&gij6plyq$3mKPBWP@o%Ee@ zIJuV4+R2C^pWo67aFG)LM@xJ|=qNJV8nI?7agO#6WRJw*W`RcmZ`|?dEbh1M6>BkS z?XQlAxAIK@=jR(;gB2W}Sz>$LZRRxa9yYSnn-1D|&lSptmC%pShTj8<+DCpTtWxJU zSGDKiNUx+31B7G^nhVIvLu?XPyLHR^#|z!t%hpb8aTb$SU)!UTb@&2IH(jUfrQO~{ zK$2~IU#2Qwc~VCc1?%{dU)3InVpNF;8-vUD{d$CL*<`W+cl^$kZ&>3O8O^DTzhbI$ z=dgS60+`-tEaQ&G6xZrc`z>!O$?px}80aRRxFD#vOji2P^cG^!O6wNc`mk z&CbnhkH@68vjRaWVcbc>0@8dT+Z`Q`-#r<33EUPuGNEZoSjizHt#)?JX`IqpWVrV( zmQA7dYtAFzmHzs;ibI%pqlo-p*xi~MVI}eI?JF_8vRSDp7{{pZH7CeNwwEfq>3KJ1 z-E`v(vYB0{mup{FE)YvHym%Gfb~e{;NVfVutr)QW`sG%$z0QgKre)aOi@vhBxcU2w z!4;b=;{(mVzB_Zr#Od=s935J%cI5Gmac3q^pR@Yli=c56;G5y9=TB)(i{l?{?3d3s za7isH^;65Z<%Km~u6#Q6$NS1bq?TQQLor^(*ju`fRMDt~^M{SHq{B9fE}@Yi2W`{W z87D6YH*t5K^W}%X?8X<__@XIaOu-k!_c%>f{xjB>2q@p=-Uw-(@ZhVoA zFPiej6nrs^U);o}o#v`wypUaFn&rVvi33oDF}in!p@pUsK4$-bMp4d*cAfcR`L@=J zF>@#1pk`OaiP~*B6FPUe$-z%AO8D}_Uv}e*Y<$s_FQ(v&Vf^AI{>5oVszTcy@vpe{ zfOPV6^D{I^dgY4R1ve}&l&I?KPB+GT-Z2zZ=Cn+Cb@kmJLGYI!{!hC>!MYwj`fJ_W zw|B+s_4Q@VtaG0!<0HxbCr+Mxjmp#M>FE}R5*oD9<-wqFA7GF-pj(K_piJ!xmwLFl zxmi*ZsWRss%xL`h$dym~Lvp5IRrWHNuM@UKsbEx{xrf0T9|%}tmi#^D-2y{ zk(?fqth8Hb-snjn*cJ%vU^KQj^o3UE5IyYQD9*x+ZZYCWx>j~0!|HSH< zIad#1Mt<7RTkG+B=>ehnIYl5DC;bwV+?d6WCy0f@`~cQYARa9ze>nQ12^udwx0v8A zomhKzDVhQPIQs{$pndD+QR*vJVnk14hkEooE@_c~fR|J@`zCK0NnK)r+O7gz_LeP$ z{>L}Ti=ay{`@sJFNi-$A><99_5+1m|dlRSD_)}XbK$7SpdL{N8f{UG&d#X#+I$kDh z%b?@%x^N!mL@koL?2FYPsXNVtu@#o_h2HM_5e;Ww!LdF*pMPIccv~m}vY{02z2RKHm~W6OGl(Ww6k{V)43oAio`b0fU8%G6m!ZM%$&I>M zG@8AkWC_sIt2eZl)%?Wp`VYs24rwK@J(9+UEH#Te3H?UZ4%+S(LUyh%bn5S83~`{x zz00n+wKO8xC1!sQ^<8LY!V&4@e_tfQTql}KtzpRDQeQv?_aVEo4W>JX^+~Pb%Jj@ugqH1p+ zw}mjYq$SX!-;ktvTTj+EPC*@&$;-E?s~8ci4L4+kwxA0Ny12#Ake}q@aZ~ZTr@qS4 zEyd;isrmnwQ@*nt^|y5*D|#Y0D@vR2TMN3JVjrXNDTm`Z*tdye;wH0g`t=;Tx+D2; zFgS-SrM2i;{eAqNoss?L_ozIT@u(a=?3w@m#c$um5fHs+I z5sb`mZMQn<%4OO2N1c?NqKF!LU<|{raRnDzC`V-mH~!pWY+hcT<=jmV3JUBD1?$h@ zZac}mn>)d$Gn+tblBPG?l`ARYcbG0fFa77AAXEHFW@?HLGgbV8m^EU;>+OnYID>nW z>p>D8<5in=H;Aa>Uk`|nCgf!dIpikNU40ycgOWaXkz2Ar=plI{@^7d6kVoq1-tdLo zE9Q}haoQ0j`arh zBA&1axSN(=ZF=zF!EQr=UYI%wE;IVo$bs;og%TqSxbwMR#Q>qZ0fwbymZx(GhUcl8q)_}ta9|?Nu2K%a5;v9j9H@|^J zlQ0@QT*Rilx@yTQQM&-Xct{4A4l3aiDqdVu>p&HWSQ2i}aDPt!aDQAst-rBf-0xjh zqv4Q65cLx)XjEmUgUE0~>XlK1a}K8jP_XmeEOhfnEi;V4djOjX6GTAc;UcAE7mpSM zcYJqmU)hp@%d-|rKNp^-V+1>P(O3yvQCl7gp^d_fgeQZM>d*wK-dizkV<@F)ke z@))%2IuWVUo|m^lqrsz(P(?h!tLwAk?aL$QOB{LS+BLZjgyHBuWepIJ=5o4^#9Zvv z=FuSO&yo!F9}(h7g4AfTsGR;hPJHn3mCGC+7!lbTqo(jqxYv%lx z0|6=zuu%iZq|M-S{Q_TbSR;++Y6|WMv8qOb6)>VO*;fhq-2dQicFx8}z6!A>tvn0; z-Cco!RXkqq%}`odO&CiP8By6i4IsB;`6>Knxap@l1C7=h?8#G~gyqC~j2IdD$ZROB zgvY~cX0LFC;DDh7Ey2|pL@g~XZ@4S3dP5ixx4ZmhcFwSvd8vR?C2BBz;AQ|#UE&xD z;+vCWQ;@9gB+s!~u~XJOZ2qrOFK5$Pc+-$9&R@Py9i}ZLyGR!l7!zo>ZWQGS{6+?j zB5XFu*Nm3gr6Pf$oB#wqTSF16VsJN(h(=9z86ru^lEzR=em5@g=*WXlJ7fowK#Zy} zkSsxf3Q0}MVPK>Oh zhGzMsjuw|UH8shp#f*kr?__~kz;Y-v`UCIyvKsG?=I%Aea^4`5m0Y4Bfu9!8A%HH{ z6x!rRGP`aKQg-{W?}Mx?s{)>#ovrOj>HhtP0295NFmJzu0}Hs`$vmFti4&^?1j5o( znK6u13dmo>B~=YY6f80#mJS{3p~a*Dt~$(m_;7E*e+iwp?G~5_=$H^V@L_jW zRu=D7UteEOYBYqBCglI1>c;E8{O}(L(MtGbH~zxQUv}f~(ZiSB_q3(Z&5Ct^~jIi zv~iHIHStA!CPBF7(x$)> znJg=wiE0Quz32i+|LQ}<$C94ntE1hmj{X^QFao}@S}(zhg) z8y{ssfiy{%AZPh8j>)nrb*CAe8m>_p{*J}B3UnpR4{HBN436R2gL<}KeGKU%OSv@l zwg!mbs%u2@BaudmbHi|O+X9-=-c^&Xo@jpfE_oLw?Z_uL1*pJv9(3X}NfD%9FlUvy z!?7EDlm1M}EXqjdc|+kV3ZwZ>z6}d=vKvrt;!*0Q`}{-D^qrTTeV4Q_Y-pgfs_WPU zm=@vn{pM#lZz+aYO}k!~1`i z@tN+Q-;xw3JWY_?Ronnm>-JK^Vc!CxzKeutkXZ`D#aI@>$*X4L4gC9WjffL02r%?; zKXE82y)~<#K=fu-D5N}VS;vG1n1?(5QKw4bpE~)e_9 zOn#P?rERTiDXY?d7)~XOXHSdqSmCjoB}UYnKPPzxi7P; zHP!=3t^OI_*EE#D+`MdGC@U@y6w{BVq49m|cNg4zGi8o*epjwW{|Pmts!7k|Y{R@8aQQ%;muGZHNE&{zsL*5_rW zsRcp?ktW&F*IkfuBLNla;-&;CeN>7L6Z*b&H2-uE8RUx6AQL1lC=WJ7GysVhnO6+u zy0Snh3fkx84c%@Do?eWF5`8t#4&|#Ew1}G7Cl$jq$Bzrbe*eGnPi~-RF*c1sV@!vp z^ea(nn~}~kOHqbc?bJ=h8Bn=;|04*25za8+HC*m&#n@6ChTE>NO-Ku~s8^s8aj6;~ z)Zy;qq>gbnw#cmFO#!JIE7Z44;MD5d*xUD^6ewM+cQp(L6F`z!{_fIr7csIhg1@F+ z5s$eTBC~bB3-?g(TRh$i=@u8aslA1wI!qjYaH`C*V|+hy!_!%(a*z=i#%_$?h)282 zV$_2(90a5C-T!VX6u-wZD#CiOR=%E|yQAb3A!!54IoCoVwXiB}hh64SCk5^MND_dQ zFdKjI`6rTRK~`ldwWzYHwo-c*DuEku-(jIyX^Ro263KlGy8Wqwgy9x{TEK(_RJY>5(qL~t*SuHcGBh*!ZI zrF|`Dy!>O?Ak*)SVT|`!ZJ!0=O={<~mS!9d0Z){)deOUvhWm}{CKh8oGU;>&5v>B| zZy=*kF7v^z^oBg6rH(N!QSmTTLIarQiunA-81YqmNbVu_hzsV1)5ohEvqkkvu71=+ z+YYe&b@lc48g-%ESfi0B#VD+2^{6z#z!-P`hsCihgg_j42Uj)uLe{uC@a%1VbAZPP zahXKeGk*NVpFh?BrtTQsKaXD*fREBtP+u&M-f9Ddef#v8h)1I((1I&)9v@Iv6_)oA zGWd*3zLFcGgs*^lab=?dQxLz5#Lq#+V_tI-%6RMk7X`iT?(grvuYP3~OnQWgo8I7k z;-WMVVq84F#vk&MgMHZVIYk;C8bm}#02g!hw||g##?lkKFU-qmfSY5aYB%+~<;olU z=2vywabQP#B;zhZbJYLx^_c6%*hNe#fRB@?u9)+?Ov|0aMH*q)oQoehcuAnk7;Vfq zhB6CJ!EB%|-LRnt%6J1vM2~noBo2%VTJ|wYkmQZ^kfs0_+GJV#xE_GnpztQKq8QDA zt4UWSngd{I5u$l@3i9l&E=rw_S|mT1bT`;tu)T8vX3w`Dfn162 zi!uD`<8GvV#S^?>DtR?GY~TS@qZZOo;h+8{88&I9<;QwN#@-)8+w<+0`G&Sl@PfBQ zkJ5ohi3%*b~&Jm!9Eyj9zgop=yD4M*(Dh7T%F2E zg2`iDUYM1OKVI3G8vfhO$d!?+g1cQ^UH3AyH!ZUcT|CjNt-GW{xzkY{!IZ2hBLk8! zdt|{Uj$zDpfplv-yZhV6wFAgX>QLSY6{J04|G2ogee}#^RGT5r#N3wM#n>Q(T#Om@ zc406))rNRK2PQms4R5)2f>Q)H)A#~-6Xmu;jFt=9hIZf|XEXK=e5*9EBI=i5evcYr z@HianYihayJ9~S-(7bGQNp;`(DE23UekoLfu|3mQbe&?3ys8qvcb?SK{+sX39MJnNX1S!4-YE`$36gNXIR(6ZG2+AVAUNddq_;?zBmq|X?S7S4?cQ9=Ab0&#k?^r z@oz_fwYslQzI@_M3up7Gi{Ly{v`p2vOm1ThE+eQT?SPy^mmv{mtNMU>Ey(W_)H0c= zh`dNV$1L(LOI=0^?}m%&_EE#jSe+=Es7)?oIWEhbvuU{C9P=B@!>uyoC7Y;Or!5xA zn@(!H9XTui{*|WnZN~2V8 zkMc4fJ@QWIr}#b!eMPxh-6B<_40tM&lks!Q6pWv*>~_!{fE)sgGsVh@`FVtJzPAT@ zVmcn#7NR)Gg0x$lFOQ$UjD!jD1)A(f6sob5>Q_H-_P#zo`xc6sd;P8aU4-jd3Gmz& zw~!gg!?~nzGkMc@0_|t6%mCov_A(XgWhyXw``4+Mk>yk^7GHjYla?}GG4X>mh+;@t zoetaJsreiJcQ-{QLdle~x5m)PJ&*G9>p_Srw@)kMNb2+7Hg!2@Fd|*7AYkEXz5=A5 z_27Zd6`&iHQNMzD=BS8mqwo|)N)J}VM5fNAN9dr~{=W&zC%Lv`g&SXf_+M$CU$*Y= zr0xG->_(Y9iCd|!?Szw&@HvLHf;{Xe*rETg|8SGzZl!0Z>RW6lteh>g0l&Ru!8x+H ztnKE8i!ZBfQrh!J)`7Q{{nuM8UluX2CFPPCJ51Gs%`u~4RMIe&zA)Tm*T2qZN)L0w zT3skAchD|>|M;n`!$eNfL^bjC|B?|-&oyADOvTx&oF%^$d1k&avJ!hN%Cge6a~8Jc zmyQor#^buE^Z!H2!gwpymn;^RwTzS(Qic!ll^bv#vrQXtWt#*Y=r1z4Q}1F)RIOHq zK@2Ogd@fmst8Ka{MbS8tsxb3AO#jC*;FzO0F78w(-kpj|iQGW&yd$`$nZeAhBJL9l zD(n`;DsvAw4;Q;KOjVs1%hxCwtCKOmZI%^87%%beEXrp$W3`oPT%54lQ&o&&ZTgJt zmiJW(#bJRNAzq7q(q7@hd)42z!vQkPHsf_P$XAr5AA*pIWknOTT`B)-+m%iJ?SbYu z7CcJJmgb^+J%^4C7r*3-`gmFN9un_WnTBsryIpLpvX`xMfL@(;*rF6AW6es^67(+H zq#x5P+s8cJGU(q~)Z3SFb3qVYJ&tAFJ=hLn;mASUgf^@l-+oTn^ysJ|+@r;gWM5QTXNO;#XK?`1ma z`+*IoM`HV`65INC@aaQguPALEd1jaCblbh;-2?3MK)eq|K9_``iT!&oRuvzqLq8;k^`DF~bZHqDf2pk63E%9AFbKURU=dly1M`}A7sCezL zCT`on0nm9ZZ?bwKHe=GDKwIfGy^loXJxP8=Ne;m}*@#Fs)>Of?i_gP`W23Wdmj(IR zh>!E0^te1}+x<*b%LwhyxGAhq$8Tt-EN*i-S2QbVNM@LNq}IE<)3m`Q3&d}QmM{H4 zE>kpw&cQhm%d}f@mnl$HG|FGQK-=_#u{+nFW`Wu5EKV0T=hZS!_v^+BUDfg5lrhfr z>lPjQ%2tb(SH`@L7L@cj;T9;MyYsrb6 zyEDJ#-F)h_55u`BY;KImsMW?#DEs|``L&0h-9KU79|G@t_MiFdJ<51ZZEasf=xl6k ztbmji4;pizKK7x_p{}tp&QKXIP(@_~W>wW4GY2t*_mCYM_!J{@|Ip!eA}%8je;JMn zY%(poin#v3E;>_`sfe!^-4^^XTy|naM9p=m$4=|gj(&aWwS5J*yq_XCmT{HrZ*6f7 z*GJM7)S8pFMp9?DxojP--Ay_hNnJeC6!Y%mR2MXQSP9LyFMzC!bI4RO)Vc#T+NNP~ zM#6)V$54~Vh1Mk$pN_n%S;mN(722P}U(x9LiTdekSD#yk4m;@8qGqmkF2lW{4@RW| z$G=xQkJ&h!BpAHMZx#SQp8SF2&jKNcp4MgTlR9-#ipu54zwh%O2b>`6$<`!g8?0D{ zKafP#$DIFXRz5C?k0vU^g`>&a6&trKvxS>|lNv=vs7NnuvmD$_GHk1v&3xAV3Ep(j zSxac#`kN>&FUC6f(rJPlg7=NQ27IF+UeG4pWxDMy3FT^e+kk`}P046GxriDHGoQ*G1gofyBNH?d_bRf|_mUnF68VQRuZ!W| z$vlRnwe9fwi*@9YfYxfyXBElNOO$a#IBGWjRNWMnXj>)x+7;uD{Pb~RxZPk2FaeH+ zZuu(g_{oz4_7V)V%GnlhiPB8inRlP4(j2Mw4(qgnKCr_hQ4k(qG14Z|089b8Y^ zCtSHZG%YaCkXco z>&IGaBEQ+4C(uq&!V`FJiU*R5v651rMDM_^DgAt3i4e~;7qYpnzx3-KvWF`FK3kP& zbrR-CEb_A5={b<=Df+7)<0S0X#?WFpD@Y1St0^5R`|a)HX0Ocu6$7Kg?~l|&6zmVh zu!1}4aC1-(HQGQboaaQI(*D*Xe} z9?scn6)1}coej_)T5v2y(d21ID%={oO2aRds*M^lp~V%Is2g`j9uDh~6pdW`1a-kU zSt;v^JubY4(iUxRHE?`*SCc5xCaFDn+qlh!MHfslw-MT_Zb9cY=ggneZ1QVc&|Mr! z^6}Evhn_st!rW*RX#t%t-l*Y}fo=|C3Grw3qrO9hi-0)OPo7ZNfgPz4CEZ2+)_iuq zb%BuAzw6+}6s3(%$A_+shBGlzLV<=-bmhYEqB(yxO;;-6Q6kt$t{7!Vdh+~)e9yt> zz|6@C7+TMWGfXx*U=T2HHy|?tdUW@)mP>nDKG{F9|50}Fz{n8;mCFuR{FQhMhi;N5 z`M#hnY?4pYZSks@#M&K|9U4l;`L~7fafd+b_Iv|U@}p;M{mZrIpZq1juACy&28OQ1 zvp3|Wpe}Qs5-Lh$AR5$Q&S>{K5A~Qq8wy?ABWWEm|IC5V;|C6rE*fbBT6dHDC3lob zJz|d7%qmhc-!H>TcaQUjHiH87y*c|aoZJ1BfU^suN2>@)#ZF$^T31f~6UzBtD#k(` z&r)F%YIZjiq}W?Md~t{0?478<5)dIfYRS?p?16e_qMUu(?+GT2wwhbk(tpoC30_XE_PO_d`Hyzb#TL$G_>h zjgYMxwJO9q%^c-v8Cy$BORGE1p{hr%FXzL{1n50p799G?mge7uNNCwv1b5t#Xg7>B zh}ph2eD4hg!cHc&4VyuIl9{HUfJw((yrIL z!ZR1BF<}f^vc&=grp^XB za8=O2iarg0!FvLZzfQ8scuB?tfPAC{<}!@)J8d9AIjNqgg^)mEW*W@aYXS82yz z!A}UHHEf|;oQtUZq@Qk4s_IVQ(Cx&cepgS)7C=3?Tn8;OQxuS+d57E}8rsrXbSH7( zkk4dvbrM-l@)7aP-?x=PFfQs^^9fy#6z&+lMux@@*4jat!Xg(R1GaqK+$LETW< zw>Czd=|%93=g0rm4f@wWX{qj$wzcBD>A7GYpY{7f90nlRn@$H&?R<5DSF7y1BD%`q zvSEFfc(udcDfX==_O&mThvRqA+vAj3c0S5RL)XeVwW3K-nFemZ;*-i>1q=W`_KoGH zewRQR+yevi0vL5tCufEF(%t(Db~TJ80fL3$lom=I01gtbQ`t8U1r}oVuv2GIav`N( zG%nTB`8}Msa>m@{8GD_RD)wE~DaK+d5d6p$a=C3(4paMMTXe030u@A5+!-M_pn}VW zFsz*rZui||D0uLxi7ZmvO;6xh2-DsX(}ixlm^@fea;@ixi)nduSZCIDj>)>iRP8$q zI=(3Y9Vb1@S$B0zb7EVqYt&qgjD601C(h2sV+fz#AUv!HweJQ3Cjp8ny9+wp6!8xo zKM`d~{hH`Pi32qK77s)hW37~`OgTkV_#cGt?n5l7=6LTs0^!jcSSU$kWCVm<@qx??BH>`-XgvH7GMh_XR=$xl8dysYNjsx9N=&~sy>DZyhM^* zv4ke{5|KxmnZg{E$%jUrhX(1p_9thiAZj;&ENNdgqX1f?fs|dHq3;@&N!N2yc^JI$-b#fuy}$lUF(r*{AG-=p)Ha88^eD9^$sO0kuXJ8FrbfdoRswpD7Nufo{S`FCd{ zmfG_^#^C=tpNMpRywD_U(lp}s1!c?#>N3b^B)c$+fps(Zyai2x8WYF;0Sn=-3vfU9 znm_(~bTm~AG`}xK_mM^Q*R*D^5+avBy&2L1!cd*zB}EO|h%gJE1YN-}IskD?KQ&PY z0I0i8!|zcJRciVE)WzTDHlZ;J&`WF?WFFa1;r9sq?0!Ns5C(fj*=B0AEaX39i~le7 z-UO=2t7{t$rHTU@tO}?=)LKPF5Uqj&iHc|)ai{_^1VyEmNkAY#NT^mrjf&QR+9;`7 zr=kL)%p_>UR1xD0$PfualqrOedHVMW+G?MtExOkCe&4@dvewgTAIN>~efHVYxvsrY zW#+%O6hJe!5e=9nmq+|{%9+MB{Z!`YAwcZB`m0rA=|E0eX3J7+*$-$_G>{TvIct0ltX7<^6 zFh{X?%6!|U8{NseiH0r?yunt2oQE*ROj=-QOO%_6c=rAV@cDmjEMUmq-$smY4~po$ z2r}Y|!jPk{`+gNnX@}kx6A=j@&H|-bIS$;(;F!Suj~sW;z?s+i8;VOOrV^d-I0Rl`GGKJ>Z$K45pigT9gbTGf z;GO+1#Jb+>l^q2dF0~*-4z-YF!;lDh-QpKFc6ESc6M`dSW1kg;HUOGWrO{Y7g+O^^ z?Vt&FGrwS+sY2Ks;h!DK{_IoKyXV@X5PZS+8B3K5_^0;(N-!K^$m?gZh7-jrSvff@ zg~@brb4vu@7Vv7Kvb5j=O|8`_jn@Gf4ItCC(l=VNY z(uvmChynVBIx48z)}fTOlETPKih%w(qK(_OZIh8FGa(ae>+x;tl$6Fwf$;C*?tjsw zLd7XVr@CS0uK^U9vu2MXhJS9v+#@$r6fF5_%4Kous=tf-{^bRgaK^K?110F_1iJ;3 z!}FKC=J+Yw&_SqYnXG<;2T$=S%1u^h~`i^0!x zL??;>sGquq z!?El;7yoF2`wScATAyL-^GyH$3@wzeJ*kDtW<~_-u-n@p1K{!IJu|@4`U0Z9favdH z?>@udXZTn5r6m8$a>}&}6gf<7n@e?1kMh2p+fbb7%eniCCVfSd|E$mT#R&h4V}!}$ zG%CQdN~(jI(|VdFUnyh;>(e$vY4G5ozSN>GwdhMNlK1r$%MF(;eYYL|>#wWUE+Dl& zyDZw=(@eSVcFl0P*LSC&0CwI6fQzRiIGz3pnDh3)2O3WcT2qpI_uZ{?NQo=kj_l&n2af7b&Gus()`Lx3HN zi$LuGZeHn`{2N%tM*t99pO#^psGWfJt^m-MV8(8981cjaAWazXUCl*0GZ@hI*s4X}X)b;WfbBdp2x z6lXj>!0%Y+DI_Amz0sL5!<)KNy3$p3qQ)W>IqpeI7C1Cr-CdIx z;Xv6+2WVSGd76WT=n|hA8!J1BE1Yn;I8%dLr_CWf{+Q*v6E&;sx7emldSWrbZ>q&) zpuL55@Y!tWw^{`w7RmBzm$;75j=G|}c#T3Vh|WyZX_RA_A?xs7?AGUW&;Wn$)zovg zd?^T4d^`k54!-0)P8A;kV8H`jLR9E?N|OVy#IoS}bAE#&BO|j(-0Q&YfWM#hYgxr^ zeE$0NYiz)T)jI5XzoDzCvgbOTPB4=k@exH8$wuhsn>P1}#dtrLIdgnm_3n4|!C^(3 zQ+_vB5%l&o_F=YZt2GbKVk_^;RRAE<9CB=YQXwg~ELRjA6?K(#EaIbf>tz`>FT8Gp-}kHh{QT_F zOD+OJ+P>cR$YAT1V16@@62;X5&>MQN?-9CPOWa5YBFoRUJny59LYIJsbi!jm8SK8G znlaba)vq`Y+8Fz~`W8hC?+c>KXntR11AeUl+HDu-MSI4j9G2a4h+K;xQQ*o{C+;Ogr_AP~t{ z@{z!*8Tl}xN^r8Qy1H86^D&arzXZz%D(0893T6`9r(gE>_n(IM0;2)S@B`W|7kZil z(f;EW@7ZCPu{h1%+ZI{?7n5W$c@g{o^0|EI%IKQ77HM9we1q0b@)5{9YLo>5>8!n% zZAX1c5qvjH*`@PFz%)2IbeD!!dxAQTg6*d%vb9bh0gO{lf~l}Ts1a=1Er4SHT^Vyo z(8&UJnta8evTA$)i(Ej8TWx>h0#p0sV@{uLi9guwZ5JHY+XA=sDk(R&T|hg=%O!zz z7S!28^W~2*Q3Pml4OC`}v?(LJyu27lCDSvlt*xiwYk(0OmbRw0Br#6hc0T-LjIF0Y zKtt{e;LL(GyMchPS*)pn#<~vpqT4+O!Bu_nW8w$R;e5r$nzMrdx}Ryk`vQ|X#)7af z=wp`n8hRUazj_EENC?shVP18tVUD8(kA<}$T>F^Cky~ISxc{?NN2u$OOR1Maf5@^N z5eJx6A7crbGS~-;S;e9)Lu%^fcu@kp>$QwODO%GH_%a;Ws&6G*l&zODortxn>aD`o zx(XpdUDkY_Ut6ys$qC`Q1i5IozCozpCP`y!Dw_E)?~*}aYq1Q1{wcLyQ$5v$LC~T< zz0MpuTVF2H)x_<3O;FVoW$GwOVN*dDS?Q-F$9qEm$B1~Ua!!-*DASihQQBm(cp^5x z%f*;~cq2@(&Eyy8s6`Te67^IKUwbCn?b3ohMkIng#qvS!veEdLr}x0nb-Q}$I8jDe zkpi&v+d;eT|4Fkd&Hn@RioqzQZxN`a>PRss5PusG@mjUqLZp5W#|p10!=g?)v%g3(BAq~IZJtC-%26_W34G1dN_MS= z#?-bqJFt{R6-<%td<9dhsRWUhML1Er=q?O`vWXC?Q&{M;l#N}LL|wSVBy-w+eM6>B z6rtBS=o+3zH@!bgHTo`ZDj`p!>2eE5V)Ci%rnvsBNq{we11Qt09x3$VJbh-5h`$jE zXg7}NJ>yl~uLY034@kq0K!K^;{9O}QhN9kTPyZy~XuWkWdX0o<@eZVi8BWn0W`Pm$9c+qz%s^^nXFFft?3DS}*{aO7lmKNxlQ2yrZc%G= z>S^qmoo;_tfK_^Fo?cT|pyO*>*v^%lDS&M5DSk-0VXE*l91O{@u$b6g+&SO2wg<`YrLdYz9}^Kl03i3OpyFeXGID5IBc8CZ$jlH#5^p|Z(Tn&T9smbHV! zvq@We@0*V?J=b96<)1+xU5m6Y{Ggp_0J+({gshAbNAHE^K4n0uBa!kU+JNEj9~W(; z_ry?_)25FwxT14tq)Zua3hLnt^RUk>PBj&7o zh8zj=brZ7lGwIw^oC8D!VHQh~#s9dF;A@aS_@7p2${@m7S=r$quT?6gylDn!n)?*8 z(2T)TTjI={ex$2DD)#e?2KsuUJU8IrcE~&Kelbu~_$$gA=|bNl1TXECR>E4X9slQ# zNpZ{GgnIk7w}Xi>cvie58^uj5cSNu$%bAv2B7zwO^kC()kICaWMASy6IDbR~O3G%V z_=wybG#BDCa;gAG8#xdglXU*j&Dug#5Q}sv{o}%ZNYa3V65yAzy&byxNSERELX!_e z;;tTnl0R3yeUo6X%GyJzmgan1S_XluZQ4{+eux}*^E&ixXC<7|)ouXVOO|3~<;s;S z&3WvKA{`l-G@-Koqr`>6ryb62ZvK9IwzrqjEf^5=r#3&k>`NDYJ(VvO2sVl9t{fGC zO0yMyk;3~YaT>`}X2Xh)DL-K2#2ny4TBMD;bkR?@ih4e8vV}7ok!g5+A@rJ#K&NgF z9$yaxy0)ihhe~H=7C|@PT2@k`_#?A=4tl*M7aSIPuWc>EmZ4x&BFgX^A9QG6GoJA!RiLaBfbk@x=PIY)0P4f173%bRrKkWKmvp2$EqcPKV%RsMi8TQyP&n zDbYYl`P|@m1^>)PsowgW22NY2RcG3}r$EajwqoUFus=@FpQzd|Ai&<9MDj;*i@ad- z&NtTWwGbS-**_D2?#+IThc}-Iv$5l2F1n*suL6kp&CoEei0wt-RlMQpd__Qm$VvZ5 z(%G7OJoX*tOLsIpKINZjFIySg-hoB*J)}=ue?-H7rGWKm*B_+0zH`!dPWsM?fr$DS zw$cBjKRIGXfi4OB$$F7$6nMH9jlZemWz;38J>jT-cIJKL(tqMHzY`Gjv4aid6qNBSC%^mUHu^Gto7>Hn^0 z`k3qazU=Ie8|l8@O?}O``nsq6AL*X-he|d^;{&KHFU1~oBnC-TU(l;dD_|I zhI@UVo|~^_|JeGj;z02d+Z;Hn2d`I$f%EO9c1Wuca~|4!OXM! z$98v4Yn{MePOcbXfKDzo+@X2@;~iS~xp3%mx8~=`h6Xf8VYD*_ll_k^?&_A;q5gJX zt{q^MEVi<;Dt?MUFFqK>kq?wYJ@?L?JI!rvp_WbRCw#{YfFxJ2wKeRLrwCHZGcG;9 z{KHZSK+ z`M6J$M3aCIf7mX+NP-5Y5rBGKA0zLmt3X)J8Q&YghSoU-Y#%RY z|62hwuR^DLkBEKisa(@~DLY7OhspZ2w+G1wy5$FEC$BzIJ{IQbTJVCm14y8ZDfE9U zJ;jld)|)$RrEykuxzx&oQl2qpSZ{T8?5;PUicLmmV5EVMD|5bbrU4EljJL2LU`iQq z?D1RdDT?U^FpqwNb`&OXNpFwz3{qvpi`tqRUsvW0gmR`HJMGs$yZy((?|wCj3it|x zPRcW%!f`&<+ry(q1mp2E& zo`rE2`9>JKf4qArI$nsnT5daM_DSjQcOwL~u}f|m5dXoQ-LD^)G0)nTPio@%rbohj z=4%gmMv5bK0MY>BP1bMF$UPkRALdCh-RYDaJFZC+*~_#xAFwBjWC%$9sO&m|dr8i~ z_tgt{tKo%!e?gX?c=J-2)&wt42MhNnyv0-901GChi=O{R!F_< zrN5dDz#s-pFo(AT9t)+>2XtWG7(iV-lUzyHA*`EZ*xk6|M0rOuK)G7ulKfxuYYG!q zvvC;jUYn=YTh!A|6J7#fNG!Cy8(}kzJ~@-n3dj4KGfd;sg!-qCv49riWnL!JREZ_! zN~N+{)?VZ5LRzgi{SX-*X~;{dzlAWN)B^_86YM^TQuFACw@3cc`EGIH`2$ZL=f#Nf zCi9g5_{qfQUEXCO1<+&fHGJV2z4=3iE!*vyU*{_Q9Ml}%KHVXA+R1!{5vKlEc&E8b zEizqNzHuvQ63`y;JXfw{6?LDgdUk`Z3dUfT_14Qx|M{&Qs6K7kgx(D{52*#4NJtEO z_r1=*d!0G(>@oF~VTMmh;fZWrZSK&3C%jMo{c-rZGHhH&Z8gH?J>ezUZCTw_Uy_(f zX$H1M#_pRTTP|d4v|YUNXld{^Cvc^KU&4$9Z@UeCV_;3WFlebkz*}{MKV6l(<#0&N zv6XL4;ETqo{>2^VEALVxE-+KC0z|!wzvwQt=pvJyi)5bDiv#Z&BX;+x2SAjcY5yyZ zR^EHefrO^|_J=H8b;Md$u#Xk8W5slw44OwIA<$6Zgr=z9=yZBn^;yf;usmDgrLf?f ztNYdv^o3!|CcSk`o#LbT0Quw7-SR{B$5|}JDKZ~1(CUi7mKRs_T5=Tz7R>Xx8bvnE zxIE&=`rNM-?stWEBc5L2-zh?|jpdR#X-|YlLq6AMBxfY_7sb)Ci*(=<=L8 z6dtTTlCwe?c1k+utYv#%itZrzy)tHliuds9_gBMS-vYpO=R`R2_$y)h_RoDiscB)> z#TS_U{{2L_YC&^E^>qkjwZ|w#DolV;X<~tpoPyXZycg{#9^7|YbMAqP92ItrKVD^z z6b*sjbf#qhwnR|A+_S^qIvxcH!A48I4D!^EtC>q$O&I*|COiH)z_%l=7TzfvEEN)- z6q$dH&}3fj#{mwW$qG@1_w+Sm^0mvLfuNrdaDx%Y1`#{B5zuE*jjd@JMb~XwbLE} zr!mt${sKW#7&Y!zE-#HutS3@|D2FuUh8^3Z8v$jt^s6JH!IhOsYC?VIA~#dKgaCtK zf{vxvZ}K&9XbCn-{^H+)YJjo27nnA@di70fM80aARE@utyZdI;YYt*EDf0@zls_xS z-Bq^=;~yjC7#8Xhv_&=TbW_)5m6|v;I8&vJtIpjacpH7;xVh>0#-@k8-4A6a0CkJv zBBB!#3N2-c47FV3pbHQ+l+IXH(0+3dLwD?Qn8jE3ob#(|(Br?EsWE@(n-?DDKf*vU zRN}6-`@cP<5BNn1s(luVBYJxl00-N7w6d;IO= zY}p<~6t6>D?##IR$K#{MxofrbsAI*iM_sOp?H}`2@Vc$9UOQ#?)>NHmD^B^HSVd7k zl}e-pcyxZ)i>@SxZ*AEpGbmm#FCdq6OoTQ(Zphx6;4ij#~(c4p-9iI`#DMwl|H>Tz6f}|w&%s8j6AgUr!~24s%i#I&AS2YIz(4| z^Xd!cu%85{hK#ajPO(^v>?JF4wY_NulYQ~sRV*$j&Fc82^6qXksbpp~iWhu(6##8F z2Ko~aJP=C`NS}xa6W7D`P(DLyJ^} zJqT%9lcG`{@3jQ%Dc|fsHp?Zkq{z=|H9IFCQb5G?1ckfQ4xFCIj{IW~nAw2!Wyz)0 z7c0+>d^bqo9Hm>g=#F1JPrcAvw*xu-{@ft}$G*Fzs+O+?Wk=9qpJu}Y0nAOmg}*Tt zgK7AA$-?SLujl?G@z3p3pXIYGaI z+EX1&pM(S~vO?H1|DF+kvjb6iUFcT>y$NbOqML~9J8BDZNXH65=Ry!$^@Ni_^ac{6 zXmxj?nJH!lWDqYnp+!3ie}C{eaXZ)?vgNPV=^kBND`Vtgf#1oS;RRvO zikkaY9((xXTflBx(jjZcozpn&DI^A{zqa2K+v{%+gE?@3ycf5r5m&LeXuFtGgKU;} zIYNB(zL4V-`h`E1&xCBtBeVF}#q;tPyUHurDWP26>8ZAcMED7wPnV7_)ng6b!W;OBA!-9@WV*E=$qy;zgj?4r%v#5d>oTZcqI+k}e4PM^N$ZXNER2dt2M8V^1~_Jl1&9)>RLo;RF<0I#Y^KI z<)z7w$dAg8$}kp4RqIPxH81Xl9L!fRW1vz%nf_}D`Z!YhA#yeGen=!P;SlA+w?keEK;0WJ&JUR@6+wQQuC82 z{i$IcXnr6a_TmD6d;*l!6!zC4l`9jDe>WSN4MUIU2jZulKA%=Ts+DdFj@$`VL0#3^ zy^fk;0`PY;y1m(X`7y1}pcz7yY)kuCtdP9pbjksc9>V+q-B{Flqp`IcSQ7PbI%GES%3OQM9dqEMQiRjP1FaP_#lRH_gSNh@x1lDi0;Nb=ktlh>J6igU1CzxP_4;j?+nuodls|;TN9KTxN_*) zl2jCy_LYT!Gs@2G;?IP0%vS_7oo+!uTcXoGleVu6zSMZsHS>ZOtiALE^s!&eAWxe71N@&xodSHO-`HZkkdF` zd6XKUpLjW>X-0~{Ly%Aj<;@ELq3>VsVYfkV4mf0Rlh|hR)MBz<1`{qmf?4HoyC?`Q z?f~=0KvxMenJmDKXirCdxD4G@lInJuV9g2i54^;E4EqrGEq}@$ftGk8HuBfRFrHpz)A~n;leBHFYg_)7XjVDur+EPlqQ>f4C ze{4R|d}4svSIPN_Wh_g}69Y_5N0oaIqrtyTM_F3Vm_B{_=0Nyd66SlZb=t81Q1r2c zEg3Dxt8;z}KEP9aGOp=D%h99La`JSK2Vs&g|Hy

e8e;6-Eh9!#5VLXxxt7WSN{0 z6`dS#y|uFD%c1hIdyFpRov-z4pbc+ITqzdmbs`!;``R-&Gk%E`j_SoSuWX-;WEhz7 z646Pw6Ai!Y`&0O1=h|JC{eMZXwfe=YT_v5ppW~LY5!2E>*@?frfo6NKlTOrZTGy0# zUwp#^XNA{xKjuLBI#wZJ>dA_7qAtp=%S>;j>!cB^C=v$WDxjDYDa(6OvlcaKC|!RC zUC6&+y==yme497D?p$QPiKFQ}w}LtIGM`IZ!$#8*%{UnN>2TKZo7nB&dd^27 z+)%53WEgka1$H)D19fId6&D?aQ`a* z!pY>-x}x2$-@b`KYFCI6Z(g)MJt}%PuIg8}e&gVmLXT>YN_8v?ANn98cJSl33p$LR z?9Z-$v)&T7rshF}tR;zqkd>{PqC|JzwP^;KnV);R0h0_*rj)$hpTxTrBPoH0Wg`#> zAZuM0BGTqN7+!ta^0zPHX?U_W;q5!hjxeTbN+*8sZh0xbTk-rug?^>em~rj=SY#GH zM9EBMZ~LF!sLESOHdR+TgBglDB%}Hc@RcVp4=N zI~em|<_C+-v%H3;SIU-rqAPsC$uc~kWO%wYYs?SI8s1R;ek!GXPS(Ilkz*ps``UxC zVy&eplRodnh~yOYuKClG!w&V{X7#93=Z+B`xch_fCq~4P78ngk_CE+#>)Thm9D4cuDJx7!8qtW7QSTY?A z8=!ghFg11jq5i5_w^~}OlJ+&Fw9*!s@u`*r5_Kg(8gA0q4wJnEOz2^aRba+DM}uai zyROie#+LjbOc7OP<8SF!VFcp(?Gzs7mtf=vk4861EP7+ci)1PRL^MQmorP+xA?h$%9#-Pv#}sCbHHsn7X#D0ZwwyL5tVatS%LQ-Cg&c?&PhoTVlst zrrRxfP%+#76UW-Y0-MOT$FD1mLVrMBC}vXLKB{e)R19Y}K6xA`zb;T^mTZK?>+>W*(qXSc-Qt`E(_eGAF?s=4MrMtLqbc7u#xH)W$v!b5#N*i zKNWauHrg6qq35_azv!LnsJ{%IIH&rn&)AluJ#YtpkY#Wes!eE_CHe-&Y3%uaxbd$k zBY%*;5SE^Ez1^`Yzv)-ZtyTe6L?sntT}}6u%&o@>@Z!#lGZC{-RE2)>zyY*o{R9u( zcrWU`%GBMX&s4dO2r(ECyGOE$N<692?U#`AVwELl7>}Dz@gE?#hPjoKUp(T0n1EI%lEW$I$>#LT=gkh^pSYgs_q75gUh(t-x16cQ&{n#=K4iQ z41%Jbu)={@4aaP9|7P%iD7r>p)wm{<2!9biX+yv?|Cb4NM$YfOViI>{F-c z^Rc(IV><>DM_VOM>d&p8<=p?l9CTKX3~xB^fva~Rs~x@c$XZXj3>y%XqfirE$?aM! zu2}K#3=UT|!RbNfxb;MCN08lOKf;Ke3LR(sQLt;bskJ6HgB+)T9rrL;AEiRmn{c%h z?>4bIg00=^ey51D^+zu*(PH23=7pEKwgQ>8w zqTj@d^qnN4v+nsfrUy#Y^>-Ex-xY*$JIpi=Ee&F-%7-_AJC7GCTTf&8WL*izDV;Ea ztazSCAlp76_1>7}26_;&zl5L{+TYaJJ7ZL|qxS2r=Cp-uwX7(0%+dfW>XPY{b)6yH zn^Z((}u`5lu<2ydrHR8>Bj`6sT;oSWbBiQP*M9b@UKK;cL zykh;)!!IKC&Sa$zpx%3YW^Bzro3;F}n!LXkbGamf`6jYx1sJxSM;%56UHhhA;NTt= z>*}3~`OlrIF=J0q)i>1e0<2g{yV*xHFE-A+Q=WQYFpu5U5=^7UOAl_=%g@jzsP!7< zdG3TURk?cEF1^;Q((#~qKvL8@Di=vjQ7^1@5U@^l8qd>1q*_uw@dhrcJ1FwDFx3MG zj_}mHpyhiA<_XT`-;$*{ij_$Z3#$B?s;cdaGjHNb|CEcy770BoklOoqB z8h+X=Jj^eRP1+{D`;c)H8;Nrc|MG{MSni66j`PeJ%TIT57)DoVH?b-TrUW9j?o0Q0 z{5fz4QShkII1kjdM1H2pc5Vb?aP5Wr%NFi}{kyoE!b=yuZ~qhr592H@7aQ!~Fgaba zbDxTgaqE3%7)ZVSS=-~ASa^j!H?fBTvAc2h%PGb+F)UJ3cBi}P!K90w9FuJgCdQjF zNswrF+uh;uf6!>k8W_9)YnHfPjwKSzlN4gk1O>?7sULXx)WrlyN5&oV90YNpqYa&) zQIK#c>iM-60-c$BVEmlYSmfSj9ZTc=V}|`lD9H2PX}ud)PawC}^(@tXLe=@he0}9& zJg0GK|3XbOEOm`zkFn!+ET>}j%LqhD&Om?CoFK+L^&ciqNMGOJ`7&=F_ zEnO#}`B_(GaX+!=eqxqdh6~VjTY~pK-)|KA+V-ZCznAlmrWx;8iPF#Pcu9P(EU9+X zvMk(SLUFzKnzTTsCqo4?#FA^gDki00u{jcjwP79st%&)WHmG@U86x! zE=vt(gR9XzZ5VTii2@E^hm~CtFM^AvJY)P~>dEuQ45?*#CY?~MdI_n7q{~**sl%qq zcjpssC5`W}h~%fJ7uPO+*@V%e&HpyEo?L&k+eZ>^evY%?@2Vj^nMD@C>$A3Xg9-Z_ zzBMtQ64YuZyMK}^;)M8Hjh#yMd4d@Q{zQneb_suC6D5fo%1=_)apLOZDGFJU!%%(# zvWdv!y>FJfB!yO(SJjN{tB;Fa1$Y6W!#I+VqFzu7+XH%sXEQ$#!sEI07Yk_S67gHNYXowf3dBR>=ZQ|8H^ScL%ywukCudI;on(x5UT>p zBeUSYyi*=Y)^0Os1~bCmVKBb9KFdb$YPZ?MeD1VQBL$A@xqq&_y9F7SXbFm?$Lz)h z*szqEL|W$J2ATz7E?HW(q=6PhNf$PQiJ9y~NLZ^F_g4v@Xl{nn+^xnWd=8bDr$0-Y zo9hCT{!JI!KWmfR5k*BS8AD!048og^c}j>^G~MJbwEC&B!D1e6Hr4I+dLN9yF<_Zc z9FlArn2HSW=<*8B!Gwk3lr@TRQIb=goZYonRaYdn8n7WRn~762o%1&+=EGr{gwoL4 zQIho^vYNpjxIHUgLlTG`T#3_kcV==u+IA`Q^Ey@(FVqwtILqBM<$S-f_9LmzsYI|- zi4)*_J@pQ^`O{_)(vL9oM4Xnu9k@Mkj4oBL)vp&eM-nUt!SNA{bMjIc^|&;65#I`i zFey;+hmNXo+;HcO-Ee40jxAE7t(lbg*PM&LRp z*bNXDViu0d`MeBwzp)ZCs?-MZLa=`ff;lCnR^zSMxGfxC)|pW1x@Gy(IjQ`;wXTB* z#rL5+hN=4GPl1g@tI9gqyH#!G{J!lW&CC5Et?GpCF~^JOW*$!xL-KV7M^HT!Cw(_~ z{K^3Cy~SMTxwju@&^|F67R>!AX_ieMeRi7VSmP=yC2Z7{m(3aG6ic1#;VQeK-skpR z&Z%;VvteU1V;r6IYZ^=nAl^f!Jmw20;+GKt5A&Qv{RjU5_0hktDwgW5Ez9CfV&j|+ zT3HkmtOygHL#XTkDPpm->r@q|65My^bveX}+Uqzx32Qm)*6%_07HSq+S<~kglWj4> zB4vdNPHVnanYJm;l``a@QQIL6IPMT(bC8YxTeR@%6d7jj9@IZ z*)&1WWa$J|0n5e1JhA=8g3nu6%TiT4>Tdsajb{&6ymDvjtK`&sb>H=033UtSJ7sW` zrot*e|KL?hZep=ZWITvDq=&&}5b)tE^(r+EKZ2YBB@=d>;@QItPG0)%v{H{u8$pv& zen2c)^{K#wn?Ki_N>Yj)>&JIot+Gz#XIIUd*FP%;Ld1Z}CmE?0NtUzEva2uX?zrE$ z&9%b()Of1%wPX68swso#op{-J{(hrZ_|UAyD4e}~oY7jtQT=lPI*|JOiz*0+Chf)@ zz|TR6X#Iq02Z|w(b`4*nZ! zm?~(CX#=S?-aahZKk*&wv%$s5b1&YM@hG*-#wJQ;LU1!rP5fL(I0!=%h7mG@UT6Pw zcbielxc=O|1hBBRR$#Qi;cS@cRGCMU_D)~*m5rNZhxbp&9-wRT1bM?z3UL^JM|Y!P z48)u)Ysgv|)qEGS@T8C%i>dS|URdioBAH~XGm9K49vNvaHjj+6({$!HOb~R9qlDQT zk(MIExnYhID(g&EbQYF!%9&O=X%m;0xu}EkxAx;B1YbFcP1SW~oH`T;?gJgI#uB8x znnbRAt~u;<4jU4r2$5@o+CT7K(7uJmP|F*C;6V16G91LI;3^xZ2|5%RXWYjfyJN8V zSi{n1I_aYtXq6tg)6R|AuiQgd>R&fxWZKYOixsbaaxaz&xT%yKWa5jRDTgZ^B@Z1B z91wT6oWk64$LsF2ZrZn{c8Uhmc@5c*>s!OGL||Ko!(GSUUo9AXXXOftzD?0e%Qn-s zLLS72m{2Jn%BRQ#ohRe?iC!TWLaRy=S$Yoh8b%xzJS>7s zv`*4aJriv(!_SZzhG*-}z2Z=*u9R1egb~4H$k(5h(>7s=iQ+RsMg_lrVWa81kyv_R z+#;KxiGoRjSKglrn%s?2QgSLGY~;4{cps@%trOXhe)2xQL5YD_N`zxc?n^jve2T5A z^V|#fYgo&XV(=Tcinxj&ocEIU=I>RkQ5Yq!MQ7Bl`xH*4yc{h-W=f}J7w?5~L*ipD zRs0YNw~dQru)mSr+Mt!<3OP$}s~hDkH~V~;12;cjWWt3nOMu$9HykP1Q{DlwlX+3& z$Jn33J0e?+8DlABhi29-Z%f>mCu~Vunfp%d$`O>5{cY|)x?RIMJ3-02DoJQb;3|@2 z0?RksfE8zS8keq}gdFwH1y&FKr*1ePmHPF7w zB3e+Kmt>vkP8RC$Up^<0e#1Cwgco7dzcK3S_{SAmv!s-5xA#*0(mWx!Da* zVeM&zIS?@}hH_epPKzDNC)uew+ig|bYwzjm8(c)Lgk0l%d%+iiW^_2kw3LjL%#@sz zyp$sNT`#P7Or+9`jh6Vuhb?g;8&YNm%ws4eCd8Dpv^*?`>M@7Ch>wTxM&Hm=N3VdI z534f!6({Ss7AnWoMB&34(I-JRK@$X9khkBluMWfop`5-L14bMIHeck0RN0*zIo6QN?L=#qwnJ^GoOjmc;r1e@VAgKD$b(#Lj}dn`QpO(` zpFaM|c-9CC=EV`D=-8xJG>UFoRCH3}H_PC{_aE*`W#|2OhINezU!nPA`hFBG<#p76 zFG$5gh4^7Zt7tJ_r{)f6;m6Aary}%k&aa;G4Q_jfidqA17dYsnIk-Z{QdFu zHUh5(ljQS;^w?Jj=k39E(Oki3Z$qn{R0 zhCY~@lvrgi9$fk*+WprxoZJsw(^9@o?lQC=XfMHHh5dVK7J8V!Jw#xUybbx^M|rgN z&)AsdD7!n#atmdA;0|=RZ!9t#PjJ9zBd|3S`(D3q_W>L7UlAI8%HMbP`%GV-C5>3?WX75B)l)-{h!JAjWFL_`lH-~DaF64;lx5A1XBg| zkE+>@rWj036zW%tHZrNs`+y%Ui?OFS5ZQU3x3$uB5Im_QtqDQlppjtM3HS(!QyH!pke zIz?m9A({h>_w6L=rNoPDy%uO)ng%R)fmu@src`!)S=Jl>(_{OfTKnd!q1K5BRR2k_ za$PZq@Fg19ZeVY%e{t$IQP-8mq2VYb+0U3P|Mt(`lwL6y+5DcmZI6gh(HB$2iv5G3 zuCGcUSzhK+cSjVzq85x=_YoM*sGzro7_R+L7oY5|cdS zg5j@0BOUF4>y~_fU3cr$XvrPM?WX1j^E^zwnwt*?d4~@zwVte*&Y8p{N3GA=*r^mU zy+(#BkZQ^Dc0sQTy|arACil$v8&|9*WH2;3AIGm`T8BwhB3Bnpjh+nG#PtVaPfq;f zvM+_{z!k~qMpv2B_MwYDOxDnXIiW{ufdeO75I)p*B3g^vH))K@u+eD@_hr+fC54v9 z7{(7wcH_1x9y4`2)?XS&hH+WzJaGM)NwVtdknfRF9u6?Us2+tMycFK0jy%{__IfP9 z#N*OSNHf1OA1G^`kofkVC2__)8MU0rlPV)2R>ojho+ z(H3*_Jd=_T>xv0=FPaZc)-ZkAMjO z=^OTy9|xWILE(J+P7yX+Ez1eA@^JO=^B{Yi@F@1E^&dSp!*DB#$vQy8ce@ja;X@hY z(H@>80ja!O($t(a2XXtVzN6_R?l`JXctGAZzwn_eX5Fi3-2|IB4+lE>N)CeZIkX3D zBGcOU^TZO+NnSAblhWC)6Ru+K$gUdJyy>9bZ)Xsx#DF7~bRQ@?Ee@t>T zHZp2o+NmimMI!6Q$$}|@jM(H$_6Q+_DV_VWIm=u)6<7bT+W4V@S_p zLC+WlJ+lTqBl6J>H;-;^u0-lSj{uT-!NC$-r#_fIVjD>;D61A3^*b>sN^v)XX$_bXxTW4d|Ek5I&H5Jzy5ze4t0XJiL?F5B+ETg> zZ~=QSh*;EpXM)x}5e_SY%XObyzPJ2u5pQk06?Ch~U8BpK&?Cn7-zzq@!ZjQw3+_Ct zj1g@v)TgLcLw=lJYZZRhBb;Kx0V%7_c5*&tQdF;fBKLGwg2`#CnoQBF`OK273Jy-U zJ=PA>omn@w8W5Zl>Qk4<9yX4xKDe2!5I2N~w4K=tS~&BnvzNBySYUt9MV6Z8Z%#Ry z!boAJoJlzkOH8S=AW()EInImS$9Ld%kTj#m8e@`c4ly{#0C$E7#KK+!k_q;@Ft$Lc zJQrCor*=U`t;TgvQ!}P}Sp#k0X)|}96Q?}B+e6f=DuXjefvsnsupW%jhB$o{i&rMi z_)2zmf;yO%$pWSi(3J>Zf0d-?1kJ!5wPKGl-m*9oaIfY8pIDS6q@*>LR=ImTNjd-L zt+`z+w=vEx_7A-aGw2;2wwP(qyUYIJ6c#!Hc@HN#5mFnn6@XnpHe{j#rnd9;&ivk! zpzQ{cArF^RX^I@kJG+a=O;D$KZQ1=K^}rJH*V{D-rbn!k^czdtN<@Uc)c*I~ZvD>r zrAm(Lf7@p2X&V!60J*#Aw+orKgL=?%m{PQfMTSkp!Y0l(Y$EaPCO!khLP*^R)<(NV zv2iC`+a3{IHtHr@uVyo?U7n&-nNv9gIdT_sjOtnV!8U+ER*bDi_?fS+0ha$fjSBbb zw9V=!AdvtoYGYQL7MnWK8V4wvNX!uDnLT(uVC`rXmR^84R00ztX-+tmGJ9UEjltl( zh=m>WF-UD7NbP4Nwb)$HiFM(!W0vkq zsg8+YK;BWQa|V?%fd$1IR4V@um0Eb_9hF*&qiV!HtM|;0Z?ghx%j5+(X;hVL{5wJ_ z-?f=I+Lk{5YU8|W-cWr)_o``9;^*0f^(&IBoQ{)HExnay=lj!oKaX8tLD&ZEU9wm!EN3XZ2nD_yhMhg`^bZc#|6q zDAk^G=V>^1Ky0cSV>ov!4ChV}&j+{J%3!=z?8%x{lc#EW0<)?)r%TL?AxbDFPEJMLWNQgnU1d*D0yw+;{+S`>rnv*W60D1y0!Gq6X zUngrYnCrLu-$}0+jT{W#^JeDM;faAi&p6j|f-8DrbT!xzs37HBHfn8*0UnCMne}V=|A0nlvm7sgI{CAG5)pF81tlX`d6|J_}?ys16H)W zdC8&OuU#n;qpSJ2TS*Ph^JG_$l6V97WM3J8IWZWdp{~xU8<9k77rm7z?o}l`x5DUE z+D+`yt2eO_mlE$EIN2D-R@YUB&};{+wi%2`D(`;>an)jCo`8|JzJOa40xAQueOY^y zFy9M+qGF_#eygfI1xWUL;w_BV4jSDz*t8EbDeJb zW69q6SD84S`XOiD@qTt$SVV!~)ADuoywpe=G2ZlG86px#5>wRfNDjl~`}3*f_~g)+ z6zIaS55zDDsrYt4XZ->^7fkY)5Y(~(qZU*3uhKvT3(s}Bp9NM1U$2wr1PK+5I6VpM zgcDm?QJsx>P?9tMY8)0Z?iy-DEN60Syy=e+l3tfnNu*YUVM8u?DsT`LceCfV@UZ$W zZ#VI&1ugDJzYZ+737RwIG{tP)YAR))JmpW4;vbAL!>*abwqzRQXOAR_8VkH=Hv9>c zg$c@H@Vlo?cXd_G(LqkcED2XgGpR`+!1q5{5OIH<0NwGy@ZUhWVyxzCFhKnK0Zx!b zO@iq4{hx=l7UCj^iKu!xmEaQl8}byvnPX5>kpKGS%MIc1$gKlFI1!+`0u=itR&=6Z z8EpCcpFl%&h@~nx@B}kG%QbfIU&5RBf4&XaXT!;=*XNOcM?aYpbY7tqQK_<@$LW&d z|KnwGkY02)&~WJ^g9jK=MMDt4dInwQB4s}}q=6ZR4SfGoO0m^UU%mdP6BYF<|9lL^C!W-jU0trG`_YYj;r6l)2zwyybip%l|et{e^o9S6h->sHJzNE zqxH$qXe~+t6S40-1=2~dw%)HAQwewtXtmo^!Spgn-J#dPUw-bzz%`0f-v#s1t-;gr zPmbJiK@7xV#Zs$(y?knKl26zUa?sXSd$D)cdoz2|wl zqKW$E)cpe7+~i3O$TevozHyWE7{)7SX|tD0U!P~DUw692c1R5y=d!o~)dH%;Au81K zG-vkrkdinBb919wB}tyy)9vszOkMYRgZbi6wUyNUsv5rTO)o9vm=tnwoH>=-8Xm=X zO})bN?tGT;T>vjEWZIP<6jwYi$Fy^=j}SIXcmy^{ySt~*iK@Z5wgVy)-MFRZHP51% zyQF1Fmx_5CE)4NpD?2fuDLdr#sUd5(C^x)3(3CyRZ|>r}`>8oR)>i)bMBerLayJj+ zbiD6N?jhr}Id{fJrHvy5WQ@Iaw)M)h%j*2Km!9T68!xw0hh49v&?R4bD0OGe+pF)& zZ_iC&?a6(f+vS~pfC#}~Fwdg#lkH<7i3hAEHBYdYDjV5lj*+3770S7JX8mjg%I+dC z@+78hjd!j*6}*R|lV2<4Gs)Ga9J@@HEBTqg%=EqsA}mpBHK#^ve6qxGcxvv7OlFT z8*nUuvld#d7f$0?GKyPfC+P0KdAsZaEvl3UbF(ABym~I!k#)Dsmb3FnSJn0I2YV)O zdT~0m_2Mk^QJaOf8t94U*W|`sIO!r;KYzzhcN=~4xk}b`%S$P{H>H_h_m`kI?}YIm zvsV0M?;jVvaLk!i^XND7(rcb=2#uOjQ`41rz;Z)^c6Yu>ZNizl7dv*^O^l)$HJJ^b z>KpOdtNA>~tg>9`b0O*O_;A{M-9|^LS!HDAwmEaG|L1ozWBbeuz{IZbCorTx3xbcc7?){>RPy?GyVu2dm2+2IR^z z6Qg?#NT(*u5BFTs?N%ow6Bi$s8xY48_0D)>QS7#1d>YFo*`*eDI;;U_HntITCn;Jc zwK&-V-6##8bx_8ZRXr$S1ll3Wv`Fl-xvvj{^+~=z9yul$RexP(&Am|e@5>_9mb&&3 zX|UP?EB^@1@Qa80L!w8X9y%AYa2X|R0CC|bbp2}&mwZAx?VoYXUr2b5K$nya5mE1} zO&1pmLYvN`wiH)|KNq-am}s}G6~&16&#eKK#7Zq5$aE4>8ATC-mH^T%g1j5NW;ql4 zh*IN#qkx~WEeGW5y1#r4*IW)yB+LPeAk)d_#Q6$OtKx=T>v{^lul;_&Bj+g%q7%2W zN$X!+;d9^5#;<#d(Ck#hc)Kx5_bMG~2mykCzQVR<|(y`g4(k{uyBQ z$8_l;3{4ji1<0v}(uwy-y>``@hOW)zJEX*}30bKXZnks*h?NPpiiAXV_G6P+|IH z+`?aBwK_`S?wb?1K|e+e03AF^ErAFkM`N&_XZw*_6C{8MQ$Zj4U0d`C?FFy%Cwo! zTi9m9dx^rz!y;k-{Gr4akC5wf{)=xMcRJVj1~*SIhyW<&c`cdT@4!Rw`$UYM1jp!~ zjbo}UNiCBf3HJr@TAW?j0{%b$zpqr&1X3n=X7Dx;PK|N*~*L)02FD=3Fz8Y4@uCkgN;08AHKV&p=7ZhlVeiy+p>;B2C|_kGs{3kvE}# zU2G&NJ~a;nS-iOsgI=V9NtCqqj_afM7SOX3YPn{;!Z5+fDRD7@uI1i3(7kUK?NeNt z2VB}X2K;8cZMSX4o|=Xb)2lTRBg&&r^Y~a^_0fBkL;cx}Bd9v3wm+pt_(Q*ROoHEh zqNlR%ARluLRChXFK6JFcBrc`IEbr~s+b#pQaJhZlH>AX9d?x}c-|-O9eXV)O9+@k5 zjCS`4Yfoi3@{#wCUKRL{uMleZZ0AH4(rI4JG5~njeMI0*zw3NfVl)Hw)O&f&BEq|x z4nn7evkBI0gTK*TV5Ef_#0XD^nRgI_?IiQyQ9d(_goJ%-FzZ;y*ygx*6}J> z`%FaHLUyEvJ9KW+vogQv5?lI~NL;#jW;XO|Aybq|1?UcCP*(oHK6N}oOmEGJ1)-zm1Qa^;8{Fj`UR?ekxCMcmE3`O6Ms4&WD=P0?!V1YfpgJ zn~Vk2a3S6w2{tw^2%B66ee0LxOmT74v7PXPlcByyT~z9bNRM>bx~i42l)q%_cxN&G z5m#4$D*j;ud&0zjgRZCZ=o<68?3;MnHBjPG^0koM(nG(INdKJ(KGv>^vSw;dUS31U z1>m)Q`3$^MLM%RKoYn@4Ne8VjKcKz6MOJbZ*GJpRZSQJo2dL%3R8zCqN_$Pgz}hlX z`$%9zm-(;}0;KIbB2C%tP47`P?j@Vw)5a<5EZj~ZvB(g;GR4`jSwbu2%^xU+(}6dK z-piTmjb&3ijK?P62?}J$pUkQgd|8#Y^oil%c)zj!$U_h4WBH__TO-OhzP2}z)2RR#X9G3HT;Wjw)e(1-^AP%~R8aN{BwKv}>^8J2n zM+s$0)ojMb$<*=}GT<>9FgHHjHGd_fmzXNJ#Ov`cG9tRQQ_}EDH1`V@^>=0^iHr)J zMJdI5MvYjs@sOLYdBYE9B&XJ`+rYVpT>SOK?2_yzEbfMZlT{B@!J{oA2m0@ho$5?f z&+5rFKKK>_n1`!7Hr{qqhDP_rSM(>+676Cpmc!F;>MYm1c@s0Xr>0CwOY4>%zxaeiiZD&%b1{IzQ7|;2C~| z{?3GX@|kAHDD8c3ltx-{CDqDI`*WFg#3nfF+A1ES-sdYXB0@>|hXu87-dM;MKD%RH zX?r74{&p?T6e6JwYePT$yVKRdI0d*3!7@8$2{juEFERSkOLNoP zas-LJHzeZ)(N#jk#Dkk>dn@}hl0lc_AZImjWWn}jg73sbiPNpM2hfik_2sBGeFH{~ z#e(!no%sg^`%L0$s{`#-Z9pn2KMw%`E(DV}Z~V;z3V1u*1nm`a{m{5e-f*H9;l0s>$>2QqklR6X+K;Lz zNu^{lgEpW>dLFh+#AY1XsV-hKG$5d%v-5ev%;k&WRw&NGKcdt3c|mwDuXm|yX=6{` zPv!tabUZ?Do+4AM+fcEAh=X$gvTXZ)L836*Sm~TxXsNf&w?eKD7ZuS{9FPkwuyWLhIuNl_KcZO+$M zG6i!)8Ks6wyyE{cP5ff#JkO$H8aLAkHQhnzlKLUbIWl>&6|S5+>k4aoHr{BTZ#TxB z(iADN%J252vRoAv?YWloDUV(Pq;qnrp&3nRf8R$G6-E(98G|zAkf2$DM?P~Lyt4A8 zhwIiGxhgNK3)gVl?c+ecWPxB3Jawg;d1r#)tX3AAo;V%^h|lGCQE*0OntRpLuo47{ z+oR0d^7LSua`2f}Th*4>A2nS9wI@9@#m zC>^0UT1DNYWVA(|YHpaqa@{U#l(-tD8>j7R=5DmtvGmgt3uKlvo!``apL!zQw@46U zmNY*C`7hi0n*+bsDZJg>PpTc7`-1;AYlc6JMKLty{XI23eKr3Wz4YVP;O*4*4Iv(M63)B0#wN z=iE^2Aw4&AMuDs8**pf}gU!tB1@gSI%v^we6wmXlv}X%C<%uZv)QRktLc%K;*Chl) zh7o!%;B_r=4yO(us(s+(XN1j*0jupKZ-8 zS;!(yhsj&e=tgXsQ7t}RF3uu|n9%YbeoenZt-h3C7qM6Ryiv88tKMa)XlHGvZ{fYK zHXZM_vBRFxH3VJkDRM)t9r&okU8FInc}#akpl^PcbS|D3=N&MMh>RZ9naCQnz>QrS zpoRlqxBf`;ZDDCfg!$hKUr;$UYm!0v(#+1ev4&AR4TDyLjyu5w;54doc!BrPqEjj5Xonc#R+LSNY@H^6JUsIFy{SJQ<^RF zi`QSt7Kas5b3JVaE)EJlH#;stysJlf;?2z{bbYn-@;zQ*)SgqbYF`WMMWBCbW`tWE ztf+%n&T`F6tnK0&xO|gwCD)?;+g0{AcPo}$TCyDn>Sz+8C%u3hp3PV+fPLNY<$R;* zDYA)OMwyEbvOEIs_&#-`IE+2Yj%jY4nzhW?k#@zC}%h6l+mv)OX{C zv=v0bP{zpNHm32%tFQc_%FoS=VAzL67#_113GyGrRJZs&vgeG)0{nZXr}4U+$PC1B>*JEI$l z%UH}BzxEfQL!HpWeY#e_IQPn;PEA#}Kot6Sl{MSzPiq$FpJA+Um)+y(M<+*zN&x11 z&jY*mxS7*9A(?3M%$udp?w70wkTh8>oa_#DKCEmUfclX4-jEw}cIt%)J*-2IAJHe| z(z1e>bLS!%!Znn8v@mj^`k8ws`qPW3&qd}n4pqMN;}{<|E9b+#mH`_36tTXP5B^D@ zq~hv_RBVS|JoM|wV~$4>gWGuK!epZ~aIruFu1lG>g>XL~soguCX5ThGrd9bFuupL0 z4kAqWm%pG8bkrDs37!Y)jih|!J<2$PVIS#?y4R_OOTB18L$B!g1?)o_<;}fyj zIcAVwWclCgYOI20`6tCRZnck*_O4kMi1r)W{8CB8qPUcuXKxudL$}5jN>d!%ae`?7|a zb?mrg<;jGOvOIC~V3ZOD_`=}ZCrJShh7MHWrU5cSDL$%m&!xCMzL_EM79IL{ID^r3 zLYKBs^_3*5q@BB5`skJ8z>ce9Kq=i1_+?i3+^so0o!YA8RS>na9B)MGl=R_xh-&i~ z{~Q}t$*~;~CJuU@!F?RAAg}EzE}1Hsf7v(erdAnozuWD}a04%3JJGcOo2`4(l^LQ4FfI03$G@+d2F)M>7| zAtVv+syEQA?XrQtul0kEe>*S0Fx?DdNIwU%X*)Ys2*9b39v_;D4VZx%yaQT)16Gm@Z1V z9R?XL@n1Z%^mx#KCimzRQTV3WFF0wsCC&L-*&8Ow?(O@oje?6cXB?Az8wBxvJqTyBvJ{sBkmcE# z?x))8skH)p4+jGILzO*HsHkMoNQ%X|n~6*AGhzwM^+d!B?vL>Q4zk;Aj(G!yN=;4D z#|{LxWS4br2qyVP`($2+EKS}t-Uc?ePC=a%KG}LR+@gB7R!T> zRXBOjDVIDH(gJeqFmX>=Cy7+fY5?U@I^kd~9CcdfN&VA0LWR`-V$Y<8P7 z3NYWp$NsFr^aqWdd98xcbJN9)>{>hN@vqjo-Nw5f;`9iQ`SnZ6zPo9_vPc`8{GT!u z*G5vf@$6k$yP6|35i_7MLJcy;ZK-kY4p=Tsq^2%NpChfJMD(LS)&R)wy9Vr6IR=gQ zH7_BM@)4C5u$r=>+VF2v8t|pE6R{0YJ|S?VKPhQ&+eBXqYv$$3X;l_)%72i{SWLsT zdNe!!C@lK!5Ie7C4(2;wZF3w0I~6_XmLEY)K8}3Kb48~O+2^<7ya;s0d;R*zcIr^MarX^sd(ELm*QvRt5ayIFdlHvj83#ky$|Q5u}P zRn|#!hltbIIR`6fwBoN&4hE)l3*;7bQ3KiRe(U^+c?|iotv>IGYi;QVbPor!Gk|J| zfnK|$OA6p4^0nC6hYYXPd`?;O;X^E@+WAmoUsjRAC5=EmQEvStX=k)8AD z-$9;P2ZaLfT=WX2D3oZ1Fw{vy1=v~9HRonSBMTQBq~mp&+$(Pg^ieTexpern^xquQX)qGHCrJjta(aq;p>aV=aX*Q`#VB5P%B_Mm>4$Ny(u^&5%iDd?>=$N<+6D2oaFMpT3noz~_z=^%wVzXU?9d9|Cu| z?)(R8jnaZlx08S~2=h>1LdsN;IR%9rUv~01R#Iae%?05^()4$|LpArh=n&t*^bI6W_n{d_n z8!85RfQ>FO2UR~1IhKObm|evUr$hYI^=2)p*1a&3q+Udcz-gF$G zYDzsUJ~-Zc%xKX0LZVWNFKU=wEsYdC6|5FvK>IT{_Yz)3PZ~l5T z`07iM<6b@2H#g)yaB^z;`%8hLwFbxKhU}u4268yom?ySiAx97q{9> zsqWD{yi3I1p5xnmHauf#d)9UR%D)UbZfWvs^+OWsi%uC9I}z7nxx8n+qkB>{w?U3S ziEpab_jpoElVx@C2+yv%Z5Y1xUxb?G5sPkFa+-l6rpl;X4XF@gT;FSLg_@nXcy3>d z8*&;HdQ1R{A2ho$Vi4u3DEP!%lXa0?{6g&|G>&~@6HFM*H)XB74wBv?D zH%p_6GgKz6TpU%QcZ+^4Z`gI*p)q7^G9W2uSz4LtBbnWukJ-c?x{eqyIrT0PV<}8N z<_%-dr~gP0l8zSGPMF~2yu2q&3%;M4DSCu13v{W!f_ACH0r?6HZAUfw^PS_pX=T(O zD6Y@{kBrB!H+fZHEIuTsw-2o$zEK?edbtJxpn0|eCN%Jn#YoFfK$fgNWe2jG~dgj8`HjYX*oHhHVI2|+aKiPDXECH95WnGjaPLM8<1~c z`WlkQ^(wajWOW5Kdp5m7r)pvT!#+z?AG2%XudTupk0yDa^WR%M4zr00@K-yK;fLBLD?0^raQRbqvVwHmUT zrIV1q5c3(-Z@Dz?ZS5>xdT4D`!76Y^9o)Yxbic*-+WYPh!#&h3)AVQB+gZKq82m-n zdShPMe?FYPs`>#1)Z$A@&2aC$6j)TwQulzPzS69)Yd(0F6)|XwwQ_0HHS(OYMfNz6 zy3QFE{_d3nJ@_yBb+5p_D4;(O*jYn}ZvchIxfsCa;?TK2DICK|H)xfoZSQnG`n=XJl}kgl_!Zqgi2+8nh6 zMC38ycCAn1bbzA_+HS$RUn)NdZ+I_(3Qo*qTUi{Os7@T51rse_J=+b`LR%Me3Q_8? zIMCl3S!KCCWO12w#Gc9oMZPceE^=pH_X`0Uu+AJ9xN^pHsngC5_G zUzHBNExfs?Ol~4RWY)xg`A}k(kHpe+!BpxnHJ20R+CvYlEZ5WhZ3od@vy@UBTvEtR zg##7~x$oef*u)hEo~^hsGAEFPyJ(o)?6BGb@5#YhX1h1_dd(tkeDi+!n2i71$t}os z=DC$qgr&M-@$BOQ7-Ha$vXuBIIm!nci`k7&o4Lpy$M9aUGnR8L!<6Kt2H$F+Sq=n z{L^%rmM1ONwA77|9_pqWFMomct?hPzB)iN?N&}$J(z2BG0x4EZlR{1wk;; z{XKblZ*kKm8>i#H%@ZvS{!|oNDA|PG26Zux7Chu78a`^(4%`|rdNyI&+ur0?Pq`9A zSo;Tn=C!h}Z$zfYwG~$a&!5Hf%Dgza8!!_YR9@!XE4>1ar9GV7N754Xq*#NWh9$o1 z;_7(H9iv7?+S>x*GTc|48y`RNFfYor6n};j^tX|N<16}2rSE&PDmUXqmWZaYvpANb zA6Cz8vrMhF6{4L&UB}OZHfT~s3m+m{W47BH1r-r6VPCJ9z zo|WS-R+#%HS?tWu>>Q)u>z#4)*@NQyI=MGuB)beihtdaa`HPj-#Sm@vAcQzfe;btN zxO7=r)>7PoOjx&>N&0Bf{>c^b&=4F1)YWEiW*YFj5gs@ zBQ&20Oa=q;9mQ%S)F|BpU0l55Zd|(VvbNrC9J{zj2X#rt0jhn{ot|=E*kU=*CIe;b z(Ek;+KarP#Ild>$yjeL%fA^u5_Dj%VN^Z6orITOe<`B3rDDBuO@AC~w5$V_)ap4Y&?3>O^N!fB(XG4}XJUOgWk8f| zK1lC_g%zXaP6}&T{(y1W&3ke2W%WT1clYOEZ5QEN4;XVwGAnMV@QDAB;*rPRsn|{Y zFm?g9P7^s>g)^vz?a%x)Log86UHRcV`8Z1w zn&1CL#VP@(x){54*&Jd3emUZoLS4{4UCC5F#NsvZO0GVP&Y{#g)$zi*JKcum(9X5~ zc2x1L3s(EZL2cD;4M8CznK6v4@-6+PEby-WWtMY&UhS(4#g=++##FbzhH!B|{kg@q zx{mxqq~t!cx)cTnHUXT`+Lotg2!dy|^cB+wsNNx%`)Kt9D&?)#Isuio0T$aL_)of` zZTSoDGMgc9lG-fZlncCs5wL#K%4=Qi@+TDZ`>cLK*kw7ZucSq|o^s1%Y3K@vrJHDW zwKCygi`=P!2NR|fm`4kH$5j0dVivq54<+m2lbjW6C3MV~?>N`#Jbk!g&qm zntqbvvX#1mn%>AkPqFWYiO(GmhV8Ah`g@tn?6L7k3Z)iC<-XdDZ>pJ8DRne<_;<%b zi=AT4Af({AwPv_D7vt(aamfe}mo#IMFbK1Xc&E(+hWy5g@6#Qc+VG$2 zgL)zb^L<}WI>keJ#6$E_^1}1UPFiFZO7ZaE8e$O$x5312dIfBz708j}UMuxn>-v}O z$6uM#yKp^2YOjtXo^*MImM`hsp8jOFC$2B;eb>A9CjxV`jkn?Kfv!KGoU3k_&S_ZM zHg|~ZCZvW}qQ~3_H4+h)XZ+Xypx?Nz&^J_K}( zDW5SQXuN|T<&>${uPJ7r7GWON9(C1By8<=jTV)@SD?$y~itHhlpG}Ofdr&C%XQ^m` zmmXq{EM4+3hwEG53!F>&2BA#oE1j-%?Dj*fDiU%_> zcDi#)T+Y#YM6{me`(8#{k=1VQw@%0>=2_(u@Ztx|$18_eV;t&FU^B1P(d4`M0{skc z_4%T0Fn;JW*LXt0?eU-16fGaKbluid9nr{z9WC=Jh&ue#5`fJ;l?+H zI=p7+i*m>bvcyR(q$3%{JqppF+{Ku;tUU$9ROJt-Vll#>Vl;@p51j}ZAt{ywk(br6 z?3niYV1D+l;mcY3?2(=MKV)!seY|T|ucCgu*j4?bqfb0KxEK(%7BZhrkG%*B451^R z*0WEOkq0zB74*_+9)xaM5!5Zo!W%ut=cbfE)1@|0Pv7NRKa;!X2|FYxd-lB&9fM-bXaOmCEZqY43@*4+ksZ(18F=O$2 zQlA@gU<)4TPnt^Up%ak?TMAWC8T)&}C@*db3o+(X2!X+q9Z5O!q2I4qw|dCct$)1S z>J(>j4KlWKw71#$onyO-Z}n49OKyQVtUudZxHKWrPh>lXwkzyATSHGGB;^Qoa6x?I zs_cl`1L{4Foq}S@5YEu73|As3BHw5RD^bTW?X&IOI|#&Xv7tx92@xik-K`|6)!}*u zl6*|2y(bX3ag}+asI9WcP{GGt#<+w2X&#HM((*#) zoK{CJ>4@^<>6QXgnRHxK#Z*Gc>VHVMq}q%k!xKb%57he`zUp?KLZBy;tK44|7a!}1 z8{p68WcKKTmDzyQNT3;a4dv%a1KwxzdNJLv?^VF=DPd^nw8@XK|!Nt`koC|!s&v9M}5c@r{A{+ zLA4x3__Po1LW)_QD0`e+V$IJ^oMEE+TEMhQ$Kd*Dk76;G-OvB|IOjRV4ncoA)cwah z$9)zvCQitq&A!>+M#9Q2({CAizWo7cp$cmZpn6X+GJJqW#;pBWuNhZ1&Spw!>&t4E zsOP_AHvEXBG^K4_n?<8&sMA)OAJh6%4uuIK;N z+}lLF40&x8B-IteKRKh{LjLpm4Yk!9Yw71qkV#pfg_pqUh_e%p%Q#6$>a;Lz1{7*4 zAbSv!KkpQ6S><25)WtQJmri>+fFq$X{A%EZEl{t2k|D;l%%O&L{oDI^IRDsMqg6aW0tOl1JOX=(DX~<`U)YYWZ{-ng$6AeT z1)^jiTFM+nz=`nULd+>Sr{cc;1#g@h81+Y5T-$`RZjrnN<1SFt<-7{(Q>LV?=nVa) zIpLr`9C(sYR2WK@cJbo|jg$y=A_=hk!)K4Ry1>9Q&gDmIvZfc2i*r&({QK;xS2R7p zn25A#6#mdcQoQ%n(QE3n0vE-FjuYT`Dzc4f)KXWhzw+(M*Lrz^XiXy z?0%r5m zk4h>sr=PLsx)H0G_2D5HZA*LlR-KwOv`DAj5AsW9-&s5tkkvq+xn4WW998=4_q|I< zz3pV^Lzsu0N&e;qw`BP-1IBo~ylHXuU&`gV|H*2mhdQ(v^s;@Eik{AGV$Sm-x1DlZ zn^C7A=aWE7s%DnrAcIll)ZH4SPR~(CX?9JXq(W|x0{8qh`dssJsZJrFEglm{d5Zpg zG5wf5U@Vc45q^X7n~UQ|T#|EgF_)ZrHHFZm_}$qtRZ?a33tZqgfST}A=Su8}1xT(N zQH(9m(luL#iIms#*Syxsy2NBE`AxK3dVQAf&bH~De6CRb!v1j+#NthbX5UHr@z5p< zzXVEP+zz{H2Lu?YN179-2 zh4_S!cf%!d76lvN#W`@NTJ*Q-L-R`ot%;-ek5VZI6FKnRuj=^#@QY<}KWoxM~b3 zm87Gaz`fl-fl!m%&ERbnoDb}4pxwbHvh%>p{({`e&9!AKk^>0&TtXeldy&SBD}hV+ zF+ee!T zjJXDJIi)}^r80GVkF#3#@c9`RgZMVg?IX@=zgYqulK?~Jc#OG6k&0_6ze)_jXAE{k z{BES81RP#lHavbAuXQ_8qEWGKg|{(1%JtWMvI{6D*GERDb;nfTJYqK?8{8r1%00H; z#2jJn*(&*;2aIBlJ*2VOlzH5!Afc)K(u{dixT-O-%MwzUH0@H_0D@V%4Fe}iQci(k z(7ie`kOl&D)Zb~K$<^l!Gq5-h0#hsY?)$t+CI64_Qzq{eS^ztsSBTdYc)#{o*xLgb z7VD$&3q0*|#Gn{^eucL&J>C1RCDH}W$v>(ORI*qWZ_EP|Fx%Ra{qIp4?(=&$-gd#f zRcm1vKUyGrZi|+LKORiAh{1dFmtGH$g8Wdp<1}2uUeb+0!9x`p2|wyVXv=gCEaIJa zULN!ZQ)eGb%ZbdILfS*NfZ~hm5yw?1Dg1}VV1oqo6UbszY-|VvUs&iH~m%Mt^*06o^3>Pg?-%iAs%|?P;T?s zjmY5k-1dow6r=8U#8}Jwt9|$MGnKx3N!Fe<)pV*nC~kV=$yiw|&@O4Bm-Dfasmyhj z`xWGpc};hf&bP{Xq~$u)`GAmKt(7de*r6!(u$!(P|E(pbB*7bj-Ew&e3)G1vs+DPM z2GD1(M30*}W%(?J+sp`cX(5Q2sz#$>OQ<++Nd}nb)A114y)#NZbUKnfS`uK;I`BvO8q|IZYW z`EbhyWV-Q7<7*1oL!M8%xa(ck^S{8_?te;`>~hIGs)@AO-t49mzwD-=Nd3(m-IBb; z*-um%B{A=-bb)*1-tfho*`Di?7K!qOFypJz4}Dh2GmCW z6{G7{@DU?dXLifBuC*`3B9&x+z(mIFmu5}K9@f74)86?j*=G0C={+BkEfqfe&pE!p zKYbg!H(QdH>H|3+*VNvTCptePL1t=ti@eSQvtw;5`!DpcjmfoDpIGg&*i%kvJ#CRn z>Vcdkp3-mgG{`A%$gf|@n@M4D%0+-kVP8KV=B?OxN?$YAWGSSb*>TxM`uSVhbp$Xr z6nX)R%G9GH2l@Q|Bu18MDL31=cw^Nz+@Gk1F`owa4oAadK}!cbvd*I?i#3XjltcYc zxrNr6_@=OnoDGHibA}MQxlA?L;#Atu`HbiiY4u6$wlOrbMx3|LBw93+YH8p4a zBCb0Ak#`V%HX>$tg9nO+>{w-JL|N=DCh6AGK)s=4Nr+*d!ZGcvCx=A$fjd@f4S>6!ndLugcnK#TB20!EA6TWaNC>_xDXuUWv8FTCIjQxnZXJI`6nD|WVvG0>Hl z7PElH2=Bx_bWvwF51n5nX{_URlfC#cgB>E%;!?DTerC4&Ml&tbWgHgMb{|s>YYTXc ztt#%jTm3*1Vg#8qa7Y&WhDG)Ms>+8anQMmZ8o3+uRB*1O(BPvTds zjFrbPVL%4?!RG@b9D&w$d|MuB^VIpDPI6gTIiNf)LvCL(JGPUR8lmnJfqj~i^Bks* zbWifcNNb9AK^<(zJ4^O?V4XVeiAMnr47dwauO-y+P#+BQx=q{Zr4+kRKvknK7sHzR za{+F~JeSch-8toR24J<70Te#UO%o>}qx`xi@kBN=Ek-+xi4#Sk(obyb{&zOvqM_8uB7k#_hgm z5Nln@4|a_oRye(MRHRDUL2|rPTTYYoVnfyr@PY$YCWhFvS^_l3fSi%0qK-$+ofwq`ormO7?pj*n`|`5%zdBIS$9`g$x@6)Z6HCisRMFWGXEKFSorkk?s=0zs z)LGUCgay|_{0e=i=Y)gJSoe=Gh<5KcV{dM&qp!a%3LAsF2EVxOWsWZ;W2&v&j-4XH1r;~n7fM1d%^0A4j}EjpQ<=_^^dG*gF%Wn2I!z)oP+_IL^kh%EcPwB)FGZ9?!rN#5ot|nw zIhY4n7q>9)!+h{eQ!>sB%7<3a_Z$(IZ&rMU%AuS~*4H`*6oKLN4*)iv7R1YNo`L}9 z{r@pji2bw5wakPSlY5WR2^K#O<;o(*p z((xnc6^$(4+yI4pFis8p!C#a6E(y=1P;UM;<80ezFisYsZ@$9nCgxS^(MzUX5IDa=Dsm1iTZBFT?zigh@C*Sr-ovW=!^rmD7H$U>x>i52ezd!8ps~%Q)rVy zZ;`k=#P^mSv7q&y7uEbVtz3IZc_Ywy)#dui-%mwdfVp(eRNRwd^&6sj6aCGh0ZKg& ztZX-*X0*XLP_T(6>*4qU3_ET%bQ#%f4uUVaXp^$wL+&vDMdpu>JjYw{!OM+{#5Yej z0RL-sKnj>SPX5u}f)Y7C*G!{@+}fr+_fL(N^HlkKC>dyiIQ+`jG~l~cf`QxqR;R7$ zInU8e=qj!wZQ(0dfw|N}+mr+#H}FfSIZ)}z-8;F;vlTSywLtV#FCE$}lgFVBPu*Sg zRF~vC3@73p>SuY!2X@t!+)EuV-2v!c<(|M2OO}PTO~Wr|pkiopuB(3B z8MXCg!%5}8toXQnDXf(t_v~h6W!vf?Z`ntvx)Q*#WZMoum6m%#K|_Q(@`qyVNDCL5 znv=^Q#7d@t!NAKclR$0tR;yRnJNIzq7$O$!{q1B5APZQ$H3?d?6JL}?w=PHo&n&(w z1HKyxNYsZUmc&Q<61m~C#@csE-Ds+qTiJj8AKVtq1RO!@ZBf-F?I!jL|8O>E?oZbCdcq>WM|qOF_FDB6v- zNTf2V9VR#Mrz;g zax3#bk}yOGjLNL4reSQCQrMX}CCH)UNYbm_rqf5i3V$H)!|%)07>hB{{%(?EVqFe*oAWIci; z_LbT-L?p@%19|NR%8zf#0vtN6M(8la{1=|}A+a^@c&Fj%a|=29DmC>Of}G~{3kCWT z^HljYZ5g}P02gEb8z&d?enHar*S{A_c`)jq5jgEZD(=dpd6QH~@jho8-W&R?iUr_u z;ZHw4@tp&ykRQOk>|95SeMRr2Lz#1hI99fui>dvD0Fh38bTGl7r=yz;9R7q2uzs_a z-4`gwao@ZW%eh|GS_OJ{s>*;Kht*l`8ETE4tyM8oE%p>Op6(AyWv#)=Tkm1cJ4 z35U;r&QRu#Cq1-Y-q*8Wn7+M-A8<^ow#ibE9YVhB;5^w4&Y<18pwZJw3y!C}(DSS0 z&LzV6vvHKNJ4M%v(ZXbIcjtxH5d((ZzT(pHBlQHMA#n?I2~GIpM?FqOFmbkz+=a$DDLxH6({dhxn zkGqwA*2lu?TI%kfJ!YAe=Wiqc{OW_baeVVHHaXFWLe}41Z6m`#Fz0oA&-~SoM&O(m zKqzNyIs=L-M8Dcdk_KoC&-^=TB z8{6T$=K2ePz_hMnH!S={DPW;pqVMhU{Zcv58FDzWItn)G^o;|s>{6tbP{C=;Fq+b? z#kj#?D_CIm{r&G4_RIl@)0rMW&+`ke$qlZt^I?uC|8sQ4k#=~9-fL2H1I8a$s6A9V zZ*A+E(KP0dODVDNil=;_)$g+HHF=JmSga3H9X0TrU&ySjFm6yRLeO5D$o_%NjkBM-mcV`QCV~3l;WJe*B zvxoD!vbk1nB6H9mtx{>iiHq_GIoIY$AYLn*YvqICPs+-aqJD3N+%H^A)zfqnL#^_| z#y_}kwu)^6T9I|7tVYGsHCv51NAnaAON5uz4yEiMkXuQri#Sa>S7ZySVQckg6R~QR zFbGLm;2*mr+RmTd41AxZvWuA4?{owkp{Ss;+06DAuk%)8D5CENLB;D<^!_j5u^h13CVAH4vZZ)|YubK!)lnYaV zFnLGuBYb-oFF{!9^B1N3g?WwlKk^~bk_Bq`qk(2yZ1h(I_IN-c=6d)Tt~SMXlTFTP zGk~A_=3F=Rp82nRGP1~Cb+J&&Q**kyRHz0f$N>b>-HRROm_2->p9uG zQ?#IVTBqf4zaVY3QGHeMj-H3q(1}{Y#z;uj z`G7g*!A?>R@t-kIH)dT{6D#VCjMsr&DkAy<+r=k;B@|A8&^+461*^-cL1^PT!_@9cBL@svhM zcOryq{xGyd97G=kKJxJYwfCN3O{H7dFzOL?6zhmeRZtw1pnwDr0#*L3Lq_^g-$}}-RPV-=e*bV^ZWVkU$krUWIy*_ z>t1W!Yd!qa6`piLcQ6DMy+tq@+d5+bp<1g#fZ+--5fd1TaDXJbV-zSef4%2%@BFvF z25(j6Hi3^DuaA^YN*Io1xxDIn=j*KsYfjs5iwBfK1j#TQ!sC@bWeGvdQFxXQIaXx59rTSA%wty9N_i&3PGxqw+36{-*-`r!|K0aN$A)yOrIk(yxxcX7 zB&a&AsRsfiLyemr1GKcNn_H{$@EWy{v)C%tt@jvsGU|En;ik3Uf|pNRGu-fR&7L=t zLhap(+0+sjW+(|fX$3bZ>8i7*Qkqg?Rakz65g>|PZJrp8J=1UcU@KhX>61q9gW41S zV12*FFK|V2MC+W_GtR6>O{Q4SrPO#Qftk^j9in#0y zosyJOMRD$#3Z1S< zfpbQ1)3V|wxT}t@%CPDpZkPoG<@F|!UMScPX(eoW9u!Q?`AoTNX=Oixl!UFcQHt^o zey@)ATRSOS$8N^-Gv3ijE<77=NbfaYs=X=Wg(v}qjF`S`pBdEq{lmqH z9-q$ffPJuve^=`s;fw=Xt=NWe{E#5Deqv&}W^;TiaH8@!_0>QV|wSn^vIJNxOTs&gC`hv{Uj%T)>HdTbaqv{?GTRucXp z3@1bEDm=c5~T?`_}23MsR9)mP&X(&-Jpbm&t z$SE9&Q4Q2IrVw3U5!S=x#WQa+)H(5};3$RgQlIt~4WKow1V;fB5N=-ml5y#qqi457 zKya|Va$5>9pb@qm{_Xf^kHj^=u*u__>*Vq4b_R^~Y~Rneq$?{@g50Fy6RhDkua*lQ zZk_}m_*Q&&A~j(;v9WC|mTTjS>#BM{($}csrj!*fd4UxqEFS82xPVl#7Khq>HC^aJUyT$b z@b-Y}wiPc`Sr)N00@vDvDI^@-*DkKNm~KXT|NLW2_h7g=A*pJ?wm=np){E*YeaCBT zf4OxyJx@Oq@*)_QQd5T=Ah}Hq?zA^HfgL zAb914)6m2FjWIh(-oVrV3uE|+Z&uu%^0kZEyFDL#+XAp>dZ z!mZ|p*miUN3x3_>_iZ!7!OuRXjM?Oz7Gr3nF!jX^0~PlBlyBF68C(yf27V#1p|HAVpFPk!?{Tt}Gw-60h}e1RRDnX^^wm}yhEM19eGggd zH(FeZR9=3j!Dh*5A~kfmJnYPldd8Q&C54fq8orAl|5xw*!4JRrEpNbHuYK~kE~lwy zIzROBR!97r`oy$RFlq7E5_mZ|sOJC19*L0j@O2-v_y%L+Dz7e2sXhFH!-YE(xK z*~BS0Cf`Xi5on_F)y`@9P4XFxKG5TOz;rHk3vwGqXf+3DffC6}rttGzDYWc&yjQU6 zdbxUxX;lRAW3Y&EyD9VyF6cok(6D_?{AbzRkk@}C z0Rxpjs(@m|Baeb#+I}Dboez7$0d;=el;Ayo)k?{ZEaGC zmOpax7bH>X$G-!-wAo&^Yrp5dKy5(OQtkYRdh>Tv=l(^U6LL#tbI#y3TJAW+Hk46w zzm)XH+4HvT?pJTI1$Cmbh3n__aSIYVn7*h(#}3Uq#p#EY_BnYwHK|k+AG$IM|Aa1X z#7+{w#!}M^<-QTkr1p9^{KGg;rM`hS$2&MTx3zbz`G1s_im1=kQX@P+V7e*_0ca+n z)*NabpKj;(UW?zoMSqLRGS~LKuN{o*YcoagtUq{c{;=r_)%OM6kqoVrOMQg!YQ+KR z)LSfOb0LlR@#qd~kF7muqDv^8Fm9<+i`sidS|+g}>@C^ zRrJ<{bZCmq)C6N$x-K)99ac;=lS*vF%qJE;hY`3nUG$dMr?%v|$3xc;q#mVjKBx%S z?^gXz<5u-&09o|hisWj-^dop{%`bdTYAVE}%8Eyn`%YUG8zOBZ=vM61-dxmD0u>6R zH3HoY8eX3#riD?#NRa%$LJcUvuDOA5D4tg?SK*4^lFt;q#Rg?uBFqL8D3tngZry6Cn4;$}(hDKUXPmPn^H)5871sal7N+lyp_DS=^?Q zuj)h(8n!oht@?mk2aVkMY;x6omJQ6#+G?FZ$8;4MDr_ot)wx={Bh{7xQuf&nPLId@ z1ZD3Zk+nw0r>-+xMFf@M{exnKG*hTBV~Lv$dyFlb)z$$Qr;&($W7?z_^1;ROgpjKb z+vzr$(sbDVE2#eqehg=Ujr*=0*LqIOwf_9$(hb9%y+zK2KJBlI1`n&kC$93>`3C_I zYAWR&j@#jM^N#Q#gF_JpatStK5xh?&!A_}uUppMyo8-E3_!&>&|4dLj6={gZxiuB5 z1AK7~m`4mu2~IK~RISBRk8HgP-^xHNrH2~rsf}_p_hadRI=$DFN`FK0YV4_txJZ+-y<51BEf}W&2_H=u#Kz$v6*4R5w_#1o%bFU^fE^<|it6XsnBM8-M zHPUC~n98CV-;P`vBHMof7&^gXzpPVPRokN7RRA}8?D?Q@5}8(K zAsRZ+r@{JVH_kZMI%HY_gGOu_i$}z9O@QTK`T1h$z}(8hT$zXoX3-*0Pg*2iGgNJ& z?-N*Q?gqjo+yBpO)5<>3>KSA_Zx2u?HFJWI>!&Ez11&vw?4?lpV7bam^{rr8Qj3EM zV_wNSrE_tYsf*IVX-g?_baX+Cd_QW*7EpfOL^abT0p{W}7M|%|hi$6fv*nhnU{3Kw z0&J+mzm^G4IS-_nGR{mv80xy;1fcFjpxW&UOhJSKzS(&KV-pRDIx6c<3AqP?gt+d2 z55?jZhX_g|TMwH7{%69#ktJx zN=t8uBcj#E<0*oklf0x!RJn`CgP(!7b8E-v4wOUaDe;CfK7(L*A9o%ZZ76fW@Hiry zd;B|T@n(>ogsYPF%aAYGpVv_hqFRcQW{twI4-UYqb(3_QOrHO(_1`(AH1h3EnnVtf|_r7-{*QjqmB{4VO zjTGe)X1_Jmdd?cmG1T4Gl-2BqN9SZIHWtR^<4^QY z^n4i>Q$ie_x5hbfV<(p_rPM5cqOh3yh}U<~RraR*Y7o+C$gG=N=4<$1nI@L{%)P5_ z!s<}$RbEQCs&l_r^G`mthvC;v_yh1;<5r8j=JCcZaaO(>xv(z)E$E;d{I!(Y)OTp(VKZD25Ay8^ z33p!DfFP0<5?Gk-nmW#ok0zSLK6q;{HK-G2ZjATqGQtCgOzB&5BvuswyS$rPf zde(MjXAN<6yLGg(%lr*M(KqH;VdIo~1C6PeR7t3J0ou4?H=KNf4Fg)Y&SP*`VJ%^X z0*(`?60=8!)|AM>G0!JqI-3M4;9*-=Br6hv$}jNV{wt}o$+MS z8ko_*hK~RM{|N%zuzTl|63l z3b_EXb2ko{WNvYR6-Co-BibMJ*1L!kAyfGPM~x?W>34I-0Fh}_27O^ikLcf)j5zhs z0hrBHCw$IE`$xTS(>3^Da(bp{n7pYAJbzu`y!POR` zFpS`DU;e|@YtJXUe=3g7)Km*S7y!kEPxk5E1CRI+Z2azA=tFt>?`DYU6 zJ}^Zt{p`y^;DG;pi2h(%!zPlu@DG|s2GjBh$gaV_*CkFLa<3CFR;snwOHg4PAB%52Vy*L_l|638 zUsq&gB?jmGEG%A7i#2`=A-g-MRCRV%M+BXnTbL_7h!mn|IS>Lk*}-d9e;F@_FO|Lr z#8QtWtuTt`$t4bl2|X1S`?M-7?KP!g$%pY?Q^kvC4w)KEX(D@n@t$h|{o)EBq0n>U zPLU!ffkC$8JWWj+j1>r{n<|4Jh@)GaoT-0_B9AfHtrQIEqPedKz(ThSm;rL&Kg}lv zs=k}WA02eaZIXg;UvmFGFXps+@REFzyGfYc8~Qy+^2;V zorEYsl$|DJ3|MbA+?GHLu?3Susa&avKYtC9`!nL|f{noQE3Plp1R1`H#B;1Tr#VdK=Gwp}i|YVc1pzhRv(kzRivAVFG{5txe>YN|Wo2MqwE@wY|NTD{tZwC@ z!Mw!ZIpp;BDUAabQ@t+~+6%1@UaGY^E7Nu79FUcZx|#|jkbvORV<3p^>o#Deup%Gh za4vTEql%UbL36cmbCcIw5PPYfq4+&GKzRH1?K%e!Ub=Yk;KgGHfotoFix3jH6jn6w^QGocuGi)Xc%ig=3V7SD?M6 zENXN#Qng*v@y*I_#%((WeWUwiauj9^%BoyVct(^x!##vg%jV{7Q9W0fP>IWdUYBPi z%pr&1J;iLvaQomRnX4Gu@LbBxzi1FK$$C|6V-yh{;aA$9X^u#Wq`;qQBp<}C)I;k| zVILoZ_Y^=hJW&gE$Zk%)hH-_laRl`ceC;f4Fb#zD`gRCD=Wcs%`b#bFS0b;(4LTeU z)1i}+6?jodbe@i~MNi0dMWpnXUTj%W7g6K)LXm1qRq3J!+6q1)=5j8!kDRs6W=fZ*A^rFjHn7-l`AJTvgY7S-|QioQ;915gMW_HYo z8oxG3({NRDYo<4zaEfsCXR4qUyGkLd_73{yWT`Pf!G4w3+2w&z?rHa6!TbI4IntRT z6_(~C&#^>O;LtI>pXcJp_UN9|ZISD6OLIPj#sxkd^9`9U7n(O~0aQ(9(lH}*gY=gO<`D(OZlCG5Uq z>($@y(yp3Y>-fF#tW%BlJZxW3p4k2Z!ipPI9*cTj+vifSS9)8?6|csClFWxA* zPxbbb!1@loeC~_KPS~2BwalPohT0k`6qZH%hpz^|DBGCLGd!(CH|i-bdS4#i{({1c z8zSi;!fIYNuH^RS=GJmwf;ZxTxpbnmfqh<&wKB#0`~_Z5AAYqrWLufQ^Vk)mADXDL zeTSKvSH|s{_LQ=0AyqDg_uANqk347#O+>@YvAL2NyKpsd z3As)Ndt?Wq_3B>5NUVLgr*;Q0HO-xdGCSyuG4kbu4_Z}5;hwK?PO3Dm;kU4l;J}Yj zP#-v~18Lk;!mp4KREw~fsSX9GgO;8vI+`Mg*{sqV#BF~pI$%)ZAqA1DW{4$;WzlvO z2rrou=+4H7;qj*@4r6^MF_m_>F>#-hFx9Dxj+KWC|7_LK75Bxv_qO$U3HsP=7Ru2zx7^lch$HTr5_B6&j~Pq1M>NL^2hNx2jrU zU(&tuLQt*H-6ZVgg`%#2!oh(0I`lb3wXHc zyJw9aw7adY$Km#lVYxmx`)uciOL^_6>-rrCaZkK^$EDaRu5b$}PuN0cFPpTl&n#lh z5oY7r8Z2Pe-fc%8vQ+#d&|B&t8F^><{0?#0L{V-^>O)5Cu1w1iBfilsW32xaB>Q*& zDlFzDc+z?*@+}?lZj`E-cn;yE)3gkx%qx}{5i2L# zTo1yey#SsP68;nyJXN+9RAP9S!VBWb+?*Qk zo!p)5=3j8O)}`aTj*J_xVaW&+iuZ^D1)S%tb*<|gyuCbntCAErsFJFs>S-KBw=^Xt zlZUHs*%vZvGr>*^gO|NUMAL97p^;k#tRjG#Qi<&^Ion1N;i86A#}f7w`2=@l+2oe+ z4FwdE33Cy}9M)U3pGhcbHwcs0qw^tw7%e%i*7-ZID|GE<6|(@9;5=Y&`n(FkUK1m@ zMy6LopPxr-hIw)^nP<})S9oG5$!y5eAb93B?ScbT*NnHC&ApT;&?Cxn(~@7115Iq@ zkgLw5EIOuay01g`Ve={|12X=K>LKH)5a6vjW$t6ukryC>sr}m`k`!NR47RB+2p0~1 z6p^KE&&n1y_K9hG%}qO}I3@&ub=tc+Q_M=;JJw3|-4Xq1Jz=i$Jrj|cZ$wRyMC=Rwv(z{7GB7k|q7gXK1ltA-LeCQW2MT^uKT# z^#^n{9`_$>nO&;^J5z1J+gKPPbGQp?*q!>&dMb8qwqa?&#@OikV4WDPv5F z*j|FgIde#;b6<{_to?jC?ZLZ-b%Ri?{v*rj->O{=Rv%~eRSu>s3c<|AxVKhX zgRylsaxjh+)*EKb`rtn8VK=Ew4`_>1qzSKw~(b2;(M2tILW^ zFJ_hA;na;B?x&0}ZG($Q8~V%%+On)eH(?mq@ws2FyOpgQqdLDYu52aHP?DgF?_(3B zc28OSGO1X6@nWqJa&ZGNsl0DE9?eo7s!Q7}Q`2L9zDt9Z@xhafK!N5i_~QNm1L>mm zH4F1l=jNmt+s%i@m-CrPvW0#z=Zx)=wfo?O-~jw(vvVY^=0nRWBe(G_jr}CF2JQIG zxfXO4!Y9djgXCkUOEd~{ic_e@Iv6f1zR8I$>5?vw)%H;**dne z+8xI@lEJvawiH!rnwhkxw)VUQ-Q1!2t%5yW_FTQl=baPbG4^3rd%4e}vvliTRwD*w3z;}S$B8{#50825%!Ivq zD;f0|i|%l^>;ltth|9udPxn>~yiaj{z3=ja`(4ahbXon0o&w<~j&SYd^RoPX-=}7W zwN83|R3cwIy+hN$rQIZ%$md-L^RwfRb+cc79ax;Vk8|pRIjtLtO$1=8w^?fA zPkbygt0<@l%&d2w9-d^gJkO3*AaI>WGO{;jfu&_v0FIm2=*5zSu;z5Ivq3$_Y!P(qXUQSVEUQdT7vXzHldp&jEHpHC0 z^y(Ygikf^+&?or54)(Ds_MGI!6gmg+ZAdPLg zO(|BIme#Ed&WWrwx3AQEu*J8OW9ByoJ2jull7`8R?5k?916jAy2q!+Suc6f?Tk$kb zVh>S?ao8&SXv_74cIDXm5pjYP@p5ilSK_jx4TjY$RSf)dZN6o{< zYEzUe=~YRQO=0o}nA6e3MG38|s+-5*JroY}IgDf5c_E*fkK-33v07=Rh3UIS^Sd{l zZ5<7Cs?Y7Blfp93?KSPk!8;?nCJYvttFq(0Wsk`Z?)bZf1i(TA3LOrixSwP4I_lpp zRMRwDXq<|ewvvE?aqUR22{4 zeJ0vVv!eB?gU{=!g*zs~q%aNnu(>jKA^C2_AGRH-S$)dZdn+>lUP1b(gI(uU4E;$V zo*LnyYZakpd&itS9m~q~{1##gw7Dh~-1I3wA4kFPtqNDq%$frFhS!~p{!X1fmIc1U zuH-(R50IV&Wh@1qxi|C|yTFvxv&u7bqc%@3SBS9OS7(vvrie9^Xk4YyAFeH4=gxK* zKh&9bhd~Ht+R;*-Za}a_fw7qj0cNYLpi7Ata+R8VZBK5duc@jjep1!33|wZH3VEoJ zZw16^o6Y%(-XjHn$k(R+g^S1YTpdRI_)>jR%abOy#iwKt=v2<1%%z;PrvAhx?5*UE zkPzIsz4QJ&TwHXXC_Gn^zQI2|GQ;-CrgT^6m}eBO2LZvEfWKp?;qw6{vfJiMcELC! z(xjJqDB^S+i8&ob93$U8pRjP0^ieObO2-Yntg{~N8~H(5AlGDT!0HYlF3g3O&8BJ4 zw)Z5l-ejPXs20e=G94s+9(&sXHg92;YT?+N$`#~w%rKS-Tz#(qbI6coW2!!aQj%6Z z(vO`;^~hf}4IySAh#4B!eGL6{iv8W$SbrQ&t-E~na{tx2yXFUOjDPQn!^1zXZLZ<4 z4zm3kAEFM?uEt?5OlSQhGlyt(!%5>~bl=D0f;Al3Q_;nHdKxuo+q$!|@K6vjN8~d3 zE~vkv#OsoanNM($-2S*SUz4ULi6u2VR)=soFazk}Sy}EUE_2qv=?l{5I|i78sdZ3~ z^I3qMqNZ$cZ@GyRz0{F~BzU4ZXQcrGJusaVF@#=ACKSJhbp&9Lszz_;uxe^b`X?A6 z52Duk*!P)@C|V_a?uYG(QWu`jLWgBbZ>Obv3>K!ZQN%})QVgVl5FR40h$|~O+__|P zLD~@Zr2FHtG3MjkU08ml4$58!v)0n5>Is+h+5(0d(R6U!*VY(sr#XHRhN<5oL>KAD z#oRnQzuI}_fyFTpD^o+~7I+(esy$D)P1M{LEmk@GxH=JBhmg#A=<@xPpkTdEMZ&u4 zPsjxLLhsL4P3_a8x@&{P@R-O?m*fd`m?euNR}zNDh&!$nnd#cz1?y-WsB~iMqP`HT z!iFnw%Rc(!bkK5u5n?;%^;P>H+Lp8q+BMcJ0D2^SyX9 zb33qr0&3)In|&J)WA5gZkwk-NuM!aWpt-_#SAC1<07G_CsP9AWJl}*V=? zb1JAUWJi&R`&O`=zJ*FUbBAqr-uUeKIkiZWEr!E?S@!iH=@wMUtFU_)xw^u649CV1 zkqL~)V0u5*Co4F4O?*2)0qP9XkYBWsjj~T@uT-p7tcDvnPT~3 z;*Ci6ubBK}F)Cbl>L)*iP_qr&b4Z!$gvaWVb=6&hww{i+;MNrNk^bDGC#t*|Q0nwq zn!vM+RQB}+GJ{m-7FBkM{WD$BVWT42ZG(kD_5GmxPuskvGUiuihh6!FUgEfQ_zI>Y zdSWc$jO8$5YG8}+lwy@Gk;ew2)w8_*^s%R6K8=gK{Kh_c*C%GGr^lq$MR8-=w>nWQ z3$^Q+J#VbX^X;~%G)3mdlqjtj(@eG8;^*Tasb-`EEjhr4^WaruLnVc8ScCF_ESy3& zy)2t^i9gGI>KDm+Xk6Ou_x9B=E~4~Y*)&x6F?SfnoQ?tQZFnL;0x2>cq94eeY!8^K zsXT;L+t)_R1~?axuK-o1M1)p$u1}@ESa+v7=2$R?XpWx_<&YngR9DR z_q+Y*TB%=5JU7?ou}Vj`YYI?RP$;#m+sR*$MK{$K*9!vyBk+=D5~~gS*{R8-CV@w9 zkF|b4XGrL25#s%)J|xc#_V{R8Pl0O0yO`rjRR5b$5NE|i{%efXgEyCyTF;sz%X!ll z*uan}G4ujwiWK5H9zQPBaH!3;w1+i0V>G7bB!XhD6#3c{)@@z7O7-SiUVc?`%r`Es z?M~bl*@Op(DZAh^1UK6(L1%%Xn z${>(e_)+iS*Uxe0^D#hbYJAimwX&~MJ(R|Qha#guE}Q70rS|o8%6uo|6FbC`YVOn7 zQI2beKV686i9UNL8aUs`zKievO>4=5*3)p5*JWv*0iUSsOPQ7cjV(8*s?H}?k3Das|$cLcvR||IUSG0`5wTUsm$wof}Xfz=%!T=93UM( zjo;7bYRdG2P<%H=m3IBgj6E`UdgyuNv!qv^QAXSFbaz{g1u|^A`Pb}nJ=CuKF?oZ^ z^sxZ)s$Kd9q*fzLiRa31s>9sfp&6%Og)f9hj9XJjxCL7oGV?gFY$x=5db0gl?#uX{ zuC*F_)$~C7469h0@=BJvH{@!W)JXnMJ|PaYEv*=txJp7IISt#}Fe+g2?Nq;CC_52e zrrpUHv8M?~e=87?{)MvI&_d9R^_1ZBW?#qyd`U5NR653LSjSJA`NmlNN_h~V5JF`Cr?OkniysA4- zX;%LNF3soGg`hL}t9lofBxoE&_MHUBZLUmWYneQR zMWBnHg5B4&b4IXgm0(_$hqU3q{w+ zY5gdOk3iKQ!LF)mWE!?(2WOtS^1DjSXT{*gCl(U2jcnXATQoB5BkTjq97PU9UQ)}T z1L$jLZYIcY_}U?RM5oiiy&4svy35(lTn!}oo0>P3;W8A{+O z(8PXzqGGgudq}Mx<$UrCJ!w_T+vH08Xy4NUZFs7i1f}O;RHk%1Xqc^;WGwuKwf#r0Zb^BU>A~>O>Gln0dGN4%= zYVz+OPV)o#`%FWs%ikQ!i(|&CkhiSLf8!_ zY9kR<#LLd_MJ)5jmsRP1<`0%|G*Pqq%-%TB4oeL5kMxyK>n*z}H``_RU%M>3o^waP zT&%q5xS&tc2q7-4aaPLGFU> z(R4R`aK4|hVLz#nwR_6!XX@eEaC#&W-`Z@RZo9ab^$2=;vyE8qG@WEOPKM0C)!YBx z-4H%Ie)4s}cn&sL3N<}#GbW1&-Y;AKtTV==S|_?iUurf!O!@=eSI66d)4{OsEMHzw z)fFV`ajD;m$l5_hn9tT11^|G6fSOU+tNXZfib#=9eii-rjOH4B;S(MC1$|F0n@k}@J)L6<#MrmO)HX#= zJk~jH|Bs7~9)%w-b>Cs-{QmmXfb9zGOk3tIL;7td6^ZZjjo*`RTW=jf$x;5W-m3zx zA!5>Y^(lE~LI!O5t>@f@VR4}>5$O}BoBEL%dfYKfit@aB^W!;tRvizCr?L6gKA1m=evU_!+eyDTTO8Qm49EB5mloL~7t*YD=hhuX_g44XBGHs}66uDX& zUV7V{iYR+waWpr3`i>Y6^gx@L*rv*_D;)Y5SmY;ZfSfdl&gl=*rk>PYkne7gSsYVy z80vokg#{|zsB&tO)$qUXK9N<~Ty3Y(m!NJqWRVOt)asxnub3^P=9WxW9EKv-E+-c1 za%0BH(ViiUA$>3F^ULxhx5hU~2XvC_*4pw`{XE11pKE%Jpe`yWF<6nm`Q_+YF5A z?0S*fw3!utcNy>AKMgO=_VBERDSO+yDvFTrQ~6n?9&by~8~#m9lCgI3+_f)JZNPZj ziT{{!>*+RMc9vS5e{FAtu)Qgl3L9-`CR|D6FnEUF?(d=YF~JBWYoir6rkpENRmdK9EM6g3MOX!nB_rI#8IG2M4`3GO093F)}JWI3fZBRR2)$}5o zPxii*!B?DVGkGy0!f}o7tO*WGH64Kq_2B*TV#{oPq$ktEsb|+WNE8`FXb50GgnoDf zqST@AwtAh-W$J-}saUnc#MubE|I$~r{-F8$gX-dStVcB{{bUx}YYr?_v$(Ow)d_u= z$nbCO9OiA~J$RcDy0Ls+{YUV^eyOJIbg+cbf*B%FY*DubQN7Op>Q|o}-XPddn32=^ zkxsnecdJ@^#7}&`zclT>IdH$p#HvtebxOg`K6Ac$);<)dQ{usn>1lF*$m(}{CGzI7 z#?m&SC(SPR-b3!3B)lmQL4|+IJXYP%AG$n}ct3)UQHcA>lbKQ2z3d+DP~+MOjZfWq zN@FxzWUo-oGE-!6_UKAlWq9PalDG!_;*|q4Rk3)L{olJG0TqGbqC$azsr&W#PJz3I zZ}wTxUr3H1LRVSg{l$+C`vb8&3%q$iXKm(`#Eku%OgM=W>s0}Kx=)4``j>U1yo7Ae z%!i(!27#LA{!)i+46HBmRg+npYXEf+ssYVCXdvzI=A;Ia#?yycO}gRzp_}W~yRJI_ zDTRqeyAgfH5JYNaL$T-b5W#^_Q>+?cwY^!E+ka=HtL3k)eJEht=3#J z@Z&b)cpnOO1v~(RM1t6_(yYp0j(FF$r8H*wX=Q%U!np zBP8z4$Aqp8UDx37Ff$s%BOv%j)ra7J2|0ZJi9E_j0Izn=JOn~(QKjf(wB2yHWAymzs-p z9^O|T0KmHtU>cZ8MlQXHx+mm8gI%cGg6XMA8%9am>^NU@2smqT?O;er-%^qCT#!C``@Cy_ELif`8F5d^9rsnA4(aM%>z!M%YtZr@4#&Qn4{f(gTJ#=URp_Vpfw zyQjr(F)!Avj!9l}wiS8;2@F47$#)4^8VEm-ksQZ&^%=d+!7^uD8xh3n&64XA+s@%H z_%8vc3@dC8H~ygJ>yf!jD3T`5^%wZm)X?8}_2jda@dq*@%GU9%pAQHLvB!O8Kc9|6 z!%Hj97>)g59?@n`S$%?#3*THCI)IpaW-K>i1+}+6xLs&QqIzwvIh>y&qRWOsC|b}VLR&mmI0Dct^hxphm~rR|qW{dU{bZju8|#Lo9nw|neAa>fP}T769YV(UfNiy% z2;Ja@YDZeC?>j718e10pWotvVu7SywO+xPC!cT%ta!4Z{;=)2tTpzOFEA>JWi@kpR z!FdiIzi<1Y5O9P8$i@;kr3T-qxom6z=dC@{@5x<6c`x07`f(mb07k9Gx&wh=$&U2{ zh<{R4wYjmnX}-y+W?d8#a#bQ5s{=;(f3F{MTOaa&uhLgrHspUV&%rI9_dnN@)c<$8 z|M%?v2NJ+H|9=C;L<;c3|9jQ=Y*0ft<})@~8;Ab)H$qS5M}CA?6GQ*?fBpQ&`ebm|*6;Lh8r>+ocJHtM13+}Ah5!Hn literal 0 HcmV?d00001 diff --git a/setup.py b/setup.py index bf17899..8f3fa72 100755 --- a/setup.py +++ b/setup.py @@ -16,6 +16,7 @@ from typing import List, Tuple from setuptools import setup, find_packages, Command +import torch from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension try: @@ -57,7 +58,7 @@ def run(self): def install(use_cuda, use_nccl): ext_libs = [] if pf.system() == 'Linux': - ext_args = ['-Wno-sign-compare', '-Wno-unused-but-set-variable', '-Wno-terminate', '-Wno-unused-function', '-Wno-strict-aliasing'] + ext_args = ['-w'] elif pf.system() == 'Darwin': ext_args = ['-mmacosx-version-min=10.13'] else: @@ -80,7 +81,7 @@ def install(use_cuda, use_nccl): setup( name='tutel', - version='0.3', + version='0.4', description='An Optimized Mixture-of-Experts Implementation.', url='https://github.com/microsoft/Tutel', author='Microsoft', @@ -138,16 +139,14 @@ def install(use_cuda, use_nccl): }, ) -if int(os.environ.get('NO_CUDA', 0)) == 1: - print('Installing without CUDA extension..') - install(use_cuda=False, use_nccl=False) -else: +if (torch.version.cuda or torch.version.hip) and int(os.environ.get('NO_CUDA', 0)) == 0: try: + print('Try installing with NCCL extension..') install(use_cuda=True, use_nccl=True) except: print('Try installing without NCCL extension..') - try: - install(use_cuda=True, use_nccl=False) - except: - print('Try installing without CUDA extension..') - install(use_cuda=False, use_nccl=False) + install(use_cuda=True, use_nccl=False) +else: + print('Installing without CUDA extension..') + install(use_cuda=False, use_nccl=False) + diff --git a/tutel/custom/antares_ops.h b/tutel/custom/antares_ops.h new file mode 100644 index 0000000..4b92021 --- /dev/null +++ b/tutel/custom/antares_ops.h @@ -0,0 +1,260 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +// Plugin reference: https://github.com/microsoft/antares + +#include +#include + +#include +#include +#include + +#undef AT_ASSERTM +#define AT_ASSERTM(cond, ...) \ + if (C10_UNLIKELY_OR_CONST(!(cond))) { \ + ::c10::detail::torchInternalAssertFail(__func__, "torch_ext.hpp", \ + static_cast(__LINE__), #cond " INTERNAL ASSERT FAILED at " C10_STRINGIZE("torch_ext.hpp") ":" C10_STRINGIZE( \ + __LINE__) ", please report a bug to PyTorch. ", c10::str(__VA_ARGS__)); \ + } + +#undef TORCH_WARN +#define TORCH_WARN(...) \ + ::c10::warn(::c10::Warning( \ + ::c10::UserWarning(), \ + {__func__, "torch_ext.hpp", static_cast(__LINE__)}, \ + WARNING_MESSAGE_STRING(__VA_ARGS__), \ + false)); + + +#if defined(__linux__) +typedef ssize_t llong; +#else +typedef long long llong; +#define ssize_t llong +#endif + +#define __RUNTIME_MODE__ +#include "backend.hpp" + +#include +#include + +#if !defined(Antares) +#define Antares CUDA +#endif + +#define ANTARES_DEV c10::DeviceType::Antares + +static c10::Device get_device() { static c10::Device dev = c10::Device(ANTARES_DEV, getenv("LOCAL_RANK") ? std::atoi(getenv("LOCAL_RANK")) : 0); return dev; } +static bool is_verbose = false; + +#define DEBUG_FUNC(x) // printf("[DEBUG] ::%s\n", x) + +std::string read_file(const std::string &path) { + std::ifstream t(path, std::ios::binary); + if (t.fail()) + return ""; + std::string _((std::istreambuf_iterator(t)), std::istreambuf_iterator()); + return _; +} + +#define OP_LOADER "OP_LOADER" + +std::string get_ops_root() { + static std::string ops_root; + if (ops_root.size() == 0) { + auto root_path = getenv(OP_LOADER); + AT_ASSERTM(root_path != nullptr && *root_path != 0, OP_LOADER " is not set, please configure this environment variable correctly."); + ops_root = root_path; + } + return ops_root; +} + +const char* get_backend_type() { +#if !defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_AMD__) + return "c-cuda"; +#else + return "c-rocm"; +#endif +} + + +namespace antares { +namespace ops { + +at::Tensor call(const void *key, const std::vector &ts, const std::vector &ps, bool allow_non_contiguous = false, size_t key_length = 0, int output_index = -1) { + DEBUG_FUNC((const char*)key); + + struct kernel_object { + std::vector symbol; + std::vector args; + + int output_exist; + torch::Dtype output_dtype; + std::vector output_shape; + + std::string entry_name, name; + }; + + auto key_id = key; + + static std::unordered_map kernel_dict[16]; + const auto &curr_dev = ts.size() > 0 ? ts[0].device() : get_device(); + auto &kernels = kernel_dict[curr_dev.index()]; + + if (is_verbose) { + std::string fname = key_length ? std::string((const char*)key, (const char*)key + key_length) : std::string((const char*)key); + TORCH_WARN("AutoRT on executing new function: `", fname, "`"); + } + + auto it = kernels.find(key_id); + if (it == kernels.end()) { + // load function module + + auto ssplit = [](const std::string &str, const std::string &sub, bool allow_empty = false) -> std::vector { + std::vector ret; + int it = 0, next; + while (next = str.find(sub, it), next >= 0) { + if (next > it || allow_empty) + ret.push_back(str.substr(it, next - it)); + it = next + sub.size(); + } + if (it < str.size() || allow_empty) + ret.push_back(str.substr(it)); + return ret; + }; + std::string fname = key_length ? std::string((const char*)key, (const char*)key + key_length) : std::string((const char*)key); // key.toStringRef(); + // TORCH_WARN("AutoRT on registering new function: `", fname, "`"); + + auto image = read_file(get_ops_root() + "/" + fname + ".mod"); + AT_ASSERTM(!image.empty(), "Failed to load operator module: ", fname.c_str()); + + int pos = image.find("||"); + AT_ASSERTM(pos >= 0, " "); + auto meta = image.substr(0, pos); + auto data = image.substr(pos + 2); + + // load symbol + std::unordered_map th; + std::unordered_map fn; + auto metas = ssplit(meta, "|"); + + if (metas.size() >= 5) { + AT_ASSERTM(!strcmp(metas[4].data(), get_backend_type()), "External operator module `", fname.c_str(), "` is not designed for current backend."); + } + + auto fentry = metas[0]; + for (auto sect: ssplit(metas[1], ";")) { + auto kvs = ssplit(sect, "="); + th[kvs[0]] = std::atoll(kvs[1].c_str()); + } + for (auto sect: ssplit(metas[2], ";")) { + auto kvs = ssplit(sect, "="); + fn[kvs[0]] = kvs[1]; + } + + kernels[key_id] = {}, it = kernels.find(key_id); + it->second.symbol = ab::moduleGetFunction(ab::moduleLoad(data), fentry, th); + it->second.entry_name = fentry; + it->second.name = fname; + + // load argument config + for (int i = 0; ; ++i) { + auto jt = fn.find("arg_" + std::to_string(i)); + if (jt == fn.end()) + break; + auto options = ssplit(jt->second, ":", true); + llong input_id = (options[0] == "") ? ~0 : std::atoll(options[0].c_str()); + llong second_ref = std::atoll(options[1].c_str()); + llong use_fp32 = (options[2] == "float32"); + + llong comb = (use_fp32 << 63) | (second_ref << 32) | ((unsigned int)input_id); + it->second.args.push_back(comb); + } + + // load output config + auto o_type = ssplit(fn["o_type"], ":", true); + if (o_type[0] == "infer") { + it->second.output_exist = -1; + + static std::unordered_map key_to_dtype = { + {"int8", torch::kInt8}, {"int16", torch::kInt16}, {"int32", torch::kInt32}, {"int64", torch::kInt64}, + {"bfloat8", at::kFloat8_e5m2}, {"float8", at::kFloat8_e4m3fn}, {"bfloat16", torch::kBFloat16}, {"float16", torch::kFloat16}, {"float32", torch::kFloat32}, {"float64", torch::kFloat64}, + {"bfloat2x16", at::kComplexHalf}, {"float2x16", at::kComplexHalf}, {"float2x32", at::kComplexFloat}, + }; + + auto dtype_it = key_to_dtype.find(o_type[1]); + if (dtype_it != key_to_dtype.end()) + it->second.output_dtype = dtype_it->second; + else + it->second.output_dtype = at::kComplexDouble; + + for (auto dim: ssplit(o_type[2], ",")) { + if (dim[0] == '#') + it->second.output_shape.push_back(~std::atoll(dim.c_str() + 1)); + else + it->second.output_shape.push_back(std::atoll(dim.c_str())); + } + } else { + AT_ASSERTM(o_type[0] == "exist", "`o_type` is not recognized: ", o_type); + it->second.output_exist = std::atoll(o_type[1].c_str()); + } + } + + auto &prop = it->second; + + std::vector krnl_args; + for (int i = 0; i < ts.size(); ++i) { + if (ts[i].device().type() != ANTARES_DEV) { + std::string error_msg = "\nThe " + std::to_string(i + 1) + "-th argument of `antares.ops." + prop.name + "(...)`is not a CUDA tensor."; + AT_ASSERTM(0, error_msg); + } + if (!allow_non_contiguous) + AT_ASSERTM(ts[i].is_contiguous(), "Not contiguous tensor for custom kernel"); + krnl_args.push_back((void*)ts[i].data_ptr()); + } + + int output_exist = prop.output_exist; + if (output_index >= 0) + output_exist = output_index; + if (output_exist == -1) + krnl_args.push_back(nullptr); // placeholder for output + + size_t param_offset = krnl_args.size(); + + // construct argument values + for (auto it: prop.args) { + llong use_fp32 = (it & (1LL << 63)); + it = (it & ~(1LL << 63)); + auto *ids = (unsigned int*)⁢ + if (ids[0] != ~0) + krnl_args.push_back((void*)ts[ids[0]].size(ids[1])); + else if (!use_fp32) + krnl_args.push_back(*(void**)ps[ids[1]].data_ptr()); + else { + float fp32val[2]; + fp32val[0] = (float)*(double*)ps[ids[1]].data_ptr(); + krnl_args.push_back(*(void**)fp32val); + } + } + krnl_args.push_back(nullptr); + + std::vector shape = prop.output_shape; + for (int i = 0; i < shape.size(); ++i) + if (shape[i] < 0) + shape[i] = (ssize_t)krnl_args[param_offset + (~shape[i])]; + + torch::Tensor output; + if (output_exist == -1) { + output = torch::empty(shape, torch::TensorOptions().dtype(prop.output_dtype).device(curr_dev)); + krnl_args[param_offset - 1] = (void*)output.data_ptr(); + } else + output = ts[output_exist]; + + ab::launchKernel(prop.symbol, krnl_args, nullptr); + return output; +} + +} // namespace ops +} // namespace antares diff --git a/tutel/custom/backend.hpp b/tutel/custom/backend.hpp new file mode 100644 index 0000000..146947f --- /dev/null +++ b/tutel/custom/backend.hpp @@ -0,0 +1,224 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +// Plugin reference: https://github.com/microsoft/antares + +#if defined(_WIN64) +#include +#include +#endif + +#if !defined(CHECK_OK) +#define CHECK_OK(x) ((x) ? 1 : (fprintf(stderr, "[CheckFail] %s:%d\n", __FILE__, __LINE__), exit(1), 0)) +#endif + +#if !defined(__RUNTIME_MODE__) +#define GET_STREAM() ((CUstream)stream) +#else +#include +#include +#include +#define GET_STREAM() at::cuda::getCurrentCUDAStream().stream() +#endif + +#if !defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_AMD__) +#include +#include +#else +#include +#include +#define cuInit hipInit +#define cuMemAlloc hipMalloc +#define cuMemFree hipFree +#define cuModuleLoad hipModuleLoad +#define cuModuleLoadData hipModuleLoadData +#define cuModuleUnload hipModuleUnload +#define cuModuleGetFunction hipModuleGetFunction +#define cuLaunchKernel hipModuleLaunchKernel +#define cuMemAllocHost hipHostMalloc +#define cuMemFreeHost hipHostFree +#define cuStreamSynchronize hipStreamSynchronize +#define cuCtxSynchronize hipDeviceSynchronize +#define cuMemcpyHtoDAsync hipMemcpyHtoDAsync +#define cuMemcpyDtoDAsync hipMemcpyDtoDAsync +#define cuMemcpyDtoHAsync hipMemcpyDtoHAsync +#define CUdeviceptr hipDeviceptr_t +#define CUmodule hipModule_t +#define CUfunction hipFunction_t +#define CUevent hipEvent_t +#define cuEventElapsedTime hipEventElapsedTime +#define cuEventCreate hipEventCreateWithFlags +#define cuEventDestroy hipEventDestroy +#define cuEventRecord hipEventRecord +#define CUcontext long +#define cuDevicePrimaryCtxRetain(x, y) (*(x) = (CUcontext)((long)(y)), 0) +#define cuCtxSetCurrent(x) hipSetDevice((int)(x)) +#define cuCtxGetCurrent(x) hipGetDevice((int*)(x)) +#define cuCtxGetDevice(x) hipGetDevice((int*)(x)) +#define CUstream hipStream_t +#define nvrtcGetCUBIN hiprtcGetCode +#define nvrtcGetCUBINSize hiprtcGetCodeSize +#endif + + +namespace ab { + + static int _current_device; + static std::unordered_map> _cached_memory; + + int init(int dev) { + static bool _retained = false; + if (_retained) + return _current_device; + _retained = true; + + CUcontext ctx; + if (dev < 0) { + if (0 == cuCtxGetDevice(&_current_device)) { + cuDevicePrimaryCtxRetain(&ctx, _current_device); + cuCtxSetCurrent(ctx); + return _current_device; + } + dev = getenv("LOCAL_RANK") ? std::atoi(getenv("LOCAL_RANK")) : 0; + } +#if !defined(__RUNTIME_MODE__) + setenv("CUDA_VISIBLE_DEVICES", std::to_string(dev).c_str(), 1); +#else + _current_device = dev; +#endif + if (0 != cuInit(0) || 0 != cuDevicePrimaryCtxRetain(&ctx, _current_device) || 0 != cuCtxSetCurrent(ctx)) + throw std::runtime_error("GPU device is not found.\n"); + return _current_device; + } + + void finalize() { + } + + inline size_t compute_slotsize(size_t value) { + if (value >= (1LL << 30)) + return value; + value -= 1; + value |= value >> 1; + value |= value >> 2; + value |= value >> 4; + value |= value >> 8; + value |= value >> 16; + value |= value >> 32; + value += 1; + return value; + } + + void* alloc(size_t byteSize, const std::vector &shape, const std::string &dtype, const std::string &name) { + init(-1); + + byteSize = compute_slotsize(byteSize); + auto &it = _cached_memory[byteSize]; + if (it.size()) { + auto dptr = it.back(); + it.pop_back(); + return dptr; + } + void *dptr = nullptr; + if (byteSize) + CHECK_OK(0 == cuMemAlloc((CUdeviceptr*)&dptr, byteSize)); + else + dptr = (void*)1LU; + return dptr; + } + + void release(void *dptr, size_t byteSize) { + byteSize = compute_slotsize(byteSize); + auto &it = _cached_memory[byteSize]; + it.push_back(dptr); + } + + void* moduleLoad(const std::string &binary) { + init(-1); + const char* data = binary.data(); + CUmodule hmod = nullptr; + CHECK_OK(0 == cuModuleLoadData(&hmod, data)); + return hmod; + } + + std::vector moduleGetFunction(const void *hModule, const std::string &fname, const std::unordered_map &threads) { + auto query = [&](const std::string &axis, long defval = 1) -> void* { + auto it = threads.find(axis); + if (it == threads.end()) + return (void*)defval; + return (void*)(long)it->second; + }; + + CUfunction hfunc = nullptr; + CHECK_OK(0 == cuModuleGetFunction(&hfunc, (CUmodule)hModule, fname.c_str())); + std::vector fdata = { hfunc, query("blockIdx.x"), query("blockIdx.y"), query("blockIdx.z"), query("threadIdx.x"), query("threadIdx.y"), query("threadIdx.z") }; + + void *item = query("$", 0); + if (item) { + fdata.push_back(item); + fdata.push_back(query("$$", 1)); + + for (int i = 0; ; ++i) { + void *item = query("$" + std::to_string(i), 0); + if (!item) + break; + fdata.push_back(item); + } + } + return fdata; + } + + void launchKernel(std::vector &hFunc, const std::vector &krnl_args, void *stream) { + std::vector pargs(krnl_args.size()); + for (int i = 0; i < krnl_args.size(); ++i) + pargs[i] = (void*)&krnl_args[i]; + + if (hFunc.size() > 7) { + long attrs = (long)hFunc[8]; + for (int i = 9; i < hFunc.size(); ++i) { + long val = (long)hFunc[i]; + if (val == -1) continue; + + auto ptr = (int*)pargs[i - 9 + (long)hFunc[7]]; + attrs *= (val > 0) ? ((*ptr + val - 1) / val) : (*ptr * (-val)); + } + hFunc[1] = (void*)attrs; + if (!hFunc[1]) return; + } + + CHECK_OK(0 == cuLaunchKernel((CUfunction)hFunc[0], (long)hFunc[1], (long)hFunc[2], (long)hFunc[3], (long)hFunc[4], (long)hFunc[5], (long)hFunc[6], + 0, GET_STREAM(), (void**)pargs.data(), nullptr)); + } + + void memcpyHtoD(void *dptr, void *hptr, size_t byteSize, void *stream) { + CHECK_OK(0 == cuMemcpyHtoDAsync((CUdeviceptr)dptr, hptr, byteSize, (CUstream)stream)); + } + + void memcpyDtoD(void *dptr, void *dptr0, size_t byteSize, void *stream) { + CHECK_OK(0 == cuMemcpyDtoDAsync((CUdeviceptr)dptr, (CUdeviceptr)dptr0, byteSize, (CUstream)stream)); + } + + void memcpyDtoH(void *hptr, void *dptr, size_t byteSize, void *stream) { + CHECK_OK(0 == cuMemcpyDtoHAsync(hptr, (CUdeviceptr)dptr, byteSize, (CUstream)stream)); + } + + void synchronize(void *stream) { + CHECK_OK(0 == cuStreamSynchronize((CUstream)stream)); + } + + void* recordTime(void *stream) { + CUevent hEvent; + CHECK_OK(0 == cuEventCreate(&hEvent, 0)); + CHECK_OK(0 == cuEventRecord(hEvent, (CUstream)stream)); + return hEvent; + } + + double convertToElapsedTime(void *hStart, void *hStop) { + CHECK_OK(0 == cuCtxSynchronize()); + + float ms; + CHECK_OK(0 == cuEventElapsedTime(&ms, (CUevent)hStart, (CUevent)hStop)); + CHECK_OK(0 == cuEventDestroy((CUevent)hStart)); + CHECK_OK(0 == cuEventDestroy((CUevent)hStop)); + return ms * 1e-3; + } +} diff --git a/tutel/custom/custom_kernel.cpp b/tutel/custom/custom_kernel.cpp index 617580a..32960b3 100644 --- a/tutel/custom/custom_kernel.cpp +++ b/tutel/custom/custom_kernel.cpp @@ -42,6 +42,8 @@ #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #if defined(USE_GPU) +#include "antares_ops.h" + namespace jit { inline static std::string file_read(const char *path) { @@ -460,6 +462,13 @@ static torch::Tensor& nccl_stream_acquire(torch::Tensor &tensor, int idx) { return tensor; } +static torch::Tensor warp_x_add_allreduce_y_f16(const torch::Tensor &x, const torch::Tensor &t) { + AT_ASSERTM(shared_world_size > 0, "Failed to initialize Shared NCCL"); + auto stream = at::cuda::getCurrentCUDAStream(); + ncclAllReduce(t.data_ptr(), t.data_ptr(), t.numel(), ncclFloat16, ncclSum, (ncclComm_t)shared_nccl_comm, stream); + return x + t; +} + static void batch_all_to_all_v(const std::vector &ins, const std::vector &outs, const torch::Tensor &in_sizes_, const torch::Tensor &out_sizes_) { AT_ASSERTM(shared_world_size > 0, "Failed to initialize Shared NCCL"); @@ -888,8 +897,109 @@ torch::Tensor warp_sparse_bmm_infer(const torch::Tensor &x, const torch::Tensor return y; } +torch::Tensor warp_gemv_nt_fp16xfp8_block_scal(const torch::Tensor &x, const torch::Tensor &w, const torch::Tensor &scal) { + CHECK_CUDA(x); + return antares::ops::call("gemv_nt_fp16xfp8_block_scal", {x.view({-1}).view(torch::kInt32), w.view(torch::kInt16), scal}, {}); +} + +torch::Tensor warp_deepseek_r1_static_gating_f16( + const torch::Tensor &x, + const torch::Tensor &gate_moe, + const torch::Tensor &gate_bias, + const ::std::optional &top_v_out_, + const ::std::optional &top_k_out_) { + CHECK_CUDA(x); + AT_ASSERTM(gate_moe.size(0) == 256, "Deepseek R1 requires 256 experts for gating."); + AT_ASSERTM(x.numel() == x.size(-1), "Batch size > 1 not enabled in this version."); + + auto top_v_out = top_v_out_.has_value() ? top_v_out_.value().view({1, -1}) : torch::empty({1, 8}, torch::TensorOptions().dtype(torch::kFloat32).device(x.device())); + auto top_k_out = top_k_out_.has_value() ? top_k_out_.value().view({1, -1}) : torch::empty({1, 8}, torch::TensorOptions().dtype(torch::kInt64).device(x.device())); + AT_ASSERTM(top_v_out.dtype() == torch::kFloat32 && top_k_out.dtype() == torch::kInt64, "Output tensor space should be float32 for top_scores and int64 for top_ids."); + + auto gate_scores = antares::ops::call("sigmoid_gemm_out", {x.to(torch::kFloat16).view(torch::kInt32).view({-1}), gate_moe.to(torch::kFloat16).view(torch::kInt32)}, {}); + antares::ops::call("deepseek_r1_top_k_f32", {gate_scores.view({1, -1}), gate_bias.to(torch::kFloat32), top_v_out, top_k_out}, {}, false, 0, 3); + return top_k_out; +} + +torch::Tensor warp_deepseek_r1_latent_attn_f16( + const torch::Tensor &data, + const torch::Tensor &key_cache, + const torch::Tensor &val_cache, + const torch::Tensor &rms_att_w, + const torch::Tensor &qkv_a_proj, + const torch::Tensor &qkv_a_proj_scal, + const torch::Tensor &q_a_norm, + const torch::Tensor &kv_a_norm, + const torch::Tensor &q_b_proj, + const torch::Tensor &q_b_proj_scal, + const torch::Tensor &kv_b_proj, + const torch::Tensor &kv_b_proj_scal, + const torch::Tensor &o_proj, + const torch::Tensor &o_proj_scal, + int64_t pos +) { + CHECK_CUDA(data); + auto x = data; + auto xb = antares::ops::call("lnorm_f16", {x, rms_att_w}, {1e-6f}).view({1, 1, -1}); + auto qkv = warp_gemv_nt_fp16xfp8_block_scal(xb, qkv_a_proj, qkv_a_proj_scal).view({1, 1, -1}); + auto q = qkv.narrow(-1, 0, 1536), kv = qkv.narrow(-1, 1536, 512), k_pe = qkv.narrow(-1, 2048, 64); + auto k_pe_out = torch::empty_like(k_pe); + antares::ops::call("rotary_emb_f16", {k_pe.view({-1, 32, 2}), k_pe_out.view({-1, 2, 32})}, {9.210340372f, pos}, false, 0, 1); + q = antares::ops::call("lnorm_f16", {q, q_a_norm}, {1e-6f}); + kv = antares::ops::call("lnorm_f16", {kv, kv_a_norm}, {1e-6f}); + q = warp_gemv_nt_fp16xfp8_block_scal(q, q_b_proj, q_b_proj_scal); + kv = warp_gemv_nt_fp16xfp8_block_scal(kv, kv_b_proj, kv_b_proj_scal); + auto query_states = q.view({1, 1, -1, 192}); + auto q_pe = query_states.narrow(-1, 128, 64).contiguous(); + auto q_pe_out = torch::empty_like(q_pe); + antares::ops::call("rotary_emb_f16", {q_pe.view({-1, 32, 2}), q_pe_out.view({-1, 2, 32})}, {9.210340372f, pos}, false, 0, 1); + antares::ops::call("cache_fill_f16", {q_pe_out, k_pe_out, query_states, key_cache.select(0, pos)}, {128}, false, 0, 3); + antares::ops::call("cache_move_f16", {kv.view({-1, 2, 128}), key_cache.select(0, pos), val_cache.select(0, pos)}, {}, false, 0, 2); + auto key_states = key_cache.narrow(0, 0, pos + 1).unsqueeze(0); + auto value_states = val_cache.narrow(0, 0, pos + 1).unsqueeze(0); + + int n_heads = query_states.size(2); + auto lm = torch::empty({2, n_heads, 64}, torch::TensorOptions().dtype(torch::kFloat16).device(query_states.device())); + + if (pos >= 63) { + auto attn_output = antares::ops::call("self_attn_infer_f16", {query_states.squeeze(0).squeeze(0), key_states.squeeze(0), value_states.squeeze(0), lm}, {0.1352337788608801}); + lm = torch::matmul(antares::ops::call("self_attn_reduce_f16", {lm}, {}).unsqueeze(1), attn_output).view({-1}); + } else { + lm = std::get<0>(at::native::_scaled_dot_product_attention_math(query_states.permute({0, 2, 1, 3}), key_states.permute({0, 2, 1, 3}), value_states.permute({0, 2, 1, 3}), {}, 0, false, {}, 0.1352337788608801)); + } + return warp_gemv_nt_fp16xfp8_block_scal(lm, o_proj, o_proj_scal).view({1, 1, -1}); +} + +torch::Tensor warp_glu_expert_f16xf8_block_scal( + const torch::Tensor &x, + const torch::Tensor &expert_ids, + const torch::Tensor &expert_weight, + const torch::Tensor &moe_gate_up_w, + const torch::Tensor &moe_gate_up_s, + const torch::Tensor &moe_down_w, + const torch::Tensor &moe_down_s) { + + CHECK_CUDA(x); + auto xb = antares::ops::call("half2_gemm_silu_left_mul_right", {x.view({-1}), expert_ids, moe_gate_up_w.view(torch::kInt16), moe_gate_up_s.view(torch::kInt64)}, {}); + return antares::ops::call("half2_gemm_mul_sum", {xb.view(torch::kInt32), expert_weight, expert_ids, moe_down_w.view(torch::kInt16), moe_down_s}, {}).view({-1, 1, moe_down_w.size(1)}); +} + +torch::Tensor warp_lnorm_f16(const torch::Tensor &x, const torch::Tensor &rms_ffn_w, double eps) { + CHECK_CUDA(x); + return antares::ops::call("lnorm_f16", {x, rms_ffn_w}, {eps}).view(x.sizes()); +} + + TORCH_LIBRARY(tutel_ops, m) { m.def("cumsum", warp_cumsum); m.def("sparse_bmm_infer", warp_sparse_bmm_infer); + m.def("lnorm_infer_f16", warp_lnorm_f16); + + m.def("gemv_nt_fp16xfp8_block_scal", warp_gemv_nt_fp16xfp8_block_scal); + m.def("glu_expert_f16xf8_block_scal", warp_glu_expert_f16xf8_block_scal); + + m.def("x_add_allreduce_y_f16", &warp_x_add_allreduce_y_f16); + m.def("deepseek_r1_static_gating_f16", warp_deepseek_r1_static_gating_f16); + m.def("deepseek_r1_attn_f16xf8_block_scal", warp_deepseek_r1_latent_attn_f16); } #endif diff --git a/tutel/system.py b/tutel/system.py index 6c237a2..a7a3c9b 100644 --- a/tutel/system.py +++ b/tutel/system.py @@ -40,6 +40,7 @@ def on_quit(): dist.destroy_process_group() except: pass + os._exit(0) import atexit atexit.register(lambda *args: on_quit())