diff --git a/ciphers/diffie.py b/ciphers/diffie.py
index 4ff90be009c1..1e1e868999b6 100644
--- a/ciphers/diffie.py
+++ b/ciphers/diffie.py
@@ -1,11 +1,28 @@
 from __future__ import annotations
 
 
-def find_primitive(n: int) -> int | None:
-    for r in range(1, n):
+def find_primitive(modulus: int) -> int | None:
+    """
+    Find a primitive root modulo modulus, if one exists.
+
+    Args:
+        modulus : The modulus for which to find a primitive root.
+
+    Returns:
+        The primitive root if one exists, or None if there is none.
+
+    Examples:
+    >>> find_primitive(7)  # Modulo 7 has primitive root 3
+    3
+    >>> find_primitive(11)  # Modulo 11 has primitive root 2
+    2
+    >>> find_primitive(8) == None # Modulo 8 has no primitive root
+    True
+    """
+    for r in range(1, modulus):
         li = []
-        for x in range(n - 1):
-            val = pow(r, x, n)
+        for x in range(modulus - 1):
+            val = pow(r, x, modulus)
             if val in li:
                 break
             li.append(val)
@@ -15,18 +32,22 @@ def find_primitive(n: int) -> int | None:
 
 
 if __name__ == "__main__":
-    q = int(input("Enter a prime number q: "))
-    a = find_primitive(q)
-    if a is None:
-        print(f"Cannot find the primitive for the value: {a!r}")
+    import doctest
+
+    doctest.testmod()
+
+    prime = int(input("Enter a prime number q: "))
+    primitive_root = find_primitive(prime)
+    if primitive_root is None:
+        print(f"Cannot find the primitive for the value: {primitive_root!r}")
     else:
         a_private = int(input("Enter private key of A: "))
-        a_public = pow(a, a_private, q)
+        a_public = pow(primitive_root, a_private, prime)
         b_private = int(input("Enter private key of B: "))
-        b_public = pow(a, b_private, q)
+        b_public = pow(primitive_root, b_private, prime)
 
-        a_secret = pow(b_public, a_private, q)
-        b_secret = pow(a_public, b_private, q)
+        a_secret = pow(b_public, a_private, prime)
+        b_secret = pow(a_public, b_private, prime)
 
         print("The key value generated by A is: ", a_secret)
         print("The key value generated by B is: ", b_secret)