Skip to content

Commit 965b3bd

Browse files
committed
Add trace-specific color sequence support from template.data
- Check template.data.<trace_type> for marker.color or line.color before falling back to template.layout.colorway - Handle timeline special case (maps to bar trace type) - Use marker colors first, fall back to line colors if no markers found - Fixes issue #5416
1 parent f083977 commit 965b3bd

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

plotly/express/_core.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,7 +1003,7 @@ def one_group(x):
10031003
return ""
10041004

10051005

1006-
def apply_default_cascade(args):
1006+
def apply_default_cascade(args, constructor=None):
10071007
# first we apply px.defaults to unspecified args
10081008

10091009
for param in defaults.__slots__:
@@ -1038,6 +1038,25 @@ def apply_default_cascade(args):
10381038
args["color_continuous_scale"] = sequential.Viridis
10391039

10401040
if "color_discrete_sequence" in args:
1041+
if args["color_discrete_sequence"] is None and constructor is not None:
1042+
if constructor == "timeline":
1043+
trace_type = "bar"
1044+
else:
1045+
trace_type = constructor().type
1046+
if trace_data_list := getattr(args["template"].data, trace_type, None):
1047+
collected_colors = [
1048+
trace_data.marker.color
1049+
for trace_data in trace_data_list
1050+
if hasattr(trace_data, "marker")
1051+
]
1052+
if not collected_colors:
1053+
collected_colors = [
1054+
trace_data.line.color
1055+
for trace_data in trace_data_list
1056+
if hasattr(trace_data, "line")
1057+
]
1058+
if collected_colors:
1059+
args["color_discrete_sequence"] = collected_colors
10411060
if args["color_discrete_sequence"] is None and args["template"].layout.colorway:
10421061
args["color_discrete_sequence"] = args["template"].layout.colorway
10431062
if args["color_discrete_sequence"] is None:
@@ -2486,7 +2505,7 @@ def get_groups_and_orders(args, grouper):
24862505
def make_figure(args, constructor, trace_patch=None, layout_patch=None):
24872506
trace_patch = trace_patch or {}
24882507
layout_patch = layout_patch or {}
2489-
apply_default_cascade(args)
2508+
apply_default_cascade(args, constructor=constructor)
24902509

24912510
args = build_dataframe(args, constructor)
24922511
if constructor in [go.Treemap, go.Sunburst, go.Icicle] and args["path"] is not None:

0 commit comments

Comments
 (0)