[Python-checkins] r58492 - sandbox/trunk/decimal-c/_decimal.c

mateusz.rukowicz python-checkins at python.org
Tue Oct 16 07:11:10 CEST 2007


Author: mateusz.rukowicz
Date: Tue Oct 16 07:11:09 2007
New Revision: 58492

Modified:
   sandbox/trunk/decimal-c/_decimal.c
Log:
Completely rewritten sqrt. Now it's faster and better :>.


Modified: sandbox/trunk/decimal-c/_decimal.c
==============================================================================
--- sandbox/trunk/decimal-c/_decimal.c	(original)
+++ sandbox/trunk/decimal-c/_decimal.c	Tue Oct 16 07:11:09 2007
@@ -212,6 +212,25 @@
     }
 }
 
+/* self = self * 10 + dig */
+static void
+_limb_add_one_digit(long *self, long ndigits, long dig) {
+    long i;
+    /* it's not bug, we have ndigits + 1 digs */
+    long limb_count = (ndigits + LOG)/LOG;
+    
+    if(ndigits % LOG == 0)
+        self[limb_count-1] = 0;     /* we added extra limb, clear it */
+
+    for(i = limb_count-1; i > 0; i--) {
+        self[i] *= 10;
+        self[i] += self[i-1] / (BASE/10);
+        self[i-1] %= (BASE/10);
+    }
+    self[0] *= 10;
+    self[0] += dig;
+}
+
 static void
 _limb_fill(long *self, long ndigits, long x)
 {
@@ -486,7 +505,7 @@
 /* XXX it's naive dividing, very slow */
 /* min_new_pos tells, when we should stop dividing, useful for integer division
  * make it > flimbs - 2, and it will have no impact*/
-
+/* function assumes that rest is filled with 0s */
 static long
 _limb_divide(long *first, long flimbs, long *second, long slimbs,
         long *out, long prec, long *rest, long min_new_pos)
@@ -589,6 +608,67 @@
     free(tmp);
     return new_pos;
 }
+
+
+/* integer division, returns how many limbs result has */
+static long     
+_limb_integer_divide(long *first, long flimbs, long *second, long slimbs, long *out, long *remainder) {
+    long rpos; 
+    int i;      
+    rpos = _limb_divide(first, flimbs, second, slimbs, out, flimbs, remainder, -1);
+    /* this is index of first limb that is before dot */
+    rpos ++;    
+    for(i = rpos; i < flimbs; i ++)
+        out[i-rpos] = out[i];
+    return _limb_size(out, flimbs - rpos);
+                
+}               
+/* computes floor(sqrt(first)), first and result are integers, 
+ * res should have at least (flimbs+1)/2 + 1 size */
+static long 
+_limb_sqrt(long *first, long flimbs, long *res) {
+    /* it's Newton's method, we start with x_1 = first */
+    int i;
+    int cmp_res;
+    int rlimbs;
+    int qlimbs;
+    long *remainder;
+    long *quot = (long*)malloc(sizeof(long) * (flimbs + 1));
+    remainder = (long*)malloc(sizeof(long) * (flimbs + 1));
+        
+    /* upper bound */ 
+    rlimbs = (flimbs + 1) / 2 + 1;
+    for(i = 0; i < rlimbs-1; i++)
+        res[i] = 0;
+    res[rlimbs-1] = 1;
+            
+    while(1) {
+        /* quot = floor(first / res */
+        memset(remainder, 0, (flimbs + 1) * sizeof(long));
+        qlimbs = _limb_integer_divide(first, flimbs, res, rlimbs, quot, remainder);
+        
+        /* if rest <= quot then break - we stop here because otherwise result would grow
+         * and become ceiling instead of floor */
+        cmp_res = _limb_compare(res, rlimbs, quot, qlimbs);
+        if(cmp_res <= 0)
+            break;
+        
+        /* res = res + quot */
+        rlimbs = (_limb_add(res, LOG * rlimbs, quot, LOG * qlimbs, res) + LOG - 1) / LOG;
+            
+        /* res = floor(res / 2) */
+        for(i=rlimbs-1;i>0;i--) {
+            res[i-1] += (res[i] &1) * BASE;
+            res[i] >>= 1;
+        }       
+        res[0] >>= 1;
+        if(rlimbs > 1 &&!res[rlimbs-1]) rlimbs --;
+    }           
+    free(remainder); 
+    free(quot); 
+    return rlimbs;
+}
+
 /*
 static long
 _limb_normalize(long *first, long size) {
@@ -3059,20 +3139,17 @@
 static decimalobject *
 _do_decimal_sqrt(decimalobject *self, contextobject *ctx)
 {
-    decimalobject *ret = 0;
-    decimalobject *ans = 0;
-    decimalobject *tmp = 0, *tmp2 = 0, *tmp3 = 0;
-    contextobject *ctx2 = 0;
-    decimalobject *half = 0;
-    PyObject *flags = 0;
-    exp_t expadd;
-    long firstprec;
-    long i;
-    exp_t Emax;
-    exp_t Emin;
-    long maxp;
-    long rounding;
-    exp_t prevexp;
+    decimalobject *ret = 0, *ret2 = 0;
+    long *b_tab;    /* our temporal storage for integer DD */
+    long b_limbs;
+    exp_t e;        /* as below */
+    int i;
+    int exact = 1;      /* is result exact */
+    int shift = 0;      /* in case result is exact - we need 
+                           to restore perfect exponent */
+    long ret_limbs;
+    int rounding;
+    
     
     if (ISSPECIAL(self)) {
         decimalobject *nan;
@@ -3099,307 +3176,147 @@
         return handle_InvalidOperation(self->ob_type, ctx, "sqrt(-x), x > 0", NULL);
     }
 
-    tmp = _NEW_decimalobj(self->ob_size + 1, self->sign, self->exp);
-
-    if (!tmp)
-        return NULL;
-
-    expadd = exp_floordiv_i(tmp->exp, 2);
-    
-    if (exp_mod_i(tmp->exp, 2)) {
-        _limb_first_n_digits(self->limbs, self->ob_size, 0, tmp->limbs, tmp->ob_size);
+    /* The idea is pretty much the same that in Python code.
+     * Assume, that we have real a, computed with precision prec + 1
+     * that a <= sqrt(self) < a + 1ULP, and we know if a == sqrt(self) 
+     * (is it exact). With this a, we have two cases:
+     *  a) a == sqrt(self) -- just round a, and return
+     *  b) otherwise -- we know, that a < sqrt(self) < a + 1ULP
+     *  the only problem that arise, is when a ends with 5 or 0 (this is
+     *  the only case, that ROUND_HALF_EVEN will give incorrect result,
+     *  because if result ends with 5 or 5000..0, it's actually bigger)
+     *  so we just check last digit, if it's 5, we add 1ULP to a, round,
+     *  and return.
+     * 
+     * So now we want to find a. We'll find such b and e, that
+     * self = b * 10^e, e mod 2 == 0, e is integer and
+     * floor(sqrt(b)) has at last prec+1 significant digits -- that is
+     * c = floor(sqrt(b)) 10^(prec) <= c, so 10^(prec) <= sqrt(b), so
+     * 10^(2*prec) <= b. Given b and e, a = floor(sqrt(floor(b)) * 10^(e/2)
+     * we know that floor(sqrt(floor(b)) <= sqrt(b), so
+     * a <= sqrt(b) * 10^(e/2) = self. On the other side 
+     * floor(sqrt(floor(b)) <= sqrt(b) < floor(sqrt(floor(b)) + 1
+     * [the proof is pretty straightforward, show that there is no x /in (i,i+1)
+     * that sqrt(x) is integer for any integer i]
+     * Concluding, self < (a+1) * 10^(e/2) = a * 10^(e/2) + 10^(e/2) = a * 10^(e/2) + 1ULP []
+     * Concluding (one more time), we get b with at least 2*prec + 1 digits
+     * and we make sure, that e is even, self = b * 10^e; now, we truncate
+     * fractional part b (let DD = floor(b)), and set a = floor(sqrt(DD)) * 10^(e/2) */
+
+    /* we need at last ctx->prec*2 + 1 digits, givent n limbs, 
+     * we have at last n*4-3 digits, n*4-3 >= ctx->prec*2 + 1
+     * n*4 >= ctx->prec * 2 + 4, n >= ctx->prec/2 + 1, so will
+     * set n = ctx->prec/2 + 3 (because there might be possibility, that 
+     * we'll need to make exp even - in that case we'll extend b)
+     * but we have to fill b_limbs - 1 limbs */
+    b_limbs = ctx->prec/2 + 3;
+    b_tab = (long*)malloc(b_limbs * sizeof(long));
+    if(b_limbs -1 > self->limb_count) {
+        /* we have to extend */
+        int diff = (b_limbs - 1) - self->limb_count;
+        shift = -diff * 4;  /*exp becomes smaller */
+        for(i = 0;i < self->limb_count; i++) 
+            b_tab[i+diff] = self->limbs[i];
+        
+        for(i = 0;i < diff; i++) 
+            b_tab[i] = 0;
     }
     else {
-        tmp->ob_size --;
-        tmp->limb_count = self->limb_count;
-        for(i = 0; i < tmp->limb_count ; i++)
-            tmp->limbs[i] = self->limbs[i];
-    }
-
-    tmp->exp = exp_from_i(0);
-    
-    ctx2 = context_copy(ctx);
-
-    if (!ctx2) {
-        Py_DECREF(tmp);
-        return NULL;
+        /* we have to truncate, we'll fill b_limbs-1 limbs */
+        int diff = self->limb_count - (b_limbs - 1);
+        shift = diff * 4;   /* we truncated diff limbs, so exp will become greater */
+        for(i = 0;i < b_limbs - 1; i++) 
+            b_tab[i] = self->limbs[i + diff];
+        /* we'll check remainding limbs of self, to check if
+         * we truncated something important */
+        for(i = 0;i < diff; i++)
+            if(self->limbs[i]) {
+                exact = 0;
+                break;
+            }
     }
 
-    flags = context_ignore_all_flags(ctx2);
+    b_tab[b_limbs-1] = 0;   /* only one we didn't set */
 
-    if (!flags) {
-        Py_DECREF(tmp);
-        Py_DECREF(ctx2);
+    /* in case exp mod 2 = 1, we'll extend multiply b_tab * 10, and decrease shift
+     * [that's why we made room for another limb] */
+    if(exp_mod_i(self->exp, 2)) {
+        shift --;
+        _limb_add_one_digit(b_tab, (b_limbs-1) * LOG, 0);
+    }
+    b_limbs = _limb_size(b_tab, b_limbs);
+    
+    /* ok, now our DD is b_tab, and our e is self->exp + shift */
+
+    /* how big tab we need - sqrt uses ceil(limbs/2) + 1, but there
+     * is possibility that we're going to set ideal exp
+     * Our result exp will be (self->exp + shift)/2, and ideal exp is
+     * floor(self->exp/2), so we might be forced to extend mantisa by 
+     * shift/2 */
+    ret_limbs = (b_limbs + 1)/2 + 2 + (shift > 0 ? (shift + 1)/2 + 1 : 0);
+    /* TODO maybe I should use some temporary tab first? */
+    ret = _NEW_decimalobj(ret_limbs * LOG, self->sign, exp_from_i(0));
+    if(!ret) {
+        free(b_tab);
         return NULL;
     }
 
-    ans = _NEW_decimalobj(3, 0, exp_from_i(0));
-
-    if (!ans) {
-        Py_DECREF(tmp);
-        Py_DECREF(ctx2);
-        Py_DECREF(flags);
-    }
-
-    tmp2 = _NEW_decimalobj(3, 0, exp_from_i(0));
-
-    if (!tmp2) {
-        Py_DECREF(tmp);
-        Py_DECREF(ctx2);
-        Py_DECREF(flags);
-        Py_DECREF(ans);
-    }
-
-    if (exp_mod_i(ADJUSTED(tmp), 2) == 0) {
-        ans->limbs[0] = 819;
-        ans->exp = exp_sub_i(ADJUSTED(tmp), 2);
-        tmp2->limbs[0] = 259;
-        exp_inp_sub_i(&(tmp2->exp), 2);
-    }
-    else {
-        ans->limbs[0] = 259;
-        ans->exp = exp_add_i(tmp->exp, tmp->ob_size - 3);
-        tmp2->limbs[0] = 819;
-        exp_inp_sub_i(&(tmp2->exp), 3);
-    }
-
-    firstprec = ctx2->prec;
-    ctx2->prec = 3;
-    /* ans += tmp * tmp2 */
-
-    tmp3 = _do_decimal_multiply(tmp, tmp2, ctx2);
-
-    if (!tmp3)
-        goto err;
-
-    Py_DECREF(tmp2);
-    tmp2 = _do_decimal_add(ans, tmp3, ctx2);
-    Py_DECREF(tmp3);
-    tmp3 = 0;
-
-    if (!tmp2)
-        goto err;
-
-    Py_DECREF(ans);
-    ans = tmp2;
-    tmp2 = 0;
-/*    ans->exp -= 1 + ADJUSTED(tmp)/2;
-    if (1 + ADJUSTED(tmp) < 0 && (1 + ADJUSTED(tmp)) % 2)
-        ans->exp --;
-  */
-    exp_inp_sub(&(ans->exp),exp_add_i(exp_floordiv_i(ADJUSTED(tmp), 2), 1));  
-    Emax = ctx2->Emax;
-    Emin = ctx2->Emin;
-
-    ctx2->Emax = PyDecimal_DefaultContext->Emax;
-    ctx2->Emin = PyDecimal_DefaultContext->Emin;
-
-    half = _decimal_fromliteral(self->ob_type, "0.5", 3, ctx2);
-
-    if (!half)
-        goto err;
-
-    maxp = firstprec + 2;
-    rounding = ctx2->rounding;
-    ctx2->rounding = ROUND_HALF_EVEN;    
-
-    while (1) {
-        ctx2->prec = 2 * ctx2->prec - 2 < maxp ? 2 * ctx2->prec - 2: maxp;
-        /* ans = half * (ans + tmp/ans) */
-        tmp2 = (decimalobject*) _do_decimal__divide(tmp, ans, 0, ctx2);
-        if (!tmp2)
-            goto err;
-        /* ans = half * (ans + tmp2) */
-        tmp3 = _do_decimal_add(ans, tmp2, ctx2);
-        if (!tmp3)
-            goto err;
-
-        /* ans = half * tmp3 */
-        Py_DECREF(ans);
-        ans = _do_decimal_multiply(half, tmp3, ctx2);
-
-        if (!ans)
-            goto err;
-
-        Py_DECREF(tmp2);
-        Py_DECREF(tmp3);
-        tmp2 = 0;
-        tmp3 = 0;
-
-        if (ctx2->prec == maxp)
-            break;
-    }
-
-    ctx2->prec = firstprec;
-    prevexp = ADJUSTED(ans);
-    tmp2 = _decimal_round(ans, -1, ctx2, -1);
-    if (!tmp2)
-        goto err;
-    Py_DECREF(ans);
-    ans = _NEW_decimalobj(tmp2->ob_size + 1, tmp2->sign, tmp2->exp);
-    if (!ans)
-        goto err;
-
-    ctx2->prec = firstprec + 1;
-    if (exp_ne(prevexp, ADJUSTED(tmp2))) {
-        _limb_first_n_digits(tmp2->limbs, tmp2->ob_size, 0, ans->limbs, ans->ob_size);
-        exp_dec(&(ans->exp));
-    }
-    else {
-        ans->limb_count = tmp2->limb_count;
-        ans->ob_size = tmp2->ob_size;
-        for (i = 0; i < ans->limb_count; i++) 
-            ans->limbs[i] = tmp2->limbs[i];
-    }
-
-    Py_DECREF(tmp2);
-    tmp2 = 0;
-
-    {
-        int cmp;
-        decimalobject *lower;
-        half->exp = exp_sub_i(ans->exp, 1);
-        half->limbs[0] = 5;
-        lower = _do_decimal_subtract(ans, half, ctx2);
-        if (!lower)
-            goto err;
-
-        ctx2->rounding = ROUND_UP;
-        tmp2 = _do_decimal_multiply(lower, lower, ctx2);
-        Py_DECREF(lower);
-        lower = 0;
-        if (!tmp2) {
-            goto err;
-        }
-
-        cmp = _do_real_decimal_compare(tmp2, tmp, ctx2);
-        if (PyErr_Occurred()) {
-            goto err;
-        }
-
-        Py_DECREF(tmp2);
-        tmp2 = 0;
-
-        if (cmp == 1) {
-            half->exp = ans->exp;
-            half->limbs[0] = 1;
-            tmp2 = _do_decimal_subtract(ans, half, ctx2);
-            if (!tmp2)
-                goto err;
-            Py_DECREF(ans);
-            ans = tmp2;
-            tmp2 = 0;
+    ret_limbs = _limb_sqrt(b_tab, b_limbs, ret->limbs);
+    ret->limb_count = ret_limbs;
+    ret->ob_size = _limb_size_s(ret->limbs, ret_limbs * LOG);
+    /* let's check if it's exact */
+    {
+        long tmp_digs;
+        long cmp_res;
+        long *tmp = (long*)malloc((ret_limbs * 2 + 1) * sizeof(long));
+        tmp_digs = _limb_multiply(ret->limbs, ret->ob_size, ret->limbs,
+                ret->ob_size, tmp);
+       
+        cmp_res = _limb_compare_un(b_tab, b_limbs, tmp, (tmp_digs +LOG-1)/LOG);
+        exact = exact && cmp_res == 0;
+        free(tmp);
+    } 
+    free(b_tab);
+
+    /* now all we need is to set exp :> and add 1 ULP in case it's not exact ;> */
+
+    if (exact) {
+        long expdiff;
+        exp_t tmp_exp;
+        ret->exp = exp_floordiv_i(self->exp, 2);
+        tmp_exp = exp_floordiv_i(exp_add_i(self->exp, shift), 2);
+        expdiff = exp_to_i(exp_sub(ret->exp, tmp_exp));
+        /* TODO SLOW */
+        if(expdiff > 0) {
+            while(expdiff && ret->ob_size > 1) {
+                assert(!(ret->limbs[0] % 10));  /* Am I sure?? TODO */
+                _limb_cut_one_digit(ret->limbs, ret->ob_size);
+                ret->ob_size --;
+                expdiff --;
+            }
         }
         else {
-            decimalobject *upper;
-            half->exp = exp_sub_i(ans->exp, 1);
-            half->limbs[0] = 5;
-            upper = _do_decimal_add(ans, half, ctx2);
-            if (!upper)
-                goto err;
-            ctx2->rounding = ROUND_DOWN;
-
-            tmp2 = _do_decimal_multiply(upper, upper, ctx2);
-
-            Py_DECREF(upper);
-            upper = 0;
-
-            cmp = _do_real_decimal_compare(tmp2, tmp, ctx2);
-            if (PyErr_Occurred())
-                goto err;
-            
-            Py_DECREF(tmp2);
-            tmp2 = 0;
-
-            if (cmp == -1) {
-                half->exp = ans->exp;
-                half->limbs[0] = 1;
-                tmp2 = _do_decimal_add(ans, half, ctx2);
-                
-                if (!tmp2)
-                    goto err;
-                Py_DECREF(ans);
-                ans = tmp2;
-                tmp2 = 0;
+            while(expdiff) {
+                _limb_add_one_digit(ret->limbs, ret->ob_size, 0);
+                ret->ob_size ++;
+                expdiff ++;
             }
         }
     }
-
-    exp_inp_add(&(ans->exp), expadd);
-    ctx2->rounding = rounding;
-
-    tmp2 = _decimal_fix(ans, ctx2);
-    if (!tmp2)
-        goto err;
-    Py_DECREF(ans);
-    ans = tmp2;
-    tmp2 = 0;
-
-    rounding = ctx2->rounding_dec;
-    ctx2->rounding_dec = NEVER_ROUND;
-
-    {
-        int cmp;
-        tmp2 = _do_decimal_multiply(ans, ans, ctx2);
-        if (!tmp2)
-            goto err;
-
-        cmp = _do_real_decimal_compare(tmp2, self, ctx2);
-
-        if (PyErr_Occurred()) 
-            goto err;
-
-        Py_DECREF(tmp2);
-        tmp2 = 0;
-
-        if (cmp != 0) {
-            if (handle_Rounded(ctx, NULL))
-                goto err;
-
-            if (handle_Inexact(ctx, NULL))
-                goto err;
-        }
-
-        else {
-/*            long exp = self->exp / 2;
-            if (self->exp < 0 && self->exp % 2)
-                exp --;*/
-            exp_t exp = exp_floordiv_i(self->exp, 2);
-            ctx2->prec += exp_to_i(exp_sub(ans->exp, exp));
-            tmp2 = _decimal_rescale(ans, exp, ctx2, -1, 1);
-
-            if (!tmp2)
-                goto err;
-            Py_DECREF(ans);
-            ans = tmp2;
-            tmp2 = 0;
-            ctx2->prec = firstprec;
-        }
+    else {
+        if(ret->limbs[0] % 5 == 0) ret->limbs[0] ++;
+        ret-> exp = exp_floordiv_i(exp_add_i(self->exp, shift), 2);
     }
 
-    tmp2 = _decimal_fix(ans, ctx);
-    if (!tmp2)
-        goto err;
-    Py_DECREF(ans);
-    ans = tmp2;
-    tmp2 = 0;
-
-    Py_DECREF(flags);
-    Py_DECREF(ctx2);
-    Py_DECREF(half);
-    Py_DECREF(tmp);
-
-    return ans;
-
-err:
-    Py_XDECREF(tmp);
-    Py_XDECREF(tmp2);
-    Py_XDECREF(tmp3);
-    Py_XDECREF(ans);
-    Py_XDECREF(flags);
-    Py_XDECREF(ctx2);
-    Py_XDECREF(ret);
-    Py_XDECREF(half);
+    /* TODO it's not thread safe */
+    rounding = ctx->rounding;
+    ctx->rounding = ROUND_HALF_EVEN;
+    
+    ret2 = _decimal_fix(ret, ctx);
+    ctx->rounding = rounding;
+    Py_DECREF(ret);
+    return ret2;
+    
     return NULL;
 }
 DECIMAL_UNARY_FUNC(sqrt)


More information about the Python-checkins mailing list