Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Fix ruff rule E721: type-comparison #49919

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ ignore = [
# TODO(MortalHappiness): Remove the following rules from the ignore list
# The above are rules ignored originally in flake8
# The following are rules ignored in ruff
"E721",
"F841",
"B018",
"B023",
Expand Down
2 changes: 1 addition & 1 deletion python/ray/_private/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -1627,7 +1627,7 @@ def start_raylet(
Returns:
ProcessInfo for the process that was started.
"""
assert node_manager_port is not None and type(node_manager_port) == int
assert node_manager_port is not None and type(node_manager_port) is int

if use_valgrind and use_profiler:
raise ValueError("Cannot use valgrind and profiler at the same time.")
Expand Down
2 changes: 1 addition & 1 deletion python/ray/_private/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ def wait_until_succeeded_without_exception(
Return:
Whether exception occurs within a timeout.
"""
if type(exceptions) != tuple:
if isinstance(type(exceptions), tuple):
raise Exception("exceptions arguments should be given as a tuple")

time_elapsed = 0
Expand Down
12 changes: 6 additions & 6 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,7 @@ def _preprocess(self) -> None:
"the driver cannot participate in the NCCL group"
)

if type(dag_node.type_hint) == ChannelOutputType:
if type(dag_node.type_hint) is ChannelOutputType:
# No type hint specified by the user. Replace
# with the default type hint for this DAG.
dag_node.with_type_hint(self._default_type_hint)
Expand Down Expand Up @@ -2593,16 +2593,16 @@ def get_channel_details(
if channel in self._channel_dict and self._channel_dict[channel] != channel:
channel = self._channel_dict[channel]
channel_details += f"\n{type(channel).__name__}"
if type(channel) == CachedChannel:
if type(channel) is CachedChannel:
channel_details += f", {channel._channel_id[:6]}..."
# get inner channel
if (
type(channel) == CompositeChannel
type(channel) is CompositeChannel
and downstream_actor_id in channel._channel_dict
):
inner_channel = channel._channel_dict[downstream_actor_id]
channel_details += f"\n{type(inner_channel).__name__}"
if type(inner_channel) == IntraProcessChannel:
if type(inner_channel) is IntraProcessChannel:
channel_details += f", {inner_channel._channel_id[:6]}..."
return channel_details

Expand Down Expand Up @@ -2766,7 +2766,7 @@ def visualize(
task.output_channels[0],
(
downstream_node._get_actor_handle()._actor_id.hex()
if type(downstream_node) == ClassMethodNode
if type(downstream_node) is ClassMethodNode
else self._proxy_actor._actor_id.hex()
),
)
Expand All @@ -2784,7 +2784,7 @@ def visualize(
task.dag_node._get_actor_handle()._actor_id.hex(),
)
dot.edge(str(idx), str(downstream_idx), label=edge_label)
if type(task.dag_node) == InputAttributeNode:
if type(task.dag_node) is InputAttributeNode:
# Add an edge from the InputAttributeNode to the InputNode
dot.edge(str(self.input_task_idx), str(idx))
dot.render(filename, view=view)
Expand Down
18 changes: 9 additions & 9 deletions python/ray/dashboard/tests/test_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ def test_immutable_types():
d["list"][0] = {str(i): i for i in range(1000)}
d["dict"] = {str(i): i for i in range(1000)}
immutable_dict = dashboard_utils.make_immutable(d)
assert type(immutable_dict) == dashboard_utils.ImmutableDict
assert type(immutable_dict) is dashboard_utils.ImmutableDict
assert immutable_dict == dashboard_utils.ImmutableDict(d)
assert immutable_dict == d
assert dashboard_utils.ImmutableDict(immutable_dict) == immutable_dict
Expand All @@ -799,14 +799,14 @@ def test_immutable_types():
assert "512" in d["dict"]

# Test type conversion
assert type(dict(immutable_dict)["list"]) == dashboard_utils.ImmutableList
assert type(list(immutable_dict["list"])[0]) == dashboard_utils.ImmutableDict
assert type(dict(immutable_dict)["list"]) is dashboard_utils.ImmutableList
assert type(list(immutable_dict["list"])[0]) is dashboard_utils.ImmutableDict

# Test json dumps / loads
json_str = json.dumps(immutable_dict, cls=dashboard_optional_utils.CustomEncoder)
deserialized_immutable_dict = json.loads(json_str)
assert type(deserialized_immutable_dict) == dict
assert type(deserialized_immutable_dict["list"]) == list
assert type(deserialized_immutable_dict) is dict
assert type(deserialized_immutable_dict["list"]) is list
assert immutable_dict.mutable() == deserialized_immutable_dict
dashboard_optional_utils.rest_response(True, "OK", data=immutable_dict)
dashboard_optional_utils.rest_response(True, "OK", **immutable_dict)
Expand All @@ -819,12 +819,12 @@ def test_immutable_types():

# Test get default immutable
immutable_default_value = immutable_dict.get("not exist list", [1, 2])
assert type(immutable_default_value) == dashboard_utils.ImmutableList
assert type(immutable_default_value) is dashboard_utils.ImmutableList

# Test recursive immutable
assert type(immutable_dict["list"]) == dashboard_utils.ImmutableList
assert type(immutable_dict["dict"]) == dashboard_utils.ImmutableDict
assert type(immutable_dict["list"][0]) == dashboard_utils.ImmutableDict
assert type(immutable_dict["list"]) is dashboard_utils.ImmutableList
assert type(immutable_dict["dict"]) is dashboard_utils.ImmutableDict
assert type(immutable_dict["list"][0]) is dashboard_utils.ImmutableDict

# Test exception
with pytest.raises(TypeError):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def _cast_large_list_to_list(batch: pyarrow.Table):

for column_name in old_schema.names:
field_type = old_schema.field(column_name).type
if type(field_type) == pyarrow.lib.LargeListType:
if type(field_type) is pyarrow.lib.LargeListType:
value_type = field_type.value_type

if value_type == pyarrow.large_binary():
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -1442,7 +1442,7 @@ def empty_pandas(batch):
block_refs = _ref_bundles_iterator_to_block_refs_list(bundles)

assert len(block_refs) == 1
assert type(ray.get(block_refs[0])) == pd.DataFrame
assert type(ray.get(block_refs[0])) is pd.DataFrame


def test_map_with_objects_and_tensors(ray_start_regular_shared):
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/tests/test_numpy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def do_map_batches(data):


def assert_structure_equals(a, b):
assert type(a) == type(b), (type(a), type(b))
assert type(a) is type(b), (type(a), type(b))
assert type(a[0]) == type(b[0]), (type(a[0]), type(b[0])) # noqa: E721
assert a.dtype == b.dtype
assert a.shape == b.shape
Expand Down
2 changes: 1 addition & 1 deletion python/ray/serve/multiplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ async def load_model(self, model_id: str) -> Any:
The user-constructed model object.
"""

if type(model_id) != str:
if type(model_id) is not str:
raise TypeError("The model ID must be a string.")

if not model_id:
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tests/gcp/test_gcp_tpu_command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def test_max_active_connections_env_var():
cmd_runner = TPUCommandRunner(**args)
os.environ[ray_constants.RAY_TPU_MAX_CONCURRENT_CONNECTIONS_ENV_VAR] = "1"
num_connections = cmd_runner.num_connections
assert type(num_connections) == int
assert type(num_connections) is int
assert num_connections == 1


Expand Down
2 changes: 1 addition & 1 deletion python/ray/tests/modin/modin_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def df_equals(df1, df2):
if isinstance(df1, pandas.DataFrame) and isinstance(df2, pandas.DataFrame):
if (df1.empty and not df2.empty) or (df2.empty and not df1.empty):
assert False, "One of the passed frames is empty, when other isn't"
elif df1.empty and df2.empty and type(df1) != type(df2):
elif df1.empty and df2.empty and type(df1) is not type(df2):
assert (
False
), f"Empty frames have different types: {type(df1)} != {type(df2)}"
Expand Down
6 changes: 3 additions & 3 deletions python/ray/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,9 +866,9 @@ def temp():
assert ray.get(f.remote(s)) == s

# Test types.
assert ray.get(f.remote(int)) == int
assert ray.get(f.remote(float)) == float
assert ray.get(f.remote(str)) == str
assert ray.get(f.remote(int)) is int
assert ray.get(f.remote(float)) is float
assert ray.get(f.remote(str)) is str

class Foo:
def __init__(self):
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tests/test_client_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_client(address):
if address in ("local", None):
assert isinstance(builder, client_builder._LocalClientBuilder)
else:
assert type(builder) == client_builder.ClientBuilder
assert type(builder) is client_builder.ClientBuilder
assert builder.address == address.replace("ray://", "")


Expand Down
2 changes: 1 addition & 1 deletion python/ray/tests/test_joblib.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_ray_backend(shutdown_only):
from ray.util.joblib.ray_backend import RayBackend

with joblib.parallel_backend("ray"):
assert type(joblib.parallel.get_active_backend()[0]) == RayBackend
assert type(joblib.parallel.get_active_backend()[0]) is RayBackend


def test_svm_single_node(shutdown_only):
Expand Down
8 changes: 4 additions & 4 deletions python/ray/tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
def is_named_tuple(cls):
"""Return True if cls is a namedtuple and False otherwise."""
b = cls.__bases__
if len(b) != 1 or b[0] != tuple:
if len(b) != 1 or b[0] is not tuple:
return False
f = getattr(cls, "_fields", None)
if not isinstance(f, tuple):
return False
return all(type(n) == str for n in f)
return all(type(n) is str for n in f)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -95,8 +95,8 @@ def f(x):
# TODO(rkn): The numpy dtypes currently come back as regular integers
# or floats.
if type(obj).__module__ != "numpy":
assert type(obj) == type(new_obj_1)
assert type(obj) == type(new_obj_2)
assert type(obj) is type(new_obj_1)
assert type(obj) is type(new_obj_2)


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tune/tests/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def assertDictAlmostEqual(a, b):
assert k in b, f"Key {k} not found in {b}"
w = b[k]

assert type(v) == type(w), f"Type {type(v)} is not {type(w)}"
assert type(v) is type(w), f"Type {type(v)} is not {type(w)}"

if isinstance(v, dict):
assert assertDictAlmostEqual(v, w), f"Subdict {v} != {w}"
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/action/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def to_state(self):
@staticmethod
def from_state(ctx: ConnectorContext, params: Any):
assert (
type(params) == list
type(params) is list
Comment on lines 47 to +48
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't line 46 just declar params as type List? @sven1977 ?

), "ActionConnectorPipeline takes a list of connector params."
connectors = []
for state in params:
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/agent/clip_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, ctx: ConnectorContext, sign=False, limit=None):
def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
d = ac_data.data
assert (
type(d) == dict
type(d) is dict
), "Single agent data must be of type Dict[str, TensorStructType]"

if SampleBatch.REWARDS not in d:
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/agent/mean_std_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
d = ac_data.data
assert (
type(d) == dict
type(d) is dict
), "Single agent data must be of type Dict[str, TensorStructType]"
if SampleBatch.OBS in d:
d[SampleBatch.OBS] = self.filter(
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/agent/obs_preproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def is_identity(self):

def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
d = ac_data.data
assert type(d) == dict, (
assert type(d) is dict, (
"Single agent data must be of type Dict[str, TensorStructType] but is of "
"type {}".format(type(d))
)
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/agent/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def to_state(self):
@staticmethod
def from_state(ctx: ConnectorContext, params: List[Any]):
assert (
type(params) == list
type(params) is list
), "AgentConnectorPipeline takes a list of connector params."
connectors = []
for state in params:
Expand Down
2 changes: 1 addition & 1 deletion rllib/connectors/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_agent_connectors_from_config(
clip_rewards = __clip_rewards(config)
if clip_rewards is True:
connectors.append(ClipRewardAgentConnector(ctx, sign=True))
elif type(clip_rewards) == float:
elif type(clip_rewards) is float:
connectors.append(ClipRewardAgentConnector(ctx, limit=abs(clip_rewards)))

if __preprocessing_enabled(config):
Expand Down
4 changes: 2 additions & 2 deletions rllib/env/wrappers/dm_control_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def _spec_to_box(spec):
def extract_min_max(s):
assert s.dtype == np.float64 or s.dtype == np.float32
dim = np.int_(np.prod(s.shape))
if type(s) == specs.Array:
if type(s) is specs.Array:
bound = np.inf * np.ones(dim, dtype=np.float32)
return -bound, bound
elif type(s) == specs.BoundedArray:
elif type(s) is specs.BoundedArray:
zeros = np.zeros(dim, dtype=np.float32)
return s.minimum + zeros, s.maximum + zeros

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_mixin_sampling_episodes(self):
for _ in range(20):
buffer.add(batch)
sample = buffer.sample(2)
assert type(sample) == MultiAgentBatch
assert type(sample) is MultiAgentBatch
results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
# One sample in the episode does not belong the the episode on thus
# gets dropped. Full episodes are of length two.
Expand All @@ -88,7 +88,7 @@ def test_mixin_sampling_sequences(self):
for _ in range(400):
buffer.add(batch)
sample = buffer.sample(10)
assert type(sample) == MultiAgentBatch
assert type(sample) is MultiAgentBatch
results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
self.assertAlmostEqual(np.mean(results), 2 * len(batch), delta=0.1)

Expand All @@ -113,7 +113,7 @@ def test_mixin_sampling_timesteps(self):
buffer.add(batch)
buffer.add(batch)
sample = buffer.sample(3)
assert type(sample) == MultiAgentBatch
assert type(sample) is MultiAgentBatch
results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
self.assertAlmostEqual(np.mean(results), 3.0, delta=0.2)

Expand All @@ -125,7 +125,7 @@ def test_mixin_sampling_timesteps(self):
for _ in range(100):
buffer.add(batch)
sample = buffer.sample(5)
assert type(sample) == MultiAgentBatch
assert type(sample) is MultiAgentBatch
results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
self.assertAlmostEqual(np.mean(results), 1.5, delta=0.2)

Expand All @@ -142,7 +142,7 @@ def test_mixin_sampling_timesteps(self):
for _ in range(100):
buffer.add(batch)
sample = buffer.sample(10)
assert type(sample) == MultiAgentBatch
assert type(sample) is MultiAgentBatch
results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
self.assertAlmostEqual(np.mean(results), 10.0, delta=0.2)

Expand All @@ -156,12 +156,12 @@ def test_mixin_sampling_timesteps(self):
buffer.add(batch)
# Expect exactly 1 batch to be returned.
sample = buffer.sample(1)
assert type(sample) == MultiAgentBatch
assert type(sample) is MultiAgentBatch
self.assertTrue(len(sample) == 1)
# Expect exactly 0 sample to be returned (nothing new to be returned;
# no replay allowed (replay_ratio=0.0)).
sample = buffer.sample(1)
assert type(sample) == MultiAgentBatch
assert type(sample) is MultiAgentBatch
assert len(sample.policy_batches) == 0
# If we insert and replay n times, expect roughly return batches of
# len 1 (replay_ratio=0.0 -> 0% replayed samples -> 1 new and 0 old samples
Expand All @@ -170,7 +170,7 @@ def test_mixin_sampling_timesteps(self):
for _ in range(100):
buffer.add(batch)
sample = buffer.sample(1)
assert type(sample) == MultiAgentBatch
assert type(sample) is MultiAgentBatch
results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
self.assertAlmostEqual(np.mean(results), 1.0, delta=0.2)

Expand All @@ -187,19 +187,19 @@ def test_mixin_sampling_timesteps(self):
buffer.add(batch)
# Expect exactly 1 sample to be returned (the new batch).
sample = buffer.sample(1)
assert type(sample) == MultiAgentBatch
assert type(sample) is MultiAgentBatch
self.assertTrue(len(sample) == 1)
# Another replay -> Expect exactly 1 sample to be returned.
sample = buffer.sample(1)
assert type(sample) == MultiAgentBatch
assert type(sample) is MultiAgentBatch
self.assertTrue(len(sample) == 1)
# If we replay n times, expect roughly return batches of
# len 1 (replay_ratio=1.0 -> 100% replayed samples -> 0 new and 1 old samples
# on average in each returned value).
results = []
for _ in range(100):
sample = buffer.sample(1)
assert type(sample) == MultiAgentBatch
assert type(sample) is MultiAgentBatch
results.append(len(sample.policy_batches[DEFAULT_POLICY_ID]))
self.assertAlmostEqual(np.mean(results), 1.0)

Expand Down
Loading
Loading