diff --git a/Huggingface2CoreML.ipynb b/Huggingface2CoreML.ipynb new file mode 100644 index 0000000..a349f9d --- /dev/null +++ b/Huggingface2CoreML.ipynb @@ -0,0 +1,424 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "801db364", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "scikit-learn version 1.2.1 is not supported. Minimum required version: 0.17. Maximum required version: 1.1.2. Disabling scikit-learn conversion API.\n" + ] + } + ], + "source": [ + "import torch\n", + "import coremltools as ct\n", + "import clip\n", + "import numpy as np\n", + "from PIL import Image" + ] + }, + { + "cell_type": "markdown", + "id": "26f7dcff", + "metadata": {}, + "source": [ + "# 1. Export TextEncoder" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8f89976b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/jing/anaconda3/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py:286: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", + " if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):\n", + "/Users/jing/anaconda3/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py:294: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", + " if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):\n", + "/Users/jing/anaconda3/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py:326: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", + " if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):\n" + ] + } + ], + "source": [ + "from transformers import CLIPTextModelWithProjection, CLIPTokenizerFast\n", + "\n", + "\n", + "model_id = \"openai/clip-vit-base-patch32\"\n", + "model = CLIPTextModelWithProjection.from_pretrained(model_id, return_dict=False)\n", + "tokenizer = CLIPTokenizerFast.from_pretrained(model_id)\n", + "model.eval()\n", + "\n", + "example_input = tokenizer(\"a photo of a cat\", return_tensors=\"pt\")\n", + "example_input = example_input.data['input_ids']\n", + "\n", + "traced_model = torch.jit.trace(model, example_input)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c87abd71", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Tuple detected at graph output. This will be flattened in the converted model.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Converting PyTorch Frontend ==> MIL Ops: 0%| | 0/830 [00:00 MIL Ops: 98%|█████████▊| 811/830 [00:00<00:00, 2660.33 ops/s]Saving value type of int64 into a builtin type of int32, might lose precision!\n", + "Converting PyTorch Frontend ==> MIL Ops: 100%|█████████▉| 828/830 [00:00<00:00, 2589.67 ops/s]\n", + "Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 104.58 passes/s]\n", + "Running MIL default pipeline: 100%|██████████| 66/66 [00:09<00:00, 7.24 passes/s]\n", + "Running MIL backend_mlprogram pipeline: 100%|██████████| 11/11 [00:00<00:00, 147.47 passes/s]\n" + ] + } + ], + "source": [ + "max_seq_length = 76 # if max_seq_length is 77 as in the original model, the validation fails, see details at the end of the notebook. Set max_seq_length to 76 works fine with the app.\n", + "text_encoder_model = ct.convert(\n", + " traced_model,\n", + " convert_to=\"mlprogram\",\n", + " minimum_deployment_target=ct.target.iOS16,\n", + " inputs=[ct.TensorType(name=\"prompt\",\n", + " shape=[1,max_seq_length],\n", + " dtype=np.int32)],\n", + " outputs=[ct.TensorType(name=\"embOutput\", dtype=np.float32),\n", + " ct.TensorType(name=\"embOutput2\", dtype=np.float32)],\n", + " )\n", + "text_encoder_model.save(\"TextEncoder_float32_test.mlpackage\")" + ] + }, + { + "cell_type": "markdown", + "id": "617e4e6b", + "metadata": {}, + "source": [ + "## Validate export precision" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "fd6af02a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" + ] + } + ], + "source": [ + "# Load the model\n", + "model = ct.models.MLModel('TextEncoder_float32_test.mlpackage')\n", + "\n", + "# Choose a tokenizer, here we use the clip tokenizer\n", + "text = clip.tokenize(\"a photo of a cat\")\n", + "text = text[:,:max_seq_length]\n", + "\n", + "# # Or use CLIPTokenizerFast\n", + "# text = tokenizer(\"a photo of a cat\", return_tensors=\"pt\", padding=\"max_length\", max_length=max_seq_length)\n", + "# text = text.data['input_ids'].to(torch.int32)\n", + "\n", + "predictions = model.predict({'prompt': text})\n", + "out = traced_model(text)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "c29d0a98", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PyTorch TextEncoder ckpt out for \"a photo of a cat\":\n", + ">>> tensor([ 0.1555, 0.0733, -0.2448, -0.2212, -0.1934, 0.2052, -0.3175, -0.7824,\n", + " -0.1816, 0.1943], grad_fn=)\n", + "\n", + "CoreML TextEncoder ckpt out for \"a photo of a cat\":\n", + ">>> [ 0.15560171 0.0732335 -0.24512495 -0.22117633 -0.19336982 0.20523793\n", + " -0.3182205 -0.78206545 -0.18144566 0.19457956]\n" + ] + } + ], + "source": [ + "print(\"PyTorch TextEncoder ckpt out for \\\"a photo of a cat\\\":\\n>>>\", out[0][0, :10])\n", + "print(\"\\nCoreML TextEncoder ckpt out for \\\"a photo of a cat\\\":\\n>>>\", predictions['embOutput'][0, :10])" + ] + }, + { + "cell_type": "markdown", + "id": "3c0d9c70", + "metadata": {}, + "source": [ + "You can see that there is some loss in precision, but it is still acceptable." + ] + }, + { + "cell_type": "markdown", + "id": "ca182b4a", + "metadata": {}, + "source": [ + "# 2. Export ImageEncoder" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "68521589", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/jing/anaconda3/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py:286: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", + " if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):\n", + "/Users/jing/anaconda3/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py:326: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", + " if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):\n" + ] + } + ], + "source": [ + "from transformers import CLIPVisionModelWithProjection, CLIPProcessor\n", + "\n", + "model_id = \"openai/clip-vit-base-patch32\"\n", + "model = CLIPVisionModelWithProjection.from_pretrained(model_id, return_dict=False)\n", + "processor = CLIPProcessor.from_pretrained(model_id)\n", + "model.eval()\n", + "\n", + "img = Image.open(\"love-letters-and-hearts.jpg\")\n", + "example_input = processor(images=img, return_tensors=\"pt\")\n", + "example_input = example_input['pixel_values']\n", + "traced_model = torch.jit.trace(model, example_input)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "304ae7b0", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Tuple detected at graph output. This will be flattened in the converted model.\n", + "Converting PyTorch Frontend ==> MIL Ops: 100%|█████████▉| 698/700 [00:00<00:00, 1354.45 ops/s]\n", + "Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 74.55 passes/s]\n", + "Running MIL default pipeline: 100%|██████████| 66/66 [00:14<00:00, 4.54 passes/s]\n", + "Running MIL backend_mlprogram pipeline: 100%|██████████| 11/11 [00:00<00:00, 220.59 passes/s]\n" + ] + } + ], + "source": [ + "bias = [-processor.image_processor.image_mean[i]/processor.image_processor.image_std[i] for i in range(3)]\n", + "scale = 1.0 / (processor.image_processor.image_std[0] * 255.0)\n", + "\n", + "image_input_scale = ct.ImageType(name=\"colorImage\",\n", + " color_layout=ct.colorlayout.RGB,\n", + " shape=example_input.shape,\n", + " scale=scale, bias=bias,\n", + " channel_first=True,)\n", + "\n", + "image_encoder_model = ct.convert(\n", + " traced_model,\n", + " convert_to=\"mlprogram\",\n", + " minimum_deployment_target=ct.target.iOS16,\n", + " inputs=[image_input_scale],\n", + " outputs=[ct.TensorType(name=\"embOutput\", dtype=np.float32), \n", + " ct.TensorType(name=\"embOutput2\", dtype=np.float32)],\n", + " )\n", + "\n", + "image_encoder_model.save(\"ImageEncoder_float32.mlpackage\")" + ] + }, + { + "cell_type": "markdown", + "id": "f3c5008e", + "metadata": {}, + "source": [ + "## Validate export" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "759bb57d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PyTorch ImageEncoder ckpt out for jpg:\n", + ">>> tensor([ 0.2489, -0.0874, -0.2821, -0.1859, -0.4129, -0.4474, 0.3093, 0.3759,\n", + " 1.0730, 0.1773], grad_fn=)\n", + "\n", + "CoreML ImageEncoder ckpt out for jpg:\n", + ">>> [ 0.24853516 -0.07476807 -0.30615234 -0.20996094 -0.4177246 -0.42163086\n", + " 0.34545898 0.47338867 1.1083984 0.1899414 ]\n" + ] + } + ], + "source": [ + "import torchvision.transforms as transforms\n", + "\n", + "image_encoder = ct.models.MLModel('ImageEncoder_float32.mlpackage')\n", + "imgPIL = Image.open(\"love-letters-and-hearts.jpg\")\n", + "imgPIL = imgPIL.resize((224, 224), Image.BICUBIC)\n", + "\n", + "img_np = np.asarray(imgPIL).astype(np.float32) # (224, 224, 3)\n", + "img_np = img_np[np.newaxis, :, :, :] # (1, 224, 224, 3)\n", + "img_np = np.transpose(img_np, [0, 3, 1, 2]) # (1, 3, 224, 224)\n", + "img_np = img_np / 255.0\n", + "torch_tensor_input = torch.from_numpy(img_np)\n", + "transform_model = torch.nn.Sequential(\n", + " transforms.Normalize(mean=processor.image_processor.image_mean,\n", + " std=processor.image_processor.image_std),\n", + ")\n", + "\n", + "predictions = image_encoder.predict({'colorImage': imgPIL})\n", + "out = traced_model(transform_model(torch_tensor_input))\n", + "print(\"PyTorch ImageEncoder ckpt out for jpg:\\n>>>\", out[0][0, :10])\n", + "print(\"\\nCoreML ImageEncoder ckpt out for jpg:\\n>>>\", predictions['embOutput'][0, :10])" + ] + }, + { + "cell_type": "markdown", + "id": "aac310d4", + "metadata": {}, + "source": [ + "## Test result for max_length = 77" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "f24ec713", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/jing/anaconda3/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py:294: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", + " if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):\n", + "Tuple detected at graph output. This will be flattened in the converted model.\n", + "Converting PyTorch Frontend ==> MIL Ops: 0%| | 0/830 [00:00 MIL Ops: 88%|████████▊ | 732/830 [00:00<00:00, 1993.68 ops/s]Saving value type of int64 into a builtin type of int32, might lose precision!\n", + "Converting PyTorch Frontend ==> MIL Ops: 100%|█████████▉| 828/830 [00:00<00:00, 2065.24 ops/s]\n", + "Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 129.31 passes/s]\n", + "Running MIL default pipeline: 100%|██████████| 66/66 [00:10<00:00, 6.39 passes/s]\n", + "Running MIL backend_mlprogram pipeline: 100%|██████████| 11/11 [00:00<00:00, 182.44 passes/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PyTorch TextEncoder ckpt out for \"a photo of a cat\":\n", + ">>> tensor([ 0.1555, 0.0733, -0.2448, -0.2212, -0.1934, 0.2052, -0.3175, -0.7824,\n", + " -0.1816, 0.1943], grad_fn=)\n", + "\n", + "CoreML TextEncoder ckpt out for \"a photo of a cat\":\n", + ">>> [-0.066312 0.17878246 0.40718645 -0.08806399 0.26841 -0.22685118\n", + " 0.2679821 -1.7103907 -0.33836532 0.28941655]\n" + ] + } + ], + "source": [ + "from transformers import CLIPTextModelWithProjection, CLIPTokenizerFast\n", + "\n", + "\n", + "model_id = \"openai/clip-vit-base-patch32\"\n", + "model = CLIPTextModelWithProjection.from_pretrained(model_id, return_dict=False)\n", + "tokenizer = CLIPTokenizerFast.from_pretrained(model_id)\n", + "model.eval()\n", + "\n", + "example_input = tokenizer(\"a photo of a cat\", return_tensors=\"pt\")\n", + "example_input = example_input.data['input_ids']\n", + "\n", + "traced_model = torch.jit.trace(model, example_input)\n", + "\n", + "max_seq_length = 77 # if max_seq_length is 77 as in the original model, the validation fails, see details below. Set max_seq_length to 76 works fine with the app.\n", + "text_encoder_model = ct.convert(\n", + " traced_model,\n", + " convert_to=\"mlprogram\",\n", + " minimum_deployment_target=ct.target.iOS16,\n", + " inputs=[ct.TensorType(name=\"prompt\",\n", + " shape=[1,max_seq_length],\n", + " dtype=np.int32)],\n", + " outputs=[ct.TensorType(name=\"embOutput\", dtype=np.float32),\n", + " ct.TensorType(name=\"embOutput2\", dtype=np.float32)],\n", + " )\n", + "\n", + "# Choose a tokenizer, here we use the clip tokenizer\n", + "text = clip.tokenize(\"a photo of a cat\")\n", + "text = text[:,:max_seq_length]\n", + "\n", + "predictions = text_encoder_model.predict({'prompt': text})\n", + "out = traced_model(text)\n", + "\n", + "print(\"PyTorch TextEncoder ckpt out for \\\"a photo of a cat\\\":\\n>>>\", out[0][0, :10])\n", + "print(\"\\nCoreML TextEncoder ckpt out for \\\"a photo of a cat\\\":\\n>>>\", predictions['embOutput'][0, :10])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}