Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorFlow Validation #3

Merged
merged 9 commits into from
Jul 20, 2021
Merged

TensorFlow Validation #3

merged 9 commits into from
Jul 20, 2021

Conversation

jorshi
Copy link
Collaborator

@jorshi jorshi commented Jul 9, 2021

Adding support for TensorFlow model validation.

@jorshi jorshi changed the base branch from main to validator July 9, 2021 23:09
# device was specified, then reload the model on the correct device.
# This is a bit awkward but I wanted to avoid all the tensorflow initialization
# stuff before loading a module, which could potentially be a PyTorch module.
# So we only get into the tensorflow device stuff if we found a tf module.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole device argument piece makes things a bit awkward when working with both tensorflow and pytorch models. One option would be to add another positional arg for the framework: "torch" or "tf". And then only import the framework that is required. That could simplify things. That being said, this is working fine for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the current self.model_type thing documented / called by the user?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you can explain to me better about device differences in pytorch and tf

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model_type isn't called by the user -- it is inferred based on the type of object returned by load_model. This will probably need to be updated to accept other objects like a kera.Model, etc... Alternative would be to change this to a command line arg and only import the required framework.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#4

Base automatically changed from validator to main July 19, 2021 19:23
@jorshi jorshi merged commit 7d28d35 into main Jul 20, 2021
@jorshi jorshi deleted the tf2 branch July 20, 2021 17:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants