321 lines
13 KiB
Python
321 lines
13 KiB
Python
|
|
"""
|
||
|
|
Tests pour le module de normalisation LLM.
|
||
|
|
|
||
|
|
Aucun appel réseau réel : le client LLM est mocké via unittest.mock.patch.
|
||
|
|
Les tests de DB utilisent tmp_path (base SQLite isolée par test).
|
||
|
|
"""
|
||
|
|
|
||
|
|
from datetime import date
|
||
|
|
from pathlib import Path
|
||
|
|
from unittest.mock import MagicMock, patch
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from tickettracker.db import repository, schema
|
||
|
|
from tickettracker.llm.client import LLMError, LLMUnavailable
|
||
|
|
from tickettracker.llm import normalizer
|
||
|
|
from tickettracker.models.receipt import Item, Receipt
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Fixtures DB (même pattern que test_db.py)
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def db_path(tmp_path: Path) -> Path:
|
||
|
|
"""Base SQLite isolée, initialisée avec le schéma complet."""
|
||
|
|
path = tmp_path / "test_normalizer.db"
|
||
|
|
schema.init_db(path)
|
||
|
|
return path
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def db_conn(db_path: Path):
|
||
|
|
conn = schema.get_connection(db_path)
|
||
|
|
yield conn
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def db_with_items(db_path: Path) -> Path:
|
||
|
|
"""Base pré-remplie avec 3 articles (name_normalized NULL)."""
|
||
|
|
receipt = Receipt(
|
||
|
|
store="leclerc",
|
||
|
|
date=date(2025, 11, 8),
|
||
|
|
total=15.00,
|
||
|
|
items=[
|
||
|
|
Item("NOIX CAJOU", 1, "pièce", 5.12, 5.12, "EPICERIE SALEE"),
|
||
|
|
Item("COCA COLA CHERRY 1.25L", 1, "pièce", 6.72, 6.72, "BOISSONS"),
|
||
|
|
Item("PQ LOTUS CONFORT X6", 1, "pièce", 3.10, 3.10, "HYGIENE"),
|
||
|
|
],
|
||
|
|
)
|
||
|
|
conn = schema.get_connection(db_path)
|
||
|
|
repository.insert_receipt(conn, receipt)
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
return db_path
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Tests de parsing de la réponse LLM
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
class TestParseNormalizedLine:
|
||
|
|
"""Tests unitaires de _parse_normalized_line."""
|
||
|
|
|
||
|
|
def test_valid_line(self):
|
||
|
|
result = normalizer._parse_normalized_line("1. Crème fraîche épaisse | MDD | 50cl")
|
||
|
|
assert result == "Crème fraîche épaisse | MDD | 50cl"
|
||
|
|
|
||
|
|
def test_valid_line_with_parenthesis_number(self):
|
||
|
|
result = normalizer._parse_normalized_line("2) Coca-Cola Cherry | Coca-Cola | 1,25L")
|
||
|
|
assert result == "Coca-Cola Cherry | Coca-Cola | 1,25L"
|
||
|
|
|
||
|
|
def test_valid_line_quantity_absent(self):
|
||
|
|
"""Un tiret '-' est une quantité valide (absente du nom brut)."""
|
||
|
|
result = normalizer._parse_normalized_line("3. Noix de cajou | MDD | -")
|
||
|
|
assert result == "Noix de cajou | MDD | -"
|
||
|
|
|
||
|
|
def test_invalid_no_pipes(self):
|
||
|
|
"""Ligne sans séparateurs | → None."""
|
||
|
|
result = normalizer._parse_normalized_line("1. Juste un nom sans format")
|
||
|
|
assert result is None
|
||
|
|
|
||
|
|
def test_invalid_only_one_pipe(self):
|
||
|
|
"""Un seul | → None (il en faut deux)."""
|
||
|
|
result = normalizer._parse_normalized_line("1. Produit | MDD")
|
||
|
|
assert result is None
|
||
|
|
|
||
|
|
def test_invalid_empty_field(self):
|
||
|
|
"""Champ vide → None."""
|
||
|
|
result = normalizer._parse_normalized_line("1. | MDD | 50cl")
|
||
|
|
assert result is None
|
||
|
|
|
||
|
|
def test_invalid_no_number(self):
|
||
|
|
"""Ligne non numérotée → None."""
|
||
|
|
result = normalizer._parse_normalized_line("Crème fraîche | MDD | 50cl")
|
||
|
|
assert result is None
|
||
|
|
|
||
|
|
def test_strips_extra_spaces(self):
|
||
|
|
"""Les espaces autour des champs sont normalisés."""
|
||
|
|
result = normalizer._parse_normalized_line("1. Noix | MDD | 200g ")
|
||
|
|
assert result == "Noix | MDD | 200g"
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Tests de normalize_product_name (appel unitaire)
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
class TestNormalizeProductName:
|
||
|
|
|
||
|
|
def test_success(self):
|
||
|
|
"""Mock LLM retourne une ligne valide → nom normalisé retourné."""
|
||
|
|
with patch("tickettracker.llm.normalizer.call_llm") as mock_llm:
|
||
|
|
mock_llm.return_value = "1. Noix de cajou | MDD | 200g"
|
||
|
|
result = normalizer.normalize_product_name("NOIX CAJOU")
|
||
|
|
assert result == "Noix de cajou | MDD | 200g"
|
||
|
|
|
||
|
|
def test_llm_error_returns_none(self):
|
||
|
|
"""LLMError → retourne None sans propager."""
|
||
|
|
with patch("tickettracker.llm.normalizer.call_llm") as mock_llm:
|
||
|
|
mock_llm.side_effect = LLMError("HTTP 500")
|
||
|
|
result = normalizer.normalize_product_name("NOIX CAJOU")
|
||
|
|
assert result is None
|
||
|
|
|
||
|
|
def test_llm_unavailable_returns_none(self):
|
||
|
|
"""LLMUnavailable → retourne None sans propager."""
|
||
|
|
with patch("tickettracker.llm.normalizer.call_llm") as mock_llm:
|
||
|
|
mock_llm.side_effect = LLMUnavailable("Timeout")
|
||
|
|
result = normalizer.normalize_product_name("NOIX CAJOU")
|
||
|
|
assert result is None
|
||
|
|
|
||
|
|
def test_unparsable_response_returns_none(self):
|
||
|
|
"""Réponse LLM non parsable → None."""
|
||
|
|
with patch("tickettracker.llm.normalizer.call_llm") as mock_llm:
|
||
|
|
mock_llm.return_value = "Désolé, je ne comprends pas."
|
||
|
|
result = normalizer.normalize_product_name("NOIX CAJOU")
|
||
|
|
assert result is None
|
||
|
|
|
||
|
|
def test_passes_raw_name_to_llm(self):
|
||
|
|
"""Vérifie que le nom brut est bien transmis au LLM."""
|
||
|
|
with patch("tickettracker.llm.normalizer.call_llm") as mock_llm:
|
||
|
|
mock_llm.return_value = "1. Coca-Cola Cherry | Coca-Cola | 1,25L"
|
||
|
|
normalizer.normalize_product_name("COCA COLA CHERRY 1.25L")
|
||
|
|
call_args = mock_llm.call_args[0][0] # messages list
|
||
|
|
user_content = next(m["content"] for m in call_args if m["role"] == "user")
|
||
|
|
assert "COCA COLA CHERRY 1.25L" in user_content
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Tests de normalize_batch
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
class TestNormalizeBatch:
|
||
|
|
|
||
|
|
def test_success_full_batch(self):
|
||
|
|
"""3 noms → 3 lignes valides retournées."""
|
||
|
|
llm_response = (
|
||
|
|
"1. Noix de cajou | MDD | 200g\n"
|
||
|
|
"2. Coca-Cola Cherry | Coca-Cola | 1,25L\n"
|
||
|
|
"3. Papier toilette confort | Lotus | x6"
|
||
|
|
)
|
||
|
|
with patch("tickettracker.llm.normalizer.call_llm") as mock_llm:
|
||
|
|
mock_llm.return_value = llm_response
|
||
|
|
results = normalizer.normalize_batch([
|
||
|
|
"NOIX CAJOU",
|
||
|
|
"COCA COLA CHERRY 1.25L",
|
||
|
|
"PQ LOTUS CONFORT X6",
|
||
|
|
])
|
||
|
|
assert len(results) == 3
|
||
|
|
assert results[0] == "Noix de cajou | MDD | 200g"
|
||
|
|
assert results[1] == "Coca-Cola Cherry | Coca-Cola | 1,25L"
|
||
|
|
assert results[2] == "Papier toilette confort | Lotus | x6"
|
||
|
|
|
||
|
|
def test_wrong_count_returns_all_none(self):
|
||
|
|
"""LLM retourne 2 lignes pour 3 items → [None, None, None]."""
|
||
|
|
llm_response = (
|
||
|
|
"1. Noix de cajou | MDD | 200g\n"
|
||
|
|
"2. Coca-Cola Cherry | Coca-Cola | 1,25L"
|
||
|
|
)
|
||
|
|
with patch("tickettracker.llm.normalizer.call_llm") as mock_llm:
|
||
|
|
mock_llm.return_value = llm_response
|
||
|
|
results = normalizer.normalize_batch([
|
||
|
|
"NOIX CAJOU",
|
||
|
|
"COCA COLA CHERRY 1.25L",
|
||
|
|
"PQ LOTUS CONFORT X6",
|
||
|
|
])
|
||
|
|
assert results == [None, None, None]
|
||
|
|
|
||
|
|
def test_llm_error_returns_all_none(self):
|
||
|
|
"""LLMError sur le batch → [None, None, None]."""
|
||
|
|
with patch("tickettracker.llm.normalizer.call_llm") as mock_llm:
|
||
|
|
mock_llm.side_effect = LLMError("HTTP 429")
|
||
|
|
results = normalizer.normalize_batch(["A", "B", "C"])
|
||
|
|
assert results == [None, None, None]
|
||
|
|
|
||
|
|
def test_llm_unavailable_propagated(self):
|
||
|
|
"""LLMUnavailable est propagé (pas silencieux) pour que normalize_all_in_db s'arrête."""
|
||
|
|
with patch("tickettracker.llm.normalizer.call_llm") as mock_llm:
|
||
|
|
mock_llm.side_effect = LLMUnavailable("Connexion refusée")
|
||
|
|
with pytest.raises(LLMUnavailable):
|
||
|
|
normalizer.normalize_batch(["A", "B"])
|
||
|
|
|
||
|
|
def test_empty_list(self):
|
||
|
|
"""Liste vide → liste vide, pas d'appel LLM."""
|
||
|
|
with patch("tickettracker.llm.normalizer.call_llm") as mock_llm:
|
||
|
|
results = normalizer.normalize_batch([])
|
||
|
|
assert results == []
|
||
|
|
mock_llm.assert_not_called()
|
||
|
|
|
||
|
|
def test_fallback_when_batch_fails(self):
|
||
|
|
"""Si normalize_batch retourne [None, None, None], normalize_all_in_db
|
||
|
|
doit tenter le fallback unitaire pour chaque item."""
|
||
|
|
# Ce test est couvert par test_normalize_all_fallback_to_unit ci-dessous.
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Tests de normalize_all_in_db
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
class TestNormalizeAllInDb:
|
||
|
|
|
||
|
|
def test_dry_run_does_not_modify_db(self, db_with_items: Path):
|
||
|
|
"""Avec --dry-run, aucun article n'est mis à jour en base."""
|
||
|
|
llm_response = (
|
||
|
|
"1. Noix de cajou | MDD | 200g\n"
|
||
|
|
"2. Coca-Cola Cherry | Coca-Cola | 1,25L\n"
|
||
|
|
"3. Papier toilette confort | Lotus | x6"
|
||
|
|
)
|
||
|
|
with patch("tickettracker.llm.normalizer.call_llm") as mock_llm:
|
||
|
|
mock_llm.return_value = llm_response
|
||
|
|
nb_ok, nb_err = normalizer.normalize_all_in_db(
|
||
|
|
db_with_items, batch_size=20, dry_run=True
|
||
|
|
)
|
||
|
|
|
||
|
|
# Vérifie que la DB n'a pas été modifiée
|
||
|
|
conn = schema.get_connection(db_with_items)
|
||
|
|
still_null = repository.fetch_unnormalized(conn)
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
assert len(still_null) == 3 # toujours 3 NULL
|
||
|
|
assert nb_ok == 3 # mais 3 normalisés en mémoire
|
||
|
|
assert nb_err == 0
|
||
|
|
|
||
|
|
def test_updates_db_when_not_dry_run(self, db_with_items: Path):
|
||
|
|
"""Sans dry-run, les articles sont mis à jour en base."""
|
||
|
|
llm_response = (
|
||
|
|
"1. Noix de cajou | MDD | 200g\n"
|
||
|
|
"2. Coca-Cola Cherry | Coca-Cola | 1,25L\n"
|
||
|
|
"3. Papier toilette confort | Lotus | x6"
|
||
|
|
)
|
||
|
|
with patch("tickettracker.llm.normalizer.call_llm") as mock_llm:
|
||
|
|
mock_llm.return_value = llm_response
|
||
|
|
nb_ok, nb_err = normalizer.normalize_all_in_db(
|
||
|
|
db_with_items, batch_size=20, dry_run=False
|
||
|
|
)
|
||
|
|
|
||
|
|
conn = schema.get_connection(db_with_items)
|
||
|
|
still_null = repository.fetch_unnormalized(conn)
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
assert len(still_null) == 0 # plus de NULL
|
||
|
|
assert nb_ok == 3
|
||
|
|
assert nb_err == 0
|
||
|
|
|
||
|
|
def test_no_items_to_normalize(self, db_path: Path):
|
||
|
|
"""Base vide (aucun item) → message, (0, 0) retourné."""
|
||
|
|
with patch("tickettracker.llm.normalizer.call_llm") as mock_llm:
|
||
|
|
nb_ok, nb_err = normalizer.normalize_all_in_db(db_path)
|
||
|
|
mock_llm.assert_not_called()
|
||
|
|
assert nb_ok == 0
|
||
|
|
assert nb_err == 0
|
||
|
|
|
||
|
|
def test_fallback_to_unit_when_batch_returns_all_none(self, db_with_items: Path):
|
||
|
|
"""Si normalize_batch retourne tous None, le fallback unitaire est tenté."""
|
||
|
|
# Batch retourne mauvais count → [None, None, None]
|
||
|
|
# Fallback unitaire : normalize_product_name est appelé 3 fois
|
||
|
|
batch_response = "1. Un seul | truc | 200g" # 1 ligne pour 3 items → mauvais count
|
||
|
|
|
||
|
|
unit_responses = [
|
||
|
|
"1. Noix de cajou | MDD | 200g",
|
||
|
|
"1. Coca-Cola Cherry | Coca-Cola | 1,25L",
|
||
|
|
"1. Papier toilette confort | Lotus | x6",
|
||
|
|
]
|
||
|
|
|
||
|
|
call_count = {"n": 0}
|
||
|
|
|
||
|
|
def fake_call_llm(messages, **kwargs):
|
||
|
|
n = call_count["n"]
|
||
|
|
call_count["n"] += 1
|
||
|
|
if n == 0:
|
||
|
|
return batch_response # premier appel = batch → mauvais count
|
||
|
|
return unit_responses[n - 1] # appels suivants = unitaires
|
||
|
|
|
||
|
|
with patch("tickettracker.llm.normalizer.call_llm", side_effect=fake_call_llm):
|
||
|
|
nb_ok, nb_err = normalizer.normalize_all_in_db(
|
||
|
|
db_with_items, batch_size=20, dry_run=False
|
||
|
|
)
|
||
|
|
|
||
|
|
# 1 appel batch + 3 appels unitaires = 4 appels total
|
||
|
|
assert call_count["n"] == 4
|
||
|
|
assert nb_ok == 3
|
||
|
|
assert nb_err == 0
|
||
|
|
|
||
|
|
def test_error_items_stay_null(self, db_with_items: Path):
|
||
|
|
"""Les items dont la normalisation échoue restent NULL en base."""
|
||
|
|
with patch("tickettracker.llm.normalizer.call_llm") as mock_llm:
|
||
|
|
# Batch échoue, fallback échoue aussi
|
||
|
|
mock_llm.side_effect = LLMError("HTTP 500")
|
||
|
|
nb_ok, nb_err = normalizer.normalize_all_in_db(
|
||
|
|
db_with_items, batch_size=20, dry_run=False
|
||
|
|
)
|
||
|
|
|
||
|
|
conn = schema.get_connection(db_with_items)
|
||
|
|
still_null = repository.fetch_unnormalized(conn)
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
assert len(still_null) == 3
|
||
|
|
assert nb_ok == 0
|
||
|
|
assert nb_err == 3
|