diff --git a/setup.py b/setup.py index 06f29abe1f8ccd42d0fa74f54d17608b46d07361..8abb33a7ae531d318d03198544fc1afee93c151a 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup, find_packages setup( name="stripe-checkout", - version="0.2", + version="0.3", author="GEANT", author_email="swd@geant.org", description="Stripe custom checkout support service", diff --git a/stripe_checkout/settings/base.py b/stripe_checkout/settings/base.py index 798ff849cc3e700f73d5bf6918cd4b1ef22cc3cc..10bf54f2e047202171b4ee9a9ca2ea0b9f7fa2e9 100644 --- a/stripe_checkout/settings/base.py +++ b/stripe_checkout/settings/base.py @@ -136,3 +136,5 @@ LOGGING = { "level": "INFO", }, } + +STRIPE_WEBHOOK_ALLOWED_IPS = ["*"] diff --git a/stripe_checkout/settings/prod.py b/stripe_checkout/settings/prod.py index d4f7007b270d61ed96881ab80b7650ed0de2d819..6edf70c831cfbfaa7276a1904e6723c48394f38d 100644 --- a/stripe_checkout/settings/prod.py +++ b/stripe_checkout/settings/prod.py @@ -30,3 +30,18 @@ STATIC_URL = os.getenv("STATIC_URL", "/static/") # noqa: F405 STATIC_ROOT = os.getenv("STATIC_ROOT", "staticfiles/") # noqa: F405 SESSION_COOKIE_SECURE = True CSRF_COOKIE_SECURE = True + +STRIPE_WEBHOOK_ALLOWED_IPS = [ + "3.18.12.63", + "3.130.192.231", + "13.235.14.237", + "13.235.122.149", + "18.211.135.69", + "35.154.171.200", + "52.15.183.38", + "54.88.130.119", + "54.88.130.237", + "54.187.174.169", + "54.187.205.235", + "54.187.216.72", +] diff --git a/stripe_checkout/stripe_checkout/admin.py b/stripe_checkout/stripe_checkout/admin.py index 1e5c1b49887003876fd18f720639b4bc97cc1dce..1af591e6b889e8996654e860e3581a4b82895a54 100644 --- a/stripe_checkout/stripe_checkout/admin.py +++ b/stripe_checkout/stripe_checkout/admin.py @@ -1,5 +1,5 @@ from django.contrib import admin -from .models import PricedItem, Event, Order +from .models import ExchangeRate, PricedItem, Event, Order @admin.register(PricedItem) @@ -18,3 +18,8 @@ class EventAdmin(admin.ModelAdmin): @admin.register(Order) class OrderAdmin(admin.ModelAdmin): list_display = ["visitor_id", "stripe_id", "paid"] + + +@admin.register(ExchangeRate) +class ExchangeRateAdmin(admin.ModelAdmin): + list_display = ["date", "rate"] diff --git a/stripe_checkout/stripe_checkout/management/commands/getexchangerate.py b/stripe_checkout/stripe_checkout/management/commands/getexchangerate.py new file mode 100644 index 0000000000000000000000000000000000000000..847bfa97057ba02c8fa5dcd7333b06bd82cca773 --- /dev/null +++ b/stripe_checkout/stripe_checkout/management/commands/getexchangerate.py @@ -0,0 +1,45 @@ +import csv +import datetime +from django.core.management.base import BaseCommand +import requests +from stripe_checkout.stripe_checkout.models import ExchangeRate +from django.utils import timezone + +URL = ( + "https://www.trade-tariff.service.gov.uk/api/v2/exchange_rates/files/" + "monthly_csv_{year}-{month}.csv" +) + + +class Command(BaseCommand): + def handle(self, *args, **options): + date = timezone.now() + month_str = date.strftime("%Y-%m") + exchange_rate = ExchangeRate.objects.filter(date__month=date.month).first() + if exchange_rate: + self.stdout.write( + f"Exchange rate for {month_str}" + f" already exists ({exchange_rate.rate:.4f})" + ) + return + data = self.get_data(date) + rate = self.extract_exchange_rate(data) + ExchangeRate.objects.create(rate=rate, date=date.date()) + self.stdout.write( + self.style.SUCCESS( + f"Updated exchange rate for {date.strftime('%Y-%m')}: {rate:.4f}" + ) + ) + + def get_data(self, date: datetime.date): + url = URL.format(month=date.month, year=date.year) + result = requests.get(url) + result.raise_for_status() + return result.content.decode() + + def extract_exchange_rate(self, content: str): + reader = csv.DictReader(content.split("\n")) + for row in reader: + if row["Currency Code"] == "EUR": + return 1 / float(row["Currency Units per £1"]) + raise ValueError("") diff --git a/stripe_checkout/stripe_checkout/migrations/0004_alter_order_items.py b/stripe_checkout/stripe_checkout/migrations/0004_alter_order_items.py new file mode 100644 index 0000000000000000000000000000000000000000..b4b054b5a14c6365db500d32fe8931d3a303df54 --- /dev/null +++ b/stripe_checkout/stripe_checkout/migrations/0004_alter_order_items.py @@ -0,0 +1,23 @@ +# Generated by Django 5.1.4 on 2025-01-14 14:38 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ( + "stripe_checkout", + "0003_rename_stripe_id_priceditem_stripe_product_id_and_more", + ), + ] + + operations = [ + migrations.AlterField( + model_name="order", + name="items", + field=models.ManyToManyField( + related_name="order", to="stripe_checkout.priceditem" + ), + ), + ] diff --git a/stripe_checkout/stripe_checkout/migrations/0005_exchangerate_order_exchange_rate.py b/stripe_checkout/stripe_checkout/migrations/0005_exchangerate_order_exchange_rate.py new file mode 100644 index 0000000000000000000000000000000000000000..8fea9c1c264a5b3d7b36726b3aef2e76c13372f6 --- /dev/null +++ b/stripe_checkout/stripe_checkout/migrations/0005_exchangerate_order_exchange_rate.py @@ -0,0 +1,41 @@ +# Generated by Django 5.1.4 on 2025-01-14 15:39 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("stripe_checkout", "0004_alter_order_items"), + ] + + operations = [ + migrations.CreateModel( + name="ExchangeRate", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("rate", models.FloatField()), + ("date", models.DateField()), + ], + ), + migrations.AddField( + model_name="order", + name="exchange_rate", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="orders", + to="stripe_checkout.exchangerate", + ), + ), + ] diff --git a/stripe_checkout/stripe_checkout/models.py b/stripe_checkout/stripe_checkout/models.py index 13b0d7c54e96016fbbfb33cb703308e33ae3f2a5..b81a4d5947198e6be4a4d8a1d1de1a5f22182874 100644 --- a/stripe_checkout/stripe_checkout/models.py +++ b/stripe_checkout/stripe_checkout/models.py @@ -56,11 +56,23 @@ class PricedItem(models.Model): ) +class ExchangeRate(models.Model): + rate = models.FloatField() + date = models.DateField() + + class Order(models.Model): visitor_id = models.CharField(max_length=40, blank=False) stripe_id = models.CharField(max_length=255, blank=False, unique=True) items = models.ManyToManyField(to=PricedItem, related_name="order") paid = models.BooleanField(default=False, db_default=False) + exchange_rate = models.ForeignKey( + ExchangeRate, + on_delete=models.SET_NULL, + blank=True, + null=True, + related_name="orders", + ) class Event(models.Model): diff --git a/stripe_checkout/stripe_checkout/stripe.py b/stripe_checkout/stripe_checkout/stripe.py index b81e5f39f7369f963636842a3caa958542410519..9a8fcc386baf4ed9339c4097ea3ea51bdcba6300 100644 --- a/stripe_checkout/stripe_checkout/stripe.py +++ b/stripe_checkout/stripe_checkout/stripe.py @@ -9,6 +9,7 @@ import stripe.error from stripe.error import StripeError # noqa F401 from django.conf import settings +from stripe_checkout.stripe_checkout.models import ExchangeRate from stripe_checkout.stripe_checkout.visit import Visitor if TYPE_CHECKING: @@ -40,7 +41,11 @@ def get_or_create_customer(visitor: Visitor) -> Optional[str]: def create_invoice( - shopping_cart: ShoppingCart, customer_id, purchase_order=None, vat_number=None + shopping_cart: ShoppingCart, + customer_id, + purchase_order=None, + vat_number=None, + gbp_exchange_rate: Optional[ExchangeRate] = None, ): stripe.api_key = settings.STRIPE_API_KEY custom_fields = [] @@ -48,6 +53,12 @@ def create_invoice( custom_fields.append({"name": "Purchase Order", "value": purchase_order}) if vat_number: custom_fields.append({"name": "VAT number", "value": vat_number}) + if gbp_exchange_rate: + rate = gbp_exchange_rate.rate + vat = shopping_cart.vat * rate + custom_fields.append( + {"name": "GBP VAT Rate", "value": f"GBP {vat:.2f} ({rate:.4f})"} + ) invoice = stripe.Invoice.create( customer=customer_id, diff --git a/stripe_checkout/stripe_checkout/utils.py b/stripe_checkout/stripe_checkout/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fd58c721426a2f67f4f9fde3e9a31adfb06d0508 --- /dev/null +++ b/stripe_checkout/stripe_checkout/utils.py @@ -0,0 +1,39 @@ +import functools + +from django.conf import settings +from django.http import HttpResponseForbidden + +ALLOW_ALL = "*" + + +def whitelist_ips(func=None, by_setting=None): + if not func: + return functools.partial(whitelist_ips, by_setting=by_setting) + + @functools.wraps(func) + def _whitelisted_view(request): + allowed_ips = _read_whitelist_setting(by_setting, [ALLOW_ALL]) + if not allowed_ips & {ALLOW_ALL, get_client_ip(request)}: + + return HttpResponseForbidden("forbidden") + return func(request) + + return _whitelisted_view + + +def _read_whitelist_setting(setting: str, default=None) -> set[str]: + result = getattr(settings, setting, default) + assert isinstance(result, list) and all( + isinstance(s, str) for s in result + ), f"{setting} must be a list of strings" + return set(result) + + +def get_client_ip(request): + # cf. https://stackoverflow.com/a/4581997 + x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR") + if x_forwarded_for: + ip = x_forwarded_for.split(",")[0] + else: + ip = request.META.get("REMOTE_ADDR") + return ip diff --git a/stripe_checkout/stripe_checkout/visit_views.py b/stripe_checkout/stripe_checkout/visit_views.py index eb9f70059e31d0eb7a2e04c51e0a20ded168af35..463e715864f494348820b9e444483e308f83a0d8 100644 --- a/stripe_checkout/stripe_checkout/visit_views.py +++ b/stripe_checkout/stripe_checkout/visit_views.py @@ -1,16 +1,17 @@ import logging from typing import Union + +import requests from django import forms from django.http import Http404, HttpResponse from django.shortcuts import redirect, render -from django.views.decorators.http import require_POST, require_http_methods, require_GET from django.views.decorators.csrf import csrf_exempt -import requests +from django.views.decorators.http import require_GET, require_http_methods, require_POST -from stripe_checkout.stripe_checkout.shopping_cart import ShoppingCart - -from .models import Event from . import stripe +from .models import Event, ExchangeRate +from .shopping_cart import ShoppingCart +from .utils import whitelist_ips from .visit import VisitorAPI logger = logging.getLogger(__name__) @@ -18,7 +19,7 @@ logger = logging.getLogger(__name__) class PaymentDetailsForm(forms.Form): purchase_order = forms.CharField( - label="Purchase Order (optional)", max_length=10, required=False + label="Purchase Order (optional)", max_length=100, required=False ) vat_number = forms.CharField( label="VAT Number (optional)", max_length=100, required=False @@ -60,11 +61,14 @@ def checkout(request, visitor_id): def create_invoice(visitor, data): shopping_cart = get_shopping_cart(visitor) customer = stripe.get_or_create_customer(visitor) + exchange_rate = ExchangeRate.objects.order_by("-date").first() + return stripe.create_invoice( shopping_cart, customer, purchase_order=data["purchase_order"], vat_number=data["vat_number"], + gbp_exchange_rate=exchange_rate, ) @@ -85,6 +89,7 @@ def checkout_success(request, visitor_id): @csrf_exempt @require_POST +@whitelist_ips(by_setting="STRIPE_WEBHOOK_ALLOWED_IPS") def stripe_event(request): try: event = stripe.read_event(request.body, request.headers.get("Stripe-Signature")) diff --git a/test/conftest.py b/test/conftest.py index 9d4833720bc8a09970a7b06f7d75d4c79a5d61cd..c49d56d1ecb0c096758c1e829868017de31534c9 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -9,9 +9,9 @@ import responses from django.conf import settings from django.contrib.auth import get_user_model from django.test import Client - +from django.utils import timezone from stripe_checkout import config -from stripe_checkout.stripe_checkout.models import ItemKind, PricedItem +from stripe_checkout.stripe_checkout.models import ExchangeRate, ItemKind, PricedItem VISIT_RESPONSES_DIR = pathlib.Path(__file__).parent / "visit-responses" @@ -240,3 +240,8 @@ def config_file(stripe_api_key, visit_api_key, visit_expo_id, tmp_path): @pytest.fixture def client(setup_django, mock_visit, mock_stripe): return Client() + + +@pytest.fixture +def default_exchange_rate(): + return ExchangeRate.objects.create(date=timezone.now().date(), rate=0.8) diff --git a/test/test_getexchangerate.py b/test/test_getexchangerate.py new file mode 100644 index 0000000000000000000000000000000000000000..8e8e01c2ea5bb1840d9f6d6d1c5c1116fdb28ddf --- /dev/null +++ b/test/test_getexchangerate.py @@ -0,0 +1,45 @@ +import re +import pytest +import responses +from django.core.management import call_command +from django.utils import timezone +from stripe_checkout.stripe_checkout.models import ExchangeRate + + +def fake_csv(): + return "\n".join( + [ + "Currency Code,Currency Units per £1", + "EUR, 1.250", + ] + ) + + +@pytest.fixture(autouse=True) +def mock_rates_api(): + responses.add( + responses.GET, + re.compile( + r"https://www.trade-tariff.service.gov.uk/api/v2/exchange_rates/files/.+$" + ), + body=fake_csv(), + ) + + +@responses.activate +@pytest.mark.django_db +def test_fetches_data(): + call_command("getexchangerate") + exchange_rates = ExchangeRate.objects.all() + assert len(exchange_rates) == 1 + assert exchange_rates[0].date == timezone.now().date() + assert exchange_rates[0].rate == 0.8 + + +@responses.activate +@pytest.mark.django_db +def test_fetches_data_only_once(): + call_command("getexchangerate") + call_command("getexchangerate") + exchange_rates = ExchangeRate.objects.all() + assert len(exchange_rates) == 1 diff --git a/test/test_visit.py b/test/test_visit.py index 454629c6df9d1bd4616155756e4e8728a284c27b..535a93d969f115157e260432a486b8d9cd7ee308 100644 --- a/test/test_visit.py +++ b/test/test_visit.py @@ -2,6 +2,7 @@ import json from unittest.mock import patch import pytest import responses +import stripe from stripe_checkout.stripe_checkout.models import Event, Order @@ -35,6 +36,17 @@ def test_create_invoice(client, visitor_id): assert Order.objects.count() == 1 +@responses.activate +@pytest.mark.django_db +def test_exchange_rate(client, default_exchange_rate, visitor_id): + client.post(f"/checkout/{visitor_id}/", data={"payment_method": "invoice"}) + call_args = stripe.Invoice.create.call_args[1] + assert call_args["custom_fields"][0] == { + "name": "GBP VAT Rate", + "value": "GBP 1.60 (0.8000)", + } + + @responses.activate @pytest.mark.django_db def test_event_webhook(client): @@ -50,3 +62,23 @@ def test_event_webhook(client): ) assert rv.status_code == 200 assert Event.objects.exists() + + +@responses.activate +@pytest.mark.django_db +def test_event_webhook_disallowed_when_not_whitelisted(client, settings): + settings.STRIPE_WEBHOOK_ALLOWED_IPS = ["1.1.1.1"] + with patch( + "stripe.Webhook.construct_event", side_effect=lambda b, *_: json.loads(b) + ): + rv = client.post( + "/stripe-event-webhook/", + json.dumps( + { + "type": "invoice.paid", + "data": {"object": {"id": "stripe-invoice"}}, + } + ), + content_type="application/json", + ) + assert rv.status_code == 403