
From: Oleg Nesterov <oleg@tv-sign.ru>

sys_get_mempolicy() accesses user memory with mmap_sem held.  If I
understand correctly, this can cause deadlock:

sys_get_mempolicy:		Another thread, same mm:

down_read(mmap_sem);
				down_write(mmap_sem);
put_user();
do_page_fault:
down_read(mmap_sem);

Compile tested only, I have no NUMA machine.

Signed-off-by: Oleg Nesterov <oleg@tv-sign.ru>
Signed-off-by: Andrew Morton <akpm@osdl.org>
---

 25-akpm/mm/mempolicy.c |   54 +++++++++++++++++++++++++++----------------------
 1 files changed, 30 insertions(+), 24 deletions(-)

diff -puN mm/mempolicy.c~fix-put_user-under-mmap_sem-in-sys_get_mempolicy mm/mempolicy.c
--- 25/mm/mempolicy.c~fix-put_user-under-mmap_sem-in-sys_get_mempolicy	2005-01-25 21:08:39.989207680 -0800
+++ 25-akpm/mm/mempolicy.c	2005-01-25 21:08:39.993207072 -0800
@@ -482,26 +482,38 @@ asmlinkage long sys_get_mempolicy(int __
 				  unsigned long maxnode,
 				  unsigned long addr, unsigned long flags)
 {
-	int err, pval;
-	struct mm_struct *mm = current->mm;
-	struct vm_area_struct *vma = NULL;
+	int err, pval = 0; /* make compiler happy */
 	struct mempolicy *pol = current->mempolicy;
 
 	if (flags & ~(unsigned long)(MPOL_F_NODE|MPOL_F_ADDR))
 		return -EINVAL;
 	if (nmask != NULL && maxnode < MAX_NUMNODES)
 		return -EINVAL;
+
 	if (flags & MPOL_F_ADDR) {
+		struct mm_struct *mm = current->mm;
+		struct vm_area_struct *vma;
+
+		err = 0;
 		down_read(&mm->mmap_sem);
 		vma = find_vma_intersection(mm, addr, addr+1);
-		if (!vma) {
-			up_read(&mm->mmap_sem);
-			return -EFAULT;
+		if (!vma)
+			err = -EFAULT;
+		else {
+			if (vma->vm_ops && vma->vm_ops->get_policy)
+				pol = vma->vm_ops->get_policy(vma, addr);
+			else
+				pol = vma->vm_policy;
+
+			if (flags & MPOL_F_NODE) {
+				pval = lookup_node(mm, addr);
+				if (pval < 0)
+					err = pval;
+			}
 		}
-		if (vma->vm_ops && vma->vm_ops->get_policy)
-			pol = vma->vm_ops->get_policy(vma, addr);
-		else
-			pol = vma->vm_policy;
+		up_read(&mm->mmap_sem);
+		if (err)
+			goto out;
 	} else if (addr)
 		return -EINVAL;
 
@@ -509,17 +521,14 @@ asmlinkage long sys_get_mempolicy(int __
 		pol = &default_policy;
 
 	if (flags & MPOL_F_NODE) {
-		if (flags & MPOL_F_ADDR) {
-			err = lookup_node(mm, addr);
-			if (err < 0)
+		if (!(flags & MPOL_F_ADDR)) {
+			if (pol == current->mempolicy &&
+					pol->policy == MPOL_INTERLEAVE) {
+				pval = current->il_next;
+			} else {
+				err = -EINVAL;
 				goto out;
-			pval = err;
-		} else if (pol == current->mempolicy &&
-				pol->policy == MPOL_INTERLEAVE) {
-			pval = current->il_next;
-		} else {
-			err = -EINVAL;
-			goto out;
+			}
 		}
 	} else
 		pval = pol->policy;
@@ -534,10 +543,7 @@ asmlinkage long sys_get_mempolicy(int __
 		get_zonemask(pol, nodes);
 		err = copy_nodes_to_user(nmask, maxnode, nodes, sizeof(nodes));
 	}
-
- out:
-	if (vma)
-		up_read(&current->mm->mmap_sem);
+out:
 	return err;
 }
 
_
