Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,26 @@ def execute(self, *args, **kwargs):
# If HOSTNAME is not defined, use the execution name as a fallback
id = os.environ.get("HOSTNAME", ctx.user_space_params.execution_id.name)

# Handle FlyteDirectory for dir parameter
# If dir is specified as a parameter name (string), look it up in kwargs
init_kwargs = self.init_kwargs.copy()
if "dir" in init_kwargs:
dir_value = init_kwargs["dir"]
# If dir is a string, treat it as a parameter name and look it up
if isinstance(dir_value, str) and dir_value in kwargs:
dir_value = kwargs[dir_value]
# Check if it's a FlyteDirectory by checking for download method
if hasattr(dir_value, "download"):
# Download the FlyteDirectory and replace with local path
init_kwargs["dir"] = dir_value.download()
else:
init_kwargs["dir"] = dir_value

run = wandb.init(
project=self.project,
entity=self.entity,
id=id,
**self.init_kwargs,
**init_kwargs,
)

# If FLYTE_EXECUTION_URL is defined, inject it into wandb to link back to the execution.
Expand Down
Loading