Files
training.python.datascience/source/plotting/charts/plotly_figure_sunburst.py
2025-10-24 22:29:15 +02:00

30 lines
1.4 KiB
Python

import pandas as pd
from plotly.graph_objs import Figure, Sunburst
if __name__ == '__main__':
df = pd.DataFrame(data={
"continent": ["Europe", "Europe", "Europe", "Europe", "Europe", "Europe", "Europe", "Asia", "Asia"],
"country": ["France", "France", "Spain", "Spain", "England", "England", "England", "China", "China"],
"city": ["Montpellier", "Bordeaux", "Madrid", "Valencia", "London", "Manchester", "Bristol", "Beijing", "Shanghai"],
"sales": [150_000, 127_000, 97_200, 137_250, 200_000, 180_000, 150_000, 120_000, 140_000]
})
# Create a path column
PATH: list[str] = ["continent", "country", "city"]
COLUMNS: list[str] = [None] * (len(PATH) - 2) + ["parent", "child"]
rollups = []
base = df.groupby(PATH).sum(numeric_only=True)
for level in range(len(PATH)):
rollup = base.groupby(None, level=list(range(level + 1))).sum()
rollup = rollup.reset_index(list(range(level + 1)), names=COLUMNS[-level - 1:])
rollups.append(rollup)
rollups = pd.concat(rollups, axis=0)
sb = Sunburst(labels=rollups["child"], parents=rollups["parent"], values=rollups["sales"], textinfo="label+value+percent parent", texttemplate="%{label}<br>%{value:,.2f}", branchvalues="total")
fig: Figure = Figure(sb, layout={
"font": {"family": "Cabin", "size": 13},
"margin": {"l": 0, "r": 0, "b": 0, "t": 0}
})
fig.show(renderer="browser")