Skip to content

MDIC Disaggregation

View or edit on GitHub

This page is synchronized from trase/data/brazil/trade/mdic/disaggregated/MDIC_Disaggregation.ipynb. Last modified on 2025-12-13 00:30 CET by Trase Admin. Please view or edit the original file there; changes should be reflected here after a midnight build (CET time), or manually triggering it with a GitHub action (link).

import polars as pl
from tqdm import tqdm
from trase.tools.sps import get_pandas_df_once
import s3fs
from trase.tools import sps
from trase.tools.sei_pcs.plotting import sankey
from trase.models.brazil.customs_2019.constants import BUCKETS
from more_itertools import one


# ------------------------------------------------------------ #
# Load disaggregated MDIC data
# ------------------------------------------------------------ #

# 0. Load data from S3: a mixture of CSV and Parquet files
old_dfs = [
    get_pandas_df_once(
        f"brazil/trade/mdic/disaggregated/brazil_mdic_disaggregated_{year}_beef.csv",
        dtype=str,
        na_filter=False,
    ).assign(year=year)
    for year in tqdm([2019, 2020], desc="Reading pre-2021 model results")
]
for df in old_dfs:
    df["success"] = df["success"].map({"True": True, "False": False})
    df["vol"] = df["vol"].astype(float)

new_dfs = [
    pl.read_parquet(
        f"s3://trase-storage/brazil/trade/mdic/disaggregated/brazil_mdic_disaggregated_{year}_beef.parquet"
    ).with_columns(year=year)
    for year in tqdm([2021, 2022, 2023], desc="Reading post-2021 model results")
]


# 1. Assert all DataFrames in both lists have the same columns
reference_columns = old_dfs[0].columns.tolist()
assert all(df.columns.tolist() == reference_columns for df in old_dfs), "Mismatch in old_dfs columns"
assert all(df.columns == reference_columns for df in new_dfs), "Mismatch in new_dfs columns"


# 2. Convert each old Pandas dataframe to Polars
converted_old_dfs = [
    pl.from_pandas(df).with_columns(pl.col("year").cast(pl.Int32))
    for df in old_dfs
]

# 3. Concatenate everything into a single Polars DataFrame
df_solved = pl.concat(converted_old_dfs + new_dfs)
df_solved = df_solved.with_columns(pl.col("month").cast(pl.Int32))

# ------------------------------------------------------------ #
# Load original MDIC data
# ------------------------------------------------------------ #
years = [2019, 2020, 2021, 2022, 2023, 2024]

# Read municipality data
df_municipality_list = []
for year in tqdm(years, desc="Reading MDIC (Municipality)"):
    if year < 2020:
        df = pl.read_csv(
            f"s3://trase-storage/brazil/trade/mdic/municipality/brazil_mdic_municipality_{year}.csv",
            separator=";",
            infer_schema=False
        ).with_columns([
            pl.col("year").cast(pl.Int64),
            pl.col("month").cast(pl.Int64),
            pl.col("vol").cast(pl.Int64),
            pl.col("fob").cast(pl.Int64),
        ])
    else:
        df = pl.read_parquet(
            f"s3://trase-storage/brazil/trade/mdic/municipality/brazil_mdic_municipality_{year}.parquet"
        )
    df_municipality_list.append(df)
df_municipality = pl.concat(df_municipality_list)

# Read port data
df_port_list = []
for year in tqdm(years, desc="Reading MDIC (Port)"):
    if year < 2020:
        df = pl.read_csv(
            f"s3://trase-storage/brazil/trade/mdic/port/brazil_mdic_port_{year}.csv",
            separator=";",
            infer_schema=False
        ).with_columns([
            pl.col("year").cast(pl.Int64),
            pl.col("month").cast(pl.Int64),
            pl.col("vol").cast(pl.Int64),
            pl.col("fob").cast(pl.Int64),
        ])
    else:
        df = pl.read_parquet(
            f"s3://trase-storage/brazil/trade/mdic/port/brazil_mdic_port_{year}.parquet"
        )
    df_port_list.append(df)
df_port = pl.concat(df_port_list)

# filter to HS4 codes that we solved for
hs4_per_year = df_solved.select(["year", "hs4"]).unique()
df_port = df_port.join(hs4_per_year, on=["year", "hs4"], coalesce=True)
df_municipality = df_municipality.join(hs4_per_year, on=["year", "hs4"], coalesce=True)

# ------------------------------------------------------------ #
# Load bills of lading
# ------------------------------------------------------------ #
def read_csv_bol(year):
    # 2023 is read appart further below as its a parquet in a different location
    return get_pandas_df_once(
        f"brazil/trade/bol/{year}/BRAZIL_BOL_{year}.csv",
        dtype=str,
        na_filter=False,
    )
with tqdm(desc="Reading bills of lading", total=5) as progress:
    df_bol_2019 = read_csv_bol(2019)
    progress.update()

    df_bol_2020 = read_csv_bol(2020)
    progress.update()

    df_bol_2021 = read_csv_bol(2021)
    progress.update()

    df_bol_2022 = read_csv_bol(2022)
    progress.update()

    df_bol_2023 = pl.read_parquet(
        "s3://trase-storage/brazil/trade/bol/2023/gold/brazil_bol_2023_gold.parquet"
    )
    progress.update()


def process_csv_data(df):
    """
    Takes the raw DataFrame for a given year and processes it:
    - Renames columns for consistency.
    - Converts data types for relevant columns.
    - Consolidates data by summing the 'vol' column.
    """
    df = (
        # Select relevant columns
        df[
            [
                "year",
                "hs4",
                "hs6",
                "country_of_destination.name",
                "vol",
                "port_of_export.name",
                "exporter.municipality.trase_id",
            ]
        ]
        # rename for consistency
        .rename(columns={"port_of_export.name": "port.name"})
        # cast strings (from CSV file) to specific types
        .astype({"year": int, "vol": float})
    )

    # sum over categorical columns
    df = sps.consolidate(df, ["vol"])

    # convert to Polars
    return pl.from_pandas(df).with_columns(pl.col("year").cast(pl.Int32))


def process_parquet_data(df):
    return (
        df
        # Select relevant columns
        .select(
            [
                "hs4",
                "hs6",
                "country_of_destination_name",
                "net_weight_kg",
                "port_of_export_name",
                "exporter_municipality_trase_id",
            ]
        )
        # rename for consistency
        .rename(
            {
                "country_of_destination_name": "country_of_destination.name",
                "port_of_export_name": "port.name",
                "net_weight_kg": "vol",
                "exporter_municipality_trase_id": "exporter.municipality.trase_id"
            }
        )
        # sum over categorical columns
        .group_by(["hs4", "hs6", "country_of_destination.name", "exporter.municipality.trase_id", "port.name"]).agg(
            pl.col("vol").sum()
        )
    )


df_bol_2021 = process_csv_data(df_bol_2021)
df_bol_2022 = process_csv_data(df_bol_2022)
df_bol_2023 = process_parquet_data(df_bol_2023).with_columns(year=2023)

# concatenate all years into a single dataframe
s = lambda df: df.select(sorted(df.columns))
df_bol = pl.concat([s(df_bol_2021), s(df_bol_2022), s(df_bol_2023)])

# filter to HS4 codes that we solved for
df_bol = df_bol.join(hs4_per_year, on=["year", "hs4"], coalesce=True)

Validate the input data was not altered

The disaggregated dataframe should aggregate to MDIC (Port) and MDIC (Municipality) exactly.

TODO: this should be a DBT quality check

def calculate_volume_errors(df_reference, group_by_columns):
    """
    Calculates the significant volume errors (SMAPE) between the reference and solved dataframes.
    This function computes the aggregated volumes for both dataframes, calculates the error (SMAPE),
    and returns the rows with significant errors.

    Parameters:
    df_reference (pl.DataFrame): The reference dataframe containing the original volume data.
    group_by_columns (list): List of columns to group by for volume aggregation (e.g., categories like year, month).

    Returns:
    pl.DataFrame: DataFrame containing rows with significant volume errors.
    """
    # Step 1: Aggregate the 'vol' column in the reference dataframe, grouped by the specified columns
    reference_volume = df_reference.group_by(group_by_columns).agg(pl.col("vol").sum().alias("vol_reference"))

    # Step 2: Aggregate the 'vol' column in the solved dataframe (filter for years after 2019)
    solved_volume = df_solved.group_by(group_by_columns).agg(pl.col("vol").sum().alias("vol_solved")).filter(pl.col("year") > 2019)

    # Step 3: Join the aggregated dataframes to compare volumes side by side
    comparison_df = solved_volume.join(
        reference_volume,
        on=group_by_columns,
        coalesce=True,  # Fill missing values
        how="full",     # Full outer join to include all rows
        validate="1:1", # Ensure a valid 1:1 match
        suffix="_new"   # Add suffix to distinguish the new column
    )

    # Step 4: Handle the special case where both volumes are close to zero (<= 2) to avoid errors
    zero_volume_condition = (pl.col("vol_solved").abs() <= 2) & (pl.col("vol_reference").abs() <= 2)

    # Step 5: Compute the SMAPE (Symmetric Mean Absolute Percentage Error) between the volumes
    comparison_df = comparison_df.with_columns(
        error=(
            pl.when(zero_volume_condition)  # Set error to 0 if both volumes are almost zero
                .then(0)
                .otherwise(
                    # Compute the SMAPE formula for non-zero volumes
                    (pl.col("vol_solved") - pl.col("vol_reference")) / (pl.col("vol_solved").abs() + pl.col("vol_reference").abs())
                )
        )
    )

    #print(f"Comparison df: {comparison_df}")
    # Step 6: Filter rows where the absolute error exceeds 0.1 (indicating significant error)
    significant_error_df = comparison_df.filter(pl.col("error").abs() > 0.1)

    # Step 7: Rename columns for clarity
    return significant_error_df.rename({"vol_reference": "vol_original", "vol_solved": "vol_current"})


def print_comparison_report(df_reference, exclude_columns=[]):
    """
    Prints a report comparing volumes between the reference dataframe and a solved (new) dataframe.
    The report highlights missing columns, calculates the error (SMAPE) between solved and reference volumes,
    and displays rows with significant errors.

    Parameters:
    df_reference (pl.DataFrame): The reference dataframe containing the original volume data.
    exclude_columns (list): List of columns to exclude from the comparison (e.g., non-relevant columns).
    """
    # Step 1: Identify common columns and missing columns between the solved and reference dataframes
    common_columns = set(df_solved.columns) & set(df_reference.columns)
    missing_columns = set(df_reference.columns) - set(df_solved.columns)

    # Step 2: Print the missing columns
    print(f"The following columns are missing in the solved dataframe:\n\n - " + "\n - ".join(sorted(missing_columns)))

    # Step 3: Identify the categorical columns to group by, excluding 'vol' and any columns specified in exclude_columns
    group_by_columns = common_columns - {"vol", *exclude_columns}

    # Step 4: Calculate the significant volume errors
    significant_errors_df = calculate_volume_errors(df_reference, group_by_columns)

    # Step 5: If there are rows with significant errors, print them
    if not significant_errors_df.is_empty():
        print("\nHere is a dataframe of rows with significant errors (SMAPE > 0.1):\n")
        display(significant_errors_df)

MDIC (Port)

A perfect match, but only when port name and country of destination name are excluded.

TODO should we investigate this?

print_comparison_report(df_port, exclude_columns=["port.name", "country_of_destination.name"])

MDIC (Municipality)

A perfect match; there are a few rows with insignificant differences:

print_comparison_report(df_municipality)

Row counts

# check that all dataframes are equal on the common columns
common = set(df_port.columns) & set(df_municipality.columns) & set(df_solved.columns) - {"vol"}
expected = {'country_of_destination.name', 'hs4', 'month', 'year'}
assert common == expected
assert calculate_volume_errors(df_port, common).is_empty()
assert calculate_volume_errors(df_municipality, common).is_empty()
import polars as pl
import seaborn as sns
import matplotlib.pyplot as plt

# Step 1: Count the number of rows in df_port for each group of common columns (e.g., year, month, hs4)
df_port_row_counts = df_port.group_by(common).agg(pl.len().alias("port_count"))

# Step 2: Count the number of rows in df_municipality for each group of common columns (e.g., year, month, hs4)
df_municipality_row_counts = df_municipality.group_by(common).agg(pl.len().alias("municipality_count"))

# Step 3: Perform an outer join between df_port and df_municipality to combine row counts for each group
df_joined_counts = df_port_row_counts.join(
    df_municipality_row_counts,
    on=common,
    how="full",  # Perform an outer join to retain all data
)

# Step 4: Calculate the L0 norm (number of non-zero entries) per year in the 'vol' column from df
non_zero_solution_counts = df_solved.filter(pl.col("vol").abs() > 1).group_by("year").agg(pl.len().alias("solution_l0"))

# Step 5: Multiply the counts from both dataframes (port and municipality counts) for each group
df_joined_counts = df_joined_counts.with_columns(
    (pl.col("port_count") * pl.col("municipality_count")).alias("maximum_l0")
)

# Step 6: Group by 'year' and sum the 'maximum_l0' to get the total product per year
maximum_l0_per_year = df_joined_counts.group_by("year").agg(pl.col("maximum_l0").sum())

# Step 7: Join the two results - non-zero counts (solution_l0) and total product counts by year
combined_data = non_zero_solution_counts.join(maximum_l0_per_year, on="year", how="full", coalesce=True)

# Step 8: Calculate the ratio of solution_l0 (non-zero counts) to the total possible product per year
combined_data = combined_data.with_columns(
    ratio=pl.col("solution_l0") / pl.col("maximum_l0")
)

# Convert to pandas for seaborn compatibility
combined_data_pd = combined_data.to_pandas()

# Plot the ratio using seaborn
plt.figure(figsize=(8, 3))
sns.barplot(data=combined_data_pd, x="year", y="ratio", color="skyblue")
plt.title("Ratio of Non-Zero to Total Product by Year")
plt.xlabel("Year")
plt.ylabel("Ratio of Non-Zero to Total Product")
plt.tight_layout()
plt.show()

# Plot the row counts using seaborn
plt.figure(figsize=(8, 3))
sns.barplot(data=combined_data_pd, x="year", y="solution_l0", color="salmon")
plt.title("Non-Zero Solution Counts")
plt.xlabel("Year")
plt.ylabel("Non-Zero Solution Counts")
plt.tight_layout()
plt.show()

png

png

Solved volumes over time

# Convert to pandas and pivot for stacking
df_stacked = df_solved_per_year.to_pandas().pivot(index="year", columns="success", values="mass_tonnes").fillna(0)

# Plot stacked bar chart
ax = df_stacked.plot(
    kind="bar",
    stacked=True,
    color={True: "skyblue", False: "salmon"},
    figsize=(8, 5)
)
import matplotlib.ticker as mticker
plt.title("Total Volume by Year and Solved Status")
plt.xlabel("Year")
plt.ylabel("Total Volume (tonnes)")
ax.yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f"{int(x):,}"))

# Add labels on each bar segment with thousands separator
for container in ax.containers:
    ax.bar_label(container, fmt="%.0f", labels=[f"{int(v):,}" if v > 0 else "" for v in container.datavalues], label_type="center")

plt.tight_layout()
plt.show()

png

Interactive manual validation

from IPython.display import Markdown
import seaborn as sns
import matplotlib.pyplot as plt

years_to_look_at = [2023]

# the columns that MDIC have in common
common_mdic = set(df_port.columns) & set(df_municipality.columns) & set(df_solved.columns) - {"vol", "fob"}

# pick one random "bucket" (country + hs4), because each bucket is an independent solve in the customs model
assert BUCKETS == ['country_of_destination.name', 'hs4']
buckets_and_year = [*BUCKETS, "year"]
years_to_look_at = [2023]
country, hs4, year = one(
    # filter the solution to the year(s) we want to investigate
    df_solved.filter(pl.col("year").is_in(years_to_look_at))
    # get a list of unique buckets
    .select(buckets_and_year).unique()
    # pick a random one and convert to a tuple
    .sample(1).rows()
)


args = [
    pl.col('country_of_destination.name') == country,
    pl.col('hs4') == hs4,
    pl.col("year") == year,
]
df_port_sample = df_port.filter(*args).drop(buckets_and_year)
df_municipality_sample = df_municipality.filter(*args).drop(buckets_and_year)
df_bol_sample = df_bol.filter(*args).drop(buckets_and_year)
df_solved_sample = df_solved.filter(*args).drop(buckets_and_year)

# print to the user
display(Markdown(f"### You have chosen {year=}, {country=}, {hs4=}"))
print("Here is what your data looks like")

display(Markdown("#### Solution:"))
success, message = one(df_solved_sample.select(["success", "message"]).unique().rows())
display(Markdown(f" - **Success**: {success}\n - **Message**: {message}"))
display(df_solved_sample)

display(Markdown("#### Bills of lading:"))
display(df_bol_sample)

display(Markdown("#### MDIC (Port)"))
display(df_port_sample)

display(Markdown("#### MDIC (Municipality):"))
display(df_port_sample)


# ----------------------------
# Comparison to bills of lading
# ----------------------------

display(Markdown("## Comparison to bills of lading"))
bol_columns = set(df_bol_sample.columns) - {"vol"}
data = pl.concat(
    [
        s(df_bol_sample.with_columns(pl.lit("Bills of lading").alias("source"))),
        s(df_solved_sample.select([*bol_columns, "vol"]).with_columns(pl.lit("Solution").alias("source"))),
    ]
)
data = data.group_by([*bol_columns, "source"]).agg(pl.col("vol").sum())
data = data.with_columns(
    y=(pl.col("port.name") + " - " + pl.col("exporter.municipality.trase_id") + " - " + pl.col("hs6"))
)

# Convert to pandas for seaborn
data_pd = data.to_pandas()

# Seaborn scatter plot
plt.figure(figsize=(10, 12))
sns.scatterplot(
    data=data_pd,
    x="vol",
    y="y",
    hue="source",
    style="source",
    markers={"Solution": "X", "Bills of lading": "o"},
    s=100
)
plt.xlabel("Volume")
plt.ylabel("Port Name - Municipality - HS6")
plt.title("Comparison to bills of lading")
plt.legend(title="Source")
plt.tight_layout()
plt.show()

You have chosen year=2023, country='KUWAIT', hs4='0202'

Here is what your data looks like

Solution:

  • Success: True
  • Message: Solved with {'solver': 'MOSEK'}
shape: (622, 10)
monthviaport.namestate.trase_idhs6hs8exporter.municipality.trase_idsuccessmessagevol
i32strstrstrstrstrstrboolstrf64
1"01""PARANAGUA""BR-11""020230""02023000""BR-1100189"true"Solved with {'solver': 'MOSEK'…27871.0
1"01""PARANAGUA""BR-11""020230""02023000""BR-3503208"true"Solved with {'solver': 'MOSEK'…0.0
1"01""PARANAGUA""BR-11""020230""02023000""BR-3526704"true"Solved with {'solver': 'MOSEK'…2.5459e-10
1"01""PARANAGUA""BR-11""020230""02023000""BR-3539202"true"Solved with {'solver': 'MOSEK'…2.5396e-10
1"01""PARANAGUA""BR-11""020230""02023000""BR-5006200"true"Solved with {'solver': 'MOSEK'…0.0
9"01""SANTOS""BR-51""020230""02023000""BR-3505500"true"Solved with {'solver': 'MOSEK'…0.0
9"01""SANTOS""BR-51""020230""02023000""BR-3526704"true"Solved with {'solver': 'MOSEK'…0.0
9"01""SANTOS""BR-51""020230""02023000""BR-5105622"true"Solved with {'solver': 'MOSEK'…79349.664493
9"01""SANTOS""BR-51""020230""02023000""BR-5106307"true"Solved with {'solver': 'MOSEK'…2631.334351
9"01""SANTOS""BR-51""020230""02023000""BR-5106752"true"Solved with {'solver': 'MOSEK'…0.000785

Bills of lading:

shape: (47, 4)
exporter.municipality.trase_idhs6port.namevol
strstrstrf64
"BR-3548500""020230""SANTOS"136500.0
"BR-5103353""020230""PARANAGUA"58205.0
"BR-5105622""020230""PARANAGUA"43251.0
"BR-5007505""020230""SANTOS"56697.0
"BR-5002704""020230""PARANAGUA"90835.0
"BR-3502101""020230""SANTOS"29151.0
"BR-5220454""020230""SANTOS"28903.0
"BR-5107602""020230""ITAJAI"28177.0
"BR-5103502""020230""PARANAGUA"142689.0
"BR-1100189""020230""PARANAGUA"115606.0

MDIC (Port)

shape: (72, 13)
monthhs8country_of_destination.paisstate.uf_letterviaport.urfvolfobhs6country_of_destination.trase_idstate.trase_idport.nameport.trase_id
i64strstrstrstrstri64i64strstrstrstrstr
4"02023000""198""TO""01""0817800"1969162216"020230""KW""BR-17""SANTOS""WPI-12970"
1"02023000""198""SP""01""0817800"93879617517"020230""KW""BR-35""SANTOS""WPI-12970"
5"02023000""198""MS""01""0817800"26979217441"020230""KW""BR-50""SANTOS""WPI-12970"
3"02023000""198""MS""01""0917800"989641790"020230""KW""BR-50""PARANAGUA""WPI-12980"
11"02023000""198""PA""01""0817800"27001204941"020230""KW""BR-15""SANTOS""WPI-12970"
4"02023000""198""GO""01""0917800"25711121883"020230""KW""BR-52""PARANAGUA""WPI-12980"
1"02023000""198""MT""01""0927800"27002236226"020230""KW""BR-51""ITAJAI""WPI-13020"
7"02023000""198""GO""01""0817800"39078265784"020230""KW""BR-52""SANTOS""WPI-12970"
10"02023000""198""SP""01""0817800"2697283992"020230""KW""BR-35""SANTOS""WPI-12970"
4"02023000""198""MT""01""0817800"34399109159"020230""KW""BR-51""SANTOS""WPI-12970"

MDIC (Municipality):

shape: (72, 13)
monthhs8country_of_destination.paisstate.uf_letterviaport.urfvolfobhs6country_of_destination.trase_idstate.trase_idport.nameport.trase_id
i64strstrstrstrstri64i64strstrstrstrstr
4"02023000""198""TO""01""0817800"1969162216"020230""KW""BR-17""SANTOS""WPI-12970"
1"02023000""198""SP""01""0817800"93879617517"020230""KW""BR-35""SANTOS""WPI-12970"
5"02023000""198""MS""01""0817800"26979217441"020230""KW""BR-50""SANTOS""WPI-12970"
3"02023000""198""MS""01""0917800"989641790"020230""KW""BR-50""PARANAGUA""WPI-12980"
11"02023000""198""PA""01""0817800"27001204941"020230""KW""BR-15""SANTOS""WPI-12970"
4"02023000""198""GO""01""0917800"25711121883"020230""KW""BR-52""PARANAGUA""WPI-12980"
1"02023000""198""MT""01""0927800"27002236226"020230""KW""BR-51""ITAJAI""WPI-13020"
7"02023000""198""GO""01""0817800"39078265784"020230""KW""BR-52""SANTOS""WPI-12970"
10"02023000""198""SP""01""0817800"2697283992"020230""KW""BR-35""SANTOS""WPI-12970"
4"02023000""198""MT""01""0817800"34399109159"020230""KW""BR-51""SANTOS""WPI-12970"

Comparison to bills of lading

png