Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse quantizelinear for skip layers using multioutput fusions #3791

Open
2 tasks
shivadbhavsar opened this issue Jan 31, 2025 · 1 comment
Open
2 tasks

Fuse quantizelinear for skip layers using multioutput fusions #3791

shivadbhavsar opened this issue Jan 31, 2025 · 1 comment
Assignees

Comments

@shivadbhavsar
Copy link
Contributor

shivadbhavsar commented Jan 31, 2025

Follow up from PR #3782
Ex resnet quantized graph after above PR:

NEW:

q -> conv -> dq -> add -> relu -> q .......... -> q -> conv -> dq -> add -> relu -> q
			   |	                                                    |
			   -> step -> q -----------------------------------------> concat -> conv -> ...

Doing some experimental work, it turns out that we get a slight perf boost from moving the skip-connection quantize op before the step op, and fusing it into the previous conv-pointwise kernel. This should probably be done as 2 steps:

  • 1. Write pass for multioutput fusion that can fuse to mlir_quant_convolution_dequantizelinear_dequantizelinear_add_add_relu_quantizelinear_quantizelinear
  • 2. Write pass to move q before step when this fusion is possible

These should be done in this order since swapping the order of quantize op and step op is generally not preferable unless doing it for this fusion.

@shivadbhavsar
Copy link
Contributor Author

shivadbhavsar commented Jan 31, 2025

Here is some example code used for the experiment.

  1. MLIR multi output fusion (in fuse_mlir pass) (This needs to be refactored to account for incoming changes: Fuse reshapes on pointwise inputs for mlir output fusion #3569 and Fuse multiple outputs for pointwise and reductions #3752 (or similar))
struct find_mlir_multi_pointwise
{
    mlir_mode conv_mode = mlir_mode::none;
    mlir_mode dot_mode  = mlir_mode::none;
    auto matcher() const
    {
        return match::name("gpu::mlir_op")(match::all_of[match::outputs()](mlir_pointwise));
    }

    void apply(module_pass_manager& mpm, const match::matcher_result& r) const
    {
        auto ins       = r.result;
        auto* mlir_mod = ins->module_inputs().front();
        auto pw_inss   = ins->outputs();

        std::string module_name = mlir_mod->name();
        std::transform(
            pw_inss.begin(),
            pw_inss.end(),
            join_back_inserter(module_name),
            [](instruction_ref pw) { return ":" + pw->module_inputs().front()->name(); });
        module_ref mm = mpm.create_module(module_name);

        std::unordered_map<instruction_ref, instruction_ref> map_main_to_mm;
        mm->add_params(ins->inputs(), &map_main_to_mm);
        std::unordered_map<instruction_ref, instruction_ref> map_mlir_mod_to_mm(map_main_to_mm);
        auto original_return = mm->fuse(*mlir_mod, ins->inputs(), &map_mlir_mod_to_mm).front();
        map_main_to_mm[ins]  = original_return;

        // single pointwise output should already be fused in
        assert(pw_inss.size() > 1);
        std::vector<instruction_ref> new_returns;
        for(auto pw_ins : pw_inss)
        {
            auto* pm = pw_ins->module_inputs().front();
            std::unordered_map<instruction_ref, instruction_ref> lit_map =
                create_param_map_with_literals(mm, pm, pw_ins->get_shape());

            mm->add_params(pw_ins->inputs(), &map_main_to_mm);
            map_main_to_mm.insert(lit_map.begin(), lit_map.end());
            std::unordered_map<instruction_ref, instruction_ref> map_pm_to_mm(map_main_to_mm);
            auto fused_pw_out = mm->fuse(*pm, pw_ins->inputs(), &map_pm_to_mm).front();

            map_main_to_mm[pw_ins] = fused_pw_out;
            new_returns.push_back(fused_pw_out);
        }

        mm->add_return(new_returns);
        auto map_mm_to_main = invert_map_ins(map_main_to_mm);
        auto new_inputs     = mm->get_inputs(map_mm_to_main);

        mm->set_bypass();
        auto fused_ins = mpm.get_module().insert_instruction(
            ins, ins->get_operator(), mlir_contiguous(mpm, new_inputs), {mm});
        mpm.get_module().debug_print(fused_ins);

        size_t out_idx = 0;
        for(auto ret : new_returns)
        {
            auto original_ins = map_mm_to_main[ret];
            mpm.get_module().replace_instruction(
                original_ins, migraphx::make_op("get_tuple_elem", {{"index", out_idx}}), fused_ins);
            out_idx++;
        }
    }
};
  1. Moving q before step (in simplify_qdq pass) - Currently only written for scalar scales and zero-points, can be generalized
struct match_step_qlinear
{
    auto matcher() const
    {
        auto any_pointwise_input = match::any_of[match::inputs()](match::pointwise().bind("pw"));
        return match::name("quantizelinear")(match::arg(0)(
            match::name("step")(match::used_once(), any_pointwise_input).bind("step")));
    }

    auto get_prebroadcast_qparam(instruction_ref i, size_t channels) const
    {
        instruction_ref top = i;
        while(top->get_shape().elements() != channels)
        {
            assert(top->inputs().size() == 1);
            top = top->inputs()[0];
        }
        return top;
    }

    void apply(module& m, const match::matcher_result& r) const
    {
        auto ins      = r.result;
        auto step_ins = r.instructions["step"];
        auto pw_ins   = r.instructions["pw"];

        assert(ins->inputs().size() == 3);
        auto scale = ins->inputs()[1];
        auto zp    = ins->inputs()[2];

        if(not(scale->get_shape().scalar() and zp->get_shape().scalar()))
            return;

        auto sscale = get_prebroadcast_qparam(scale, 1);
        auto szp    = get_prebroadcast_qparam(zp, 1);

        auto scale_mb = m.insert_instruction(
            step_ins,
            make_op("multibroadcast", {{"out_lens", pw_ins->get_shape().lens()}}),
            {sscale});
        auto zp_mb = m.insert_instruction(
            step_ins, make_op("multibroadcast", {{"out_lens", pw_ins->get_shape().lens()}}), {szp});

        auto new_q = m.insert_instruction(step_ins, ins->get_operator(), {pw_ins, scale_mb, zp_mb});

        m.replace_instruction(ins, step_ins->get_operator(), {new_q});
    }
};

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant