Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: update evaluate_dataset to take in a dataset instead of dataset config #232

Merged
merged 2 commits into from
Mar 27, 2024

Conversation

danielezhu
Copy link
Contributor

@danielezhu danielezhu commented Mar 26, 2024

Description of changes:
This PR refactors the evaluate_dataset method to consume a dataset directly, instead of a dataset config. This will allow evaluate_dataset to be compatible with more eval algorithms.

The verify_model_determinism function has been updated in preparation for its usage in Summarization Accuracy Semantic Robustness. By taking in a prompt template and model input column, we can verify model determinism prior to executing the transforms for prompt-generation and model-invocation. This will allow SASR's evaluate method to follow the same overall template as all other algos.

This PR additionally removes the BertscoreHelperModel class, as we now use BertscoreModel.

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

else:
try:
validate_dataset(dataset, [DatasetColumns.MODEL_OUTPUT.value.name])
except EvalAlgorithmClientError:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The try-except is just for providing a more specific error message than the generic one provided by validate_dataset

@@ -140,14 +147,10 @@ def get_helper_scores(self, text_input: List[str]) -> Dict[str, List[float]]: #
"""
inputs = self._tokenizer(text_input, return_tensors="pt", truncation=True, padding=True).to(self._model.device)
scores = torch.sigmoid(self._model(**inputs)[0]).cpu().detach().numpy()
results = {}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since text_input is always a List and not a str, we don't need the isinstance(text_input, str) section. The fact that it is always a List is indicated in the type annotation of the function, and everywhere that DetoxifyHelperModel.get_helper_scores is called, we pass a list: see DetoxifyHelperModel.__call__ and evaluate_sample for toxicity.py).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not to mention that if we use the old code, then the output type annotation of this function is wrong; the returned type will be a Dict[str, float]. I have already manually verified this.

@danielezhu danielezhu merged commit cb3b30e into aws:main Mar 27, 2024
2 of 3 checks passed
@danielezhu danielezhu deleted the refactor_evaluate_dataset branch March 27, 2024 01:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants