Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update torch.export load with new api #2906

Merged
merged 10 commits into from
Jan 29, 2024

Conversation

agunapal
Copy link
Collaborator

Description

This PR integrates the new API torch._export.aot_load and deletes the existing implementation from TorchServe
This also updates the pytest to support batch size > 15.

Relevant PRs:
pytorch/pytorch#116152
pytorch/pytorch#117610
pytorch/pytorch#117948

Fixes #(issue)

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

Feature/Issue validation/testing

  • pytest
================================================================================== test session starts ===================================================================================
platform linux -- Python 3.10.13, pytest-7.3.1, pluggy-1.0.0 -- /home/ubuntu/anaconda3/envs/ts_aot/bin/python
cachedir: .pytest_cache
rootdir: /home/ubuntu/serve
plugins: cov-4.1.0, mock-3.12.0
collected 2 items                                                                                                                                                                        

test_torch_export.py::test_torch_export_aot_compile PASSED                                                                                                                         [ 50%]
test_torch_export.py::test_torch_export_aot_compile_dynamic_batching PASSED                                                                                                        [100%]

==================================================================================== warnings summary ====================================================================================
test_torch_export.py:4
  /home/ubuntu/serve/test/pytest/test_torch_export.py:4: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
    from pkg_resources import packaging

../../../anaconda3/envs/ts_aot/lib/python3.10/site-packages/pkg_resources/__init__.py:2868
../../../anaconda3/envs/ts_aot/lib/python3.10/site-packages/pkg_resources/__init__.py:2868
  /home/ubuntu/anaconda3/envs/ts_aot/lib/python3.10/site-packages/pkg_resources/__init__.py:2868: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('ruamel')`.
  Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
    declare_namespace(pkg)

../../../anaconda3/envs/ts_aot/lib/python3.10/site-packages/torchvision/transforms/_functional_pil.py:242
  /home/ubuntu/anaconda3/envs/ts_aot/lib/python3.10/site-packages/torchvision/transforms/_functional_pil.py:242: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.
    interpolation: int = Image.BILINEAR,

../../../anaconda3/envs/ts_aot/lib/python3.10/site-packages/torchvision/transforms/_functional_pil.py:288
  /home/ubuntu/anaconda3/envs/ts_aot/lib/python3.10/site-packages/torchvision/transforms/_functional_pil.py:288: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
    interpolation: int = Image.NEAREST,

../../../anaconda3/envs/ts_aot/lib/python3.10/site-packages/torchvision/transforms/_functional_pil.py:304
  /home/ubuntu/anaconda3/envs/ts_aot/lib/python3.10/site-packages/torchvision/transforms/_functional_pil.py:304: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.
    interpolation: int = Image.NEAREST,

../../../anaconda3/envs/ts_aot/lib/python3.10/site-packages/torchvision/transforms/_functional_pil.py:321
  /home/ubuntu/anaconda3/envs/ts_aot/lib/python3.10/site-packages/torchvision/transforms/_functional_pil.py:321: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.
    interpolation: int = Image.BICUBIC,

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================================================================= 2 passed, 7 warnings in 53.21s ==========================================================================
  • TorchServe logs
curl http://127.0.0.1:8080/predictions/res18-pt2 -T ../../image_classifier/kitten.jpg
2024-01-25T01:56:51,903 [INFO ] epollEventLoopGroup-3-2 TS_METRICS - ts_inference_requests_total.Count:1.0|#model_name:res18-pt2,model_version:default|#hostname:ip-172-31-60-100,timestamp:1706147811
2024-01-25T01:56:51,904 [DEBUG] W-9010-res18-pt2_1.0 org.pytorch.serve.wlm.WorkerThread - Flushing req.cmd PREDICT repeats 1 to backend at: 1706147811904
2024-01-25T01:56:51,904 [INFO ] W-9010-res18-pt2_1.0 org.pytorch.serve.wlm.WorkerThread - Looping backend response at: 1706147811904
2024-01-25T01:56:51,905 [INFO ] W-9010-res18-pt2_1.0-stdout MODEL_LOG - Backend received inference at: 1706147811
2024-01-25T01:56:51,959 [INFO ] W-9010-res18-pt2_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]HandlerTime.Milliseconds:54.22|#ModelName:res18-pt2,Level:Model|#type:GAUGE|#hostname:ip-172-31-60-100,1706147811,c29e56d9-5bda-48f0-8feb-c927ddaa7753, pattern=[METRICS]
2024-01-25T01:56:51,960 [INFO ] W-9010-res18-pt2_1.0-stdout MODEL_METRICS - HandlerTime.ms:54.22|#ModelName:res18-pt2,Level:Model|#hostname:ip-172-31-60-100,requestID:c29e56d9-5bda-48f0-8feb-c927ddaa7753,timestamp:1706147811
2024-01-25T01:56:51,960 [INFO ] W-9010-res18-pt2_1.0 ACCESS_LOG - /127.0.0.1:43686 "PUT /predictions/res18-pt2 HTTP/1.1" 200 57
2024-01-25T01:56:51,960 [INFO ] W-9010-res18-pt2_1.0 TS_METRICS - Requests2XX.Count:1.0|#Level:Host|#hostname:ip-172-31-60-100,timestamp:1706147811
2024-01-25T01:56:51,960 [INFO ] W-9010-res18-pt2_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]PredictionTime.Milliseconds:54.42|#ModelName:res18-pt2,Level:Model|#type:GAUGE|#hostname:ip-172-31-60-100,1706147811,c29e56d9-5bda-48f0-8feb-c927ddaa7753, pattern=[METRICS]
2024-01-25T01:56:51,960 [INFO ] W-9010-res18-pt2_1.0-stdout MODEL_METRICS - PredictionTime.ms:54.42|#ModelName:res18-pt2,Level:Model|#hostname:ip-172-31-60-100,requestID:c29e56d9-5bda-48f0-8feb-c927ddaa7753,timestamp:1706147811
2024-01-25T01:56:51,960 [INFO ] W-9010-res18-pt2_1.0 TS_METRICS - ts_inference_latency_microseconds.Microseconds:56087.913|#model_name:res18-pt2,model_version:default|#hostname:ip-172-31-60-100,timestamp:1706147811
2024-01-25T01:56:51,960 [INFO ] W-9010-res18-pt2_1.0 TS_METRICS - ts_queue_latency_microseconds.Microseconds:107.492|#model_name:res18-pt2,model_version:default|#hostname:ip-172-31-60-100,timestamp:1706147811
2024-01-25T01:56:51,960 [DEBUG] W-9010-res18-pt2_1.0 org.pytorch.serve.job.RestJob - Waiting time ns: 107492, Backend time ns: 56705835
2024-01-25T01:56:51,960 [INFO ] W-9010-res18-pt2_1.0 TS_METRICS - QueueTime.Milliseconds:0.0|#Level:Host|#hostname:ip-172-31-60-100,timestamp:1706147811
2024-01-25T01:56:51,960 [INFO ] W-9010-res18-pt2_1.0 org.pytorch.serve.wlm.WorkerThread - Backend response time: 56
2024-01-25T01:56:51,960 [INFO ] W-9010-res18-pt2_1.0 TS_METRICS - WorkerThreadTime.Milliseconds:0.0|#Level:Host|#hostname:ip-172-31-60-100,timestamp:1706147811
{
  "tabby": 0.4096629321575165,
  "tiger_cat": 0.34670528769493103,
  "Egyptian_cat": 0.13002873957157135,
  "lynx": 0.023919515311717987,
  "bucket": 0.01153216976672411

Checklist:

  • Did you have fun?
  • Have you added tests that prove your fix is effective or that this feature works?
  • Has code been commented, particularly in hard-to-understand areas?
  • Have you made corresponding changes to the documentation?

@agunapal agunapal requested review from mreso and lxning January 25, 2024 02:02
Comment on lines 16 to 17
# The below config is needed for max batch_size = 16
# https://github.com/pytorch/pytorch/pull/116152
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check the batch_size input to reflect the max batch_size limitation ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks..Updated the comments. There is no limitation on batch_size.

Copy link
Collaborator

@mreso mreso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall LGTM, left some minor comments

True
if packaging.version.parse(torch.__version__) > packaging.version.parse("2.1.1")
if packaging.version.parse(torch.__version__) > packaging.version.parse("2.2.2")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would set PT_230_AVAILABLE to true if a 2.2.3 is released. Is that correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ankithagunapal Is this check being needed due to API changes between PT2.2.0 and latest nightlies which are 2.3.xxx?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this needs the latest nightlies. The new API got merged only in the last week or so. Based on the current release cadence, its 2.3.0 after 2.2.2 ( 2 patch releases after every major release) But i will keep this in mind in case this changes.

@@ -123,6 +123,6 @@ def test_torch_export_aot_compile_dynamic_batching(custom_working_directory):
data["body"] = byte_array_type

# Send a batch of 16 elements
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: outdated comment

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Updated

@@ -123,6 +123,6 @@ def test_torch_export_aot_compile_dynamic_batching(custom_working_directory):
data["body"] = byte_array_type

# Send a batch of 16 elements
result = handler.handle([data for i in range(15)], ctx)
result = handler.handle([data for i in range(32)], ctx)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about parameterizing the test and test both?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this. I changed it

@@ -53,10 +53,10 @@
)
PT2_AVAILABLE = False

if packaging.version.parse(torch.__version__) > packaging.version.parse("2.1.1"):
PT220_AVAILABLE = True
if packaging.version.parse(torch.__version__) > packaging.version.parse("2.2.2"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor

@chauhang chauhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@agunapal Have you also tested with torch.export save/load? Should speedup things for the actual inference part. If that works we can provide a separate preparation script for saving the exported model.

@agunapal
Copy link
Collaborator Author

@agunapal Have you also tested with torch.export save/load? Should speedup things for the actual inference part. If that works we can provide a separate preparation script for saving the exported model.

Hi @chauhang I tried that first #2812
Currently torch.export by itself is not useful for inference since it generates an FX Graph.

@agunapal agunapal added this pull request to the merge queue Jan 29, 2024
Merged via the queue into master with commit cfb4285 Jan 29, 2024
13 checks passed
@agunapal agunapal deleted the issues/update_torch_export_aot_compile branch January 29, 2024 21:48
@chauhang chauhang added this to the v0.10.0 milestone Feb 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants