Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for WASM based tSNE #207

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
180,008 changes: 180,008 additions & 0 deletions examples/data.json

Large diffs are not rendered by default.

392 changes: 392 additions & 0 deletions package-lock.json

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@
"unist-util-visit": "^5.0.0",
"vite": "^4.3.3",
"vite-plugin-svgr": "^2.4.0",
"vite-plugin-top-level-await": "^1.4.4",
"vite-plugin-wasm": "^3.3.0",
"wasm-dist-bhtsne": "^1.1.0",
"wasm-feature-detect": "^1.6.2",
"web-vitals": "^2.1.4"
},
"scripts": {
Expand Down Expand Up @@ -91,4 +95,4 @@
"vite-plugin-eslint": "^1.8.1",
"vitest": "^0.30.1"
}
}
}
293 changes: 182 additions & 111 deletions src/components/FilterEditorWindow/config/RequestFromCode.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,53 @@
import axios from 'axios';
import { bigIntJSON } from '../../../common/bigIntJSON';
import { tsneConfig } from '../../VisualizeChart/tsneConfig';

// function parseDataToRequest(reqBody) {
// // Validate color_by
// if (reqBody.color_by) {
// const colorBy = reqBody.color_by;

// if (typeof colorBy === 'string') {
// // Parse into payload variant
// reqBody.color_by = {
// payload: colorBy,
// };
// } else {
// // Check we only have one of the options: payload, or discover_score
// const options = [colorBy.payload, colorBy.discover_score];
// const optionsCount = options.filter((option) => option).length;
// if (optionsCount !== 1) {
// return {
// reqBody: reqBody,
// error: '`color_by`: Only one of `payload`, or `discover_score` can be used',
// };
// }

// // Put search arguments in main request body
// if (colorBy.discover_score) {
// reqBody = {
// ...reqBody,
// ...colorBy.discover_score,
// };
// }
// }
// }

// // Set with_vector name
// if (reqBody.vector_name) {
// reqBody.with_vector = [reqBody.vector_name];
// return {
// reqBody: reqBody,
// error: null,
// };
// } else if (!reqBody.vector_name) {
// reqBody.with_vector = true;
// return {
// reqBody: reqBody,
// error: null,
// };
// }
// }

function parseDataToRequest(reqBody) {
// Validate color_by
Expand All @@ -11,69 +59,92 @@ function parseDataToRequest(reqBody) {
reqBody.color_by = {
payload: colorBy,
};
} else {
// Check we only have one of the options: payload, or discover_score
const options = [colorBy.payload, colorBy.discover_score];
const optionsCount = options.filter((option) => option).length;
if (optionsCount !== 1) {
return {
reqBody: reqBody,
error: '`color_by`: Only one of `payload`, or `discover_score` can be used',
};
}

// Put search arguments in main request body
if (colorBy.discover_score) {
reqBody = {
...reqBody,
...colorBy.discover_score,
};
}
}
}

// Set with_vector name
if (reqBody.vector_name) {
reqBody.with_vector = [reqBody.vector_name];
return {
reqBody: reqBody,
error: null,
};
} else if (!reqBody.vector_name) {
reqBody.with_vector = true;
return {
reqBody: reqBody,
error: null,
};
}
reqBody.limit = tsneConfig.number_of_neighbors;
return {
reqBody: reqBody,
error: null,
};
}

// export async function requestFromCode(dataRaw, collectionName) {
// const data = parseDataToRequest(dataRaw);
// // Sending request
// const colorBy = data.reqBody.color_by;
// if (colorBy?.payload) {
// return await actionFromCode(collectionName, data, 'scroll');
// }
// if (colorBy?.discover_score) {
// return discoverFromCode(collectionName, data);
// }
// return await actionFromCode(collectionName, data, 'scroll');
// }

export async function requestFromCode(dataRaw, collectionName) {
const data = parseDataToRequest(dataRaw);
// Sending request
const colorBy = data.reqBody.color_by;
if (colorBy?.payload) {
return await actionFromCode(collectionName, data, 'scroll');
}
if (colorBy?.discover_score) {
return discoverFromCode(collectionName, data);
}
return await actionFromCode(collectionName, data, 'scroll');
return await actionFromCode(collectionName, data);
}

// async function actionFromCode(collectionName, data, action) {
// try {
// const response = await axios({
// method: 'POST',
// url: `collections/${collectionName}/points/${action || 'scroll'}`,
// data: data.reqBody,
// });
// response.data.color_by = data.reqBody.color_by;
// response.data.vector_name = data.reqBody.vector_name;
// response.data.result.points = response.data.result.points.filter((point) => Object.keys(point.vector).length > 0);
// return {
// data: response.data,
// error: null,
// };
// } catch (err) {
// return {
// data: null,
// error: err.response?.data?.status ? err.response?.data?.status : err,
// };
// }
// }

async function getDistanceType(collectionName) {
const response = await axios({
method: 'GET',
url: `collections/${collectionName}`,
});
return response.data.result.config.params.vectors.distance ?? "";
}

async function actionFromCode(collectionName, data, action) {
async function actionFromCode(collectionName, data) {
try {
const response = await axios({
method: 'POST',
url: `collections/${collectionName}/points/${action || 'scroll'}`,
url: `collections/${collectionName}/points/search/matrix/offsets`,
data: data.reqBody,
});
// fetch points with given ids
const vecResponse = await axios({
method: 'POST',
url: `collections/${collectionName}/points`,
data: {
ids: response.data.result.ids,
with_payload: true,
}
});
// get distance type
response.data.distance_type = await getDistanceType(collectionName);

response.data.color_by = data.reqBody.color_by;
response.data.vector_name = data.reqBody.vector_name;
response.data.result.points = response.data.result.points.filter((point) => Object.keys(point.vector).length > 0);
response.data.result.points = vecResponse.data.result;

return {
data: response.data,
error: null,
};

} catch (err) {
return {
data: null,
Expand All @@ -82,73 +153,73 @@ async function actionFromCode(collectionName, data, action) {
}
}

async function discoverFromCode(collectionName, data) {
// Do 20/80 split. 20% of the points will be returned with the query
// and 80 % will be returned with random sampling
const queryLimit = Math.floor(data.reqBody.limit * 0.2);
const randomLimit = data.reqBody.limit - queryLimit;
data.reqBody.limit = queryLimit;
data.reqBody.with_payload = true;

const queryResponse = await actionFromCode(collectionName, data, 'discover');
if (queryResponse.error) {
return {
data: null,
error: queryResponse.error,
};
}

// Add tag to know which points were returned by the query
queryResponse.data.result = queryResponse.data.result.map((point) => ({
...point,
from_query: true,
}));

// Get "random" points ids.
// There is no sampling endpoint in Qdrant yet, so for now we just scroll excluding the previous results
const idsToExclude = queryResponse.data.result.map((point) => point.id);

const originalFilter = data.reqBody.filter;
const mustNotFilter = [{ has_id: idsToExclude }];
data.reqBody.filter = originalFilter || {};
data.reqBody.filter.must_not = mustNotFilter.concat(data.reqBody.filter.must_not ?? []);

data.reqBody.limit = randomLimit;
const randomResponse = await actionFromCode(collectionName, data, 'scroll');
if (randomResponse.error) {
return {
data: null,
error: randomResponse.error,
};
}

// Then score these random points
const idsToInclude = randomResponse.data.result.points.map((point) => point.id);
const mustFilter = [{ has_id: idsToInclude }];
data.reqBody.filter = originalFilter || {};
data.reqBody.filter.must = mustFilter.concat(data.reqBody.filter.must || []);

const scoredRandomResponse = await actionFromCode(collectionName, data, 'discover');
if (scoredRandomResponse.error) {
return {
data: null,
error: scoredRandomResponse.error,
};
}

// Concat both results
const points = queryResponse.data.result.concat(scoredRandomResponse.data.result);

return {
data: {
...queryResponse.data,
result: {
points: points,
},
},
error: null,
};
}
// async function discoverFromCode(collectionName, data) {
// // Do 20/80 split. 20% of the points will be returned with the query
// // and 80 % will be returned with random sampling
// const queryLimit = Math.floor(data.reqBody.limit * 0.2);
// const randomLimit = data.reqBody.limit - queryLimit;
// data.reqBody.limit = queryLimit;
// data.reqBody.with_payload = true;

// const queryResponse = await actionFromCode(collectionName, data, 'discover');
// if (queryResponse.error) {
// return {
// data: null,
// error: queryResponse.error,
// };
// }

// // Add tag to know which points were returned by the query
// queryResponse.data.result = queryResponse.data.result.map((point) => ({
// ...point,
// from_query: true,
// }));

// // Get "random" points ids.
// // There is no sampling endpoint in Qdrant yet, so for now we just scroll excluding the previous results
// const idsToExclude = queryResponse.data.result.map((point) => point.id);

// const originalFilter = data.reqBody.filter;
// const mustNotFilter = [{ has_id: idsToExclude }];
// data.reqBody.filter = originalFilter || {};
// data.reqBody.filter.must_not = mustNotFilter.concat(data.reqBody.filter.must_not ?? []);

// data.reqBody.limit = randomLimit;
// const randomResponse = await actionFromCode(collectionName, data, 'scroll');
// if (randomResponse.error) {
// return {
// data: null,
// error: randomResponse.error,
// };
// }

// // Then score these random points
// const idsToInclude = randomResponse.data.result.points.map((point) => point.id);
// const mustFilter = [{ has_id: idsToInclude }];
// data.reqBody.filter = originalFilter || {};
// data.reqBody.filter.must = mustFilter.concat(data.reqBody.filter.must || []);

// const scoredRandomResponse = await actionFromCode(collectionName, data, 'discover');
// if (scoredRandomResponse.error) {
// return {
// data: null,
// error: scoredRandomResponse.error,
// };
// }

// // Concat both results
// const points = queryResponse.data.result.concat(scoredRandomResponse.data.result);

// return {
// data: {
// ...queryResponse.data,
// result: {
// points: points,
// },
// },
// error: null,
// };
// }

export function codeParse(codeText) {
// Parse JSON
Expand Down
Loading