Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ml processor for offline batch inference #5507

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions data-prepper-plugins/ml-processor/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@

# ml Processor

This plugin enables you to send data from your Data Prepper pipeline directly to ml-commons for machine learning related activities.

## Usage
```aidl
lambda-pipeline:
...
processor:
- ml:
host: "https://search-xunzh-ml-tests-ihx7htldf7nvo2gdg25m6ehthq.us-east-1.es.amazonaws.com"
aws_sigv4: true
action_type: "batch_predict"
service_name: "bedrock"
model_id: "6ifdTZUBEBlFHJzvGSxO"
output_path: "s3://offlinebatch/bedrock-multisource/output-multisource/"
aws:
region: "us-east-1"
ml_when: /bucket == "offlinebatch"

```
`model_id` as the model id that is registered in the OpenSearch ml-commons plugin.
`service_name` as the remote AI service platform to process then batch job.
`output_path` as the batch job output location of the S3 Uri

# Metrics

### Counter
- `mlProcessorSuccessRequests`: measures total number of requests received and processed successfully by ml-processor.
- `mlProcessorFailedRequests`: measures total number of requests failed by ml-processor.
- `numberOfBatchJobsCreationSucceeded`: measures total number of batch jobs successfully created (200 response status code) by OpenSearch ml-commons API.
- `numberOfBatchJobsCreationFailed`: measures total number of batch jobs failed in creation by OpenSearch ml-commons API.

## Developer Guide

The integration tests for this plugin do not run as part of the Data Prepper build.
The following command runs the integration tests:
69 changes: 69 additions & 0 deletions data-prepper-plugins/ml-processor/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

dependencies {
implementation project(':data-prepper-api')
implementation project(path: ':data-prepper-plugins:common')
implementation project(':data-prepper-plugins:aws-plugin-api')
implementation project(':data-prepper-plugins:failures-common')
implementation project(':data-prepper-plugins:parse-json-processor')
implementation 'io.micrometer:micrometer-core'
implementation 'com.fasterxml.jackson.core:jackson-core'
implementation 'com.fasterxml.jackson.core:jackson-databind'
implementation 'software.amazon.awssdk:sdk-core:2.x.x'
implementation 'software.amazon.awssdk:sts'
implementation 'org.hibernate.validator:hibernate-validator:8.0.1.Final'
implementation 'com.fasterxml.jackson.dataformat:jackson-dataformat-yaml'
implementation 'org.json:json'
implementation libs.commons.lang3
implementation 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310'
implementation 'org.projectlombok:lombok:1.18.22'
implementation 'software.amazon.awssdk:s3'
compileOnly 'org.projectlombok:lombok:1.18.20'
annotationProcessor 'org.projectlombok:lombok:1.18.20'
testCompileOnly 'org.projectlombok:lombok:1.18.20'
testAnnotationProcessor 'org.projectlombok:lombok:1.18.20'
testImplementation 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310'
testImplementation project(':data-prepper-test-common')
testImplementation testLibs.slf4j.simple
testImplementation 'org.mockito:mockito-core:4.6.1'
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.2'
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.2'
}

test {
useJUnitPlatform()
}

sourceSets {
integrationTest {
java {
compileClasspath += main.output + test.output
runtimeClasspath += main.output + test.output
srcDir file('src/integrationTest/java')
}
resources.srcDir file('src/integrationTest/resources')
}
}

configurations {
integrationTestImplementation.extendsFrom testImplementation
integrationTestRuntime.extendsFrom testRuntime
}

task integrationTest(type: Test) {
group = 'verification'
testClassesDirs = sourceSets.integrationTest.output.classesDirs

useJUnitPlatform()

classpath = sourceSets.integrationTest.runtimeClasspath

systemProperty 'log4j.configurationFile', 'src/test/resources/log4j2.properties'

filter {
includeTestsMatching '*IT'
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.ml.processor;

import io.micrometer.core.instrument.Counter;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.expression.ExpressionEvaluator;
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin;
import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.processor.AbstractProcessor;
import org.opensearch.dataprepper.model.processor.Processor;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.plugins.ml.processor.common.MLBatchJobCreator;
import org.opensearch.dataprepper.plugins.ml.processor.common.MLBatchJobCreatorFactory;
import org.opensearch.dataprepper.plugins.ml.processor.configuration.ServiceName;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY;

@DataPrepperPlugin(name = "ml", pluginType = Processor.class, pluginConfigurationType = MLProcessorConfig.class)
public class MLProcessor extends AbstractProcessor<Record<Event>, Record<Event>> {
public static final Logger LOG = LoggerFactory.getLogger(MLProcessor.class);
public static final String NUMBER_OF_ML_PROCESSOR_SUCCESS = "mlProcessorSuccessfullyCreated";
public static final String NUMBER_OF_ML_PROCESSOR_FAILED = "mlProcessorFailedToCreated";

private final String whenCondition;
private MLBatchJobCreator mlBatchJobCreator;
private final Counter numberOfMLProcessorSuccessCounter;
private final Counter numberOfMLProcessorFailedCounter;
private final ExpressionEvaluator expressionEvaluator;

@DataPrepperPluginConstructor
public MLProcessor(final MLProcessorConfig mlProcessorConfig, final PluginMetrics pluginMetrics, final AwsCredentialsSupplier awsCredentialsSupplier, final ExpressionEvaluator expressionEvaluator) {
super(pluginMetrics);
this.whenCondition = mlProcessorConfig.getWhenCondition();
ServiceName serviceName = mlProcessorConfig.getServiceName();
this.numberOfMLProcessorSuccessCounter = pluginMetrics.counter(
NUMBER_OF_ML_PROCESSOR_SUCCESS);
this.numberOfMLProcessorFailedCounter = pluginMetrics.counter(
NUMBER_OF_ML_PROCESSOR_FAILED);
this.expressionEvaluator = expressionEvaluator;

// Use factory to get the appropriate job creator
mlBatchJobCreator = MLBatchJobCreatorFactory.getJobCreator(serviceName, mlProcessorConfig, awsCredentialsSupplier, pluginMetrics);
}

@Override
public Collection<Record<Event>> doExecute(Collection<Record<Event>> records) {
// reads from input - S3 input
if (records.size() == 0)
return records;

List<Record<Event>> recordsToMlCommons = new ArrayList<>();
for (Record<Event> record : records) {
final Event event = record.getData();
if (whenCondition != null && !expressionEvaluator.evaluateConditional(whenCondition,
event)) {
continue;
}
recordsToMlCommons.add(record);
}

try {
mlBatchJobCreator.createMLBatchJob(recordsToMlCommons);
numberOfMLProcessorSuccessCounter.increment();
} catch (Exception e) {
LOG.error(NOISY, e.getMessage(), e);
numberOfMLProcessorFailedCounter.increment();
}
return records;
}

@Override
public void prepareForShutdown() {
}

@Override
public boolean isReadyForShutdown() {
return true;
}

@Override
public void shutdown() {
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.ml.processor;

import com.fasterxml.jackson.annotation.JsonClassDescription;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import com.fasterxml.jackson.annotation.JsonPropertyOrder;
import jakarta.validation.Valid;
import jakarta.validation.constraints.NotNull;
import lombok.Getter;
import org.opensearch.dataprepper.model.annotations.ExampleValues;
import org.opensearch.dataprepper.plugins.ml.processor.configuration.ActionType;
import org.opensearch.dataprepper.plugins.ml.processor.configuration.AwsAuthenticationOptions;
import org.opensearch.dataprepper.plugins.ml.processor.configuration.ServiceName;

@Getter
@JsonPropertyOrder
@JsonClassDescription("The <code>ml</code> processor enables invocation of the ml-commons plugin in OpenSearch service within your pipeline in order to process events. " +
"It supports both synchronous and asynchronous invocations based on your use case.")
public class MLProcessorConfig {

@JsonProperty("aws")
@NotNull
@Valid
private AwsAuthenticationOptions awsAuthenticationOptions;

@JsonPropertyDescription("action type defines the way we want to invoke ml-commons in the predict API")
@JsonProperty("action_type")
private ActionType actionType = ActionType.BATCH_PREDICT;

@JsonPropertyDescription("AI service hosting the remote model for ML Commons predictions")
@JsonProperty("service_name")
private ServiceName serviceName = ServiceName.SAGEMAKER;

@JsonPropertyDescription("defines the OpenSearch host url to be invoked")
@JsonProperty("host")
private String hostUrl;

@JsonPropertyDescription("defines the model id to be invoked in ml-commons")
@JsonProperty("model_id")
private String modelId;

@JsonPropertyDescription("defines the S3 location to write the offline model responses to")
@JsonProperty("output_path")
private String outputPath;

@JsonProperty("aws_sigv4")
private boolean awsSigv4;

@JsonPropertyDescription("Defines a condition for event to use this processor.")
@ExampleValues({
@ExampleValues.Example(value = "/some_key == null", description = "The processor will only run on events where this condition evaluates to true.")
})
@JsonProperty("ml_when")
private String whenCondition;

public ActionType getActionType() {
return actionType;
}

public String getModelId() { return modelId; }

public String getHostUrl() { return hostUrl; }

public String getWhenCondition() {
return whenCondition;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.ml.processor.client;

public class S3ClientFactory {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.ml.processor.common;

import com.fasterxml.jackson.databind.ObjectMapper;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import io.micrometer.core.instrument.Counter;
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.plugins.ml.processor.MLProcessorConfig;

import java.util.Collection;

public abstract class AbstractBatchJobCreator implements MLBatchJobCreator {
public static final String NUMBER_OF_SUCCESSFUL_BATCH_JOBS_CREATION = "numberOfBatchJobsCreationSucceeded";
public static final String NUMBER_OF_FAILED_BATCH_JOBS_CREATION = "numberOfBatchJobsCreationFailed";

protected static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
protected final MLProcessorConfig mlProcessorConfig;
protected final AwsCredentialsSupplier awsCredentialsSupplier;
protected final Counter numberOfBatchJobsSuccessCounter;
protected final Counter numberOfBatchJobsFailedCounter;

// Constructor
public AbstractBatchJobCreator(MLProcessorConfig mlProcessorConfig,
AwsCredentialsSupplier awsCredentialsSupplier,
final PluginMetrics pluginMetrics) {
this.mlProcessorConfig = mlProcessorConfig;
this.awsCredentialsSupplier = awsCredentialsSupplier;
this.numberOfBatchJobsSuccessCounter = pluginMetrics.counter(NUMBER_OF_SUCCESSFUL_BATCH_JOBS_CREATION);
this.numberOfBatchJobsFailedCounter = pluginMetrics.counter(NUMBER_OF_FAILED_BATCH_JOBS_CREATION);
}

// Add common logic here that both subclasses can share
public void incrementSuccessCounter() {
numberOfBatchJobsSuccessCounter.increment();
}

public void incrementFailureCounter() {
numberOfBatchJobsFailedCounter.increment();
}

// Abstract methods for batch job creation, specific to the implementations
public abstract void createMLBatchJob(Collection<Record<Event>> records);

}
Loading
Loading