Skip to content

Commit

Permalink
add ml processor for offline batch inference
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Mar 6, 2025
1 parent b5b45bc commit 26b5f58
Show file tree
Hide file tree
Showing 21 changed files with 1,081 additions and 4 deletions.
38 changes: 38 additions & 0 deletions data-prepper-plugins/ml-processor/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@

# ml Processor

This plugin enables you to send data from your Data Prepper pipeline directly to ml-commons for machine learning related activities.

## Usage
```aidl
lambda-pipeline:
...
processor:
- ml:
host: "https://search-xunzh-ml-tests-ihx7htldf7nvo2gdg25m6ehthq.us-east-1.es.amazonaws.com"
aws_sigv4: true
action_type: "batch_predict"
service_name: "bedrock"
model_id: "6ifdTZUBEBlFHJzvGSxO"
output_path: "s3://offlinebatch/bedrock-multisource/output-multisource/"
aws:
region: "us-east-1"
ml_when: /bucket == "offlinebatch"
```
`model_id` as the model id that is registered in the OpenSearch ml-commons plugin.
`service_name` as the remote AI service platform to process then batch job.
`output_path` as the batch job output location of the S3 Uri

# Metrics

### Counter
- `mlProcessorSuccessRequests`: measures total number of requests received and processed successfully by ml-processor.
- `mlProcessorFailedRequests`: measures total number of requests failed by ml-processor.
- `numberOfBatchJobsCreationSucceeded`: measures total number of batch jobs successfully created (200 response status code) by OpenSearch ml-commons API.
- `numberOfBatchJobsCreationFailed`: measures total number of batch jobs failed in creation by OpenSearch ml-commons API.

## Developer Guide

The integration tests for this plugin do not run as part of the Data Prepper build.
The following command runs the integration tests:
69 changes: 69 additions & 0 deletions data-prepper-plugins/ml-processor/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

dependencies {
implementation project(':data-prepper-api')
implementation project(path: ':data-prepper-plugins:common')
implementation project(':data-prepper-plugins:aws-plugin-api')
implementation project(':data-prepper-plugins:failures-common')
implementation project(':data-prepper-plugins:parse-json-processor')
implementation 'io.micrometer:micrometer-core'
implementation 'com.fasterxml.jackson.core:jackson-core'
implementation 'com.fasterxml.jackson.core:jackson-databind'
implementation 'software.amazon.awssdk:sdk-core:2.x.x'
implementation 'software.amazon.awssdk:sts'
implementation 'org.hibernate.validator:hibernate-validator:8.0.1.Final'
implementation 'com.fasterxml.jackson.dataformat:jackson-dataformat-yaml'
implementation 'org.json:json'
implementation libs.commons.lang3
implementation 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310'
implementation 'org.projectlombok:lombok:1.18.22'
implementation 'software.amazon.awssdk:s3'
compileOnly 'org.projectlombok:lombok:1.18.20'
annotationProcessor 'org.projectlombok:lombok:1.18.20'
testCompileOnly 'org.projectlombok:lombok:1.18.20'
testAnnotationProcessor 'org.projectlombok:lombok:1.18.20'
testImplementation 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310'
testImplementation project(':data-prepper-test-common')
testImplementation testLibs.slf4j.simple
testImplementation 'org.mockito:mockito-core:4.6.1'
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.2'
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.2'
}

test {
useJUnitPlatform()
}

sourceSets {
integrationTest {
java {
compileClasspath += main.output + test.output
runtimeClasspath += main.output + test.output
srcDir file('src/integrationTest/java')
}
resources.srcDir file('src/integrationTest/resources')
}
}

configurations {
integrationTestImplementation.extendsFrom testImplementation
integrationTestRuntime.extendsFrom testRuntime
}

task integrationTest(type: Test) {
group = 'verification'
testClassesDirs = sourceSets.integrationTest.output.classesDirs

useJUnitPlatform()

classpath = sourceSets.integrationTest.runtimeClasspath

systemProperty 'log4j.configurationFile', 'src/test/resources/log4j2.properties'

filter {
includeTestsMatching '*IT'
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.ml.processor;

import io.micrometer.core.instrument.Counter;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.expression.ExpressionEvaluator;
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin;
import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.processor.AbstractProcessor;
import org.opensearch.dataprepper.model.processor.Processor;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.plugins.ml.processor.common.MLBatchJobCreator;
import org.opensearch.dataprepper.plugins.ml.processor.common.MLBatchJobCreatorFactory;
import org.opensearch.dataprepper.plugins.ml.processor.configuration.ServiceName;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY;

@DataPrepperPlugin(name = "ml", pluginType = Processor.class, pluginConfigurationType = MLProcessorConfig.class)
public class MLProcessor extends AbstractProcessor<Record<Event>, Record<Event>> {
public static final Logger LOG = LoggerFactory.getLogger(MLProcessor.class);
public static final String NUMBER_OF_ML_PROCESSOR_SUCCESS = "mlProcessorSuccessfullyCreated";
public static final String NUMBER_OF_ML_PROCESSOR_FAILED = "mlProcessorFailedToCreated";

private final String whenCondition;
private MLBatchJobCreator mlBatchJobCreator;
private final Counter numberOfMLProcessorSuccessCounter;
private final Counter numberOfMLProcessorFailedCounter;
private final ExpressionEvaluator expressionEvaluator;

@DataPrepperPluginConstructor
public MLProcessor(final MLProcessorConfig mlProcessorConfig, final PluginMetrics pluginMetrics, final AwsCredentialsSupplier awsCredentialsSupplier, final ExpressionEvaluator expressionEvaluator) {
super(pluginMetrics);
this.whenCondition = mlProcessorConfig.getWhenCondition();
ServiceName serviceName = mlProcessorConfig.getServiceName();
this.numberOfMLProcessorSuccessCounter = pluginMetrics.counter(
NUMBER_OF_ML_PROCESSOR_SUCCESS);
this.numberOfMLProcessorFailedCounter = pluginMetrics.counter(
NUMBER_OF_ML_PROCESSOR_FAILED);
this.expressionEvaluator = expressionEvaluator;

// Use factory to get the appropriate job creator
mlBatchJobCreator = MLBatchJobCreatorFactory.getJobCreator(serviceName, mlProcessorConfig, awsCredentialsSupplier, pluginMetrics);
}

@Override
public Collection<Record<Event>> doExecute(Collection<Record<Event>> records) {
// reads from input - S3 input
if (records.size() == 0)
return records;
System.out.println("Received .... " + records.size() + " records");

List<Record<Event>> recordsToMlCommons = new ArrayList<>();
for (Record<Event> record : records) {
final Event event = record.getData();
if (whenCondition != null && !expressionEvaluator.evaluateConditional(whenCondition,
event)) {
continue;
}
recordsToMlCommons.add(record);
}

try {
mlBatchJobCreator.createMLBatchJob(recordsToMlCommons);
numberOfMLProcessorSuccessCounter.increment();
} catch (Exception e) {
LOG.error(NOISY, e.getMessage(), e);
numberOfMLProcessorFailedCounter.increment();
}
return records;
}

@Override
public void prepareForShutdown() {
}

@Override
public boolean isReadyForShutdown() {
return true;
}

@Override
public void shutdown() {
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.ml.processor;

import com.fasterxml.jackson.annotation.JsonClassDescription;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import com.fasterxml.jackson.annotation.JsonPropertyOrder;
import jakarta.validation.Valid;
import jakarta.validation.constraints.NotNull;
import lombok.Getter;
import org.opensearch.dataprepper.model.annotations.ExampleValues;
import org.opensearch.dataprepper.plugins.ml.processor.configuration.ActionType;
import org.opensearch.dataprepper.plugins.ml.processor.configuration.AwsAuthenticationOptions;
import org.opensearch.dataprepper.plugins.ml.processor.configuration.ServiceName;

@Getter
@JsonPropertyOrder
@JsonClassDescription("The <code>ml</code> processor enables invocation of the ml-commons plugin in OpenSearch service within your pipeline in order to process events. " +
"It supports both synchronous and asynchronous invocations based on your use case.")
public class MLProcessorConfig {

@JsonProperty("aws")
@NotNull
@Valid
private AwsAuthenticationOptions awsAuthenticationOptions;

@JsonPropertyDescription("action type defines the way we want to invoke ml-commons in the predict API")
@JsonProperty("action_type")
private ActionType actionType = ActionType.BATCH_PREDICT;

@JsonPropertyDescription("AI service hosting the remote model for ML Commons predictions")
@JsonProperty("service_name")
private ServiceName serviceName = ServiceName.SAGEMAKER;

@JsonPropertyDescription("defines the OpenSearch host url to be invoked")
@JsonProperty("host")
private String hostUrl;

@JsonPropertyDescription("defines the model id to be invoked in ml-commons")
@JsonProperty("model_id")
private String modelId;

@JsonPropertyDescription("defines the S3 location to write the offline model responses to")
@JsonProperty("output_path")
private String outputPath;

@JsonProperty("aws_sigv4")
private boolean awsSigv4;

@JsonPropertyDescription("Defines a condition for event to use this processor.")
@ExampleValues({
@ExampleValues.Example(value = "/some_key == null", description = "The processor will only run on events where this condition evaluates to true.")
})
@JsonProperty("ml_when")
private String whenCondition;

public ActionType getActionType() {
return actionType;
}

public String getModelId() { return modelId; }

public String getHostUrl() { return hostUrl; }

public String getWhenCondition() {
return whenCondition;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.ml.processor.client;

public class S3ClientFactory {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.ml.processor.common;

import com.fasterxml.jackson.databind.ObjectMapper;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import io.micrometer.core.instrument.Counter;
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.plugins.ml.processor.MLProcessorConfig;

import java.util.Collection;

public abstract class AbstractBatchJobCreator implements MLBatchJobCreator {
public static final String NUMBER_OF_SUCCESSFUL_BATCH_JOBS_CREATION = "numberOfBatchJobsCreationSucceeded";
public static final String NUMBER_OF_FAILED_BATCH_JOBS_CREATION = "numberOfBatchJobsCreationFailed";

protected static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
protected final MLProcessorConfig mlProcessorConfig;
protected final AwsCredentialsSupplier awsCredentialsSupplier;
protected final Counter numberOfBatchJobsSuccessCounter;
protected final Counter numberOfBatchJobsFailedCounter;

// Constructor
public AbstractBatchJobCreator(MLProcessorConfig mlProcessorConfig,
AwsCredentialsSupplier awsCredentialsSupplier,
final PluginMetrics pluginMetrics) {
this.mlProcessorConfig = mlProcessorConfig;
this.awsCredentialsSupplier = awsCredentialsSupplier;
this.numberOfBatchJobsSuccessCounter = pluginMetrics.counter(NUMBER_OF_SUCCESSFUL_BATCH_JOBS_CREATION);
this.numberOfBatchJobsFailedCounter = pluginMetrics.counter(NUMBER_OF_FAILED_BATCH_JOBS_CREATION);
}

// Add common logic here that both subclasses can share
public void incrementSuccessCounter() {
numberOfBatchJobsSuccessCounter.increment();
}

public void incrementFailureCounter() {
numberOfBatchJobsFailedCounter.increment();
}

// Abstract methods for batch job creation, specific to the implementations
public abstract void createMLBatchJob(Collection<Record<Event>> records);

}
Loading

0 comments on commit 26b5f58

Please sign in to comment.