Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile ImageClassifier example #2915

Merged
merged 13 commits into from
Feb 6, 2024
Merged

Conversation

agunapal
Copy link
Collaborator

@agunapal agunapal commented Jan 31, 2024

Description

This PR shows how to use torch.compile with densenet161.

It also adds a pytest for the same

Fixes #(issue)

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

Feature/Issue validation/testing

Please describe the Unit or Integration tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced.
Please also list any relevant details for your test configuration.

-[ ] 3x speedup with torch.compile

2024-02-03T00:54:31,136 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_preprocess.Milliseconds:6.118656158447266|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921671,c02b3170-c8fc-4396-857d-6c6266bf94a9, pattern=[METRICS]
2024-02-03T00:54:31,155 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_inference.Milliseconds:18.77564811706543|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921671,c02b3170-c8fc-4396-857d-6c6266bf94a9, pattern=[METRICS]
2024-02-03T00:54:31,155 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_postprocess.Milliseconds:0.16630400717258453|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921671,c02b3170-c8fc-4396-857d-6c6266bf94a9, pattern=[METRICS]
2024-02-03T00:56:14,808 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_preprocess.Milliseconds:5.9771199226379395|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921774,d38601be-6312-46b4-b455-0322150509e5, pattern=[METRICS]
2024-02-03T00:56:14,814 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_inference.Milliseconds:5.8818559646606445|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921774,d38601be-6312-46b4-b455-0322150509e5, pattern=[METRICS]
2024-02-03T00:56:14,814 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]ts_handler_postprocess.Milliseconds:0.19392000138759613|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706921774,d38601be-6312-46b4-b455-0322150509e5, pattern=[METRICS]
  • pytest
pytest -v test_example_torch_compile.py 
================================================================================== test session starts ===================================================================================
platform linux -- Python 3.10.13, pytest-7.3.1, pluggy-1.0.0 -- /home/ubuntu/anaconda3/envs/ts_export_aot/bin/python
cachedir: .pytest_cache
rootdir: /home/ubuntu/serve
plugins: cov-4.1.0, mock-3.12.0
collected 1 item                                                                                                                                                                         

test_example_torch_compile.py::test_torch_compile_inference PASSED                                                                                                                 [100%]

==================================================================================== warnings summary ====================================================================================
test_example_torch_compile.py:4
  /home/ubuntu/serve/test/pytest/test_example_torch_compile.py:4: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
    from pkg_resources import packaging

../../../anaconda3/envs/ts_export_aot/lib/python3.10/site-packages/pkg_resources/__init__.py:2868
../../../anaconda3/envs/ts_export_aot/lib/python3.10/site-packages/pkg_resources/__init__.py:2868
  /home/ubuntu/anaconda3/envs/ts_export_aot/lib/python3.10/site-packages/pkg_resources/__init__.py:2868: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('ruamel')`.
  Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
    declare_namespace(pkg)

../../../anaconda3/envs/ts_export_aot/lib/python3.10/site-packages/transformers/utils/generic.py:441
  /home/ubuntu/anaconda3/envs/ts_export_aot/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
    _torch_pytree._register_pytree_node(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================================================================= 1 passed, 4 warnings in 38.66s ============================================================================
  • Inference logs
curl http://127.0.0.1:8080/predictions/densenet161 -T ../../image_classifier/kitten.jpg
2024-01-31T22:20:53,900 [INFO ] epollEventLoopGroup-3-1 TS_METRICS - ts_inference_requests_total.Count:1.0|#model_name:densenet161,model_version:default|#hostname:ip-172-31-11-40,timestamp:1706739653
2024-01-31T22:20:53,902 [DEBUG] W-9000-densenet161_1.0 org.pytorch.serve.wlm.WorkerThread - Flushing req.cmd PREDICT repeats 1 to backend at: 1706739653902
2024-01-31T22:20:53,902 [INFO ] W-9000-densenet161_1.0 org.pytorch.serve.wlm.WorkerThread - Looping backend response at: 1706739653902
2024-01-31T22:20:53,904 [INFO ] W-9000-densenet161_1.0-stdout MODEL_LOG - Backend received inference at: 1706739653
2024-01-31T22:21:25,858 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]HandlerTime.Milliseconds:31953.46|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706739685,f88bb520-8532-4d55-8723-c4d3b6435c73, pattern=[METRICS]
2024-01-31T22:21:25,858 [INFO ] W-9000-densenet161_1.0-stdout MODEL_METRICS - HandlerTime.ms:31953.46|#ModelName:densenet161,Level:Model|#hostname:ip-172-31-11-40,requestID:f88bb520-8532-4d55-8723-c4d3b6435c73,timestamp:1706739685
2024-01-31T22:21:25,859 [INFO ] W-9000-densenet161_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - result=[METRICS]PredictionTime.Milliseconds:31953.75|#ModelName:densenet161,Level:Model|#type:GAUGE|#hostname:ip-172-31-11-40,1706739685,f88bb520-8532-4d55-8723-c4d3b6435c73, pattern=[METRICS]
2024-01-31T22:21:25,859 [INFO ] W-9000-densenet161_1.0-stdout MODEL_METRICS - PredictionTime.ms:31953.75|#ModelName:densenet161,Level:Model|#hostname:ip-172-31-11-40,requestID:f88bb520-8532-4d55-8723-c4d3b6435c73,timestamp:1706739685
2024-01-31T22:21:25,859 [INFO ] W-9000-densenet161_1.0 ACCESS_LOG - /127.0.0.1:54538 "PUT /predictions/densenet161 HTTP/1.1" 200 31960
2024-01-31T22:21:25,859 [INFO ] W-9000-densenet161_1.0 TS_METRICS - Requests2XX.Count:1.0|#Level:Host|#hostname:ip-172-31-11-40,timestamp:1706739685
2024-01-31T22:21:25,860 [INFO ] W-9000-densenet161_1.0 TS_METRICS - ts_inference_latency_microseconds.Microseconds:3.1956773879E7|#model_name:densenet161,model_version:default|#hostname:ip-172-31-11-40,timestamp:1706739685
2024-01-31T22:21:25,860 [INFO ] W-9000-densenet161_1.0 TS_METRICS - ts_queue_latency_microseconds.Microseconds:130.342|#model_name:densenet161,model_version:default|#hostname:ip-172-31-11-40,timestamp:1706739685
2024-01-31T22:21:25,860 [DEBUG] W-9000-densenet161_1.0 org.pytorch.serve.job.RestJob - Waiting time ns: 130342, Backend time ns: 31958446968
2024-01-31T22:21:25,861 [INFO ] W-9000-densenet161_1.0 TS_METRICS - QueueTime.Milliseconds:0.0|#Level:Host|#hostname:ip-172-31-11-40,timestamp:1706739685
2024-01-31T22:21:25,861 [INFO ] W-9000-densenet161_1.0 org.pytorch.serve.wlm.WorkerThread - Backend response time: 31955
2024-01-31T22:21:25,861 [INFO ] W-9000-densenet161_1.0 TS_METRICS - WorkerThreadTime.Milliseconds:4.0|#Level:Host|#hostname:ip-172-31-11-40,timestamp:1706739685


{
  "confectionery": 0.01269740890711546,
  "African_chameleon": 0.011943917721509933,
  "sturgeon": 0.011671576648950577,
  "Bedlington_terrier": 0.009965566918253899,
  "entertainment_center": 0.009582704864442348
}

Checklist:

  • Did you have fun?
  • Have you added tests that prove your fix is effective or that this feature works?
  • Has code been commented, particularly in hard-to-understand areas?
  • Have you made corresponding changes to the documentation?


In this example , we use the following config

```yaml
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

echo "pt2 : {backend: inductor, mode: reduce-overhead}" > model-config.yaml

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"lynx": 0.0012969186063855886,
"plastic_bag": 0.00022856894065625966
}
```
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so people would be integrating torch.compile to see speedups
I also believe the label values change a tiny bit so you should show the expected speedups

I mention this because I'm not sure if densenet speedups were there

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great feedback. I added a section on perf measurement. 3x speedup with torch.compile. Makes it very compelling now!

Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly looks good, just a few minor nits please fix before merge

@agunapal agunapal added this pull request to the merge queue Feb 6, 2024
Merged via the queue into master with commit 88eca54 Feb 6, 2024
13 checks passed
@chauhang chauhang added this to the v0.10.0 milestone Feb 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants