[Python-checkins] r58492 - sandbox/trunk/decimal-c/_decimal.c
mateusz.rukowicz
python-checkins at python.org
Tue Oct 16 07:11:10 CEST 2007
Author: mateusz.rukowicz
Date: Tue Oct 16 07:11:09 2007
New Revision: 58492
Modified:
sandbox/trunk/decimal-c/_decimal.c
Log:
Completely rewritten sqrt. Now it's faster and better :>.
Modified: sandbox/trunk/decimal-c/_decimal.c
==============================================================================
--- sandbox/trunk/decimal-c/_decimal.c (original)
+++ sandbox/trunk/decimal-c/_decimal.c Tue Oct 16 07:11:09 2007
@@ -212,6 +212,25 @@
}
}
+/* self = self * 10 + dig */
+static void
+_limb_add_one_digit(long *self, long ndigits, long dig) {
+ long i;
+ /* it's not bug, we have ndigits + 1 digs */
+ long limb_count = (ndigits + LOG)/LOG;
+
+ if(ndigits % LOG == 0)
+ self[limb_count-1] = 0; /* we added extra limb, clear it */
+
+ for(i = limb_count-1; i > 0; i--) {
+ self[i] *= 10;
+ self[i] += self[i-1] / (BASE/10);
+ self[i-1] %= (BASE/10);
+ }
+ self[0] *= 10;
+ self[0] += dig;
+}
+
static void
_limb_fill(long *self, long ndigits, long x)
{
@@ -486,7 +505,7 @@
/* XXX it's naive dividing, very slow */
/* min_new_pos tells, when we should stop dividing, useful for integer division
* make it > flimbs - 2, and it will have no impact*/
-
+/* function assumes that rest is filled with 0s */
static long
_limb_divide(long *first, long flimbs, long *second, long slimbs,
long *out, long prec, long *rest, long min_new_pos)
@@ -589,6 +608,67 @@
free(tmp);
return new_pos;
}
+
+
+/* integer division, returns how many limbs result has */
+static long
+_limb_integer_divide(long *first, long flimbs, long *second, long slimbs, long *out, long *remainder) {
+ long rpos;
+ int i;
+ rpos = _limb_divide(first, flimbs, second, slimbs, out, flimbs, remainder, -1);
+ /* this is index of first limb that is before dot */
+ rpos ++;
+ for(i = rpos; i < flimbs; i ++)
+ out[i-rpos] = out[i];
+ return _limb_size(out, flimbs - rpos);
+
+}
+/* computes floor(sqrt(first)), first and result are integers,
+ * res should have at least (flimbs+1)/2 + 1 size */
+static long
+_limb_sqrt(long *first, long flimbs, long *res) {
+ /* it's Newton's method, we start with x_1 = first */
+ int i;
+ int cmp_res;
+ int rlimbs;
+ int qlimbs;
+ long *remainder;
+ long *quot = (long*)malloc(sizeof(long) * (flimbs + 1));
+ remainder = (long*)malloc(sizeof(long) * (flimbs + 1));
+
+ /* upper bound */
+ rlimbs = (flimbs + 1) / 2 + 1;
+ for(i = 0; i < rlimbs-1; i++)
+ res[i] = 0;
+ res[rlimbs-1] = 1;
+
+ while(1) {
+ /* quot = floor(first / res */
+ memset(remainder, 0, (flimbs + 1) * sizeof(long));
+ qlimbs = _limb_integer_divide(first, flimbs, res, rlimbs, quot, remainder);
+
+ /* if rest <= quot then break - we stop here because otherwise result would grow
+ * and become ceiling instead of floor */
+ cmp_res = _limb_compare(res, rlimbs, quot, qlimbs);
+ if(cmp_res <= 0)
+ break;
+
+ /* res = res + quot */
+ rlimbs = (_limb_add(res, LOG * rlimbs, quot, LOG * qlimbs, res) + LOG - 1) / LOG;
+
+ /* res = floor(res / 2) */
+ for(i=rlimbs-1;i>0;i--) {
+ res[i-1] += (res[i] &1) * BASE;
+ res[i] >>= 1;
+ }
+ res[0] >>= 1;
+ if(rlimbs > 1 &&!res[rlimbs-1]) rlimbs --;
+ }
+ free(remainder);
+ free(quot);
+ return rlimbs;
+}
+
/*
static long
_limb_normalize(long *first, long size) {
@@ -3059,20 +3139,17 @@
static decimalobject *
_do_decimal_sqrt(decimalobject *self, contextobject *ctx)
{
- decimalobject *ret = 0;
- decimalobject *ans = 0;
- decimalobject *tmp = 0, *tmp2 = 0, *tmp3 = 0;
- contextobject *ctx2 = 0;
- decimalobject *half = 0;
- PyObject *flags = 0;
- exp_t expadd;
- long firstprec;
- long i;
- exp_t Emax;
- exp_t Emin;
- long maxp;
- long rounding;
- exp_t prevexp;
+ decimalobject *ret = 0, *ret2 = 0;
+ long *b_tab; /* our temporal storage for integer DD */
+ long b_limbs;
+ exp_t e; /* as below */
+ int i;
+ int exact = 1; /* is result exact */
+ int shift = 0; /* in case result is exact - we need
+ to restore perfect exponent */
+ long ret_limbs;
+ int rounding;
+
if (ISSPECIAL(self)) {
decimalobject *nan;
@@ -3099,307 +3176,147 @@
return handle_InvalidOperation(self->ob_type, ctx, "sqrt(-x), x > 0", NULL);
}
- tmp = _NEW_decimalobj(self->ob_size + 1, self->sign, self->exp);
-
- if (!tmp)
- return NULL;
-
- expadd = exp_floordiv_i(tmp->exp, 2);
-
- if (exp_mod_i(tmp->exp, 2)) {
- _limb_first_n_digits(self->limbs, self->ob_size, 0, tmp->limbs, tmp->ob_size);
+ /* The idea is pretty much the same that in Python code.
+ * Assume, that we have real a, computed with precision prec + 1
+ * that a <= sqrt(self) < a + 1ULP, and we know if a == sqrt(self)
+ * (is it exact). With this a, we have two cases:
+ * a) a == sqrt(self) -- just round a, and return
+ * b) otherwise -- we know, that a < sqrt(self) < a + 1ULP
+ * the only problem that arise, is when a ends with 5 or 0 (this is
+ * the only case, that ROUND_HALF_EVEN will give incorrect result,
+ * because if result ends with 5 or 5000..0, it's actually bigger)
+ * so we just check last digit, if it's 5, we add 1ULP to a, round,
+ * and return.
+ *
+ * So now we want to find a. We'll find such b and e, that
+ * self = b * 10^e, e mod 2 == 0, e is integer and
+ * floor(sqrt(b)) has at last prec+1 significant digits -- that is
+ * c = floor(sqrt(b)) 10^(prec) <= c, so 10^(prec) <= sqrt(b), so
+ * 10^(2*prec) <= b. Given b and e, a = floor(sqrt(floor(b)) * 10^(e/2)
+ * we know that floor(sqrt(floor(b)) <= sqrt(b), so
+ * a <= sqrt(b) * 10^(e/2) = self. On the other side
+ * floor(sqrt(floor(b)) <= sqrt(b) < floor(sqrt(floor(b)) + 1
+ * [the proof is pretty straightforward, show that there is no x /in (i,i+1)
+ * that sqrt(x) is integer for any integer i]
+ * Concluding, self < (a+1) * 10^(e/2) = a * 10^(e/2) + 10^(e/2) = a * 10^(e/2) + 1ULP []
+ * Concluding (one more time), we get b with at least 2*prec + 1 digits
+ * and we make sure, that e is even, self = b * 10^e; now, we truncate
+ * fractional part b (let DD = floor(b)), and set a = floor(sqrt(DD)) * 10^(e/2) */
+
+ /* we need at last ctx->prec*2 + 1 digits, givent n limbs,
+ * we have at last n*4-3 digits, n*4-3 >= ctx->prec*2 + 1
+ * n*4 >= ctx->prec * 2 + 4, n >= ctx->prec/2 + 1, so will
+ * set n = ctx->prec/2 + 3 (because there might be possibility, that
+ * we'll need to make exp even - in that case we'll extend b)
+ * but we have to fill b_limbs - 1 limbs */
+ b_limbs = ctx->prec/2 + 3;
+ b_tab = (long*)malloc(b_limbs * sizeof(long));
+ if(b_limbs -1 > self->limb_count) {
+ /* we have to extend */
+ int diff = (b_limbs - 1) - self->limb_count;
+ shift = -diff * 4; /*exp becomes smaller */
+ for(i = 0;i < self->limb_count; i++)
+ b_tab[i+diff] = self->limbs[i];
+
+ for(i = 0;i < diff; i++)
+ b_tab[i] = 0;
}
else {
- tmp->ob_size --;
- tmp->limb_count = self->limb_count;
- for(i = 0; i < tmp->limb_count ; i++)
- tmp->limbs[i] = self->limbs[i];
- }
-
- tmp->exp = exp_from_i(0);
-
- ctx2 = context_copy(ctx);
-
- if (!ctx2) {
- Py_DECREF(tmp);
- return NULL;
+ /* we have to truncate, we'll fill b_limbs-1 limbs */
+ int diff = self->limb_count - (b_limbs - 1);
+ shift = diff * 4; /* we truncated diff limbs, so exp will become greater */
+ for(i = 0;i < b_limbs - 1; i++)
+ b_tab[i] = self->limbs[i + diff];
+ /* we'll check remainding limbs of self, to check if
+ * we truncated something important */
+ for(i = 0;i < diff; i++)
+ if(self->limbs[i]) {
+ exact = 0;
+ break;
+ }
}
- flags = context_ignore_all_flags(ctx2);
+ b_tab[b_limbs-1] = 0; /* only one we didn't set */
- if (!flags) {
- Py_DECREF(tmp);
- Py_DECREF(ctx2);
+ /* in case exp mod 2 = 1, we'll extend multiply b_tab * 10, and decrease shift
+ * [that's why we made room for another limb] */
+ if(exp_mod_i(self->exp, 2)) {
+ shift --;
+ _limb_add_one_digit(b_tab, (b_limbs-1) * LOG, 0);
+ }
+ b_limbs = _limb_size(b_tab, b_limbs);
+
+ /* ok, now our DD is b_tab, and our e is self->exp + shift */
+
+ /* how big tab we need - sqrt uses ceil(limbs/2) + 1, but there
+ * is possibility that we're going to set ideal exp
+ * Our result exp will be (self->exp + shift)/2, and ideal exp is
+ * floor(self->exp/2), so we might be forced to extend mantisa by
+ * shift/2 */
+ ret_limbs = (b_limbs + 1)/2 + 2 + (shift > 0 ? (shift + 1)/2 + 1 : 0);
+ /* TODO maybe I should use some temporary tab first? */
+ ret = _NEW_decimalobj(ret_limbs * LOG, self->sign, exp_from_i(0));
+ if(!ret) {
+ free(b_tab);
return NULL;
}
- ans = _NEW_decimalobj(3, 0, exp_from_i(0));
-
- if (!ans) {
- Py_DECREF(tmp);
- Py_DECREF(ctx2);
- Py_DECREF(flags);
- }
-
- tmp2 = _NEW_decimalobj(3, 0, exp_from_i(0));
-
- if (!tmp2) {
- Py_DECREF(tmp);
- Py_DECREF(ctx2);
- Py_DECREF(flags);
- Py_DECREF(ans);
- }
-
- if (exp_mod_i(ADJUSTED(tmp), 2) == 0) {
- ans->limbs[0] = 819;
- ans->exp = exp_sub_i(ADJUSTED(tmp), 2);
- tmp2->limbs[0] = 259;
- exp_inp_sub_i(&(tmp2->exp), 2);
- }
- else {
- ans->limbs[0] = 259;
- ans->exp = exp_add_i(tmp->exp, tmp->ob_size - 3);
- tmp2->limbs[0] = 819;
- exp_inp_sub_i(&(tmp2->exp), 3);
- }
-
- firstprec = ctx2->prec;
- ctx2->prec = 3;
- /* ans += tmp * tmp2 */
-
- tmp3 = _do_decimal_multiply(tmp, tmp2, ctx2);
-
- if (!tmp3)
- goto err;
-
- Py_DECREF(tmp2);
- tmp2 = _do_decimal_add(ans, tmp3, ctx2);
- Py_DECREF(tmp3);
- tmp3 = 0;
-
- if (!tmp2)
- goto err;
-
- Py_DECREF(ans);
- ans = tmp2;
- tmp2 = 0;
-/* ans->exp -= 1 + ADJUSTED(tmp)/2;
- if (1 + ADJUSTED(tmp) < 0 && (1 + ADJUSTED(tmp)) % 2)
- ans->exp --;
- */
- exp_inp_sub(&(ans->exp),exp_add_i(exp_floordiv_i(ADJUSTED(tmp), 2), 1));
- Emax = ctx2->Emax;
- Emin = ctx2->Emin;
-
- ctx2->Emax = PyDecimal_DefaultContext->Emax;
- ctx2->Emin = PyDecimal_DefaultContext->Emin;
-
- half = _decimal_fromliteral(self->ob_type, "0.5", 3, ctx2);
-
- if (!half)
- goto err;
-
- maxp = firstprec + 2;
- rounding = ctx2->rounding;
- ctx2->rounding = ROUND_HALF_EVEN;
-
- while (1) {
- ctx2->prec = 2 * ctx2->prec - 2 < maxp ? 2 * ctx2->prec - 2: maxp;
- /* ans = half * (ans + tmp/ans) */
- tmp2 = (decimalobject*) _do_decimal__divide(tmp, ans, 0, ctx2);
- if (!tmp2)
- goto err;
- /* ans = half * (ans + tmp2) */
- tmp3 = _do_decimal_add(ans, tmp2, ctx2);
- if (!tmp3)
- goto err;
-
- /* ans = half * tmp3 */
- Py_DECREF(ans);
- ans = _do_decimal_multiply(half, tmp3, ctx2);
-
- if (!ans)
- goto err;
-
- Py_DECREF(tmp2);
- Py_DECREF(tmp3);
- tmp2 = 0;
- tmp3 = 0;
-
- if (ctx2->prec == maxp)
- break;
- }
-
- ctx2->prec = firstprec;
- prevexp = ADJUSTED(ans);
- tmp2 = _decimal_round(ans, -1, ctx2, -1);
- if (!tmp2)
- goto err;
- Py_DECREF(ans);
- ans = _NEW_decimalobj(tmp2->ob_size + 1, tmp2->sign, tmp2->exp);
- if (!ans)
- goto err;
-
- ctx2->prec = firstprec + 1;
- if (exp_ne(prevexp, ADJUSTED(tmp2))) {
- _limb_first_n_digits(tmp2->limbs, tmp2->ob_size, 0, ans->limbs, ans->ob_size);
- exp_dec(&(ans->exp));
- }
- else {
- ans->limb_count = tmp2->limb_count;
- ans->ob_size = tmp2->ob_size;
- for (i = 0; i < ans->limb_count; i++)
- ans->limbs[i] = tmp2->limbs[i];
- }
-
- Py_DECREF(tmp2);
- tmp2 = 0;
-
- {
- int cmp;
- decimalobject *lower;
- half->exp = exp_sub_i(ans->exp, 1);
- half->limbs[0] = 5;
- lower = _do_decimal_subtract(ans, half, ctx2);
- if (!lower)
- goto err;
-
- ctx2->rounding = ROUND_UP;
- tmp2 = _do_decimal_multiply(lower, lower, ctx2);
- Py_DECREF(lower);
- lower = 0;
- if (!tmp2) {
- goto err;
- }
-
- cmp = _do_real_decimal_compare(tmp2, tmp, ctx2);
- if (PyErr_Occurred()) {
- goto err;
- }
-
- Py_DECREF(tmp2);
- tmp2 = 0;
-
- if (cmp == 1) {
- half->exp = ans->exp;
- half->limbs[0] = 1;
- tmp2 = _do_decimal_subtract(ans, half, ctx2);
- if (!tmp2)
- goto err;
- Py_DECREF(ans);
- ans = tmp2;
- tmp2 = 0;
+ ret_limbs = _limb_sqrt(b_tab, b_limbs, ret->limbs);
+ ret->limb_count = ret_limbs;
+ ret->ob_size = _limb_size_s(ret->limbs, ret_limbs * LOG);
+ /* let's check if it's exact */
+ {
+ long tmp_digs;
+ long cmp_res;
+ long *tmp = (long*)malloc((ret_limbs * 2 + 1) * sizeof(long));
+ tmp_digs = _limb_multiply(ret->limbs, ret->ob_size, ret->limbs,
+ ret->ob_size, tmp);
+
+ cmp_res = _limb_compare_un(b_tab, b_limbs, tmp, (tmp_digs +LOG-1)/LOG);
+ exact = exact && cmp_res == 0;
+ free(tmp);
+ }
+ free(b_tab);
+
+ /* now all we need is to set exp :> and add 1 ULP in case it's not exact ;> */
+
+ if (exact) {
+ long expdiff;
+ exp_t tmp_exp;
+ ret->exp = exp_floordiv_i(self->exp, 2);
+ tmp_exp = exp_floordiv_i(exp_add_i(self->exp, shift), 2);
+ expdiff = exp_to_i(exp_sub(ret->exp, tmp_exp));
+ /* TODO SLOW */
+ if(expdiff > 0) {
+ while(expdiff && ret->ob_size > 1) {
+ assert(!(ret->limbs[0] % 10)); /* Am I sure?? TODO */
+ _limb_cut_one_digit(ret->limbs, ret->ob_size);
+ ret->ob_size --;
+ expdiff --;
+ }
}
else {
- decimalobject *upper;
- half->exp = exp_sub_i(ans->exp, 1);
- half->limbs[0] = 5;
- upper = _do_decimal_add(ans, half, ctx2);
- if (!upper)
- goto err;
- ctx2->rounding = ROUND_DOWN;
-
- tmp2 = _do_decimal_multiply(upper, upper, ctx2);
-
- Py_DECREF(upper);
- upper = 0;
-
- cmp = _do_real_decimal_compare(tmp2, tmp, ctx2);
- if (PyErr_Occurred())
- goto err;
-
- Py_DECREF(tmp2);
- tmp2 = 0;
-
- if (cmp == -1) {
- half->exp = ans->exp;
- half->limbs[0] = 1;
- tmp2 = _do_decimal_add(ans, half, ctx2);
-
- if (!tmp2)
- goto err;
- Py_DECREF(ans);
- ans = tmp2;
- tmp2 = 0;
+ while(expdiff) {
+ _limb_add_one_digit(ret->limbs, ret->ob_size, 0);
+ ret->ob_size ++;
+ expdiff ++;
}
}
}
-
- exp_inp_add(&(ans->exp), expadd);
- ctx2->rounding = rounding;
-
- tmp2 = _decimal_fix(ans, ctx2);
- if (!tmp2)
- goto err;
- Py_DECREF(ans);
- ans = tmp2;
- tmp2 = 0;
-
- rounding = ctx2->rounding_dec;
- ctx2->rounding_dec = NEVER_ROUND;
-
- {
- int cmp;
- tmp2 = _do_decimal_multiply(ans, ans, ctx2);
- if (!tmp2)
- goto err;
-
- cmp = _do_real_decimal_compare(tmp2, self, ctx2);
-
- if (PyErr_Occurred())
- goto err;
-
- Py_DECREF(tmp2);
- tmp2 = 0;
-
- if (cmp != 0) {
- if (handle_Rounded(ctx, NULL))
- goto err;
-
- if (handle_Inexact(ctx, NULL))
- goto err;
- }
-
- else {
-/* long exp = self->exp / 2;
- if (self->exp < 0 && self->exp % 2)
- exp --;*/
- exp_t exp = exp_floordiv_i(self->exp, 2);
- ctx2->prec += exp_to_i(exp_sub(ans->exp, exp));
- tmp2 = _decimal_rescale(ans, exp, ctx2, -1, 1);
-
- if (!tmp2)
- goto err;
- Py_DECREF(ans);
- ans = tmp2;
- tmp2 = 0;
- ctx2->prec = firstprec;
- }
+ else {
+ if(ret->limbs[0] % 5 == 0) ret->limbs[0] ++;
+ ret-> exp = exp_floordiv_i(exp_add_i(self->exp, shift), 2);
}
- tmp2 = _decimal_fix(ans, ctx);
- if (!tmp2)
- goto err;
- Py_DECREF(ans);
- ans = tmp2;
- tmp2 = 0;
-
- Py_DECREF(flags);
- Py_DECREF(ctx2);
- Py_DECREF(half);
- Py_DECREF(tmp);
-
- return ans;
-
-err:
- Py_XDECREF(tmp);
- Py_XDECREF(tmp2);
- Py_XDECREF(tmp3);
- Py_XDECREF(ans);
- Py_XDECREF(flags);
- Py_XDECREF(ctx2);
- Py_XDECREF(ret);
- Py_XDECREF(half);
+ /* TODO it's not thread safe */
+ rounding = ctx->rounding;
+ ctx->rounding = ROUND_HALF_EVEN;
+
+ ret2 = _decimal_fix(ret, ctx);
+ ctx->rounding = rounding;
+ Py_DECREF(ret);
+ return ret2;
+
return NULL;
}
DECIMAL_UNARY_FUNC(sqrt)
More information about the Python-checkins
mailing list