Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autoparallel] handled illegal sharding strategy #1728

Merged

Conversation

FrankLeeeee
Copy link
Contributor

@FrankLeeeee FrankLeeeee commented Oct 18, 2022

What is the problem?

There is seldom a systematic way to filter out the illegal sharding strategy generated by the StrategyGenerator. Illegal sharding strategies can be generated when:

  1. a tensor cannot be sharded at all, e.g. a tensor is of the shape [1, 4] and we want to shard the dimension 0
  2. the logical sharding spec cannot be applied to the physical shape, e.g. for logical shape [4, 4], physical shape [1, 2, 2, 4] and logical sharding spec [S, S], we cannot have a physical sharding spec [S, R, R, S]

These sharding strategies are allowed by default in the current implementation and thus will lead to wrong result.

What does this PR do?

This PR designed and implemented a systematic and hierarchical way of handling illegal sharding strategy. The illegal ones are captured in three layers:

  1. ShardingSpec: when a ShardingSpec is instantiated, it will automatically check whether the specs are correct. If not, it will throw ShardingSpecException.
  2. StrategyGenerator: for a StrategyGenerator, we need to implement one method for one strategy. This method must be decorated with ignore_sharding_exception, this decorator will capture the ShardingSpecException and return None upon exception. These None values will be removed automatically later.
  3. NodeHandler: During logical-physical sharding spec conversion in the post_process method, the developer needs to manually catch the ShardingSpecException. The NodeHandler will check for the validity of the sharding strategy in register_strategy method (currently not implemented in this PR as it will cause many tests to fail, I will put up a separate PR to deal with this).

In this way, we can ensure each node has the correct sharding strategies.

A summary of the code change

  1. ShardingSpecException is defined in the sharding_spec.py so that every exception we throw has better semantics.

  2. In this PR, the APIs of the StrategyGenerator is refactored by introducing an additional collate_strategies method. This method is introduced so that we don't have to manually remove illegal sharding strategy and update cost in every generator. This will be taken over by the generate method and code redundancy is removed. Therefore, for every childStrategyGenerator, the generate method is changed to collate_strategies method, and generate method only exists in the parent class.

Screen Shot 2022-10-18 at 15 14 00

  1. Every strategy method in the child StrategyGenerator implementation is decorated with ignore_sharding_exception. Two changes are included for this decorator.

    • This function is renamed from exception_handler because we don't catch the general exception, but only sharding exception
    • This function now returns None explicitly if exception occurs.
  2. The validate method will be called inside __init__ method of the StrategyGenerator class. validate is defined previously but never called anywhere, so now it is called to do first-hand checking.

  3. As an example, I added stricter test code in the test_linear_node_handler.py to ensure all generated sharding strategies are valid.

@FrankLeeeee FrankLeeeee force-pushed the hotfix/complete-node-handler-testing branch from 841ebc6 to 48b52e9 Compare October 19, 2022 03:44
@FrankLeeeee FrankLeeeee merged commit eee8490 into hpcaitech:main Oct 19, 2022
@FrankLeeeee FrankLeeeee deleted the hotfix/complete-node-handler-testing branch January 26, 2023 07:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants