跳转至

Docstrings Python

snowflake_ingestion.functions

config_logger()

Configure the global logger.

Reads LOGGER_LEVEL from the module environment and configures the root logging settings (level, format, date format). Intended to be called once at application start.

No return value.

Source code in snowflake_ingestion/functions.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def config_logger() -> None:
    """Configure the global logger.

    Reads LOGGER_LEVEL from the module environment and configures the
    root logging settings (level, format, date format). Intended to be
    called once at application start.

    No return value.
    """
    logging.basicConfig(
        level=LOGGER_LEVEL,
        format="%(asctime)s - %(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )

connect_with_role(user, password, account, role)

Create a Snowflake connection using the specified credentials and role.

Parameters:

Name Type Description Default
user str

Snowflake username.

required
password str

Snowflake password.

required
account str

Snowflake account identifier.

required
role str

Snowflake role to assume for the session.

required

Returns:

Type Description
SnowflakeConnection

snowflake.connector.connection.SnowflakeConnection: A Snowflake connection object with autocommit enabled.

Notes

This function opens a network connection to Snowflake. The caller is responsible for closing the connection when it is no longer needed.

Source code in snowflake_ingestion/functions.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def connect_with_role(user: str, password: str, account: str, role: str) -> snowflake.connector.SnowflakeConnection:
    """Create a Snowflake connection using the specified credentials and role.

    Args:
        user (str): Snowflake username.
        password (str): Snowflake password.
        account (str): Snowflake account identifier.
        role (str): Snowflake role to assume for the session.

    Returns:
        snowflake.connector.connection.SnowflakeConnection:
            A Snowflake connection object with autocommit enabled.

    Notes:
        This function opens a network connection to Snowflake. The caller
        is responsible for closing the connection when it is no longer needed.
    """
    return snowflake.connector.connect(
        user=user,
        password=password,
        account=account,
        role=role,
        autocommit=True,
    )

plural_suffix(count)

Return 's' if count is greater than or equal to 2, else return an empty string.

Parameters:

Name Type Description Default
count int

The number of items.

required

Returns:

Name Type Description
str str

's' if count >= 2, else ''.

Examples:

>>> plural_suffix(0)
''
>>> plural_suffix(1)
''
>>> plural_suffix(2)
's'
>>> plural_suffix(3)
's'
Source code in snowflake_ingestion/functions.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def plural_suffix(count: int) -> str:
    """Return 's' if count is greater than or equal to 2, else return an empty string.

    Args:
        count (int): The number of items.

    Returns:
        str: 's' if count >= 2, else ''.

    Examples:
        >>> plural_suffix(0)
        ''
        >>> plural_suffix(1)
        ''
        >>> plural_suffix(2)
        's'
        >>> plural_suffix(3)
        's'
    """
    return "s" if count >= 2 else ""

run_sql_file(cur, filepath)

Execute SQL statements from a file using VAR_PLACEHOLDER placeholders.

The function
  • Reads the SQL file.
  • Finds placeholders in the form VAR_PLACEHOLDER (captures VAR).
  • Replaces each found placeholder with the value of the corresponding global variable named VAR (stringified).
  • Masks variables containing "PASSWORD" in logger output.
  • Splits the file by semicolons and executes non-empty statements.

Parameters:

Name Type Description Default
cur SnowflakeCursor

Active cursor.

required
filepath Path or str

Path to the SQL file.

required
Notes

Placeholders that do not match a global variable are replaced with a string of the form <VAR_NOT_FOUND>.

Source code in snowflake_ingestion/functions.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def run_sql_file(cur: snowflake.connector.cursor.SnowflakeCursor, filepath: Path | str) -> None:
    """Execute SQL statements from a file using VAR_PLACEHOLDER placeholders.

    The function:
      - Reads the SQL file.
      - Finds placeholders in the form VAR_PLACEHOLDER (captures VAR).
      - Replaces each found placeholder with the value of the corresponding
        global variable named VAR (stringified).
      - Masks variables containing "PASSWORD" in logger output.
      - Splits the file by semicolons and executes non-empty statements.

    Args:
        cur (snowflake.connector.cursor.SnowflakeCursor): Active cursor.
        filepath (pathlib.Path or str): Path to the SQL file.

    Notes:
        Placeholders that do not match a global variable are replaced with
        a string of the form `<VAR_NOT_FOUND>`.
    """
    with open(filepath, "r") as f:
        sql = f.read()
        keys = re.findall(r'(?:SCHEMA_)?(\w+)_PLACEHOLDER', sql)
        variables = {k: globals().get(k, f"<{k}_NOT_FOUND>") for k in keys}
        logger.debug(f"🔎 Variables détectées dans {Path(filepath).name}: {sorted(set(keys))}")
        for key, value in variables.items():
            sql = sql.replace(f"{key}_PLACEHOLDER", str(value))
        masked_vars = {k: "*****" if "PASSWORD" in k.upper() else v for k, v in variables.items()}
        logger.debug(f"Variables utilisées : {dict(sorted(masked_vars.items()))}")

        for statement in sql.split(";"):
            statement = statement.strip()
            if statement:
                cur.execute(statement)

use_context(cur, WH_NAME, DW_NAME, RAW_SCHEMA)

Set the Snowflake session context: warehouse, database and schema.

Parameters:

Name Type Description Default
cur SnowflakeCursor

Active Snowflake cursor.

required
WH_NAME str

Warehouse name to use.

required
DW_NAME str

Database name to use.

required
RAW_SCHEMA str

Schema name to use.

required

Raises:

Type Description
SystemExit

Exits the process on any exception when setting the context.

Source code in snowflake_ingestion/functions.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def use_context(cur: snowflake.connector.cursor.SnowflakeCursor, WH_NAME: str, DW_NAME: str, RAW_SCHEMA: str) -> None:
    """Set the Snowflake session context: warehouse, database and schema.

    Args:
        cur (snowflake.connector.cursor.SnowflakeCursor): Active Snowflake cursor.
        WH_NAME (str): Warehouse name to use.
        DW_NAME (str): Database name to use.
        RAW_SCHEMA (str): Schema name to use.

    Raises:
        SystemExit: Exits the process on any exception when setting the context.
    """
    logger.debug(f"⚙️ Configuration du contexte: WH={WH_NAME}, DB={DW_NAME}, SCHEMA=SCHEMA_{RAW_SCHEMA}")
    try:
        cur.execute(f"USE WAREHOUSE {WH_NAME}")
        cur.execute(f"USE DATABASE {DW_NAME}")
        cur.execute(f"USE SCHEMA SCHEMA_{RAW_SCHEMA}")
    except Exception as e:
        logger.critical("❌ Erreur : Relancer l'étape Snowflake Infra Init")
        sys.exit(1)

snowflake_ingestion.init_infra_snowflake

create_roles_and_user(cur)

Create the DBT role and user in Snowflake.

Parameters:

Name Type Description Default
cur SnowflakeCursor

Active Snowflake cursor.

required
Source code in snowflake_ingestion/init_infra_snowflake.py
21
22
23
24
25
26
27
28
29
30
def create_roles_and_user(cur: SnowflakeCursor) -> None:
    """Create the DBT role and user in Snowflake.

    Args:
        cur (snowflake.connector.cursor.SnowflakeCursor): Active Snowflake cursor.
    """
    logger.info("🔐 Creating roles and users...")
    sql_file = SQL_DIR / "create_roles_and_user.sql"
    functions.run_sql_file(cur, sql_file)
    logger.info("✅ Roles and users created")

grant_privileges(cur)

Grant required privileges to the TRANSFORMER role in Snowflake.

Parameters:

Name Type Description Default
cur SnowflakeCursor

Active Snowflake cursor.

required
Source code in snowflake_ingestion/init_infra_snowflake.py
32
33
34
35
36
37
38
39
40
41
def grant_privileges(cur: SnowflakeCursor) -> None:
    """Grant required privileges to the TRANSFORMER role in Snowflake.

    Args:
        cur (snowflake.connector.cursor.SnowflakeCursor): Active Snowflake cursor.
    """
    logger.info("🔑 Granting privileges to the roles...")
    sql_file = SQL_DIR / "grant_privileges.sql"
    functions.run_sql_file(cur, sql_file)
    logger.info("✅ Privileges granted")

main()

Main initialization process for the Snowflake environment.

Establishes connections with appropriate roles (SYSADMIN, SECURITYADMIN, ACCOUNTADMIN) and executes setup steps in order.

Source code in snowflake_ingestion/init_infra_snowflake.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def main() -> None:
    """Main initialization process for the Snowflake environment.

    Establishes connections with appropriate roles (SYSADMIN, SECURITYADMIN, ACCOUNTADMIN)
    and executes setup steps in order.
    """
    try:
        conn = functions.connect_with_role(functions.USER, functions.PASSWORD, functions.ACCOUNT, "SYSADMIN")
        with conn.cursor() as cur:
            setup_data_warehouse(cur)
        conn.close()

        conn = functions.connect_with_role(functions.USER, functions.PASSWORD, functions.ACCOUNT, "SECURITYADMIN")
        with conn.cursor() as cur:
            create_roles_and_user(cur)
            grant_privileges(cur)
        conn.close()

        conn = functions.connect_with_role(functions.USER, functions.PASSWORD, functions.ACCOUNT, "ACCOUNTADMIN")
        with conn.cursor() as cur:
            set_data_retention(cur)
        conn.close()

        logger.info("🎯 Complete initialization finished successfully!")
    except Exception as e:
        logger.error(e)

set_data_retention(cur)

Set the data retention period for the Snowflake account.

Checks if the account is Enterprise, then applies the retention time. Logs the result in days, with pluralization handled automatically.

Parameters:

Name Type Description Default
cur SnowflakeCursor

Active Snowflake cursor.

required
Source code in snowflake_ingestion/init_infra_snowflake.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def set_data_retention(cur: SnowflakeCursor) -> None:
    """Set the data retention period for the Snowflake account.

    Checks if the account is Enterprise, then applies the retention time.
    Logs the result in days, with pluralization handled automatically.

    Args:
        cur (snowflake.connector.cursor.SnowflakeCursor): Active Snowflake cursor.
    """
    logger.info("🏗️  Setting up data retention")
    sql_file = SQL_DIR / "set_data_retention.sql"
    functions.run_sql_file(cur, sql_file)
    s = functions.plural_suffix(int(functions.RETENTION_TIME))
    logger.info(f"✅ Data retention set to {functions.RETENTION_TIME} day{s}")

setup_data_warehouse(cur)

Create the data warehouse, database, and schemas in Snowflake.

Parameters:

Name Type Description Default
cur SnowflakeCursor

Active Snowflake cursor.

required
Source code in snowflake_ingestion/init_infra_snowflake.py
10
11
12
13
14
15
16
17
18
19
def setup_data_warehouse(cur: SnowflakeCursor) -> None:
    """Create the data warehouse, database, and schemas in Snowflake.

    Args:
        cur (snowflake.connector.cursor.SnowflakeCursor): Active Snowflake cursor.
    """
    logger.info("🏗️  Creating warehouse, database and schemas...")
    sql_file = SQL_DIR / "setup_data_warehouse.sql"
    functions.run_sql_file(cur, sql_file)
    logger.info("✅ Warehouse and schemas created")

Scrape the NYC Taxi data page for Parquet file URLs. Sends an HTTP request to the NYC Taxi,parses the page HTML, and extracts links to Parquet files for the relevant years.

Returns:

Type Description
List[str]

list[str]: List of Parquet file URLs.

Source code in snowflake_ingestion/scrape_links.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def get_parquet_links() -> List[str]:
    """Scrape the NYC Taxi data page for Parquet file URLs.
    Sends an HTTP request to the NYC Taxi,parses the page HTML,
    and extracts links to Parquet files for the relevant years.

    Returns:
        list[str]: List of Parquet file URLs.
    """
    logger.info("🌐 Starting NYC Taxi data scraping")
    response = requests.get(scraping_url)
    tree = html.fromstring(response.content)
    xpath_query = get_xpath()
    filtered_links = tree.xpath(xpath_query)
    return [
        link.get("href")
        for link in filtered_links
        if link.get("href") and link.get("href").endswith(".parquet")
    ]

get_scraping_year()

Determine the scraping year to use based on environment settings. Uses SCRAPING_YEAR if defined and valid, otherwise selects the previous year when current month ≤ 3, or the current year otherwise.

Returns:

Name Type Description
int int

The year to scrape.

Doctests: from functions import SCRAPING_YEAR

get_scraping_year() == (int(SCRAPING_YEAR) if SCRAPING_YEAR != '' else current_year) - int(current_month <= 3) True

Source code in snowflake_ingestion/scrape_links.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def get_scraping_year() -> int:
    """Determine the scraping year to use based on environment settings.
    Uses SCRAPING_YEAR if defined and valid, otherwise selects the previous
    year when current month ≤ 3, or the current year otherwise.

    Returns:
        int: The year to scrape.

    Doctests:
    from functions import SCRAPING_YEAR
    >>> get_scraping_year() == (int(SCRAPING_YEAR) if SCRAPING_YEAR != '' else current_year) - int(current_month <= 3)
    True

    """
    default_year = current_year - 1 if current_month <= 3 else current_year
    if functions.SCRAPING_YEAR == "":
        return default_year
    else:
        try:
            int_year = int(functions.SCRAPING_YEAR)
        except ValueError:
            logger.error(f"\"SCRAPING_YEAR = {functions.SCRAPING_YEAR}\" is not a valid year!")
            logger.warning(f"Scraping year has been reset to {default_year}")
            return default_year

        if int_year < 2009 or int_year > current_year:
            logger.error(
                f"\"SCRAPING_YEAR = {functions.SCRAPING_YEAR}\" scraping year must be between 2009 and {current_year} inclusive!"
            )
            logger.warning(f"Scraping year has been reset to {default_year}")
            return default_year
        logger.info(f"Files will be scraped from year {default_year}")
        return int_year

get_xpath()

Build the XPath expression used to locate Parquet file links. The expression filters NYC Taxi data links by year, starting from the scraping year up to the current year.

Returns:

Name Type Description
str str

XPath query string.

Source code in snowflake_ingestion/scrape_links.py
50
51
52
53
54
55
56
57
58
59
60
61
62
def get_xpath() -> str:
    """Build the XPath expression used to locate Parquet file links.
    The expression filters NYC Taxi data links by year, starting from the
    scraping year up to the current year.

    Returns:
        str: XPath query string.
    """
    xpath_query = "//a[@title='Yellow Taxi Trip Records' and ("
    get_contains = lambda year: f"contains(@href, '{year}')"
    contains_list = [get_contains(year) for year in range(get_scraping_year(), current_year + 1)]
    xpath_query += " or ".join(contains_list) + ")]"
    return xpath_query

main()

Main scraping and metadata update workflow. Connects to Snowflake using the transformer role, initializes context, checks or creates the metadata table, scrapes new file URLs, and updates the metadata accordingly.

Source code in snowflake_ingestion/scrape_links.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def main() -> None:
    """Main scraping and metadata update workflow.
    Connects to Snowflake using the transformer role, initializes context,
    checks or creates the metadata table, scrapes new file URLs, and updates
    the metadata accordingly.
    """
    conn = functions.connect_with_role(
        functions.USER_DEV,
        functions.PASSWORD_DEV,
        functions.ACCOUNT,
        functions.ROLE_TRANSFORMER,
    )
    with conn.cursor() as cur:
        functions.use_context(cur, functions.WH_NAME, functions.DW_NAME, functions.RAW_SCHEMA)
        setup_meta_table(cur)

        links = get_parquet_links()
        s = functions.plural_suffix(len(links))
        logger.info(f"📎 {len(links)} link{s} found")
        new_file_detected: bool = False

        for url in links:
            filename = url.split("/")[-1]
            cur.execute(
                f"SELECT 1 FROM {functions.METADATA_TABLE} WHERE file_name = %s",
                (filename,),
            )
            if not cur.fetchone():
                logger.info(f"➕ New file detected : {filename}")
                new_file_detected = True

                parts = (
                    filename.replace("yellow_tripdata_", "")
                    .replace(".parquet", "")
                    .split("-")
                )
                year = int(parts[0]) if len(parts) > 0 else None
                month = int(parts[1]) if len(parts) > 1 else None

                logger.debug(f"🚀 Loading {functions.METADATA_TABLE}")
                cur.execute(
                    f"""
                    INSERT INTO {functions.METADATA_TABLE}
                    (file_url, file_name, year, month, rows_loaded, load_status)
                    VALUES (%s, %s, %s, %s, 0, 'SCRAPED')
                    """,
                    (url, filename, year, month),
                )
            else:
                logger.info(f"⏭️  {filename} already referenced")

            if not new_file_detected:
                logger.debug("🔍 Analyzing SCRAPED files")
                functions.run_sql_file(cur, SQL_DIR / "count_new_files.sql")
                if cur.fetchone()[0] > 0:
                    new_file_detected = True

    conn.close()

    if not new_file_detected:
        logger.warning("⚠️  No new files to load.")

    logger.info("✅ Scraping completed")

setup_meta_table(cur)

Ensure the metadata table exists in Snowflake. Executes the SQL script responsible for creating or verifying the metadata table.

Parameters:

Name Type Description Default
cur SnowflakeCursor

Active Snowflake cursor.

required
Source code in snowflake_ingestion/scrape_links.py
83
84
85
86
87
88
89
90
91
92
93
94
def setup_meta_table(cur: SnowflakeCursor) -> None:
    """Ensure the metadata table exists in Snowflake.
    Executes the SQL script responsible for creating or verifying the
    metadata table.

    Args:
        cur (snowflake.connector.cursor.SnowflakeCursor): Active Snowflake cursor.
    """
    logger.info("📋 Verification/Creation of metadata table")
    sql_file = SQL_DIR / "setup_meta_table.sql"
    functions.run_sql_file(cur, sql_file)
    logger.info("✅ Metadata table ready")

snowflake_ingestion.upload_stage

download_and_upload_file(cur, file_url, filename)

Download a Parquet file from URL and upload it directly to Snowflake stage.

This function streams the file content directly to Snowflake without persisting it permanently on disk. It uses a temporary file that is automatically deleted after the upload completes, ensuring no residual files are left behind.

Parameters:

Name Type Description Default
cur SnowflakeCursor

Active Snowflake cursor used to execute the PUT command.

required
file_url str

HTTPS URL of the Parquet file to download.

required
filename str

Destination filename in the Snowflake stage.

required

Raises:

Type Description
HTTPError

If the HTTP request fails (non-200 status code).

Error

If the Snowflake PUT command fails.

Source code in snowflake_ingestion/upload_stage.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def download_and_upload_file(cur: SnowflakeCursor, file_url: str, filename: str) -> None:
    """Download a Parquet file from URL and upload it directly to Snowflake stage.

    This function streams the file content directly to Snowflake without persisting
    it permanently on disk. It uses a temporary file that is automatically deleted
    after the upload completes, ensuring no residual files are left behind.

    Args:
        cur (snowflake.connector.cursor.SnowflakeCursor): Active Snowflake cursor 
            used to execute the PUT command.
        file_url (str): HTTPS URL of the Parquet file to download.
        filename (str): Destination filename in the Snowflake stage.

    Raises:
        requests.HTTPError: If the HTTP request fails (non-200 status code).
        snowflake.connector.errors.Error: If the Snowflake PUT command fails.
    """
    logger.info(f"📥 Downloading {filename}...")
    response = requests.get(file_url)
    response.raise_for_status()
    with tempfile.NamedTemporaryFile(suffix=".parquet", delete=True) as tmp_file:
        tmp_file.write(response.content)
        tmp_file.flush()
        logger.info("📤 Uploading to Snowflake...")
        cur.execute(f"PUT 'file://{tmp_file.name}' @~/{filename} AUTO_COMPRESS=FALSE")
    logger.info(f"✅ {filename} uploaded and temporary file cleaned")

main()

Main staging process for Parquet files.

Connects to Snowflake, retrieves metadata for scraped files, downloads each file, uploads it to the stage, and updates the metadata table with the appropriate load status.

Source code in snowflake_ingestion/upload_stage.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def main() -> None:
    """Main staging process for Parquet files.

    Connects to Snowflake, retrieves metadata for scraped files, downloads
    each file, uploads it to the stage, and updates the metadata table
    with the appropriate load status.
    """
    conn = functions.connect_with_role(
        functions.USER_DEV,
        functions.PASSWORD_DEV,
        functions.ACCOUNT,
        functions.ROLE_TRANSFORMER,
    )

    with conn.cursor() as cur:
        functions.use_context(cur, functions.WH_NAME, functions.DW_NAME, functions.RAW_SCHEMA)
        logger.debug("📥 Retrieving scraped file URLs and names")
        functions.run_sql_file(cur, SQL_DIR / "select_file_url_name_from_meta_scraped.sql")
        scraped_files = cur.fetchall()
        scraped_files_count: int = len(scraped_files)

        if scraped_files_count == 0:
            logger.warning("⚠️  No files to upload")
        else:
            logger.info(f"📦 {scraped_files_count} files to upload")

        for file_url, filename in scraped_files:
            try:
                download_and_upload_file(cur, file_url, filename)
                logger.info(f"✅ {filename} uploaded")
                cur.execute(
                    f"UPDATE {functions.METADATA_TABLE} SET load_status='STAGED' WHERE file_name=%s",
                    (filename,),
                )
                logger.debug(f"🚀 Loading {functions.METADATA_TABLE}")
            except Exception as e:
                logger.error(f"❌ Upload error {filename}: {e}")
                logger.debug(f"🚀 Loading {functions.METADATA_TABLE}")
                cur.execute(
                    f"UPDATE {functions.METADATA_TABLE} SET load_status='FAILED_STAGE' WHERE file_name=%s",
                    (filename,),
                )

    conn.close()

snowflake_ingestion.load_to_table

cleanup_stage_file(cur, filename)

Remove the processed file from the Snowflake stage. Args: cur (snowflake.connector.cursor.SnowflakeCursor): Active Snowflake cursor. filename (str): Name of the file to delete from the stage.

Source code in snowflake_ingestion/load_to_table.py
101
102
103
104
105
106
107
108
def cleanup_stage_file(cur: SnowflakeCursor, filename: str) -> None:
    """Remove the processed file from the Snowflake stage.
    Args:
        cur (snowflake.connector.cursor.SnowflakeCursor): Active Snowflake cursor.
        filename (str): Name of the file to delete from the stage.
    """
    cur.execute(f"REMOVE @~/{filename}")
    logger.info(f"✅ {filename} removed from stage")

copy_file_to_table_and_count(cur, filename, table_schema)

Load a Parquet file from stage into the RAW table and count inserted rows. Uses COPY INTO with transformation to generate TRIP_ID using sequence and maps Parquet columns using positional references.

Parameters:

Name Type Description Default
cur SnowflakeCursor

Active Snowflake cursor.

required
filename str

Name of the staged file to load.

required
table_schema list

Pre-detected schema from create_table function.

required

Returns:

Name Type Description
int int

Number of rows inserted into the RAW table.

Source code in snowflake_ingestion/load_to_table.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def copy_file_to_table_and_count(cur: SnowflakeCursor, filename: str, table_schema: List[Tuple[str, str]]) -> int:
    """Load a Parquet file from stage into the RAW table and count inserted rows.
    Uses COPY INTO with transformation to generate TRIP_ID using sequence and 
    maps Parquet columns using positional references.

    Args:
        cur (snowflake.connector.cursor.SnowflakeCursor): Active Snowflake cursor.
        filename (str): Name of the staged file to load.
        table_schema (list): Pre-detected schema from create_table function.

    Returns:
        int: Number of rows inserted into the RAW table.
    """
    logger.info(f"🚀 Loading {filename} into {functions.RAW_TABLE}...")
    column_names = [col[0].replace("airport_fee", "Airport_fee") for col in table_schema]
    select_columns = [f"$1:{col_name}" for col_name in column_names]
    copy_sql = f"""
        COPY INTO {functions.RAW_TABLE} (TRIP_ID, {', '.join(column_names)}, FILENAME)
        FROM (
            SELECT 
                {functions.ID_SEQUENCE}.NEXTVAL,
                {', '.join(select_columns)},
                '{filename}'
            FROM '@~/{filename}'
        )
        FILE_FORMAT=(FORMAT_NAME='{functions.DW_NAME}.SCHEMA_{functions.RAW_SCHEMA}.{functions.PARQUET_FORMAT}')
        FORCE = TRUE
    """
    cur.execute(copy_sql)
    result = cur.fetchone()
    if result and len(result) > 3:
        rows_loaded = result[3]
    else:
        rows_loaded = 0
    s = functions.plural_suffix(rows_loaded)
    logger.info(f"✅ {filename} loaded ({rows_loaded} row{s})")
    return rows_loaded

create_table(cur)

Create or verify the RAW table dynamically based on staged file schema. Executes SQL to detect the file schema in the Snowflake stage, creates the RAW table if it does not exist, and adds the filename column if needed.

Parameters:

Name Type Description Default
cur SnowflakeCursor

Active Snowflake cursor.

required

Returns:

Name Type Description
list List[Tuple[str, str]]

The table schema detected from staged files

Source code in snowflake_ingestion/load_to_table.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def create_table(cur: SnowflakeCursor) -> List[Tuple[str, str]]:
    """Create or verify the RAW table dynamically based on staged file schema.
    Executes SQL to detect the file schema in the Snowflake stage,
    creates the RAW table if it does not exist, and adds the filename
    column if needed.

    Args:
        cur (snowflake.connector.cursor.SnowflakeCursor): Active Snowflake cursor.

    Returns:
        list: The table schema detected from staged files
    """
    logger.info(f"📋 Dynamic verification/creation of table {functions.RAW_TABLE}")
    functions.run_sql_file(cur, SQL_DIR / "detect_file_schema_stage.sql")
    schema = cur.fetchall()
    seen = set()
    table_schema: List[Tuple[str, str]] = []
    for col_name, col_type in schema:
        if col_name.lower() not in seen:
            seen.add(col_name.lower())
            table_schema.append((col_name, col_type))
    if len(table_schema) == 0:
        logger.warning("⚠️  No data in STAGE")
        return table_schema
    functions.run_sql_file(cur, SQL_DIR / "create_sequence.sql")

    columns = [f"TRIP_ID NUMBER"] + [f"{col_name} {col_type}" for col_name, col_type in table_schema]
    if len(columns) != 0:
        create_sql = f"CREATE TABLE IF NOT EXISTS {functions.RAW_TABLE} ({', '.join(columns)})"
        cur.execute(create_sql)
        functions.run_sql_file(cur, SQL_DIR / "add_filename_to_raw_table.sql")
        logger.info(f"✅ Table {functions.RAW_TABLE} ready")
    else:
        logger.warning(f"⚠️  No data in STAGE")
    return table_schema

handle_loading_error(cur, filename, error)

Handle errors occurring during file loading into the RAW table. Logs the error and updates the metadata table to mark the file as failed during the load step.

Parameters:

Name Type Description Default
cur SnowflakeCursor

Active Snowflake cursor.

required
filename str

Name of the file that failed to load.

required
error Exception

Exception raised during the loading process.

required
Source code in snowflake_ingestion/load_to_table.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def handle_loading_error(cur: SnowflakeCursor, filename: str, error: Exception) -> None:
    """Handle errors occurring during file loading into the RAW table.
    Logs the error and updates the metadata table to mark the file
    as failed during the load step.

    Args:
        cur (snowflake.connector.cursor.SnowflakeCursor): Active Snowflake cursor.
        filename (str): Name of the file that failed to load.
        error (Exception): Exception raised during the loading process.
    """
    logger.error(f"❌ Loading error {filename}: {error}")
    logger.debug(f"🚀 Loading {functions.METADATA_TABLE}")
    cur.execute(
        f"UPDATE {functions.METADATA_TABLE} SET load_status='FAILED_LOAD' WHERE file_name=%s",
        (filename,),
    )

main()

Main process for loading staged Parquet files into the RAW table. Connects to Snowflake, ensures the RAW table exists, retrieves staged files, loads each into the RAW table, updates metadata, and cleans up stage files.

Source code in snowflake_ingestion/load_to_table.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def main() -> None:
    """Main process for loading staged Parquet files into the RAW table.
    Connects to Snowflake, ensures the RAW table exists, retrieves staged files,
    loads each into the RAW table, updates metadata, and cleans up stage files.
    """
    conn = functions.connect_with_role(functions.USER_DEV, functions.PASSWORD_DEV, functions.ACCOUNT, functions.ROLE_TRANSFORMER)
    with conn.cursor() as cur:
        functions.use_context(cur, functions.WH_NAME, functions.DW_NAME, functions.RAW_SCHEMA)
        table_schema = create_table(cur)

        logger.info("🔍 Analyzing files in STAGE")
        functions.run_sql_file(cur, SQL_DIR / "select_filename_from_meta_staged.sql")
        staged_files = cur.fetchall()

        for (filename,) in staged_files:
            try:
                rows_loaded = copy_file_to_table_and_count(cur, filename, table_schema)
                update_metadata(cur, filename, rows_loaded)
                cleanup_stage_file(cur, filename)
            except Exception as e:
                handle_loading_error(cur, filename, e)

    conn.close()

update_metadata(cur, filename, rows_loaded)

Update the metadata table after successful file loading. Args: cur (snowflake.connector.cursor.SnowflakeCursor): Active Snowflake cursor. filename (str): Name of the loaded file. rows_loaded (int): Number of rows successfully inserted.

Source code in snowflake_ingestion/load_to_table.py
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def update_metadata(cur: SnowflakeCursor, filename: str, rows_loaded: int) -> None:
    """Update the metadata table after successful file loading.
    Args:
        cur (snowflake.connector.cursor.SnowflakeCursor): Active Snowflake cursor.
        filename (str): Name of the loaded file.
        rows_loaded (int): Number of rows successfully inserted.
    """
    cur.execute(
        f"""
        UPDATE {functions.METADATA_TABLE} 
        SET rows_loaded = %s, load_status = 'SUCCESS' 
        WHERE file_name = %s
        """,
        (rows_loaded, filename),
    )
    logger.debug(f"🚀 Loading {functions.METADATA_TABLE}")

snowflake_ingestion.backup_policy

create_and_set_backup(cur)

Creates the backup policies and backup sets for the data warehouse.

Executes the SQL script to create the monthly backup policies and link them to the target objects (full database, raw table, final schema).

Parameters:

Name Type Description Default
cur SnowflakeCursor

Active Snowflake cursor.

required
Source code in snowflake_ingestion/backup_policy.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def create_and_set_backup(cur: SnowflakeCursor) -> None:
    """Creates the backup policies and backup sets for the data warehouse.

    Executes the SQL script to create the monthly backup policies and
    link them to the target objects (full database, raw table, final schema).

    Args:
        cur (snowflake.connector.cursor.SnowflakeCursor): Active Snowflake cursor.
    """
    logger.info("🔐 Creating backup policies and sets...")
    sql_file = SQL_DIR / "create_and_set_backup.sql"
    functions.run_sql_file(cur, sql_file)
    logger.info(f"✅ {functions.DW_NAME}_BACKUP retention : {functions.FULL_BACKUP_POLICY_DAYS}")
    logger.info(f"✅ {functions.RAW_TABLE}_BACKUP retention : {functions.RAW_TABLE_BACKUP_POLICY_DAYS}")
    logger.info(f"✅ {functions.FINAL_SCHEMA}_BACKUP retention : {functions.FINAL_SCHEMA_BACKUP_POLICY_DAYS}")

main()

Main initialization process for the Snowflake environment.

Establishes connections with appropriate roles (SYSADMIN, SECURITYADMIN, ACCOUNTADMIN) and executes setup steps in order.

Source code in snowflake_ingestion/backup_policy.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def main() -> None:
    """Main initialization process for the Snowflake environment.

    Establishes connections with appropriate roles (SYSADMIN, SECURITYADMIN, ACCOUNTADMIN)
    and executes setup steps in order.
    """
    try:
        conn = functions.connect_with_role(functions.USER, functions.PASSWORD, functions.ACCOUNT, "SYSADMIN")
        with conn.cursor() as cur:
            create_and_set_backup(cur)
        conn.close()


        logger.info("🎯 Complete initialization finished successfully!")
    except Exception as e:
        logger.error(e)

snowflake_ingestion.tests.test_functions

test_connect_with_role_autocommit_enabled()

Unit test verifying that autocommit is always enabled. Verifies that the autocommit=True parameter is systematically passed to the Snowflake connection, regardless of other parameters.

Source code in snowflake_ingestion/tests/test_functions.py
24
25
26
27
28
29
30
31
32
33
def test_connect_with_role_autocommit_enabled():
    """Unit test verifying that autocommit is always enabled.
    Verifies that the autocommit=True parameter is systematically passed
    to the Snowflake connection, regardless of other parameters.
    """
    mock_connection = Mock()
    with patch('snowflake_ingestion.functions.snowflake.connector.connect', return_value=mock_connection) as mock_connect:
        connect_with_role("different_user", "different_pass", "different_account", "different_role")
        call_kwargs = mock_connect.call_args.kwargs
        assert call_kwargs['autocommit'] == True

test_connect_with_role_parameters_passed_correctly()

Unit test verifying correct parameter forwarding. Verifies that all parameters (user, password, account, role) are correctly forwarded to snowflake.connector.connect without modification.

Source code in snowflake_ingestion/tests/test_functions.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def test_connect_with_role_parameters_passed_correctly():
    """Unit test verifying correct parameter forwarding.
    Verifies that all parameters (user, password, account, role) are
    correctly forwarded to snowflake.connector.connect without modification.
    """
    mock_connection = Mock()
    with patch('snowflake_ingestion.functions.snowflake.connector.connect', return_value=mock_connection) as mock_connect:
        test_params = {
            'user': 'test_user',
            'password': 'test_password', 
            'account': 'test_account',
            'role': 'test_role'
        }
        connect_with_role(**test_params)
        call_kwargs = mock_connect.call_args.kwargs
        for key, value in test_params.items():
            assert call_kwargs[key] == value
        assert call_kwargs['autocommit'] == True

test_connect_with_role_success()

Unit test for connect_with_role on success. Verifies that the function calls snowflake.connector.connect with the correct parameters, enables autocommit, and returns the connection object.

Source code in snowflake_ingestion/tests/test_functions.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
def test_connect_with_role_success():
    """Unit test for connect_with_role on success.
    Verifies that the function calls snowflake.connector.connect with the correct
    parameters, enables autocommit, and returns the connection object.
    """
    mock_connection = Mock()
    with patch('snowflake_ingestion.functions.snowflake.connector.connect', return_value=mock_connection) as mock_connect:
        result = connect_with_role("user", "pass", "account", "role")
        mock_connect.assert_called_once_with(
            user="user",
            password="pass", 
            account="account",
            role="role",
            autocommit=True
        )
        assert result == mock_connection

test_run_sql_file()

Unit test for run_sql_file with variable substitution. Verifies that SQL placeholders are correctly replaced by the values of the corresponding global variables.

Source code in snowflake_ingestion/tests/test_functions.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def test_run_sql_file():
    """Unit test for run_sql_file with variable substitution.
    Verifies that SQL placeholders are correctly replaced
    by the values of the corresponding global variables.
    """
    mock_cursor = Mock()
    sql_content = "SELECT * FROM RAW_TABLE_PLACEHOLDER WHERE user = USER_PLACEHOLDER;"
    with patch('builtins.open', mock_open(read_data=sql_content)):
        with patch('snowflake_ingestion.functions.RAW_TABLE', 'my_table'):
            with patch('snowflake_ingestion.functions.USER', 'test_user'):
                run_sql_file(mock_cursor, Path("test.sql"))
    mock_cursor.execute.assert_called_once_with("SELECT * FROM my_table WHERE user = test_user")

test_run_sql_file_multiple_statements()

Unit test for run_sql_file with multiple statements. Verifies that statements separated by semicolons are correctly split and executed individually.

Source code in snowflake_ingestion/tests/test_functions.py
116
117
118
119
120
121
122
123
124
125
def test_run_sql_file_multiple_statements():
    """Unit test for run_sql_file with multiple statements.
    Verifies that statements separated by semicolons are
    correctly split and executed individually.
    """
    mock_cursor = Mock()
    sql_content = "SELECT 1; SELECT 2; SELECT 3;"
    with patch('builtins.open', mock_open(read_data=sql_content)):
        run_sql_file(mock_cursor, Path("test.sql"))
    assert mock_cursor.execute.call_count == 3

test_run_sql_file_variable_not_found()

Unit test for run_sql_file with an unresolved variable. Verifies that placeholders with no matching global variable are replaced by the default value .

Source code in snowflake_ingestion/tests/test_functions.py
104
105
106
107
108
109
110
111
112
113
def test_run_sql_file_variable_not_found():
    """Unit test for run_sql_file with an unresolved variable.
    Verifies that placeholders with no matching global variable
    are replaced by the default value <VAR_NOT_FOUND>.
    """
    mock_cursor = Mock()
    sql_content = "SELECT * FROM UNKNOWN_PLACEHOLDER;"
    with patch('builtins.open', mock_open(read_data=sql_content)):
        run_sql_file(mock_cursor, Path("test.sql"))
    mock_cursor.execute.assert_called_once_with("SELECT * FROM <UNKNOWN_NOT_FOUND>")

test_use_context(mocker)

Unit test for the use_context function. Verifies that the function executes the 3 expected SQL commands to configure the Snowflake context (warehouse, database, schema) in the correct order.

Parameters:

Name Type Description Default
mocker Mock

pytest fixture for mocking

required
Source code in snowflake_ingestion/tests/test_functions.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def test_use_context(mocker: Mock):
    """Unit test for the use_context function.
    Verifies that the function executes the 3 expected SQL commands to configure
    the Snowflake context (warehouse, database, schema) in the correct order.

    Args:
        mocker: pytest fixture for mocking
    """
    mock_cursor = Mock()
    with patch.object(mock_cursor, 'execute') as mock_execute:
        use_context(mock_cursor, "WH", "DB", "TEST")
        assert mock_execute.call_count == 3
        calls = [call[0][0] for call in mock_execute.call_args_list]
        assert "USE WAREHOUSE WH" in calls
        assert "USE DATABASE DB" in calls  
        assert "USE SCHEMA SCHEMA_TEST" in calls

test_use_context_exception()

Unit test for error handling in use_context. Verifies that the function raises a SystemExit exception when an error occurs during SQL command execution.

Source code in snowflake_ingestion/tests/test_functions.py
76
77
78
79
80
81
82
83
84
85
def test_use_context_exception():
    """Unit test for error handling in use_context.
    Verifies that the function raises a SystemExit exception when an error
    occurs during SQL command execution.
    """
    mock_cursor = Mock()
    mock_cursor.execute.side_effect = Exception("DB error")

    with pytest.raises(SystemExit):
        use_context(mock_cursor, "WH", "DB", "SCHEMA")

snowflake_ingestion.tests.test_init_infra_snowflake

test_create_roles_and_user()

Unit test for the create_roles_and_user function. Tests the creation of Snowflake roles and users as part of infrastructure initialization. Verifies that the appropriate SQL script is executed and success/failure logs are properly recorded during the role and user creation process.

Source code in snowflake_ingestion/tests/test_init_infra_snowflake.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def test_create_roles_and_user():
    """
    Unit test for the create_roles_and_user function.
    Tests the creation of Snowflake roles and users as part of infrastructure initialization.
    Verifies that the appropriate SQL script is executed and success/failure logs are properly
    recorded during the role and user creation process.
    """
    mock_cursor = Mock()
    with patch('snowflake_ingestion.init_infra_snowflake.functions.run_sql_file') as mock_run_sql:
        with patch('snowflake_ingestion.init_infra_snowflake.logger') as mock_logger:
            infra.create_roles_and_user(mock_cursor)

            mock_run_sql.assert_called_once_with(mock_cursor, infra.SQL_DIR / "create_roles_and_user.sql")
            mock_logger.info.assert_any_call("🔐 Creating roles and users...")
            mock_logger.info.assert_any_call("✅ Roles and users created")

test_grant_privileges()

Unit test for the grant_privileges function. Tests the granting of privileges to the created roles in Snowflake. Verifies that the correct SQL script is executed and appropriate log messages are recorded during the privilege granting process.

Source code in snowflake_ingestion/tests/test_init_infra_snowflake.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def test_grant_privileges():
    """
    Unit test for the grant_privileges function.
    Tests the granting of privileges to the created roles in Snowflake.
    Verifies that the correct SQL script is executed and appropriate log messages are recorded
    during the privilege granting process.
    """
    mock_cursor = Mock()
    with patch('snowflake_ingestion.init_infra_snowflake.functions.run_sql_file') as mock_run_sql:
        with patch('snowflake_ingestion.init_infra_snowflake.logger') as mock_logger:
            infra.grant_privileges(mock_cursor)

            mock_run_sql.assert_called_once_with(mock_cursor, infra.SQL_DIR / "grant_privileges.sql")
            mock_logger.info.assert_any_call("🔑 Granting privileges to the roles...")
            mock_logger.info.assert_any_call("✅ Privileges granted")

test_main_exception()

Unit test for the main function when an exception occurs during initialization. Tests error handling by simulating a connection failure and verifying that the exception is properly caught and logged as an error.

Source code in snowflake_ingestion/tests/test_init_infra_snowflake.py
65
66
67
68
69
70
71
72
73
74
def test_main_exception():
    """
    Unit test for the main function when an exception occurs during initialization.
    Tests error handling by simulating a connection failure and verifying that the
    exception is properly caught and logged as an error.
    """
    with patch('snowflake_ingestion.init_infra_snowflake.functions.connect_with_role', side_effect=Exception("Connection failed")):
        with patch('snowflake_ingestion.init_infra_snowflake.logger') as mock_logger:
            infra.main()
            mock_logger.error.assert_called_once()

test_main_success()

Unit test for the main function when the infrastructure initialization completes successfully. Tests the complete initialization flow including warehouse setup, role creation, and privilege granting. Verifies that all three connections are made with the appropriate roles (ACCOUNTADMIN, SYSADMIN, SECURITYADMIN) and that all initialization functions are called in sequence with successful logging.

Source code in snowflake_ingestion/tests/test_init_infra_snowflake.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def test_main_success():
    """
    Unit test for the main function when the infrastructure initialization completes successfully.
    Tests the complete initialization flow including warehouse setup, role creation,
    and privilege granting. Verifies that all three connections are made with the
    appropriate roles (ACCOUNTADMIN, SYSADMIN, SECURITYADMIN) and that all
    initialization functions are called in sequence with successful logging.
    """
    mock_conns = [Mock(), Mock(), Mock()]
    for mock_conn in mock_conns:
        mock_conn.cursor.return_value.__enter__ = Mock(return_value=Mock())
        mock_conn.cursor.return_value.__exit__ = Mock(return_value=None)

    call_counter = 0
    def connect_side_effect(*args, **kwargs):
        nonlocal call_counter
        if call_counter < len(mock_conns):
            result = mock_conns[call_counter]
            call_counter += 1
            return result
        return Mock()

    with patch('snowflake_ingestion.init_infra_snowflake.functions.RETENTION_TIME', 90):
        with patch('snowflake_ingestion.init_infra_snowflake.functions.connect_with_role', 
                   side_effect=connect_side_effect) as mock_connect:
            with patch('snowflake_ingestion.init_infra_snowflake.setup_data_warehouse') as mock_setup:
                with patch('snowflake_ingestion.init_infra_snowflake.create_roles_and_user') as mock_create:
                    with patch('snowflake_ingestion.init_infra_snowflake.grant_privileges') as mock_grant:
                        with patch('snowflake_ingestion.init_infra_snowflake.logger') as mock_logger:
                            infra.main()

                            assert mock_connect.call_count == 3
                            mock_connect.assert_any_call(infra.functions.USER, infra.functions.PASSWORD, infra.functions.ACCOUNT, 'ACCOUNTADMIN')
                            mock_connect.assert_any_call(infra.functions.USER, infra.functions.PASSWORD, infra.functions.ACCOUNT, 'SYSADMIN')
                            mock_connect.assert_any_call(infra.functions.USER, infra.functions.PASSWORD, infra.functions.ACCOUNT, 'SECURITYADMIN')

                            mock_setup.assert_called_once()
                            mock_create.assert_called_once()
                            mock_grant.assert_called_once()
                            mock_logger.info.assert_called_with("🎯 Complete initialization finished successfully!")

test_set_data_retention()

Unit test for the set_data_retention function. Verifies that the function executes the correct SQL file with the specified retention time and logs appropriate messages when setting data retention policies.

Source code in snowflake_ingestion/tests/test_init_infra_snowflake.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
def test_set_data_retention():
    """
    Unit test for the set_data_retention function.
    Verifies that the function executes the correct SQL file with the specified retention time
    and logs appropriate messages when setting data retention policies.
    """
    mock_cursor = Mock()

    with patch('snowflake_ingestion.init_infra_snowflake.functions.RETENTION_TIME', 90):
        with patch('snowflake_ingestion.init_infra_snowflake.functions.run_sql_file') as mock_run_sql:
            with patch('snowflake_ingestion.init_infra_snowflake.logger') as mock_logger:
                infra.set_data_retention(mock_cursor)

test_setup_data_warehouse()

Unit test for the setup_data_warehouse function. Tests the creation of warehouse, database, and schemas infrastructure in Snowflake. Verifies that the correct SQL script is executed and appropriate log messages are recorded during the warehouse setup process.

Source code in snowflake_ingestion/tests/test_init_infra_snowflake.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def test_setup_data_warehouse():
    """
    Unit test for the setup_data_warehouse function.
    Tests the creation of warehouse, database, and schemas infrastructure in Snowflake.
    Verifies that the correct SQL script is executed and appropriate log messages are recorded
    during the warehouse setup process.
    """
    mock_cursor = Mock()
    with patch('snowflake_ingestion.init_infra_snowflake.functions.run_sql_file') as mock_run_sql:
        with patch('snowflake_ingestion.init_infra_snowflake.logger') as mock_logger:
            infra.setup_data_warehouse(mock_cursor)

            mock_run_sql.assert_called_once_with(mock_cursor, infra.SQL_DIR / "setup_data_warehouse.sql")
            mock_logger.info.assert_any_call("🏗️  Creating warehouse, database and schemas...")
            mock_logger.info.assert_any_call("✅ Warehouse and schemas created")

Test successful extraction of parquet file links from the NYC TLC website HTML content. Verifies that only links with title 'Yellow Taxi Trip Records' are extracted and returned.

Source code in snowflake_ingestion/tests/test_scrape_links.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def test_get_parquet_links_success():
    """
    Test successful extraction of parquet file links from the NYC TLC website HTML content.
    Verifies that only links with title 'Yellow Taxi Trip Records' are extracted and returned.
    """
    mock_html_content = """
    <html>
        <a title="Yellow Taxi Trip Records" href="https://example.com/file1.parquet">Link1</a>
        <a title="Yellow Taxi Trip Records" href="https://example.com/file2.parquet">Link2</a>
        <a title="Other Title" href="https://example.com/file3.parquet">Link3</a>
    </html>
    """
    mock_response = Mock()
    mock_response.content = mock_html_content.encode()

    mock_tree = Mock()
    mock_link1 = Mock()
    mock_link1.get.return_value = "https://example.com/file1.parquet"
    mock_link2 = Mock()
    mock_link2.get.return_value = "https://example.com/file2.parquet"
    mock_tree.xpath.return_value = [mock_link1, mock_link2]

    with patch('snowflake_ingestion.scrape_links.requests.get') as mock_get:
        with patch('snowflake_ingestion.scrape_links.html.fromstring') as mock_fromstring:
            with patch('snowflake_ingestion.scrape_links.get_xpath', return_value="//a[@title='Yellow Taxi Trip Records']"):
                with patch('snowflake_ingestion.scrape_links.logger') as mock_logger:
                    mock_get.return_value = mock_response
                    mock_fromstring.return_value = mock_tree

                    result = scrape.get_parquet_links()

                    mock_get.assert_called_once_with("https://www.nyc.gov/site/tlc/about/tlc-trip-record-data.page")
                    mock_logger.info.assert_called_with("🌐 Starting NYC Taxi data scraping")
                    assert result == ["https://example.com/file1.parquet", "https://example.com/file2.parquet"]

test_get_scraping_year_with_empty_env_early_month()

Test get_scraping_year behavior when SCRAPING_YEAR is empty and current month is early (January to March). The function should default to the previous year since current year's data may not be fully available yet.

Source code in snowflake_ingestion/tests/test_scrape_links.py
18
19
20
21
22
23
24
25
26
27
def test_get_scraping_year_with_empty_env_early_month():
    """
    Test get_scraping_year behavior when SCRAPING_YEAR is empty and current month is early (January to March).
    The function should default to the previous year since current year's data may not be fully available yet.
    """
    with patch('snowflake_ingestion.scrape_links.functions.SCRAPING_YEAR', ''):
        with patch('snowflake_ingestion.scrape_links.current_month', 1):
            result = scrape.get_scraping_year()
            expected = scrape.current_year - 1
            assert result == expected

test_get_scraping_year_with_empty_env_late_month()

Test get_scraping_year behavior when SCRAPING_YEAR is empty and current month is late (April to December). The function should default to the current year for scraping operations.

Source code in snowflake_ingestion/tests/test_scrape_links.py
29
30
31
32
33
34
35
36
37
38
def test_get_scraping_year_with_empty_env_late_month():
    """
    Test get_scraping_year behavior when SCRAPING_YEAR is empty and current month is late (April to December).
    The function should default to the current year for scraping operations.
    """
    with patch('snowflake_ingestion.scrape_links.functions.SCRAPING_YEAR', ''):
        with patch('snowflake_ingestion.scrape_links.current_month', 4):
            result = scrape.get_scraping_year()
            expected = scrape.current_year
            assert result == expected

test_get_scraping_year_with_invalid_env_early_month()

Test get_scraping_year behavior when SCRAPING_YEAR contains invalid non-numeric data in early months. The function should log an error and default to the previous year while handling the invalid input gracefully.

Source code in snowflake_ingestion/tests/test_scrape_links.py
40
41
42
43
44
45
46
47
48
49
50
51
def test_get_scraping_year_with_invalid_env_early_month():
    """
    Test get_scraping_year behavior when SCRAPING_YEAR contains invalid non-numeric data in early months.
    The function should log an error and default to the previous year while handling the invalid input gracefully.
    """
    with patch('snowflake_ingestion.scrape_links.functions.SCRAPING_YEAR', 'invalid'):
        with patch('snowflake_ingestion.scrape_links.current_month', 3):
            with patch('snowflake_ingestion.scrape_links.logger') as mock_logger:
                result = scrape.get_scraping_year()
                expected = scrape.current_year - 1
                assert result == expected
                mock_logger.error.assert_called_once()

test_get_scraping_year_with_invalid_env_late_month()

Test get_scraping_year behavior when SCRAPING_YEAR contains invalid non-numeric data in late months. The function should log an error and default to the current year while handling the invalid input gracefully.

Source code in snowflake_ingestion/tests/test_scrape_links.py
53
54
55
56
57
58
59
60
61
62
63
64
65
def test_get_scraping_year_with_invalid_env_late_month():
    """
    Test get_scraping_year behavior when SCRAPING_YEAR contains invalid non-numeric data in late months.
    The function should log an error and default to the current year while handling the invalid input gracefully.
    """
    fixed_year = scrape.current_year
    with patch('snowflake_ingestion.scrape_links.functions.SCRAPING_YEAR', 'invalid'):
        with patch('snowflake_ingestion.scrape_links.current_year', fixed_year):
            with patch('snowflake_ingestion.scrape_links.current_month', 12):
                with patch('snowflake_ingestion.scrape_links.logger') as mock_logger:
                    result = scrape.get_scraping_year()
                    assert result == fixed_year
                    mock_logger.error.assert_called_once()

test_get_scraping_year_with_valid_env()

Test that get_scraping_year correctly parses and returns a valid integer value from the SCRAPING_YEAR environment variable. Verifies that when SCRAPING_YEAR is set to '2023', the function returns the integer 2023.

Source code in snowflake_ingestion/tests/test_scrape_links.py
 9
10
11
12
13
14
15
16
def test_get_scraping_year_with_valid_env():
    """
    Test that get_scraping_year correctly parses and returns a valid integer value from the SCRAPING_YEAR environment variable.
    Verifies that when SCRAPING_YEAR is set to '2023', the function returns the integer 2023.
    """
    with patch('snowflake_ingestion.scrape_links.functions.SCRAPING_YEAR', '2023'):
        result = scrape.get_scraping_year()
        assert result == 2023

test_get_xpath()

Test that get_xpath generates the correct XPath expression for locating Yellow Taxi Trip Records links. The XPath should include both the scraping year and current year to capture relevant data files.

Source code in snowflake_ingestion/tests/test_scrape_links.py
67
68
69
70
71
72
73
74
75
76
def test_get_xpath():
    """
    Test that get_xpath generates the correct XPath expression for locating Yellow Taxi Trip Records links.
    The XPath should include both the scraping year and current year to capture relevant data files.
    """
    with patch('snowflake_ingestion.scrape_links.get_scraping_year', return_value=2023):
        with patch('snowflake_ingestion.scrape_links.current_year', 2024):
            result = scrape.get_xpath()
            expected = "//a[@title='Yellow Taxi Trip Records' and (contains(@href, '2023') or contains(@href, '2024'))]"
            assert result == expected

test_main_file_parsing()

Test that the main function correctly parses filename patterns to extract year and month components. Verifies that URLs like 'yellow_tripdata_2023-07.parquet' are correctly parsed into (2023, 7) tuples.

Source code in snowflake_ingestion/tests/test_scrape_links.py
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def test_main_file_parsing():
    """
    Test that the main function correctly parses filename patterns to extract year and month components.
    Verifies that URLs like 'yellow_tripdata_2023-07.parquet' are correctly parsed into (2023, 7) tuples.
    """
    mock_conn = MagicMock()
    mock_cursor = MagicMock()
    mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
    mock_cursor.fetchone.return_value = None

    with patch('snowflake_ingestion.scrape_links.functions.connect_with_role', return_value=mock_conn):
        with patch('snowflake_ingestion.scrape_links.functions.use_context'):
            with patch('snowflake_ingestion.scrape_links.setup_meta_table'):
                with patch('snowflake_ingestion.scrape_links.get_parquet_links') as mock_links:
                    with patch('snowflake_ingestion.scrape_links.functions.run_sql_file'):
                        with patch('snowflake_ingestion.scrape_links.logger'):
                            mock_links.return_value = ["https://example.com/yellow_tripdata_2023-07.parquet"]
                            scrape.main()

                            insert_call = None
                            for call in mock_cursor.execute.call_args_list:
                                if 'INSERT' in str(call[0][0]):
                                    insert_call = call
                                    break

                            assert insert_call is not None
                            assert insert_call[0][1] == ("https://example.com/yellow_tripdata_2023-07.parquet", 
                                                        "yellow_tripdata_2023-07.parquet", 2023, 7)

test_main_with_new_files()

Test the main scraping workflow when new parquet files are detected that don't exist in the metadata table. Verifies that new files trigger INSERT operations into the metadata table with appropriate logging.

Source code in snowflake_ingestion/tests/test_scrape_links.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def test_main_with_new_files():
    """
    Test the main scraping workflow when new parquet files are detected that don't exist in the metadata table.
    Verifies that new files trigger INSERT operations into the metadata table with appropriate logging.
    """
    mock_conn = MagicMock()
    mock_cursor = MagicMock()
    mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
    mock_cursor.fetchone.side_effect = [None, None, [5]]

    with patch('snowflake_ingestion.scrape_links.functions.connect_with_role', return_value=mock_conn):
        with patch('snowflake_ingestion.scrape_links.functions.use_context'):
            with patch('snowflake_ingestion.scrape_links.setup_meta_table'):
                with patch('snowflake_ingestion.scrape_links.get_parquet_links') as mock_links:
                    with patch('snowflake_ingestion.scrape_links.functions.run_sql_file'):
                        with patch('snowflake_ingestion.scrape_links.logger') as mock_logger:

                            mock_links.return_value = [
                                "https://example.com/yellow_tripdata_2023-01.parquet",
                                "https://example.com/yellow_tripdata_2023-02.parquet"
                            ]

                            scrape.main()

                            assert mock_cursor.execute.call_count >= 4
                            mock_logger.info.assert_any_call("📎 2 links found")
                            mock_logger.info.assert_any_call("➕ New file detected : yellow_tripdata_2023-01.parquet")

test_main_without_new_files()

Test the main scraping workflow when all discovered parquet files already exist in the metadata table. Verifies that no INSERT operations occur and appropriate informational and warning logs are recorded.

Source code in snowflake_ingestion/tests/test_scrape_links.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
def test_main_without_new_files():
    """
    Test the main scraping workflow when all discovered parquet files already exist in the metadata table.
    Verifies that no INSERT operations occur and appropriate informational and warning logs are recorded.
    """
    mock_conn = MagicMock()
    mock_cursor = MagicMock()
    mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
    mock_cursor.fetchone.side_effect = [[1], [0]]

    with patch('snowflake_ingestion.scrape_links.functions.connect_with_role', return_value=mock_conn):
        with patch('snowflake_ingestion.scrape_links.functions.use_context'):
            with patch('snowflake_ingestion.scrape_links.setup_meta_table'):
                with patch('snowflake_ingestion.scrape_links.get_parquet_links') as mock_links:
                    with patch('snowflake_ingestion.scrape_links.functions.run_sql_file'):
                        with patch('snowflake_ingestion.scrape_links.logger') as mock_logger:
                            mock_links.return_value = ["https://example.com/yellow_tripdata_2023-01.parquet"]
                            scrape.main()
                            mock_logger.info.assert_any_call("📎 1 link found")
                            mock_logger.info.assert_any_call("⏭️  yellow_tripdata_2023-01.parquet already referenced")
                            mock_logger.info.assert_any_call("✅ Scraping completed")
                            mock_logger.warning.assert_called_with("⚠️  No new files to load.")

test_setup_meta_table()

Test that setup_meta_table correctly executes the SQL script for creating or verifying the metadata table. Verifies that the appropriate SQL file is executed and success/failure logs are recorded.

Source code in snowflake_ingestion/tests/test_scrape_links.py
113
114
115
116
117
118
119
120
121
122
123
124
def test_setup_meta_table():
    """
    Test that setup_meta_table correctly executes the SQL script for creating or verifying the metadata table.
    Verifies that the appropriate SQL file is executed and success/failure logs are recorded.
    """
    mock_cursor = Mock()
    with patch('snowflake_ingestion.scrape_links.functions.run_sql_file') as mock_run_sql:
        with patch('snowflake_ingestion.scrape_links.logger') as mock_logger:
            scrape.setup_meta_table(mock_cursor)
            mock_run_sql.assert_called_once_with(mock_cursor, scrape.SQL_DIR / "setup_meta_table.sql")
            mock_logger.info.assert_any_call("📋 Verification/Creation of metadata table")
            mock_logger.info.assert_any_call("✅ Metadata table ready")

snowflake_ingestion.tests.test_upload_stage

test_download_and_upload_file_http_error()

Unit test for download_and_upload_file in case of HTTP error. Verifies that the function raises an exception when the HTTP download fails.

Source code in snowflake_ingestion/tests/test_upload_stage.py
40
41
42
43
44
45
46
47
48
49
50
51
52
def test_download_and_upload_file_http_error():
    """Unit test for download_and_upload_file in case of HTTP error.
    Verifies that the function raises an exception when the HTTP download fails.
    """
    mock_cursor = Mock()
    mock_response = Mock()
    mock_response.raise_for_status.side_effect = requests.HTTPError("HTTP Error")

    with patch('snowflake_ingestion.upload_stage.requests.get', return_value=mock_response):
        with patch('snowflake_ingestion.upload_stage.tempfile.NamedTemporaryFile'):
            with patch('snowflake_ingestion.upload_stage.logger'):
                with pytest.raises(requests.HTTPError):
                    stage.download_and_upload_file(mock_cursor, "http://example.com/test.parquet", "test.parquet")

test_download_and_upload_file_snowflake_error()

Unit test for download_and_upload_file in case of Snowflake error. Verifies that the function raises an exception when the Snowflake upload fails.

Source code in snowflake_ingestion/tests/test_upload_stage.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def test_download_and_upload_file_snowflake_error():
    """Unit test for download_and_upload_file in case of Snowflake error.
    Verifies that the function raises an exception when the Snowflake upload fails.
    """
    mock_cursor = Mock()
    mock_response = Mock()
    mock_response.content = b"fake parquet content"
    mock_response.raise_for_status = Mock()
    mock_temp_file = Mock()
    mock_temp_file.name = "/tmp/tempfile_123.parquet"
    mock_temp_file.write = Mock()
    mock_temp_file.flush = Mock()
    mock_temp_file.__enter__ = Mock(return_value=mock_temp_file)
    mock_temp_file.__exit__ = Mock(return_value=None)

    with patch('snowflake_ingestion.upload_stage.requests.get', return_value=mock_response):
        with patch('snowflake_ingestion.upload_stage.tempfile.NamedTemporaryFile', return_value=mock_temp_file):
            with patch('snowflake_ingestion.upload_stage.logger'):
                mock_cursor.execute.side_effect = Exception("Snowflake PUT failed")       
                with pytest.raises(Exception, match="Snowflake PUT failed"):
                    stage.download_and_upload_file(mock_cursor, "http://example.com/test.parquet", "test.parquet")

test_download_and_upload_file_success()

Unit test for download_and_upload_file in case of success. Verifies that the function downloads the file from the URL, uploads it to Snowflake via PUT, and automatically cleans up the temporary file.

Source code in snowflake_ingestion/tests/test_upload_stage.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def test_download_and_upload_file_success():
    """Unit test for download_and_upload_file in case of success.
    Verifies that the function downloads the file from the URL, uploads it to Snowflake
    via PUT, and automatically cleans up the temporary file.
    """
    mock_cursor = Mock()
    mock_response = Mock()
    mock_response.content = b"fake parquet content"
    mock_response.raise_for_status = Mock()
    mock_temp_file = Mock()
    mock_temp_file.name = "/tmp/tempfile_123.parquet"
    mock_temp_file.write = Mock()
    mock_temp_file.flush = Mock()
    mock_temp_file.__enter__ = Mock(return_value=mock_temp_file)
    mock_temp_file.__exit__ = Mock(return_value=None)

    with patch('snowflake_ingestion.upload_stage.requests.get', return_value=mock_response):
        with patch('snowflake_ingestion.upload_stage.tempfile.NamedTemporaryFile') as mock_tempfile:
            with patch('snowflake_ingestion.upload_stage.logger') as mock_logger:
                mock_tempfile.return_value = mock_temp_file
                stage.download_and_upload_file(mock_cursor, "http://example.com/test.parquet", "test.parquet")

                mock_response.raise_for_status.assert_called_once()
                mock_temp_file.write.assert_called_once_with(b"fake parquet content")
                mock_temp_file.flush.assert_called_once()
                mock_cursor.execute.assert_called_once_with("PUT 'file:///tmp/tempfile_123.parquet' @~/test.parquet AUTO_COMPRESS=FALSE")
                mock_logger.info.assert_any_call("📥 Downloading test.parquet...")
                mock_logger.info.assert_any_call("📤 Uploading to Snowflake...")
                mock_logger.info.assert_any_call("✅ test.parquet uploaded and temporary file cleaned")

test_download_and_upload_file_tempfile_error()

Unit test for download_and_upload_file in case of temporary file creation error. Verifies that the function raises an exception when temporary file creation fails.

Source code in snowflake_ingestion/tests/test_upload_stage.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def test_download_and_upload_file_tempfile_error():
    """Unit test for download_and_upload_file in case of temporary file creation error.
    Verifies that the function raises an exception when temporary file creation fails.
    """
    mock_cursor = Mock()
    mock_response = Mock()
    mock_response.content = b"fake parquet content"
    mock_response.raise_for_status = Mock()

    with patch('snowflake_ingestion.upload_stage.requests.get', return_value=mock_response):
        with patch('snowflake_ingestion.upload_stage.tempfile.NamedTemporaryFile') as mock_tempfile:
            with patch('snowflake_ingestion.upload_stage.logger'):
                mock_tempfile.side_effect = OSError("Cannot create temp file")
                with pytest.raises(OSError, match="Cannot create temp file"):
                    stage.download_and_upload_file(mock_cursor, "http://example.com/test.parquet", "test.parquet")

test_main_file_processing_flow()

Unit test for the complete file processing flow. Verifies the order of operations: DB connection, metadata retrieval, download, upload, status update.

Source code in snowflake_ingestion/tests/test_upload_stage.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def test_main_file_processing_flow():
    """Unit test for the complete file processing flow.
    Verifies the order of operations: DB connection, metadata retrieval,
    download, upload, status update.
    """
    mock_conn = MagicMock()
    mock_cursor = MagicMock()
    mock_conn.cursor.return_value.__enter__.return_value = mock_cursor

    mock_cursor.fetchall.return_value = [("http://example.com/test.parquet", "test.parquet")]
    execute_calls = []

    def track_execute(*args, **kwargs):
        execute_calls.append((args, kwargs))
        return MagicMock()

    mock_cursor.execute.side_effect = track_execute

    with patch('snowflake_ingestion.upload_stage.functions.connect_with_role', return_value=mock_conn):
        with patch('snowflake_ingestion.upload_stage.functions.use_context'):
            with patch('snowflake_ingestion.upload_stage.functions.run_sql_file'):
                with patch('snowflake_ingestion.upload_stage.download_and_upload_file'):
                    with patch('snowflake_ingestion.upload_stage.logger'):
                        stage.main()
                        staged_updates = []
                        for args, kwargs in execute_calls:
                            if len(args) > 0 and 'UPDATE' in args[0] and 'STAGED' in args[0]:
                                staged_updates.append((args, kwargs))
                        assert len(staged_updates) == 1
                        update_args = staged_updates[0][0]
                        assert len(update_args) >= 2
                        assert update_args[1] == ('test.parquet',)

test_main_with_files()

Unit test for the main function with files to upload. Verifies that the function retrieves the scraped files, downloads them, uploads them to Snowflake and updates the status in the metadata.

Source code in snowflake_ingestion/tests/test_upload_stage.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def test_main_with_files():
    """Unit test for the main function with files to upload.
    Verifies that the function retrieves the scraped files, downloads them,
    uploads them to Snowflake and updates the status in the metadata.
    """
    mock_conn = MagicMock()
    mock_cursor = MagicMock()
    mock_conn.cursor.return_value.__enter__.return_value = mock_cursor

    mock_cursor.fetchall.return_value = [
        ("http://example.com/file1.parquet", "file1.parquet"),
        ("http://example.com/file2.parquet", "file2.parquet")
    ]

    with patch('snowflake_ingestion.upload_stage.functions.connect_with_role', return_value=mock_conn):
        with patch('snowflake_ingestion.upload_stage.functions.use_context'):
            with patch('snowflake_ingestion.upload_stage.functions.run_sql_file'):
                with patch('snowflake_ingestion.upload_stage.download_and_upload_file') as mock_download:
                    with patch('snowflake_ingestion.upload_stage.logger') as mock_logger:
                        stage.main()
                        mock_logger.info.assert_any_call("📦 2 files to upload")
                        mock_logger.info.assert_any_call("✅ file1.parquet uploaded")
                        mock_logger.info.assert_any_call("✅ file2.parquet uploaded")
                        update_calls = [call for call in mock_cursor.execute.call_args_list 
                                      if 'UPDATE' in str(call[0][0]) and 'STAGED' in str(call[0][0])]
                        assert len(update_calls) == 2

test_main_with_upload_error()

Unit test for the main function with upload error. Verifies that the function correctly handles upload errors by updating the status to FAILED_STAGE.

Source code in snowflake_ingestion/tests/test_upload_stage.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def test_main_with_upload_error():
    """Unit test for the main function with upload error.
    Verifies that the function correctly handles upload errors by updating
    the status to FAILED_STAGE.
    """
    mock_conn = MagicMock()
    mock_cursor = MagicMock()
    mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
    mock_cursor.fetchall.return_value = [("http://example.com/file1.parquet", "file1.parquet")]

    with patch('snowflake_ingestion.upload_stage.functions.connect_with_role', return_value=mock_conn):
        with patch('snowflake_ingestion.upload_stage.functions.use_context'):
            with patch('snowflake_ingestion.upload_stage.functions.run_sql_file'):
                with patch('snowflake_ingestion.upload_stage.download_and_upload_file') as mock_download:
                    with patch('snowflake_ingestion.upload_stage.logger') as mock_logger:
                        mock_download.side_effect = Exception("Upload failed")
                        stage.main()
                        mock_logger.error.assert_called_with("❌ Upload error file1.parquet: Upload failed")
                        update_calls = [call for call in mock_cursor.execute.call_args_list 
                                      if 'FAILED_STAGE' in str(call[0][0])]
                        assert len(update_calls) == 1

snowflake_ingestion.tests.test_load_to_table

test_cleanup_stage_file()

Tests the removal of a processed file from the Snowflake stage. Verifies the correct REMOVE command is executed and a success log is recorded.

Source code in snowflake_ingestion/tests/test_load_to_table.py
133
134
135
136
137
138
139
140
141
142
def test_cleanup_stage_file():
    """Tests the removal of a processed file from the Snowflake stage.
    Verifies the correct REMOVE command is executed and a success log is recorded.
    """
    mock_cursor = Mock()
    # Correction: Patcher le logger dans le module load_to_table directement
    with patch('snowflake_ingestion.load_to_table.logger') as mock_logger:
        load.cleanup_stage_file(mock_cursor, "test_file.parquet")
        mock_cursor.execute.assert_called_once_with("REMOVE @~/test_file.parquet")
        mock_logger.info.assert_called_with("✅ test_file.parquet removed from stage")

test_copy_file_to_table_and_count_copy_error()

Tests the handling of an exception raised during the COPY INTO execution. Verifies that the exception is propagated and not caught within the function.

Source code in snowflake_ingestion/tests/test_load_to_table.py
100
101
102
103
104
105
106
107
108
109
110
111
112
def test_copy_file_to_table_and_count_copy_error():
    """Tests the handling of an exception raised during the COPY INTO execution.
    Verifies that the exception is propagated and not caught within the function.
    """
    mock_cursor = Mock()
    mock_cursor.execute.side_effect = Exception("COPY failed")

    with patch('snowflake_ingestion.functions.run_sql_file'):
        # Correction: Patcher le logger dans le module load_to_table directement
        with patch('snowflake_ingestion.load_to_table.logger') as mock_logger:
            table_schema = [("vendorid", "NUMBER"), ("tpep_pickup_datetime", "TIMESTAMP_NTZ")]
            with pytest.raises(Exception, match="COPY failed"):
                load.copy_file_to_table_and_count(mock_cursor, "test_file.parquet", table_schema)

test_copy_file_to_table_and_count_success()

Tests the successful execution of COPY INTO command. Verifies the command execution, correct parsing of the result, logging of success with row count, and the return of the correct number of loaded rows.

Source code in snowflake_ingestion/tests/test_load_to_table.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def test_copy_file_to_table_and_count_success():
    """Tests the successful execution of COPY INTO command.
    Verifies the command execution, correct parsing of the result,
    logging of success with row count, and the return of the correct number of loaded rows.
    """
    mock_cursor = Mock()
    mock_cursor.fetchone.return_value = ('test_file.parquet', 'LOADED', 250, 250, 1, 0, None, None, None, None)

    with patch('snowflake_ingestion.functions.run_sql_file'):
        # Correction: Patcher le logger dans le module load_to_table directement
        with patch('snowflake_ingestion.load_to_table.logger') as mock_logger:
            table_schema = [("vendorid", "NUMBER"), ("tpep_pickup_datetime", "TIMESTAMP_NTZ")]
            result = load.copy_file_to_table_and_count(mock_cursor, "test_file.parquet", table_schema)
            assert result == 250
            mock_cursor.execute.assert_called_once()
            mock_logger.info.assert_called_with("✅ test_file.parquet loaded (250 rows)")

test_copy_file_to_table_and_count_zero_loaded()

Tests the COPY INTO command when no rows are processed. Verifies the function returns 0 when the execution result indicates no files were processed.

Source code in snowflake_ingestion/tests/test_load_to_table.py
86
87
88
89
90
91
92
93
94
95
96
97
98
def test_copy_file_to_table_and_count_zero_loaded():
    """Tests the COPY INTO command when no rows are processed.
    Verifies the function returns 0 when the execution result indicates no files were processed.
    """
    mock_cursor = Mock()
    mock_cursor.fetchone.return_value = ('Copy executed with 0 files processed.',)

    with patch('snowflake_ingestion.functions.run_sql_file'):
        # Correction: Patcher le logger dans le module load_to_table directement
        with patch('snowflake_ingestion.load_to_table.logger') as mock_logger:
            table_schema = [("vendorid", "NUMBER"), ("tpep_pickup_datetime", "TIMESTAMP_NTZ")]
            result = load.copy_file_to_table_and_count(mock_cursor, "test_file.parquet", table_schema)
            assert result == 0

test_create_table_no_schema()

Tests the behavior when the stage contains no data. Verifies that only the schema detection script runs, a warning is logged, no table creation SQL is executed, and an empty schema list is returned.

Source code in snowflake_ingestion/tests/test_load_to_table.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def test_create_table_no_schema():
    """Tests the behavior when the stage contains no data.
    Verifies that only the schema detection script runs, a warning is logged,
    no table creation SQL is executed, and an empty schema list is returned.
    """
    mock_cursor = Mock()
    mock_cursor.fetchall.return_value = []

    with patch('snowflake_ingestion.functions.run_sql_file') as mock_run_sql:
        # Correction: Patcher le logger dans le module load_to_table directement
        with patch('snowflake_ingestion.load_to_table.logger') as mock_logger:
            result = load.create_table(mock_cursor)

            mock_run_sql.assert_called_once_with(mock_cursor, load.SQL_DIR / "detect_file_schema_stage.sql")
            mock_cursor.fetchall.assert_called_once()
            mock_logger.warning.assert_called_with("⚠️  No data in STAGE")
            mock_cursor.execute.assert_not_called()
            create_sequence_call = call(mock_cursor, load.SQL_DIR / "create_sequence.sql")
            add_filename_call = call(mock_cursor, load.SQL_DIR / "add_filename_to_raw_table.sql")
            assert create_sequence_call not in mock_run_sql.call_args_list
            assert add_filename_call not in mock_run_sql.call_args_list
            assert result == []

test_create_table_success()

Tests the successful creation of a table with dynamic schema detection. Verifies that all SQL files are executed in the correct order, the correct CREATE TABLE statement is generated, appropriate logs are recorded, and the correct table schema is returned.

Source code in snowflake_ingestion/tests/test_load_to_table.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def test_create_table_success():
    """Tests the successful creation of a table with dynamic schema detection.
    Verifies that all SQL files are executed in the correct order,
    the correct CREATE TABLE statement is generated, appropriate logs are recorded,
    and the correct table schema is returned.
    """
    mock_cursor = Mock()
    mock_cursor.fetchall.return_value = [
        ("vendor_id", "NUMBER"),
        ("tpep_pickup_datetime", "TIMESTAMP_NTZ"),
        ("tpep_dropoff_datetime", "TIMESTAMP_NTZ")
    ]

    with patch('snowflake_ingestion.functions.run_sql_file') as mock_run_sql:
        # Correction: Patcher le logger dans le module load_to_table directement
        with patch('snowflake_ingestion.load_to_table.logger') as mock_logger:
            result = load.create_table(mock_cursor)

            assert mock_run_sql.call_args_list[0] == call(mock_cursor, load.SQL_DIR / "detect_file_schema_stage.sql")
            assert mock_run_sql.call_args_list[1] == call(mock_cursor, load.SQL_DIR / "create_sequence.sql")
            assert mock_run_sql.call_args_list[2] == call(mock_cursor, load.SQL_DIR / "add_filename_to_raw_table.sql")

            mock_cursor.execute.assert_called_once()
            create_call = mock_cursor.execute.call_args[0][0]
            assert "CREATE TABLE IF NOT EXISTS" in create_call
            assert "vendor_id NUMBER" in create_call
            assert "tpep_pickup_datetime TIMESTAMP_NTZ" in create_call

            mock_logger.info.assert_any_call(f"📋 Dynamic verification/creation of table {load.functions.RAW_TABLE}")
            mock_logger.info.assert_any_call(f"✅ Table {load.functions.RAW_TABLE} ready")

            assert result == [
                ("vendor_id", "NUMBER"),
                ("tpep_pickup_datetime", "TIMESTAMP_NTZ"),
                ("tpep_dropoff_datetime", "TIMESTAMP_NTZ")
            ]

test_handle_loading_error()

Tests the error handling flow when a file fails to load. Verifies that an error is logged, the metadata table is updated to 'FAILED_LOAD', and a debug log is recorded.

Source code in snowflake_ingestion/tests/test_load_to_table.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def test_handle_loading_error():
    """Tests the error handling flow when a file fails to load.
    Verifies that an error is logged, the metadata table is updated to 'FAILED_LOAD',
    and a debug log is recorded.
    """
    mock_cursor = Mock()
    test_error = Exception("COPY INTO failed")
    # Correction: Patcher le logger dans le module load_to_table directement
    with patch('snowflake_ingestion.load_to_table.logger') as mock_logger:
        load.handle_loading_error(mock_cursor, "test_file.parquet", test_error)
        mock_logger.error.assert_called_with(f"❌ Loading error test_file.parquet: COPY INTO failed")
        mock_logger.debug.assert_called_with(f"🚀 Loading {load.functions.METADATA_TABLE}")
        mock_cursor.execute.assert_called_once()
        update_call = mock_cursor.execute.call_args
        assert "FAILED_LOAD" in update_call[0][0]
        assert update_call[0][1] == ("test_file.parquet",)

test_main_complete_flow_with_counts()

Tests a complete successful flow with a single file, verifying precise function call sequence. Ensures update_metadata and cleanup_stage_file are called exactly once with the correct arguments.

Source code in snowflake_ingestion/tests/test_load_to_table.py
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
def test_main_complete_flow_with_counts():
    """Tests a complete successful flow with a single file, verifying precise function call sequence.
    Ensures update_metadata and cleanup_stage_file are called exactly once with the correct arguments.
    """
    mock_conn = MagicMock()
    mock_cursor = MagicMock()
    mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
    mock_cursor.fetchall.return_value = [("file1.parquet",)]

    with patch('snowflake_ingestion.functions.connect_with_role', return_value=mock_conn):
        with patch('snowflake_ingestion.functions.use_context'):
            with patch('snowflake_ingestion.load_to_table.create_table', return_value=[("vendorid", "NUMBER")]):
                with patch('snowflake_ingestion.functions.run_sql_file'):
                    with patch('snowflake_ingestion.load_to_table.copy_file_to_table_and_count', return_value=150):
                        with patch('snowflake_ingestion.load_to_table.update_metadata') as mock_update:
                            with patch('snowflake_ingestion.load_to_table.cleanup_stage_file') as mock_cleanup:
                                # Correction: Patcher le logger dans le module load_to_table directement
                                with patch('snowflake_ingestion.load_to_table.logger') as mock_logger:
                                    load.main()
                                    mock_update.assert_called_once_with(ANY, "file1.parquet", 150)
                                    mock_cleanup.assert_called_once_with(ANY, "file1.parquet")

test_main_connection_error()

Tests the main flow when the initial database connection fails. Verifies that the connection exception propagates and stops the process.

Source code in snowflake_ingestion/tests/test_load_to_table.py
253
254
255
256
257
258
259
260
def test_main_connection_error():
    """Tests the main flow when the initial database connection fails.
    Verifies that the connection exception propagates and stops the process.
    """
    with patch('snowflake_ingestion.functions.connect_with_role') as mock_connect:
        mock_connect.side_effect = Exception("Connection failed")
        with pytest.raises(Exception, match="Connection failed"):
            load.main()

test_main_exception_handling_in_loop()

Tests error handling within the file processing loop. Verifies that an error on one file triggers handle_loading_error for that file, while other files continue processing normally (update and cleanup are called for successful files).

Source code in snowflake_ingestion/tests/test_load_to_table.py
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
def test_main_exception_handling_in_loop():
    """Tests error handling within the file processing loop.
    Verifies that an error on one file triggers handle_loading_error for that file,
    while other files continue processing normally (update and cleanup are called for successful files).
    """
    mock_conn = MagicMock()
    mock_cursor = MagicMock()
    mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
    mock_cursor.fetchall.return_value = [("file1.parquet",), ("file2.parquet",), ("file3.parquet",)]

    with patch('snowflake_ingestion.functions.connect_with_role', return_value=mock_conn):
        with patch('snowflake_ingestion.functions.use_context'):
            with patch('snowflake_ingestion.load_to_table.create_table', return_value=[("vendorid", "NUMBER")]):
                with patch('snowflake_ingestion.functions.run_sql_file'):
                    with patch('snowflake_ingestion.load_to_table.copy_file_to_table_and_count') as mock_copy:
                        with patch('snowflake_ingestion.load_to_table.update_metadata') as mock_update:
                            with patch('snowflake_ingestion.load_to_table.cleanup_stage_file') as mock_cleanup:
                                with patch('snowflake_ingestion.load_to_table.handle_loading_error') as mock_handle_error:
                                    # Correction: Patcher le logger dans le module load_to_table directement
                                    with patch('snowflake_ingestion.load_to_table.logger') as mock_logger:
                                        def mock_copy_side_effect(cur, filename, table_schema):
                                            if filename == "file2.parquet":
                                                raise Exception("Error on file2")
                                            return 100
                                        mock_copy.side_effect = mock_copy_side_effect
                                        load.main()
                                        assert mock_copy.call_count == 3
                                        mock_handle_error.assert_called_once_with(ANY, "file2.parquet", ANY)
                                        assert mock_update.call_count == 2
                                        assert mock_cleanup.call_count == 2

test_main_multiple_files_different_results()

Tests processing multiple files with different row counts. Verifies that update_metadata is called for each file with its respective row count.

Source code in snowflake_ingestion/tests/test_load_to_table.py
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
def test_main_multiple_files_different_results():
    """Tests processing multiple files with different row counts.
    Verifies that update_metadata is called for each file with its respective row count.
    """
    mock_conn = MagicMock()
    mock_cursor = MagicMock()
    mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
    mock_cursor.fetchall.return_value = [("small_file.parquet",), ("large_file.parquet",)]

    with patch('snowflake_ingestion.functions.connect_with_role', return_value=mock_conn):
        with patch('snowflake_ingestion.functions.use_context'):
            with patch('snowflake_ingestion.load_to_table.create_table', return_value=[("vendorid", "NUMBER")]):
                with patch('snowflake_ingestion.functions.run_sql_file'):
                    with patch('snowflake_ingestion.load_to_table.copy_file_to_table_and_count') as mock_copy:
                        with patch('snowflake_ingestion.load_to_table.update_metadata') as mock_update:
                            with patch('snowflake_ingestion.load_to_table.cleanup_stage_file'):
                                # Correction: Patcher le logger dans le module load_to_table directement
                                with patch('snowflake_ingestion.load_to_table.logger') as mock_logger:
                                    mock_copy.side_effect = [50, 1000]
                                    load.main()
                                    update_calls = mock_update.call_args_list
                                    assert len(update_calls) == 2
                                    assert update_calls[0][0] == (ANY, "small_file.parquet", 50)
                                    assert update_calls[1][0] == (ANY, "large_file.parquet", 1000)

test_main_no_staged_files()

Tests the main flow when no files are found in the stage. Verifies that the stage analysis log occurs but no copy operations are attempted.

Source code in snowflake_ingestion/tests/test_load_to_table.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def test_main_no_staged_files():
    """Tests the main flow when no files are found in the stage.
    Verifies that the stage analysis log occurs but no copy operations are attempted.
    """
    mock_conn = MagicMock()
    mock_cursor = MagicMock()
    mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
    mock_cursor.fetchall.return_value = []

    with patch('snowflake_ingestion.functions.connect_with_role', return_value=mock_conn):
        with patch('snowflake_ingestion.functions.use_context'):
            with patch('snowflake_ingestion.load_to_table.create_table', return_value=[("vendorid", "NUMBER")]):
                with patch('snowflake_ingestion.functions.run_sql_file'):
                    # Correction: Patcher le logger dans le module load_to_table directement
                    with patch('snowflake_ingestion.load_to_table.logger') as mock_logger:
                        load.main()
                        mock_logger.info.assert_any_call("🔍 Analyzing files in STAGE")
                        assert not any("COPY INTO" in str(call) for call in mock_cursor.execute.call_args_list)

test_main_success_flow()

Tests the complete successful main execution flow with two files. Verifies the full sequence: connection, context setting, table creation, fetching staged files, loading each file, updating metadata, and cleanup.

Source code in snowflake_ingestion/tests/test_load_to_table.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def test_main_success_flow():
    """Tests the complete successful main execution flow with two files.
    Verifies the full sequence: connection, context setting, table creation,
    fetching staged files, loading each file, updating metadata, and cleanup.
    """
    mock_conn = MagicMock()
    mock_cursor = MagicMock()
    mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
    mock_cursor.fetchall.side_effect = [
        [("file1.parquet",), ("file2.parquet",)],
    ]

    with patch('snowflake_ingestion.functions.connect_with_role', return_value=mock_conn):
        with patch('snowflake_ingestion.functions.use_context'):
            with patch('snowflake_ingestion.load_to_table.create_table', return_value=[("vendorid", "NUMBER")]):
                with patch('snowflake_ingestion.functions.run_sql_file'):
                    with patch('snowflake_ingestion.load_to_table.copy_file_to_table_and_count') as mock_copy:
                        with patch('snowflake_ingestion.load_to_table.update_metadata'):
                            with patch('snowflake_ingestion.load_to_table.cleanup_stage_file'):
                                # Correction: Patcher le logger dans le module load_to_table directement
                                with patch('snowflake_ingestion.load_to_table.logger') as mock_logger:
                                    mock_copy.return_value = 100
                                    load.main()
                                    mock_logger.info.assert_any_call("🔍 Analyzing files in STAGE")
                                    assert mock_copy.call_count == 2
                                    mock_copy.assert_any_call(ANY, "file1.parquet", [("vendorid", "NUMBER")])
                                    mock_copy.assert_any_call(ANY, "file2.parquet", [("vendorid", "NUMBER")])

test_main_table_creation_error()

Tests the main flow when table creation fails. Verifies that the exception from create_table propagates and stops the main process.

Source code in snowflake_ingestion/tests/test_load_to_table.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def test_main_table_creation_error():
    """Tests the main flow when table creation fails.
    Verifies that the exception from create_table propagates and stops the main process.
    """
    mock_conn = MagicMock()
    mock_cursor = MagicMock()
    mock_conn.cursor.return_value.__enter__.return_value = mock_cursor

    with patch('snowflake_ingestion.functions.connect_with_role', return_value=mock_conn):
        with patch('snowflake_ingestion.functions.use_context'):
            with patch('snowflake_ingestion.load_to_table.create_table') as mock_create_table:
                mock_create_table.side_effect = Exception("Table creation failed")
                with pytest.raises(Exception, match="Table creation failed"):
                    load.main()

test_main_with_loading_error()

Tests the main flow when one file fails to load. Verifies that the error handler is called for the failed file, while successful files continue processing normally.

Source code in snowflake_ingestion/tests/test_load_to_table.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
def test_main_with_loading_error():
    """Tests the main flow when one file fails to load.
    Verifies that the error handler is called for the failed file,
    while successful files continue processing normally.
    """
    mock_conn = MagicMock()
    mock_cursor = MagicMock()
    mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
    mock_cursor.fetchall.return_value = [("file1.parquet",), ("file2.parquet",)]

    with patch('snowflake_ingestion.functions.connect_with_role', return_value=mock_conn):
        with patch('snowflake_ingestion.functions.use_context'):
            with patch('snowflake_ingestion.load_to_table.create_table', return_value=[("vendorid", "NUMBER")]):
                with patch('snowflake_ingestion.functions.run_sql_file'):
                    with patch('snowflake_ingestion.load_to_table.copy_file_to_table_and_count') as mock_copy:
                        with patch('snowflake_ingestion.load_to_table.update_metadata'):
                            with patch('snowflake_ingestion.load_to_table.cleanup_stage_file'):
                                with patch('snowflake_ingestion.load_to_table.handle_loading_error') as mock_handle_error:
                                    # Correction: Patcher le logger dans le module load_to_table directement
                                    with patch('snowflake_ingestion.load_to_table.logger') as mock_logger:
                                        def mock_copy_side_effect(cur, filename, table_schema):
                                            if filename == "file1.parquet":
                                                return 100
                                            else:
                                                raise Exception("COPY INTO failed")
                                        mock_copy.side_effect = mock_copy_side_effect
                                        load.main()
                                        assert mock_copy.call_count == 2
                                        mock_handle_error.assert_called_once_with(ANY, "file2.parquet", ANY)

test_update_metadata()

Tests the successful update of the metadata table after a file load. Verifies the correct UPDATE SQL is executed with the proper parameters and that a debug log is recorded.

Source code in snowflake_ingestion/tests/test_load_to_table.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def test_update_metadata():
    """Tests the successful update of the metadata table after a file load.
    Verifies the correct UPDATE SQL is executed with the proper parameters
    and that a debug log is recorded.
    """
    mock_cursor = Mock()
    # Correction: Patcher le logger dans le module load_to_table directement
    with patch('snowflake_ingestion.load_to_table.logger') as mock_logger:
        load.update_metadata(mock_cursor, "test_file.parquet", 250)
        mock_cursor.execute.assert_called_once()
        update_call = mock_cursor.execute.call_args
        assert "UPDATE" in update_call[0][0]
        assert "rows_loaded" in update_call[0][0]
        assert "SUCCESS" in update_call[0][0]
        assert len(update_call[0][1]) == 2
        assert update_call[0][1][0] == 250
        assert update_call[0][1][1] == "test_file.parquet"
        mock_logger.debug.assert_called_with(f"🚀 Loading {load.functions.METADATA_TABLE}")

snowflake_ingestion.tests.test_backup_policy

test_backup_configuration_values()

Test backup configuration constants are properly accessed and logged.

Ensures retention days for different backup policies are correctly formatted in log messages.

Source code in snowflake_ingestion/tests/test_backup_policy.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def test_backup_configuration_values():
    """
    Test backup configuration constants are properly accessed and logged.

    Ensures retention days for different backup policies are correctly
    formatted in log messages.
    """
    mock_cursor = Mock()

    with patch('snowflake_ingestion.backup_policy.functions.run_sql_file'):
        with patch('snowflake_ingestion.backup_policy.logger') as mock_logger:
            with patch('snowflake_ingestion.backup_policy.functions.DW_NAME', 'PRODUCTION_DW'):
                with patch('snowflake_ingestion.backup_policy.functions.RAW_TABLE', 'YELLOW_TAXI'):
                    with patch('snowflake_ingestion.backup_policy.functions.FINAL_SCHEMA', 'ANALYTICS'):
                        with patch('snowflake_ingestion.backup_policy.functions.FULL_BACKUP_POLICY_DAYS', 365):
                            with patch('snowflake_ingestion.backup_policy.functions.RAW_TABLE_BACKUP_POLICY_DAYS', 730):
                                with patch('snowflake_ingestion.backup_policy.functions.FINAL_SCHEMA_BACKUP_POLICY_DAYS', 180):

                                    backup.create_and_set_backup(mock_cursor)

                                    mock_logger.info.assert_any_call("✅ PRODUCTION_DW_BACKUP retention : 365")
                                    mock_logger.info.assert_any_call("✅ YELLOW_TAXI_BACKUP retention : 730")
                                    mock_logger.info.assert_any_call("✅ ANALYTICS_BACKUP retention : 180")

test_create_and_set_backup()

Test the create_and_set_backup function execution.

Verifies that the function runs the correct SQL file and logs appropriate messages with retention period information for each backup policy.

Source code in snowflake_ingestion/tests/test_backup_policy.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def test_create_and_set_backup():
    """
    Test the create_and_set_backup function execution.

    Verifies that the function runs the correct SQL file and logs appropriate
    messages with retention period information for each backup policy.
    """
    mock_cursor = Mock()

    with patch('snowflake_ingestion.backup_policy.functions.run_sql_file') as mock_run_sql:
        with patch('snowflake_ingestion.backup_policy.logger') as mock_logger:
            with patch('snowflake_ingestion.backup_policy.functions.DW_NAME', 'TEST_DW'):
                with patch('snowflake_ingestion.backup_policy.functions.RAW_TABLE', 'TEST_RAW'):
                    with patch('snowflake_ingestion.backup_policy.functions.FINAL_SCHEMA', 'TEST_FINAL'):
                        with patch('snowflake_ingestion.backup_policy.functions.FULL_BACKUP_POLICY_DAYS', 90):
                            with patch('snowflake_ingestion.backup_policy.functions.RAW_TABLE_BACKUP_POLICY_DAYS', 180):
                                with patch('snowflake_ingestion.backup_policy.functions.FINAL_SCHEMA_BACKUP_POLICY_DAYS', 365):

                                    backup.create_and_set_backup(mock_cursor)

                                    expected_sql_file = backup.SQL_DIR / "create_and_set_backup.sql"
                                    mock_run_sql.assert_called_once_with(mock_cursor, expected_sql_file)

                                    mock_logger.info.assert_any_call("🔐 Creating backup policies and sets...")
                                    mock_logger.info.assert_any_call("✅ TEST_DW_BACKUP retention : 90")
                                    mock_logger.info.assert_any_call("✅ TEST_RAW_BACKUP retention : 180")
                                    mock_logger.info.assert_any_call("✅ TEST_FINAL_BACKUP retention : 365")

test_main_exception()

Test error handling in main function when an exception occurs.

Verifies that exceptions are caught and logged as errors without crashing.

Source code in snowflake_ingestion/tests/test_backup_policy.py
67
68
69
70
71
72
73
74
75
76
77
78
79
def test_main_exception():
    """
    Test error handling in main function when an exception occurs.

    Verifies that exceptions are caught and logged as errors without crashing.
    """
    test_exception = Exception("Connection failed: Invalid credentials")

    with patch('snowflake_ingestion.backup_policy.functions.connect_with_role', side_effect=test_exception):
        with patch('snowflake_ingestion.backup_policy.logger') as mock_logger:

            backup.main()
            mock_logger.error.assert_called_once_with(test_exception)

test_main_success()

Test the main function with successful backup setup workflow.

Ensures connection is made with SYSADMIN role, backup function is called, connection is closed properly, and success is logged.

Source code in snowflake_ingestion/tests/test_backup_policy.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def test_main_success():
    """
    Test the main function with successful backup setup workflow.

    Ensures connection is made with SYSADMIN role, backup function is called,
    connection is closed properly, and success is logged.
    """
    mock_conn = Mock()
    mock_cursor = Mock()

    mock_conn.cursor.return_value.__enter__ = Mock(return_value=mock_cursor)
    mock_conn.cursor.return_value.__exit__ = Mock(return_value=None)

    with patch('snowflake_ingestion.backup_policy.functions.connect_with_role', return_value=mock_conn) as mock_connect:
        with patch('snowflake_ingestion.backup_policy.create_and_set_backup') as mock_create_backup:
            with patch('snowflake_ingestion.backup_policy.logger') as mock_logger:

                backup.main()

                mock_connect.assert_called_once_with(
                    backup.functions.USER,
                    backup.functions.PASSWORD,
                    backup.functions.ACCOUNT,
                    "SYSADMIN"
                )

                mock_create_backup.assert_called_once_with(mock_cursor)
                mock_conn.close.assert_called_once()
                mock_logger.info.assert_called_with("🎯 Complete initialization finished successfully!")

test_main_with_connection_context()

Test proper cursor context management in main function.

Verifies that cursor context manager protocols are followed correctly.

Source code in snowflake_ingestion/tests/test_backup_policy.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def test_main_with_connection_context():
    """
    Test proper cursor context management in main function.

    Verifies that cursor context manager protocols are followed correctly.
    """
    mock_conn = Mock()
    mock_cursor_context = Mock()
    mock_cursor = Mock()

    mock_cursor_context.__enter__ = Mock(return_value=mock_cursor)
    mock_cursor_context.__exit__ = Mock(return_value=None)
    mock_conn.cursor.return_value = mock_cursor_context

    with patch('snowflake_ingestion.backup_policy.functions.connect_with_role', return_value=mock_conn):
        with patch('snowflake_ingestion.backup_policy.create_and_set_backup'):
            with patch('snowflake_ingestion.backup_policy.logger'):

                backup.main()
                mock_conn.cursor.assert_called_once()
                mock_cursor_context.__enter__.assert_called_once()
                mock_cursor_context.__exit__.assert_called_once()