-
Notifications
You must be signed in to change notification settings - Fork 218
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add ml processor for offline batch inference
Signed-off-by: Xun Zhang <[email protected]>
- Loading branch information
1 parent
b5b45bc
commit 26b5f58
Showing
21 changed files
with
1,081 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
|
||
# ml Processor | ||
|
||
This plugin enables you to send data from your Data Prepper pipeline directly to ml-commons for machine learning related activities. | ||
|
||
## Usage | ||
```aidl | ||
lambda-pipeline: | ||
... | ||
processor: | ||
- ml: | ||
host: "https://search-xunzh-ml-tests-ihx7htldf7nvo2gdg25m6ehthq.us-east-1.es.amazonaws.com" | ||
aws_sigv4: true | ||
action_type: "batch_predict" | ||
service_name: "bedrock" | ||
model_id: "6ifdTZUBEBlFHJzvGSxO" | ||
output_path: "s3://offlinebatch/bedrock-multisource/output-multisource/" | ||
aws: | ||
region: "us-east-1" | ||
ml_when: /bucket == "offlinebatch" | ||
``` | ||
`model_id` as the model id that is registered in the OpenSearch ml-commons plugin. | ||
`service_name` as the remote AI service platform to process then batch job. | ||
`output_path` as the batch job output location of the S3 Uri | ||
|
||
# Metrics | ||
|
||
### Counter | ||
- `mlProcessorSuccessRequests`: measures total number of requests received and processed successfully by ml-processor. | ||
- `mlProcessorFailedRequests`: measures total number of requests failed by ml-processor. | ||
- `numberOfBatchJobsCreationSucceeded`: measures total number of batch jobs successfully created (200 response status code) by OpenSearch ml-commons API. | ||
- `numberOfBatchJobsCreationFailed`: measures total number of batch jobs failed in creation by OpenSearch ml-commons API. | ||
|
||
## Developer Guide | ||
|
||
The integration tests for this plugin do not run as part of the Data Prepper build. | ||
The following command runs the integration tests: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
dependencies { | ||
implementation project(':data-prepper-api') | ||
implementation project(path: ':data-prepper-plugins:common') | ||
implementation project(':data-prepper-plugins:aws-plugin-api') | ||
implementation project(':data-prepper-plugins:failures-common') | ||
implementation project(':data-prepper-plugins:parse-json-processor') | ||
implementation 'io.micrometer:micrometer-core' | ||
implementation 'com.fasterxml.jackson.core:jackson-core' | ||
implementation 'com.fasterxml.jackson.core:jackson-databind' | ||
implementation 'software.amazon.awssdk:sdk-core:2.x.x' | ||
implementation 'software.amazon.awssdk:sts' | ||
implementation 'org.hibernate.validator:hibernate-validator:8.0.1.Final' | ||
implementation 'com.fasterxml.jackson.dataformat:jackson-dataformat-yaml' | ||
implementation 'org.json:json' | ||
implementation libs.commons.lang3 | ||
implementation 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310' | ||
implementation 'org.projectlombok:lombok:1.18.22' | ||
implementation 'software.amazon.awssdk:s3' | ||
compileOnly 'org.projectlombok:lombok:1.18.20' | ||
annotationProcessor 'org.projectlombok:lombok:1.18.20' | ||
testCompileOnly 'org.projectlombok:lombok:1.18.20' | ||
testAnnotationProcessor 'org.projectlombok:lombok:1.18.20' | ||
testImplementation 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310' | ||
testImplementation project(':data-prepper-test-common') | ||
testImplementation testLibs.slf4j.simple | ||
testImplementation 'org.mockito:mockito-core:4.6.1' | ||
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.2' | ||
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.2' | ||
} | ||
|
||
test { | ||
useJUnitPlatform() | ||
} | ||
|
||
sourceSets { | ||
integrationTest { | ||
java { | ||
compileClasspath += main.output + test.output | ||
runtimeClasspath += main.output + test.output | ||
srcDir file('src/integrationTest/java') | ||
} | ||
resources.srcDir file('src/integrationTest/resources') | ||
} | ||
} | ||
|
||
configurations { | ||
integrationTestImplementation.extendsFrom testImplementation | ||
integrationTestRuntime.extendsFrom testRuntime | ||
} | ||
|
||
task integrationTest(type: Test) { | ||
group = 'verification' | ||
testClassesDirs = sourceSets.integrationTest.output.classesDirs | ||
|
||
useJUnitPlatform() | ||
|
||
classpath = sourceSets.integrationTest.runtimeClasspath | ||
|
||
systemProperty 'log4j.configurationFile', 'src/test/resources/log4j2.properties' | ||
|
||
filter { | ||
includeTestsMatching '*IT' | ||
} | ||
} |
96 changes: 96 additions & 0 deletions
96
...-processor/src/main/java/org/opensearch/dataprepper/plugins/ml/processor/MLProcessor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.dataprepper.plugins.ml.processor; | ||
|
||
import io.micrometer.core.instrument.Counter; | ||
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; | ||
import org.opensearch.dataprepper.expression.ExpressionEvaluator; | ||
import org.opensearch.dataprepper.metrics.PluginMetrics; | ||
import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin; | ||
import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor; | ||
import org.opensearch.dataprepper.model.event.Event; | ||
import org.opensearch.dataprepper.model.processor.AbstractProcessor; | ||
import org.opensearch.dataprepper.model.processor.Processor; | ||
import org.opensearch.dataprepper.model.record.Record; | ||
import org.opensearch.dataprepper.plugins.ml.processor.common.MLBatchJobCreator; | ||
import org.opensearch.dataprepper.plugins.ml.processor.common.MLBatchJobCreatorFactory; | ||
import org.opensearch.dataprepper.plugins.ml.processor.configuration.ServiceName; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Collection; | ||
import java.util.List; | ||
|
||
import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY; | ||
|
||
@DataPrepperPlugin(name = "ml", pluginType = Processor.class, pluginConfigurationType = MLProcessorConfig.class) | ||
public class MLProcessor extends AbstractProcessor<Record<Event>, Record<Event>> { | ||
public static final Logger LOG = LoggerFactory.getLogger(MLProcessor.class); | ||
public static final String NUMBER_OF_ML_PROCESSOR_SUCCESS = "mlProcessorSuccessfullyCreated"; | ||
public static final String NUMBER_OF_ML_PROCESSOR_FAILED = "mlProcessorFailedToCreated"; | ||
|
||
private final String whenCondition; | ||
private MLBatchJobCreator mlBatchJobCreator; | ||
private final Counter numberOfMLProcessorSuccessCounter; | ||
private final Counter numberOfMLProcessorFailedCounter; | ||
private final ExpressionEvaluator expressionEvaluator; | ||
|
||
@DataPrepperPluginConstructor | ||
public MLProcessor(final MLProcessorConfig mlProcessorConfig, final PluginMetrics pluginMetrics, final AwsCredentialsSupplier awsCredentialsSupplier, final ExpressionEvaluator expressionEvaluator) { | ||
super(pluginMetrics); | ||
this.whenCondition = mlProcessorConfig.getWhenCondition(); | ||
ServiceName serviceName = mlProcessorConfig.getServiceName(); | ||
this.numberOfMLProcessorSuccessCounter = pluginMetrics.counter( | ||
NUMBER_OF_ML_PROCESSOR_SUCCESS); | ||
this.numberOfMLProcessorFailedCounter = pluginMetrics.counter( | ||
NUMBER_OF_ML_PROCESSOR_FAILED); | ||
this.expressionEvaluator = expressionEvaluator; | ||
|
||
// Use factory to get the appropriate job creator | ||
mlBatchJobCreator = MLBatchJobCreatorFactory.getJobCreator(serviceName, mlProcessorConfig, awsCredentialsSupplier, pluginMetrics); | ||
} | ||
|
||
@Override | ||
public Collection<Record<Event>> doExecute(Collection<Record<Event>> records) { | ||
// reads from input - S3 input | ||
if (records.size() == 0) | ||
return records; | ||
System.out.println("Received .... " + records.size() + " records"); | ||
|
||
List<Record<Event>> recordsToMlCommons = new ArrayList<>(); | ||
for (Record<Event> record : records) { | ||
final Event event = record.getData(); | ||
if (whenCondition != null && !expressionEvaluator.evaluateConditional(whenCondition, | ||
event)) { | ||
continue; | ||
} | ||
recordsToMlCommons.add(record); | ||
} | ||
|
||
try { | ||
mlBatchJobCreator.createMLBatchJob(recordsToMlCommons); | ||
numberOfMLProcessorSuccessCounter.increment(); | ||
} catch (Exception e) { | ||
LOG.error(NOISY, e.getMessage(), e); | ||
numberOfMLProcessorFailedCounter.increment(); | ||
} | ||
return records; | ||
} | ||
|
||
@Override | ||
public void prepareForShutdown() { | ||
} | ||
|
||
@Override | ||
public boolean isReadyForShutdown() { | ||
return true; | ||
} | ||
|
||
@Override | ||
public void shutdown() { | ||
} | ||
} |
72 changes: 72 additions & 0 deletions
72
...ssor/src/main/java/org/opensearch/dataprepper/plugins/ml/processor/MLProcessorConfig.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.dataprepper.plugins.ml.processor; | ||
|
||
import com.fasterxml.jackson.annotation.JsonClassDescription; | ||
import com.fasterxml.jackson.annotation.JsonProperty; | ||
import com.fasterxml.jackson.annotation.JsonPropertyDescription; | ||
import com.fasterxml.jackson.annotation.JsonPropertyOrder; | ||
import jakarta.validation.Valid; | ||
import jakarta.validation.constraints.NotNull; | ||
import lombok.Getter; | ||
import org.opensearch.dataprepper.model.annotations.ExampleValues; | ||
import org.opensearch.dataprepper.plugins.ml.processor.configuration.ActionType; | ||
import org.opensearch.dataprepper.plugins.ml.processor.configuration.AwsAuthenticationOptions; | ||
import org.opensearch.dataprepper.plugins.ml.processor.configuration.ServiceName; | ||
|
||
@Getter | ||
@JsonPropertyOrder | ||
@JsonClassDescription("The <code>ml</code> processor enables invocation of the ml-commons plugin in OpenSearch service within your pipeline in order to process events. " + | ||
"It supports both synchronous and asynchronous invocations based on your use case.") | ||
public class MLProcessorConfig { | ||
|
||
@JsonProperty("aws") | ||
@NotNull | ||
@Valid | ||
private AwsAuthenticationOptions awsAuthenticationOptions; | ||
|
||
@JsonPropertyDescription("action type defines the way we want to invoke ml-commons in the predict API") | ||
@JsonProperty("action_type") | ||
private ActionType actionType = ActionType.BATCH_PREDICT; | ||
|
||
@JsonPropertyDescription("AI service hosting the remote model for ML Commons predictions") | ||
@JsonProperty("service_name") | ||
private ServiceName serviceName = ServiceName.SAGEMAKER; | ||
|
||
@JsonPropertyDescription("defines the OpenSearch host url to be invoked") | ||
@JsonProperty("host") | ||
private String hostUrl; | ||
|
||
@JsonPropertyDescription("defines the model id to be invoked in ml-commons") | ||
@JsonProperty("model_id") | ||
private String modelId; | ||
|
||
@JsonPropertyDescription("defines the S3 location to write the offline model responses to") | ||
@JsonProperty("output_path") | ||
private String outputPath; | ||
|
||
@JsonProperty("aws_sigv4") | ||
private boolean awsSigv4; | ||
|
||
@JsonPropertyDescription("Defines a condition for event to use this processor.") | ||
@ExampleValues({ | ||
@ExampleValues.Example(value = "/some_key == null", description = "The processor will only run on events where this condition evaluates to true.") | ||
}) | ||
@JsonProperty("ml_when") | ||
private String whenCondition; | ||
|
||
public ActionType getActionType() { | ||
return actionType; | ||
} | ||
|
||
public String getModelId() { return modelId; } | ||
|
||
public String getHostUrl() { return hostUrl; } | ||
|
||
public String getWhenCondition() { | ||
return whenCondition; | ||
} | ||
} |
9 changes: 9 additions & 0 deletions
9
...src/main/java/org/opensearch/dataprepper/plugins/ml/processor/client/S3ClientFactory.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.dataprepper.plugins.ml.processor.client; | ||
|
||
public class S3ClientFactory { | ||
} |
50 changes: 50 additions & 0 deletions
50
.../java/org/opensearch/dataprepper/plugins/ml/processor/common/AbstractBatchJobCreator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.dataprepper.plugins.ml.processor.common; | ||
|
||
import com.fasterxml.jackson.databind.ObjectMapper; | ||
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; | ||
import io.micrometer.core.instrument.Counter; | ||
import org.opensearch.dataprepper.metrics.PluginMetrics; | ||
import org.opensearch.dataprepper.model.event.Event; | ||
import org.opensearch.dataprepper.model.record.Record; | ||
import org.opensearch.dataprepper.plugins.ml.processor.MLProcessorConfig; | ||
|
||
import java.util.Collection; | ||
|
||
public abstract class AbstractBatchJobCreator implements MLBatchJobCreator { | ||
public static final String NUMBER_OF_SUCCESSFUL_BATCH_JOBS_CREATION = "numberOfBatchJobsCreationSucceeded"; | ||
public static final String NUMBER_OF_FAILED_BATCH_JOBS_CREATION = "numberOfBatchJobsCreationFailed"; | ||
|
||
protected static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); | ||
protected final MLProcessorConfig mlProcessorConfig; | ||
protected final AwsCredentialsSupplier awsCredentialsSupplier; | ||
protected final Counter numberOfBatchJobsSuccessCounter; | ||
protected final Counter numberOfBatchJobsFailedCounter; | ||
|
||
// Constructor | ||
public AbstractBatchJobCreator(MLProcessorConfig mlProcessorConfig, | ||
AwsCredentialsSupplier awsCredentialsSupplier, | ||
final PluginMetrics pluginMetrics) { | ||
this.mlProcessorConfig = mlProcessorConfig; | ||
this.awsCredentialsSupplier = awsCredentialsSupplier; | ||
this.numberOfBatchJobsSuccessCounter = pluginMetrics.counter(NUMBER_OF_SUCCESSFUL_BATCH_JOBS_CREATION); | ||
this.numberOfBatchJobsFailedCounter = pluginMetrics.counter(NUMBER_OF_FAILED_BATCH_JOBS_CREATION); | ||
} | ||
|
||
// Add common logic here that both subclasses can share | ||
public void incrementSuccessCounter() { | ||
numberOfBatchJobsSuccessCounter.increment(); | ||
} | ||
|
||
public void incrementFailureCounter() { | ||
numberOfBatchJobsFailedCounter.increment(); | ||
} | ||
|
||
// Abstract methods for batch job creation, specific to the implementations | ||
public abstract void createMLBatchJob(Collection<Record<Event>> records); | ||
|
||
} |
Oops, something went wrong.