Skip to content

Commit

Permalink
Merge pull request #725 from danielgatis/remove-explicity-providers
Browse files Browse the repository at this point in the history
refactor: remove unused providers parameter from session constructors
  • Loading branch information
danielgatis authored Feb 21, 2025
2 parents 30450d5 + 172404b commit 9079508
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 48 deletions.
7 changes: 2 additions & 5 deletions rembg/session_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
from .sessions.u2net import U2netSession


def new_session(
model_name: str = "u2net", providers=None, *args, **kwargs
) -> BaseSession:
def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession:
"""
Create a new session object based on the specified model name.
Expand All @@ -21,7 +19,6 @@ def new_session(
Parameters:
model_name (str): The name of the model.
providers: The providers for the session.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Expand All @@ -41,4 +38,4 @@ def new_session(
sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
sess_opts.intra_op_num_threads = int(os.environ["OMP_NUM_THREADS"])

return session_class(model_name, sess_opts, providers, *args, **kwargs)
return session_class(model_name, sess_opts, *args, **kwargs)
21 changes: 1 addition & 20 deletions rembg/sessions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,11 @@
class BaseSession:
"""This is a base class for managing a session with a machine learning model."""

def __init__(
self,
model_name: str,
sess_opts: ort.SessionOptions,
providers=None,
*args,
**kwargs
):
def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
"""Initialize an instance of the BaseSession class."""
self.model_name = model_name

self.providers = []

_providers = ort.get_available_providers()
if providers:
for provider in providers:
if provider in _providers:
self.providers.append(provider)
else:
self.providers.extend(_providers)

self.inner_session = ort.InferenceSession(
str(self.__class__.download_models(*args, **kwargs)),
providers=self.providers,
sess_options=sess_opts,
)

Expand Down
13 changes: 0 additions & 13 deletions rembg/sessions/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def __init__(
self,
model_name: str,
sess_opts: ort.SessionOptions,
providers=None,
*args,
**kwargs,
):
Expand All @@ -102,25 +101,13 @@ def __init__(
"""
self.model_name = model_name

valid_providers = []
available_providers = ort.get_available_providers()

if providers:
for provider in providers or []:
if provider in available_providers:
valid_providers.append(provider)
else:
valid_providers.extend(available_providers)

paths = self.__class__.download_models(*args, **kwargs)
self.encoder = ort.InferenceSession(
str(paths[0]),
providers=valid_providers,
sess_options=sess_opts,
)
self.decoder = ort.InferenceSession(
str(paths[1]),
providers=valid_providers,
sess_options=sess_opts,
)

Expand Down
12 changes: 2 additions & 10 deletions rembg/sessions/u2net_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,13 @@
class U2netCustomSession(BaseSession):
"""This is a class representing a custom session for the U2net model."""

def __init__(
self,
model_name: str,
sess_opts: ort.SessionOptions,
providers=None,
*args,
**kwargs
):
def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
"""
Initialize a new U2netCustomSession object.
Parameters:
model_name (str): The name of the model.
sess_opts (ort.SessionOptions): The session options.
providers: The providers.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Expand All @@ -38,7 +30,7 @@ def __init__(
if model_path is None:
raise ValueError("model_path is required")

super().__init__(model_name, sess_opts, providers, *args, **kwargs)
super().__init__(model_name, sess_opts, *args, **kwargs)

def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
"""
Expand Down

0 comments on commit 9079508

Please sign in to comment.