-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TensorFlow Validation #3
Conversation
# device was specified, then reload the model on the correct device. | ||
# This is a bit awkward but I wanted to avoid all the tensorflow initialization | ||
# stuff before loading a module, which could potentially be a PyTorch module. | ||
# So we only get into the tensorflow device stuff if we found a tf module. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This whole device argument piece makes things a bit awkward when working with both tensorflow and pytorch models. One option would be to add another positional arg for the framework: "torch" or "tf". And then only import the framework that is required. That could simplify things. That being said, this is working fine for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is the current self.model_type thing documented / called by the user?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe you can explain to me better about device differences in pytorch and tf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model_type isn't called by the user -- it is inferred based on the type of object returned by load_model. This will probably need to be updated to accept other objects like a kera.Model, etc... Alternative would be to change this to a command line arg and only import the required framework.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding support for TensorFlow model validation.